Skip to content

Commit

Permalink
fix(torch): reload solver params on API device
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and Bycob committed Nov 5, 2020
1 parent 5ab90c7 commit 30fa16f
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 5 deletions.
2 changes: 1 addition & 1 deletion src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,7 @@ namespace dd

int it = 0;
// reload solver and set it value accordingly
it = tsolver.load(this->_mlmodel._sstate);
it = tsolver.load(this->_mlmodel._sstate, _device);
tsolver.zero_grad();
_module.train();

Expand Down
3 changes: 2 additions & 1 deletion src/backends/torch/torchmodule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,7 @@ namespace dd
try
{
_graph = std::make_shared<CaffeToTorch>(model._proto);
_graph->to(_device);
}
catch (std::exception &e)
{
Expand Down Expand Up @@ -107,7 +108,7 @@ namespace dd
_logger->info("loading " + tmodel._native);
try
{
torch::load(_native, tmodel._native);
torch::load(_native, tmodel._native, _device);
}
catch (std::exception &e)
{
Expand Down
4 changes: 2 additions & 2 deletions src/backends/torch/torchsolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ namespace dd
torch::save(*_optimizer, sfile);
}

int TorchSolver::load(std::string sstate)
int TorchSolver::load(std::string sstate, torch::Device device)
{
if (!sstate.empty())
{
Expand All @@ -140,7 +140,7 @@ namespace dd
_logger->info("loading " + sstate);
try
{
torch::load(*_optimizer, sstate);
torch::load(*_optimizer, sstate, device);
}
catch (std::exception &e)
{
Expand Down
2 changes: 1 addition & 1 deletion src/backends/torch/torchsolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ namespace dd
void configure(APIData ad_solver);
void create(TorchModule &module);

int load(std::string sstate);
int load(std::string sstate, torch::Device device);
void save(std::string sfile);

void zero_grad()
Expand Down

0 comments on commit 30fa16f

Please sign in to comment.