#include "../Test.h"
class LogRegSGDTestCase : public Test
{
bool TestLogReg() {
{ 0.29651885541310041 },
{ 0.64581698997155579 },
{ 0.73420390202980546 },
{ -0.33842438014317311 },
{ 0.73079510849631024 },
{ -0.90637656297433933 },
{ -0.58136213318226426 },
{ -0.12520400256483499 },
{ 0.41344311936636324 },
{ 0.0089382791988343868 },
});
std::function<bool(double)> condition = [](double x) { return bool(x >= 0.0); };
A = A.Apply([minVal, maxVal](const double& in) { return (in - minVal) / (maxVal - minVal); });
logRegSgd.fit(A, B);
auto val = logRegSgd.predict({ { -1.0 } })(0, 0);
auto val2 = logRegSgd.predict({ { 2.0 } })(0, 0);
AssertEqual(val, 1);
AssertEqual(val2, -1);
return true;
}
bool TestNoCost() {
AssertEqual(logRegSgd.costFunction({ { 1 } }), 0.0);
return true;
}
public:
void run() override {
TestLogReg();
}
};
int main() {
LogRegSGDTestCase().run();
return 0;
}
Definition: LogRegSGD.h:12
T min(const Matrix< T > &mat)
Definition: matrix_utils.h:324
T max(const Matrix< T > &mat)
Definition: matrix_utils.h:292
Matrix< T > where(const std::function< bool(T)> &condition, const Matrix< T > &in, const Matrix< T > &valIfTrue, const Matrix< T > &valIfFalse)
Definition: matrix_utils.h:194