Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

IQN #139

Draft
wants to merge 6 commits into
base: master
Choose a base branch
from
Draft

IQN #139

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions docs/guide/examples.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,20 @@ Train a Quantile Regression DQN (QR-DQN) agent on the CartPole environment.
model.learn(total_timesteps=10_000, log_interval=4)
model.save("qrdqn_cartpole")

IQN
---

Train a Implicit Quantile Networks (IQN) agent on the CartPole environment.

.. code-block:: python

from sb3_contrib import IQN

policy_kwargs = dict(n_quantiles=32)
model = IQN("MlpPolicy", "CartPole-v1", policy_kwargs=policy_kwargs, verbose=1)
model.learn(total_timesteps=10_000, log_interval=4)
model.save("iqn_cartpole")

MaskablePPO
-----------

Expand Down
1 change: 1 addition & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ RL Baselines3 Zoo also offers a simple interface to train, evaluate agents and d
:caption: RL Algorithms

modules/ars
modules/iqn
modules/ppo_mask
modules/ppo_recurrent
modules/qrdqn
Expand Down
2 changes: 2 additions & 0 deletions sb3_contrib/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os

from sb3_contrib.ars import ARS
from sb3_contrib.iqn import IQN
from sb3_contrib.ppo_mask import MaskablePPO
from sb3_contrib.ppo_recurrent import RecurrentPPO
from sb3_contrib.qrdqn import QRDQN
Expand All @@ -14,6 +15,7 @@

