24#include "../matrix_utils.h"
25#include "../numerics/utils.h"
69 auto labels =
unique(trainY);
70 for(
size_t i = 0; i < x.
rows(); ++i) {
72 auto current_dists =
norm(trainX - xi, 0);
73 auto k_neares = trainY.GetSlicesByIndex(
argsort(current_dists));
75 auto gamma =
zeros(labels.rows(), 1);
76 for(
size_t k = 0; k < labels.rows(); ++k) {
77 std::function<bool(
double)> condition = [k](
double x) {
return bool(x == k); };
79 auto yis =
where(condition, k_nearest, { { 1 } }, { { 0 } });
81 gamma(k, 0) = yis.sumElements();
83 predictions(i, 0) =
argmax(gamma);
Definition: Classifier.h:22
Matrix< double > weights
Vector holding weights.
Definition: Classifier.h:29
KNN(int neighbors)
Definition: KNN.h:36
int nearest_neighbors
Definition: KNN.h:32
Matrix< double > predict(const Matrix< double > &x) override
Definition: KNN.h:65
Matrix< double > trainLabels
Definition: KNN.h:33
void fit(const Matrix< double > &X, const Matrix< double > &y) override
Definition: KNN.h:49
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
Matrix GetSlice(size_t rowStart) const
Definition: Matrix.h:609
size_t argmax(const Matrix< T > &mat)
Definition: matrix_utils.h:143
Matrix< T > unique(const Matrix< T > &in, int axis=0)
Definition: matrix_utils.h:466
Matrix< T > where(const std::function< bool(T)> &condition, const Matrix< T > &in, const Matrix< T > &valIfTrue, const Matrix< T > &valIfFalse)
Definition: matrix_utils.h:194
Matrix< size_t > argsort(const Matrix< double > &in)
double norm(const Matrix< double > &in)
Matrix< double > zerosV(size_t rows)
Matrix< double > zeros(size_t rows, size_t columns, size_t elements=1)