Rainbow DQN

Rainbow DQN is an extension of DQN that integrates multiple improvements and techniques to achieve state-of-the-art performance. These improvements include:

  • Double DQN (DDQN): Addresses the overestimation bias of Q-values by using two networks to decouple the selection and evaluation of the action in the Q-learning target.

  • Prioritized Experience Replay: Instead of uniformly sampling from the replay buffer, it samples more important transitions more frequently based on the magnitude of their temporal difference (TD) error.

  • Dueling Networks: Splits the Q-network into two separate streams — one for estimating the state value function and another for estimating the advantages for each action. They are then combined to produce Q-values.

  • Multi-step Learning (n-step returns): Instead of using just the immediate reward for learning, it uses multi-step returns which consider a sequence of future rewards.

  • Distributional RL: Instead of estimating the expected value of the cumulative future reward, it predicts the entire distribution of the cumulative future reward.

  • Noisy Nets: Adds noise directly to the weights of the network, providing a way to explore the environment without the need for epsilon-greedy exploration.

  • Categorical DQN (C51): A specific form of distributional RL where the continuous range of possible cumulative future rewards is discretized into a fixed set of categories.

Rainbow DQN paper: https://arxiv.org/abs/1710.02298

Can I use it?









import gymnasium as gym
from agilerl.utils.utils import makeVectEnvs
from agilerl.components.replay_buffer import ReplayBuffer
from agilerl.algorithms.dqn_rainbow import RainbowDQN

# Create environment and Experience Replay Buffer
env = makeVectEnvs('LunarLander-v2', num_envs=1)
    state_dim = env.single_observation_space.n          # Discrete observation space
    one_hot = True                                      # Requires one-hot encoding
    state_dim = env.single_observation_space.shape      # Continuous observation space
    one_hot = False                                     # Does not require one-hot encoding
    action_dim = env.single_action_space.n              # Discrete action space
    action_dim = env.single_action_space.shape[0]       # Continuous action space

channels_last = False # Swap image channels dimension from last to first [H, W, C] -> [C, H, W]

if channels_last:
    state_dim = (state_dim[2], state_dim[0], state_dim[1])

field_names = ["state", "action", "reward", "next_state", "done"]
memory = ReplayBuffer(action_dim=action_dim, memory_size=10000, field_names=field_names)

agent = RainbowDQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot)   # Create agent

state = env.reset()[0]  # Reset environment at start of episode
while True:
    if channels_last:
        state = np.moveaxis(state, [3], [1])
    action = agent.getAction(state, epsilon)    # Get next action from agent
    next_state, reward, done, _, _ = env.step(action)   # Act in environment

    # Save experience to replay buffer
    if channels_last:
        memory.save2memoryVectEnvs(state, action, reward, np.moveaxis(next_state, [3], [1]), done)
        memory.save2memoryVectEnvs(state, action, reward, next_state, done)

    # Learn according to learning frequency
    if memory.counter % agent.learn_step == 0 and len(memory) >= agent.batch_size:
        experiences = memory.sample(agent.batch_size) # Sample replay buffer
        agent.learn(experiences)    # Learn according to agent's RL algorithm

To configure the network architecture, pass a dict to the DQN net_config field. For an MLP, this can be as simple as:

      'arch': 'mlp',      # Network architecture
      'hidden_size': [32, 32]  # Network hidden size

Or for a CNN:

      'arch': 'cnn',      # Network architecture
      'hidden_size': [128],    # Network hidden size
      'channel_size': [32, 32], # CNN channel size
      'kernel_size': [8, 4],   # CNN kernel size
      'stride_size': [4, 2],   # CNN stride size
      'normalize': True   # Normalize image from range [0,255] to [0,1]
agent = RainbowDQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot, net_config=NET_CONFIG)   # Create agent

Saving and loading agents

To save an agent, use the saveCheckpoint method:

from agilerl.algorithms.dqn_rainbow import RainbowDQN

agent = RainbowDQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot)   # Create Rainbow DQN agent

checkpoint_path = "path/to/checkpoint"

To load a saved agent, use the load method:

from agilerl.algorithms.dqn_rainbow import RainbowDQN

