#include "../Test.h"
class KNNTestCase : public Test
{
bool TestConstructor() {
return true;
}
bool TestFit() {
{ { 2, 1 }, { 4, 1 }, { 2, 3 }, { 4, 3 }, { 8, 6 }, { 10, 8 }, { 2, 12 }, { 2, 14 }, { 4, 14 } });
Matrix<double> labels({ { 0 }, { 0 }, { 0 }, { 0 }, { 1 }, { 1 }, { 2 }, { 2 }, { 2 } });
clf.fit(input_data, labels);
AssertEqual(clf.weights, input_data);
clf.fit(input_data, labels);
AssertEqual(clf.weights, input_data);
return true;
}
bool TestPredict() {
{ { 2, 1 }, { 4, 1 }, { 2, 3 }, { 4, 3 }, { 8, 6 }, { 10, 8 }, { 2, 12 }, { 2, 14 }, { 4, 14 } });
Matrix<double> labels({ { 0 }, { 0 }, { 0 }, { 0 }, { 1 }, { 1 }, { 2 }, { 2 }, { 2 } });
Matrix<double> test_data({ { 1, 0 }, { 8, 8 }, { 0, 1 }, { 3, 16 } });
clf.fit(input_data, labels);
auto predictions = clf.predict(test_data);
AssertEqual(predictions, expected_labels);
return true;
}
bool TestPredictionOnIrisData() {
Set dataset =
Set(
"../../tests/ds/train.tsv", 4, 1);
Set test_set =
Set(
"../../tests/ds/test.tsv", 4, 1);
auto preds = clf.predict(test_set.
Input);
std::cout << "Iris flower classification, achieved " << acc * 100. << "% accuracy." << std::endl;
AssertEqual(acc, 1.0);
return true;
}
public:
virtual void run() {
TestConstructor();
TestFit();
TestPredict();
TestPredictionOnIrisData();
}
};
int main() {
KNNTestCase().run();
return 0;
}
double accuracy(const Matrix< double > &predictions, const Matrix< double > &ground_truth)
Definition: utils.h:14
Matrix< double > Input
input data
Definition: DataSet.h:20
Matrix< double > Output
expected output data
Definition: DataSet.h:22