philsupertramp/game-math
Loading...
Searching...
No Matches
matrix_utils.h
Go to the documentation of this file.
1#pragma once
2
3#include "Matrix.h"
4
16template<typename T>
17Matrix<T> HadamardMulti(const Matrix<T>& lhs, const Matrix<T>& rhs) {
18 lhs.assertSize(rhs);
19 auto result = Matrix<T>(0, lhs.rows(), lhs.columns(), lhs.elements());
20 for(size_t i = 0; i < result.rows(); i++) {
21 for(size_t j = 0; j < result.columns(); j++) {
22 for(size_t elem = 0; elem < result.elements(); elem++) { result(i, j, elem) = lhs(i, j, elem) * rhs(i, j, elem); }
23 }
24 }
25 return result;
26}
27
35template<typename T>
36Matrix<T> HadamardDiv(const Matrix<T>& lhs, const Matrix<T>& rhs) {
37 lhs.assertSize(rhs);
38 auto result = Matrix<T>(0, lhs.rows(), lhs.columns(), lhs.elements());
39 for(size_t i = 0; i < result.rows(); i++) {
40 for(size_t j = 0; j < result.columns(); j++) {
41 for(size_t elem = 0; elem < result.elements(); elem++) { result(i, j, elem) = lhs(i, j, elem) / rhs(i, j, elem); }
42 }
43 }
44 return result;
45}
46
55template<typename T>
57 assert(lhs.elements() == rhs.elements());
58 auto result = Matrix<T>(0, lhs.rows() * rhs.rows(), lhs.columns() * rhs.columns(), rhs.elements());
59 for(size_t m = 0; m < lhs.rows(); m++) {
60 for(size_t n = 0; n < lhs.columns(); n++) {
61 for(size_t p = 0; p < rhs.rows(); p++) {
62 for(size_t q = 0; q < rhs.columns(); q++) {
63 for(size_t elem = 0; elem < rhs.elements(); elem++) {
64 result(m * rhs.rows() + p, n * rhs.columns() + q, elem) = lhs(m, n, elem) * rhs(p, q, elem);
65 }
66 }
67 }
68 }
69 }
70 return result;
71}
72
80template<typename T>
82 assert(lhs.rows() == rhs.rows());
83 assert(lhs.elements() == rhs.elements());
84 auto result = Matrix<T>(0.0, lhs.rows(), lhs.columns() + rhs.columns(), lhs.elements());
85 for(size_t i = 0; i < lhs.rows(); ++i) {
86 for(size_t j = 0; j < lhs.columns() + rhs.columns(); ++j) {
87 for(size_t elem = 0; elem < lhs.elements(); ++elem) {
88 result(i, j, elem) = j < lhs.columns() ? lhs(i, j, elem) : rhs(i, j - lhs.columns(), elem);
89 }
90 }
91 }
92 return result;
93}
94
102template<typename T>
103size_t Corr(const Matrix<T>& A, const Matrix<T>& B) {
104 A.assertSize(B);
105 size_t count = 0;
106 for(size_t i = 0; i < A.rows(); i++) {
107 for(size_t j = 0; j < A.columns(); j++) {
108 for(size_t elem = 0; elem < A.elements(); elem++) { count += (A(i, j, elem) == B(i, j, elem)); }
109 }
110 }
111 return count;
112}
113
121template<typename T>
122Matrix<T> from_vptr(const T* value, MatrixDimension size) {
123 auto out = Matrix<T>(0, size.rows, size.columns);
124 for(size_t i = 0; i < size.rows; i++) {
125 for(size_t j = 0; j < size.columns; j++) {
126 for(size_t elem = 0; elem < size.elemDim; elem++) {
127 out(i, j, elem) = value[elem + j * size.elemDim + i * size.columns * size.elemDim];
128 }
129 }
130 }
131 return out;
132}
133
142template<typename T>
143size_t argmax(const Matrix<T>& mat) {
144 T maxVal = std::numeric_limits<T>::min();
145 size_t maxIndex = -1;
146 for(size_t i = 0; i < mat.rows(); i++) {
147 for(size_t j = 0; j < mat.columns(); j++) {
148 for(size_t elem = 0; elem < mat.elements(); elem++) {
149 if(mat(i, j, elem) > maxVal) {
150 maxVal = mat(i, j, elem);
151 maxIndex = elem + j * mat.elements() + i * mat.columns() * mat.elements();
152 }
153 }
154 }
155 }
156 return maxIndex;
157}
158
166template<typename T>
167size_t argmin(const Matrix<T>& mat) {
168 T maxVal = std::numeric_limits<T>::max();
169 size_t maxIndex = -1;
170 for(size_t i = 0; i < mat.rows(); i++) {
171 for(size_t j = 0; j < mat.columns(); j++) {
172 for(size_t elem = 0; elem < mat.elements(); ++elem) {
173 if(mat(i, j, elem) < maxVal) {
174 maxVal = mat(i, j, elem);
175 maxIndex = elem + j * mat.elements() + i * mat.columns() * mat.elements();
176 }
177 }
178 }
179 }
180 return maxIndex;
181}
182
183
193template<typename T>
195const std::function<bool(T)>& condition, const Matrix<T>& in, const Matrix<T>& valIfTrue, const Matrix<T>& valIfFalse) {
196 assert(valIfTrue.columns() == valIfFalse.columns() && valIfTrue.rows() == valIfFalse.rows());
197 bool refVector = true;
198 if((valIfTrue.columns() == valIfTrue.rows()) == 1) { refVector = false; }
199 auto out = refVector ? valIfTrue : Matrix<T>(0, in.rows(), in.columns(), in.elements());
200
201 for(size_t i = 0; i < in.rows(); i++) {
202 for(size_t j = 0; j < in.columns(); j++) {
203 for(size_t elem = 0; elem < in.elements(); elem++) {
204 if(refVector) {
205 if(!condition(in(i, j, elem))) out(i, j, elem) = valIfFalse(i, j);
206 else
207 out(i, j, elem) = in(i, j, elem);
208 } else {
209 out(i, j, elem) = condition(in(i, j, elem)) ? valIfTrue(0, 0) : valIfFalse(0, 0);
210 }
211 }
212 }
213 }
214 return out;
215}
216
224template<typename T>
226 assert(in.IsVector());
227 bool requires_transposition = in.rows() < in.columns();
228 Matrix<size_t> out = Matrix<size_t>(0, !requires_transposition ? in.rows() : in.columns(), 1);
229 size_t found_vals = 0;
230 for(size_t i = 0; i < (!requires_transposition ? in.rows() : in.columns()); ++i) {
231 if(in(requires_transposition ? 0 : i, requires_transposition ? i : 0) == value) {
232 out(found_vals, 0) = i;
233 found_vals++;
234 }
235 }
236 return requires_transposition ? out.GetSlice(0, found_vals - 1).Transpose() : out.GetSlice(0, found_vals - 1);
237}
245template<typename T>
247 return where_value(in, 1.0);
248}
256template<typename T>
258 return where_value(in, 0.0);
259}
260
261
270template<typename T>
271std::vector<std::pair<Matrix<T>, Matrix<T>>> zip(const Matrix<T>& a, const Matrix<T>& b) {
272 std::vector<std::pair<Matrix<T>, Matrix<T>>> out(a.rows());
273 for(size_t i = 0; i < a.rows(); i++) {
274 Matrix<T> subA, subB;
275 subA.Resize(1, a.columns());
276 subB.Resize(1, b.columns());
277 for(size_t j = 0; j < a.columns(); j++) { subA(0, j) = a(i, j); }
278 for(size_t j = 0; j < b.columns(); j++) { subB(0, j) = b(i, j); }
279
280 out[i] = { subA, subB };
281 }
282 return out;
283}
284
291template<typename T>
292T max(const Matrix<T>& mat) {
293 T maxVal = std::numeric_limits<T>::min();
294 for(size_t i = 0; i < mat.rows(); i++) {
295 for(size_t j = 0; j < mat.columns(); j++) {
296 for(size_t k = 0; k < mat.elements(); k++) {
297 if(mat(i, j, k) > maxVal) { maxVal = mat(i, j, k); }
298 }
299 }
300 }
301 return maxVal;
302}
303
304
305template<typename T>
306Matrix<T> max(const Matrix<T>& mat, int axis) {
307 bool row_wise = axis == 0;
308 Matrix<T> out = Matrix<T>(0, row_wise ? 1 : mat.rows(), row_wise ? mat.columns() : 1);
309
310 for(size_t i = 0; i < (row_wise ? mat.columns() : mat.rows()); i++) {
311 out(row_wise ? 0 : i, row_wise ? i : 0) = elemMax(
312 mat.GetSlice(row_wise ? 0 : i, row_wise ? mat.rows() - 1 : i, row_wise ? i : 0, row_wise ? i : mat.columns() - 1),
313 0);
314 }
315 return out;
316}
323template<typename T>
324T min(const Matrix<T>& mat) {
325 T minVal = std::numeric_limits<T>::max();
326 for(size_t i = 0; i < mat.rows(); i++) {
327 for(size_t j = 0; j < mat.columns(); j++) {
328 for(size_t k = 0; k < mat.elements(); k++) {
329 if(mat(i, j, k) < minVal) { minVal = mat(i, j, k); }
330 }
331 }
332 }
333 return minVal;
334}
335
336template<typename T>
337Matrix<T> min(const Matrix<T>& mat, int axis) {
338 bool row_wise = axis == 0;
339 Matrix<T> out = Matrix<T>(0, row_wise ? 1 : mat.rows(), row_wise ? mat.columns() : 1);
340
341 for(size_t i = 0; i < (row_wise ? mat.columns() : mat.rows()); i++) {
342 out(row_wise ? 0 : i, row_wise ? i : 0) = elemMin(
343 mat.GetSlice(row_wise ? 0 : i, row_wise ? mat.rows() - 1 : i, row_wise ? i : 0, row_wise ? i : mat.columns() - 1),
344 0);
345 }
346 return out;
347}
348
349
357template<typename T>
358T elemMax(const Matrix<T>& mat, const size_t& elemIndex) {
359 assert(mat.elements() - 1 >= elemIndex);
360 T maxVal = std::numeric_limits<T>::min();
361 size_t index = 0;
362 for(size_t i = 0; i < mat.rows(); i++) {
363 for(size_t j = 0; j < mat.columns(); j++) {
364 if(mat(i, j, elemIndex) > maxVal) { maxVal = mat(i, j, elemIndex); }
365 index++;
366 }
367 }
368 return maxVal;
369}
370
378template<typename T>
379T elemMin(const Matrix<T>& mat, const size_t& elemIndex) {
380 assert(mat.elements() - 1 >= elemIndex);
381 T maxVal = std::numeric_limits<T>::max();
382 size_t index = 0;
383 for(size_t i = 0; i < mat.rows(); i++) {
384 for(size_t j = 0; j < mat.columns(); j++) {
385 if(mat(i, j, elemIndex) < maxVal) { maxVal = mat(i, j, elemIndex); }
386 index++;
387 }
388 }
389 return maxVal;
390}
391
400template<typename T>
401Matrix<T> mean(const Matrix<T>& mat, int axis = -1) {
402 if(axis == -1) {
403 Matrix<T> sum = Matrix<T>(0, 1, 1);
404 T index = 0;
405 for(size_t i = 0; i < mat.rows(); i++) {
406 for(size_t j = 0; j < mat.columns(); j++) {
407 sum(0, 0) += mat(i, j);
408 index++;
409 }
410 }
411 return (1.0 / index) * sum;
412 }
413 bool row_wise = axis == 0;
414
415 Matrix<T> sum = Matrix<T>(0, row_wise ? 1 : mat.rows(), row_wise ? mat.columns() : 1);
416 for(size_t i = 0; i < (row_wise ? mat.rows() : mat.columns()); i++) {
417 sum +=
418 mat.GetSlice(row_wise ? i : 0, row_wise ? i : mat.rows() - 1, row_wise ? 0 : i, row_wise ? mat.columns() - 1 : i);
419 }
420 return (1.0 / (row_wise ? mat.rows() : mat.columns())) * sum;
421}
422
430template<typename T>
431T elemMean(const Matrix<T>& mat, const size_t& elemIndex) {
432 assert(mat.elements() - 1 >= elemIndex);
433 T sum(0);
434 size_t index = 0;
435 for(size_t i = 0; i < mat.rows(); i++) {
436 for(size_t j = 0; j < mat.columns(); j++) {
437 sum += mat(i, j, elemIndex);
438 index++;
439 }
440 }
441 return sum / index;
442}
443
450template<typename T>
452 Matrix<T> out(0, in.rows(), 1, in.elements());
453 for(size_t i = 0; i < in.rows(); i++) {
454 for(size_t elem = 0; elem < in.elements(); elem++) { out(i, 0, elem) = in(i, i, elem); }
455 }
456 return out;
457}
458
465template<typename T>
466Matrix<T> unique(const Matrix<T>& in, [[maybe_unused]] int axis = 0) {
467 bool row_wise = axis == 0;
468 Matrix<T> out = Matrix<T>(0, in.rows(), in.columns());
469 size_t found_vals = 0;
470 for(size_t i = 0; i < (row_wise ? in.rows() : in.columns()); ++i) {
471 bool found = false;
472 auto xi =
473 in.GetSlice(row_wise ? i : 0, row_wise ? i : in.rows() - 1, row_wise ? 0 : i, row_wise ? in.columns() - 1 : i);
474 for(size_t j = 0; j < found_vals; ++j) {
475 found = xi
476 == out.GetSlice(
477 row_wise ? j : 0, row_wise ? j : out.rows() - 1, row_wise ? 0 : j, row_wise ? out.columns() - 1 : j);
478 if(found) { break; }
479 }
480 if(!found) {
481 out.SetSlice(found_vals, xi);
482 found_vals++;
483 }
484 }
485 return out.GetSlice(0, found_vals - 1);
486}
Definition: Matrix.h:42
constexpr Matrix< T > Transpose() const
Definition: Matrix.h:256
void SetSlice(const size_t &row_start, const size_t &row_end, const size_t &col_start, const size_t &col_end, const Matrix< T > &slice)
Definition: Matrix.h:641
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
Matrix GetSlice(size_t rowStart) const
Definition: Matrix.h:609
void assertSize(const Matrix< T > &other) const
Definition: Matrix.h:334
void Resize(size_t rows, size_t cols, size_t elementSize=1)
Definition: Matrix.h:585
size_t elements() const
Definition: Matrix.h:204
bool IsVector() const
Definition: Matrix.h:328
Matrix< size_t > where_true(const Matrix< T > &in)
Definition: matrix_utils.h:246
size_t Corr(const Matrix< T > &A, const Matrix< T > &B)
Definition: matrix_utils.h:103
size_t argmin(const Matrix< T > &mat)
Definition: matrix_utils.h:167
Matrix< T > HorizontalConcat(const Matrix< T > &lhs, const Matrix< T > &rhs)
Definition: matrix_utils.h:81
Matrix< T > mean(const Matrix< T > &mat, int axis=-1)
Definition: matrix_utils.h:401
size_t argmax(const Matrix< T > &mat)
Definition: matrix_utils.h:143
Matrix< T > diag_elements(const Matrix< T > &in)
Definition: matrix_utils.h:451
T elemMin(const Matrix< T > &mat, const size_t &elemIndex)
Definition: matrix_utils.h:379
T min(const Matrix< T > &mat)
Definition: matrix_utils.h:324
T max(const Matrix< T > &mat)
Definition: matrix_utils.h:292
Matrix< T > KroneckerMulti(const Matrix< T > &lhs, const Matrix< T > &rhs)
Definition: matrix_utils.h:56
std::vector< std::pair< Matrix< T >, Matrix< T > > > zip(const Matrix< T > &a, const Matrix< T > &b)
Definition: matrix_utils.h:271
Matrix< T > HadamardDiv(const Matrix< T > &lhs, const Matrix< T > &rhs)
Definition: matrix_utils.h:36
T elemMean(const Matrix< T > &mat, const size_t &elemIndex)
Definition: matrix_utils.h:431
Matrix< T > from_vptr(const T *value, MatrixDimension size)
Definition: matrix_utils.h:122
Matrix< T > HadamardMulti(const Matrix< T > &lhs, const Matrix< T > &rhs)
Definition: matrix_utils.h:17
Matrix< size_t > where_false(const Matrix< T > &in)
Definition: matrix_utils.h:257
T elemMax(const Matrix< T > &mat, const size_t &elemIndex)
Definition: matrix_utils.h:358
Matrix< size_t > where_value(const Matrix< T > &in, T value)
Definition: matrix_utils.h:225
Matrix< T > unique(const Matrix< T > &in, int axis=0)
Definition: matrix_utils.h:466
Matrix< T > where(const std::function< bool(T)> &condition, const Matrix< T > &in, const Matrix< T > &valIfTrue, const Matrix< T > &valIfFalse)
Definition: matrix_utils.h:194
Definition: Matrix.h:20
size_t elemDim
number elements
Definition: Matrix.h:26
size_t rows
number rows
Definition: Matrix.h:22
size_t columns
number columns
Definition: Matrix.h:24