__all__ = [
"ARS",
"IQN",
"MaskablePPO",
"RecurrentPPO",
"QRDQN",
Expand Down
4 changes: 4 additions & 0 deletions sb3_contrib/iqn/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from sb3_contrib.iqn.iqn import IQN
from sb3_contrib.iqn.policies import CnnPolicy, MlpPolicy, MultiInputPolicy

__all__ = ["IQN", "CnnPolicy", "MlpPolicy", "MultiInputPolicy"]
282 changes: 282 additions & 0 deletions sb3_contrib/iqn/iqn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,282 @@
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Union

import numpy as np
import torch as th
from gym import spaces
from stable_baselines3.common.buffers import ReplayBuffer
from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm
from stable_baselines3.common.policies import BasePolicy
from stable_baselines3.common.preprocessing import maybe_transpose
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import get_linear_fn, get_parameters_by_name, is_vectorized_observation, polyak_update

from sb3_contrib.common.utils import quantile_huber_loss
from sb3_contrib.iqn.policies import CnnPolicy, IQNPolicy, MlpPolicy, MultiInputPolicy

SelfIQN = TypeVar("SelfIQN", bound="IQN")


class IQN(OffPolicyAlgorithm):
"""
Implicit Quantile Network (IQN)
Paper: https://arxiv.org/abs/1806.06923
Default hyperparameters are taken from the paper and are tuned for Atari games.

:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...)
:param env: The environment to learn from (if registered in Gym, can be str)
:param learning_rate: The learning rate, it can be a function
of the current progress remaining (from 1 to 0)
:param buffer_size: size of the replay buffer
:param learning_starts: how many steps of the model to collect transitions for before learning starts
:param batch_size: Minibatch size for each gradient update
:param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update
:param gamma: the discount factor
:param num_tau_samples: Number of samples used to estimate the current quantiles
:param num_tau_prime_samples: Number of samples used to estimate the next quantiles
:param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit
like ``(5, "step")`` or ``(2, "episode")``.
:param gradient_steps: How many gradient steps to do after each rollout
(see ``train_freq`` and ``n_episodes_rollout``)
Set to ``-1`` means to do as many gradient steps as steps done in the environment
during the rollout.
:param replay_buffer_class: Replay buffer class to use (for instance ``HerReplayBuffer``).
If ``None``, it will be automatically selected.
:param replay_buffer_kwargs: Keyword arguments to pass to the replay buffer on creation.
:param optimize_memory_usage: Enable a memory efficient variant of the replay buffer
at a cost of more complexity.
See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195
:param target_update_interval: update the target network every ``target_update_interval``
environment steps.
:param exploration_fraction: fraction of entire training period over which the exploration rate is reduced
:param exploration_initial_eps: initial value of random action probability
:param exploration_final_eps: final value of random action probability
:param max_grad_norm: The maximum value for the gradient clipping (if None, no clipping)
:param tensorboard_log: the log location for tensorboard (if None, no logging)
:param policy_kwargs: additional arguments to be passed to the policy on creation
:param verbose: the verbosity level: 0 no output, 1 info, 2 debug
:param seed: Seed for the pseudo random generators
:param device: Device (cpu, cuda, ...) on which the code should be run.
Setting it to auto, the code will be run on the GPU if possible.
:param _init_setup_model: Whether or not to build the network at the creation of the instance
"""

policy_aliases: Dict[str, Type[BasePolicy]] = {
"MlpPolicy": MlpPolicy,
"CnnPolicy": CnnPolicy,
"MultiInputPolicy": MultiInputPolicy,
}

def __init__(
self,
policy: Union[str, Type[IQNPolicy]],
env: Union[GymEnv, str],
learning_rate: Union[float, Schedule] = 5e-5,
buffer_size: int = 1000000, # 1e6
learning_starts: int = 50000,
batch_size: int = 32,
tau: float = 1.0,
gamma: float = 0.99,
num_tau_samples: int = 32,
num_tau_prime_samples: int = 64,
train_freq: int = 4,
gradient_steps: int = 1,
replay_buffer_class: Optional[Type[ReplayBuffer]] = None,
replay_buffer_kwargs: Optional[Dict[str, Any]] = None,
optimize_memory_usage: bool = False,
target_update_interval: int = 10000,
exploration_fraction: float = 0.005,
exploration_initial_eps: float = 1.0,
exploration_final_eps: float = 0.01,
max_grad_norm: Optional[float] = None,
tensorboard_log: Optional[str] = None,
policy_kwargs: Optional[Dict[str, Any]] = None,
verbose: int = 0,
seed: Optional[int] = None,
device: Union[th.device, str] = "auto",
_init_setup_model: bool = True,
):

super().__init__(
policy,
env,
learning_rate,
buffer_size,
learning_starts,
batch_size,
tau,
gamma,
train_freq,
gradient_steps,
action_noise=None, # No action noise
replay_buffer_class=replay_buffer_class,
replay_buffer_kwargs=replay_buffer_kwargs,
policy_kwargs=policy_kwargs,
tensorboard_log=tensorboard_log,
verbose=verbose,
device=device,
seed=seed,
sde_support=False,
optimize_memory_usage=optimize_memory_usage,
supported_action_spaces=(spaces.Discrete,),
support_multi_env=True,
)

self.num_tau_samples = num_tau_samples
self.num_tau_prime_samples = num_tau_prime_samples

self.exploration_initial_eps = exploration_initial_eps
self.exploration_final_eps = exploration_final_eps
self.exploration_fraction = exploration_fraction
self.target_update_interval = target_update_interval
self.max_grad_norm = max_grad_norm
# "epsilon" for the epsilon-greedy exploration
self.exploration_rate = 0.0
# Linear schedule will be defined in `_setup_model()`
self.exploration_schedule: Schedule
self.policy: IQNPolicy # type: ignore[assignment]

if "optimizer_class" not in self.policy_kwargs:
self.policy_kwargs["optimizer_class"] = th.optim.Adam
# Proposed in the QR-DQN paper where `batch_size = 32`
self.policy_kwargs["optimizer_kwargs"] = dict(eps=0.01 / batch_size)

if _init_setup_model:
self._setup_model()

def _setup_model(self) -> None:
super()._setup_model()
self._create_aliases()
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
self.batch_norm_stats = get_parameters_by_name(self.quantile_net, ["running_"])
self.batch_norm_stats_target = get_parameters_by_name(self.quantile_net_target, ["running_"])
self.exploration_schedule = get_linear_fn(
self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction
)

def _create_aliases(self) -> None:
self.quantile_net = self.policy.quantile_net
self.quantile_net_target = self.policy.quantile_net_target
self.n_quantiles = self.policy.n_quantiles

def _on_step(self) -> None:
"""
Update the exploration rate and target network if needed.
This method is called in ``collect_rollouts()`` after each step in the environment.
"""
if self.num_timesteps % self.target_update_interval == 0:
polyak_update(self.quantile_net.parameters(), self.quantile_net_target.parameters(), self.tau)
# Copy running stats, see https://github.com/DLR-RM/stable-baselines3/issues/996
polyak_update(self.batch_norm_stats, self.batch_norm_stats_target, 1.0)

self.exploration_rate = self.exploration_schedule(self._current_progress_remaining)
self.logger.record("rollout/exploration_rate", self.exploration_rate)

def train(self, gradient_steps: int, batch_size: int = 100) -> None:
# Switch to train mode (this affects batch norm / dropout)
self.policy.set_training_mode(True)
# Update learning rate according to schedule
self._update_learning_rate(self.policy.optimizer)

losses = []
for _ in range(gradient_steps):
# Sample replay buffer
replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env)

