On-Policy Rollout Functions¶
These helpers gather transitions from an environment using an on-policy agent.
They fill the agent’s rollout buffer so that calling agent.learn() will
update the policy from the collected data.
- agilerl.rollouts.collect_rollouts(agent: PPO, env: str | Env | VectorEnv | AsyncVectorEnv, n_steps: int | None = None, **kwargs) list[float]¶
Collect rollouts for non-recurrent on-policy algorithms.
- Parameters:
agent (RLAlgorithm) – The agent to collect rollouts for.
env (GymEnvType) – The environment to collect rollouts from.
n_steps (int | None) – The number of steps to collect rollouts for.
- Returns:
The list of scores for the episodes completed in the rollouts
- Return type:
- agilerl.rollouts.collect_rollouts_recurrent(agent: PPO, env: str | Env | VectorEnv | AsyncVectorEnv, n_steps: int | None = None, **kwargs) list[float]¶
Collect rollouts for recurrent on-policy algorithms.
- Parameters:
agent (RLAlgorithm) – The agent to collect rollouts for.
env (GymEnvType) – The environment to collect rollouts from.
n_steps (int | None) – The number of steps to collect rollouts for.
- Returns:
The list of scores for the episodes completed in the rollouts
- Return type:
Example¶
Using a non-recurrent PPO agent:
import gymnasium as gym
from agilerl.algorithms import PPO
from agilerl.rollouts import collect_rollouts
env = gym.make("CartPole-v1")
agent = PPO(env.observation_space, env.action_space, use_rollout_buffer=True)
collect_rollouts(agent, env, n_steps=agent.learn_step)
agent.learn()
For recurrent policies, use collect_rollouts_recurrent:
num_envs = 4
env = gym.vector.SyncVectorEnv([lambda: gym.make("CartPole-v1")] * num_envs)
agent = PPO(
env.single_observation_space,
env.single_action_space,
use_rollout_buffer=True,
recurrent=True,
num_envs=num_envs,
)
collect_rollouts_recurrent(agent, env, n_steps=5)
agent.learn()