Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DuelingDQN #127

Draft
wants to merge 16 commits into
base: master
Choose a base branch
from
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ See documentation for the full list of included features.

**RL Algorithms**:
- [Augmented Random Search (ARS)](https://arxiv.org/abs/1803.07055)
- [Dueling DQN](https://arxiv.org/abs/1511.06581)
- [Quantile Regression DQN (QR-DQN)](https://arxiv.org/abs/1710.10044)
- [PPO with invalid action masking (MaskablePPO)](https://arxiv.org/abs/2006.14171)
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
Expand Down
1 change: 1 addition & 0 deletions docs/guide/algos.rst
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ along with some useful characteristics: support for discrete/continuous actions,
Name ``Box`` ``Discrete`` ``MultiDiscrete`` ``MultiBinary`` Multi Processing
============ =========== ============ ================= =============== ================
ARS ✔️ ❌️ ❌ ❌ ✔️
Dueling DQN ❌ ️✔️ ❌ ❌ ✔️
MaskablePPO ❌ ✔️ ✔️ ✔️ ✔️
QR-DQN ️❌ ️✔️ ❌ ❌ ✔️
RecurrentPPO ✔️ ✔️ ✔️ ✔️ ✔️
Expand Down
13 changes: 13 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,19 @@ Train a Truncated Quantile Critics (TQC) agent on the Pendulum environment.
model.learn(total_timesteps=10_000, log_interval=4)
model.save("tqc_pendulum")

DuelingDQN
----------

Train a Dueling DQN agent on the CartPole environment.

.. code-block:: python

from sb3_contrib import DuelingDQN

model = DuelingDQN("MlpPolicy", "CartPole-v1", verbose=1)
model.learn(total_timesteps=10_000, log_interval=4)
model.save("dueling_dqn_cartpole")

QR-DQN
------

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:caption: RL Algorithms

modules/ars
modules/dueling_dqn
modules/ppo_mask
modules/ppo_recurrent
modules/qrdqn
Expand Down
1 change: 1 addition & 0 deletions docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ New Features:
^^^^^^^^^^^^^
- Introduced mypy type checking
- Added ``with_bias`` parameter to ``ARSPolicy``
- Added ``DuelingDQN``

Bug Fixes:
^^^^^^^^^^
Expand Down
153 changes: 153 additions & 0 deletions docs/modules/dueling_dqn.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
.. _dueling_dqn:

.. automodule:: sb3_contrib.dueling_dqn


Dueling-DQN
===========

`Dueling DQN <https://arxiv.org/abs/1511.06581>`_ builds on `Deep Q-Network (DQN) <https://arxiv.org/abs/1312.5602>`_
and #TODO:


.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
CnnPolicy
MultiInputPolicy


Notes
-----

- Original paper: https://arxiv.org/abs/1511.06581


Can I use?
----------

- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ✔️ ✔️
Box ❌ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
Dict ❌ ✔️
============= ====== ===========


Example
-------

.. code-block:: python

import gym

from sb3_contrib import DuelingDQN

env = gym.make("CartPole-v1")

model = DuelingDQN("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10000, log_interval=4)
model.save("dueling_dqn_cartpole")

del model # remove to demonstrate saving and loading

model = DuelingDQN.load("dueling_dqn_cartpole")

obs = env.reset()
while True:
action, _states = model.predict(obs, deterministic=True)
obs, reward, done, info = env.step(action)
env.render()
if done:
obs = env.reset()


Results
-------

Result on Atari environments (10M steps, Pong and Breakout) and classic control tasks using 3 and 5 seeds.

The complete learning curves are available in the `associated PR <https://github.com/Stable-Baselines-Team/stable-baselines3-contrib/pull/126>`_. #TODO:


.. note::

DuelingDQN implementation was validated against #TODO: valid the results


============ =========== ===========
Environments DuelingDQN DQN
============ =========== ===========
Breakout ~300
Pong ~20
CartPole 500 +/- 0
MountainCar -107 +/- 4
LunarLander 195 +/- 28
Acrobot -74 +/- 2
============ =========== ===========

#TODO: Fill the tabular

How to replicate the results?
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Clone RL-Zoo fork and checkout the branch ``feat/dueling-dqn``:

.. code-block:: bash

git clone https://github.com/DLR-RM/rl-baselines3-zoo/
cd rl-baselines3-zoo/
git checkout feat/dueling-dqn #TODO: create this branch

Run the benchmark (replace ``$ENV_ID`` by the envs mentioned above):

.. code-block:: bash

python train.py --algo dueling_dqn --env $ENV_ID --eval-episodes 10 --eval-freq 10000 #TODO: check if that command line works


Plot the results:

.. code-block:: bash

python scripts/all_plots.py -a dueling_dqn -e Breakout Pong -f logs/ -o logs/dueling_dqn_results #TODO: check if that command line works
python scripts/plot_from_file.py -i logs/dueling_dqn_results.pkl -latex -l Dueling DQN #TODO: check if that command line works



Parameters
----------

.. autoclass:: DuelingDQN
:members:
:inherited-members:

.. _dueling_dqn_policies:

Dueling DQN Policies
--------------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:

.. autoclass:: sb3_contrib.dueling_dqn.policies.DuelingDQNPolicy
:members:
:noindex:

.. autoclass:: CnnPolicy
:members:

.. autoclass:: MultiInputPolicy
:members:
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from sb3_contrib.ars import ARS
from sb3_contrib.dueling_dqn import DuelingDQN
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
Expand All @@ -14,6 +15,7 @@

__all__ = [
"ARS",
"DuelingDQN",
"MaskablePPO",
"RecurrentPPO",
"QRDQN",
Expand Down
4 changes: 4 additions & 0 deletions sb3_contrib/dueling_dqn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sb3_contrib.dueling_dqn.dueling_dqn import DuelingDQN
from sb3_contrib.dueling_dqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy

__all__ = ["DuelingDQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
132 changes: 132 additions & 0 deletions sb3_contrib/dueling_dqn/dueling_dqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from typing import Any, Dict, Optional, Tuple, Type, TypeVar, Union

import torch as th
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.dqn.dqn import DQN

from sb3_contrib.dueling_dqn.policies import CnnPolicy, DuelingDQNPolicy, MlpPolicy, MultiInputPolicy

SelfDuelingDQN = TypeVar("SelfDuelingDQN", bound="DuelingDQN")


class DuelingDQN(DQN):
"""
Dueling Deep Q-Network (Dueling DQN)

Paper: https://arxiv.org/abs/1511.06581

:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
:param gamma: the discount factor
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param target_update_interval: update the target network every ``target_update_interval``
environment steps.
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
:param exploration_initial_eps: initial value of random action probability
:param exploration_final_eps: final value of random action probability
:param max_grad_norm: The maximum value for the gradient clipping
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 for
debug messages
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""

policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}

def __init__(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if init is the same as DQN, I guess you can drop it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The only (but necessary) diff is the policy type hint.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see...

self,
policy: Union[str, Type[DuelingDQNPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 0.0001,
buffer_size: int = 1000000,
learning_starts: int = 50000,
batch_size: int = 32,
tau: float = 1,
gamma: float = 0.99,
train_freq: Union[int, Tuple[int, str]] = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
exploration_fraction: float = 0.1,
exploration_initial_eps: float = 1,
exploration_final_eps: float = 0.05,
max_grad_norm: float = 10,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):
super().__init__(
policy,
env,
learning_rate,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
replay_buffer_class,
replay_buffer_kwargs,
optimize_memory_usage,
target_update_interval,
exploration_fraction,
exploration_initial_eps,
exploration_final_eps,
max_grad_norm,
tensorboard_log,
policy_kwargs,
verbose,
seed,
device,
_init_setup_model,
)

def learn(
self: SelfDuelingDQN,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "DuelingDQN",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfDuelingDQN:
return super().learn(
total_timesteps,
callback,
log_interval,
tb_log_name,
reset_num_timesteps,
progress_bar,
)
Loading