Skip to content

Commit

Permalink
fix(native): do not raise exception if no template_param is given
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and sileht committed Oct 1, 2020
1 parent 64d3c9f commit d0705ab
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 21 deletions.
6 changes: 4 additions & 2 deletions src/backends/torch/native/native_factory.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ namespace dd
{
if (tdef.find("nbeats") != std::string::npos)
{
std::vector<std::string> p = template_params.get("template_params")
.get<std::vector<std::string>>();
std::vector<std::string> p;
if (template_params.has("template_params"))
p = template_params.get("template_params")
.get<std::vector<std::string>>();
return new NBeats(inputc, p);
}
else
Expand Down
61 changes: 42 additions & 19 deletions src/backends/torch/native/templates/nbeats.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,22 @@
#include "mllibstrategy.h"
#include "../native_net.h"

#define NBEATS_DEFAULT_STACK_TYPES \
{ \
trend, seasonality, generic \
}
#define NBEATS_DEFAULT_NB_BLOCKS 3
#define NBEATS_DEFAULT_DATA_SIZE 1
#define NBEATS_DEFAULT_OUTPUT_SIZE 1
#define NBEATS_DEFAULT_BACKCAST_LENGTH 50
#define NBEATS_DEFAULT_FORECAST_LENGTH 50
#define NBEATS_DEFAULT_THETAS \
{ \
2, 8, 3 \
}
#define NBEATS_DEFAULT_SHARE_WEIGHTS false
#define NBEATS_DEFAULT_HIDDEN_LAYER_UNITS 10

namespace dd
{
class NBeats : public NativeModule
Expand Down Expand Up @@ -182,11 +198,15 @@ namespace dd
public:
NBeats(const CSVTSTorchInputFileConn &inputc,
std::vector<std::string> stackdef,
std::vector<BlockType> stackTypes = { trend, seasonality, generic },
int nb_blocks_per_stack = 3, int data_size = 1, int output_size = 1,
int backcast_length = 50, int forecast_length = 10,
std::vector<int> thetas_dims = { 2, 8, 3 },
bool share_weights_in_stack = false, int hidden_layer_units = 10)
std::vector<BlockType> stackTypes = NBEATS_DEFAULT_STACK_TYPES,
int nb_blocks_per_stack = NBEATS_DEFAULT_NB_BLOCKS,
int data_size = NBEATS_DEFAULT_DATA_SIZE,
int output_size = NBEATS_DEFAULT_OUTPUT_SIZE,
int backcast_length = NBEATS_DEFAULT_BACKCAST_LENGTH,
int forecast_length = NBEATS_DEFAULT_FORECAST_LENGTH,
std::vector<int> thetas_dims = NBEATS_DEFAULT_THETAS,
bool share_weights_in_stack = NBEATS_DEFAULT_SHARE_WEIGHTS,
int hidden_layer_units = NBEATS_DEFAULT_HIDDEN_LAYER_UNITS)
: _data_size(data_size), _output_size(output_size),
_backcast_length(backcast_length), _forecast_length(forecast_length),
_hidden_layer_units(hidden_layer_units),
Expand All @@ -199,11 +219,14 @@ namespace dd
create_nbeats();
}
NBeats()
: _data_size(1), _output_size(1), _backcast_length(50),
_forecast_length(10), _hidden_layer_units(1024),
_nb_blocks_per_stack(3), _share_weights_in_stack(false),
_stack_types({ trend, seasonality, generic }),
_thetas_dims({ 2, 8, 3 })
: _data_size(1), _output_size(1),
_backcast_length(NBEATS_DEFAULT_BACKCAST_LENGTH),
_forecast_length(NBEATS_DEFAULT_FORECAST_LENGTH),
_hidden_layer_units(NBEATS_DEFAULT_HIDDEN_LAYER_UNITS),
_nb_blocks_per_stack(NBEATS_DEFAULT_NB_BLOCKS),
_share_weights_in_stack(NBEATS_DEFAULT_SHARE_WEIGHTS),
_stack_types(NBEATS_DEFAULT_STACK_TYPES),
_thetas_dims(NBEATS_DEFAULT_THETAS)
{
create_nbeats();
}
Expand Down Expand Up @@ -393,16 +416,16 @@ namespace dd
}

protected:
unsigned int _data_size;
unsigned int _output_size;
unsigned int _backcast_length;
unsigned int _forecast_length;
unsigned int _hidden_layer_units;
unsigned int _nb_blocks_per_stack;
bool _share_weights_in_stack;
std::vector<BlockType> _stack_types;
unsigned int _data_size = NBEATS_DEFAULT_DATA_SIZE;
unsigned int _output_size = NBEATS_DEFAULT_OUTPUT_SIZE;
unsigned int _backcast_length = NBEATS_DEFAULT_BACKCAST_LENGTH;
unsigned int _forecast_length = NBEATS_DEFAULT_FORECAST_LENGTH;
unsigned int _hidden_layer_units = NBEATS_DEFAULT_HIDDEN_LAYER_UNITS;
unsigned int _nb_blocks_per_stack = NBEATS_DEFAULT_NB_BLOCKS;
bool _share_weights_in_stack = NBEATS_DEFAULT_SHARE_WEIGHTS;
std::vector<BlockType> _stack_types = NBEATS_DEFAULT_STACK_TYPES;
std::vector<Stack> _stacks;
std::vector<int> _thetas_dims;
std::vector<int> _thetas_dims = NBEATS_DEFAULT_THETAS;
torch::nn::Linear _fcn{ nullptr };
torch::Device _device = torch::Device("cpu");
std::vector<float> _backcast_linspace;
Expand Down

0 comments on commit d0705ab

Please sign in to comment.