diff options
-rw-r--r-- | matrix.hpp | 202 |
1 files changed, 202 insertions, 0 deletions
diff --git a/matrix.hpp b/matrix.hpp new file mode 100644 index 0000000..ada303b --- /dev/null +++ b/matrix.hpp @@ -0,0 +1,202 @@ +#ifndef MATRIX_HPP +#define MATRIX_HPP 1 + +#include <array> +#include <algorithm> +#include <span> +#include <ostream> + +template<typename T, std::size_t M, std::size_t N> + requires(std::is_arithmetic_v<T>) +class matrix +{ +protected: + std::array<T, N * M> v; +public: + constexpr matrix(const std::array<std::array<T, N>, M> &vs) + { + for(std::size_t m = 0; m < M; m++) + for(std::size_t n = 0; n < N; n++) { + v[(m * N) + n] = vs[m][n]; + } + } + constexpr matrix(const std::array<T, M * N> &vs) : + v(vs) + {} + + constexpr matrix(T a) { + v.fill(a); + } + + constexpr matrix &operator+=(const matrix &rhs) + { + for(std::size_t i = 0; i < M * N; i++) + v[i] += rhs.v[i]; + return v; + } + + constexpr friend matrix operator+( + matrix lhs, + const matrix &rhs) + { + lhs += rhs; + return lhs; + } + + constexpr friend matrix operator-(matrix lhs) + { + for(std::size_t i = 0; i < M * N; i++) + lhs.v[i] = -lhs.v[i]; + return lhs; + } + + constexpr matrix &operator-=(const matrix &rhs) + { + for(std::size_t i = 0; i < M * N; i++) + v[i] -= rhs.v[i]; + return *this; + } + + constexpr friend matrix operator-( + matrix lhs, + const matrix &rhs) + { + lhs -= rhs; + return lhs; + } + + template<std::size_t P> + constexpr friend matrix<T, M, P> operator*( + const matrix<T, M, N> &rhs, + const matrix<T, N, P> &lhs) + { + matrix<T, M, P> ret(0); + for(std::size_t p = 0; p < P; p++) { + for(std::size_t m = 0; m < M; m++) { + for(std::size_t n = 0; n < N; n++) { + ret.at(m, p) += rhs.at(m, n) * lhs.at(n, p); + } + } + } + return ret; + } + + constexpr matrix &operator*=(const T &rhs) + { + std::ranges::transform( + v.begin(), + v.end(), + v.begin(), + [rhs](T t) -> T { return t * rhs; } + ); + return *this; + } + + constexpr friend matrix operator*( + matrix lhs, + const T &rhs) + { + lhs *= rhs; + return lhs; + } + + constexpr std::array<std::reference_wrapper<T>, N> row(std::size_t m) + { + std::array<std::reference_wrapper<T>, N> ret; + for(std::size_t n = 0; n < N; n++) ret[n] = v[(m * N) + n]; + return ret; + } + + constexpr std::array<std::reference_wrapper<T>, M> col(std::size_t n) + { + std::array<std::reference_wrapper<T>, M> ret; + for(std::size_t m = 0; m < M; m++) ret[m] = v[(m * M) + n]; + return ret; + } + + constexpr T &at(std::size_t m, std::size_t n) + { + return v.at((m * N) + n); + } + constexpr const T &at(std::size_t m, std::size_t n) const + { + return v.at((m * N) + n); + } + + constexpr std::array<T, M * N> &data() noexcept { return v; } + constexpr std::array<T, M * N> &data() const noexcept { return v; } + + constexpr friend std::ostream &operator<<(std::ostream &lhs, const matrix &rhs) + { + for(std::size_t m = 0; m < M; m++) { + lhs << '\n'; + for(std::size_t n = 0; n < N; n++) { + lhs << rhs.at(m, n) << ", "; + } + } + return lhs; + } +}; + +template<typename T, std::size_t S> +class sqmatrix final : public matrix<T, S, S> +{ + static constexpr sqmatrix identity() + { + sqmatrix ret(0); + for(std::size_t i = 0; i < S; i++) { + ret.at(i, i) = 1; + } + return ret; + } + + constexpr sqmatrix operator*=(const sqmatrix &rhs) + { + sqmatrix copy(*this); + for(std::size_t i = 0; i < S; i++) { + for(std::size_t j = 0; j < S; j++) { + T cell = 0; + for(std::size_t k = 0; k < S; k++) + cell += copy.at(k, j) * rhs.at(i, k); + this->at(i, j) = cell; + } + } + return *this; + } + + constexpr friend sqmatrix operator*( + sqmatrix lhs, + const sqmatrix &rhs) + { + lhs *= rhs; + return lhs; + } +}; + +template<typename T, std::size_t L> +class rowmatrix final : public matrix<T, L, 1> +{ + constexpr T &operator[](std::size_t i) const + { + return this->v[i]; + } + constexpr T &operator[](std::size_t i) + { + return this->v[i]; + } +}; + +template<typename T, std::size_t L> +class colmatrix final : public matrix<T, 1, L> +{ + constexpr T &operator[](std::size_t i) const + { + return this->v[i]; + } + constexpr T &operator[](std::size_t i) + { + return this->v[i]; + } +}; + +#endif |