Conservative Q-Learning (CQL)¶
CQL is an extension of Q-learning that addresses the typical overestimation of values induced by the distributional shift between the dataset and the learned policy in offline RL algorithms. A conservative Q-function is learned, such that the expected value of a policy under this Q-function lower-bounds its true value
CQL paper: https://arxiv.org/abs/2006.04779
Can I use it?¶
Action |
Observation |
|
---|---|---|
Discrete |
✔️ |
✔️ |
Continuous |
✔️ |
✔️ |
So far, we have implemented CQN - CQL applied to DQN, which cannot be used on continuous action spaces. We will soon be adding other CQL extensions of algorithms for offline RL.
Example¶
import gymnasium as gym
import h5py
from agilerl.components.replay_buffer import ReplayBuffer
from agilerl.algorithms.cqn import CQN
# Create environment and Experience Replay Buffer, and load dataset
env = gym.make('CartPole-v1')
try:
state_dim = env.observation_space.n # Discrete observation space
one_hot = True # Requires one-hot encoding
except Exception:
state_dim = env.observation_space.shape # Continuous observation space
one_hot = False # Does not require one-hot encoding
try:
action_dim = env.action_space.n # Discrete action space
except Exception:
action_dim = env.action_space.shape[0] # Continuous action space
channels_last = False # Swap image channels dimension from last to first [H, W, C] -> [C, H, W]
if channels_last:
state_dim = (state_dim[2], state_dim[0], state_dim[1])
field_names = ["state", "action", "reward", "next_state", "done"]
memory = ReplayBuffer(action_dim=action_dim, memory_size=10000, field_names=field_names)
dataset = h5py.File('data/cartpole/cartpole_random_v1.1.0.h5', 'r') # Load dataset
# Save transitions to replay buffer
dataset_length = dataset['rewards'].shape[0]
for i in range(dataset_length-1):
state = dataset['observations'][i]
next_state = dataset['observations'][i+1]
if channels_last:
state = np.moveaxis(state, [3], [1])
next_state = np.moveaxis(next_state, [3], [1])
action = dataset['actions'][i]
reward = dataset['rewards'][i]
done = bool(dataset['terminals'][i])
memory.save2memory(state, action, reward, next_state, done)
agent = CQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot) # Create DQN agent
state = env.reset()[0] # Reset environment at start of episode
while True:
experiences = memory.sample(agent.batch_size) # Sample replay buffer
# Learn according to agent's RL algorithm
agent.learn(experiences)
To configure the network architecture, pass a dict to the CQN net_config
field. For an MLP, this can be as simple as:
NET_CONFIG = {
'arch': 'mlp', # Network architecture
'hidden_size': [32, 32] # Network hidden size
}
Or for a CNN:
NET_CONFIG = {
'arch': 'cnn', # Network architecture
'hidden_size': [128], # Network hidden size
'channel_size': [32, 32], # CNN channel size
'kernel_size': [8, 4], # CNN kernel size
'stride_size': [4, 2], # CNN stride size
'normalize': True # Normalize image from range [0,255] to [0,1]
}
agent = CQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot, net_config=NET_CONFIG) # Create CQN agent
Saving and loading agents¶
To save an agent, use the saveCheckpoint
method:
from agilerl.algorithms.cqn import CQN
agent = CQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot) # Create CQN agent
checkpoint_path = "path/to/checkpoint"
agent.saveCheckpoint(checkpoint_path)
To load a saved agent, use the load
method:
from agilerl.algorithms.cqn import CQN
checkpoint_path = "path/to/checkpoint"
agent = CQN.load(checkpoint_path)
Parameters¶
- class agilerl.algorithms.cqn.CQN(state_dim, action_dim, one_hot, index=0, net_config={'arch': 'mlp', 'hidden_size': [64, 64]}, batch_size=64, lr=0.0001, learn_step=5, gamma=0.99, tau=0.001, mut=None, double=False, actor_network=None, device='cpu', accelerator=None, wrap=True)¶
The CQN algorithm class. CQN paper: https://arxiv.org/abs/2006.04779
- Parameters:
action_dim (int) – Action dimension
one_hot (bool) – One-hot encoding, used with discrete observation spaces
index (int, optional) – Index to keep track of object instance during tournament selection and mutation, defaults to 0
net_config (dict, optional) – Network configuration, defaults to mlp with hidden size [64,64]
batch_size (int, optional) – Size of batched sample from replay buffer for learning, defaults to 64
lr (float, optional) – Learning rate for optimizer, defaults to 1e-4
learn_step (int, optional) – Learning frequency, defaults to 5
gamma (float, optional) – Discount factor, defaults to 0.99
tau (float, optional) – For soft update of target network parameters, defaults to 1e-3
mut (str, optional) – Most recent mutation to agent, defaults to None
double (bool, optional) – Use double Q-learning, defaults to False
actor_network (nn.Module, optional) – Custom actor network, defaults to None
device (str, optional) – Device for accelerated computing, ‘cpu’ or ‘cuda’, defaults to ‘cpu’
accelerator (accelerate.Accelerator(), optional) – Accelerator for distributed computing, defaults to None
wrap (bool, optional) – Wrap models for distributed training upon creation, defaults to True
- clone(index=None, wrap=True)¶
Returns cloned agent identical to self.
- Parameters:
index (int, optional) – Index to keep track of agent for tournament selection and mutation, defaults to None
- getAction(state, epsilon=0, action_mask=None)¶
Returns the next action to take in the environment. Epsilon is the probability of taking a random action, used for exploration. For epsilon-greedy behaviour, set epsilon to 0.
- learn(experiences)¶
Updates agent network parameters to learn from experiences.
- Parameters:
experiences – List of batched states, actions, rewards, next_states, dones in that order.
- classmethod load(path, device='cpu', accelerator=None)¶
Creates agent with properties and network weights loaded from path.
- Parameters:
path (string) – Location to load checkpoint from
device (str, optional) – Device for accelerated computing, ‘cpu’ or ‘cuda’, defaults to ‘cpu’
accelerator (accelerate.Accelerator(), optional) – Accelerator for distributed computing, defaults to None
- loadCheckpoint(path)¶
Loads saved agent properties and network weights from checkpoint.
- Parameters:
path (string) – Location to load checkpoint from
- saveCheckpoint(path)¶
Saves a checkpoint of agent properties and network weights to path.
- Parameters:
path (string) – Location to save checkpoint at
- softUpdate()¶
Soft updates target network.
- test(env, swap_channels=False, max_steps=500, loop=3)¶
Returns mean test score of agent in environment with epsilon-greedy policy.
- Parameters:
env (Gym-style environment) – The environment to be tested in
swap_channels (bool, optional) – Swap image channels dimension from last to first [H, W, C] -> [C, H, W], defaults to False
max_steps (int, optional) – Maximum number of testing steps, defaults to 500
loop (int, optional) – Number of testing loops/episodes to complete. The returned score is the mean. Defaults to 3