Skip to content

Commit

Permalink
fix(torch/timeseries): unscale prediction output if needed
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and sileht committed Oct 8, 2020
1 parent 19e9674 commit aa30e88
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 1 deletion.
43 changes: 42 additions & 1 deletion src/backends/torch/torchlib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -502,6 +502,46 @@ namespace dd
empty_cuda_cache();
}

template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
class TMLModel>
double TorchLib<TInputConnectorStrategy, TOutputConnectorStrategy,
TMLModel>::unscale(double val, unsigned int k,
const TInputConnectorStrategy &inputc)
{
(void)inputc;
(void)k;
// unscaling is input connector specific
return val;
}

// full template specialization
template <>
double
TorchLib<CSVTSTorchInputFileConn, SupervisedOutput, TorchModel>::unscale(
double val, unsigned int k, const CSVTSTorchInputFileConn &inputc)

{
if (inputc._min_vals.empty() || inputc._max_vals.empty())
{
this->_logger->info("not unscaling output because no bounds "
"data found");
return val;
}
else
{

if (!inputc._dont_scale_labels)
{
double max = inputc._max_vals[inputc._label_pos[k]];
double min = inputc._min_vals[inputc._label_pos[k]];
if (inputc._scale_between_minus1_and_1)
val += 0.5;
val = val * (max - min) + min;
}
return val;
}
}

/*- from mllib -*/
template <class TInputConnectorStrategy, class TOutputConnectorStrategy,
class TMLModel>
Expand Down Expand Up @@ -1342,7 +1382,8 @@ namespace dd
for (unsigned int k = 0; k < this->_inputc._ntargets;
++k)
{
preds.push_back(output_acc[j][t][k]);
double res = output_acc[j][t][k];
preds.push_back(unscale(res, k, inputc));
}
APIData ts;
ts.add("out", preds);
Expand Down
3 changes: 3 additions & 0 deletions src/backends/torch/torchlib.h
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,9 @@ namespace dd
void snapshot(int64_t elapsed_it, torch::optim::Optimizer &optimizer);

void remove_model(int64_t it);

double unscale(double val, unsigned int k,
const TInputConnectorStrategy &inputc);
};
}

Expand Down

0 comments on commit aa30e88

Please sign in to comment.