Skip to content

Commit

Permalink
feat(graph): lstm autoencoder
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and sileht committed Oct 1, 2020
1 parent d0705ab commit 038a74c
Show file tree
Hide file tree
Showing 8 changed files with 335 additions and 47 deletions.
77 changes: 60 additions & 17 deletions src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3814,8 +3814,18 @@ namespace dd
->mutable_filler()
->set_value(1.0 / (float)_ntargets / (float)batch_size);
}
for (int i = 0; i < np->layers_size(); i++)
{
caffe::LayerParameter *lp = np->mutable_layer(i);
if (lp->has_dummy_data_param())
{
lp->mutable_dummy_data_param()->mutable_shape(0)->set_dim(
0, batch_size);
lp->mutable_dummy_data_param()->mutable_shape(1)->set_dim(
0, batch_size);
}
}
}

// caffe::WriteProtoToTextFile(*np,sp.net().c_str());
sp.clear_net();
}
Expand All @@ -3834,6 +3844,16 @@ namespace dd
if (orig_timesteps == timesteps)
return false;
lparam->mutable_memory_data_param()->set_channels(timesteps);
// also update timesteps in case of autoencoder
for (unsigned int i = 0; i < deploy_net_param.layer_size(); ++i)
{
caffe::LayerParameter *lp = deploy_net_param.mutable_layer(i);
if (lp->has_tile_param())
{
lp->mutable_tile_param()->set_tiles(timesteps);
break;
}
}
caffe::WriteProtoToTextFile(deploy_net_param, deploy_file);
return true;
}
Expand Down Expand Up @@ -3995,8 +4015,8 @@ namespace dd
->set_dim(2, inputc._sequence_txt);
}

// if autoencoder, set the last inner product layer output number to input
// size (i.e. inputc.channels())
// if autoencoder, set the last inner product layer output number to
// input size (i.e. inputc.channels())
if (_autoencoder && this->_inputc._timeserie)
{
int k = net_param.layer_size();
Expand Down Expand Up @@ -4053,6 +4073,17 @@ namespace dd
}
}

// update tile number with time steps (autoencoder)
for (unsigned int i = 0; i < net_param.layer_size(); ++i)
{
caffe::LayerParameter *lp = net_param.mutable_layer(i);
if (lp->has_tile_param())
{
lp->mutable_tile_param()->set_tiles(timesteps);
break;
}
}

caffe::NetParameter deploy_net_param;
caffe::ReadProtoFromTextFile(deploy_file, &deploy_net_param);
#ifdef USE_CUDNN
Expand Down Expand Up @@ -4155,6 +4186,17 @@ namespace dd
->set_crop_size(_crop_size);
}

// update tile number with time steps (autoencoder)
for (unsigned int i = 0; i < deploy_net_param.layer_size(); ++i)
{
caffe::LayerParameter *lp = deploy_net_param.mutable_layer(i);
if (lp->has_tile_param())
{
lp->mutable_tile_param()->set_tiles(timesteps);
break;
}
}

