Experience Sampler¶
The Sampler class provides a unified interface for sampling experiences from replay buffers in both standard and distributed
training scenarios. It can work with various types of replay buffers including standard replay buffers, multi-agent buffers,
prioritized buffers, and multi-step buffers.
For distributed training, the sampler can work with PyTorch DataLoaders and provides custom collate functions to properly handle TensorDict objects. This enables efficient batch processing of experiences across multiple workers.
The sampler automatically detects whether it’s being used in a standard training setup (with a replay buffer) or distributed setup (with a dataset and dataloader) and provides the appropriate sampling interface.
from agilerl.components.sampler import Sampler
from agilerl.components.replay_buffer import ReplayBuffer
# Standard training setup
buffer = ReplayBuffer(max_size=10000, device=device)
sampler = Sampler(memory=buffer)
# Sample experiences
batch = sampler.sample(batch_size=32)
# Distributed training setup
from agilerl.components.data import ReplayDataset
dataset = ReplayDataset(buffer, batch_size=32)
dataloader = Sampler.create_dataloader(dataset, batch_size=32)
distributed_sampler = Sampler(dataset=dataset, dataloader=dataloader)
Classes¶
- class agilerl.components.sampler.Sampler(memory: ReplayBuffer | MultiAgentReplayBuffer | PrioritizedReplayBuffer | MultiStepReplayBuffer | None = None, dataset: ReplayDataset | None = None, dataloader: DataLoader | None = None)¶
Sampler class to handle both standard and distributed training.
- Parameters:
memory (ReplayBuffer | MultiAgentReplayBuffer | PrioritizedReplayBuffer | MultiStepReplayBuffer | None, optional) – Replay buffer memory, defaults to None
dataset (ReplayDataset | None, optional) – Dataset for distributed sampling, defaults to None
dataloader (DataLoader | None, optional) – DataLoader for distributed sampling, defaults to None
- Raises:
AssertionError – If neither memory nor (dataset and dataloader) are provided
- classmethod create_dataloader(dataset: ReplayDataset, batch_size: int | None = None, **kwargs) DataLoader¶
Create a DataLoader with the appropriate collate function.
- Parameters:
dataset (ReplayDataset) – Dataset to create a DataLoader for
batch_size (int | None, optional) – Batch size for the DataLoader, defaults to None
kwargs – Additional arguments to pass to the DataLoader
- Returns:
DataLoader with tensordict_collate_fn
- Return type:
DataLoader
- sample_distributed(batch_size: int, return_idx: bool | None = None) TensorDict¶
Sample a batch of experiences from the distributed dataset.
- sample_n_step(idxs: Any) dict[str, ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts] | tuple[ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts, ...]¶
Sample a batch of experiences from the n-step replay buffer.
- Parameters:
idxs (Any) – Indices to sample from
- Returns:
Sampled batch of experiences
- Return type:
TensorDict
- sample_per(batch_size: int, beta: float) dict[str, ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts] | tuple[ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts, ...]¶
Sample a batch of experiences from the Prioritized Experience Replay buffer.
- sample_standard(batch_size: int, return_idx: bool = False) dict[str, ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts] | tuple[ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts, ...]¶
Sample a batch of experiences from the standard replay buffer.