4#include "../numerics/utils.h"
15 default:
return "unknown";
29 auto unique_values =
unique(in);
36 auto bins =
zeros(unique_values.rows(), 2);
37 for(
size_t i = 0; i < unique_values.rows(); ++i) {
38 auto label = unique_values(i, 0);
39 std::function<bool(
double)> cond = [label](
double xi) {
return bool(xi == label); };
41 bins(i, 1) =
where_true(
where(cond, in, { { 1 } }, { { 0 } })).elements_total();
55 bins = bins.GetSlice(0,bins.rows()-1,1,1);
56 auto pct = ((1.0 / (double)in.
elements_total()) * bins).Apply([](
double xi){
return xi * xi;});
57 return 1.0 - pct.sumElements();
91 type = DecisionNodeType::LEAF;
100 type = DecisionNodeType::DECISION;
138 int min_samples_per_split = 2,
int max_depth = 10,
ImpurityMeasure decision_method = ImpurityMeasure::GINI)
147 case ImpurityMeasure::ENTROPY:
return entropy(x);
148 case ImpurityMeasure::GINI:
return gini(x);
171 double best_info_gain = -1.;
173 for(
size_t idf = 0; idf < X.columns(); ++idf) {
174 auto eV =
zerosV(X.columns());
176 auto features = X * eV;
177 for(
size_t idx = 0; idx < features.rows(); ++idx) {
179 auto threshold = features(idx, 0);
181 std::function<bool(
double)> cond = [threshold](
double xi) {
return bool(xi <= threshold); };
182 auto split_indices =
where(cond, features, { { 1 } }, { { 0 } });
183 auto left_indices =
where_true(split_indices);
185 if(left_indices.rows() > 0 && right_indices.rows() > 0) {
186 auto left_split = df.GetSlicesByIndex(left_indices);
187 auto right_split = df.GetSlicesByIndex(right_indices);
190 left_split.GetSlice(0, left_split.rows() - 1, left_split.columns() - 1),
191 right_split.GetSlice(0, right_split.rows() - 1, right_split.columns() - 1));
192 if(gain > best_info_gain) {
193 best_info_gain = gain;
223 if(best.information_gain > 0) {
225 best.left_split.GetSlice(0, best.left_split.rows() - 1, 0, best.left_split.columns() - 2),
226 best.left_split.GetSlice(0, best.left_split.rows() - 1, best.left_split.columns() - 1),
229 best.right_split.GetSlice(0, best.right_split.rows() - 1, 0, best.right_split.columns() - 2),
230 best.right_split.GetSlice(0, best.right_split.rows() - 1, best.right_split.columns() - 1),
233 out =
new DecisionNode(best.feature, best.threshold, best.information_gain, left, right);
240 auto argmax_value = bins(
argmax(bins.GetSlice(0, bins.rows() - 1, 1, 1)), 0);
254 if(node->
type == DecisionNodeType::LEAF) {
return node->
value; }
260 ostr << std::string(depth,
'\t') << side;
261 if(node->
type == DecisionNodeType::LEAF) {
262 ostr <<
"label = " << node->
value << std::endl;
269 (
"L (x_" + std::to_string(node->
feature) +
" <= " + std::to_string(node->
threshold) +
"): ").c_str(),
276 (
"R (x_" + std::to_string(node->
feature) +
" > " + std::to_string(node->
threshold) +
"): ").c_str(),
double entropy(const Matrix< double > &in)
Definition: DecisionTree.h:28
const char * measure_to_name(ImpurityMeasure measure)
Definition: DecisionTree.h:11
DecisionNodeType
Definition: DecisionTree.h:60
@ DECISION
Definition: DecisionTree.h:60
@ LEAF
Definition: DecisionTree.h:60
@ NONE
Definition: DecisionTree.h:60
Matrix< double > count_bins(const Matrix< double > &in)
Definition: DecisionTree.h:34
ImpurityMeasure
Definition: DecisionTree.h:9
@ ENTROPY
Definition: DecisionTree.h:9
@ GINI
Definition: DecisionTree.h:9
double gini(const Matrix< double > &in)
Definition: DecisionTree.h:53
Definition: DecisionTree.h:71
DecisionNode * right
Definition: DecisionTree.h:81
DecisionNode()
Definition: DecisionTree.h:87
double value
Definition: DecisionTree.h:74
double threshold
Definition: DecisionTree.h:78
int feature
Definition: DecisionTree.h:77
DecisionNode(int feat, double thresh, double ig, DecisionNode *left, DecisionNode *right)
Definition: DecisionTree.h:94
DecisionNode(double val)
Definition: DecisionTree.h:89
double information_gain
Definition: DecisionTree.h:79
DecisionNode * left
Definition: DecisionTree.h:81
DecisionNodeType type
Definition: DecisionTree.h:84
Definition: DecisionTree.h:130
friend std::ostream & operator<<(std::ostream &ostr, const DecisionTree &tree)
Definition: DecisionTree.h:311
DecisionNode * GetRootNode() const
Definition: DecisionTree.h:283
Matrix< double > predict(const Matrix< double > &X) override
Definition: DecisionTree.h:299
double information_gain(const Matrix< double > &parent, const Matrix< double > &left_child, const Matrix< double > &right_child)
Definition: DecisionTree.h:163
DecisionNode * build_tree(const Matrix< double > &X, const Matrix< double > &y, int depth=0)
Definition: DecisionTree.h:219
Matrix< double > transform(const Matrix< double > &in) override
Definition: DecisionTree.h:309
int _min_samples_per_split
Definition: DecisionTree.h:133
struct DataSplit best_split(const Matrix< double > &X, const Matrix< double > &y)
Definition: DecisionTree.h:169
double impurity(const Matrix< double > &x)
Definition: DecisionTree.h:145
int _max_depth
Definition: DecisionTree.h:132
void fit(const Matrix< double > &X, const Matrix< double > &y) override
Definition: DecisionTree.h:290
ImpurityMeasure _decision_method
Definition: DecisionTree.h:131
double _predict(const Matrix< double > &x, DecisionNode *node)
Definition: DecisionTree.h:253
void render_node(std::ostream &ostr, DecisionNode *node, const char *side, int depth) const
Definition: DecisionTree.h:259
DecisionTree(int min_samples_per_split=2, int max_depth=10, ImpurityMeasure decision_method=ImpurityMeasure::GINI)
Definition: DecisionTree.h:137
DecisionNode * base_node
Definition: DecisionTree.h:134
constexpr Matrix< T > Transpose() const
Definition: Matrix.h:256
size_t elements_total() const
Definition: Matrix.h:210
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
Matrix GetSlice(size_t rowStart) const
Definition: Matrix.h:609
Definition: Predictor.h:5
Matrix< size_t > where_true(const Matrix< T > &in)
Definition: matrix_utils.h:246
Matrix< T > HorizontalConcat(const Matrix< T > &lhs, const Matrix< T > &rhs)
Definition: matrix_utils.h:81
size_t argmax(const Matrix< T > &mat)
Definition: matrix_utils.h:143
Matrix< size_t > where_false(const Matrix< T > &in)
Definition: matrix_utils.h:257
Matrix< T > unique(const Matrix< T > &in, int axis=0)
Definition: matrix_utils.h:466
Matrix< T > where(const std::function< bool(T)> &condition, const Matrix< T > &in, const Matrix< T > &valIfTrue, const Matrix< T > &valIfFalse)
Definition: matrix_utils.h:194
Matrix< double > zerosV(size_t rows)
Matrix< double > zeros(size_t rows, size_t columns, size_t elements=1)
std::vector< T > sort(const std::vector< T > &in)
Definition: sorting.h:33
Definition: DecisionTree.h:108
double information_gain
Definition: DecisionTree.h:111
Matrix< double > right_split
Definition: DecisionTree.h:113
double threshold
Definition: DecisionTree.h:110
int feature
Definition: DecisionTree.h:109
Matrix< double > left_split
Definition: DecisionTree.h:112