Skip to content

Commit

Permalink
feat(caffe): add new optimizers flavors to API
Browse files Browse the repository at this point in the history
  • Loading branch information
fantes authored and sileht committed Oct 23, 2020
1 parent 18ba916 commit d534a16
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 1 deletion.
4 changes: 3 additions & 1 deletion docs/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -689,8 +689,10 @@ Parameter | Type | Optional | Default | Description
iterations | int | yes | N/A | Max number of solver's iterations
snapshot | int | yes | N/A | Iterations between model snapshots
snapshot_prefix | string | yes | empty | Prefix to snapshot file, supports repository
solver_type | string | yes | SGD | from "SGD", "ADAGRAD", "NESTEROV", "RMSPROP", "ADADELTA", "ADAM", "AMSGRAD", "ADAMW", "SGDW", "AMSGRADW" (*W version for decoupled weight decay)
solver_type | string | yes | SGD | from "SGD", "ADAGRAD", "NESTEROV", "RMSPROP", "ADADELTA", "ADAM", "AMSGRAD", "RANGER", "RANGER_PLUS", "ADAMW", "SGDW", "AMSGRADW" (*W version for decoupled weight decay, RANGER_PLUS is ranger + adabelief + centralized_gradient)
rectified | bool | yes | false | rectified momentum variance ie https://arxiv.org/abs/1908.03265 valid for ADAM[W] and AMSGRAD[W]
adabelief | bool | yes | false | adabelief mod for ADAM https://arxiv.org/abs/2010.07468
gradient_centralization | bool | yes | false | centralized gradient mod for ADAM ie https://arxiv.org/abs/2004.01461v2
test_interval | int | yes | N/A | Number of iterations between testing phases
test_initialization | bool | true | N/A | Whether to start training by testing the network
lr_policy | string | yes | N/A | learning rate policy ("step", "inv", "fixed", "sgdr", ...)
Expand Down
21 changes: 21 additions & 0 deletions src/backends/caffe/caffelib.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1480,6 +1480,22 @@ namespace dd
caffe::SolverParameter_SolverType_ADAM);
solver_param.set_amsgrad(true);
}
else if (strcasecmp(solver_type.c_str(), "RANGER") == 0)
{
solver_param.set_solver_type(
caffe::SolverParameter_SolverType_ADAM);
solver_param.set_rectified(true);
solver_param.set_lookahead(true);
}
else if (strcasecmp(solver_type.c_str(), "RANGER_PLUS") == 0)
{
solver_param.set_solver_type(
caffe::SolverParameter_SolverType_ADAM);
solver_param.set_rectified(true);
solver_param.set_lookahead(true);
solver_param.set_adabelief(true);
solver_param.set_gc(true);
}
else if (strcasecmp(solver_type.c_str(), "ADAMW") == 0)
{
solver_param.set_solver_type(
Expand Down Expand Up @@ -1585,6 +1601,11 @@ namespace dd
solver_param.set_rms_decay(ad_solver.get("rms_decay").get<double>());
if (ad_solver.has("iter_size"))
solver_param.set_iter_size(ad_solver.get("iter_size").get<int>());
if (ad_solver.has("adabelief"))
solver_param.set_adabelief(ad_solver.get("adabelief").get<bool>());
if (ad_solver.has("gradient_centralization"))
solver_param.set_gc(
ad_solver.get("gradient_centralization").get<bool>());
if (ad_solver.has("lookahead"))
solver_param.set_lookahead(ad_solver.get("lookahead").get<bool>());
if (ad_solver.has("lookahead_steps"))
Expand Down

0 comments on commit d534a16

Please sign in to comment.