philsupertramp/game-math
Loading...
Searching...
No Matches
Perceptron.h
Go to the documentation of this file.
1#pragma once
2
3#include "../Matrix.h"
4#include "../matrix_utils.h"
5#include "Classifier.h"
6#include "SGD.h"
7
25{
26public:
32 explicit Perceptron(double _eta = 0.01, int iter = 10)
33 : ANNClassifier(_eta, iter) { }
34
41 void fit(const Matrix<double>& X, const Matrix<double>& y) override {
44 for(int iter = 0; iter < n_iter; iter++) {
45 int _errors = 0;
46 auto iterable = zip(X, y);
47 for(const auto& elem : iterable) {
48 auto xi = elem.first;
49 auto target = elem.second;
50
51 auto output = predict(xi);
52 auto delta_w = (target - output);
53 update_weights(delta_w, (delta_w * xi).Transpose() * eta);
54 _errors += costFunction(delta_w);
55 }
56 costs(iter, 0) = _errors;
57 }
58 }
59
68 Matrix<double> netInput(const Matrix<double>& X) override { return SGD::netInput(X, weights); }
69
75 Matrix<double> activation(const Matrix<double>& X) override { return netInput(X); }
76
83 std::function<bool(double)> condition = [](double x) { return bool(x >= 0.0); };
84 return where(condition, activation(X), { { 1 } }, { { -1 } });
85 }
86
92 double costFunction(const Matrix<double>& X) override { return (double)(X(0, 0) != 0.0); }
93};
94
Definition: Classifier.h:70
double eta
Learning rate.
Definition: Classifier.h:73
int n_iter
number epochs
Definition: Classifier.h:75
Matrix< double > weights
Vector holding weights.
Definition: Classifier.h:29
void initialize_weights(size_t numRows, size_t numColumns=1)
Definition: Classifier.h:45
Matrix< double > costs
Vector holding classification error per epoch.
Definition: Classifier.h:31
void update_weights(const Matrix< double > &update, const Matrix< double > &delta)
Definition: Classifier.h:50
Definition: Matrix.h:42
size_t columns() const
Definition: Matrix.h:198
Definition: Perceptron.h:25
void fit(const Matrix< double > &X, const Matrix< double > &y) override
Definition: Perceptron.h:41
Perceptron(double _eta=0.01, int iter=10)
Definition: Perceptron.h:32
Matrix< double > activation(const Matrix< double > &X) override
Definition: Perceptron.h:75
Matrix< double > predict(const Matrix< double > &X) override
Definition: Perceptron.h:82
double costFunction(const Matrix< double > &X) override
Definition: Perceptron.h:92
Matrix< double > netInput(const Matrix< double > &X) override
Definition: Perceptron.h:68
static Matrix< double > netInput(const Matrix< double > &X, const Matrix< double > &weights)
Definition: SGD.h:132
std::vector< std::pair< Matrix< T >, Matrix< T > > > zip(const Matrix< T > &a, const Matrix< T > &b)
Definition: matrix_utils.h:271
Matrix< T > where(const std::function< bool(T)> &condition, const Matrix< T > &in, const Matrix< T > &valIfTrue, const Matrix< T > &valIfFalse)
Definition: matrix_utils.h:194