caffe::WriteProtoToTextFile(net_param, net_file);
caffe::WriteProtoToTextFile(deploy_net_param, deploy_file);
}
Expand Down Expand Up @@ -4384,8 +4426,8 @@ namespace dd
}
// fix class numbers
// this procedure looks for the first bottom layer with a 'num_output'
// field and rename the layer so that its weights can be reinitialized and
// the net finetuned
// field and rename the layer so that its weights can be reinitialized
// and the net finetuned
int k = net_param.layer_size();
for (int l = net_param.layer_size() - 1; l > 0; l--)
{
Expand Down Expand Up @@ -4433,8 +4475,8 @@ namespace dd
APIData ad_net = ad.getobj("parameters").getobj("mllib").getobj("net");
if (ad_net.has("batch_size"))
{
// adjust batch size so that it is a multiple of the number of training
// samples (Caffe requirement)
// adjust batch size so that it is a multiple of the number of
// training samples (Caffe requirement)
user_batch_size = batch_size = test_batch_size
= ad_net.get("batch_size").get<int>();
if (ad_net.has("test_batch_size"))
Expand All @@ -4444,8 +4486,8 @@ namespace dd
this->_logger->info("user batch_size={} / inputc batch_size=",
batch_size, inputc.batch_size());

// code below is required when Caffe (weirdly) requires the batch size
// to be a multiple of the training dataset size.
// code below is required when Caffe (weirdly) requires the batch
// size to be a multiple of the training dataset size.
if (!inputc._ctc && !inputc._segmentation
&& !(!inputc._db
&& typeid(inputc) == typeid(ImgCaffeInputFileConn)))
Expand Down Expand Up @@ -4569,9 +4611,9 @@ namespace dd
mltype = "rois";
break;
}
if (ltype == "ContinuationIndicator") // XXX: CTC layer does not appear
// in deploy file, this is a hack
// used by our LSTMs
if (ltype == "ContinuationIndicator") // XXX: CTC layer does not
// appear in deploy file, this
// is a hack used by our LSTMs
{
mltype = "ctc";
break;
Expand All @@ -4583,7 +4625,8 @@ namespace dd
{
mltype = "segmentation";
has_deconv = true;
// we don't break since some detection tasks may use deconvolutions
// we don't break since some detection tasks may use
// deconvolutions
}
if (!has_deconv && ltype == "Sigmoid" && l == net->layers().size() - 1)
{
Expand Down Expand Up @@ -4770,8 +4813,8 @@ namespace dd
}
catch (std::exception &e)
{
this->_logger->warn(
"dice contour size unrecognized, (odd) int expected");
this->_logger->warn("dice contour size unrecognized, "
"(odd) int expected");
}
}
if (contourdata.has("amplitude"))
Expand All @@ -4783,8 +4826,8 @@ namespace dd
}
catch (std::exception &e)
{
this->_logger->warn(
"dice contour amplitude unrecognized, float expected");
this->_logger->warn("dice contour amplitude "
"unrecognized, float expected");
}
}
this->_logger->warn(
Expand Down
38 changes: 28 additions & 10 deletions src/backends/torch/torchgraphbackend.cc
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,9 @@ namespace dd
auto opname_v = opname(v);
std::vector<torch::Tensor> output;
std::string optype = this->optype(v);

if (optype == "LSTM" || optype == "RNN")
{
// get<0>(f()) for all outputs / hidden
// get<0>(get<1>(f)) for last output
// get<1>(get<1>(f)) for last internal state
std::tuple<torch::Tensor, std::tuple<torch::Tensor, torch::Tensor>>
full_output;
if (_lstm_continuation && _rnn_has_memories[opname_v])
Expand All @@ -122,7 +120,11 @@ namespace dd
= _modules[opname_v]
.forward<std::tuple<Tensor, std::tuple<Tensor, Tensor>>>(
inputsTensor[0]);
output.push_back(std::get<0>(full_output));
output.push_back(std::get<0>(full_output)); // all outputs
output.push_back(
std::get<0>(std::get<1>(full_output))); // last hidden value
output.push_back(
std::get<1>(std::get<1>(full_output))); // last memory / c value
if (_lstm_continuation)
{
_rnn_memories[opname_v] = std::get<1>(full_output);
Expand All @@ -131,6 +133,18 @@ namespace dd
}
else if (optype == "InnerProduct")
output.push_back(_modules[opname_v].forward(inputsTensor[0]));
else if (optype == "Tile")
{
torch::Tensor x = inputsTensor[0];
std::vector<long int> rssizes = x.sizes().vec();
rssizes.erase(rssizes.begin()); // remove first dim because it is
// 1 : num_layers * num_directions
rssizes.insert(rssizes.begin() + _graph[v].axis, 1L);
torch::Tensor y = x.reshape(rssizes);
std::vector<long int> tiless(rssizes.size(), 1);
tiless[_graph[v].axis] = _graph[v].dim[_graph[v].axis];
output.push_back(y.repeat(tiless));
}
else
throw TorchGraphException("unknown optype " + optype + " for operator "
+ opname_v);
Expand All @@ -146,7 +160,8 @@ namespace dd
if (_parameters_used)
throw TorchGraphException(
"parameters reallocation necessary while they are used "
"elsewhere. You should module.forward() / module.set_input() / "
"elsewhere. You should module.forward() / module.set_input() "
"/ "
"module.finalize() with correct input dimensions before "
"modules.parameters() or module.parameters_release() if you "
"know what you are doing");
Expand All @@ -158,7 +173,8 @@ namespace dd
torch::nn::AnyModule m;
if (optype == "LSTM")
{
// dim(v,0,2) is 2nd dimension of input 0 of v, ie datadim for lstm
// dim(v,0,2) is 2nd dimension of input 0 of v, ie datadim for
// lstm
LSTM m = register_module(
opname, LSTM(LSTMOptions(dim(v, 0, 2), num_output(v))
.num_layers(1)
Expand All @@ -171,7 +187,8 @@ namespace dd
}
else if (optype == "RNN")
{
// dim(v,0,2) is 2nd dimension of input 0 of v, ie datadim for lstm
// dim(v,0,2) is 2nd dimension of input 0 of v, ie datadim for
// lstm
RNN m = register_module(opname,
RNN(RNNOptions(dim(v, 0, 2), num_output(v))
.num_layers(1)
Expand All @@ -183,14 +200,16 @@ namespace dd
}
else if (optype == "InnerProduct")
{
// dim(v,0,2) is 2nd dimension of input 0 of v, ie datadim for lstm
// output
// dim(v,0,2) is 2nd dimension of input 0 of v, ie datadim for
// lstm output
Linear m = register_module(
opname,
Linear(LinearOptions(dim(v, 0, 2), num_output(v)).bias(true)));
_modules[opname] = AnyModule(m);
_graph[v].alloc_needed = false;
}
else if (optype == "Tile")
_graph[v].alloc_needed = false;
}
to(_device, _dtype);
}
Expand Down Expand Up @@ -248,5 +267,4 @@ namespace dd
return torch::nn::Module::parameters(recurse);
}
}

}
2 changes: 1 addition & 1 deletion src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -489,7 +489,7 @@ namespace dd
{
caffe::NetParameter net_param;
configure_recurrent_template(lib_ad, this->_inputc, net_param,
this->_logger);
this->_logger, true);
torch_write_proto_to_text_file(net_param, dest_net);
}
else
Expand Down
8 changes: 8 additions & 0 deletions src/basegraph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,13 @@ namespace dd
outputdim.push_back(_graph[inputs[0]].dim[i]);
outputdim.push_back(_graph[producer].num_output);
}
else if (_graph[producer].type == "Tile")
{
// BIG FAT WARNING : this is a hack for tile in autoencoders only
// because exposed hidden ouputs are computed as of full size
outputdim.push_back(_graph[inputs[0]].dim[1]); // timesteps
outputdim.push_back(_graph[inputs[0]].dim[2]); // hidden_size
}
return std::tie(inputdim, outputdim);
}

