Skip to content

Commit

Permalink
feat(torch): ranger optimizer (ie rectified ADAM + lookahead) + \
Browse files Browse the repository at this point in the history
adabelief + gradient centralizaton
  • Loading branch information
fantes authored and Bycob committed Oct 22, 2020
1 parent b517910 commit a3004f0
Show file tree
Hide file tree
Showing 16 changed files with 1,664 additions and 720 deletions.
11 changes: 10 additions & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -752,7 +752,16 @@ Parameter | Type | Optional | Default | Description
--------- | ---- | -------- | ------- | -----------
iterations | int | yes | N/A | Max number of solver's iterations
snapshot | int | yes | N/A | Iterations between model snapshots
solver_type | string | yes | SGD | from "SGD", "ADAGRAD", "RMSPROP", "ADAM"
solver_type | string | yes | SGD | from "SGD", "ADAGRAD", "RMSPROP", "ADAM", "RANGER", "RANGER_PLUS"
beta1 | real | yes | 0.9 | for RANGER* : beta1 param
beta2 | real | yes | 0.999 | for RANGER* : beta2 param
weight_decay | real | yes | 0.0 | for RANGER* : weight decay
rectified | bool | yes | true | for RANGER* : enable/disable rectified ADAM
lookahead | bool | yes | true | for RANGER* : enable/disable lookahead
lookahead_steps | int | yes | 6 | for RANGER* : if lookahead enabled, number of steps
lookahead_alpha | real | yes | 0.5 | for RANGER* : if lookahead enables, alpha param
adabelief | bool | yes | false for RANGER, true for RANGER_PLUS | for RANGER* : enable/disable adabelief
gradient_centralization | bool | yes | false for RANGER, true for RANGER_PLUS| for RANGER* : enable/disable gradient centralization
test_interval | int | yes | N/A | Number of iterations between testing phases
base_lr | real | yes | N/A | Initial learning rate
iter_size | int | yes | 1 | Number of passes (iter_size * batch_size) at every iteration
Expand Down
5 changes: 5 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,15 @@ if (USE_TORCH)
backends/torch/db.cpp
backends/torch/db_lmdb.cpp
backends/torch/native/templates/nbeats.cc
backends/torch/native/native_factory.cc
basegraph.cc
caffegraphinput.cc
backends/torch/torchgraphbackend.cc
graph.cc
backends/torch/torchsolver.cc
backends/torch/torchmodule.cc
backends/torch/torchutils.cc
backends/torch/optim/ranger.cc
)
if (NOT EXISTS ${CMAKE_SOURCE_DIR}/src/caffe.proto)
file(DOWNLOAD https://raw.githubusercontent.com/jolibrain/caffe/master/src/caffe/proto/caffe.proto ${CMAKE_SOURCE_DIR}/src/caffe.proto)
Expand Down
65 changes: 65 additions & 0 deletions src/backends/torch/native/native_factory.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
/**
* DeepDetect
* Copyright (c) 2019-2020 Jolibrain
* Author: Guillaume Infantes <[email protected]>
*
* This file is part of deepdetect.
*
* deepdetect is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* deepdetect is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/

#include "native_factory.h"

namespace dd
{

template <class TInputConnectorStrategy>
NativeModule *
NativeFactory::from_template(const std::string tdef,
const APIData template_params,
const TInputConnectorStrategy &inputc)
{
(void)(tdef);
(void)(template_params);
(void)(inputc);
return nullptr;
}

template <>
NativeModule *NativeFactory::from_template<CSVTSTorchInputFileConn>(
const std::string tdef, const APIData template_params,
const CSVTSTorchInputFileConn &inputc)
{
if (tdef.find("nbeats") != std::string::npos)
{
std::vector<std::string> p;
if (template_params.has("template_params"))
p = template_params.get("template_params")
.get<std::vector<std::string>>();
return new NBeats(inputc, p);
}
else
return nullptr;
}

template NativeModule *
NativeFactory::from_template(const std::string tdef,
const APIData template_params,
const TxtTorchInputFileConn &inputc);

template NativeModule *
NativeFactory::from_template(const std::string tdef,
const APIData template_params,
const ImgTorchInputFileConn &inputc);
}
46 changes: 22 additions & 24 deletions src/backends/torch/native/native_factory.h
Original file line number Diff line number Diff line change
@@ -1,3 +1,24 @@
/**
* DeepDetect
* Copyright (c) 2019-2020 Jolibrain
* Author: Guillaume Infantes <[email protected]>
*
* This file is part of deepdetect.
*
* deepdetect is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* deepdetect is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef NATIVE_FACTORY_H
#define NATIVE_FACTORY_H

Expand All @@ -14,13 +35,7 @@ namespace dd
template <class TInputConnectorStrategy>
static NativeModule *from_template(const std::string tdef,
const APIData template_params,
const TInputConnectorStrategy &inputc)
{
(void)(tdef);
(void)(template_params);
(void)(inputc);
return nullptr;
}
const TInputConnectorStrategy &inputc);

static bool valid_template_def(std::string tdef)
{
Expand All @@ -36,22 +51,5 @@ namespace dd
return false;
}
};

template <>
NativeModule *NativeFactory::from_template<CSVTSTorchInputFileConn>(
const std::string tdef, const APIData template_params,
const CSVTSTorchInputFileConn &inputc)
{
if (tdef.find("nbeats") != std::string::npos)
{
std::vector<std::string> p;
if (template_params.has("template_params"))
p = template_params.get("template_params")
.get<std::vector<std::string>>();
return new NBeats(inputc, p);
}
else
return nullptr;
}
}
#endif
28 changes: 27 additions & 1 deletion src/backends/torch/native/native_net.h
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
/**
* DeepDetect
* Copyright (c) 2019-2020 Jolibrain
* Author: Guillaume Infantes <[email protected]>
*
* This file is part of deepdetect.
*
* deepdetect is free software: you can redistribute it and/or modify
* it under the terms of the GNU Lesser General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* deepdetect is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Lesser General Public License for more details.
*
* You should have received a copy of the GNU Lesser General Public License
* along with deepdetect. If not, see <http://www.gnu.org/licenses/>.
*/

#ifndef NATIVE_NET_H
#define NATIVE_NET_H

#include "torch/torch.h"
#pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wunused-parameter"
#include <torch/torch.h>
#pragma GCC diagnostic pop

#include "../torchinputconns.h"

namespace dd
{
Expand Down
Loading

0 comments on commit a3004f0

Please sign in to comment.