philsupertramp/game-math
Loading...
Searching...
No Matches
LogRegSGD.h
Go to the documentation of this file.
1#pragma once
2
3#include "../Matrix.h"
4#include "Classifier.h"
5#include "SGD.h"
6#include "utils.h"
7
12{
13public:
15 bool shuffle;
19 SGD* sgd = nullptr;
20
21private:
26 void initialize_weights(size_t m) {
28 else
29 weights = Matrix<double>(0, m, 1);
30 w_initialized = true;
31 }
32
33
40 double update_weights(const Matrix<double>& xi, const Matrix<double>& target) {
41 auto output = netInput(xi);
42 auto error = target - output;
43 auto class1_cost = (target * -1.0) * Log(output);
44 auto class2_cost = (Matrix<double>(1, output.rows(), output.columns()) - target)
45 * Log(Matrix<double>(1, output.rows(), output.columns()) - output);
46 auto cost = class1_cost - class2_cost;
47
48 auto gradient = (xi.Transpose() * error) * eta;
49 gradient = gradient * (1.0 / xi.rows());
50
51 this->weights = this->weights - gradient;
52
53 return cost.sumElements() / xi.rows();
54 }
55
62 static Matrix<double> Log(const Matrix<double>& in) {
63 auto out = in;
64 for(size_t i = 0; i < out.rows(); i++) {
65 for(size_t j = 0; j < out.columns(); j++) { out(i, j) = log(out(i, j)); }
66 }
67 return out;
68 }
69
70public:
78 explicit LogRegSGD(double _eta = 0.01, int iter = 10, bool _shuffle = false, int _randomState = 0)
79 : ANNClassifier(_eta, iter)
80 , shuffle(_shuffle)
81 , randomState(_randomState) {
83 }
84
91 void fit(const Matrix<double>& X, const Matrix<double>& y) override {
92 if(sgd == nullptr) {
93 sgd = new SGD(
94 eta,
95 n_iter,
96 shuffle,
97 [this](const Matrix<double>& x, const Matrix<double>& y) { return this->update_weights(x, y); },
98 [this](const Matrix<double>& x) { return this->netInput(x); });
99 }
101 sgd->fit(X, y, weights);
102 }
103
104
110 void partial_fit(const Matrix<double>& X, const Matrix<double>& y) {
112 sgd->partial_fit(X, y, weights);
113 }
114
120 Matrix<double> netInput(const Matrix<double>& X) override { return Sigmoid(X * weights); }
121
128 Matrix<double> activation(const Matrix<double>& X) override { return netInput(X); }
129
136 std::function<bool(double)> condition = [](double x) { return bool(x >= EPS); };
137 return where(condition, activation(X), { { 1 } }, { { -1 } });
138 // return activation(X);
139 }
140
146 double costFunction([[maybe_unused]] const Matrix<double>& mat) override { return 0; }
147};
148
149
Definition: Classifier.h:70
double eta
Learning rate.
Definition: Classifier.h:73
int n_iter
number epochs
Definition: Classifier.h:75
bool w_initialized
flag to initialize weights only once
Definition: Classifier.h:25
Matrix< double > weights
Vector holding weights.
Definition: Classifier.h:29
Definition: LogRegSGD.h:12
static Matrix< double > Log(const Matrix< double > &in)
Definition: LogRegSGD.h:62
double costFunction(const Matrix< double > &mat) override
Definition: LogRegSGD.h:146
bool shuffle
signalizes whether given dataset should be shuffled while fitting
Definition: LogRegSGD.h:15
Matrix< double > netInput(const Matrix< double > &X) override
Definition: LogRegSGD.h:120
Matrix< double > predict(const Matrix< double > &X) override
Definition: LogRegSGD.h:135
void initialize_weights(size_t m)
Definition: LogRegSGD.h:26
int randomState
initialize weights with random state
Definition: LogRegSGD.h:17
void partial_fit(const Matrix< double > &X, const Matrix< double > &y)
Definition: LogRegSGD.h:110
Matrix< double > activation(const Matrix< double > &X) override
Definition: LogRegSGD.h:128
double update_weights(const Matrix< double > &xi, const Matrix< double > &target)
Definition: LogRegSGD.h:40
void fit(const Matrix< double > &X, const Matrix< double > &y) override
Definition: LogRegSGD.h:91
SGD * sgd
algorithmic object to represent fitting algorithm
Definition: LogRegSGD.h:19
LogRegSGD(double _eta=0.01, int iter=10, bool _shuffle=false, int _randomState=0)
Definition: LogRegSGD.h:78
Definition: Matrix.h:42
constexpr Matrix< T > Transpose() const
Definition: Matrix.h:256
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
static Matrix Random(size_t rows, size_t columns, size_t element_size=1, double minValue=0.0, double maxValue=1.0)
Definition: Matrix.h:158
T sumElements() const
Definition: Matrix.h:422
static void SetSeed(int seed)
Definition: Random.h:30
Definition: SGD.h:12
void partial_fit(const Matrix< double > &X, const Matrix< double > &y, Matrix< double > &weights) const
Definition: SGD.h:86
void fit(const Matrix< double > &X, const Matrix< double > &y, Matrix< double > &weights)
Definition: SGD.h:57
Matrix< double > Sigmoid(const Matrix< double > &in)
Definition: utils.h:9
Matrix< T > where(const std::function< bool(T)> &condition, const Matrix< T > &in, const Matrix< T > &valIfTrue, const Matrix< T > &valIfFalse)
Definition: matrix_utils.h:194
#define EPS
accuracy of calculated results
Definition: utils.h:24