Expand All @@ -265,6 +272,7 @@ namespace dd
auto es = boost::in_edges(v, _graph);
auto eit = es.first;
BaseGraph::Vertex producer = boost::source(*eit, _graph);

auto newdims = compute_dims_from_producer(producer);
update_alloc_status(newdims, producer);
_graph[producer].dim = std::get<1>(newdims);
Expand Down
18 changes: 17 additions & 1 deletion src/caffegraphinput.cc
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ namespace dd
else
firstl = ninput + 3;

if (net.layer(firstl).type() == "DummyData")
firstl++;

caffe::LayerParameter lparam = net.layer(firstl);
if (lparam.type() != "LSTM" && lparam.type() != "RNN"
&& lparam.type() != "InnerProduct")
Expand Down Expand Up @@ -131,7 +134,8 @@ namespace dd
std::vector<BaseGraph::Vertex> vi = add_inputs(v, inputs);

std::vector<std::string> outputs;
outputs.push_back(lparam.top(0));
for (unsigned int i = 0; i < lparam.top_size(); ++i)
outputs.push_back(lparam.top(i));
std::vector<BaseGraph::Vertex> vo = add_outputs(v, outputs);
set_output_name(lparam.top(0));
}
Expand All @@ -156,6 +160,18 @@ namespace dd
std::vector<BaseGraph::Vertex> vo = add_outputs(v, outputs);
set_output_name(lparam.top(0));
}
else if (lparam.type() == "Tile")
{
Vertex v = add_layer(lparam.name(), lparam.type());
std::vector<std::string> inputs;
inputs.push_back(lparam.bottom(0));
add_inputs(v, inputs);
std::vector<std::string> outputs;
outputs.push_back(lparam.top(0));
add_outputs(v, outputs);
set_output_name(lparam.top(0));
_graph[v].axis = lparam.tile_param().axis();
}
}

return true;
Expand Down
Loading

0 comments on commit 038a74c

Please sign in to comment.