philsupertramp/game-math
Loading...
Searching...
No Matches
NCC.h
Go to the documentation of this file.
1
37#include "../matrix_utils.h"
38#include "../numerics/utils.h"
39#include "../utils.h"
40#include "Classifier.h"
41#include <cstddef>
42
43
48class NCC : public Classifier
49{
50 bool use_iterative = true;
51
52public:
53 NCC(bool useIterative)
54 : Classifier()
55 , use_iterative(useIterative) { }
56
60 void fit(const Matrix<double>& X, const Matrix<double>& y) override {
61 assert(y.rows() == 1 || y.columns() == 1);
62 auto labels = unique(y);
63 weights = zeros(labels.rows(), X.columns());
64 if(use_iterative) {
65 fit_iterative(X, y, labels);
66 } else {
67 fit_batch(X, y, labels);
68 }
69 };
70
82 void fit_batch(const Matrix<double>& X, const Matrix<double>& y, const Matrix<double>& labels) {
83 // TODO: Technically we can omit labels here and only pass the number of unique labels. Or we keep a vector
84 // of unique labels as a class attribute. Both options might make sense.
85 // O(K)
86 for(size_t i = 0; i < labels.rows(); ++i) {
87 std::function<bool(double)> condition = [i](double x) { return bool(x == i); };
88 // O(N)
89 auto yis = where(condition, y, { { 1 } }, { { 0 } });
90 // O(N)
91 auto Nk = yis.sumElements();
92 // O(2N)
93 auto Xis = X.GetSlicesByIndex(where_true(yis));
94 // O(N * N * N)
95 weights.SetSlice(i, Xis.sum(1) * (1. / Nk));
96 }
97 }
98
111 void fit_iterative(const Matrix<double>& X, const Matrix<double>& y, const Matrix<double>& labels) {
112 auto counters = Matrix<int>(0, labels.rows(), 1);
113 // O(N)
114 for(size_t i = 0; i < X.rows(); ++i) {
115 auto k = y(i, 0);
116 auto xi = X.GetSlice(i, i);
117 auto current_counter = counters(k, 0);
118 // O(D)
120 k,
121 // O(D)
122 weights.GetSlice(k, k) * (current_counter / double(current_counter + 1))
123 // O(D)
124 + xi * (1. / double(current_counter + 1)));
125 counters(k, 0) += 1;
126 }
127 }
128
142 auto predictions = zerosV(x.rows());
143 for(size_t i = 0; i < x.rows(); ++i) {
144 auto xi = x.GetSlice(i, i);
145 auto distances = norm(weights - xi, 0);
146 predictions(i, 0) = argmin(distances);
147 }
148 return predictions;
149 };
150};
Definition: Classifier.h:22
Matrix< double > weights
Vector holding weights.
Definition: Classifier.h:29
Definition: Matrix.h:42
void SetSlice(const size_t &row_start, const size_t &row_end, const size_t &col_start, const size_t &col_end, const Matrix< T > &slice)
Definition: Matrix.h:641
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
Matrix GetSlice(size_t rowStart) const
Definition: Matrix.h:609
Matrix< T > GetSlicesByIndex(const Matrix< size_t > &indices) const
Definition: Matrix.h:679
Definition: NCC.h:49
Matrix< double > predict(const Matrix< double > &x) override
Definition: NCC.h:141
void fit_iterative(const Matrix< double > &X, const Matrix< double > &y, const Matrix< double > &labels)
Definition: NCC.h:111
NCC(bool useIterative)
Definition: NCC.h:53
void fit(const Matrix< double > &X, const Matrix< double > &y) override
Definition: NCC.h:60
void fit_batch(const Matrix< double > &X, const Matrix< double > &y, const Matrix< double > &labels)
Definition: NCC.h:82
bool use_iterative
Definition: NCC.h:50
Matrix< size_t > where_true(const Matrix< T > &in)
Definition: matrix_utils.h:246
size_t argmin(const Matrix< T > &mat)
Definition: matrix_utils.h:167
Matrix< T > unique(const Matrix< T > &in, int axis=0)
Definition: matrix_utils.h:466
Matrix< T > where(const std::function< bool(T)> &condition, const Matrix< T > &in, const Matrix< T > &valIfTrue, const Matrix< T > &valIfFalse)
Definition: matrix_utils.h:194
double norm(const Matrix< double > &in)
Matrix< double > zerosV(size_t rows)
Matrix< double > zeros(size_t rows, size_t columns, size_t elements=1)