Skip to content

Commit

Permalink
feat(torch): nbeats
Browse files Browse the repository at this point in the history
feat(nbeats): make nbeats able to handle signals that are more than 1D

feat(nbeats): expose nbeats net definition in api
  • Loading branch information
fantes authored and sileht committed Sep 24, 2020
1 parent a471b82 commit f288665
Show file tree
Hide file tree
Showing 15 changed files with 1,289 additions and 86 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -773,7 +773,7 @@ if (USE_TORCH)
if (NOT USE_CPU_ONLY AND CUDA_FOUND)
list(APPEND TORCH_LIB_DEPS ${TORCH_LOCATION}/lib/libc10_cuda.so ${TORCH_LOCATION}/lib/libtorch_cuda.so)
else()
list(APPEND TORCH_LIB_DEPS iomp5)
list(APPEND TORCH_LIB_DEPS ${TORCH_LOCATION}/lib/libtorch_cpu.so iomp5)
endif()

set(TORCH_INC_DIR ${TORCH_LOCATION}/include/ ${TORCH_LOCATION}/include/torch/csrc/api/include/ ${CMAKE_BINARY_DIR}/pytorch/src/pytorch/torch/include/torch/csrc/api/include ${TORCH_LOCATION}/.. ${CMAKE_BINARY_DIR}/src)
Expand Down
1 change: 0 additions & 1 deletion examples/all/sinus/gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
os.remove("predict/"+f)
os.rmdir("predict")


os.mkdir("train")
os.mkdir("test")
os.mkdir("predict")
Expand Down
1 change: 1 addition & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ if (USE_TORCH)
backends/torch/torchinputconns.cc
backends/torch/db.cpp
backends/torch/db_lmdb.cpp
backends/torch/native/templates/nbeats.cc
basegraph.cc
caffegraphinput.cc
backends/torch/torchgraphbackend.cc
Expand Down
7 changes: 7 additions & 0 deletions src/backends/torch/native/native.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ifndef NATIVE_H
#define NATIVE_H

#include "native_net.h"
#include "native_factory.h"

#endif
55 changes: 55 additions & 0 deletions src/backends/torch/native/native_factory.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
#ifndef NATIVE_FACTORY_H
#define NATIVE_FACTORY_H

#include "native_net.h"
#include "./templates/nbeats.h"
#include "../torchinputconns.h"
#include "apidata.h"

namespace dd
{
class NativeFactory
{
public:
template <class TInputConnectorStrategy>
static NativeModule *from_template(const std::string tdef,
const APIData template_params,
const TInputConnectorStrategy &inputc)
{
(void)(tdef);
(void)(template_params);
(void)(inputc);
return nullptr;
}

static bool valid_template_def(std::string tdef)
{
if (tdef.find("nbeats") != std::string::npos)
return true;
return false;
}

static bool is_timeserie(std::string tdef)
{
if (tdef.find("nbeats") != std::string::npos)
return true;
return false;
}
};

template <>
NativeModule *NativeFactory::from_template<CSVTSTorchInputFileConn>(
const std::string tdef, const APIData template_params,
const CSVTSTorchInputFileConn &inputc)
{
if (tdef.find("nbeats") != std::string::npos)
{
std::vector<std::string> p = template_params.get("template_params")
.get<std::vector<std::string>>();
return new NBeats(inputc, p);
}
else
return nullptr;
}
}
#endif
74 changes: 74 additions & 0 deletions src/backends/torch/native/native_net.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#ifndef NATIVE_NET_H
#define NATIVE_NET_H

#include "torch/torch.h"

namespace dd
{

class NativeModule : public torch::nn::Module
{
public:
virtual torch::Tensor forward(torch::Tensor x) = 0;
virtual ~NativeModule()
{
}
/**
* \brief see torch::module::to
* @param device cpu / gpu
* @param non_blocking
*/
virtual void to(torch::Device device, bool non_blocking = false)
{
torch::nn::Module::to(device, non_blocking);
_device = device;
}

/**
* \brief see torch::module::to
* @param dtype : torch::kFloat32 or torch::kFloat64
* @param non_blocking
*/
virtual void to(torch::Dtype dtype, bool non_blocking = false)
{
torch::nn::Module::to(dtype, non_blocking);
_dtype = dtype;
}

/**
* \brief see torch::module::to
* @param device cpu / gpu
* @param dtype : torch::kFloat32 or torch::kFloat64
* @param non_blocking
*/
virtual void to(torch::Device device, torch::Dtype dtype,
bool non_blocking = false)
{
torch::nn::Module::to(device, dtype, non_blocking);
_device = device;
_dtype = dtype;
}

virtual torch::Tensor cleanup_output(torch::Tensor output)
{
return output;
}

virtual torch::Tensor loss(std::string loss, torch::Tensor input,
torch::Tensor output, torch::Tensor target)
= 0;

virtual void update_input_connector(TorchInputInterface &inputc)
{
(void)(inputc);
}

protected:
torch::Dtype _dtype
= torch::kFloat32; /**< type of data stored in tensors */
torch::Device _device
= torch::DeviceType::CPU; /**< device to compute on */
};
}

#endif
Loading

0 comments on commit f288665

Please sign in to comment.