Skip to content

Commit

Permalink
fix(torch): load weights only once
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and sileht committed Oct 14, 2020
1 parent cc086ff commit 0052a03
Show file tree
Hide file tree
Showing 6 changed files with 167 additions and 47 deletions.
4 changes: 4 additions & 0 deletions src/backends/torch/torchgraphbackend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ namespace dd

void TorchGraphBackend::allocate_modules()
{
_allocation_done = false;
for (BaseGraph::Vertex v : _sortedOps)
{
if (!_graph[v].alloc_needed)
Expand Down Expand Up @@ -220,6 +221,7 @@ namespace dd
_modules[opname] = AnyModule(m);
_graph[v].alloc_needed = false;
_rnn_has_memories[opname] = false;
_allocation_done = true;
}
else if (optype == "RNN")
{
Expand All @@ -233,6 +235,7 @@ namespace dd
_modules[opname] = AnyModule(m);
_graph[v].alloc_needed = false;
_rnn_has_memories[opname] = false;
_allocation_done = true;
}
else if (optype == "InnerProduct")
{
Expand All @@ -243,6 +246,7 @@ namespace dd
Linear(LinearOptions(dim(v, 0, 2), num_output(v)).bias(true)));
_modules[opname] = AnyModule(m);
_graph[v].alloc_needed = false;
_allocation_done = true;
}
else if (optype == "Tile")
_graph[v].alloc_needed = false;
Expand Down
13 changes: 11 additions & 2 deletions src/backends/torch/torchgraphbackend.h
Original file line number Diff line number Diff line change
Expand Up @@ -177,10 +177,18 @@ namespace dd
_parameters_used = false;
}

/**
* tells if some allocation was done (needs to be called just after
* set_inputdim or finalize
*/
bool needs_reload()
{
return _allocation_done;
}

protected:
/**
* internal torch module allocation, called whithin (finalize)
* @param force
*/
void allocate_modules();

Expand Down Expand Up @@ -215,8 +223,9 @@ namespace dd
std::unordered_map<std::string, bool>
_rnn_has_memories; /**< true if previsous hidden values are available
*/
};

bool _allocation_done = false;
};
}

#endif
170 changes: 131 additions & 39 deletions src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -163,29 +163,119 @@ namespace dd
_classif->to(device, dtype);
}

void TorchModule::proto_model_load(const TorchModel &model)
{
_logger->info("loading " + model._proto);
try
{
_graph = std::make_shared<CaffeToTorch>(model._proto);
}
catch (std::exception &e)
{
_logger->info("unable to load " + model._proto);
throw;
}
}

void TorchModule::graph_model_load(const TorchModel &tmodel)
{
if (!tmodel._traced.empty() && _graph->needs_reload())
{
_logger->info("loading " + tmodel._traced);
try
{
torch::load(_graph, tmodel._traced, _device);
}
catch (std::exception &e)
{
_logger->error("unable to load " + tmodel._traced);
throw;
}
}
}

void TorchModule::native_model_load(const TorchModel &tmodel)
{
if (!tmodel._native.empty())
{
_logger->info("loading " + tmodel._native);
try
{
torch::load(_native, tmodel._native);
}
catch (std::exception &e)
{
_logger->error("unable to load " + tmodel._native);
throw;
}
}
}

void TorchModule::classif_model_load(const TorchModel &model)
{
_logger->info("loading " + model._weights);
try
{
torch::load(_classif, model._weights, _device);
}
catch (std::exception &e)
{
_logger->error("unable to load " + model._weights);
throw;
}
}

void TorchModule::classif_layer_load()
{
if (!_classif_layer_file.empty())
{
_logger->info("loading " + _classif_layer_file);
torch::load(_classif, _classif_layer_file, _device);
}
}

void TorchModule::traced_model_load(TorchModel &model)
{
_logger->info("loading " + model._traced);
try
{
_traced = std::make_shared<torch::jit::script::Module>(
torch::jit::load(model._traced, _device));
}
catch (std::exception &e)
{
_logger->error("unable to load " + model._traced);
throw;
}
}

template <class TInputConnectorStrategy>
void TorchModule::post_transform(const std::string tmpl,
const APIData &template_params,
const TInputConnectorStrategy &inputc,
const TorchModel &tmodel,
const torch::Device &device)
{
_device = device;
this->_native = std::shared_ptr<NativeModule>(
NativeFactory::from_template<TInputConnectorStrategy>(
tmpl, template_params, inputc));

if (_native)
if (!tmodel._native.empty())
torch::load(_native, tmodel._native, device);
{
_logger->info("created net using template " + tmpl);
native_model_load(tmodel);
}

if (_graph)
{
std::vector<long int> dims = inputc._dataset.datasize(0);
dims.insert(dims.begin(), 1); // dummy batch size
_graph->finalize(dims);
if (_graph->needs_reload())
_logger->info("net was reallocated due to input dim changes");
// reload params after finalize
if (!tmodel._traced.empty())
torch::load(_graph, tmodel._traced, _device);
graph_model_load(tmodel);
}
to(_device);
}
Expand Down Expand Up @@ -361,11 +451,7 @@ namespace dd
// First dimension is batch id
int outdim = to_tensor_safe(forward(input_example)).sizes()[1];
_classif = torch::nn::Linear(outdim, nclasses);

if (!_classif_layer_file.empty())
{
torch::load(_classif, _classif_layer_file, _device);
}
classif_layer_load();
}

