philsupertramp/game-math
Loading...
Searching...
No Matches
KNN.h
Go to the documentation of this file.
1
23#include "../Matrix.h"
24#include "../matrix_utils.h"
25#include "../numerics/utils.h"
26#include "Classifier.h"
27#include <cstddef>
28
29
30class KNN : public Classifier
31{
34
35public:
36 KNN(int neighbors)
37 : Classifier()
38 , nearest_neighbors(neighbors) { }
39
40
49 void fit(const Matrix<double>& X, const Matrix<double>& y) override {
50 assert(y.rows() == 1 || y.columns() == 1);
51 weights = X;
52 trainLabels = y;
53 };
54
66 auto predictions = zerosV(x.rows());
67 auto trainX = weights;
68 auto trainY = trainLabels;
69 auto labels = unique(trainY);
70 for(size_t i = 0; i < x.rows(); ++i) {
71 auto xi = x.GetSlice(i);
72 auto current_dists = norm(trainX - xi, 0);
73 auto k_neares = trainY.GetSlicesByIndex(argsort(current_dists));
74 auto k_nearest = k_neares.GetSlice(0, nearest_neighbors - 1);
75 auto gamma = zeros(labels.rows(), 1);
76 for(size_t k = 0; k < labels.rows(); ++k) {
77 std::function<bool(double)> condition = [k](double x) { return bool(x == k); };
78 // O(N)
79 auto yis = where(condition, k_nearest, { { 1 } }, { { 0 } });
80 // O(N)
81 gamma(k, 0) = yis.sumElements();
82 }
83 predictions(i, 0) = argmax(gamma);
84 }
85 return predictions;
86 };
87};
Definition: Classifier.h:22
Matrix< double > weights
Vector holding weights.
Definition: Classifier.h:29
Definition: KNN.h:31
KNN(int neighbors)
Definition: KNN.h:36
int nearest_neighbors
Definition: KNN.h:32
Matrix< double > predict(const Matrix< double > &x) override
Definition: KNN.h:65
Matrix< double > trainLabels
Definition: KNN.h:33
void fit(const Matrix< double > &X, const Matrix< double > &y) override
Definition: KNN.h:49
Definition: Matrix.h:42
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
Matrix GetSlice(size_t rowStart) const
Definition: Matrix.h:609
size_t argmax(const Matrix< T > &mat)
Definition: matrix_utils.h:143
Matrix< T > unique(const Matrix< T > &in, int axis=0)
Definition: matrix_utils.h:466
Matrix< T > where(const std::function< bool(T)> &condition, const Matrix< T > &in, const Matrix< T > &valIfTrue, const Matrix< T > &valIfFalse)
Definition: matrix_utils.h:194
Matrix< size_t > argsort(const Matrix< double > &in)
double norm(const Matrix< double > &in)
Matrix< double > zerosV(size_t rows)
Matrix< double > zeros(size_t rows, size_t columns, size_t elements=1)