Implicit Language Q-Learning (ILQL)¶
ILQL is an extension of Implicit Q-learning that can be used to finetune large language models (LLMs) with reinforcement leaning from human feedback (RLHF).
Parameters¶
- class agilerl.algorithms.ilql.ILQL(dataset: RL_Dataset, net_config: dict[str, dict[str, Any] | Any] | None = None, index: int = 0, batch_size: int = 64, lr: float = 1e-05, alpha: float = 0.005, beta: float = 0.0, gamma: float = 0.99, tau: float = 0.6, mutation: str | None = None, transition_weight: float = 0.0, clip_weight: float | None = None, value_max: float | None = None, value_min: float | None = None, detach_v: bool = False, detach_q: bool = False, detach_pi: bool = False, double_q: bool = True, per_token: bool = True, exp_weights: bool = True, dm_margin: float = 0.0, cql_temp: float = 1.0, weight_decay: float = 0.0, device: str | device = 'cpu')¶
The Implicit Language Q Learning algorithm class. ILQL paper: https://arxiv.org/pdf/2206.11871.pdf.
- Parameters:
dataset (torch.utils.data.Dataset) – Language dataset to perform ILQL on
net_config (dict, optional) – Network configuration, defaults to GPT2 configuration
index (int, optional) – Index to keep track of object instance during tournament selection and mutation, defaults to 0
batch_size (int, optional) – Size of batched sample from replay buffer for learning, defaults to 64
lr (float, optional) – Learning rate for optimizer, defaults to 1e-5
alpha (float, optional) – For soft update of target network parameters, defaults to 0.005
beta (float, optional) – For AWR policy extraction, defaults to 0.0
gamma (float, optional) – Discount factor, defaults to 0.99
tau (float, optional) – For value network loss, defaults to 0.6
mutation (str, optional) – Most recent mutation to agent, defaults to None
transition_weight (float, optional) – Value to use temporarily for weights in transition, defaults to 0.0
clip_weight (float, optional) – Maximum value to clip weights at, defaults to None
value_max (float, optional) – Maximum Q value for clipping, defaults to None
value_min (float, optional) – Minimum Q value for clipping, defaults to None
detach_v (bool, optional) – Detach V network, defaults to False
detach_q (bool, optional) – Detach Q network, defaults to False
detach_pi (bool, optional) – Detach Policy network, defaults to False
double_q (bool, optional) – Use double Q learning, defaults to True
per_token (bool, optional) – Do per_token ILQL, defaults to True
exp_weights (bool, optional) – Exponential advantage weights, defaults to True
dm_margin (float, optional) – Margin for DM loss, defaults to 0.0
cql_temp (float, optional) – Temperature parameter for CQL loss, defaults to 1.0
weight_decay (float, optional) – weight decay for optimizer, defaults to 0.0
device (str, optional) – Device for accelerated computing, ‘cpu’ or ‘cuda’, defaults to ‘cpu’
- clone(index: int | None = None) Self¶
Return cloned agent identical to self.
- Parameters:
index (int, optional) – Index to keep track of agent for tournament selection and mutation, defaults to None
- forward(tokens: Tensor, state_idxs: Tensor, action_idxs: Tensor, attn_mask: Tensor | None = None, prefix_embs: Tensor | None = None, prefix_attn_mask: Tensor | None = None, remove_prefix_position_embs: bool = False, qv_kwargs: dict[str, Any] | None = None, policy_kwargs: dict[str, Any] | None = None, target_kwargs: dict[str, Any] | None = None, skip_policy_on_train: bool = False, detach_full_policy: bool = False) dict[str, Any]¶
Forward pass through transformers.
- Parameters:
tokens (torch.Tensor) – Tokens to input to model
state_idxs (torch.Tensor) – State indexes
action_idxs (torch.Tensor) – Action indexes
attn_mask (torch.Tensor, optional) – Attention mask for transformers, defaults to None
prefix_embs (torch.Tensor, optional) – Prefix embeddings, defaults to None
skip_policy_on_train (bool, optional) – Skip policy language model when training, defaults to False
detach_full_policy (bool, optional) – Use policy language model without gradients, defaults to False
- load_checkpoint(path: str | Path) None¶
Load saved agent properties and network weights from checkpoint.
- Parameters:
path (string) – Location to load checkpoint from