summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--matrix.hpp202
1 files changed, 202 insertions, 0 deletions
diff --git a/matrix.hpp b/matrix.hpp
new file mode 100644
index 0000000..ada303b
--- /dev/null
+++ b/matrix.hpp
@@ -0,0 +1,202 @@
+#ifndef MATRIX_HPP
+#define MATRIX_HPP 1
+
+#include <array>
+#include <algorithm>
+#include <span>
+#include <ostream>
+
+template<typename T, std::size_t M, std::size_t N>
+ requires(std::is_arithmetic_v<T>)
+class matrix
+{
+protected:
+ std::array<T, N * M> v;
+public:
+ constexpr matrix(const std::array<std::array<T, N>, M> &vs)
+ {
+ for(std::size_t m = 0; m < M; m++)
+ for(std::size_t n = 0; n < N; n++) {
+ v[(m * N) + n] = vs[m][n];
+ }
+ }
+ constexpr matrix(const std::array<T, M * N> &vs) :
+ v(vs)
+ {}
+
+ constexpr matrix(T a) {
+ v.fill(a);
+ }
+
+ constexpr matrix &operator+=(const matrix &rhs)
+ {
+ for(std::size_t i = 0; i < M * N; i++)
+ v[i] += rhs.v[i];
+ return v;
+ }
+
+ constexpr friend matrix operator+(
+ matrix lhs,
+ const matrix &rhs)
+ {
+ lhs += rhs;
+ return lhs;
+ }
+
+ constexpr friend matrix operator-(matrix lhs)
+ {
+ for(std::size_t i = 0; i < M * N; i++)
+ lhs.v[i] = -lhs.v[i];
+ return lhs;
+ }
+
+ constexpr matrix &operator-=(const matrix &rhs)
+ {
+ for(std::size_t i = 0; i < M * N; i++)
+ v[i] -= rhs.v[i];
+ return *this;
+ }
+
+ constexpr friend matrix operator-(
+ matrix lhs,
+ const matrix &rhs)
+ {
+ lhs -= rhs;
+ return lhs;
+ }
+
+ template<std::size_t P>
+ constexpr friend matrix<T, M, P> operator*(
+ const matrix<T, M, N> &rhs,
+ const matrix<T, N, P> &lhs)
+ {
+ matrix<T, M, P> ret(0);
+ for(std::size_t p = 0; p < P; p++) {
+ for(std::size_t m = 0; m < M; m++) {
+ for(std::size_t n = 0; n < N; n++) {
+ ret.at(m, p) += rhs.at(m, n) * lhs.at(n, p);
+ }
+ }
+ }
+ return ret;
+ }
+
+ constexpr matrix &operator*=(const T &rhs)
+ {
+ std::ranges::transform(
+ v.begin(),
+ v.end(),
+ v.begin(),
+ [rhs](T t) -> T { return t * rhs; }
+ );
+ return *this;
+ }
+
+ constexpr friend matrix operator*(
+ matrix lhs,
+ const T &rhs)
+ {
+ lhs *= rhs;
+ return lhs;
+ }
+
+ constexpr std::array<std::reference_wrapper<T>, N> row(std::size_t m)
+ {
+ std::array<std::reference_wrapper<T>, N> ret;
+ for(std::size_t n = 0; n < N; n++) ret[n] = v[(m * N) + n];
+ return ret;
+ }
+
+ constexpr std::array<std::reference_wrapper<T>, M> col(std::size_t n)
+ {
+ std::array<std::reference_wrapper<T>, M> ret;
+ for(std::size_t m = 0; m < M; m++) ret[m] = v[(m * M) + n];
+ return ret;
+ }
+
+ constexpr T &at(std::size_t m, std::size_t n)
+ {
+ return v.at((m * N) + n);
+ }
+ constexpr const T &at(std::size_t m, std::size_t n) const
+ {
+ return v.at((m * N) + n);
+ }
+
+ constexpr std::array<T, M * N> &data() noexcept { return v; }
+ constexpr std::array<T, M * N> &data() const noexcept { return v; }
+
+ constexpr friend std::ostream &operator<<(std::ostream &lhs, const matrix &rhs)
+ {
+ for(std::size_t m = 0; m < M; m++) {
+ lhs << '\n';
+ for(std::size_t n = 0; n < N; n++) {
+ lhs << rhs.at(m, n) << ", ";
+ }
+ }
+ return lhs;
+ }
+};
+
+template<typename T, std::size_t S>
+class sqmatrix final : public matrix<T, S, S>
+{
+ static constexpr sqmatrix identity()
+ {
+ sqmatrix ret(0);
+ for(std::size_t i = 0; i < S; i++) {
+ ret.at(i, i) = 1;
+ }
+ return ret;
+ }
+
+ constexpr sqmatrix operator*=(const sqmatrix &rhs)
+ {
+ sqmatrix copy(*this);
+ for(std::size_t i = 0; i < S; i++) {
+ for(std::size_t j = 0; j < S; j++) {
+ T cell = 0;
+ for(std::size_t k = 0; k < S; k++)
+ cell += copy.at(k, j) * rhs.at(i, k);
+ this->at(i, j) = cell;
+ }
+ }
+ return *this;
+ }
+
+ constexpr friend sqmatrix operator*(
+ sqmatrix lhs,
+ const sqmatrix &rhs)
+ {
+ lhs *= rhs;
+ return lhs;
+ }
+};
+
+template<typename T, std::size_t L>
+class rowmatrix final : public matrix<T, L, 1>
+{
+ constexpr T &operator[](std::size_t i) const
+ {
+ return this->v[i];
+ }
+ constexpr T &operator[](std::size_t i)
+ {
+ return this->v[i];
+ }
+};
+
+template<typename T, std::size_t L>
+class colmatrix final : public matrix<T, 1, L>
+{
+ constexpr T &operator[](std::size_t i) const
+ {
+ return this->v[i];
+ }
+ constexpr T &operator[](std::size_t i)
+ {
+ return this->v[i];
+ }
+};
+
+#endif