4#include "../matrix_utils.h"
39 bool _shuffle =
false,
45 if(update_weights_fun !=
nullptr) {
weight_update = update_weights_fun; }
61 for(
size_t iter = 0; iter <
n_iter; iter++) {
64 xCopy = fooPair.first;
65 yCopy = fooPair.second;
68 for(
const auto& elem :
zip(xCopy, yCopy)) {
76 cost(iter, 0) = costSum;
88 for(
const auto& elem :
zip(X, y)) {
update_weights(elem.first, elem.second, weights); }
108 auto error = target - output;
109 weights.
SetRow(0, weights(0) +
eta * error);
111 for(
size_t i = 0; i < weights.
rows() - 1; i++) { weights.
SetRow(i + 1, weights(i + 1) + delta(i)); }
112 return ((error * error) * 0.5).sumElements() / (double)target.
rows();
135 for(
size_t i = 0; i < weights.
rows(); i++) {
136 for(
size_t j = 0; j < weights.
columns(); j++) {
138 B(i, j) = weights(i, j);
140 A(i - 1, j) = weights(i, j);
constexpr Matrix< T > Transpose() const
Definition: Matrix.h:256
void SetRow(size_t index, const Matrix< T > &other)
Definition: Matrix.h:541
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
double update_weights(const Matrix< double > &xi, const Matrix< double > &target, Matrix< double > &weights) const
Definition: SGD.h:101
std::function< Matrix< double >(const Matrix< double > &)> net_input_fun
represents net input function
Definition: SGD.h:22
bool shuffle
shuffle during training
Definition: SGD.h:18
size_t n_iter
number iterations (epochs) during fit
Definition: SGD.h:14
void partial_fit(const Matrix< double > &X, const Matrix< double > &y, Matrix< double > &weights) const
Definition: SGD.h:86
std::pair< Matrix< double >, Matrix< double > > shuffleData(const Matrix< double > &X, const Matrix< double > &y)
Definition: SGD.h:122
static Matrix< double > netInput(const Matrix< double > &X, const Matrix< double > &weights)
Definition: SGD.h:132
Matrix< double > cost
matrix holding cost per epoch
Definition: SGD.h:26
std::function< double(const Matrix< double > &, const Matrix< double > &)> weight_update
represents weight update function
Definition: SGD.h:20
SGD(double _eta=0.01, size_t iter=10, bool _shuffle=false, const std::function< double(const Matrix< double > &, const Matrix< double > &)> &update_weights_fun=nullptr, const std::function< Matrix< double >(const Matrix< double > &)> &netInputFun=nullptr)
Definition: SGD.h:36
void fit(const Matrix< double > &X, const Matrix< double > &y, Matrix< double > &weights)
Definition: SGD.h:57
double eta
cost factor
Definition: SGD.h:16
std::vector< std::pair< Matrix< T >, Matrix< T > > > zip(const Matrix< T > &a, const Matrix< T > &b)
Definition: matrix_utils.h:271