Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add the Bootstrapped Dual Policy Iteration algorithm for discrete action spaces #35

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ See documentation for the full list of included features.
**RL Algorithms**:
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
- [Bootstrapped Dual Policy Iteration](https://arxiv.org/abs/1903.04193)

**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
Expand Down
3 changes: 2 additions & 1 deletion docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,12 @@ Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Pr
============ =========== ============ ================= =============== ================
TQC ✔️ ❌ ❌ ❌ ❌
QR-DQN ️❌ ️✔️ ❌ ❌ ❌
BDPI ❌ ✔️ ❌ ❌ ✔️
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it says "multiprocessing" but I only see tests with one environment in the code...
and you probably need DLR-RM/stable-baselines3#439 to make it work with multiple envs.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I may have mis-understood what "Multiprocessing" in this document means, and what PPO in stable-baselines is doing.

BDPI distributes its training updates on several processes, even if only one environment is used. To me, this is multi-processing, comparable to PPO that uses MPI to distribute compute. But if PPO uses multiple environments to be able to do multiprocessing, then I understand that "Multiprocessing" in the documentation means "compatible with multiple envs", not just "fast because it uses several processes".

Should I add a note, or a second column to distinguish "multiple envs" from "multi-processing with one env"?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To me, this is multi-processing, comparable to PPO that uses MPI to distribute compute.

PPO with MPI distributes env and training compute (and is currently not implemented in SB3).

then I understand that "Multiprocessing" in the documentation means "compatible with multiple envs"

yes, that's the meaning (because we can use SubProcEnv to distribute data collection)

Should I add a note, or a second column to distinguish "multiple envs" from "multi-processing with one env"?

no, I think we have already enough columns in this table.

============ =========== ============ ================= =============== ================


.. note::
Non-array spaces such as ``Dict`` or ``Tuple`` are not currently supported by any algorithm.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it seems that we forgot to update the contrib doc when adding support for Dict obs, the correct formulation should be the one from https://github.com/DLR-RM/stable-baselines3/blob/master/docs/guide/algos.rst

Non-array spaces such as ``Dict`` or ``Tuple`` are only supported by BDPI, using ``MultiInputPolicy`` instead of ``MlpPolicy`` as the policy.

Actions ``gym.spaces``:

Expand Down
13 changes: 13 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,19 @@ Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.
model.learn(total_timesteps=10000, log_interval=4)
model.save("qrdqn_cartpole")

BDPI
----

Train a Bootstrapped Dual Policy Iteration (BDPI) agent on the LunarLander environment

.. code-block:: python

from sb3_contrib import BDPI

policy_kwargs = dict(n_critics=8)
model = BDPI("MlpPolicy", "LunarLander-v2", policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=50000, log_interval=4)
model.save("bdpi_lunarlander")

.. PyBullet: Normalizing input features
.. ------------------------------------
Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d

modules/tqc
modules/qrdqn
modules/bdpi

.. toctree::
:maxdepth: 1
Expand Down
136 changes: 136 additions & 0 deletions docs/modules/bdpi.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
.. _bdpi:

.. automodule:: sb3_contrib.bdpi

BDPI
====

`Bootstrapped Dual Policy Iteration <https://arxiv.org/abs/1903.04193>`_ is an actor-critic algorithm for
discrete action spaces. The distinctive components of BDPI are as follows:

- Like Bootstrapped DQN, it uses several critics, with each critic having a Qa and Qb network
(like Clipped DQN).
- The BDPI critics, inspired from the DQN literature, are therefore off-policy. They don't know
about the actor, and do not use any form of off-policy corrections to evaluate the actor. They
instead directly approximate Q*, the optimal value function.
- The actor is trained with an equation inspired from Conservative Policy Iteration, instead of
Policy Gradient (as used by A2C, PPO, SAC, DDPG, etc). This use of Conservative Policy Iteration
is what allows the BDPI actor to be compatible with off-policy critics.

As a result, BDPI can be configured to be highly sample-efficient, at the cost of compute efficiency.
The off-policy critics can learn aggressively (many samples, many gradient steps), as they don't have
to remain close to the actor. The actor then learns from a mixture of high-quality critics, leading
to good exploration even in challenging environments (see the Table environment described in the paper
linked above).

Can I use?
----------

- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔️ ✔️
Box ❌ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
Dict ❌ ✔️
============= ====== ===========

Example
-------

Train a BDPI agent on ``LunarLander-v2``, with hyper-parameters tuned by Optuna in rl-baselines3-zoo:

.. code-block:: python

import gym

from sb3_contrib import BDPI

model = BDPI(
"MlpPolicy",
'LunarLander-v2',
actor_lr=0.01, # How fast the actor pursues the greedy policy of the critics
critic_lr=0.234, # Q-Learning learning rate of the critics
batch_size=256, # 256 experiences sampled from the buffer every time-step, for every critic
buffer_size=100000,
gradient_steps=64, # Actor and critics fit for 64 gradient steps per time-step
learning_rate=0.00026, # Adam optimizer learning rate
policy_kwargs=dict(net_arch=[64, 64], n_critics=8), # 8 critics
verbose=1,
tensorboard_log='./tb_log'
)

model.learn(total_timesteps=50000)
model.save("bdpi_lunarlander")

del model # remove to demonstrate saving and loading

model = BDPI.load("bdpi_lunarlander")

obs = env.reset()
while True:
action, _states = model.predict(obs)
obs, rewards, dones, info = env.step(action)
env.render()


Results
-------

LunarLander
^^^^^^^^^^^

Results for BDPI are available in `this Github issue <https://github.com/DLR-RM/stable-baselines3/issues/499>`_.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

as you are aiming for sample efficiency, I would prefer a comparison to DQN, QR-DQN (with tuned hyperparameters, results are already linked in the documentation: #13)

Regarding which envs to compare too, please do the classic control ones + 2 Atari games at least (Pong, Breakout) using the zoo, so we can compare the results with QR-DQN and DQN.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would also like a comparison of the compromise sample efficiency vs training time (how much more time does it take to train?)


How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Clone the `rl-zoo repo <https://github.com/DLR-RM/rl-baselines3-zoo>`_:

.. code-block:: bash

git clone https://github.com/DLR-RM/rl-baselines3-zoo
cd rl-baselines3-zoo/


Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above, and ``$N`` with
the number of CPU cores in your machine):

.. code-block:: bash

python train.py --algo bdpi --env $ENV_ID --eval-episodes 10 --eval-freq 10000 -params threads:$N


Plot the results (here for LunarLander only):

.. code-block:: bash

python scripts/all_plots.py -a bdpi -e LunarLander -f logs/ -o logs/bdpi_results
python scripts/plot_from_file.py -i logs/bdpi_results.pkl -latex -l BDPI


Parameters
----------

.. autoclass:: BDPI
:members:
:inherited-members:


BDPI Policies
-------------

.. autoclass:: MlpPolicy
:members:

.. autoclass:: CnnPolicy
:members:

.. autoclass:: MultiInputPolicy
:members:
1 change: 1 addition & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

from sb3_contrib.bdpi import BDPI
from sb3_contrib.qrdqn import QRDQN
from sb3_contrib.tqc import TQC

Expand Down
2 changes: 2 additions & 0 deletions sb3_contrib/bdpi/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from sb3_contrib.bdpi.bdpi import BDPI
from sb3_contrib.bdpi.policies import CnnPolicy, MlpPolicy, MultiInputPolicy
Loading