summaryrefslogblamecommitdiffstats
path: root/matrix.hpp
blob: ada303bacdd17f68658f6ac4afce2a5db48989e7 (plain) (tree)









































































































































































































                                                                                   
#ifndef MATRIX_HPP
#define MATRIX_HPP 1

#include <array>
#include <algorithm>
#include <span>
#include <ostream>

template<typename T, std::size_t M, std::size_t N>
    requires(std::is_arithmetic_v<T>)
class matrix
{
protected:
    std::array<T, N * M> v;
public:
    constexpr matrix(const std::array<std::array<T, N>, M> &vs)
    {
        for(std::size_t m = 0; m < M; m++)
        for(std::size_t n = 0; n < N; n++) {
            v[(m * N) + n] = vs[m][n];
        }
    }
    constexpr matrix(const std::array<T, M * N> &vs) :
        v(vs)
    {}

    constexpr matrix(T a) {
        v.fill(a);
    }

    constexpr matrix &operator+=(const matrix &rhs)
    {
        for(std::size_t i = 0; i < M * N; i++)
            v[i] += rhs.v[i];
        return v;
    }

    constexpr friend matrix operator+(
            matrix lhs, 
            const matrix &rhs)
    {
        lhs += rhs;
        return lhs;
    }

    constexpr friend matrix operator-(matrix lhs)
    {
        for(std::size_t i = 0; i < M * N; i++)
            lhs.v[i] = -lhs.v[i];
        return lhs;
    }

    constexpr matrix &operator-=(const matrix &rhs)
    {
        for(std::size_t i = 0; i < M * N; i++)
            v[i] -= rhs.v[i];
        return *this;
    }

    constexpr friend matrix operator-(
            matrix lhs,
            const matrix &rhs)
    {
        lhs -= rhs;
        return lhs;
    }

    template<std::size_t P>
    constexpr friend matrix<T, M, P> operator*(
            const matrix<T, M, N> &rhs,
            const matrix<T, N, P> &lhs)
    {
        matrix<T, M, P> ret(0);
        for(std::size_t p = 0; p < P; p++) {
            for(std::size_t m = 0; m < M; m++) {
                for(std::size_t n = 0; n < N; n++) {
                    ret.at(m, p) += rhs.at(m, n) * lhs.at(n, p);
                }
            }
        }
        return ret;
    }

    constexpr matrix &operator*=(const T &rhs)
    {
        std::ranges::transform(
                v.begin(), 
                v.end(), 
                v.begin(), 
                [rhs](T t) -> T { return t * rhs; }
            );
        return *this;
    }

    constexpr friend matrix operator*(
            matrix lhs,
            const T &rhs)
    {
        lhs *= rhs;
        return lhs;
    }

    constexpr std::array<std::reference_wrapper<T>, N> row(std::size_t m)
    {
        std::array<std::reference_wrapper<T>, N> ret;
        for(std::size_t n = 0; n < N; n++) ret[n] = v[(m * N) + n];
        return ret;
    }

    constexpr std::array<std::reference_wrapper<T>, M> col(std::size_t n)
    {
        std::array<std::reference_wrapper<T>, M> ret;
        for(std::size_t m = 0; m < M; m++) ret[m] = v[(m * M) + n];
        return ret;
    }

    constexpr T &at(std::size_t m, std::size_t n)
    {
        return v.at((m * N) + n);
    }
    constexpr const T &at(std::size_t m, std::size_t n) const
    {
        return v.at((m * N) + n);
    }

    constexpr std::array<T, M * N> &data() noexcept { return v; }
    constexpr std::array<T, M * N> &data() const noexcept { return v; }

    constexpr friend std::ostream &operator<<(std::ostream &lhs, const matrix &rhs)
    {
        for(std::size_t m = 0; m < M; m++) {
            lhs << '\n';
            for(std::size_t n = 0; n < N; n++) {
                lhs << rhs.at(m, n) << ", ";
            }
        }
        return lhs;
    }
};

template<typename T, std::size_t S>
class sqmatrix final : public matrix<T, S, S>
{
    static constexpr sqmatrix identity()
    {
        sqmatrix ret(0);
        for(std::size_t i = 0; i < S; i++) {
            ret.at(i, i) = 1;
        }
        return ret;
    }

    constexpr sqmatrix operator*=(const sqmatrix &rhs)
    {
        sqmatrix copy(*this);
        for(std::size_t i = 0; i < S; i++) {
            for(std::size_t j = 0; j < S; j++) {
                T cell = 0;
                for(std::size_t k = 0; k < S; k++)
                    cell += copy.at(k, j) * rhs.at(i, k);
                this->at(i, j) = cell;
            }
        }
        return *this;
    }

    constexpr friend sqmatrix operator*(
            sqmatrix lhs,
            const sqmatrix &rhs)
    {
        lhs *= rhs;
        return lhs;
    }
};

template<typename T, std::size_t L>
class rowmatrix final : public matrix<T, L, 1>
{
    constexpr T &operator[](std::size_t i) const
    {
        return this->v[i];
    }
    constexpr T &operator[](std::size_t i)
    {
        return this->v[i];
    }
};

template<typename T, std::size_t L>
class colmatrix final : public matrix<T, 1, L>
{
    constexpr T &operator[](std::size_t i) const
    {
        return this->v[i];
    }
    constexpr T &operator[](std::size_t i)
    {
        return this->v[i];
    }
};

#endif