LLM Utils

Helpers for LLM training (DeepSpeed ZeRO, model loading, eval sampling). For Gymnasium-style LLM datasets, see agilerl.llm_envs (also re-exported here for backwards compatibility).

agilerl.utils.llm_utils.gather_if_zero3(zero_stage: int, params: list[Tensor], modifier_rank: int | None = None) Generator[None, None, None]

Conditional context manager for setting the zero stage for the model.

Parameters:
  • zero_stage (int) – The zero stage

  • params (list[torch.Tensor]) – The parameters to gather

  • modifier_rank (int | None) – The modifier rank

agilerl.utils.llm_utils.get_state_dict(model: Module) dict[str, Tensor]

Get the state dict of the model for zero3.

Parameters:

model (nn.Module) – The model to get the state dict of.

Returns:

The state dict of the model.

Return type:

dict[str, torch.Tensor]

agilerl.utils.llm_utils.create_model_from_name_or_path(model_name_or_path: str, model_config: dict[str, Any] | None = None, add_value_head: bool = False, use_accelerator: bool = False) PreTrainedModel

Create a model from a name or path.

Parameters:
  • model_name_or_path (str) – The name or path of the model to create.

  • model_config (dict[str, Any ] | None) – The configuration of the model to create.

  • use_value_head (bool, optional) – Flag to indicate if a value head should be added to the model, defaults to False

  • use_accelerator (bool, optional) – Flag to indicate if the model should be created with the accelerator, defaults to False

Returns:

The created model.

Return type:

PreTrainedModel

agilerl.utils.llm_utils.sample_eval_prompts(env: Any, n: int = 5, seed: int = 0) list[tuple[str, str | None, str | None]]

Randomly sample n (prompt, chosen, rejected) triples from env’s held-out test dataset.

Columns are resolved automatically per gym type:

  • SFTGymchosen is env.response_column; rejected is None (SFT has no negative example).

  • PreferenceGymchosen and rejected map to the dataset’s "chosen" / "rejected" columns.

  • Any other gym — both are None.

Parameters:
  • env – AgileRL gym environment with a test_dataloader attribute.

  • n (int, optional) – Number of samples to draw, defaults to 5.

  • seed (int, optional) – Random seed for reproducible sampling, defaults to 0.

Returns:

List of (prompt, chosen, rejected) tuples; unused fields are None.

Return type:

list[tuple[str, str | None, str | None]]

agilerl.utils.llm_utils.compare_responses(agent: Any, tokenizer: Any, samples: list[tuple[str, str | None, str | None]], max_new_tokens: int = 200, temperature: float = 1.0, do_sample: bool = False, skip_special_tokens: bool = True, show_base_model: bool = True) None

Run each prompt through the base model and the fine-tuned LoRA model, printing a formatted comparison to the terminal one sample at a time.

After each sample the user is prompted to press Enter to continue or q + Enter to quit early. Intended to be called at the end of a training script for a quick qualitative sanity-check.

Works with any LoRA-adapted LLMAlgorithm (SFT, DPO, …). When the model has no LoRA adapter the base-model column is omitted and only the current model’s output is shown.

Parameters:
  • agent (LLMAlgorithm) – Trained AgileRL LLM agent exposing agent.actor and agent.device.

  • tokenizer – HuggingFace tokenizer matching the model.

  • samples (list[tuple[str, str | None, str | None]]) – (prompt, chosen, rejected) triples as returned by sample_eval_prompts(). None fields are silently skipped.

  • max_new_tokens (int, optional) – Maximum tokens to generate per response, defaults to 200.

  • temperature (float, optional) – Sampling temperature, defaults to 1.0.

  • do_sample (bool, optional) – Use sampling instead of greedy decoding, defaults to False. Set True together with a temperature != 1.0 for stochastic outputs.

  • skip_special_tokens (bool, optional) – Strip special tokens when decoding, defaults to True.

  • show_base_model (bool, optional) – If False, skip the base-model generation block (only the current model output is printed). Useful when the adapter is merged or base vs. adapter outputs are identical.