Skip to content

Commit

Permalink
fix: svm with db training
Browse files Browse the repository at this point in the history
* svm with db training
* added test_split support with svm db
* added svm training with db to unit tests
  • Loading branch information
beniz authored and sileht committed Sep 22, 2020
1 parent 8da9074 commit 6e925f2
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 23 deletions.
2 changes: 1 addition & 1 deletion src/backends/caffe/caffeinputconns.h
Original file line number Diff line number Diff line change
Expand Up @@ -1489,7 +1489,6 @@ namespace dd
APIData ad_input = ad_param.getobj("input");
APIData ad_mllib = ad_param.getobj("mllib");

_test_dbfullname = "";
if (_train && ad_input.has("db") && ad_input.get("db").get<bool>())
{
_dbfullname = _model_repo + "/" + _dbfullname;
Expand All @@ -1511,6 +1510,7 @@ namespace dd
}
else
{
_test_dbfullname = "";
try
{
SVMInputFileConn::transform(ad);
Expand Down
20 changes: 17 additions & 3 deletions src/svminputfileconn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ namespace dd
}
if (vals.empty())
throw InputConnectorBadParamException(
"Issue while reading svm example (index might be out of boundsi)");
"Issue while reading svm example (index might be out of bounds)");
}

void SVMInputFileConn::read_svm(const APIData &ad, const std::string &fname)
Expand All @@ -126,6 +126,7 @@ namespace dd

// first pass to get max index
std::string col;
int total_lines = 0;
while (std::getline(svm_file, hline))
{
bool fpos = true;
Expand All @@ -146,28 +147,41 @@ namespace dd
_fids.insert(fid);
}
}
++total_lines;
}
svm_file.clear();
svm_file.seekg(0, std::ios::beg);

_logger->info("total number of dimensions={}", _fids.size());

int train_lines = 0;
if (_test_split > 0.0)
{
train_lines = total_lines * (1.0 - _test_split);
}

// read data
int nlines = 0;
int tnlines = 0;
while (std::getline(svm_file, hline))
{
std::unordered_map<int, double> vals;
int label;
read_svm_line(hline, vals, label);
add_train_svmline(label, vals, nlines);
if (train_lines > 0 && nlines < train_lines)
add_train_svmline(label, vals, nlines);
else
{
add_test_svmline(label, vals, tnlines);
++tnlines;
}
++nlines;
}
svm_file.close();
_logger->info("read {} lines from SVM data file", nlines);

