philsupertramp/game-math
Loading...
Searching...
No Matches
DecisionTree.h
Go to the documentation of this file.
1#pragma once
2
3#include "../Matrix.h"
4#include "../numerics/utils.h"
5#include "Predictor.h"
6#include <string>
7
8
9enum ImpurityMeasure { ENTROPY = 0, GINI = 1 };
10
11const char* measure_to_name(ImpurityMeasure measure) {
12 switch(measure) {
13 case ImpurityMeasure::ENTROPY: return "entropy";
14 case ImpurityMeasure::GINI: return "gini";
15 default: return "unknown";
16 }
17}
18/*
19 def entropy(s):
20 counts = np.bincount(np.array(s, dtype=np.int64))
21 percentages = counts / len(s)
22 entropy = 0
23 for pct in percentages:
24 if pct > 0:
25 entropy += pct*np.log2(pct)
26 return -entropy
27 */
28double entropy(const Matrix<double>& in) {
29 auto unique_values = unique(in);
30 // TODO: Implement log2 to implement entropy... sadge
31 return 0;
32}
33
35 auto unique_values = sort(unique(in.rows() > in.columns() ? in : in.Transpose()));
36 auto bins = zeros(unique_values.rows(), 2);
37 for(size_t i = 0; i < unique_values.rows(); ++i) {
38 auto label = unique_values(i, 0);
39 std::function<bool(double)> cond = [label](double xi) { return bool(xi == label); };
40 bins(i, 0) = label;
41 bins(i, 1) = where_true(where(cond, in, { { 1 } }, { { 0 } })).elements_total();
42 }
43 return bins;
44}
45
46/*
47
48 def gini_impurity(s):
49 counts = np.bincount(np.array(s, dtype=np.int64))
50 percentages = counts / len(s)
51 return 1 - (percentages**2).sum()
52 */
53double gini(const Matrix<double>& in) {
54 auto bins = count_bins(in);
55 bins = bins.GetSlice(0,bins.rows()-1,1,1);
56 auto pct = ((1.0 / (double)in.elements_total()) * bins).Apply([](double xi){return xi * xi;});
57 return 1.0 - pct.sumElements();
58}
59
60enum DecisionNodeType { NONE = -1, LEAF = 0, DECISION = 1 };
61
71{
72public:
73 // Leaf node
74 double value = -1; // class label for leaf node
75
76 // Decision node
77 int feature = -1;
78 double threshold = -1;
79 double information_gain = -1;
80
82
83 // general
84 DecisionNodeType type = DecisionNodeType::NONE;
85
86public:
88
89 DecisionNode(double val)
90 : value(val) {
91 type = DecisionNodeType::LEAF;
92 }
93
94 DecisionNode(int feat, double thresh, double ig, DecisionNode* left, DecisionNode* right)
95 : feature(feat)
96 , threshold(thresh)
98 , left(left)
99 , right(right) {
100 type = DecisionNodeType::DECISION;
101 }
102};
103
104
108struct DataSplit {
110 double threshold;
115
130{
132 int _max_depth = 10;
135
136public:
138 int min_samples_per_split = 2, int max_depth = 10, ImpurityMeasure decision_method = ImpurityMeasure::GINI)
139 : Predictor()
140 , _decision_method(decision_method)
141 , _min_samples_per_split(min_samples_per_split)
142 , _max_depth(max_depth) { }
143
144
145 double impurity(const Matrix<double>& x) {
146 switch(_decision_method) {
147 case ImpurityMeasure::ENTROPY: return entropy(x);
148 case ImpurityMeasure::GINI: return gini(x);
149 default: return -1;
150 }
151 }
162 double
163 information_gain(const Matrix<double>& parent, const Matrix<double>& left_child, const Matrix<double>& right_child) {
164 double p_left = (double)left_child.elements_total() / parent.elements_total();
165 double p_right = (double)right_child.elements_total() / parent.elements_total();
166 return impurity(parent) - (p_left * impurity(left_child) + p_right * impurity(right_child));
167 }
168
169 struct DataSplit best_split(const Matrix<double>& X, const Matrix<double>& y) {
170 struct DataSplit best_split { };
171 double best_info_gain = -1.;
172
173 for(size_t idf = 0; idf < X.columns(); ++idf) {
174 auto eV = zerosV(X.columns());
175 eV(idf, 0) = 1;
176 auto features = X * eV;
177 for(size_t idx = 0; idx < features.rows(); ++idx) {
178 auto df = HorizontalConcat(X, y);
179 auto threshold = features(idx, 0);
180
181 std::function<bool(double)> cond = [threshold](double xi) { return bool(xi <= threshold); };
182 auto split_indices = where(cond, features, { { 1 } }, { { 0 } });
183 auto left_indices = where_true(split_indices);
184 auto right_indices = where_false(split_indices);
185 if(left_indices.rows() > 0 && right_indices.rows() > 0) {
186 auto left_split = df.GetSlicesByIndex(left_indices);
187 auto right_split = df.GetSlicesByIndex(right_indices);
188 auto gain = information_gain(
189 y,
190 left_split.GetSlice(0, left_split.rows() - 1, left_split.columns() - 1),
191 right_split.GetSlice(0, right_split.rows() - 1, right_split.columns() - 1));
192 if(gain > best_info_gain) {
193 best_info_gain = gain;
195 best_split.feature = idf;
196 best_split.threshold = threshold;
197 best_split.left_split = left_split;
198 best_split.right_split = right_split;
199 }
200 }
201 }
202 }
203 return best_split;
204 }
205
206private:
219 DecisionNode* build_tree(const Matrix<double>& X, const Matrix<double>& y, int depth = 0) {
220 DecisionNode* out = (DecisionNode*)malloc(sizeof(DecisionNode));
221 if(X.rows() >= _min_samples_per_split && depth < _max_depth) {
222 auto best = best_split(X, y);
223 if(best.information_gain > 0) {
224 auto left = build_tree(
225 best.left_split.GetSlice(0, best.left_split.rows() - 1, 0, best.left_split.columns() - 2),
226 best.left_split.GetSlice(0, best.left_split.rows() - 1, best.left_split.columns() - 1),
227 depth + 1);
228 auto right = build_tree(
229 best.right_split.GetSlice(0, best.right_split.rows() - 1, 0, best.right_split.columns() - 2),
230 best.right_split.GetSlice(0, best.right_split.rows() - 1, best.right_split.columns() - 1),
231 depth + 1);
232
233 out = new DecisionNode(best.feature, best.threshold, best.information_gain, left, right);
234 return out;
235 }
236 }
237
238 // found leaf node
239 auto bins = count_bins(y);
240 auto argmax_value = bins(argmax(bins.GetSlice(0, bins.rows() - 1, 1, 1)), 0);
241 out = new DecisionNode(argmax_value);
242 return out;
243 }
244
245
253 double _predict(const Matrix<double>& x, DecisionNode* node) {
254 if(node->type == DecisionNodeType::LEAF) { return node->value; }
255 if(x(0, node->feature) <= node->threshold) { return _predict(x, node->left); }
256 return _predict(x, node->right);
257 }
258
259 void render_node(std::ostream& ostr, DecisionNode* node, const char* side, int depth) const {
260 ostr << std::string(depth, '\t') << side;
261 if(node->type == DecisionNodeType::LEAF) {
262 ostr << "label = " << node->value << std::endl;
263 } else {
264 ostr << "feature(" << node->feature << ") threshold(" << node->threshold << ")" << measure_to_name(_decision_method) << " = " << node->information_gain << std::endl;
265 if(node->left) {
267 ostr,
268 node->left,
269 ("L (x_" + std::to_string(node->feature) + " <= " + std::to_string(node->threshold) + "): ").c_str(),
270 depth + 1);
271 }
272 if(node->right) {
274 ostr,
275 node->right,
276 ("R (x_" + std::to_string(node->feature) + " > " + std::to_string(node->threshold) + "): ").c_str(),
277 depth + 1);
278 }
279 }
280 }
281
282public:
283 DecisionNode* GetRootNode() const { return base_node; }
290 void fit(const Matrix<double>& X, const Matrix<double>& y) override {
291 base_node = (DecisionNode*)malloc(sizeof(DecisionNode));
292 base_node = build_tree(X, y, 0);
293 }
294
300 auto out = zeros(X.rows(), 1);
301 for(size_t i = 0; i < X.rows(); ++i) { out(i, 0) = _predict(X.GetSlice(i), base_node); }
302 return out;
303 }
304
309 Matrix<double> transform(const Matrix<double>& in) override { return in; };
310
311 friend std::ostream& operator<<(std::ostream& ostr, const DecisionTree& tree) {
312 ostr.precision(17);
313 tree.render_node(ostr, tree.base_node, "root: ", 0);
314 return ostr;
315 }
316};
double entropy(const Matrix< double > &in)
Definition: DecisionTree.h:28
const char * measure_to_name(ImpurityMeasure measure)
Definition: DecisionTree.h:11
DecisionNodeType
Definition: DecisionTree.h:60
@ DECISION
Definition: DecisionTree.h:60
@ LEAF
Definition: DecisionTree.h:60
@ NONE
Definition: DecisionTree.h:60
Matrix< double > count_bins(const Matrix< double > &in)
Definition: DecisionTree.h:34
ImpurityMeasure
Definition: DecisionTree.h:9
@ ENTROPY
Definition: DecisionTree.h:9
@ GINI
Definition: DecisionTree.h:9
double gini(const Matrix< double > &in)
Definition: DecisionTree.h:53
Definition: DecisionTree.h:71
DecisionNode * right
Definition: DecisionTree.h:81
DecisionNode()
Definition: DecisionTree.h:87
double value
Definition: DecisionTree.h:74
double threshold
Definition: DecisionTree.h:78
int feature
Definition: DecisionTree.h:77
DecisionNode(int feat, double thresh, double ig, DecisionNode *left, DecisionNode *right)
Definition: DecisionTree.h:94
DecisionNode(double val)
Definition: DecisionTree.h:89
double information_gain
Definition: DecisionTree.h:79
DecisionNode * left
Definition: DecisionTree.h:81
DecisionNodeType type
Definition: DecisionTree.h:84
Definition: DecisionTree.h:130
friend std::ostream & operator<<(std::ostream &ostr, const DecisionTree &tree)
Definition: DecisionTree.h:311
DecisionNode * GetRootNode() const
Definition: DecisionTree.h:283
Matrix< double > predict(const Matrix< double > &X) override
Definition: DecisionTree.h:299
double information_gain(const Matrix< double > &parent, const Matrix< double > &left_child, const Matrix< double > &right_child)
Definition: DecisionTree.h:163
DecisionNode * build_tree(const Matrix< double > &X, const Matrix< double > &y, int depth=0)
Definition: DecisionTree.h:219
Matrix< double > transform(const Matrix< double > &in) override
Definition: DecisionTree.h:309
int _min_samples_per_split
Definition: DecisionTree.h:133
struct DataSplit best_split(const Matrix< double > &X, const Matrix< double > &y)
Definition: DecisionTree.h:169
double impurity(const Matrix< double > &x)
Definition: DecisionTree.h:145
int _max_depth
Definition: DecisionTree.h:132
void fit(const Matrix< double > &X, const Matrix< double > &y) override
Definition: DecisionTree.h:290
ImpurityMeasure _decision_method
Definition: DecisionTree.h:131
double _predict(const Matrix< double > &x, DecisionNode *node)
Definition: DecisionTree.h:253
void render_node(std::ostream &ostr, DecisionNode *node, const char *side, int depth) const
Definition: DecisionTree.h:259
DecisionTree(int min_samples_per_split=2, int max_depth=10, ImpurityMeasure decision_method=ImpurityMeasure::GINI)
Definition: DecisionTree.h:137
DecisionNode * base_node
Definition: DecisionTree.h:134
Definition: Matrix.h:42
constexpr Matrix< T > Transpose() const
Definition: Matrix.h:256
size_t elements_total() const
Definition: Matrix.h:210
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
Matrix GetSlice(size_t rowStart) const
Definition: Matrix.h:609
Definition: Predictor.h:5
Matrix< size_t > where_true(const Matrix< T > &in)
Definition: matrix_utils.h:246
Matrix< T > HorizontalConcat(const Matrix< T > &lhs, const Matrix< T > &rhs)
Definition: matrix_utils.h:81
size_t argmax(const Matrix< T > &mat)
Definition: matrix_utils.h:143
Matrix< size_t > where_false(const Matrix< T > &in)
Definition: matrix_utils.h:257
Matrix< T > unique(const Matrix< T > &in, int axis=0)
Definition: matrix_utils.h:466
Matrix< T > where(const std::function< bool(T)> &condition, const Matrix< T > &in, const Matrix< T > &valIfTrue, const Matrix< T > &valIfFalse)
Definition: matrix_utils.h:194
Matrix< double > zerosV(size_t rows)
Matrix< double > zeros(size_t rows, size_t columns, size_t elements=1)
std::vector< T > sort(const std::vector< T > &in)
Definition: sorting.h:33
Definition: DecisionTree.h:108
double information_gain
Definition: DecisionTree.h:111
Matrix< double > right_split
Definition: DecisionTree.h:113
double threshold
Definition: DecisionTree.h:110
int feature
Definition: DecisionTree.h:109
Matrix< double > left_split
Definition: DecisionTree.h:112