Experience Sampler

The Sampler class provides a unified interface for sampling experiences from replay buffers in both standard and distributed training scenarios. It can work with various types of replay buffers including standard replay buffers, multi-agent buffers, prioritized buffers, and multi-step buffers.

For distributed training, the sampler can work with PyTorch DataLoaders and provides custom collate functions to properly handle TensorDict objects. This enables efficient batch processing of experiences across multiple workers.

The sampler automatically detects whether it’s being used in a standard training setup (with a replay buffer) or distributed setup (with a dataset and dataloader) and provides the appropriate sampling interface.

from agilerl.components.sampler import Sampler
from agilerl.components.replay_buffer import ReplayBuffer

# Standard training setup
buffer = ReplayBuffer(max_size=10000, device=device)
sampler = Sampler(memory=buffer)

# Sample experiences
batch = sampler.sample(batch_size=32)

# Distributed training setup
from agilerl.components.data import ReplayDataset
dataset = ReplayDataset(buffer, batch_size=32)
dataloader = Sampler.create_dataloader(dataset, batch_size=32)
distributed_sampler = Sampler(dataset=dataset, dataloader=dataloader)

Classes

class agilerl.components.sampler.Sampler(memory: ReplayBuffer | MultiAgentReplayBuffer | PrioritizedReplayBuffer | MultiStepReplayBuffer | None = None, dataset: ReplayDataset | None = None, dataloader: DataLoader | None = None)

Sampler class to handle both standard and distributed training.

Parameters:
Raises:

AssertionError – If neither memory nor (dataset and dataloader) are provided

classmethod create_dataloader(dataset: ReplayDataset, batch_size: int | None = None, **kwargs) DataLoader

Create a DataLoader with the appropriate collate function.

Parameters:
  • dataset (ReplayDataset) – Dataset to create a DataLoader for

  • batch_size (int | None, optional) – Batch size for the DataLoader, defaults to None

  • kwargs – Additional arguments to pass to the DataLoader

Returns:

DataLoader with tensordict_collate_fn

Return type:

DataLoader

sample_distributed(batch_size: int, return_idx: bool | None = None) TensorDict

Sample a batch of experiences from the distributed dataset.

Parameters:
  • batch_size (int) – Size of the batch to sample

  • return_idx (bool | None, optional) – Not used in distributed sampling, defaults to None

Returns:

Sampled batch of experiences

Return type:

TensorDict

sample_n_step(idxs: Any) dict[str, ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts] | tuple[ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts, ...]

Sample a batch of experiences from the n-step replay buffer.

Parameters:

idxs (Any) – Indices to sample from

Returns:

Sampled batch of experiences

Return type:

TensorDict

sample_per(batch_size: int, beta: float) dict[str, ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts] | tuple[ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts, ...]

Sample a batch of experiences from the Prioritized Experience Replay buffer.

Parameters:
  • batch_size (int) – Size of the batch to sample

  • beta (float) – Importance-sampling weight

Returns:

Sampled batch of experiences, indices, and importance-sampling weights

Return type:

TensorDict

sample_standard(batch_size: int, return_idx: bool = False) dict[str, ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts] | tuple[ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts, ...]

Sample a batch of experiences from the standard replay buffer.

Parameters:
  • batch_size (int) – Size of the batch to sample

  • return_idx (bool, optional) – Whether to return indices, defaults to False

Returns:

Sampled batch of experiences

Return type:

TensorDict

static tensordict_collate_fn(batch: list[TensorDict]) TensorDict | list[TensorDict]

Provide a custom collate function that properly handles TensorDict objects.

Parameters:

batch (list[TensorDict]) – List of TensorDict objects to collate

Returns:

Either a single TensorDict or a list of TensorDicts

Return type:

TensorDict | list[TensorDict]