Skip to content

Commit

Permalink
fix(torch): Fix conditions to add classification head.
Browse files Browse the repository at this point in the history
  • Loading branch information
Bycob authored and sileht committed Oct 8, 2020
1 parent c1f4ef9 commit f46a710
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 17 deletions.
26 changes: 14 additions & 12 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,18 @@ namespace dd
if (!tmodel._traced.empty())
torch::load(_graph, tmodel._traced, _device);
}
to(_device);
}

template <class TInputConnectorStrategy>
void TorchModule::post_transform_train(const std::string tmpl,
const APIData &template_params,
const TInputConnectorStrategy &inputc,
const TorchModel &tmodel,
const torch::Device &device)
{
post_transform(tmpl, template_params, inputc, tmodel, device);

if (_require_classif_layer && !_classif)
{
try
Expand All @@ -197,24 +209,14 @@ namespace dd
setup_classification(_nclasses,
const_cast<TInputConnectorStrategy &>(inputc)
.get_input_example(device));
_classif->to(_device);
}
catch (std::exception &e)
{
throw MLLibInternalException(std::string("Libtorch error: ")
+ e.what());
}
}
to(_device);
}

template <class TInputConnectorStrategy>
void TorchModule::post_transform_train(const std::string tmpl,
const APIData &template_params,
const TInputConnectorStrategy &inputc,
const TorchModel &tmodel,
const torch::Device &device)
{
post_transform(tmpl, template_params, inputc, tmodel, device);
}

template <class TInputConnectorStrategy>
Expand Down Expand Up @@ -654,7 +656,7 @@ namespace dd
this->_mltype = "classification";
_module._nclasses = _nclasses;

if (!this->_mlmodel._traced.empty() && !_module._classif)
if (_finetuning)
{
_module._require_classif_layer = true;
this->_logger->info(
Expand Down
7 changes: 4 additions & 3 deletions src/imginputfileconn.h
Original file line number Diff line number Diff line change
Expand Up @@ -519,9 +519,10 @@ namespace dd
ImgInputFileConn(const ImgInputFileConn &i)
: InputConnectorStrategy(i), _width(i._width), _height(i._height),
_crop_width(i._crop_width), _crop_height(i._crop_height), _bw(i._bw),
_unchanged_data(i._unchanged_data), _test_split(i._test_split),
_mean(i._mean), _has_mean_scalar(i._has_mean_scalar),
_scale(i._scale), _scaled(i._scaled), _scale_min(i._scale_min),
_rgb(i._rgb), _unchanged_data(i._unchanged_data),
_test_split(i._test_split), _mean(i._mean),
_has_mean_scalar(i._has_mean_scalar), _scale(i._scale),
_scaled(i._scaled), _scale_min(i._scale_min),
_scale_max(i._scale_max), _keep_orig(i._keep_orig),
_interp(i._interp)
#ifdef USE_CUDA_CV
Expand Down
3 changes: 1 addition & 2 deletions tests/ut-torchapi.cc
Original file line number Diff line number Diff line change
Expand Up @@ -78,8 +78,7 @@ TEST(torchapi, service_predict)
// predict
std::string jpredictstr
= "{\"service\":\"imgserv\",\"parameters\":{\"input\":{\"height\":224,"
"\"width\":224,\"rgb\":true,\"scale\":0.0039},\"output\":{\"best\":1}}"
",\"data\":[\""
"\"width\":224},\"output\":{\"best\":1}},\"data\":[\""
+ incept_repo + "cat.jpg\"]}";
joutstr = japi.jrender(japi.service_predict(jpredictstr));
JDoc jd;
Expand Down

0 comments on commit f46a710

Please sign in to comment.