5#include "../matrix_utils.h"
40 Set(
size_t inputCount,
size_t outputCount) {
51 Set(
const char* fileName,
size_t inputCount,
size_t outputCount) {
69 std::ifstream dataFile2(fileName);
70 if(dataFile2.is_open()) {
71 while(getline(dataFile2, line,
'\n')) { lineCount++; }
74 std::ifstream dataFile(fileName);
78 if(dataFile.is_open()) {
80 while(getline(dataFile, line,
'\n')) {
82 std::stringstream lineStream(line);
85 std::getline(lineStream, val,
'\t');
87 std::getline(lineStream, val);
100 std::cerr <<
"DataSet::ReadFromFile: Unable to open file " << fileName << std::endl;
110 if(batchSize == -1 || batchSize == (
int)
count) {
return *
this; }
115 for(
int i = 0; i < batchSize; i++) {
117 for(
size_t in = 0; in <
InputCount; in++) { newDS.Input(i, in) =
Input(index, in); }
118 for(
size_t out = 0; out <
OutputCount; out++) { newDS.Output(i, out) =
Output(index, out); }
132 auto minMax = maxVal - minVal;
134 for(
size_t j = 0; j <
Input.
rows(); j++) {
Input(i, j) = (
Input(i, j) - minVal) / minMax; }
143 for(
size_t i = 0; i <
Input.
rows(); ++i) {
148 for(
size_t i = 0; i <
Input.
rows(); ++i) {
149 for(
size_t j = 0; j <
Input.
columns(); ++j) { stds(i, 0) +=
pow(
Input(i, j) - means(i, 0), 2); }
152 for(
size_t i = 0; i <
Input.
rows(); ++i) {
163 for(
size_t i = 0; i <
Input.
rows(); ++i) {
168 for(
size_t i = 0; i <
Input.
rows(); ++i) {
169 for(
size_t j = 0; j <
Input.
columns(); ++j) { stds(i, 0) +=
pow(
Input(i, j) - means(i, 0), 2); }
172 for(
size_t i = 0; i <
Input.
rows(); ++i) {
205 DataSet(
const char* filePath,
size_t inputCount,
size_t outputCount)
208 ,
Training(
format(
"%s%s", filePath,
"training.dat").c_str(), inputCount, outputCount)
209 ,
Validation(
format(
"%s%s", filePath,
"validation.dat").c_str(), inputCount, outputCount)
210 ,
Test(
format(
"%s%s", filePath,
"test.dat").c_str(), inputCount, outputCount) { }
217 DataSet(
size_t inputCount,
size_t outputCount)
222 ,
Test(inputCount, outputCount) { }
NormalizerMethod
Definition: DataSet.h:13
@ SET_MEAN
Definition: DataSet.h:13
@ ROW_MEAN
Definition: DataSet.h:13
@ COL_MEAN
Definition: DataSet.h:13
double pow(double x, int exponent)
Definition: DataSet.h:195
DataSet(const char *filePath, size_t inputCount, size_t outputCount)
Definition: DataSet.h:205
Set Test
set to test after training
Definition: DataSet.h:241
double eta
learning rate
Definition: DataSet.h:248
int batchSize
number of elements per batch
Definition: DataSet.h:250
size_t OutputCount
number output elements
Definition: DataSet.h:234
Set Validation
set for validation of training
Definition: DataSet.h:239
size_t InputCount
number input elements
Definition: DataSet.h:232
Set Training
set for training
Definition: DataSet.h:237
virtual void PrepareDirectory(const char *filePath)
Definition: DataSet.h:228
double stopThreshold
threshold for loss to prevent over-fitting
Definition: DataSet.h:246
int maxEpoch
number of epochs while training
Definition: DataSet.h:244
DataSet(size_t inputCount, size_t outputCount)
Definition: DataSet.h:217
bool verbose
use verbose output during fitting
Definition: DataSet.h:252
only include Magick++ if needed
Definition: ImageDataSet.h:17
constexpr Matrix< T > Transpose() const
Definition: Matrix.h:256
size_t rows() const
Definition: Matrix.h:193
size_t columns() const
Definition: Matrix.h:198
static double Get(double l=0.0, double r=1.0)
Definition: Random.h:40
size_t argmin(const Matrix< T > &mat)
Definition: matrix_utils.h:167
size_t argmax(const Matrix< T > &mat)
Definition: matrix_utils.h:143
Matrix< double > Input
input data
Definition: DataSet.h:20
size_t count
number of input-output pairs
Definition: DataSet.h:28
Set(const char *fileName, size_t inputCount, size_t outputCount)
Definition: DataSet.h:51
size_t InputCount
number of input elements
Definition: DataSet.h:24
Set GetBatch(int batchSize) const
Definition: DataSet.h:109
void ReadFromFile(const char *fileName)
Definition: DataSet.h:65
size_t OutputCount
number of output elements
Definition: DataSet.h:26
void NormalizeRowMean()
Definition: DataSet.h:141
void NormalizeColMean()
Definition: DataSet.h:160
Matrix< double > Output
expected output data
Definition: DataSet.h:22
void NormalizeSetMean()
Definition: DataSet.h:126
void Normalize(NormalizerMethod method=NormalizerMethod::SET_MEAN)
Definition: DataSet.h:182
Set(size_t inputCount, size_t outputCount)
Definition: DataSet.h:40