checkpoint_path = "path/to/checkpoint"
agent = RainbowDQN.load(checkpoint_path)


class agilerl.algorithms.dqn_rainbow.RainbowDQN(state_dim, action_dim, one_hot, index=0, net_config={'arch': 'mlp', 'hidden_size': [64, 64]}, batch_size=64, lr=0.0001, learn_step=5, gamma=0.99, tau=0.001, beta=0.4, prior_eps=1e-06, num_atoms=51, v_min=-10, v_max=10, n_step=3, mut=None, actor_network=None, device='cpu', accelerator=None, wrap=True)

The Rainbow DQN algorithm class. Rainbow DQN paper: https://arxiv.org/abs/1710.02298

  • state_dim (list[int]) – State observation dimension

  • action_dim (int) – Action dimension

  • one_hot (bool) – One-hot encoding, used with discrete observation spaces

  • index (int, optional) – Index to keep track of object instance during tournament selection and mutation, defaults to 0

  • net_config (dict, optional) – Network configuration, defaults to mlp with hidden size [64,64]

  • batch_size (int, optional) – Size of batched sample from replay buffer for learning, defaults to 64

  • lr (float, optional) – Learning rate for optimizer, defaults to 1e-4

  • learn_step (int, optional) – Learning frequency, defaults to 5

  • gamma (float, optional) – Discount factor, defaults to 0.99

  • tau (float, optional) – For soft update of target network parameters, defaults to 1e-3

  • beta (float, optional) – Importance sampling coefficient, defaults to 0.4

  • prior_eps (float, optional) – Minimum priority for sampling, defaults to 1e-6

  • num_atoms (int, optional) – Unit number of support, defaults to 51

  • v_min (float, optional) – Minimum value of support, defaults to 0

  • v_max (float, optional) – Maximum value of support, defaults to 200

  • n_step (int, optional) – Step number to calculate n-step td error, defaults to 3

  • mut (str, optional) – Most recent mutation to agent, defaults to None

  • actor_network (nn.Module, optional) – Custom actor network, defaults to None

  • device (str, optional) – Device for accelerated computing, ‘cpu’ or ‘cuda’, defaults to ‘cpu’

  • accelerator (accelerate.Accelerator(), optional) – Accelerator for distributed computing, defaults to None

  • wrap (bool, optional) – Wrap models for distributed training upon creation, defaults to True

clone(index=None, wrap=True)

Returns cloned agent identical to self.


index (int, optional) – Index to keep track of agent for tournament selection and mutation, defaults to None

getAction(state, action_mask=None, training=True)

Returns the next action to take in the environment.

  • state (numpy.ndarray[float]) – State observation, or multiple observations in a batch

  • action_mask (numpy.ndarray, optional) – Mask of legal actions 1=legal 0=illegal, defaults to None

learn(experiences, n_step=False, per=False)

Updates agent network parameters to learn from experiences.

  • experiences – List of batched states, actions, rewards, next_states, dones in that order.

  • n_step (bool, optional) – Use multi-step learning, defaults to True

  • per (bool, optional) – Use prioritized experience replay buffer, defaults to True

classmethod load(path, device='cpu', accelerator=None)

Creates agent with properties and network weights loaded from path.

  • path (string) – Location to load checkpoint from

  • device (str, optional) – Device for accelerated computing, ‘cpu’ or ‘cuda’, defaults to ‘cpu’

  • accelerator (accelerate.Accelerator(), optional) – Accelerator for distributed computing, defaults to None


Loads saved agent properties and network weights from checkpoint.


path (string) – Location to load checkpoint from


Saves a checkpoint of agent properties and network weights to path.


path (string) – Location to save checkpoint at


Soft updates target network.

test(env, swap_channels=False, max_steps=None, loop=3)

Returns mean test score of agent in environment with epsilon-greedy policy.

  • env (Gym-style environment) – The environment to be tested in

  • swap_channels (bool, optional) – Swap image channels dimension from last to first [H, W, C] -> [C, H, W], defaults to False

  • max_steps (int, optional) – Maximum number of testing steps, defaults to None

  • loop (int, optional) – Number of testing loops/episodes to complete. The returned score is the mean over these tests. Defaults to 3