LLM Proximal Policy Optimization (LLM PPO)¶
PPO (Proximal Policy Optimization)
is a policy-gradient method that keeps updates inside a clipped trust region.
LLMPPO adapts this idea to causal language models and is designed for both
single-turn and multi-turn fine-tuning.
In AgileRL, the implementation is turn-aware:
Turn-level credit assignment: each generated turn is treated as one RL action, with discounting across turns.
Actor-critic optimization: policy and value adapters are updated jointly, with clipped policy/value losses plus entropy regularization.
Single-turn and multi-turn parity: single-turn prompting is treated as the special case where all action tokens belong to turn
0.
This algorithm can therefore be used in multi-turn agentic finetuning or single-turn reasoning tasks.
Example¶
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from agilerl.algorithms import LLMPPO
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-0.5B",
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
agent = LLMPPO(
actor_network=model,
pad_token_id=tokenizer.eos_token_id,
pad_token=tokenizer.eos_token,
device="cuda" if torch.cuda.is_available() else "cpu",
batch_size=8,
update_epochs=1,
clip_coef=0.2,
max_output_tokens=128,
max_model_len=1024,
)
Training¶
Typical training entry points are finetune_llm_reasoning and
finetune_llm_multiturn in agilerl.training.train_llm.
from datasets import Dataset
from agilerl.training.train_llm import (
finetune_llm_multiturn,
finetune_llm_reasoning,
)
from agilerl.llm_envs import ReasoningGym, TokenObservationWrapper
# Tiny mock reasoning dataset
train_ds = Dataset.from_dict(
{
"question": ["2+2?", "Capital of France?"],
"answer": ["4", "Paris"],
}
)
test_ds = Dataset.from_dict(
{
"question": ["3+3?"],
"answer": ["6"],
}
)
def reward_fn(completion: str, answer: str, question: str) -> float:
del question
return float(answer.lower() in completion.lower())
reasoning_env = ReasoningGym(
train_dataset=train_ds,
test_dataset=test_ds,
tokenizer=tokenizer,
reward_fn=reward_fn,
conversation_template=[{"role": "user", "content": "Q: {question}\nA:"}],
data_batch_size_per_gpu=2,
)
# 1) Single-turn / reasoning datasets (ReasoningGym)
trained_pop = finetune_llm_reasoning(
pop=[agent],
env=reasoning_env,
max_steps=2000,
evaluation_interval=50,
)
# 2) Multi-turn text environments (factory + wrapper)
class ToyMultiTurnEnv:
def reset(self, seed=None):
del seed
return "Start: What is 2+2?", {}
def step(self, action: str):
reward = 1.0 if "4" in action else 0.0
return "Done.", reward, True, False, {"correct": bool(reward)}
def env_factory():
return TokenObservationWrapper(
env=ToyMultiTurnEnv(),
tokenizer=tokenizer,
max_turns=4,
pad_id=tokenizer.eos_token_id,
max_model_len=1024,
max_output_tokens=128,
)
trained_pop = finetune_llm_multiturn(
pop=[agent],
max_turns=4,
env_factory=env_factory,
max_steps=2000,
evaluation_interval=50,
)
Saving and Loading Agents¶
To save an agent, use the save_llm_checkpoint function:
from agilerl.utils.utils import save_llm_checkpoint
save_llm_checkpoint(agent, "path/to/checkpoint")
As with other AgileRL LLM algorithms, loading is done with Hugging Face
from_pretrained APIs for the base model and adapter.
Parameters¶
- class agilerl.algorithms.ppo_llm.PPO(*args: Any, **kwargs: Any)¶
Turn-level PPO for LLM finetuning with actor/reference adapters.
Each generation sequence (turn) is treated as a single RL action. GAE discounts between turns, not between tokens within a turn. Single-turn is the special case where all action tokens share turn 0.
- Parameters:
pad_token_id (int) – Token id used for sequence padding.
pad_token (str) – Padding token string.
model_name (str | None, optional) – HF model name or local path used when building internally.
actor_network (Any | None, optional) – Pre-built actor model. If omitted,
model_nameis used.model_config (dict[str, Any] | None, optional) – Extra kwargs passed when constructing a model from
model_name.hp_config (HyperparameterConfig | None, optional) – Hyperparameter mutation configuration.
index (int, optional) – Population index used by evolutionary workflows.
batch_size (int, optional) – Batch size used for PPO updates.
beta (float, optional) – KL penalty coefficient against the reference policy.
vf_coef (float, optional) – Value loss coefficient.
clip_coef (float, optional) – PPO clipping coefficient.
gamma (float, optional) – Discount factor across turns.
gae_lambda (float, optional) – GAE lambda used for turn-level advantage estimation.
lr_actor (float, optional) – Actor learning rate.
lr_critic (float | None, optional) – Critic/value-head learning rate. If
None,lr_actoris used.max_grad_norm (float, optional) – Gradient clipping norm.
update_epochs (int, optional) – Number of PPO epochs per update.
temperature (float, optional) – Sampling temperature for generation.
repetition_penalty (float, optional) – Repetition penalty used during generation.
top_p (float, optional) – Nucleus sampling threshold.
top_k (int, optional) – Top-k sampling threshold.
min_p (float, optional) – Minimum probability cutoff for sampling.
use_separate_reference_adapter (bool, optional) – Whether to keep a separate reference adapter.
calc_position_embeddings (bool, optional) – Whether to compute position embeddings.
micro_batch_size_per_gpu (int | None, optional) – Optional target micro-batch size per GPU.
max_output_tokens (int | None, optional) – Maximum newly generated tokens per completion.
min_output_tokens (int | None, optional) – Minimum newly generated tokens per completion.
max_model_len (int | None, optional) – Maximum model context length.
hf_generate_chunk_size (int | None, optional) – Number of prompts per HuggingFace generation chunk. Ignored when
use_vllm=True.lora_config (LoraConfigProtocol | None, optional) – LoRA configuration.
cosine_lr_schedule_config (CosineLRScheduleConfig | None, optional) – Cosine LR scheduler configuration.
accelerator (Accelerator | None, optional) – Optional HuggingFace
Acceleratorinstance.device (str, optional) – Device string used when no accelerator is provided.
wrap (bool, optional) – Whether to wrap models for distributed execution.
clone (bool, optional) – Whether this instance is being created as a clone.
use_vllm (bool, optional) – Whether to route generation through vLLM.
use_memory_efficient_params (bool, optional) – Enable memory-efficient parameter handling.
vllm_config (VLLMConfig | None, optional) – vLLM runtime configuration.
seed (int, optional) – Random seed.
turn_level_clip (bool, optional) – Apply clipping at per-turn ratio level.
action_granularity (Literal["turn", "token", "auto"], optional) – PPO action granularity.
"turn"enforces turn-level updates,"token"enforces token-level updates, and"auto"uses token-level only when all samples are single-turn.turn_value_reduction (str, optional) – Aggregation used to map token critic values to turn values.
"mean"reproduces existing behavior,"final_value"uses the final action token value in each turn.adv_whitening (bool, optional) – Whether to whiten computed advantages before PPO optimization.
gradient_checkpointing (bool, optional) – Enable gradient checkpointing.
torch_compiler (str | None, optional) – Optional torch compile mode.
- clone(index: int | None = None, wrap: bool = True) Self¶
Create a clone of the algorithm.
- Parameters:
- Returns:
A clone of the algorithm
- Return type:
- static copy_attributes(agent: EvolvableAlgorithm, clone: EvolvableAlgorithm) EvolvableAlgorithm¶
Copy the non-evolvable attributes of the algorithm to a clone.
- Parameters:
clone (EvolvableAlgorithm) – The clone of the algorithm.
- Returns:
The clone of the algorithm.
- Return type:
- evolvable_attributes(networks_only: bool = False) dict[str, EvolvableModuleProtocol | ModuleDictProtocol | Optimizer | dict[str, Optimizer] | OptimizerWrapperProtocol]¶
Return the attributes related to the evolvable networks in the algorithm. Includes attributes that are either EvolvableModule or ModuleDict objects, as well as the optimizers associated with the networks.
- get_action(obs: list[ReasoningPrompts] | ReasoningPrompts, training: bool = True, **kwargs: Any) tuple[list[Tensor], list[Tensor]]¶
Generate completion tokens for each prompt in the batch.
- Parameters:
obs (LLMObsType) – A single prompt dict or a list of HF-style prompt dicts.
training (bool) – If
False, use near-deterministic decoding where applicable.kwargs (Any) – Additional keyword arguments accepted for base-class compatibility.
- Returns:
Per-prompt completion token IDs and masks over generated positions.
- Return type:
- static get_action_dim(action_space: Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary | list[Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary]) tuple[int, ...]¶
Return the dimension of the action space as it pertains to the underlying networks (i.e. the output size of the networks).
- Parameters:
action_space (spaces.Space or list[spaces.Space].) – The action space of the environment.
- Returns:
The dimension of the action space.
- Return type:
int.
- get_policy() EvolvableModuleProtocol¶
Return the policy network of the algorithm.
- static get_state_dim(observation_space: Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary | list[Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary]) tuple[int, ...]¶
Return the dimension of the state space as it pertains to the underlying networks (i.e. the input size of the networks).
- static inspect_attributes(agent: EvolvableAlgorithm, input_args_only: bool = False) dict[str, Any]¶
Inspect and retrieve the attributes of the current object, excluding attributes related to the underlying evolvable networks (i.e. EvolvableModule, torch.optim.Optimizer) and with an option to include only the attributes that are input arguments to the constructor.
- learn(experiences: dict[str, ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts] | tuple[ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts, ...], turn_ids: Tensor | None = None) dict[str, float]¶
Update actor and critic adapters using configured PPO granularity.
- Parameters:
experiences (ExperiencesType) –
(completion_ids, action_masks, rewards). For single-turn,rewardsis a flat tensor of scalars; for multi-turn, shape[batch, max_turns]per-turn rewards.turn_ids (torch.Tensor | None) – Optional
[batch, seq_len - 1]tensor of turn indices;-1for non-action tokens. IfNone, all action tokens are turn0.
- Returns:
Mean training metrics across PPO minibatch updates.
- Return type:
- classmethod load(path: str, device: str | device = 'cpu', accelerator: Accelerator | None = None) None¶
Load an algorithm from a checkpoint.
- Parameters:
path (string) – Location to load checkpoint from.
device (str, optional) – Device to load the algorithm on, defaults to ‘cpu’
accelerator (Accelerator | None, optional) – Accelerator object for distributed computing, defaults to None
- Returns:
An instance of the algorithm
- Return type:
- load_checkpoint(path: str, load_optimizer: bool = False, overwrite_reference_adapter: bool = False, overwrite_critic_adapter: bool = True, merge_lora_configs: bool = False) None¶
Load adapter weights and algorithm state from a checkpoint directory.
Adapter roles restored on load:
actor— the trained policy. Always loaded.reference— the fixed policy used for KL / comparison. The checkpoint’sactoradapter is copied ontoreferenceso that SFT -> DPO -> GRPO chains work out of the box: the stage-N actor becomes the stage-N+1 reference.critic— optional value head. Loaded from disk if acritic/adapter is present, else copied fromactor, else left as the live fresh LoRA init.
LoRA config reconciliation: when the checkpoint’s config and the live algorithm’s config disagree, loading fails fast by default. Pass
merge_lora_configs=Trueto merge them for compatibility:r(rank) ->max(current, checkpoint); the smaller side’s weights are padded into the top-left rank slice of the larger adapter (see_pad_adapter_state_to_live_shape()).target_modules/modules_to_save-> union.Any other mismatched field -> current value wins, with a warning.
Any adapter whose live config ends up differing from the selected target config is rebuilt via
_reconfigure_adapters_to_match()before weights are loaded, so tensors always land in the correct shape.- No DeepSpeed:
- lora_only=T, load_optimizer=T -> PEFT adapter load + optimizer
state from
attributes.pt
lora_only=T, load_optimizer=F -> PEFT adapter load only lora_only=F, load_optimizer=T -> torch load of actor +
optimizer from
attributes.ptlora_only=F, load_optimizer=F -> torch load of actor only
- DeepSpeed:
- lora_only=T, load_optimizer=T -> DeepSpeed engine load from
<path>/save_checkpoint
lora_only=T, load_optimizer=F -> PEFT adapter load lora_only=F, load_optimizer=T -> DeepSpeed engine load from
<path>/save_checkpoint- lora_only=F, load_optimizer=F ->
actor.load_state_dict(...) from
attributes.pt
When
load_optimizer=Truebut the checkpoint contains no optimizer state (e.g. it was saved withsave_optimizer=False), aUserWarningis emitted and a freshly-initialised optimizer is used.- Parameters:
path (str) – Directory containing a checkpoint written by
save_checkpoint().load_optimizer (bool) – If
True(default) also load the optimizer and LR scheduler state so training can resume. On DeepSpeed ZeRO ≥ 2 this reads a sharded checkpoint from<path>/save_checkpoint; otherwise optimizer state is read fromattributes.pt.merge_lora_configs (bool) – If
True, allow loading checkpoints whose LoRA config differs from the live agent by reconciling them. IfFalse(default), mismatched LoRA configs raiseValueError.
- classmethod population(size: int, observation_space: Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary | list[Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary], action_space: Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary | list[Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary], wrapper_cls: type[SelfAgentWrapper] | None = None, wrapper_kwargs: dict[str, Any] | None = None, **kwargs) list[Self | SelfAgentWrapper]¶
Create a population of algorithms.
- Parameters:
size (int.) – The size of the population.
- Returns:
A list of algorithms.
- Return type:
- preprocess_observation(observation: ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts) Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor]¶
Preprocess observations (dummy) for forward pass through neural network.
- recompile() None¶
Recompile evolvable modules with
torch.compile.Iterates over
evolvable_attributesand compiles each one. Skipped when DeepSpeed is active becauseDeepSpeedEngineis not compatible withOptimizedModulewrapping.
- register_mutation_hook(hook: Callable) None¶
Register a hook to be executed after a mutation is performed on the algorithm.
- Parameters:
hook (Callable) – The hook to be executed after mutation.
- register_network_group(group: NetworkGroup) None¶
Set the evaluation network for the algorithm.
- Parameters:
name (str) – The name of the evaluation network.
- reinit_optimizers(optimizer: OptimizerConfig | None = None) None¶
Reinitialize the optimizers of an algorithm. If no optimizer is passed, all optimizers are reinitialized.
- Parameters:
optimizer (OptimizerConfig | None, optional) – The optimizer to reinitialize, defaults to None, in which case all optimizers are reinitialized.
- save_checkpoint(path: str, lora_only: bool = True, save_optimizer: bool = True, **kwargs: Any) None¶
Save adapter weights and algorithm state to a directory.
AgileRL never persists base-model weights when
lora_only=Truefor LLM algorithms: a checkpoint is a directory containing<adapter>/adapter_model.safetensors+adapter_config.json— one subdirectory per adapter inselected_adapters(alwaysactor, plusreference/criticwhen those adapters are configured). Written only whenlora_only=True.attributes.pt— algorithm hyperparameters, plus (optionally) the actor state dict and/or optimizer state dict depending on the cell below. Always present.save_checkpoint/— DeepSpeed ZeRO ≥ 2 sharded-checkpoint output. Present only when anAcceleratoris attached andsave_optimizer=True.
Behaviour per cell of the
(lora_only, save_optimizer, deepspeed)grid:- Plain (no accelerator):
- lora_only=T, save_optimizer=T -> PEFT adapter dirs on disk +
optimizer state in
attributes.pt
lora_only=T, save_optimizer=F -> PEFT adapter dirs only lora_only=F, save_optimizer=T -> full actor state_dict +
optimizer state in
attributes.ptlora_only=F, save_optimizer=F -> full actor state_dict in
attributes.pt- DeepSpeed:
- lora_only=T, save_optimizer=T -> engine tag dir (frozen params
excluded) + PEFT adapter dirs
lora_only=T, save_optimizer=F -> PEFT adapter dirs only lora_only=F, save_optimizer=T -> engine tag dir (frozen params
included)
- lora_only=F, save_optimizer=F -> gathered (ZeRO-3 aware) actor
state_dict injected into
attributes.pt
- Parameters:
path (str) – Directory to write the checkpoint into.
lora_only (bool) – If
True(default) only adapter weights are written to disk viasave_pretrained; the base model is shared across checkpoints and not serialised. IfFalse, the full actor state dict is persisted (intoattributes.pton the plain path, or into the DeepSpeed engine’s tag dir / gathered dict on the distributed path).save_optimizer (bool) – If
True(default) also persist the optimizer and LR scheduler state so training can resume. On DeepSpeed ZeRO ≥ 2 this writes a sharded checkpoint into<path>/save_checkpoint; otherwise optimizer state is included inattributes.pt.
- select_adapter(adapter_name: str) None¶
Temporarily switch adapter; restores the actor adapter on exit.
- Parameters:
adapter_name (str) – Name of the adapter to activate (“actor”, “critic”, “reference”).
- set_reference_policy(reference_update_tracker: int) None¶
Update the reference policy when the reference policy update tracker is greater than the current reference policy update tracker.
- Parameters:
reference_update_tracker (int) – The reference policy update tracker
- set_training_mode(training: bool) None¶
Set the training mode of the algorithm.
- Parameters:
training (bool) – If True, set the algorithm to training mode.
- test(env: ReasoningGym | MultiTurnEnv, loop: int = 1) Tensor¶
Return fitness (test) score tensor of llm on test sub-set.
ReasoningGym(and compatible dataset envs):resetreturns a batch of prompt dicts; eachstepaccepts completion id tensors and returns the next batch plus rewards.loopiterations advance the test dataloader that many times.- Parameters:
env (ReasoningGym | MultiTurnEnv) – A
ReasoningGymorTokenObservationWrapper.loop (int) – Number of outer test iterations (dataloader passes or episodes).
- Returns:
Concatenated per-step rewards from the test loop.
- Return type:
torch.Tensor
- to_device(*experiences: Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor]) tuple[Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor], ...]¶
Move experiences to the device.
- static update_lr(optimizer: torch.optim.Optimizer, lr: float | tuple[float, float], accelerator: Accelerator | None = None, scheduler_config: CosineLRScheduleConfig | None = None) tuple[Accelerator | None, SequentialLR | None]¶
Update the learning rate of the optimizer.
- Parameters:
optimizer (Optimizer) – Optimizer
lr (float | tuple[float, float]) – Learning rate value, or actor/critic pair.
accelerator (Accelerator | None) – Accelerator
scheduler_config (CosineLRScheduleConfig | None) – Scheduler configuration
- Returns:
Tuple of accelerator and scheduler
- Returns:
Accelerator
- use_adapter(adapter_name: str) None¶
Switch the active PEFT adapter, handling all side-effects.
For “reference”: switches adapter and freezes reference params (never trained). For all others: switches adapter and restores requires_grad=True on all training adapter LoRA params so that DeepSpeed ZeRO-2 gradient bucket hooks keep firing correctly.
- Parameters:
adapter_name (str) – Name of the adapter to activate (“actor”, “critic”, “reference”).