Fused linear log-prob optimizations¶
When you train an LLM with reinforcement learning, the model still has to turn
each hidden vector into a score for every vocabulary token. That intermediate
usually has shape (batch, sequence length, vocab size). For large vocabs (typically >100k)
that tensor alone can dominate GPU memory, even though many algorithms only
need the log-probability of the token that was actually chosen at each
position—a much smaller (batch, sequence length) result.
AgileRL offers two optional speed/memory switches. They are meant to implement the same training objective as the standard code, up to normal floating-point differences.
use_liger_lossdefaults toNone, which resolves toTruewhenliger-kernelis importable andFalseotherwise. PassFalseexplicitly to force the standard path even when Liger is installed.use_fused_linear_logprobsdefaults toNone, which resolves to the same value as the resolveduse_liger_loss. Without Liger the gradient-time path already materializes the full(B, T, V)logits, so fusing the no-grad rollout alone wouldn’t lower overall peak — and would force the model to expose a discoverablelm_head. PassTrueexplicitly to fuse only the rollout side without enabling Liger.
Flag |
What it does |
When it runs |
|---|---|---|
|
Chunked |
Rollout-side work only (e.g. “old” and reference log-probs) when gradients are off—no impact on how the policy loss backprops. |
|
Fused chunking for the policy / KL part of the loss, including
backward through |
While the loss is being differentiated (PPO, REINFORCE, GRPO, CISPO, GSPO family). |
use_fused_linear_logprobs is pure AgileRL code and does not require
liger-kernel. use_liger_loss does require liger-kernel; if you
explicitly pass True without it installed you get a warning and the flag
is turned off. Leaving it at the default None simply opts out of the
fused loss when Liger is missing — no warning. If you use use_liger_loss
with LoRA, lm_head is excluded from LoRA adapters (with a warning) because
the fused kernel expects a single full head weight matrix.
cast_logprobs_to_fp32 (on LLMAlgorithm) controls whether the
chunked log-prob reductions in the standard and
use_fused_linear_logprobs paths run the numerically sensitive
logsumexp / gather steps in fp32, then cast back. It defaults to None,
which resolves to the same value as the resolved use_liger_loss: Liger
users get fp32 (consistent with Liger’s own gradient-time math and cheap
because the fused rollout workspace is small), while non-Liger users get
bf16 to avoid a ~10 GB fp32 chunk workspace landing on top of the full
(B, T, V) bf16 logits on the unfused grad path. Set True explicitly
if you want fp32 precision and have the memory headroom; set False to
force bf16 even with Liger. Note that the Liger gradient-time kernels use
their own fused math and ignore this flag for the loss backward — it
only governs the rollout-side log-prob reductions.
Usage¶
from agilerl.algorithms import GRPO, CISPO, GSPO, LLMPPO, LLMREINFORCE
# Both fused paths are on by default when liger-kernel is installed.
agent = GRPO(...)
# Force the standard PyTorch paths even if Liger is available.
agent = GRPO(
...,
use_liger_loss=False,
use_fused_linear_logprobs=False,
)
Tiny batches (only a few hundred tokens total) may not see much benefit from chunking and can even be slightly slower; very large sequences may still run out of memory for non-vocabulary reasons (attention, backbone activations).
Example: what changes in memory?¶
Illustrative peak workspace for the vocabulary projection only: same batch, sequence, and vocab, comparing storing full logits once versus fusing/chunking so only a thin slice of vocab scores exists at a time. Numbers are order-of-magnitude; real runs add the rest of the model on top.
Setting |
Dominant temporary tensor |
Rough size (bf16) for |
|---|---|---|
Standard |
Logits |
~5 GB for that tensor alone |
|
One chunk of logits |
~0.3 GB peak for that slice (≈10–50× smaller slice, depending on chunk size) |