9#include <initializer_list>
40template<
typename T =
double>
51 explicit Matrix(T val,
size_t rowCount,
size_t colCount,
size_t elementDimension = 1) {
52 Resize(rowCount, colCount, elementDimension);
61 explicit Matrix(T* val,
size_t colCount) {
75 Matrix(
const std::initializer_list<std::initializer_list<T>>& lst) {
77 auto rows = lst.size();
78 auto cols = lst.begin()->size();
80 for(
const auto& l : lst) {
81 for(
const auto& v : l) {
94 Matrix(
const std::initializer_list<std::initializer_list<std::initializer_list<T>>>& lst) {
95 int i = 0, j = 0, c = 0;
96 auto rows = lst.size();
97 auto cols = lst.begin()->size();
98 auto elems = lst.begin()->begin()->size();
100 for(
const auto& l : lst) {
101 for(
const auto& v : l) {
102 for(
const auto k : v) {
135 _data[i] =
static_cast<T
>(other(i / (index_factor), i % (index_factor), i %
_element_size));
158 Random(
size_t rows,
size_t columns,
size_t element_size = 1,
double minValue = 0.0,
double maxValue = 1.0) {
174 constexpr double two_pi = 2.0 * M_PI;
181 auto mag = sigma * sqrt(-2.0 * log(u1));
182 out.
_data[i] = mag * cos(two_pi * u2) + mu;
183 out.
_data[i + 1] = mag * sin(two_pi * u2) + mu;
193 [[nodiscard]]
inline size_t rows()
const {
return _rows; }
235 for(
size_t c = 0; c <
_columns; c++) {
237 for(
size_t i = 1; i <
_rows; i++) {
239 for(
size_t j = 0; j <
_columns; j++) {
271 assert(this->
rows() == other.
rows());
273 for(
size_t i = 0; i <
rows(); ++i) {
304 return !(rhs == *
this);
346 if(
_data[i] != rhs.
_data[i]) {
return false; }
403 for(
size_t m = 0; m <
rows(); m++) {
404 for(
size_t n = 0; n <
columns(); n++) {
405 for(
size_t p = 0; p < other.
rows(); p++) {
406 for(
size_t q = 0; q < other.
columns(); q++) {
408 (*result)(m * other.
rows() + p, n * other.
columns() + q, elem) =
437 out(axis == 0 ? i : 0, axis == 1 ? i : 0) =
438 GetSlice(axis == 0 ? i : 0, axis == 0 ? i :
_rows - 1, axis == 1 ? i : 0, axis == 1 ? i :
_columns - 1)
450 (*this) = *
this * rhs;
460 (*this) = (*this) + rhs;
469 (*this) = (*this) - rhs;
526 auto rowCount = isInColumns ? other.
columns() : other.
rows();
527 assert(rowCount ==
_rows);
529 for(
size_t i = 0; i < rowCount; ++i) {
530 for(
size_t elem = 0; elem <
elements(); ++elem) {
531 _data[
GetIndex(i, index, elem)] = other(isInColumns ? 0 : i, isInColumns ? i : 0, elem);
543 auto colCount = isInColumns ? other.
columns() : other.
rows();
546 for(
size_t i = 0; i < colCount; ++i) {
547 for(
size_t elem = 0; elem <
elements(); ++elem) {
548 _data[
GetIndex(index, i, elem)] = other(isInColumns ? 0 : i, isInColumns ? i : 0, elem);
562 for(
size_t row = 0; row < m.
rows(); row++) {
564 for(
size_t col = 0; col < m.
columns(); col++) {
565 if(m.
elements() > 1) { ostr <<
"( "; }
566 for(
size_t elem = 0; elem < m.
elements(); elem++) {
568 if(elem < m.
elements() - 1) ostr <<
", ";
570 if(m.
elements() > 1) { ostr <<
" )"; }
571 if(col < m.
columns() - 1) ostr <<
", ";
585 void Resize(
size_t rows,
size_t cols,
size_t elementSize = 1) {
590 _data = (T*)realloc(
_data,
rows * cols * elementSize *
sizeof(T));
592 _data = (T*)malloc(
rows * cols * elementSize *
sizeof(T));
605 [[nodiscard]]
inline int GetIndex(
size_t row,
size_t col,
size_t elem = 0)
const {
613 [[nodiscard]]
inline Matrix GetSlice(
size_t rowStart,
size_t rowEnd,
size_t colStart)
const {
625 [[nodiscard]]
inline Matrix GetSlice(
size_t rowStart,
size_t rowEnd,
size_t colStart,
size_t colEnd)
const {
626 size_t numRows = (rowEnd - rowStart) + 1;
627 size_t numCols = (colEnd - colStart) + 1;
631 for(
size_t i = 0; i < numRows; ++i) {
632 for(
size_t j = 0; j < numCols; ++j) {
634 out(i, j, elem) =
_data[
GetIndex(rowStart + i, colStart + j, elem)];
642 const size_t& row_start,
643 const size_t& row_end,
644 const size_t& col_start,
645 const size_t& col_end,
647 size_t numRows = (row_end - row_start) + 1;
648 size_t numCols = (col_end - col_start) + 1;
649 assert(numRows == slice.
rows());
650 assert(numCols == slice.
columns());
652 for(
size_t i = 0; i < numRows; ++i) {
653 for(
size_t j = 0; j < numCols; ++j) {
_data[
GetIndex(row_start + i, col_start + j)] = slice(i, j); }
673 for(
size_t i = 0; i <
_rows; ++i) {
683 for(
size_t i = 0; i < _indices.rows(); ++i) {
684 auto idx = _indices(i, 0);
733 for(
size_t i = 0; i < lhs.
rows(); i++) {
734 for(
size_t j = 0; j < lhs.
columns(); j++) {
735 for(
size_t elem = 0; elem < lhs.
elements(); elem++) {
736 result(i, j, elem) = lhs(i, j, elem) + rhs(row_wise ? i : 0, row_wise ? 0 : j, elem);
745 for(
size_t i = 0; i < lhs.
rows(); i++) {
746 for(
size_t j = 0; j < lhs.
columns(); j++) { result(i, j) = lhs(i, j) + rhs(i, j); }
764 for(
size_t i = 0; i < lhs.
rows(); i++) {
765 for(
size_t j = 0; j < lhs.
columns(); j++) {
766 for(
size_t elem = 0; elem < lhs.
elements(); elem++) {
776 for(
size_t i = 0; i < lhs.
rows(); i++) {
777 for(
size_t j = 0; j < lhs.
columns(); j++) {
778 for(
size_t elem = 0; elem < lhs.
elements(); elem++) { result(i, j, elem) = lhs(i, j, elem) - rhs(i, j, elem); }
790template<
typename T,
typename U>
793 for(
size_t i = 0; i < rhs.
rows(); i++) {
794 for(
size_t j = 0; j < rhs.
columns(); j++) {
795 for(
size_t elem = 0; elem < rhs.
elements(); elem++) { result(i, j, elem) = lhs / rhs(i, j, elem); }
807template<
typename T,
typename U>
810 for(
size_t i = 0; i < lhs.
rows(); i++) {
811 for(
size_t j = 0; j < lhs.
columns(); j++) {
812 for(
size_t elem = 0; elem < lhs.
elements(); elem++) { result(i, j, elem) = lhs(i, j, elem) / rhs; }
843 for(
size_t i = 0; i < lhs.
rows(); i++) {
844 for(
size_t j = 0; j < lhs.
columns(); j++) {
845 for(
size_t elem = 0; elem < lhs.
elements(); elem++) {
846 result(i, j, elem) = lhs(i, j, elem) / rhs(row_wise ? i : 0, row_wise ? 0 : j, elem);
853 return lhs * (1.0 / rhs);
862template<
typename T,
typename U>
865 for(
size_t i = 0; i < lhs.
rows(); i++) {
866 for(
size_t j = 0; j < lhs.
columns(); j++) {
867 for(
size_t elem = 0; elem < lhs.
elements(); elem++) { result(i, j, elem) = lhs(i, j, elem) * rhs; }
878template<
typename T,
typename U>
881 for(
size_t i = 0; i < A.
rows(); i++) {
882 for(
size_t j = 0; j < A.
columns(); j++) {
883 for(
size_t elem = 0; elem < A.
elements(); elem++) { result(i, j, elem) = A(i, j, elem) * lambda; }
900 for(
size_t i = 0; i < lhs.
rows(); i++) {
901 for(
size_t j = 0; j < rhs.
columns(); j++) {
902 for(
size_t k = 0; k < rhs.
rows(); k++) {
903 for(
size_t elem = 0; elem < rhs.
elements(); elem++) {
904 result(i, j, elem) += (T)(lhs(i, k, elem) * rhs(k, j, elem));
915 for(
size_t i = 0; i < lhs.
rows(); i++) {
916 for(
size_t j = 0; j < lhs.
columns(); j++) {
917 for(
size_t k = 0; k < lhs.
elements(); k++) {
918 result(i, j, k) = (T)(lhs(i, j, k) * rhs(row_wise ? i : 0, row_wise ? 0 : j, k));
Matrix< T > operator-(const Matrix< T > &lhs, const Matrix< T > &rhs)
Definition: Matrix.h:758
Matrix< T > operator/(U lhs, const Matrix< T > &rhs)
Definition: Matrix.h:791
Matrix< T > operator*(const Matrix< T > &lhs, const U &rhs)
Definition: Matrix.h:863
Matrix< T > operator+(const Matrix< T > &lhs, const Matrix< T > &rhs)
Definition: Matrix.h:724
Matrix(Matrix const &other)
Definition: Matrix.h:118
constexpr Matrix< T > Transpose() const
Definition: Matrix.h:256
void SetRow(size_t index, const Matrix< T > &other)
Definition: Matrix.h:541
void SetSlice(const size_t &row_start, const size_t &row_end, const size_t &col_start, const size_t &col_end, const Matrix< T > &slice)
Definition: Matrix.h:641
Matrix(const std::initializer_list< std::initializer_list< T > > &lst)
Definition: Matrix.h:75
void SetSlice(const size_t &row_start, const Matrix< T > &slice)
Definition: Matrix.h:661
Matrix(const std::initializer_list< std::initializer_list< std::initializer_list< T > > > &lst)
Definition: Matrix.h:94
bool HasDet() const
Definition: Matrix.h:696
Matrix GetSlice(size_t rowStart, size_t rowEnd) const
Definition: Matrix.h:610
Matrix< T > & operator*=(T rhs)
Definition: Matrix.h:449
Matrix(const Matrix< V > &other)
Definition: Matrix.h:131
Matrix & HadamardMulti(const Matrix &other)
Definition: Matrix.h:389
T & operator*()
Definition: Matrix.h:512
Matrix< T > & operator+=(const Matrix< T > &rhs)
Definition: Matrix.h:459
bool operator==(const Matrix< T > &rhs) const
Definition: Matrix.h:292
static Matrix Normal(size_t rows, size_t columns, double mu, double sigma)
Definition: Matrix.h:171
size_t elements_total() const
Definition: Matrix.h:210
bool operator<(const Matrix< T > &rhs) const
Definition: Matrix.h:307
Matrix GetSlice(size_t rowStart, size_t rowEnd, size_t colStart, size_t colEnd) const
Definition: Matrix.h:625
bool operator!=(const Matrix< T > &rhs) const
Definition: Matrix.h:303
Matrix< T > & operator-=(const Matrix< T > &rhs)
Definition: Matrix.h:468
size_t rows() const
Definition: Matrix.h:193
Matrix(T val, size_t rowCount, size_t colCount, size_t elementDimension=1)
Definition: Matrix.h:51
T Determinant() const
Definition: Matrix.h:216
Matrix< T > operator()(size_t row)
Definition: Matrix.h:500
size_t columns() const
Definition: Matrix.h:198
Matrix GetSlice(size_t rowStart) const
Definition: Matrix.h:609
Matrix< T > operator=(const Matrix< T > &other)
Definition: Matrix.h:360
void assertSize(const Matrix< T > &other) const
Definition: Matrix.h:334
T & operator()(size_t row, size_t column, size_t elem=0) const
Definition: Matrix.h:490
bool needsFree
Definition: Matrix.h:710
T & operator()(size_t row, size_t column, size_t elem=0)
Definition: Matrix.h:482
size_t _element_size
number elements
Definition: Matrix.h:703
size_t _columns
number columns
Definition: Matrix.h:701
~Matrix()
Definition: Matrix.h:144
Matrix GetSlice(size_t rowStart, size_t rowEnd, size_t colStart) const
Definition: Matrix.h:613
void Resize(size_t rows, size_t cols, size_t elementSize=1)
Definition: Matrix.h:585
Matrix()
Definition: Matrix.h:69
Matrix< T > & KroneckerMulti(const Matrix< T > &other)
Definition: Matrix.h:400
T * _data
ongoing array representing data
Definition: Matrix.h:706
Matrix< T > GetComponents(const size_t &index) const
Definition: Matrix.h:670
Matrix< T > operator()(size_t row) const
Definition: Matrix.h:506
Matrix< T > Apply(const std::function< T(T)> &fun) const
Definition: Matrix.h:375
static Matrix Random(size_t rows, size_t columns, size_t element_size=1, double minValue=0.0, double maxValue=1.0)
Definition: Matrix.h:158
Matrix< T > sum(size_t axis) const
Definition: Matrix.h:434
friend std::ostream & operator<<(std::ostream &ostr, const Matrix &m)
Definition: Matrix.h:559
size_t _dataSize
total number of elements
Definition: Matrix.h:708
size_t elements() const
Definition: Matrix.h:204
Matrix(T *val, size_t colCount)
Definition: Matrix.h:61
Matrix< T > GetSlicesByIndex(const Matrix< size_t > &indices) const
Definition: Matrix.h:679
T & operator*() const
Definition: Matrix.h:517
bool operator>(const Matrix< T > &rhs) const
Definition: Matrix.h:315
int GetIndex(size_t row, size_t col, size_t elem=0) const
Definition: Matrix.h:605
size_t _rows
number rows
Definition: Matrix.h:699
void SetColumn(size_t index, const Matrix< T > &other)
Definition: Matrix.h:524
T sumElements() const
Definition: Matrix.h:422
bool IsVector() const
Definition: Matrix.h:328
Matrix< T > HorizontalConcat(const Matrix< T > &other)
Definition: Matrix.h:270
bool elementWiseCompare(const Matrix< T > &rhs) const
Definition: Matrix.h:343
static double Get(double l=0.0, double r=1.0)
Definition: Random.h:40
size_t elemDim
number elements
Definition: Matrix.h:26
size_t rows
number rows
Definition: Matrix.h:22
size_t columns
number columns
Definition: Matrix.h:24