Skip to content

Commit

Permalink
feat(nbeats): expose hidden size param in API
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and sileht committed Oct 6, 2020
1 parent 44a0db5 commit d7e5515
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 1 deletion.
11 changes: 10 additions & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -727,7 +727,16 @@ nclasses | int | yes | none | if set to some int, add a classi
self_supervised | string | yes | "" | self-supervised mode: "mask" for masked language model
embedding_size | int | yes | 768 | embedding size for NLP models
freeze_traced | bool | yes | false | Freeze the traced part of the net during finetuning (e.g. for classification)
template | string | yes | "" | for language models, either "bert" or "gpt2"
template | string | yes | "" | for language models, either "bert" or "gpt2", "recurrent" for LSTM-like models (including autoencoder), "nbeats" for nbeats model


Model instantiation parameters:

Parameter | Template | Type | Default | Description
--------- | --------- | ------ | ---------------------------- | -----------
template_params | nbeats | array of string | ["t2","s8","g3","b3","h10" ] | default means: trend stack with theta = 2, seasonal stack with theta = 8 , generic stack with theta = 3, 3 blocks per stacks, hidden unit size of 10 everywhere
layers | recurrent | array of string | [] | ["L50","L50"] means 2 layers of LSTMs with hidden size of 50. ["L100","L100", "T", "L300"] means an lstm autoencoder with encoder composed of 2 LSTM layers of hidden size 100 and decoder is one LSTM layer of hidden size 300


Solver:

Expand Down
6 changes: 6 additions & 0 deletions src/backends/torch/native/templates/nbeats.h
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,11 @@ namespace dd
_nb_blocks_per_stack
= std::stoi(s.substr(pos + _nbblock_str.size()));
}
else if ((pos = s.find(_hsize_str)) != std::string::npos)
{
_hidden_layer_units
= std::stoi(s.substr(pos + _hsize_str.size()));
}
else
{
throw MLLibBadParamException(
Expand Down Expand Up @@ -460,6 +465,7 @@ namespace dd
std::string _season_str = "s";
std::string _generic_str = "g";
std::string _nbblock_str = "b";
std::string _hsize_str = "h";
};
}
#endif

0 comments on commit d7e5515

Please sign in to comment.