Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented CrossQ #243

Open
wants to merge 29 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
9afecf5
Implemented CrossQ
danielpalen May 3, 2024
4fa78a7
Fixed code style
danielpalen May 5, 2024
7ce57de
Clean up, comments and refactored to sbx variable names
danielpalen May 12, 2024
9c339b8
1024 neuron Q function (sbx default)
danielpalen May 12, 2024
2b1ff5e
batch norm parameters as function arguments
danielpalen May 12, 2024
aace2ac
clean up. reshape instead of split
danielpalen May 12, 2024
4df7111
Added policy delay
danielpalen May 12, 2024
5583225
fixed commit-checks
danielpalen May 12, 2024
567c2fb
Fix f-string
araffin May 13, 2024
8970ed0
Update documentation
araffin May 13, 2024
8792621
Rename to torch layers
araffin May 13, 2024
230a948
Fix for policy delay and minor edits
araffin May 13, 2024
cd8bd7d
Update tests
araffin May 13, 2024
27a96f6
Update documentation
araffin May 13, 2024
7d6c642
Merge branch 'master' into feat/crossq
araffin Jul 1, 2024
3927a70
Update doc
araffin Jul 6, 2024
2019327
Add more tests for crossQ
araffin Jul 6, 2024
b0213ec
Improve doc and expose batchnorm params
araffin Jul 6, 2024
9772ecf
Merge branch 'master' into feat/crossq
araffin Jul 6, 2024
454224d
Add some comments and todos and fix type check
araffin Jul 6, 2024
a7bbac9
Merge branch 'feat/crossq' of github.com:danielpalen/stable-baselines…
araffin Jul 6, 2024
bbd654c
Use torch module for BN
araffin Jul 19, 2024
bb80218
Re-organize losses
araffin Jul 19, 2024
a717d13
Add set_bn_training_mode
araffin Jul 19, 2024
cb1bc8f
Simplify network creation with new SB3 version, and fix default momentum
araffin Jul 20, 2024
a88a19b
Use different b1 for Adam as in original implementation
araffin Jul 20, 2024
32f66fe
Reformat TOML file
araffin Jul 20, 2024
03db09e
Update CI workflow, skip mypy for 3.8
araffin Jul 22, 2024
244b930
Merge branch 'master' into feat/crossq
araffin Aug 13, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ See documentation for the full list of included features.
- [PPO with recurrent policy (RecurrentPPO aka PPO LSTM)](https://ppo-details.cleanrl.dev//2021/11/05/ppo-implementation-details/)
- [Truncated Quantile Critics (TQC)](https://arxiv.org/abs/2005.04269)
- [Trust Region Policy Optimization (TRPO)](https://arxiv.org/abs/1502.05477)
- [Batch Normalization in Deep Reinforcement Learning (CrossQ)](https://openreview.net/forum?id=PczQtTsTIX)

**Gym Wrappers**:
- [Time Feature Wrapper](https://arxiv.org/abs/1712.00378)
Expand Down
7 changes: 7 additions & 0 deletions docs/common/torch_layers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
.. _th_layers:

Torch Layers
============

.. automodule:: sb3_contrib.common.torch_layers
:members:
13 changes: 13 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -113,3 +113,16 @@ Train a PPO agent with a recurrent policy on the CartPole environment.
obs, rewards, dones, info = vec_env.step(action)
episode_starts = dones
vec_env.render("human")

CrossQ
------

Train a CrossQ agent on the Pendulum environment.

.. code-block:: python

from sb3_contrib import CrossQ

model = CrossQ("MlpPolicy", "Pendulum-v1", verbose=1, policy_kwargs=dict(net_arch=dict(pi=[256, 256], qf=[1024, 1024])))
model.learn(total_timesteps=5_000, log_interval=4)
model.save("crossq_pendulum")
Binary file added docs/images/crossQ_performance.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
2 changes: 2 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:caption: RL Algorithms

modules/ars
modules/crossq
modules/ppo_mask
modules/ppo_recurrent
modules/qrdqn
Expand All @@ -42,6 +43,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:maxdepth: 1
:caption: Common

common/torch_layers
common/utils
common/wrappers

Expand Down
26 changes: 25 additions & 1 deletion docs/misc/changelog.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,30 @@
Changelog
==========


Release 2.4.0a0 (WIP)
--------------------------

Breaking Changes:
^^^^^^^^^^^^^^^^^

New Features:
^^^^^^^^^^^^^
- Added CrossQ (@danielpalen)

Bug Fixes:
^^^^^^^^^^

Deprecations:
^^^^^^^^^^^^^

Others:
^^^^^^^

Documentation:
^^^^^^^^^^^^^^


Release 2.3.0 (2024-03-31)
--------------------------

Expand Down Expand Up @@ -554,4 +578,4 @@ Contributors:
-------------

@ku2482 @guyk1971 @minhlong94 @ayeright @kronion @glmcdona @cyprienc @sgillen @Gregwar @rnederstigt @qgallouedec
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl
@mlodel @CppMaster @burakdmb @honglu2875 @ZikangXiong @AlexPasqua @jonasreiher @icheered @Armandpl @danielpalen
100 changes: 100 additions & 0 deletions docs/modules/crossq.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
.. _crossq:

.. automodule:: sb3_contrib.crossq


CrossQ
======

Implementation of CrossQ proposed in:

`Bhatt A.* & Palenicek D.* et al. Batch Normalization in Deep Reinforcement Learning for Greater Sample Efficiency and Simplicity. ICLR 2024.`

CrossQ is a simple and efficient algorithm that uses batch normalization to improve the sample efficiency of off-policy deep reinforcement learning algorithms.
It is based on the idea of carefully introducing batch normalization layers in the critic network and dropping target networks.
This yield a simpler and more sample-efficient algorithm without requiring high update-to-data ratios.

.. rubric:: Available Policies

.. autosummary::
:nosignatures:

MlpPolicy
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add at least the multi input policy? (so we can try it in combination with HER)
Only the feature extractor should be changed normally.

And what do you think about adding CnnPolicy?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good point. I looked into it and have not yet added it. If I am not mistaken this would also require some changes to the CrossQ train() function. Since, now concatenating and splitting the batches would also require some control flow based on the used policy.
For simplicity sake (for now) and since I did not have time to try and evaluate the multi input policy I did not add that yet.


.. note::

Compared to the original implementation, the default network architecture for the q-value function is ``[1024, 1024]``
instead of ``[2048, 2048]`` as it provides a good compromise between speed and performance.

Notes
-----

- Original paper: https://openreview.net/pdf?id=PczQtTsTIX
- Original Implementation: https://github.com/adityab/CrossQ
- SBX Implementation: https://github.com/araffin/sbx


Can I use?
----------

- Recurrent policies: ❌
- Multi processing: ✔️
- Gym spaces:


============= ====== ===========
Space Action Observation
============= ====== ===========
Discrete ❌ ✔️
Box ✔️ ✔️
MultiDiscrete ❌ ✔️
MultiBinary ❌ ✔️
Dict ❌ ✔️
============= ====== ===========


Example
-------

.. code-block:: python

from sb3_contrib import CrossQ

model = CrossQ("MlpPolicy", "Walker2d-v4")
model.learn(total_timesteps=1_000_000)
model.save("crossq_walker")


Results
araffin marked this conversation as resolved.
Show resolved Hide resolved
-------

Performance evaluation of CrossQ on six MuJoCo environments.
Compared to results from the original paper as well as a version from `SBX <https://github.com/araffin/sbx>`_.

.. image:: ../images/crossQ_performance.png

Comments
--------

This implementation is based on SB3 SAC implementation.


Parameters
----------

.. autoclass:: CrossQ
:members:
:inherited-members:

.. _crossq_policies:

CrossQ Policies
---------------

.. autoclass:: MlpPolicy
:members:
:inherited-members:

.. autoclass:: sb3_contrib.crossq.policies.CrossQPolicy
:members:
:noindex:
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from sb3_contrib.ars import ARS
from sb3_contrib.crossq import CrossQ
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
Expand All @@ -14,6 +15,7 @@

__all__ = [
"ARS",
"CrossQ",
"MaskablePPO",
"RecurrentPPO",
"QRDQN",
Expand Down
114 changes: 114 additions & 0 deletions sb3_contrib/common/torch_layers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import torch

__all__ = ["BatchRenorm1d", "BatchRenorm"]


class BatchRenorm(torch.jit.ScriptModule):
araffin marked this conversation as resolved.
Show resolved Hide resolved
"""
BatchRenorm Module (https://arxiv.org/abs/1702.03275).
Adapted to Pytorch from sbx.sbx.common.jax_layers.BatchRenorm

BatchRenorm is an improved version of vanilla BatchNorm. Contrary to BatchNorm,
BatchRenorm uses the running statistics for normalizing the batches after a warmup phase.
This makes it less prone to suffer from "outlier" batches that can happen
during very long training runs and, therefore, is more robust during long training runs.

During the warmup phase, it behaves exactly like a BatchNorm layer. After the warmup phase,
the running statistics are used for normalization. The running statistics are updated during
training mode. During evaluation mode, the running statistics are used for normalization but
not updated.

:param num_features: Number of features in the input tensor.
:param eps: A value added to the variance for numerical stability.
:param momentum: The value used for the ra_mean and ra_var computation.
:param affine: A boolean value that when set to True, this module has learnable
affine parameters. Default: True
:param warmup_steps: Number of warum steps that are performed before the running statistics
are used form normalization. During the warump phase, the batch statistics are used.
"""

def __init__(
self,
num_features: int,
eps: float = 0.001,
momentum: float = 0.01,
affine: bool = True,
warmup_steps: int = 100_000,
):
super().__init__()
# Running average mean and variance
self.register_buffer("ra_mean", torch.zeros(num_features, dtype=torch.float))
self.register_buffer("ra_var", torch.ones(num_features, dtype=torch.float))
self.register_buffer("steps", torch.tensor(0, dtype=torch.long))
self.scale = torch.nn.Parameter(torch.ones(num_features, dtype=torch.float))
self.bias = torch.nn.Parameter(torch.zeros(num_features, dtype=torch.float))

self.affine = affine
self.eps = eps
self.step = 0
self.momentum = momentum
self.rmax = 3.0
self.dmax = 5.0
self.warmup_steps = warmup_steps

def _check_input_dim(self, x: torch.Tensor) -> None:
raise NotImplementedError()

def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Normalize the input tensor.

:param x: Input tensor
:return: Normalized tensor.
"""

if self.training:
batch_mean = x.mean(0)
batch_var = x.var(0)
batch_std = (batch_var + self.eps).sqrt()

# Use batch statistics during initial warm up phase.
# Note: in the original paper, after some warmup phase (batch norm phase of 5k steps)
# the constraints are linearly relaxed to r_max/d_max over 40k steps
# Here we only have a warmup phase
if self.steps > self.warmup_steps:

running_std = (self.ra_var + self.eps).sqrt()
# scale
r = (batch_std / running_std).detach()
r = r.clamp(1 / self.rmax, self.rmax)
# bias
d = ((batch_mean - self.ra_mean) / running_std).detach()
d = d.clamp(-self.dmax, self.dmax)

# BatchNorm normalization, using minibatch stats and running average stats
# Because we use _normalize, this is equivalent to
# ((x - x_mean) / sigma) * r + d = ((x - x_mean) * r + d * sigma) / sigma
# where sigma = sqrt(var)
custom_mean = batch_mean - d * batch_var.sqrt() / r
custom_var = batch_var / (r**2)

else:
custom_mean, custom_var = batch_mean, batch_var

# Update Running Statistics
self.ra_mean += self.momentum * (batch_mean.detach() - self.ra_mean)
self.ra_var += self.momentum * (batch_var.detach() - self.ra_var)
self.steps += 1

else:
custom_mean, custom_var = self.ra_mean, self.ra_var

# Normalize
x = (x - custom_mean[None]) / (custom_var[None] + self.eps).sqrt()

if self.affine:
x = self.scale * x + self.bias

return x


class BatchRenorm1d(BatchRenorm):
def _check_input_dim(self, x: torch.Tensor) -> None:
if x.dim() == 1:
raise ValueError(f"Expected 2D or 3D input (got {x.dim()}D input)")
4 changes: 4 additions & 0 deletions sb3_contrib/crossq/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sb3_contrib.crossq.crossq import CrossQ
from sb3_contrib.crossq.policies import MlpPolicy

__all__ = ["CrossQ", "MlpPolicy"]
Loading
Loading