philsupertramp/game-math
Loading...
Searching...
No Matches
DataSet.h
Go to the documentation of this file.
1#pragma once
2
3#include "../Matrix.h"
4#include "../format.h"
5#include "../matrix_utils.h"
6#include <fstream>
7#include <map>
8#include <sstream>
9
14
18struct Set {
24 size_t InputCount = 0;
26 size_t OutputCount = 0;
28 size_t count = 0;
29
33 Set() = default;
34
40 Set(size_t inputCount, size_t outputCount) {
41 InputCount = inputCount;
42 OutputCount = outputCount;
43 }
44
51 Set(const char* fileName, size_t inputCount, size_t outputCount) {
52 InputCount = inputCount;
53 OutputCount = outputCount;
54 ReadFromFile(fileName);
55 }
56
65 void ReadFromFile(const char* fileName) {
66 std::string line;
67 size_t lineCount = 0;
68
69 std::ifstream dataFile2(fileName);
70 if(dataFile2.is_open()) {
71 while(getline(dataFile2, line, '\n')) { lineCount++; }
72 dataFile2.close();
73 }
74 std::ifstream dataFile(fileName);
75 count = 0;
76 Input = Matrix<double>(0, lineCount, InputCount);
77 Output = Matrix<double>(0, lineCount, OutputCount);
78 if(dataFile.is_open()) {
79 line = "";
80 while(getline(dataFile, line, '\n')) {
81 std::string val;
82 std::stringstream lineStream(line);
83 for(size_t i = 0; i < InputCount + OutputCount; i++) {
84 if(i < InputCount + OutputCount - 1) {
85 std::getline(lineStream, val, '\t');
86 } else {
87 std::getline(lineStream, val);
88 }
89 if(i < InputCount) {
90 Input(count, i) = std::atof(val.c_str());
91 } else {
92 Output(count, i - InputCount) = std::atof(val.c_str());
93 }
94 }
95 ++count;
96 }
97 dataFile.close();
98 // Classes = OutputToClass(Output);
99 } else {
100 std::cerr << "DataSet::ReadFromFile: Unable to open file " << fileName << std::endl;
101 }
102 }
103
109 [[nodiscard]] Set GetBatch(int batchSize) const {
110 if(batchSize == -1 || batchSize == (int)count) { return *this; }
111
112 auto newDS = Set(InputCount, OutputCount);
113 newDS.Input = Matrix<double>(0, batchSize, InputCount);
114 newDS.Output = Matrix<double>(0, batchSize, OutputCount);
115 for(int i = 0; i < batchSize; i++) {
116 int index = (int)(Random::Get() * count);
117 for(size_t in = 0; in < InputCount; in++) { newDS.Input(i, in) = Input(index, in); }
118 for(size_t out = 0; out < OutputCount; out++) { newDS.Output(i, out) = Output(index, out); }
119 }
120 return newDS;
121 }
122
127 auto maxIndex = argmax(Input);
128 auto minIndex = argmin(Input);
129 auto maxVal = Input(int(maxIndex / Input.columns()), int(maxIndex % Input.columns()));
130 auto minVal = Input(int(minIndex / Input.columns()), int(minIndex % Input.columns()));
131
132 auto minMax = maxVal - minVal;
133 for(size_t i = 0; i < Input.columns(); i++) {
134 for(size_t j = 0; j < Input.rows(); j++) { Input(i, j) = (Input(i, j) - minVal) / minMax; }
135 }
136 }
137
142 auto means = Matrix<double>(0, Input.rows(), 1);
143 for(size_t i = 0; i < Input.rows(); ++i) {
144 for(size_t j = 0; j < Input.columns(); ++j) { means(i, 0) += Input(i, j); }
145 means(i, 0) /= Input.columns();
146 }
147 auto stds = Matrix<double>(0, Input.rows(), 1);
148 for(size_t i = 0; i < Input.rows(); ++i) {
149 for(size_t j = 0; j < Input.columns(); ++j) { stds(i, 0) += pow(Input(i, j) - means(i, 0), 2); }
150 stds(i, 0) /= Input.columns();
151 }
152 for(size_t i = 0; i < Input.rows(); ++i) {
153 for(size_t j = 0; j < Input.columns(); ++j) { Input(i, j) = (Input(i, j) - means(i, 0)) / stds(i, 0); }
154 }
155 }
156
162 auto means = Matrix<double>(0, Input.rows(), 1);
163 for(size_t i = 0; i < Input.rows(); ++i) {
164 for(size_t j = 0; j < Input.columns(); ++j) { means(i, 0) += Input(i, j); }
165 means(i, 0) /= Input.columns();
166 }
167 auto stds = Matrix<double>(0, Input.rows(), 1);
168 for(size_t i = 0; i < Input.rows(); ++i) {
169 for(size_t j = 0; j < Input.columns(); ++j) { stds(i, 0) += pow(Input(i, j) - means(i, 0), 2); }
170 stds(i, 0) /= Input.columns();
171 }
172 for(size_t i = 0; i < Input.rows(); ++i) {
173 for(size_t j = 0; j < Input.columns(); ++j) { Input(i, j) = (Input(i, j) - means(i, 0)) / stds(i, 0); }
174 }
176 }
177
182 void Normalize(NormalizerMethod method = NormalizerMethod::SET_MEAN) {
183 switch(method) {
184 case SET_MEAN: NormalizeSetMean(); break;
185 case COL_MEAN: NormalizeColMean(); break;
186 case ROW_MEAN: NormalizeRowMean(); break;
187 }
188 }
189};
190
195{
196 friend class ImageDataSet;
197
198public:
205 DataSet(const char* filePath, size_t inputCount, size_t outputCount)
206 : InputCount(inputCount)
207 , OutputCount(outputCount)
208 , Training(format("%s%s", filePath, "training.dat").c_str(), inputCount, outputCount)
209 , Validation(format("%s%s", filePath, "validation.dat").c_str(), inputCount, outputCount)
210 , Test(format("%s%s", filePath, "test.dat").c_str(), inputCount, outputCount) { }
211
217 DataSet(size_t inputCount, size_t outputCount)
218 : InputCount(inputCount)
219 , OutputCount(outputCount)
220 , Training(inputCount, outputCount)
221 , Validation(inputCount, outputCount)
222 , Test(inputCount, outputCount) { }
223
228 virtual void PrepareDirectory([[maybe_unused]] const char* filePath) { }
229
230public:
235
242
244 int maxEpoch = 1000;
246 double stopThreshold = 0.001;
248 double eta = 0.0051;
250 int batchSize = 5;
252 bool verbose = false;
253};
NormalizerMethod
Definition: DataSet.h:13
@ SET_MEAN
Definition: DataSet.h:13
@ ROW_MEAN
Definition: DataSet.h:13
@ COL_MEAN
Definition: DataSet.h:13
double pow(double x, int exponent)
Definition: DataSet.h:195
DataSet(const char *filePath, size_t inputCount, size_t outputCount)
Definition: DataSet.h:205
Set Test
set to test after training
Definition: DataSet.h:241
double eta
learning rate
Definition: DataSet.h:248
int batchSize
number of elements per batch
Definition: DataSet.h:250
size_t OutputCount
number output elements
Definition: DataSet.h:234
Set Validation
set for validation of training
Definition: DataSet.h:239
size_t InputCount
number input elements
Definition: DataSet.h:232
Set Training
set for training
Definition: DataSet.h:237
virtual void PrepareDirectory(const char *filePath)
Definition: DataSet.h:228
double stopThreshold
threshold for loss to prevent over-fitting
Definition: DataSet.h:246
int maxEpoch
number of epochs while training
Definition: DataSet.h:244
DataSet(size_t inputCount, size_t outputCount)
Definition: DataSet.h:217
bool verbose
use verbose output during fitting
Definition: DataSet.h:252
only include Magick++ if needed
Definition: ImageDataSet.h:17
Definition: Matrix.h:42
constexpr Matrix< T > Transpose() const
Definition: Matrix.h:256
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
static double Get(double l=0.0, double r=1.0)
Definition: Random.h:40
std::string format(const char *fmt,...)
Definition: format.h:22
size_t argmin(const Matrix< T > &mat)
Definition: matrix_utils.h:167
size_t argmax(const Matrix< T > &mat)
Definition: matrix_utils.h:143
Definition: DataSet.h:18
Matrix< double > Input
input data
Definition: DataSet.h:20
size_t count
number of input-output pairs
Definition: DataSet.h:28
Set(const char *fileName, size_t inputCount, size_t outputCount)
Definition: DataSet.h:51
size_t InputCount
number of input elements
Definition: DataSet.h:24
Set GetBatch(int batchSize) const
Definition: DataSet.h:109
Set()=default
void ReadFromFile(const char *fileName)
Definition: DataSet.h:65
size_t OutputCount
number of output elements
Definition: DataSet.h:26
void NormalizeRowMean()
Definition: DataSet.h:141
void NormalizeColMean()
Definition: DataSet.h:160
Matrix< double > Output
expected output data
Definition: DataSet.h:22
void NormalizeSetMean()
Definition: DataSet.h:126
void Normalize(NormalizerMethod method=NormalizerMethod::SET_MEAN)
Definition: DataSet.h:182
Set(size_t inputCount, size_t outputCount)
Definition: DataSet.h:40