philsupertramp/game-math
Loading...
Searching...
No Matches
Classifier.h
Go to the documentation of this file.
1#pragma once
2
3#include "../Matrix.h"
4#include "Predictor.h"
5
6
21class Classifier : public Predictor
22{
23protected:
25 bool w_initialized = false;
26
27public:
32
39
45 void initialize_weights(size_t numRows, size_t numColumns = 1) {
46 weights = Matrix<double>(0, numRows + 1, numColumns);
47 w_initialized = true;
48 }
49
50 void update_weights(const Matrix<double>& update, const Matrix<double>& delta) {
51 for(size_t i = 0; i < weights.rows(); i++) {
52 for(size_t j = 0; j < weights.columns(); j++) {
53 if(i == 0) weights(i, j) += update(i, j);
54 else
55 weights(i, j) += delta(i - 1, j);
56 }
57 }
58 }
59 Matrix<double> transform(const Matrix<double>& in) override { return in; }
60};
61
62
70{
71protected:
73 double eta;
75 int n_iter;
76
77public:
78 ANNClassifier(double _eta, int _n_iter)
79 : Classifier()
80 , eta(_eta)
81 , n_iter(_n_iter) { }
82
88
94
99 virtual double costFunction(const Matrix<double>&) = 0;
100};
Definition: Classifier.h:70
double eta
Learning rate.
Definition: Classifier.h:73
virtual Matrix< double > netInput(const Matrix< double > &)=0
virtual double costFunction(const Matrix< double > &)=0
int n_iter
number epochs
Definition: Classifier.h:75
virtual Matrix< double > activation(const Matrix< double > &)=0
ANNClassifier(double _eta, int _n_iter)
Definition: Classifier.h:78
Definition: Classifier.h:22
bool w_initialized
flag to initialize weights only once
Definition: Classifier.h:25
Matrix< double > weights
Vector holding weights.
Definition: Classifier.h:29
void initialize_weights(size_t numRows, size_t numColumns=1)
Definition: Classifier.h:45
Matrix< double > costs
Vector holding classification error per epoch.
Definition: Classifier.h:31
Matrix< double > transform(const Matrix< double > &in) override
Definition: Classifier.h:59
void update_weights(const Matrix< double > &update, const Matrix< double > &delta)
Definition: Classifier.h:50
Classifier()
Definition: Classifier.h:38
Definition: Matrix.h:42
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
Definition: Predictor.h:5