if (!_svm_test_fname.empty())
{
int tnlines = 0;
std::ifstream svm_test_file(_svm_test_fname, std::ios::binary);
if (!svm_test_file.is_open())
throw InputConnectorBadParamException("cannot open SVM test file "
Expand Down
37 changes: 18 additions & 19 deletions tests/ut-caffeapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -851,25 +851,27 @@ TEST(caffeapi, service_train_svm)
{
// create service
JsonAPI japi;
std::string farm_repo_loc = "farm";
std::string sname = "my_service";
std::string jstr
= "{\"mllib\":\"caffe\",\"description\":\"my "
"classifier\",\"type\":\"supervised\",\"model\":{\"repository\":\""
+ farm_repo + "\",\"templates\":\"" + model_templates_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"svm\"},\"mllib\":{"
"\"template\":\"mlp\",\"nclasses\":2,\"activation\":\"prelu\"}}}";
+ farm_repo_loc + "\",\"templates\":\"" + model_templates_repo
+ "\",\"create_repository\":true},\"parameters\":{\"input\":{"
"\"connector\":\"svm\"},\"mllib\":{\"template\":\"mlp\","
"\"nclasses\":2,\"activation\":\"prelu\",\"db\":true}}}";
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// assert json blob file
ASSERT_TRUE(
fileops::file_exists(farm_repo + "/" + JsonAPI::_json_blob_fname));
fileops::file_exists(farm_repo_loc + "/" + JsonAPI::_json_blob_fname));

// train
std::string jtrainstr
= "{\"service\":\"" + sname
+ "\",\"async\":false,\"parameters\":{\"input\":{\"test_split\":0.1,"
"\"shuffle\":true},\"mllib\":{\"gpu\":true,\"gpuid\":"
+ "\",\"async\":false,\"parameters\":{\"input\":{\"db\":true,\"test_"
"split\":0.1,\"shuffle\":true},\"mllib\":{\"gpu\":true,\"gpuid\":"
+ gpuid + ",\"solver\":{\"iterations\":" + iterations_farm
+ ",\"base_lr\":0.01},\"net\":{\"batch_size\":100}},\"output\":{"
"\"measure\":[\"acc\",\"mcll\",\"f1\",\"cmdiag\",\"cmfull\"]}},"
Expand Down Expand Up @@ -897,12 +899,10 @@ TEST(caffeapi, service_train_svm)
#endif
ASSERT_EQ(jd["body"]["measure"]["accp"].GetDouble(),
jd["body"]["measure"]["acc"].GetDouble());
ASSERT_TRUE(jd["body"].HasMember("parameters"));
ASSERT_TRUE(jd["body"]["measure"].HasMember("cmdiag"));
ASSERT_EQ(2, jd["body"]["measure"]["cmdiag"].Size());
ASSERT_TRUE(jd["body"]["measure"]["cmdiag"][0].GetDouble() >= 0);
ASSERT_TRUE(jd["body"]["measure"]["cmfull"].Size());
ASSERT_EQ(16, jd["body"]["parameters"]["mllib"]["batch_size"].GetInt());

std::string mem_data
= "8:1 9:1 23:1 31:1 32:1 34:1 45:1 46:1 49:1 50:1 52:1 54:1 57:1 60:1 "
Expand Down Expand Up @@ -1006,34 +1006,34 @@ TEST(caffeapi, service_train_svm)
ASSERT_TRUE("1" == cat0);

// remove service
jstr = "{\"clear\":\"lib\"}";
jstr = "{\"clear\":\"full\"}";
joutstr = japi.jrender(japi.service_delete(sname, jstr));
ASSERT_EQ(ok_str, joutstr);

// assert json blob file is still there (or gone if clear=full)
ASSERT_TRUE(
!fileops::file_exists(farm_repo + "/" + JsonAPI::_json_blob_fname));
ASSERT_TRUE(!fileops::remove_directory_files(farm_repo, { ".prototxt" }));
!fileops::file_exists(farm_repo_loc + "/" + JsonAPI::_json_blob_fname));
}

TEST(caffeapi, service_train_svm_resnet)
{
// create service
JsonAPI japi;
std::string farm_repo_loc = "farm";
std::string sname = "my_service";
std::string jstr
= "{\"mllib\":\"caffe\",\"description\":\"my "
"classifier\",\"type\":\"supervised\",\"model\":{\"repository\":\""
+ farm_repo + "\",\"templates\":\"" + model_templates_repo
+ "\"},\"parameters\":{\"input\":{\"connector\":\"svm\"},\"mllib\":{"
"\"template\":\"resnet\",\"nclasses\":2,\"activation\":\"prelu\","
"\"layers\":[30,25,15]}}}";
+ farm_repo_loc + "\",\"templates\":\"" + model_templates_repo
+ "\",\"create_repository\":true},\"parameters\":{\"input\":{"
"\"connector\":\"svm\"},\"mllib\":{\"template\":\"resnet\","
"\"nclasses\":2,\"activation\":\"prelu\",\"layers\":[30,25,15]}}}";
std::string joutstr = japi.jrender(japi.service_create(sname, jstr));
ASSERT_EQ(created_str, joutstr);

// assert json blob file
ASSERT_TRUE(
fileops::file_exists(farm_repo + "/" + JsonAPI::_json_blob_fname));
fileops::file_exists(farm_repo_loc + "/" + JsonAPI::_json_blob_fname));

// train
std::string jtrainstr
Expand Down Expand Up @@ -1174,14 +1174,13 @@ TEST(caffeapi, service_train_svm_resnet)
ASSERT_TRUE("1" == cat0);

// remove service
jstr = "{\"clear\":\"lib\"}";
jstr = "{\"clear\":\"full\"}";
joutstr = japi.jrender(japi.service_delete(sname, jstr));
ASSERT_EQ(ok_str, joutstr);

// assert json blob file is still there (or gone if clear=full)
ASSERT_TRUE(
!fileops::file_exists(farm_repo + "/" + JsonAPI::_json_blob_fname));
ASSERT_TRUE(!fileops::remove_directory_files(farm_repo, { ".prototxt" }));
!fileops::file_exists(farm_repo_loc + "/" + JsonAPI::_json_blob_fname));
}

TEST(caffeapi, service_train_images)
Expand Down

0 comments on commit 6e925f2

Please sign in to comment.