std::vector<Tensor> TorchModule::parameters()
Expand Down Expand Up @@ -401,13 +487,13 @@ namespace dd
void TorchModule::load(TorchModel &model)
{
if (!model._traced.empty() && model._proto.empty())
_traced = std::make_shared<torch::jit::script::Module>(
torch::jit::load(model._traced, _device));
traced_model_load(model);

if (!model._weights.empty())
{
if (_classif)
{
torch::load(_classif, model._weights, _device);
classif_model_load(model);
}
else if (_require_classif_layer)
{
Expand All @@ -416,16 +502,12 @@ namespace dd
}
if (!model._proto.empty())
{
_graph = std::make_shared<CaffeToTorch>(model._proto);
if (!model._traced.empty())
torch::load(_graph, model._traced, _device);
proto_model_load(model);
graph_model_load(model);
}

if (!model._native.empty())
{
std::shared_ptr<NativeModule> m;
torch::load(m, model._native);
_native = m;
}
native_model_load(model);
}

void TorchModule::eval()
Expand Down Expand Up @@ -544,6 +626,33 @@ namespace dd
}
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
class TMLModel>
void
TorchLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::solver_load(std::unique_ptr<optim::Optimizer> &optimizer)
{
if (!this->_mlmodel._sstate.empty())
{

this->_logger->info("Reload solver from {}", this->_mlmodel._sstate);
size_t start = this->_mlmodel._sstate.rfind("-") + 1;
size_t end = this->_mlmodel._sstate.rfind(".");
int it = std::stoi(this->_mlmodel._sstate.substr(start, end - start));
this->_logger->info("Restarting optimization from iter {}", it);
this->_logger->info("loading " + this->_mlmodel._sstate);
try
{
torch::load(*optimizer, this->_mlmodel._sstate);
}
catch (std::exception &e)
{
this->_logger->error("unable to load " + this->_mlmodel._sstate);
throw;
}
}
}

/*- from mllib -*/
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
class TMLModel>
Expand Down Expand Up @@ -581,6 +690,7 @@ namespace dd
_device = gpu ? torch::Device(DeviceType::CUDA, gpuid)
: torch::Device(DeviceType::CPU);
_module._device = _device;
_module._logger = this->_logger;

if (_template.find("recurrent") != std::string::npos)
{
Expand Down Expand Up @@ -665,15 +775,6 @@ namespace dd
}

// Load weights
if (!this->_mlmodel._traced.empty())
this->_logger->info("Loading ml model from file {}.",
this->_mlmodel._traced);
if (!this->_mlmodel._proto.empty())
this->_logger->info("Loading ml model from file {}.",
this->_mlmodel._proto);
if (!this->_mlmodel._weights.empty())
this->_logger->info("Loading weights from file {}.",
this->_mlmodel._weights);
_module.load(this->_mlmodel);
_module.freeze_traced(freeze_traced);

Expand Down Expand Up @@ -919,15 +1020,7 @@ namespace dd

int it = 0;
// reload solver and set it value accordingly
if (!this->_mlmodel._sstate.empty())
{
this->_logger->info("Reload solver from {}", this->_mlmodel._sstate);
size_t start = this->_mlmodel._sstate.rfind("-") + 1;
size_t end = this->_mlmodel._sstate.rfind(".");
it = std::stoi(this->_mlmodel._sstate.substr(start, end - start));
this->_logger->info("Restarting optimization from iter {}", it);
torch::load(*optimizer, this->_mlmodel._sstate);
}
solver_load(optimizer);
optimizer->zero_grad();
_module.train();

Expand Down Expand Up @@ -1422,7 +1515,6 @@ namespace dd
unsupo.finalize(ad.getobj("parameters").getobj("output"), out,
static_cast<MLModel *>(&this->_mlmodel));
}

out.add("status", 0);
return 0;
}
Expand Down
13 changes: 13 additions & 0 deletions src/backends/torch/torchlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,16 @@ namespace dd
the file where the weights are stored */
unsigned int _nclasses = 0;

std::shared_ptr<spdlog::logger> _logger; /**< mllib logger. */

private:
bool _freeze_traced = false; /**< Freeze weights of the traced module */
void proto_model_load(const TorchModel &tmodel);
void graph_model_load(const TorchModel &tmodel);
void native_model_load(const TorchModel &tmodel);
void classif_model_load(const TorchModel &tmodel);
void traced_model_load(TorchModel &model);
void classif_layer_load();
};

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
Expand Down Expand Up @@ -203,6 +211,11 @@ namespace dd

void snapshot(int64_t elapsed_it, torch::optim::Optimizer &optimizer);

/**
* \brief (re) load solver state
*/
void solver_load(std::unique_ptr<torch::optim::Optimizer> &optimizer);

void remove_model(int64_t it);

double unscale(double val, unsigned int k,
Expand Down
12 changes: 7 additions & 5 deletions src/caffegraphinput.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
#include <fcntl.h>
#include <unistd.h>

#include "mllibstrategy.h"

using google::protobuf::io::CodedInputStream;
using google::protobuf::io::CodedOutputStream;
using google::protobuf::io::FileInputStream;
Expand Down Expand Up @@ -177,19 +179,19 @@ namespace dd
return true;
}

int CaffeGraphInput::from_proto(std::string filename)
void CaffeGraphInput::from_proto(std::string filename)
{
caffe::NetParameter net;
if (!read_proto(filename, &net))
return -1;
throw MLLibBadParamException("unable to parse protofile");

bool simple_lstm = is_simple_lstm(net);
if (simple_lstm)
{
parse_simple_lstm(net);
return 0;
return;
}
return 0;
throw MLLibBadParamException(
"proto file do not contain a proper LSTM/autoencoder");
}

}
2 changes: 1 addition & 1 deletion src/caffegraphinput.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ namespace dd
/**
* create basegraph from proto
*/
int from_proto(std::string filename);
void from_proto(std::string filename);

/**
* read protofile
Expand Down

0 comments on commit 0052a03

Please sign in to comment.