philsupertramp/game-math
Loading...
Searching...
No Matches
SGD.h
Go to the documentation of this file.
1#pragma once
2
3#include "../Matrix.h"
4#include "../matrix_utils.h"
5#include <functional>
6
7
11class SGD
12{
14 size_t n_iter;
16 double eta;
18 bool shuffle = false;
20 std::function<double(const Matrix<double>&, const Matrix<double>&)> weight_update = nullptr;
22 std::function<Matrix<double>(const Matrix<double>&)> net_input_fun = nullptr;
23
24public:
27
36 explicit SGD(
37 double _eta = 0.01,
38 size_t iter = 10,
39 bool _shuffle = false,
40 const std::function<double(const Matrix<double>&, const Matrix<double>&)>& update_weights_fun = nullptr,
41 const std::function<Matrix<double>(const Matrix<double>&)>& netInputFun = nullptr)
42 : n_iter(iter)
43 , eta(_eta)
44 , shuffle(_shuffle) {
45 if(update_weights_fun != nullptr) { weight_update = update_weights_fun; }
46 if(netInputFun != nullptr) { net_input_fun = netInputFun; }
47 cost = Matrix<double>(0.0, 0, 0);
48 }
49
57 void fit(const Matrix<double>& X, const Matrix<double>& y, Matrix<double>& weights) {
58 cost = Matrix<double>(0, n_iter, 1);
59 auto xCopy = X;
60 auto yCopy = y;
61 for(size_t iter = 0; iter < n_iter; iter++) {
62 if(shuffle) {
63 auto fooPair = shuffleData(X, y);
64 xCopy = fooPair.first;
65 yCopy = fooPair.second;
66 }
67 double costSum = 0;
68 for(const auto& elem : zip(xCopy, yCopy)) {
69 if(weight_update != nullptr) {
70 costSum += weight_update(elem.first, elem.second);
71 } else {
72 costSum += update_weights(elem.first, elem.second, weights);
73 }
74 }
75
76 cost(iter, 0) = costSum;
77 }
78 }
79
86 void partial_fit(const Matrix<double>& X, const Matrix<double>& y, Matrix<double>& weights) const {
87 if(y.rows() > 1) {
88 for(const auto& elem : zip(X, y)) { update_weights(elem.first, elem.second, weights); }
89 } else {
90 update_weights(X, y, weights);
91 }
92 }
93
101 double update_weights(const Matrix<double>& xi, const Matrix<double>& target, Matrix<double>& weights) const {
102 auto output = Matrix<double>();
103 if(net_input_fun != nullptr) {
104 output = net_input_fun(xi);
105 } else {
106 output = netInput(xi, weights);
107 }
108 auto error = target - output;
109 weights.SetRow(0, weights(0) + eta * error);
110 auto delta = eta * (xi.Transpose() * error);
111 for(size_t i = 0; i < weights.rows() - 1; i++) { weights.SetRow(i + 1, weights(i + 1) + delta(i)); }
112 return ((error * error) * 0.5).sumElements() / (double)target.rows();
113 }
114
121 std::pair<Matrix<double>, Matrix<double>>
122 shuffleData([[maybe_unused]] const Matrix<double>& X, [[maybe_unused]] const Matrix<double>& y) {
123 return { X, y };
124 }
125
132 [[nodiscard]] static Matrix<double> netInput(const Matrix<double>& X, const Matrix<double>& weights) {
133 Matrix<double> A(0, weights.rows() - 1, weights.columns());
134 Matrix<double> B(0, 1, weights.columns());
135 for(size_t i = 0; i < weights.rows(); i++) {
136 for(size_t j = 0; j < weights.columns(); j++) {
137 if(i == 0) {
138 B(i, j) = weights(i, j);
139 } else {
140 A(i - 1, j) = weights(i, j);
141 }
142 }
143 }
144
145 // Note: since Matrix<T> will never allow Matrix-Scalar Addition
146 // we need to create a vector of size of rows of the input
147 // values (jeez, horrible sentence). Therefore we rescale
148 // B and assign the bias to each element.
149 B = Matrix<double>(B(0, 0), X.rows(), 1);
150 return (X * A) + B;
151 }
152};
Definition: Matrix.h:42
constexpr Matrix< T > Transpose() const
Definition: Matrix.h:256
void SetRow(size_t index, const Matrix< T > &other)
Definition: Matrix.h:541
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
Definition: SGD.h:12
double update_weights(const Matrix< double > &xi, const Matrix< double > &target, Matrix< double > &weights) const
Definition: SGD.h:101
std::function< Matrix< double >(const Matrix< double > &)> net_input_fun
represents net input function
Definition: SGD.h:22
bool shuffle
shuffle during training
Definition: SGD.h:18
size_t n_iter
number iterations (epochs) during fit
Definition: SGD.h:14
void partial_fit(const Matrix< double > &X, const Matrix< double > &y, Matrix< double > &weights) const
Definition: SGD.h:86
std::pair< Matrix< double >, Matrix< double > > shuffleData(const Matrix< double > &X, const Matrix< double > &y)
Definition: SGD.h:122
static Matrix< double > netInput(const Matrix< double > &X, const Matrix< double > &weights)
Definition: SGD.h:132
Matrix< double > cost
matrix holding cost per epoch
Definition: SGD.h:26
std::function< double(const Matrix< double > &, const Matrix< double > &)> weight_update
represents weight update function
Definition: SGD.h:20
SGD(double _eta=0.01, size_t iter=10, bool _shuffle=false, const std::function< double(const Matrix< double > &, const Matrix< double > &)> &update_weights_fun=nullptr, const std::function< Matrix< double >(const Matrix< double > &)> &netInputFun=nullptr)
Definition: SGD.h:36
void fit(const Matrix< double > &X, const Matrix< double > &y, Matrix< double > &weights)
Definition: SGD.h:57
double eta
cost factor
Definition: SGD.h:16
std::vector< std::pair< Matrix< T >, Matrix< T > > > zip(const Matrix< T > &a, const Matrix< T > &b)
Definition: matrix_utils.h:271