Training¶
If you are using a Gym-style environment, it is easiest to use our training function, which returns a population of trained agents and logged training metrics.
If you are training on static, offline data, you can use our offline RL training function.
The multi agent training function handles Pettingzoo-style environments and multi-agent algorithms.
- agilerl.training.train_off_policy.train_off_policy(env, env_name, algo, pop, memory, INIT_HP=None, MUT_P=None, swap_channels=False, n_episodes=2000, max_steps=500, evo_epochs=5, evo_loop=1, eps_start=1.0, eps_end=0.1, eps_decay=0.995, target=200.0, n_step=False, per=False, noisy=False, n_step_memory=None, tournament=None, mutation=None, checkpoint=None, checkpoint_path=None, save_elite=False, elite_path=None, wb=False, verbose=True, accelerator=None, wandb_api_key=None)¶
The general online RL training function. Returns trained population of agents and their fitnesses.
- Parameters:
env (Gym-style environment) – The environment to train in. Can be vectorized.
env_name (str) – Environment name
algo (str) – RL algorithm name
memory (object, optional) – Experience Replay Buffer
INIT_HP (dict, optional) – Dictionary containing initial hyperparameters, defaults to None
MUT_P (dict, optional) – Dictionary containing mutation parameters, defaults to None
swap_channels (bool, optional) – Swap image channels dimension from last to first [H, W, C] -> [C, H, W], defaults to False
n_episodes (int, optional) – Maximum number of training episodes, defaults to 2000
max_steps (int, optional) – Maximum number of steps in environment per episode, defaults to 500
evo_epochs (int, optional) – Evolution frequency (episodes), defaults to 5
evo_loop (int, optional) – Number of evaluation episodes, defaults to 1
eps_start (float, optional) – Maximum exploration - initial epsilon value, defaults to 1.0
eps_end (float, optional) – Minimum exploration - final epsilon value, defaults to 0.1
eps_decay (float, optional) – Epsilon decay per episode, defaults to 0.995
target (float, optional) – Target score for early stopping, defaults to 200.
n_step (bool, optional) – Use multi-step experience replay buffer, defaults to False
per (bool, optional) – Using prioritized experience replay buffer, defaults to False
noisy (bool, optional) – Using noisy network exploration, defaults to False
memory – Multi-step Experience Replay Buffer to be used alongside Prioritized ERB, defaults to None
tournament (object, optional) – Tournament selection object, defaults to None
mutation (object, optional) – Mutation object, defaults to None
checkpoint (int, optional) – Checkpoint frequency (episodes), defaults to None
checkpoint_path (str, optional) – Location to save checkpoint, defaults to None
save_elite (bool, optional) – Boolean flag indicating whether to save elite member at the end of training, defaults to False
elite_path (str, optional) – Location to save elite agent, defaults to None
wb (bool, optional) – Weights & Biases tracking, defaults to False
verbose (bool, optional) – Display training stats, defaults to True
accelerator (accelerate.Accelerator(), optional) – Accelerator for distributed computing, defaults to None
wandb_api_key (str, optional) – API key for Weights & Biases, defaults to None
- agilerl.training.train_offline.train_offline(env, env_name, dataset, algo, pop, memory, INIT_HP=None, MUT_P=None, swap_channels=False, n_episodes=2000, max_steps=500, evo_epochs=5, evo_loop=1, target=200.0, tournament=None, mutation=None, checkpoint=None, checkpoint_path=None, save_elite=False, elite_path=None, wb=False, verbose=True, accelerator=None, minari_dataset_id=None, remote=False, wandb_api_key=None)¶
The general offline RL training function. Returns trained population of agents and their fitnesses.
- Parameters:
env (Gym-style environment) – The environment to train in
env_name (str) – Environment name
dataset (h5py-style dataset) – Offline RL dataset
algo (str) – RL algorithm name
memory (object) – Experience Replay Buffer
INIT_HP (dict, optional) – Dictionary containing initial hyperparameters, defaults to None
MUT_P (dict, optional) – Dictionary containing mutation parameters, defaults to None
swap_channels (bool, optional) – Swap image channels dimension from last to first [H, W, C] -> [C, H, W], defaults to False
n_episodes (int, optional) – Maximum number of training episodes, defaults to 2000
max_steps (int, optional) – Maximum number of steps in environment per episode, defaults to 500
evo_epochs (int, optional) – Evolution frequency (episodes), defaults to 5
evo_loop (int, optional) – Number of evaluation episodes, defaults to 1
target (float, optional) – Target score for early stopping, defaults to 200.
tournament (object, optional) – Tournament selection object, defaults to None
mutation (object, optional) – Mutation object, defaults to None
checkpoint (int, optional) – Checkpoint frequency (episodes), defaults to None
checkpoint_path (str, optional) – Location to save checkpoint, defaults to None
save_elite (bool, optional) – Boolean flag indicating whether to save elite member at the end of training, defaults to False
elite_path (str, optional) – Location to save elite agent, defaults to None
wb (bool, optional) – Weights & Biases tracking, defaults to False
verbose (bool, optional) – Display training stats, defaults to True
accelerator (accelerate.Accelerator(), optional) – Accelerator for distributed computing, defaults to None
wandb_api_key (str, optional) – API key for Weights & Biases, defaults to None
- agilerl.training.train_on_policy.train_on_policy(env, env_name, algo, pop, INIT_HP=None, MUT_P=None, swap_channels=False, n_episodes=2000, max_steps=500, evo_epochs=5, evo_loop=1, target=200.0, tournament=None, mutation=None, checkpoint=None, checkpoint_path=None, save_elite=False, elite_path=None, wb=False, verbose=True, accelerator=None, wandb_api_key=None)¶
The general on-policy RL training function. Returns trained population of agents and their fitnesses.
- Parameters:
env (Gym-style environment) – The environment to train in. Can be vectorized.
env_name (str) – Environment name
algo (str) – RL algorithm name
INIT_HP (dict, optional) – Dictionary containing initial hyperparameters, defaults to None
MUT_P (dict, optional) – Dictionary containing mutation parameters, defaults to None
swap_channels (bool, optional) – Swap image channels dimension from last to first [H, W, C] -> [C, H, W], defaults to False
n_episodes (int, optional) – Maximum number of training episodes, defaults to 2000
max_steps (int, optional) – Maximum number of steps in environment per episode, defaults to 500
evo_epochs (int, optional) – Evolution frequency (episodes), defaults to 5
evo_loop (int, optional) – Number of evaluation episodes, defaults to 1
target (float, optional) – Target score for early stopping, defaults to 200.
tournament (object, optional) – Tournament selection object, defaults to None
mutation (object, optional) – Mutation object, defaults to None
checkpoint (int, optional) – Checkpoint frequency (episodes), defaults to None
checkpoint_path (str, optional) – Location to save checkpoint, defaults to None
save_elite (bool, optional) – Boolean flag indicating whether to save elite member at the end of training, defaults to False
elite_path (str, optional) – Location to save elite agent, defaults to None
wb (bool, optional) – Weights & Biases tracking, defaults to False
verbose (bool, optional) – Display training stats, defaults to True
accelerator (accelerate.Accelerator(), optional) – Accelerator for distributed computing, defaults to None
wandb_api_key (str, optional) – API key for Weights & Biases, defaults to None
- agilerl.training.train_multi_agent.train_multi_agent(env, env_name, algo, pop, memory, INIT_HP=None, MUT_P=None, net_config=None, swap_channels=False, n_episodes=2000, max_steps=25, evo_epochs=5, evo_loop=5, eps_start=1.0, eps_end=0.1, eps_decay=0.995, target=200.0, tournament=None, mutation=None, checkpoint=None, checkpoint_path=None, save_elite=False, elite_path=None, wb=False, verbose=True, accelerator=None, wandb_api_key=None)¶
The general online multi-agent RL training function. Returns trained population of agents and their fitnesses.
- Parameters:
env (Gym-style environment) – The environment to train in. Can be vectorized.
env_name (str) – Environment name
algo (str) – RL algorithm name
memory (object) – Experience Replay Buffer
INIT_HP (dict) – Dictionary containing initial hyperparameters.
MUT_P (dict, optional) – Dictionary containing mutation parameters, defaults to None
net_config (dict) – Network configuration dictionary, defaults to None
swap_channels (bool, optional) – Swap image channels dimension from last to first [H, W, C] -> [C, H, W], defaults to False
n_episodes (int, optional) – Maximum number of training episodes, defaults to 2000
max_steps (int, optional) – Maximum number of steps in environment per episode, defaults to 500
evo_epochs (int, optional) – Evolution frequency (episodes), defaults to 5
evo_loop (int, optional) – Number of evaluation episodes, defaults to 1
eps_start (float, optional) – Maximum exploration - initial epsilon value, defaults to 1.0
eps_end (float, optional) – Minimum exploration - final epsilon value, defaults to 0.1
eps_decay (float, optional) – Epsilon decay per episode, defaults to 0.995
target (float, optional) – Target score for early stopping, defaults to 200.
tournament (object, optional) – Tournament selection object, defaults to None
mutation (object, optional) – Mutation object, defaults to None
checkpoint (int, optional) – Checkpoint frequency (episodes), defaults to None
checkpoint_path (str, optional) – Location to save checkpoint, defaults to None
save_elite (bool, optional) – Boolean flag indicating whether to save elite member at the end of training, defaults to False
elite_path (str, optional) – Location to save elite agent, defaults to None
wb (bool, optional) – Weights & Biases tracking, defaults to False
verbose (bool, optional) – Display training stats, defaults to True
accelerator (accelerate.Accelerator(), optional) – Accelerator for distributed computing, defaults to None
wandb_api_key (str, optional) – API key for Weights & Biases, defaults to None