with th.no_grad():
# Compute the quantiles of next observation
next_quantiles = self.quantile_net_target(replay_data.next_observations, self.n_quantiles)
# Compute the greedy actions which maximize the next Q values
next_greedy_actions = next_quantiles.mean(dim=1, keepdim=True).argmax(dim=2, keepdim=True)
# Make "num_tau_prime_samples" copies of actions, and reshape to (batch_size, num_tau_prime_samples, 1)
next_greedy_actions = next_greedy_actions.expand(batch_size, self.num_tau_prime_samples, 1)
# Compute the quantiles of next observation, but with another number of tau samples
next_quantiles = self.quantile_net_target(replay_data.next_observations, self.num_tau_prime_samples)
# Follow greedy policy: use the one with the highest Q values
next_quantiles = next_quantiles.gather(dim=2, index=next_greedy_actions).squeeze(dim=2)
# 1-step TD target
target_quantiles = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_quantiles

# Get current quantile estimates
current_quantiles = self.quantile_net(replay_data.observations, self.num_tau_samples)
# Make "num_tau_samples" copies of actions, and reshape to (batch_size, num_tau_samples, 1).
actions = replay_data.actions[..., None].long().expand(batch_size, self.num_tau_samples, 1)
# Retrieve the quantiles for the actions from the replay buffer
current_quantiles = th.gather(current_quantiles, dim=2, index=actions).squeeze(dim=2)

# Compute Quantile Huber loss, summing over a quantile dimension as in the paper.
loss = quantile_huber_loss(current_quantiles, target_quantiles, sum_over_quantiles=True)
losses.append(loss.item())

# Optimize the policy
self.policy.optimizer.zero_grad()
loss.backward()
# Clip gradient norm
if self.max_grad_norm is not None:
th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm)
self.policy.optimizer.step()

# Increase update counter
self._n_updates += gradient_steps

self.logger.record("train/n_updates", self._n_updates, exclude="tensorboard")
self.logger.record("train/loss", np.mean(losses))

def predict(
self,
observation: Union[np.ndarray, Dict[str, np.ndarray]],
state: Optional[Tuple[np.ndarray, ...]] = None,
episode_start: Optional[np.ndarray] = None,
deterministic: bool = False,
) -> Tuple[np.ndarray, Optional[Tuple[np.ndarray, ...]]]:
"""
Get the policy action from an observation (and optional hidden state).
Includes sugar-coating to handle different observations (e.g. normalizing images).

:param observation: the input observation
:param state: The last hidden states (can be None, used in recurrent policies)
:param episode_start: The last masks (can be None, used in recurrent policies)
this correspond to beginning of episodes,
where the hidden states of the RNN must be reset.
:param deterministic: Whether or not to return deterministic actions.
:return: the model's action and the next hidden state
(used in recurrent policies)
"""
if not deterministic and np.random.rand() < self.exploration_rate:
if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space):
if isinstance(self.observation_space, spaces.Dict):
n_batch = observation[list(observation.keys())[0]].shape[0]
else:
n_batch = observation.shape[0]
action = np.array([self.action_space.sample() for _ in range(n_batch)])
else:
action = np.array(self.action_space.sample())
else:
action, state = self.policy.predict(observation, state, episode_start, deterministic)
return action, state

def learn(
self: SelfIQN,
total_timesteps: int,
callback: MaybeCallback = None,
log_interval: int = 4,
tb_log_name: str = "IQN",
reset_num_timesteps: bool = True,
progress_bar: bool = False,
) -> SelfIQN:

return super().learn(
total_timesteps=total_timesteps,
callback=callback,
log_interval=log_interval,
tb_log_name=tb_log_name,
reset_num_timesteps=reset_num_timesteps,
progress_bar=progress_bar,
)

def _excluded_save_params(self) -> List[str]:
return super()._excluded_save_params() + ["quantile_net", "quantile_net_target"]

def _get_torch_save_params(self) -> Tuple[List[str], List[str]]:
state_dicts = ["policy", "policy.optimizer"]

return state_dicts, []
Loading