philsupertramp/game-math
|
#include <SGD.h>
Public Member Functions | |
SGD (double _eta=0.01, size_t iter=10, bool _shuffle=false, const std::function< double(const Matrix< double > &, const Matrix< double > &)> &update_weights_fun=nullptr, const std::function< Matrix< double >(const Matrix< double > &)> &netInputFun=nullptr) | |
void | fit (const Matrix< double > &X, const Matrix< double > &y, Matrix< double > &weights) |
void | partial_fit (const Matrix< double > &X, const Matrix< double > &y, Matrix< double > &weights) const |
double | update_weights (const Matrix< double > &xi, const Matrix< double > &target, Matrix< double > &weights) const |
std::pair< Matrix< double >, Matrix< double > > | shuffleData (const Matrix< double > &X, const Matrix< double > &y) |
Static Public Member Functions | |
static Matrix< double > | netInput (const Matrix< double > &X, const Matrix< double > &weights) |
Public Attributes | |
Matrix< double > | cost |
matrix holding cost per epoch More... | |
Private Attributes | |
size_t | n_iter |
number iterations (epochs) during fit More... | |
double | eta |
cost factor More... | |
bool | shuffle = false |
shuffle during training More... | |
std::function< double(const Matrix< double > &, const Matrix< double > &)> | weight_update = nullptr |
represents weight update function More... | |
std::function< Matrix< double >(const Matrix< double > &)> | net_input_fun = nullptr |
represents net input function More... | |
Implements statistic gradient decent method to fit a model
|
inlineexplicit |
default constructor
_eta | |
iter | |
_shuffle | |
update_weights_fun | |
netInputFun |
|
inline |
fits given weights according to used weight update/net input function based on given values X and y
X | |
y | |
weights |
|
inlinestatic |
computes the net-input for given values
X | |
weights |
|
inline |
Performs partial fit of given weights to input values
X | |
y | |
weights |
|
inline |
shuffles given data
X | |
y |
|
inline |
calculates update values (mean-square-error)
xi | |
target | |
weights |
Matrix<double> SGD::cost |
matrix holding cost per epoch
|
private |
cost factor
|
private |
number iterations (epochs) during fit
represents net input function
|
private |
shuffle during training
|
private |
represents weight update function