#ifndef MATRIX_HPP
#define MATRIX_HPP 1
#include <array>
#include <algorithm>
#include <span>
#include <ostream>
template<typename T, std::size_t M, std::size_t N>
requires(std::is_arithmetic_v<T>)
class matrix
{
protected:
std::array<T, N * M> v;
public:
constexpr matrix(const std::array<std::array<T, N>, M> &vs)
{
for(std::size_t m = 0; m < M; m++)
for(std::size_t n = 0; n < N; n++) {
v[(m * N) + n] = vs[m][n];
}
}
constexpr matrix(const std::array<T, M * N> &vs) :
v(vs)
{}
constexpr matrix(T a) {
v.fill(a);
}
constexpr matrix &operator+=(const matrix &rhs)
{
for(std::size_t i = 0; i < M * N; i++)
v[i] += rhs.v[i];
return v;
}
constexpr friend matrix operator+(
matrix lhs,
const matrix &rhs)
{
lhs += rhs;
return lhs;
}
constexpr friend matrix operator-(matrix lhs)
{
for(std::size_t i = 0; i < M * N; i++)
lhs.v[i] = -lhs.v[i];
return lhs;
}
constexpr matrix &operator-=(const matrix &rhs)
{
for(std::size_t i = 0; i < M * N; i++)
v[i] -= rhs.v[i];
return *this;
}
constexpr friend matrix operator-(
matrix lhs,
const matrix &rhs)
{
lhs -= rhs;
return lhs;
}
template<std::size_t P>
constexpr friend matrix<T, M, P> operator*(
const matrix<T, M, N> &rhs,
const matrix<T, N, P> &lhs)
{
matrix<T, M, P> ret(0);
for(std::size_t p = 0; p < P; p++) {
for(std::size_t m = 0; m < M; m++) {
for(std::size_t n = 0; n < N; n++) {
ret.at(m, p) += rhs.at(m, n) * lhs.at(n, p);
}
}
}
return ret;
}
constexpr matrix &operator*=(const T &rhs)
{
std::ranges::transform(
v.begin(),
v.end(),
v.begin(),
[rhs](T t) -> T { return t * rhs; }
);
return *this;
}
constexpr friend matrix operator*(
matrix lhs,
const T &rhs)
{
lhs *= rhs;
return lhs;
}
constexpr std::array<std::reference_wrapper<T>, N> row(std::size_t m)
{
std::array<std::reference_wrapper<T>, N> ret;
for(std::size_t n = 0; n < N; n++) ret[n] = v[(m * N) + n];
return ret;
}
constexpr std::array<std::reference_wrapper<T>, M> col(std::size_t n)
{
std::array<std::reference_wrapper<T>, M> ret;
for(std::size_t m = 0; m < M; m++) ret[m] = v[(m * M) + n];
return ret;
}
constexpr T &at(std::size_t m, std::size_t n)
{
return v.at((m * N) + n);
}
constexpr const T &at(std::size_t m, std::size_t n) const
{
return v.at((m * N) + n);
}
constexpr std::array<T, M * N> &data() noexcept { return v; }
constexpr std::array<T, M * N> &data() const noexcept { return v; }
constexpr friend std::ostream &operator<<(std::ostream &lhs, const matrix &rhs)
{
for(std::size_t m = 0; m < M; m++) {
lhs << '\n';
for(std::size_t n = 0; n < N; n++) {
lhs << rhs.at(m, n) << ", ";
}
}
return lhs;
}
};
template<typename T, std::size_t S>
class sqmatrix final : public matrix<T, S, S>
{
static constexpr sqmatrix identity()
{
sqmatrix ret(0);
for(std::size_t i = 0; i < S; i++) {
ret.at(i, i) = 1;
}
return ret;
}
constexpr sqmatrix operator*=(const sqmatrix &rhs)
{
sqmatrix copy(*this);
for(std::size_t i = 0; i < S; i++) {
for(std::size_t j = 0; j < S; j++) {
T cell = 0;
for(std::size_t k = 0; k < S; k++)
cell += copy.at(k, j) * rhs.at(i, k);
this->at(i, j) = cell;
}
}
return *this;
}
constexpr friend sqmatrix operator*(
sqmatrix lhs,
const sqmatrix &rhs)
{
lhs *= rhs;
return lhs;
}
};
template<typename T, std::size_t L>
class rowmatrix final : public matrix<T, L, 1>
{
constexpr T &operator[](std::size_t i) const
{
return this->v[i];
}
constexpr T &operator[](std::size_t i)
{
return this->v[i];
}
};
template<typename T, std::size_t L>
class colmatrix final : public matrix<T, 1, L>
{
constexpr T &operator[](std::size_t i) const
{
return this->v[i];
}
constexpr T &operator[](std::size_t i)
{
return this->v[i];
}
};
#endif