philsupertramp/game-math
Loading...
Searching...
No Matches
AdalineGD.h
Go to the documentation of this file.
1#pragma once
2
3#include "../Matrix.h"
4#include "Classifier.h"
5#include "SGD.h"
6
11{
12public:
18 explicit AdalineGD(double _eta = 0.01, int iter = 10)
19 : ANNClassifier(_eta, iter) { }
20
26 void fit(const Matrix<double>& X, const Matrix<double>& y) override {
29 for(int iter = 0; iter < n_iter; iter++) {
30 auto output = netInput(X);
31 auto errors = y - output;
32
33 auto delta_w = (X.Transpose() * errors) * eta;
34
35 update_weights(Matrix<double>(eta * errors.sumElements(), weights.rows(), weights.columns()), delta_w);
36 auto cost = costFunction(errors);
37 costs(iter, 0) = cost;
38 }
39 }
40
47 Matrix<double> netInput(const Matrix<double>& X) override { return SGD::netInput(X, weights); }
48
54 Matrix<double> activation(const Matrix<double>& X) override { return netInput(X); }
55
62 std::function<bool(double)> condition = [](double x) { return bool(x >= 0.0); };
63 return where(condition, activation(X), { { 1 } }, { { -1 } });
64 }
65
73 double costFunction(const Matrix<double>& X) override { return HadamardMulti<double>(X, X).sumElements() / 2.0; }
74};
Definition: Classifier.h:70
double eta
Learning rate.
Definition: Classifier.h:73
int n_iter
number epochs
Definition: Classifier.h:75
Definition: AdalineGD.h:11
void fit(const Matrix< double > &X, const Matrix< double > &y) override
Definition: AdalineGD.h:26
Matrix< double > predict(const Matrix< double > &X) override
Definition: AdalineGD.h:61
AdalineGD(double _eta=0.01, int iter=10)
Definition: AdalineGD.h:18
Matrix< double > netInput(const Matrix< double > &X) override
Definition: AdalineGD.h:47
double costFunction(const Matrix< double > &X) override
Definition: AdalineGD.h:73
Matrix< double > activation(const Matrix< double > &X) override
Definition: AdalineGD.h:54
Matrix< double > weights
Vector holding weights.
Definition: Classifier.h:29
void initialize_weights(size_t numRows, size_t numColumns=1)
Definition: Classifier.h:45
Matrix< double > costs
Vector holding classification error per epoch.
Definition: Classifier.h:31
void update_weights(const Matrix< double > &update, const Matrix< double > &delta)
Definition: Classifier.h:50
Definition: Matrix.h:42
constexpr Matrix< T > Transpose() const
Definition: Matrix.h:256
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
static Matrix< double > netInput(const Matrix< double > &X, const Matrix< double > &weights)
Definition: SGD.h:132
Matrix< T > where(const std::function< bool(T)> &condition, const Matrix< T > &in, const Matrix< T > &valIfTrue, const Matrix< T > &valIfFalse)
Definition: matrix_utils.h:194