On-Policy Rollout Functions

These helpers gather transitions from an environment using an on-policy agent. They fill the agent’s rollout buffer so that calling agent.learn() will update the policy from the collected data.

agilerl.rollouts.collect_rollouts(agent: PPO, env: str | Env | VectorEnv | AsyncVectorEnv, n_steps: int | None = None, **kwargs) list[float]

Collect rollouts for non-recurrent on-policy algorithms.

Parameters:
  • agent (RLAlgorithm) – The agent to collect rollouts for.

  • env (GymEnvType) – The environment to collect rollouts from.

  • n_steps (int | None) – The number of steps to collect rollouts for.

Returns:

The list of scores for the episodes completed in the rollouts

Return type:

list[float]

agilerl.rollouts.collect_rollouts_recurrent(agent: PPO, env: str | Env | VectorEnv | AsyncVectorEnv, n_steps: int | None = None, **kwargs) list[float]

Collect rollouts for recurrent on-policy algorithms.

Parameters:
  • agent (RLAlgorithm) – The agent to collect rollouts for.

  • env (GymEnvType) – The environment to collect rollouts from.

  • n_steps (int | None) – The number of steps to collect rollouts for.

Returns:

The list of scores for the episodes completed in the rollouts

Return type:

list[float]

Example

Using a non-recurrent PPO agent:

import gymnasium as gym
from agilerl.algorithms import PPO
from agilerl.rollouts import collect_rollouts

env = gym.make("CartPole-v1")
agent = PPO(env.observation_space, env.action_space, use_rollout_buffer=True)

collect_rollouts(agent, env, n_steps=agent.learn_step)
agent.learn()

For recurrent policies, use collect_rollouts_recurrent:

num_envs = 4
env = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")] * num_envs)
agent = PPO(
    env.single_observation_space,
    env.single_action_space,
    use_rollout_buffer=True,
    recurrent=True,
    num_envs=num_envs,
)

collect_rollouts_recurrent(agent, env, n_steps=5)
agent.learn()