LLM REINFORCE¶
REINFORCE is the classic
score-function policy-gradient method. LLMREINFORCE brings this approach to
causal language model finetuning with turn-aware trajectories.
In AgileRL, the algorithm uses Return Batch Normalization (ReBN) — as popularized by the GEM paper — to improve stability in practice:
Turn-level Monte Carlo returns: discounted returns are computed across turns for each sampled trajectory.
Batch-normalized returns (ReBN): turn returns are z-scored across valid
(sample, turn)pairs before being broadcast to token-level advantages.Value-head-free training: unlike PPO-style actor-critic updates, this path optimizes the policy directly from normalized returns.
Variance Reduction¶
LLM policy-gradient algorithms differ mostly in how they reduce the variance of the Monte Carlo return signal. Three families show up in this codebase:
Learned value baseline (:ref:`LLM PPO<llmppo>`) — subtract a learned state-value estimate to form an advantage. Strong asymptotic variance reduction, but spends parameters and compute on a value head and is sensitive to value-function staleness.
Group-relative normalization (:ref:`GRPO<grpo>` and variants) — sample a group of
Grollouts per prompt and z-score their returns within the group. No critic to train; effective when rewards are sparse and rollouts are cheap, but the baseline degenerates as the group’s returns collapse and it ties variance reduction to having a large group size.Return Batch Normalization (this algorithm) — z-score returns across every valid
(sample, turn)pair in the batch. No critic and no group requirement, and it remains well-defined under arbitrary discount factors and per-step dense rewards (where group-relative normalization is awkward). The trade-off is that the baseline is global to the batch, so it reduces less variance than a state-conditioned critic on tasks where reward depends sharply on the prompt.
Note
ReBN itself is a specific application of the long-standing “advantage normalization” trick — z-scoring the policy-gradient signal across a batch — that has been standard in PPO implementations since OpenAI Baselines and was systematically studied by Engstrom et al. 2020 and Andrychowicz et al. 2021. The GEM paper names the specific variant that operates on per-transition Monte Carlo returns across the whole batch (rather than on GAE advantages or within a per-prompt group).
Example¶
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from agilerl.algorithms import LLMREINFORCE
model = AutoModelForCausalLM.from_pretrained(
"Qwen/Qwen2.5-0.5B",
torch_dtype=torch.bfloat16,
device_map="auto",
)
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2.5-0.5B")
agent = LLMREINFORCE(
actor_network=model,
pad_token_id=tokenizer.eos_token_id,
pad_token=tokenizer.eos_token,
device="cuda" if torch.cuda.is_available() else "cpu",
batch_size=8,
update_epochs=1,
gamma=0.99,
max_output_tokens=128,
max_model_len=1024,
)
Training¶
Typical training entry points are finetune_llm_reasoning and
finetune_llm_multiturn in agilerl.training.train_llm.
from datasets import Dataset
from agilerl.training.train_llm import (
finetune_llm_multiturn,
finetune_llm_reasoning,
)
from agilerl.llm_envs import ReasoningGym, TokenObservationWrapper
# Tiny mock reasoning dataset
train_ds = Dataset.from_dict(
{
"question": ["2+2?", "Capital of France?"],
"answer": ["4", "Paris"],
}
)
test_ds = Dataset.from_dict(
{
"question": ["3+3?"],
"answer": ["6"],
}
)
def reward_fn(completion: str, answer: str, question: str) -> float:
del question
return float(answer.lower() in completion.lower())
reasoning_env = ReasoningGym(
train_dataset=train_ds,
test_dataset=test_ds,
tokenizer=tokenizer,
reward_fn=reward_fn,
conversation_template=[{"role": "user", "content": "Q: {question}\nA:"}],
data_batch_size_per_gpu=2,
)
# 1) Single-turn / reasoning datasets (ReasoningGym)
trained_pop = finetune_llm_reasoning(
pop=[agent],
env=reasoning_env,
max_steps=2000,
evaluation_interval=50,
)
# 2) Multi-turn text environments (factory + wrapper)
class ToyMultiTurnEnv:
def reset(self, seed=None):
del seed
return "Start: What is 2+2?", {}
def step(self, action: str):
reward = 1.0 if "4" in action else 0.0
return "Done.", reward, True, False, {"correct": bool(reward)}
def env_factory():
return TokenObservationWrapper(
env=ToyMultiTurnEnv(),
tokenizer=tokenizer,
max_turns=4,
pad_id=tokenizer.eos_token_id,
max_model_len=1024,
max_output_tokens=128,
)
trained_pop = finetune_llm_multiturn(
pop=[agent],
max_turns=4,
env_factory=env_factory,
max_steps=2000,
evaluation_interval=50,
)
Saving and Loading Agents¶
To save an agent, use the save_llm_checkpoint function:
from agilerl.utils.utils import save_llm_checkpoint
save_llm_checkpoint(agent, "path/to/checkpoint")
Loading follows the standard Hugging Face from_pretrained flow for the base
model and any finetuned adapter.
Parameters¶
- class agilerl.algorithms.reinforce_llm.REINFORCE(*args: Any, **kwargs: Any)¶
Turn-level REINFORCE with Return Batch Normalization (ReBN) for LLM finetuning.
ReBN normalizes per-turn Monte Carlo returns across the entire batch of transitions. This gives per-turn credit assignment with arbitrary discount factors.
Optionally uses PPO-style clipped surrogate objectives for safe multi-epoch updates (controlled by
clip_coefandupdate_epochs).- Parameters:
pad_token_id (int) – Pad token id.
pad_token (str) – Pad token string.
model_name (str | None) – Model name or path.
actor_network (PreTrainedModelProtocol | None) – Pre-instantiated HuggingFace model.
model_config (dict[str, Any] | None) – Model configuration dict.
hp_config (HyperparameterConfig | None) – RL hyperparameter mutation configuration.
index (int) – Instance index for tournament selection.
batch_size (int) – Mini-batch size for learning.
beta (float) – KL penalty coefficient against the reference policy.
clip_coef (float) – PPO-style surrogate clipping coefficient.
gamma (float) – Discount factor for multi-turn returns.
lr (float) – Learning rate for the actor optimizer.
max_grad_norm (float) – Maximum gradient norm for clipping.
update_epochs (int) – Number of policy update epochs per batch.
temperature (float) – Sampling temperature for generation.
repetition_penalty (float) – Repetition penalty for generation.
top_p (float) – Top-p (nucleus) sampling parameter.
top_k (int) – Top-k sampling parameter.
min_p (float) – Min-p sampling parameter.
use_separate_reference_adapter (bool) – Use a dedicated LoRA adapter for the frozen reference policy.
calc_position_embeddings (bool) – Calculate position embeddings explicitly.
micro_batch_size_per_gpu (int | None) – Micro-batch size for gradient accumulation.
max_output_tokens (int | None) – Maximum new tokens per generation.
min_output_tokens (int | None) – Minimum new tokens per generation.
max_model_len (int | None) – Maximum context window length.
hf_generate_chunk_size (int | None, optional) – Number of prompts per HuggingFace generation chunk. Ignored when
use_vllm=True.use_memory_efficient_params (bool) – For colocated vLLM, offload the trainer’s own base to CPU during rollout (and bring it back for the training step) so the rollout engine and the trainer never both hold a base on the GPU. Defaults to True; inert without colocated vLLM, and disabled under DeepSpeed ZeRO-3.
lora_config (LoraConfigProtocol | None) – LoRA adapter configuration.
cosine_lr_schedule_config (CosineLRScheduleConfig | None) – Cosine LR schedule configuration.
accelerator (Accelerator | None) – HuggingFace Accelerator for distributed training.
device (str) – Device string.
wrap (bool) – Wrap models for distributed training upon creation.
clone (bool) – Whether this is a clone instantiation.
use_vllm (bool) – Use vLLM for generation.
vllm_config (VLLMConfig | None) – vLLM configuration.
seed (int) – Random seed.
advantage_granularity (Literal["turn", "token", "auto"]) – Policy-action granularity (ReBN advantage axis).
"turn"enforces turn-level advantages,"token"enforces token-level advantages, and"auto"uses token-level only when all samples are single-turn.action_granularity (str | None, optional) – Deprecated alias for
advantage_granularity; when set it overridesadvantage_granularityand emits aDeprecationWarning.importance_sampling_level (Literal["token", "turn", "trajectory"], optional) – IS / ratio-pooling level for the clipped surrogate, orthogonal to
advantage_granularity."token"(default) clips per token;"turn"pools the ratio per turn (requiresturn_idsinlearn());"trajectory"pools over the whole completion; the advantage is pooled to the same bucket. Turn/trajectory pooling cannot be token-chunked in the fused kernel, so setuse_liger_loss=Falsethere (the standard path is always memory-bounded).turn_ratio_pooling (Literal["sum", "mean"], optional) – Reduction used to pool per-token log-ratios into a per-turn ratio when
importance_sampling_level="turn"; ignored at token/trajectory level."sum"(default) yields the product ratio per turn — the standard, paper-aligned per-turn importance weight."mean"yields a length-normalized geometric-mean ratio (GSPO-style); reach for it on long or highly variable-length turns, where the product ratio lands far outside the clip band on every turn and saturates the clipped surrogate — length-normalizing keeps the per-turn ratio in range so the surrogate stays informative.gradient_checkpointing (bool) – Enable gradient checkpointing.
torch_compiler (str | None) – Torch compiler mode.
reduce_memory_peak (bool, optional) – Deprecated and ignored; previously hinted peak-memory batching. Configure
micro_batch_size_per_gpuinstead.cast_logprobs_to_fp32 (bool, optional) – When
True(default), run the per-token log-prob reduction (gather/logsumexp) in fp32 before casting back to the input dtype, for numerically stable log-probs.Falseruns it in the input dtype, saving a little memory at the cost of a per-token bf16 quantisation error that can bias importance-sampling ratios.fused_logprobs_chunk_rows (int | None, optional) – Standard (non-Liger) path only. Rows (tokens) per
(chunk_rows, vocab)logit tile when computing per-token log-probs via the fused-linear-logprob path. Peak logits memory isO(chunk_rows * vocab)regardless of batch/sequence length.None(default) auto-tunes to a ~256 MB fp32 tile.use_liger_loss (bool, optional) – Use the Liger fused policy loss, defaults to
False(requiresliger-kernel). Recommended for REINFORCE: via AgileRL’sLigerFusedLinearPolicyLossFunction(the same liger-based path as PPO, not the upstream Liger GRPO kernel), it is roughly memory-neutral with a mild speedup that grows with sequence length at token-level IS. Separate from the Liger model patches (fused RMSNorm/RoPE/SwiGLU), which apply wheneverliger-kernelis installed.quantization_config (BitsAndBytesConfig | None, optional) – Optional
transformers.BitsAndBytesConfigfor loading the base model in 4-/8-bit (QLoRA).lm_headis kept unquantized so the fused-linear-logprob path stays numerically exact.activation_offload (bool, optional) – When
True, run the training forward insidetorch.autograd.graph.save_on_cpuso tensors saved for backward live in pinned host RAM instead of GPU memory. Trades PCIe bandwidth for GPU memory (the win grows with sequence length); a no-op during rollout / reference forwards.fused_loss_chunk_rows (int | None, optional) – Rows per
(chunk_rows, vocab)logit tile in the token-level Liger fused policy loss.None(default) auto-tunes to a ~256 MB fp32 logit workspace — the same heuristic asfused_logprobs_chunk_rowson the standard path; pass an int to override.vllm_importance_sampling_correction (bool, optional) – When
True(default) anduse_vllm=True, correct the rollout/trainer log-prob mismatch by weighting each training token byclamp(exp(trainer - sampling), max=vllm_importance_sampling_cap). Active only for training rollouts; inert on the HuggingFace path and at eval.vllm_importance_sampling_cap (float, optional) – Upper clamp on the vLLM importance-sampling ratio (default
2.0), bounding the correction weight to limit variance from outlier tokens. Must be > 0.use_sequence_packing (bool, optional) – Opt in to padding-free sequence packing for the gradient forward pass. Only honoured under a FlashAttention-2 backend; otherwise inert.
lora_target_scope (str | None, optional) – Optional PEFT LoRA path scope for multimodal models (e.g.
"language_model"). Passed toadapt_lora_config_for_model().
- clone(index: int | None = None, wrap: bool = True) Self¶
Create a clone of the algorithm.
- Parameters:
- Returns:
A clone of the algorithm
- Return type:
- static copy_attributes(agent: EvolvableAlgorithm, clone: EvolvableAlgorithm) EvolvableAlgorithm¶
Copy the non-evolvable attributes of the algorithm to a clone.
- Parameters:
clone (EvolvableAlgorithm) – The clone of the algorithm.
- Returns:
The clone of the algorithm.
- Return type:
- evolvable_attributes(networks_only: bool = False) dict[str, EvolvableModuleProtocol | ModuleDictProtocol | Optimizer | dict[str, Optimizer] | OptimizerWrapperProtocol]¶
Return the attributes related to the evolvable networks in the algorithm. Includes attributes that are either EvolvableModule or ModuleDict objects, as well as the optimizers associated with the networks.
- get_action(obs: list[ReasoningPrompts] | ReasoningPrompts, training: bool = True, **kwargs: Any) ActionResult¶
Generate completion tokens for each prompt in the batch.
- Parameters:
obs (LLMObsType) – A single prompt dict or a list of HF-style prompt dicts.
training (bool) – If
False, use near-deterministic decoding where applicable.kwargs (Any) – Additional keyword arguments accepted for base-class signature compatibility. Unused in this implementation.
- Returns:
An
ActionResultof per-prompt completion token IDs and masks. When the vLLM sampling-mismatch correction is enabled (training rollouts on the vLLM path),sampling_logpscarries the captured per-row sampling logprobs; otherwise it isNone.- Return type:
ActionResult
- static get_action_dim(action_space: Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary | list[Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary]) tuple[int, ...]¶
Return the dimension of the action space as it pertains to the underlying networks (i.e. the output size of the networks).
- Parameters:
action_space (spaces.Space or list[spaces.Space].) – The action space of the environment.
- Returns:
The dimension of the action space.
- Return type:
int.
- get_policy() EvolvableModuleProtocol¶
Return the policy network of the algorithm.
- static get_state_dim(observation_space: Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary | list[Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary]) tuple[int, ...]¶
Return the dimension of the state space as it pertains to the underlying networks (i.e. the input size of the networks).
- static inspect_attributes(agent: EvolvableAlgorithm, input_args_only: bool = False) dict[str, Any]¶
Inspect and retrieve the attributes of the current object, excluding attributes related to the underlying evolvable networks (i.e. EvolvableModule, torch.optim.Optimizer) and with an option to include only the attributes that are input arguments to the constructor.
- learn(experiences: dict[str, ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts] | tuple[ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts, ...], turn_ids: Tensor | None = None, sampling_logps: list[Tensor | None] | None = None) dict[str, float]¶
Update actor using REINFORCE with Return Batch Normalization.
- Parameters:
experiences (ExperiencesType) –
(completion_ids, action_masks, rewards). For single-turn,rewardsis a flat tensor of scalars; for multi-turn, shape[batch, max_turns]per-turn rewards.turn_ids (torch.Tensor | None) – Optional
[batch, seq_len - 1]tensor of turn indices per token;-1for non-action tokens. IfNone, all action tokens are treated as turn0.sampling_logps (list[torch.Tensor | None] | None) – Optional per-row flat vLLM sampling logprobs (one 1-D tensor per trajectory, generated tokens only; concatenated across turns for multi-turn) for the vLLM sampling-mismatch correction. Parallel to the stacked
completion_idsrows.Nonedisables the correction for this update.
- Returns:
Dict with keys
mean_loss,mean_kl,mean_pg_loss,mean_entropy, averaged over all minibatch updates.- Return type:
- classmethod load(path: str, device: str | device = 'cpu', accelerator: Accelerator | None = None) None¶
Load an algorithm from a checkpoint.
- Parameters:
path (string) – Location to load checkpoint from.
device (str, optional) – Device to load the algorithm on, defaults to ‘cpu’
accelerator (Accelerator | None, optional) – Accelerator object for distributed computing, defaults to None
- Returns:
An instance of the algorithm
- Return type:
- load_checkpoint(path: str, load_optimizer: bool = False, overwrite_reference_adapter: bool = False, overwrite_critic_adapter: bool = True) None¶
Load adapter weights and algorithm state from a checkpoint directory.
Adapter roles restored on load:
actor— the trained policy. Always loaded.reference— the fixed policy used for KL / comparison. The checkpoint’sactoradapter is copied ontoreferenceso that SFT -> DPO -> GRPO chains work out of the box: the stage-N actor becomes the stage-N+1 reference.critic— optional value head. Loaded from disk if acritic/adapter is present, else copied fromactor, else left as the live fresh LoRA init.
The checkpoint’s LoRA config must match the live algorithm’s config; a mismatch raises
ValueError(re-create the agent with the checkpoint’s LoRA config to load it).- No DeepSpeed:
- lora_only=T, load_optimizer=T -> PEFT adapter load + optimizer
state from
attributes.pt
lora_only=T, load_optimizer=F -> PEFT adapter load only lora_only=F, load_optimizer=T -> torch load of actor +
optimizer from
attributes.ptlora_only=F, load_optimizer=F -> torch load of actor only
- DeepSpeed:
- lora_only=T, load_optimizer=T -> DeepSpeed engine load from
<path>/save_checkpoint
lora_only=T, load_optimizer=F -> PEFT adapter load lora_only=F, load_optimizer=T -> DeepSpeed engine load from
<path>/save_checkpoint- lora_only=F, load_optimizer=F ->
actor.load_state_dict(...) from
attributes.pt
When
load_optimizer=Truebut the checkpoint contains no optimizer state (e.g. it was saved withsave_optimizer=False), aUserWarningis emitted and a freshly-initialised optimizer is used.- Parameters:
path (str) – Directory containing a checkpoint written by
save_checkpoint().load_optimizer (bool) – If
True(default) also load the optimizer and LR scheduler state so training can resume. On DeepSpeed ZeRO ≥ 2 this reads a sharded checkpoint from<path>/save_checkpoint; otherwise optimizer state is read fromattributes.pt.
- classmethod population(size: int, observation_space: Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary | list[Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary], action_space: Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary | list[Box | Discrete | MultiDiscrete | Dict | Tuple | MultiBinary], wrapper_cls: type[SelfAgentWrapper] | None = None, wrapper_kwargs: dict[str, Any] | None = None, **kwargs) list[Self | SelfAgentWrapper]¶
Create a population of algorithms.
- Parameters:
size (int.) – The size of the population.
- Returns:
A list of algorithms.
- Return type:
- preprocess_observation(observation: ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts) Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor]¶
Preprocess observations (dummy) for forward pass through neural network.
- recompile() None¶
Recompile evolvable modules with
torch.compile.Iterates over
evolvable_attributesand compiles each one. Skipped when DeepSpeed is active becauseDeepSpeedEngineis not compatible withOptimizedModulewrapping.
- register_mutation_hook(hook: Callable) None¶
Register a hook to be executed after a mutation is performed on the algorithm.
- Parameters:
hook (Callable) – The hook to be executed after mutation.
- register_network_group(group: NetworkGroup) None¶
Set the evaluation network for the algorithm.
- Parameters:
name (str) – The name of the evaluation network.
- reinit_optimizers(optimizer: OptimizerConfig | None = None) None¶
Reinitialize the optimizers of an algorithm. If no optimizer is passed, all optimizers are reinitialized.
- Parameters:
optimizer (OptimizerConfig | None, optional) – The optimizer to reinitialize, defaults to None, in which case all optimizers are reinitialized.
- save_checkpoint(path: str, lora_only: bool = True, save_optimizer: bool = True, **kwargs: Any) None¶
Save adapter weights and algorithm state to a directory.
AgileRL never persists base-model weights when
lora_only=Truefor LLM algorithms: a checkpoint is a directory containing<adapter>/adapter_model.safetensors+adapter_config.json— one subdirectory per adapter inselected_adapters(alwaysactor, plusreference/criticwhen those adapters are configured). Written only whenlora_only=True.attributes.pt— algorithm hyperparameters, plus (optionally) the actor state dict and/or optimizer state dict depending on the cell below. Always present.save_checkpoint/— DeepSpeed ZeRO ≥ 2 sharded-checkpoint output. Present only when anAcceleratoris attached andsave_optimizer=True.
Behaviour per cell of the
(lora_only, save_optimizer, deepspeed)grid:- Plain (no accelerator):
- lora_only=T, save_optimizer=T -> PEFT adapter dirs on disk +
optimizer state in
attributes.pt
lora_only=T, save_optimizer=F -> PEFT adapter dirs only lora_only=F, save_optimizer=T -> full actor state_dict +
optimizer state in
attributes.ptlora_only=F, save_optimizer=F -> full actor state_dict in
attributes.pt- DeepSpeed:
- lora_only=T, save_optimizer=T -> engine tag dir (frozen params
excluded) + PEFT adapter dirs
lora_only=T, save_optimizer=F -> PEFT adapter dirs only lora_only=F, save_optimizer=T -> engine tag dir (frozen params
included)
- lora_only=F, save_optimizer=F -> gathered (ZeRO-3 aware) actor
state_dict injected into
attributes.pt
- Parameters:
path (str) – Directory to write the checkpoint into.
lora_only (bool) – If
True(default) only adapter weights are written to disk viasave_pretrained; the base model is shared across checkpoints and not serialised. IfFalse, the full actor state dict is persisted (intoattributes.pton the plain path, or into the DeepSpeed engine’s tag dir / gathered dict on the distributed path).save_optimizer (bool) – If
True(default) also persist the optimizer and LR scheduler state so training can resume. On DeepSpeed ZeRO ≥ 2 this writes a sharded checkpoint into<path>/save_checkpoint; otherwise optimizer state is included inattributes.pt.
- select_adapter(adapter_name: str) None¶
Temporarily switch adapter; restores the actor adapter on exit.
- Parameters:
adapter_name (str) – Name of the adapter to activate (“actor”, “critic”, “reference”).
- set_reference_policy(reference_update_tracker: int) None¶
Update the reference policy when the tracker advances past the stored value.
Base weights are immutable in AgileRL’s LoRA-only training: with
use_separate_reference_adapter=Truethe actor adapter is copied onto thereferenceadapter; without one the implicit reference (the base model with adapters disabled) cannot move, so the update request is acknowledged with a one-time warning and the KL anchor stays the initial policy.- Parameters:
reference_update_tracker (int) – The reference policy update tracker
- set_training_mode(training: bool) None¶
Set the training mode of the algorithm.
- Parameters:
training (bool) – If True, set the algorithm to training mode.
- test(env: ReasoningGym | MultiTurnEnv, loop: int = 1) Tensor¶
Return fitness (test) score tensor of llm on test sub-set.
ReasoningGym(and compatible dataset envs):resetreturns a batch of prompt dicts; eachstepaccepts completion id tensors and returns the next batch plus rewards.loopiterations advance the test dataloader that many times.- Parameters:
env (ReasoningGym | MultiTurnEnv) – A
ReasoningGymorTokenObservationWrapper.loop (int) – Number of outer test iterations (dataloader passes or episodes).
- Returns:
Concatenated per-step rewards from the test loop.
- Return type:
torch.Tensor
- to_device(*experiences: Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor]) tuple[Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor], ...]¶
Move experiences to the device.
- static update_lr(optimizer: torch.optim.Optimizer, lr: float | tuple[float, float], accelerator: Accelerator | None = None, scheduler_config: CosineLRScheduleConfig | None = None) tuple[Accelerator | None, SequentialLR | None]¶
Update the learning rate of the optimizer.
- Parameters:
optimizer (Optimizer) – Optimizer
lr (float | tuple[float, float]) – Learning rate value, or actor/critic pair.
accelerator (Accelerator | None) – Accelerator
scheduler_config (CosineLRScheduleConfig | None) – Scheduler configuration
- Returns:
Tuple of accelerator and scheduler
- Returns:
Accelerator
- Return type:
tuple[Accelerator | None, SequentialLR | None]
- use_adapter(adapter_name: str) None¶
Switch the active PEFT adapter, handling all side-effects.
For “reference”: switches adapter and freezes reference params (never trained). For all others: switches adapter and restores requires_grad=True on all training adapter LoRA params so that DeepSpeed ZeRO-2 gradient bucket hooks keep firing correctly.
- Parameters:
adapter_name (str) – Name of the adapter to activate (“actor”, “critic”, “reference”).