philsupertramp/game-math
Loading...
Searching...
No Matches
Public Member Functions | Static Public Member Functions | Public Attributes | Private Attributes | List of all members
SGD Class Reference

#include <SGD.h>

Public Member Functions

 SGD (double _eta=0.01, size_t iter=10, bool _shuffle=false, const std::function< double(const Matrix< double > &, const Matrix< double > &)> &update_weights_fun=nullptr, const std::function< Matrix< double >(const Matrix< double > &)> &netInputFun=nullptr)
 
void fit (const Matrix< double > &X, const Matrix< double > &y, Matrix< double > &weights)
 
void partial_fit (const Matrix< double > &X, const Matrix< double > &y, Matrix< double > &weights) const
 
double update_weights (const Matrix< double > &xi, const Matrix< double > &target, Matrix< double > &weights) const
 
std::pair< Matrix< double >, Matrix< double > > shuffleData (const Matrix< double > &X, const Matrix< double > &y)
 

Static Public Member Functions

static Matrix< double > netInput (const Matrix< double > &X, const Matrix< double > &weights)
 

Public Attributes

Matrix< double > cost
 matrix holding cost per epoch More...
 

Private Attributes

size_t n_iter
 number iterations (epochs) during fit More...
 
double eta
 cost factor More...
 
bool shuffle = false
 shuffle during training More...
 
std::function< double(const Matrix< double > &, const Matrix< double > &)> weight_update = nullptr
 represents weight update function More...
 
std::function< Matrix< double >(const Matrix< double > &)> net_input_fun = nullptr
 represents net input function More...
 

Detailed Description

Implements statistic gradient decent method to fit a model

Constructor & Destructor Documentation

◆ SGD()

SGD::SGD ( double  _eta = 0.01,
size_t  iter = 10,
bool  _shuffle = false,
const std::function< double(const Matrix< double > &, const Matrix< double > &)> &  update_weights_fun = nullptr,
const std::function< Matrix< double >(const Matrix< double > &)> &  netInputFun = nullptr 
)
inlineexplicit

default constructor

Parameters
_eta
iter
_shuffle
update_weights_fun
netInputFun

Member Function Documentation

◆ fit()

void SGD::fit ( const Matrix< double > &  X,
const Matrix< double > &  y,
Matrix< double > &  weights 
)
inline

fits given weights according to used weight update/net input function based on given values X and y

Parameters
X
y
weights

◆ netInput()

static Matrix< double > SGD::netInput ( const Matrix< double > &  X,
const Matrix< double > &  weights 
)
inlinestatic

computes the net-input for given values

Parameters
X
weights
Returns

◆ partial_fit()

void SGD::partial_fit ( const Matrix< double > &  X,
const Matrix< double > &  y,
Matrix< double > &  weights 
) const
inline

Performs partial fit of given weights to input values

Parameters
X
y
weights

◆ shuffleData()

std::pair< Matrix< double >, Matrix< double > > SGD::shuffleData ( const Matrix< double > &  X,
const Matrix< double > &  y 
)
inline

shuffles given data

Parameters
X
y
Returns

◆ update_weights()

double SGD::update_weights ( const Matrix< double > &  xi,
const Matrix< double > &  target,
Matrix< double > &  weights 
) const
inline

calculates update values (mean-square-error)

Parameters
xi
target
weights
Returns

Member Data Documentation

◆ cost

Matrix<double> SGD::cost

matrix holding cost per epoch

◆ eta

double SGD::eta
private

cost factor

◆ n_iter

size_t SGD::n_iter
private

number iterations (epochs) during fit

◆ net_input_fun

std::function<Matrix<double>(const Matrix<double>&)> SGD::net_input_fun = nullptr
private

represents net input function

◆ shuffle

bool SGD::shuffle = false
private

shuffle during training

◆ weight_update

std::function<double(const Matrix<double>&, const Matrix<double>&)> SGD::weight_update = nullptr
private

represents weight update function


The documentation for this class was generated from the following file: