philsupertramp/game-math
Loading...
Searching...
No Matches
Public Member Functions | Private Member Functions | Private Attributes | Friends | List of all members
DecisionTree Class Reference

#include <DecisionTree.h>

Inheritance diagram for DecisionTree:
Predictor

Public Member Functions

 DecisionTree (int min_samples_per_split=2, int max_depth=10, ImpurityMeasure decision_method=ImpurityMeasure::GINI)
 
double impurity (const Matrix< double > &x)
 
double information_gain (const Matrix< double > &parent, const Matrix< double > &left_child, const Matrix< double > &right_child)
 
struct DataSplit best_split (const Matrix< double > &X, const Matrix< double > &y)
 
DecisionNodeGetRootNode () const
 
void fit (const Matrix< double > &X, const Matrix< double > &y) override
 
Matrix< double > predict (const Matrix< double > &X) override
 
Matrix< double > transform (const Matrix< double > &in) override
 
virtual void fit (const Matrix< double > &X, const Matrix< double > &y)=0
 
virtual Matrix< double > predict (const Matrix< double > &)=0
 
virtual Matrix< double > transform (const Matrix< double > &)=0
 

Private Member Functions

DecisionNodebuild_tree (const Matrix< double > &X, const Matrix< double > &y, int depth=0)
 
double _predict (const Matrix< double > &x, DecisionNode *node)
 
void render_node (std::ostream &ostr, DecisionNode *node, const char *side, int depth) const
 

Private Attributes

ImpurityMeasure _decision_method
 
int _max_depth = 10
 
int _min_samples_per_split = 2
 
DecisionNodebase_node
 

Friends

std::ostream & operator<< (std::ostream &ostr, const DecisionTree &tree)
 

Detailed Description

Decision (Bi-)Tree implementation

Constructor & Destructor Documentation

◆ DecisionTree()

DecisionTree::DecisionTree ( int  min_samples_per_split = 2,
int  max_depth = 10,
ImpurityMeasure  decision_method = ImpurityMeasure::GINI 
)
inline

Member Function Documentation

◆ _predict()

double DecisionTree::_predict ( const Matrix< double > &  x,
DecisionNode node 
)
inlineprivate

Method to recursively traverse the tree to perform a prediction

Parameters
xinput value(s)
nodethe current node to evaluate
Returns
: prediction for given node

◆ best_split()

struct DataSplit DecisionTree::best_split ( const Matrix< double > &  X,
const Matrix< double > &  y 
)
inline

◆ build_tree()

DecisionNode * DecisionTree::build_tree ( const Matrix< double > &  X,
const Matrix< double > &  y,
int  depth = 0 
)
inlineprivate

Method to recursively build decision tree nodes.

Uses pruning to decide to split the data further. Class-Hyperparameters:

  • max_depth
  • min_samples_per_split
Parameters
Xinput data of current node
yrelated class labels
Returns
: A DecisionNode instance holding the next node.

◆ fit()

void DecisionTree::fit ( const Matrix< double > &  X,
const Matrix< double > &  y 
)
inlineoverridevirtual

Implements training algorithm

Parameters
Xarray-like with the shape: [n_samples, n_features]
yarray-like with shape: [n_samples, 1]
Returns
this

Implements Predictor.

◆ GetRootNode()

DecisionNode * DecisionTree::GetRootNode ( ) const
inline

◆ impurity()

double DecisionTree::impurity ( const Matrix< double > &  x)
inline

◆ information_gain()

double DecisionTree::information_gain ( const Matrix< double > &  parent,
const Matrix< double > &  left_child,
const Matrix< double > &  right_child 
)
inline

Calculates the information gain of a set of values given a split of two children.

Note: Pass your labels in here, not your input data!

Parameters
parentsuper set of label values
left_childleft split of parent set
right_childright split of parent set
Returns
information gain of given split

◆ predict()

Matrix< double > DecisionTree::predict ( const Matrix< double > &  X)
inlineoverridevirtual

Makes prediction for given input

Returns

Implements Predictor.

◆ render_node()

void DecisionTree::render_node ( std::ostream &  ostr,
DecisionNode node,
const char *  side,
int  depth 
) const
inlineprivate

◆ transform()

Matrix< double > DecisionTree::transform ( const Matrix< double > &  in)
inlineoverridevirtual

Implements Predictor.

Friends And Related Function Documentation

◆ operator<<

std::ostream & operator<< ( std::ostream &  ostr,
const DecisionTree tree 
)
friend

Member Data Documentation

◆ _decision_method

ImpurityMeasure DecisionTree::_decision_method
private

◆ _max_depth

int DecisionTree::_max_depth = 10
private

◆ _min_samples_per_split

int DecisionTree::_min_samples_per_split = 2
private

◆ base_node

DecisionNode* DecisionTree::base_node
private

The documentation for this class was generated from the following file: