Conservative Q-Learning (CQL)#
CQL is an extension of Q-learning that addresses the typical overestimation of values induced by the distributional shift between the dataset and the learned policy in offline RL algorithms. A conservative Q-function is learned, such that the expected value of a policy under this Q-function lower-bounds its true value
CQL paper: https://arxiv.org/abs/2006.04779
Can I use it?#
Action |
Observation |
|
---|---|---|
Discrete |
✔️ |
✔️ |
Continuous |
✔️ |
✔️ |
So far, we have implemented CQN - CQL applied to DQN, which cannot be used on continuous action spaces. We will soon be adding other CQL extensions of algorithms for offline RL.
Example#
import gymnasium as gym
import h5py
from agilerl.components.replay_buffer import ReplayBuffer
from agilerl.algorithms.cqn import CQN
# Create environment and Experience Replay Buffer, and load dataset
env = gym.make('CartPole-v1')
try:
state_dim = env.observation_space.n # Discrete observation space
one_hot = True # Requires one-hot encoding
except Exception:
state_dim = env.observation_space.shape # Continuous observation space
one_hot = False # Does not require one-hot encoding
try:
action_dim = env.action_space.n # Discrete action space
except Exception:
action_dim = env.action_space.shape[0] # Continuous action space
channels_last = False # Swap image channels dimension from last to first [H, W, C] -> [C, H, W]
if channels_last:
state_dim = (state_dim[2], state_dim[0], state_dim[1])
field_names = ["state", "action", "reward", "next_state", "done"]
memory = ReplayBuffer(action_dim=action_dim, memory_size=10000, field_names=field_names)
dataset = h5py.File('data/cartpole/cartpole_random_v1.1.0.h5', 'r') # Load dataset
# Save transitions to replay buffer
dataset_length = dataset['rewards'].shape[0]
for i in range(dataset_length-1):
state = dataset['observations'][i]
next_state = dataset['observations'][i+1]
if channels_last:
state = np.moveaxis(state, [3], [1])
next_state = np.moveaxis(next_state, [3], [1])
action = dataset['actions'][i]
reward = dataset['rewards'][i]
done = bool(dataset['terminals'][i])
memory.save2memory(state, action, reward, next_state, done)
agent = CQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot) # Create DQN agent
state = env.reset()[0] # Reset environment at start of episode
while True:
experiences = memory.sample(agent.batch_size) # Sample replay buffer
# Learn according to agent's RL algorithm
agent.learn(experiences)
To configure the network architecture, pass a dict to the CQN net_config
field. For an MLP, this can be as simple as:
NET_CONFIG = {
'arch': 'mlp', # Network architecture
'h_size': [32, 32] # Network hidden size
}
Or for a CNN:
NET_CONFIG = {
'arch': 'cnn', # Network architecture
'h_size': [128], # Network hidden size
'c_size': [32, 32], # CNN channel size
'k_size': [8, 4], # CNN kernel size
's_size': [4, 2], # CNN stride size
'normalize': True # Normalize image from range [0,255] to [0,1]
}
agent = CQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot, net_config=NET_CONFIG) # Create CQN agent
Saving and loading agents#
To save an agent, use the saveCheckpoint
method:
from agilerl.algorithms.cqn import CQN
agent = CQN(state_dim=state_dim, action_dim=action_dim, one_hot=one_hot) # Create CQN agent
checkpoint_path = "path/to/checkpoint"
agent.saveCheckpoint(checkpoint_path)
To load a saved agent, use the load
method:
from agilerl.algorithms.cqn import CQN
checkpoint_path = "path/to/checkpoint"
agent = CQN.load(checkpoint_path)
Parameters#
- class agilerl.algorithms.cqn.CQN(state_dim, action_dim, one_hot, index=0, net_config={'arch': 'mlp', 'h_size': [64, 64]}, batch_size=64, lr=0.0001, learn_step=5, gamma=0.99, tau=0.001, mut=None, double=False, actor_network=None, device='cpu', accelerator=None, wrap=True)#
The CQN algorithm class. CQN paper: https://arxiv.org/abs/2006.04779
- Parameters:
action_dim (int) – Action dimension
one_hot (bool) – One-hot encoding, used with discrete observation spaces
index (int, optional) – Index to keep track of object instance during tournament selection and mutation, defaults to 0
net_config (dict, optional) – Network configuration, defaults to mlp with hidden size [64,64]
batch_size (int, optional) – Size of batched sample from replay buffer for learning, defaults to 64
lr (float, optional) – Learning rate for optimizer, defaults to 1e-4
learn_step (int, optional) – Learning frequency, defaults to 5
gamma (float, optional) – Discount factor, defaults to 0.99
tau (float, optional) – For soft update of target network parameters, defaults to 1e-3
mut (str, optional) – Most recent mutation to agent, defaults to None
double (bool, optional) – Use double Q-learning, defaults to False
actor_network (nn.Module, optional) – Custom actor network, defaults to None
device (str, optional) – Device for accelerated computing, ‘cpu’ or ‘cuda’, defaults to ‘cpu’
accelerator (accelerate.Accelerator(), optional) – Accelerator for distributed computing, defaults to None
wrap (bool, optional) – Wrap models for distributed training upon creation, defaults to True
- clone(index=None, wrap=True)#
Returns cloned agent identical to self.
- Parameters:
index (int, optional) – Index to keep track of agent for tournament selection and mutation, defaults to None
- getAction(state, epsilon=0, action_mask=None)#
Returns the next action to take in the environment. Epsilon is the probability of taking a random action, used for exploration. For epsilon-greedy behaviour, set epsilon to 0.
- learn(experiences)#
Updates agent network parameters to learn from experiences.
- Parameters:
experiences – List of batched states, actions, rewards, next_states, dones in that order.
- classmethod load(path, device='cpu', accelerator=None)#
Creates agent with properties and network weights loaded from path.
- Parameters:
path (string) – Location to load checkpoint from
device (str, optional) – Device for accelerated computing, ‘cpu’ or ‘cuda’, defaults to ‘cpu’
accelerator (accelerate.Accelerator(), optional) – Accelerator for distributed computing, defaults to None
- loadCheckpoint(path)#
Loads saved agent properties and network weights from checkpoint.
- Parameters:
path (string) – Location to load checkpoint from
- saveCheckpoint(path)#
Saves a checkpoint of agent properties and network weights to path.
- Parameters:
path (string) – Location to save checkpoint at
- softUpdate()#
Soft updates target network.
- test(env, swap_channels=False, max_steps=500, loop=3)#
Returns mean test score of agent in environment with epsilon-greedy policy.
- Parameters:
env (Gym-style environment) – The environment to be tested in
swap_channels (bool, optional) – Swap image channels dimension from last to first [H, W, C] -> [C, H, W], defaults to False
max_steps (int, optional) – Maximum number of testing steps, defaults to 500
loop (int, optional) – Number of testing loops/episodes to complete. The returned score is the mean. Defaults to 3