Implicit Language Q-Learning (ILQL)

ILQL is an extension of Implicit Q-learning that can be used to finetune large language models (LLMs) with reinforcement leaning from human feedback (RLHF).

Parameters

class agilerl.algorithms.ilql.ILQL(dataset: RL_Dataset, net_config: dict[str, dict[str, Any] | Any] | None = None, index: int = 0, batch_size: int = 64, lr: float = 1e-05, alpha: float = 0.005, beta: float = 0.0, gamma: float = 0.99, tau: float = 0.6, mutation: str | None = None, transition_weight: float = 0.0, clip_weight: float | None = None, value_max: float | None = None, value_min: float | None = None, detach_v: bool = False, detach_q: bool = False, detach_pi: bool = False, double_q: bool = True, per_token: bool = True, exp_weights: bool = True, dm_margin: float = 0.0, cql_temp: float = 1.0, weight_decay: float = 0.0, device: str | device = 'cpu')

The Implicit Language Q Learning algorithm class. ILQL paper: https://arxiv.org/pdf/2206.11871.pdf.

Parameters:
  • dataset (torch.utils.data.Dataset) – Language dataset to perform ILQL on

  • net_config (dict, optional) – Network configuration, defaults to GPT2 configuration

  • index (int, optional) – Index to keep track of object instance during tournament selection and mutation, defaults to 0

  • batch_size (int, optional) – Size of batched sample from replay buffer for learning, defaults to 64

  • lr (float, optional) – Learning rate for optimizer, defaults to 1e-5

  • alpha (float, optional) – For soft update of target network parameters, defaults to 0.005

  • beta (float, optional) – For AWR policy extraction, defaults to 0.0

  • gamma (float, optional) – Discount factor, defaults to 0.99

  • tau (float, optional) – For value network loss, defaults to 0.6

  • mutation (str, optional) – Most recent mutation to agent, defaults to None

  • transition_weight (float, optional) – Value to use temporarily for weights in transition, defaults to 0.0

  • clip_weight (float, optional) – Maximum value to clip weights at, defaults to None

  • value_max (float, optional) – Maximum Q value for clipping, defaults to None

  • value_min (float, optional) – Minimum Q value for clipping, defaults to None

  • detach_v (bool, optional) – Detach V network, defaults to False

  • detach_q (bool, optional) – Detach Q network, defaults to False

  • detach_pi (bool, optional) – Detach Policy network, defaults to False

  • double_q (bool, optional) – Use double Q learning, defaults to True

  • per_token (bool, optional) – Do per_token ILQL, defaults to True

  • exp_weights (bool, optional) – Exponential advantage weights, defaults to True

  • dm_margin (float, optional) – Margin for DM loss, defaults to 0.0

  • cql_temp (float, optional) – Temperature parameter for CQL loss, defaults to 1.0

  • weight_decay (float, optional) – weight decay for optimizer, defaults to 0.0

  • device (str, optional) – Device for accelerated computing, ‘cpu’ or ‘cuda’, defaults to ‘cpu’

clean_up() None

Clean up the networks.

clone(index: int | None = None) Self

Return cloned agent identical to self.

Parameters:

index (int, optional) – Index to keep track of agent for tournament selection and mutation, defaults to None

forward(tokens: Tensor, state_idxs: Tensor, action_idxs: Tensor, attn_mask: Tensor | None = None, prefix_embs: Tensor | None = None, prefix_attn_mask: Tensor | None = None, remove_prefix_position_embs: bool = False, qv_kwargs: dict[str, Any] | None = None, policy_kwargs: dict[str, Any] | None = None, target_kwargs: dict[str, Any] | None = None, skip_policy_on_train: bool = False, detach_full_policy: bool = False) dict[str, Any]

Forward pass through transformers.

Parameters:
  • tokens (torch.Tensor) – Tokens to input to model

  • state_idxs (torch.Tensor) – State indexes

  • action_idxs (torch.Tensor) – Action indexes

  • attn_mask (torch.Tensor, optional) – Attention mask for transformers, defaults to None

  • prefix_embs (torch.Tensor, optional) – Prefix embeddings, defaults to None

  • skip_policy_on_train (bool, optional) – Skip policy language model when training, defaults to False

  • detach_full_policy (bool, optional) – Use policy language model without gradients, defaults to False

hard_update() None

Hard update target networks.

load_checkpoint(path: str | Path) None

Load saved agent properties and network weights from checkpoint.

Parameters:

path (string) – Location to load checkpoint from

save_checkpoint(path: str | Path) None

Save a checkpoint of agent properties and network weights to path.

Parameters:

path (string) – Location to save checkpoint at

soft_update() None

Soft updates target networks.