Skip to content

Commit

Permalink
fix(nbeats): much lower memory use in case of large dim signals
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and sileht committed Sep 29, 2020
1 parent f513f17 commit 639e222
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 77 deletions.
138 changes: 65 additions & 73 deletions src/backends/torch/native/templates/nbeats.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,16 @@ namespace dd
_fc3 = register_module("fc3", torch::nn::Linear(_units, _units));
_fc4 = register_module("fc4", torch::nn::Linear(_units, _units));
_theta_f_fc = register_module(
"theta_f_fc", torch::nn::Linear(torch::nn::LinearOptions(
_units, _thetas_dim * _data_size)
.bias(false)));
"theta_f_fc",
torch::nn::Linear(
torch::nn::LinearOptions(_units, _thetas_dim).bias(false)));
if (_share_thetas)
_theta_b_fc = _theta_f_fc;
else
_theta_b_fc = register_module(
"theta_b_fc", torch::nn::Linear(torch::nn::LinearOptions(
_units, _thetas_dim * _data_size)
.bias(false)));
"theta_b_fc",
torch::nn::Linear(
torch::nn::LinearOptions(_units, _thetas_dim).bias(false)));
}

torch::Tensor NBeats::BlockImpl::first_forward(torch::Tensor x)
Expand Down Expand Up @@ -56,60 +56,56 @@ namespace dd
unsigned int p2 = (p % 2 == 0) ? p / 2 : p / 2 + 1;
std::vector<float> tdata;

for (unsigned int d1 = 0; d1 < _data_size; ++d1)
for (unsigned int i = 0; i < p1; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _forecast_linspace.size(); ++j)
tdata.push_back(std::cos(2 * M_PI * i * _forecast_linspace[j]));
for (unsigned int i = 0; i < p1; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _forecast_linspace.size(); ++j)
tdata.push_back(std::cos(2 * M_PI * i * _forecast_linspace[j]));
torch::Tensor s1
= torch::from_blob(tdata.data(),
{ _data_size * p1,
static_cast<long int>(_forecast_linspace.size())
* _data_size },
options)
= torch::from_blob(
tdata.data(),
{ p1, static_cast<long int>(_forecast_linspace.size())
* _data_size },
options)
.clone();

tdata.clear();
for (unsigned int d1 = 0; d1 < _data_size; ++d1)
for (unsigned int i = 0; i < p2; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _forecast_linspace.size(); ++j)
tdata.push_back(std::sin(2 * M_PI * i * _forecast_linspace[j]));
for (unsigned int i = 0; i < p2; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _forecast_linspace.size(); ++j)
tdata.push_back(std::sin(2 * M_PI * i * _forecast_linspace[j]));
torch::Tensor s2
= torch::from_blob(tdata.data(),
{ _data_size * p2,
static_cast<long int>(_forecast_linspace.size())
* _data_size },
options)
= torch::from_blob(
tdata.data(),
{ p2, static_cast<long int>(_forecast_linspace.size())
* _data_size },
options)
.clone();
torch::Tensor fS = torch::cat({ s1, s2 });

tdata.clear();
for (unsigned int d1 = 0; d1 < _data_size; ++d1)
for (unsigned int i = 0; i < p1; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _backcast_linspace.size(); ++j)
tdata.push_back(std::cos(2 * M_PI * i * _backcast_linspace[j]));
for (unsigned int i = 0; i < p1; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _backcast_linspace.size(); ++j)
tdata.push_back(std::cos(2 * M_PI * i * _backcast_linspace[j]));
torch::Tensor ss1
= torch::from_blob(tdata.data(),
{ _data_size * p1,
static_cast<long int>(_backcast_linspace.size())
* _data_size },
options)
= torch::from_blob(
tdata.data(),
{ p1, static_cast<long int>(_backcast_linspace.size())
* _data_size },
options)
.clone();

tdata.clear();
for (unsigned int d1 = 0; d1 < _data_size; ++d1)
for (unsigned int i = 0; i < p2; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _backcast_linspace.size(); ++j)
tdata.push_back(std::sin(2 * M_PI * i * _backcast_linspace[j]));
for (unsigned int i = 0; i < p2; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _backcast_linspace.size(); ++j)
tdata.push_back(std::sin(2 * M_PI * i * _backcast_linspace[j]));
torch::Tensor ss2
= torch::from_blob(tdata.data(),
{ _data_size * p2,
static_cast<long int>(_backcast_linspace.size())
* _data_size },
options)
= torch::from_blob(
tdata.data(),
{ p2, static_cast<long int>(_backcast_linspace.size())
* _data_size },
options)
.clone();

torch::Tensor bS = torch::cat({ ss1, ss2 });
Expand All @@ -124,36 +120,32 @@ namespace dd
unsigned int p = thetas_dim;
std::vector<float> tdata;

for (unsigned int d1 = 0; d1 < _data_size; ++d1)
for (unsigned int i = 0; i < p; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _forecast_linspace.size(); ++j)
{
tdata.push_back(static_cast<float>(
powf(_forecast_linspace[j], static_cast<float>(i))));
;
}
fT = torch::from_blob(
tdata.data(),
{ static_cast<long int>(p) * static_cast<long int>(_data_size),
static_cast<long int>(_forecast_linspace.size())
* static_cast<long int>(_data_size) },
options)
for (unsigned int i = 0; i < p; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _forecast_linspace.size(); ++j)
{
tdata.push_back(static_cast<float>(
powf(_forecast_linspace[j], static_cast<float>(i))));
;
}
fT = torch::from_blob(tdata.data(),
{ static_cast<long int>(p),
static_cast<long int>(_forecast_linspace.size())
* static_cast<long int>(_data_size) },
options)
.clone();

tdata.clear();
for (unsigned int d1 = 0; d1 < _data_size; ++d1)
for (unsigned int i = 0; i < p; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _backcast_linspace.size(); ++j)
tdata.push_back(static_cast<float>(
powf(_backcast_linspace[j], static_cast<float>(i))));
bT = torch::from_blob(
tdata.data(),
{ static_cast<long int>(p) * static_cast<long int>(_data_size),
static_cast<long int>(_backcast_linspace.size())
* static_cast<long int>(_data_size) },
options)
for (unsigned int i = 0; i < p; ++i)
for (unsigned int d2 = 0; d2 < _data_size; ++d2)
for (unsigned int j = 0; j < _backcast_linspace.size(); ++j)
tdata.push_back(static_cast<float>(
powf(_backcast_linspace[j], static_cast<float>(i))));
bT = torch::from_blob(tdata.data(),
{ static_cast<long int>(p),
static_cast<long int>(_backcast_linspace.size())
* static_cast<long int>(_data_size) },
options)
.clone();
return std::make_tuple(bT.to(_device), fT.to(_device));
}
Expand Down
8 changes: 4 additions & 4 deletions src/backends/torch/native/templates/nbeats.h
Original file line number Diff line number Diff line change
Expand Up @@ -155,11 +155,11 @@ namespace dd
false, data_size)
{
_backcast_fc = register_module(
"backcast_fc", torch::nn::Linear(_thetas_dim * _data_size,
_backcast_length * _data_size));
"backcast_fc",
torch::nn::Linear(_thetas_dim, _backcast_length * _data_size));
_forecast_fc = register_module(
"forecast_fc", torch::nn::Linear(_thetas_dim * _data_size,
_forecast_length * _data_size));
"forecast_fc",
torch::nn::Linear(_thetas_dim, _forecast_length * _data_size));
}

GenericBlockImpl(GenericBlockImpl &b) : BlockImpl(b)
Expand Down

0 comments on commit 639e222

Please sign in to comment.