Skip to content

Commit

Permalink
fix(torch): clip gradient in rectified adam as stated in annex B of o…
Browse files Browse the repository at this point in the history
…riginal paper
  • Loading branch information
fantes authored and sileht committed Nov 10, 2020
1 parent f5c0abb commit 1561269
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 3 deletions.
2 changes: 2 additions & 0 deletions docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -690,6 +690,8 @@ iterations | int | yes | N/A | Max number of solver'
snapshot | int | yes | N/A | Iterations between model snapshots
snapshot_prefix | string | yes | empty | Prefix to snapshot file, supports repository
solver_type | string | yes | SGD | from "SGD", "ADAGRAD", "NESTEROV", "RMSPROP", "ADADELTA", "ADAM", "AMSGRAD", "RANGER", "RANGER_PLUS", "ADAMW", "SGDW", "AMSGRADW" (*W version for decoupled weight decay, RANGER_PLUS is ranger + adabelief + centralized_gradient)
clip | bool | yes | false (true if RANGER* selected) | clip gradients, implemented only in ranger
clip_value | real | yes | 5.0 | value for clipping gradients (used only by RANGER)
rectified | bool | yes | false | rectified momentum variance ie https://arxiv.org/abs/1908.03265 valid for ADAM[W] and AMSGRAD[W]
adabelief | bool | yes | false | adabelief mod for ADAM https://arxiv.org/abs/2010.07468
gradient_centralization | bool | yes | false | centralized gradient mod for ADAM ie https://arxiv.org/abs/2004.01461v2
Expand Down
15 changes: 13 additions & 2 deletions src/backends/torch/optim/ranger.cc
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,10 @@ namespace dd
&& (std::get<1>(lhs.betas()) == std::get<1>(rhs.betas()))
&& (lhs.eps() == rhs.eps())
&& (lhs.weight_decay() == rhs.weight_decay())
&& (lhs.clip() == rhs.clip())
&& (lhs.clip_value() == rhs.clip_value())
&& (lhs.rectified() == rhs.rectified())
&& (lhs.decoupled_wd() == rhs.decoupled_wd())
&& (lhs.lookahead() == rhs.lookahead())
&& (lhs.adabelief() == rhs.adabelief())
&& (lhs.gradient_centralization() == rhs.gradient_centralization())
Expand All @@ -60,6 +63,8 @@ namespace dd
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(betas);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(eps);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(weight_decay);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(clip);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(clip_value);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(decoupled_wd);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(rectified);
_TORCH_OPTIM_SERIALIZE_TORCH_ARG(lookahead);
Expand All @@ -75,6 +80,8 @@ namespace dd
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(betas_t, betas);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, eps);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, weight_decay);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, clip);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(double, clip_value);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, decoupled_wd);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, rectified);
_TORCH_OPTIM_DESERIALIZE_TORCH_ARG(bool, lookahead);
Expand Down Expand Up @@ -126,8 +133,9 @@ namespace dd
continue;
}
auto grad = p.grad();
TORCH_CHECK(
!grad.is_sparse(), "Ranger does not support sparse gradients" /*, please consider SparseRanger instead*/);

TORCH_CHECK(!grad.is_sparse(),
"Ranger does not support sparse gradients");
auto param_state
= state_.find(c10::guts::to_string(p.unsafeGetTensorImpl()));
auto &options = static_cast<RangerOptions &>(group.options());
Expand Down Expand Up @@ -174,6 +182,9 @@ namespace dd
grad.add_(-grad.mean(torch::IntArrayRef(dim), true));
}

if (options.clip())
grad.clamp_(-options.clip_value(), options.clip_value());

exp_avg.mul_(beta1).add_(grad, 1 - beta1); // m_t

if (options.adabelief())
Expand Down
4 changes: 4 additions & 0 deletions src/backends/torch/optim/ranger.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@ namespace dd
TORCH_ARG(betas_t, betas) = std::make_tuple(0.9, 0.999);
TORCH_ARG(double, eps) = 1e-8;
TORCH_ARG(double, weight_decay) = 0.0;
TORCH_ARG(bool, clip) = false;
TORCH_ARG(double, clip_value) = 5.0;
TORCH_ARG(bool, decoupled_wd) = false;
TORCH_ARG(bool, rectified) = true;
TORCH_ARG(bool, lookahead) = true;
Expand Down Expand Up @@ -103,6 +105,8 @@ namespace dd
"Invalid beta parameter at index 1: ", std::get<1>(betas));
TORCH_CHECK(defaults.weight_decay() >= 0,
"Invalid weight_decay value: ", defaults.weight_decay());
TORCH_CHECK(!defaults.clip() || defaults.clip_value() >= 0,
"Invalid clip value: ", defaults.clip_value());
TORCH_CHECK(defaults.lsteps() >= 0,
"Invalid lookahead steps: ", defaults.lsteps());
TORCH_CHECK(defaults.lalpha() >= 0,
Expand Down
14 changes: 13 additions & 1 deletion src/backends/torch/torchsolver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,9 @@ namespace dd
if (ad_solver.has("solver_type"))
_solver_type = ad_solver.get("solver_type").get<std::string>();

if (_solver_type == "RANGER" || _solver_type == "RANGER_PLUS")
_clip = true;

if (_solver_type == "RANGER_PLUS")
{
_adabelief = true;
Expand All @@ -41,6 +44,10 @@ namespace dd
_beta1 = ad_solver.get("beta1").get<double>();
if (ad_solver.has("beta"))
_beta2 = ad_solver.get("beta2").get<double>();
if (ad_solver.has("clip"))
_clip = ad_solver.get("clip").get<bool>();
if (ad_solver.has("clip_value"))
_clip_value = ad_solver.get("clip_value").get<double>();
if (ad_solver.has("rectified"))
_rectified = ad_solver.get("rectified").get<bool>();
if (ad_solver.has("lookahead"))
Expand Down Expand Up @@ -102,10 +109,15 @@ namespace dd
.adabelief(_adabelief)
.gradient_centralization(_gc)
.lsteps(_lsteps)
.lalpha(_lalpha)));
.lalpha(_lalpha)
.clip(_clip)
.clip_value(_clip_value)));
this->_logger->info("base_lr: {}", _base_lr);
this->_logger->info("beta_1: {}", _beta1);
this->_logger->info("beta_2: {}", _beta2);
this->_logger->info("clip: {}", _clip);
if (_clip)
this->_logger->info("clip_value: {}", _clip_value);
this->_logger->info("weight_decay: {}", _weight_decay);
this->_logger->info("rectified: {}", _rectified);
this->_logger->info("lookahead: {}", _lookahead);
Expand Down
2 changes: 2 additions & 0 deletions src/backends/torch/torchsolver.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,8 @@ namespace dd
int _lsteps
= 5; /**< for RANGER, if lookahead: number of lookahead steps */
double _lalpha = 0.5; /**< for RANGER, if lookahead: weight of lookahead */
bool _clip = false; /**< for RANGER , clip gradients */
double _clip_value = 5.0; /**< for RANGER, value to clip gradients to */
double _weight_decay = 0.0; /**< weight decay value*/
bool _decoupled_wd = false; /**< for RANGER : use decoupled weight decay,
NOT YET IMPLEMENTED */
Expand Down

0 comments on commit 1561269

Please sign in to comment.