Skip to content

Commit

Permalink
feat(torch,native): extract_layer
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and sileht committed Oct 5, 2020
1 parent 28247b4 commit d37e182
Show file tree
Hide file tree
Showing 9 changed files with 923 additions and 103 deletions.
28 changes: 28 additions & 0 deletions src/backends/torch/native/native_net.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,35 @@ namespace dd
class NativeModule : public torch::nn::Module
{
public:
/**
* \brief forward pass over the
* @param input tensor
* @return value of output
*/
virtual torch::Tensor forward(torch::Tensor x) = 0;

/**
* \brief extract layer from net
* @param input
* @param name of data to extract
* @return extracted tensor
*/
virtual torch::Tensor extract(torch::Tensor x, std::string extract_layer)
= 0;

/**
* \brief check is string correspond to some layer in the net
* @param the name of the data node
* @return true if it exists in the net
*/
virtual bool extractable(std::string extract_layer) const = 0;

/**
* \brief return all candidates for extraction, ie all data nodes of the
* net
*/
virtual std::vector<std::string> extractable_layers() const = 0;

virtual ~NativeModule()
{
}
Expand Down
183 changes: 168 additions & 15 deletions src/backends/torch/native/templates/nbeats.cc
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
#include "nbeats.h"
#include <cmath>
#include <string>

namespace dd
{
Expand Down Expand Up @@ -34,6 +35,29 @@ namespace dd
return x;
}

torch::Tensor NBeats::BlockImpl::first_extract(torch::Tensor x,
std::string extract_layer)
{
x = x.reshape({ x.size(0), x.size(1) * x.size(2) });
x = _fc1->forward(x);
if (extract_layer == "fc1")
return x;
x = torch::relu(x);
x = _fc2->forward(x);
if (extract_layer == "fc2")
return x;
x = torch::relu(x);
x = _fc3->forward(x);
if (extract_layer == "fc3")
return x;
x = torch::relu(x);
x = _fc4->forward(x);
if (extract_layer == "fc4")
return x;
x = torch::relu(x);
return x;
}

std::tuple<torch::Tensor, torch::Tensor>
NBeats::SeasonalityBlockImpl::forward(torch::Tensor x)
{
Expand All @@ -44,7 +68,24 @@ namespace dd
torch::Tensor forecast = tffc.mm(_fS.to(_device));
return std::make_tuple(
backcast.reshape({ backcast.size(0), _backcast_length, _data_size }),
forecast.reshape({ forecast.size(0), _forecast_length, _data_size }));
forecast.reshape({ forecast.size(0), _backcast_length, _data_size }));
}

torch::Tensor
NBeats::SeasonalityBlockImpl::extract(torch::Tensor x,
std::string extract_layer)
{
x = BlockImpl::first_extract(x, extract_layer);
if (extract_layer != "theta_b_fc" && extract_layer != "theta_f_fc")
return x;
torch::Tensor tbfc = _theta_b_fc->forward(x);
torch::Tensor backcast = tbfc.mm(_bS.to(_device));
if (extract_layer == "theta_b_fc")
return backcast;

torch::Tensor tffc = _theta_f_fc->forward(x);
torch::Tensor forecast = tffc.mm(_fS.to(_device));
return forecast;
}

std::tuple<torch::Tensor, torch::Tensor>
Expand Down Expand Up @@ -150,6 +191,23 @@ namespace dd
return std::make_tuple(bT.to(_device), fT.to(_device));
}

torch::Tensor NBeats::TrendBlockImpl::extract(torch::Tensor x,
std::string extract_layer)
{
x = BlockImpl::first_extract(x, extract_layer);
if (extract_layer != "theta_b_fc" && extract_layer != "theta_f_fc")
return x;

torch::Tensor tbfc = _theta_b_fc->forward(x);
torch::Tensor backcast = tbfc.mm(_bT.to(_device));
if (extract_layer == "theta_b_fc")
return backcast;

torch::Tensor tffc = _theta_b_fc->forward(x);
torch::Tensor forecast = tffc.mm(_fT.to(_device));
return forecast;
}

std::tuple<torch::Tensor, torch::Tensor>
NBeats::TrendBlockImpl::forward(torch::Tensor x)
{
Expand All @@ -161,7 +219,30 @@ namespace dd

return std::make_tuple(
backcast.reshape({ backcast.size(0), _backcast_length, _data_size }),
forecast.reshape({ forecast.size(0), _forecast_length, _data_size }));
forecast.reshape({ forecast.size(0), _backcast_length, _data_size }));
}

torch::Tensor NBeats::GenericBlockImpl::extract(torch::Tensor x,
std::string extract_layer)
{
x = BlockImpl::first_extract(x, extract_layer);
if (extract_layer != "theta_b_fc" && extract_layer != "theta_f_fc")
return x;

x = _theta_b_fc->forward(x);
if (extract_layer == "theta_b_fc")
return x;

torch::Tensor theta_b = torch::relu(x);
torch::Tensor y = _theta_f_fc->forward(x);
if (extract_layer == "theta_f_fc")
return y;
torch::Tensor theta_f = torch::relu(y);
torch::Tensor backcast = _backcast_fc->forward(theta_b);
if (extract_layer == "backcast_fc")
return backcast;
torch::Tensor forecast = _forecast_fc->forward(theta_f);
return forecast;
}

std::tuple<torch::Tensor, torch::Tensor>
Expand All @@ -174,25 +255,23 @@ namespace dd
torch::Tensor forecast = _forecast_fc->forward(theta_f);
return std::make_tuple(
backcast.reshape({ backcast.size(0), _backcast_length, _data_size }),
forecast.reshape({ forecast.size(0), _forecast_length, _data_size }));
forecast.reshape({ forecast.size(0), _backcast_length, _data_size }));
}

void NBeats::update_params(const CSVTSTorchInputFileConn &inputc)
{
_output_size = inputc._label.size();
_data_size = inputc._datadim - _output_size;
_backcast_length = inputc._timesteps;
// per dd timeserie / LSTM logic, there is one output per input
_forecast_length = inputc._timesteps;
}

void NBeats::create_nbeats()
{
float back_step = 1.0 / (float)(_backcast_length);
for (unsigned int i = 0; i < _backcast_length; ++i)
_backcast_linspace.push_back(back_step * static_cast<float>(i));
float fore_step = 1.0 / (float)(_forecast_length);
for (unsigned int i = 0; i < _forecast_length; ++i)
float fore_step = 1.0 / (float)(_backcast_length);
for (unsigned int i = 0; i < _backcast_length; ++i)
_forecast_linspace.push_back(fore_step * static_cast<float>(i));

std::tuple<torch::Tensor, torch::Tensor> S;
Expand All @@ -212,9 +291,8 @@ namespace dd
"seasonalityBlock_" + std::to_string(block_id) + "_stack_"
+ std::to_string(stack_id),
SeasonalityBlock(_hidden_layer_units, _thetas_dims[stack_id],
_backcast_length, _forecast_length,
_data_size, std::get<0>(S),
std::get<1>(S)))));
_backcast_length, _data_size,
std::get<0>(S), std::get<1>(S)))));
break;
case trend:
T = create_exp_basis(_thetas_dims[stack_id]);
Expand All @@ -224,8 +302,8 @@ namespace dd
"trendBlock_" + std::to_string(block_id) + "_stack_"
+ std::to_string(stack_id),
TrendBlock(_hidden_layer_units, _thetas_dims[stack_id],
_backcast_length, _forecast_length, _data_size,
std::get<0>(T), std::get<1>(T)))));
_backcast_length, _data_size, std::get<0>(T),
std::get<1>(T)))));
break;
case generic:
for (unsigned int block_id = 0; block_id < _nb_blocks_per_stack;
Expand All @@ -234,8 +312,7 @@ namespace dd
"genericBlock_" + std::to_string(block_id) + "_stack_"
+ std::to_string(stack_id),
GenericBlock(_hidden_layer_units, _thetas_dims[stack_id],
_backcast_length, _forecast_length,
_data_size))));
_backcast_length, _data_size))));
break;
default:
break;
Expand All @@ -248,7 +325,7 @@ namespace dd
torch::Tensor NBeats::forward(torch::Tensor x)
{
torch::Tensor b = x;
torch::Tensor f = torch::zeros({ x.size(0), _forecast_length, _data_size })
torch::Tensor f = torch::zeros({ x.size(0), _backcast_length, _data_size })
.to(_device);

int stack_counter = 0;
Expand All @@ -264,4 +341,80 @@ namespace dd
}
return torch::stack({ b, f }, 0);
}

bool NBeats::extractable(std::string extract_layer) const
{
std::vector<std::string> els = extractable_layers();
return std::find(els.begin(), els.end(), extract_layer) != els.end();
}

std::vector<std::string> NBeats::extractable_layers() const
{
std::vector<std::string> els;
for (unsigned long int si = 0; si < _stacks.size(); ++si)
for (unsigned long int bi = 0; bi < _stacks[si].size(); ++bi)
{
for (auto item : _stacks[si][bi].ptr()->named_children().keys())
els.push_back(std::to_string(si) + ":" + std::to_string(bi) + ":"
+ item);
els.push_back(std::to_string(si) + ":" + std::to_string(bi)
+ ":end");
}
return els;
}

torch::Tensor NBeats::extract(torch::Tensor x, std::string extract_layer)
{

std::vector<std::string> subst;
std::string item;
size_t pos_start = 0, pos_end;
while ((pos_end = extract_layer.find(":", pos_start)) != std::string::npos)
{
subst.push_back(extract_layer.substr(pos_start, pos_end - pos_start));
pos_start = pos_end + 1;
}
subst.push_back(extract_layer.substr(pos_start));

int num_stack = std::stoi(subst[0]);
int num_block = std::stoi(subst[1]);
bool endofblock = subst[2] == "end";

torch::Tensor b = x;
torch::Tensor f = torch::zeros({ x.size(0), _backcast_length, _data_size })
.to(_device);

int stack_counter = 0;
for (Stack s : _stacks)
{
int block_counter = 0;
for (torch::nn::AnyModule m : s)
{
if (num_stack == stack_counter && num_block == block_counter
&& !endofblock)
{
if (_stack_types[stack_counter] == trend)
return m.get<TrendBlock>()->extract(b, subst[2]);
if (_stack_types[stack_counter] == seasonality)
return m.get<SeasonalityBlock>()->extract(b, subst[2]);
if (_stack_types[stack_counter] == generic)
return m.get<GenericBlock>()->extract(b, subst[2]);
}
else
{
auto bf
= m.forward<std::tuple<torch::Tensor, torch::Tensor>>(b);

b = b - std::get<0>(bf);
f = f + std::get<1>(bf);
if (num_stack == stack_counter && num_block == block_counter
&& endofblock)
return torch::stack({ b, f }, 0);
block_counter++;
}
}
stack_counter++;
}
return torch::stack({ b, f }, 0);
}
}
Loading

0 comments on commit d37e182

Please sign in to comment.