Data Structures and Utilities¶
This module provides essential data structures and utility functions for handling experiences and datasets in reinforcement learning.
The main components include the Transition tensorclass for representing environment transitions, the ReplayDataset for
creating iterable datasets from replay buffers, and utility functions for converting between different data formats.
The Transition class wraps observations, actions, rewards, next observations, and done flags as a structured data container,
automatically handling conversions between different data types and formats. The ReplayDataset enables integration with
PyTorch’s DataLoader for distributed training scenarios.
from agilerl.components.data import Transition, ReplayDataset, to_tensordict
from agilerl.components.replay_buffer import ReplayBuffer
# Create a transition
transition = Transition(
obs=obs,
action=action,
reward=reward,
next_obs=next_obs,
done=done
)
# Create a dataset from a replay buffer
buffer = ReplayBuffer(max_size=10000, device=device)
dataset = ReplayDataset(buffer, batch_size=32)
Functions¶
- agilerl.components.data.to_tensordict(data: ndarray | dict[str, ndarray] | tuple[ndarray, ...] | Tensor | TensorDict | tuple[Tensor, ...] | dict[str, Tensor] | Number | list[ReasoningPrompts] | ReasoningPrompts, dtype: dtype = torch.float32) TensorDict¶
Convert a tuple or dict of torch.Tensor or np.ndarray to a TensorDict.
- Parameters:
data (ObservationType) – Tuple or dict of torch.Tensor or np.ndarray.
dtype (torch.dtype, optional) – Data type of the TensorDict, defaults to torch.float32
- Returns:
TensorDict, whether the data was a tuple or not.
- agilerl.components.data.to_torch_tensor(data: ndarray | Tensor, dtype: dtype = torch.float32) Tensor¶
Convert a numpy array or Python number to a torch tensor.
- Parameters:
data (ArrayOrTensor) – Numpy array or Python number.
dtype (torch.dtype, optional) – Data type of the torch tensor, defaults to torch.float32
- Returns:
Torch tensor.
Classes¶
- class agilerl.components.data.Transition(obs: numpy.ndarray | dict[str, numpy.ndarray] | tuple[numpy.ndarray, ...] | torch.Tensor | tensordict._td.TensorDict | tuple[torch.Tensor, ...] | dict[str, torch.Tensor] | numbers.Number | list[agilerl.typing.ReasoningPrompts] | agilerl.typing.ReasoningPrompts, action: numpy.ndarray | torch.Tensor, next_obs: numpy.ndarray | dict[str, numpy.ndarray] | tuple[numpy.ndarray, ...] | torch.Tensor | tensordict._td.TensorDict | tuple[torch.Tensor, ...] | dict[str, torch.Tensor] | numbers.Number | list[agilerl.typing.ReasoningPrompts] | agilerl.typing.ReasoningPrompts, reward: numpy.ndarray | torch.Tensor, done: numpy.ndarray | torch.Tensor, *, batch_size, device=None, names=None)¶
- property device: device¶
Retrieves the device type of tensor class.
- dumps(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = None) Any¶
Saves the tensordict to disk.
This function is a proxy to
memmap().
- classmethod fields()¶
Return a tuple describing the fields of this dataclass.
Accepts a dataclass or an instance of one. Tuple elements are of type Field.
- from_schema(*, batch_size: Sequence[int] | Size | None = None, storage: str | None = None, device=None, **kwargs) TensorDictBase¶
Pre-allocate a zero-filled TensorDict from a schema.
Creates a
TensorDictBasewhose storage backend is selected bystorage. Each entry inschemamaps a field name to an(element_shape, dtype)pair; the full stored shape is[*batch_size, *element_shape].- Args:
- schema: Mapping from field name to
(element_shape, dtype). element_shapeis the per-element shape (excludingbatch_size).
- schema: Mapping from field name to
- Keyword Args:
- batch_size: Overall batch dimensions prepended to every element
shape. Defaults to
().
storage (str or None): Backend selector:
None– plainTensorDictwith regular tensors."memmap"– memory-mapped tensors on disk. Passprefix=<dir>in kwargs."h5"– HDF5 viaPersistentTensorDict. Passfilename=<path>in kwargs."shared"– CPU shared-memory tensors."redis"/"dragonfly"– delegates toTensorDictStore.from_schema().
- device: Device for the resulting tensors (ignored by some
backends).
- **kwargs: Backend-specific arguments forwarded to the
underlying constructor (e.g.
prefixfor memmap,filenamefor h5,host/portfor redis).
- Returns:
A new
TensorDictBasesubclass instance with pre-allocated (zero-filled) keys.- Examples:
>>> td = TensorDict.from_schema( ... {"obs": ([84, 84, 3], torch.uint8), ... "reward": ([], torch.float32)}, ... batch_size=[1000], ... ) >>> td["obs"].shape torch.Size([1000, 84, 84, 3])
>>> import tempfile >>> with tempfile.TemporaryDirectory() as d: ... td_mm = TensorDict.from_schema( ... {"obs": ([4], torch.float32)}, ... batch_size=[8], ... storage="memmap", ... prefix=d, ... ) ... assert td_mm.is_memmap()
- classmethod from_tensordict(tensordict: TensorDictBase, non_tensordict: dict | None = None, safe: bool = True) Any¶
Tensor class wrapper to instantiate a new tensor class object.
- Args:
tensordict (TensorDictBase): Dictionary of tensor types non_tensordict (dict): Dictionary with non-tensor and nested tensor class objects safe (bool): Whether to raise an error if the tensordict is not a TensorDictBase instance
- get(key: NestedKey, *args, **kwargs)¶
Gets the value stored with the input key.
- Args:
- key (str, tuple of str): key to be queried. If tuple of str it is
equivalent to chained calls of getattr.
default: default value if the key is not found in the tensorclass.
- Returns:
value stored with the input key
- classmethod load(prefix: str | Path, *args, **kwargs) Any¶
Loads a tensordict from disk.
This class method is a proxy to
load_memmap().
- load_(prefix: str | Path, *args, **kwargs)¶
Loads a tensordict from disk within the current tensordict.
This class method is a proxy to
load_memmap_().
- classmethod load_memmap(prefix: str | Path, device: device | None = None, non_blocking: bool = False, *, out: TensorDictBase | None = None, robust_key: bool | None = None) Any¶
Loads a memory-mapped tensordict from disk.
- Args:
- prefix (str or Path to folder): the path to the folder where the
saved tensordict should be fetched.
- device (torch.device or equivalent, optional): if provided, the
data will be asynchronously cast to that device. Supports “meta” device, in which case the data isn’t loaded but a set of empty “meta” tensors are created. This is useful to get a sense of the total model size and structure without actually opening any file.
- non_blocking (bool, optional): if
True, synchronize won’t be called after loading tensors on device. Defaults to
False.- out (TensorDictBase, optional): optional tensordict where the data
should be written.
- robust_key (bool, optional): if
True, expects robust key encoding was used when saving and decodes filenames accordingly. If
False, uses legacy behavior. IfNone(default), emits a deprecation warning and falls back to legacy behavior. Will default toTruein v0.12.
- Examples:
>>> from tensordict import TensorDict >>> td = TensorDict.fromkeys(["a", "b", "c", ("nested", "e")], 0) >>> td.memmap("./saved_td") >>> td_load = TensorDict.load_memmap("./saved_td") >>> assert (td == td_load).all()
This method also allows loading nested tensordicts.
- Examples:
>>> nested = TensorDict.load_memmap("./saved_td/nested") >>> assert nested["e"] == 0
A tensordict can also be loaded on “meta” device or, alternatively, as a fake tensor.
- Examples:
>>> import tempfile >>> td = TensorDict({"a": torch.zeros(()), "b": {"c": torch.zeros(())}}) >>> with tempfile.TemporaryDirectory() as path: ... td.save(path) ... td_load = TensorDict.load_memmap(path, device="meta") ... print("meta:", td_load) ... from torch._subclasses import FakeTensorMode ... with FakeTensorMode(): ... td_load = TensorDict.load_memmap(path) ... print("fake:", td_load) meta: TensorDict( fields={ a: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: Tensor(shape=torch.Size([]), device=meta, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False)}, batch_size=torch.Size([]), device=meta, is_shared=False) fake: TensorDict( fields={ a: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False), b: TensorDict( fields={ c: FakeTensor(shape=torch.Size([]), device=cpu, dtype=torch.float32, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)}, batch_size=torch.Size([]), device=cpu, is_shared=False)
- load_state_dict(state_dict: dict[str, Any], strict=True, assign=False, from_flatten=None)¶
Loads a state_dict into the tensorclass.
Supports both the new format (logical keys with
_metadata) and the legacy format (_tensordict/_non_tensordictwrapper keys).
- memmap(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True, robust_key: bool | None = None) Any¶
Writes all tensors onto a corresponding memory-mapped Tensor in a new tensordict.
- Args:
- prefix (str): directory prefix where the memory-mapped tensors will
be stored. The directory tree structure will mimic the tensordict’s.
- copy_existing (bool): If False (default), an exception will be raised if an
entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Args:
- num_threads (int, optional): the number of threads used to write the memmap
tensors. Defaults to 0.
- return_early (bool, optional): if
Trueandnum_threads>0, the method will return a future of the tensordict.
- share_non_tensor (bool, optional): if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non_tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to
False.- existsok (bool, optional): if
False, an exception will be raised if a tensor already exists in the same path. Defaults to
True.- robust_key (bool, optional): if
True, uses robust key encoding that safely handles keys with path separators and special characters. If
False, uses legacy behavior (keys used as-is). IfNone(default), emits a deprecation warning and falls back to legacy behavior. Will default toTruein v0.12.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
A new tensordict with the tensors stored on disk if
return_early=False, otherwise aTensorDictFutureinstance.- Note:
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, existsok: bool = True, robust_key: bool | None = None) Any¶
Writes all tensors onto a corresponding memory-mapped Tensor, in-place.
- Args:
- prefix (str): directory prefix where the memory-mapped tensors will
be stored. The directory tree structure will mimic the tensordict’s.
- copy_existing (bool): If False (default), an exception will be raised if an
entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Args:
- num_threads (int, optional): the number of threads used to write the memmap
tensors. Defaults to 0.
- return_early (bool, optional): if
Trueandnum_threads>0, the method will return a future of the tensordict. The resulting tensordict can be queried using future.result().
- share_non_tensor (bool, optional): if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to
False.- existsok (bool, optional): if
False, an exception will be raised if a tensor already exists in the same path. Defaults to
True.- robust_key (bool, optional): if
True, uses robust key encoding that safely handles keys with path separators and special characters. If
False, uses legacy behavior (keys used as-is). IfNone(default), emits a deprecation warning and falls back to legacy behavior. Will default toTruein v0.12.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
self if
return_early=False, otherwise aTensorDictFutureinstance.- Note:
Serialising in this fashion might be slow with deeply nested tensordicts, so it is not recommended to call this method inside a training loop.
- memmap_like(prefix: str | None = None, copy_existing: bool = False, *, existsok: bool = True, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = None) Any¶
Creates a contentless Memory-mapped tensordict with the same shapes as the original one.
- Args:
- prefix (str): directory prefix where the memory-mapped tensors will
be stored. The directory tree structure will mimic the tensordict’s.
- copy_existing (bool): If False (default), an exception will be raised if an
entry in the tensordict is already a tensor stored on disk with an associated file, but is not saved in the correct location according to prefix. If
True, any existing Tensor will be copied to the new location.
- Keyword Args:
- num_threads (int, optional): the number of threads used to write the memmap
tensors. Defaults to 0.
- return_early (bool, optional): if
Trueandnum_threads>0, the method will return a future of the tensordict.
- share_non_tensor (bool, optional): if
True, the non-tensor data will be shared between the processes and writing operation (such as inplace update or set) on any of the workers within a single node will update the value on all other workers. If the number of non-tensor leaves is high (e.g., sharing large stacks of non-tensor data) this may result in OOM or similar errors. Defaults to
False.- existsok (bool, optional): if
False, an exception will be raised if a tensor already exists in the same path. Defaults to
True.- robust_key (bool, optional): if
True, uses robust key encoding that safely handles keys with path separators and special characters. If
False, uses legacy behavior (keys used as-is). IfNone(default), emits a deprecation warning and falls back to legacy behavior. Will default toTruein v0.12.
The TensorDict is then locked, meaning that any writing operations that isn’t in-place will throw an exception (eg, rename, set or remove an entry). Once the tensordict is unlocked, the memory-mapped attribute is turned to
False, because cross-process identity is not guaranteed anymore.- Returns:
A new
TensorDictinstance with data stored as memory-mapped tensors ifreturn_early=False, otherwise aTensorDictFutureinstance.
Note
This is the recommended method to write a set of large buffers on disk, as
memmap_()will copy the information, which can be slow for large content.- Examples:
>>> td = TensorDict({ ... "a": torch.zeros((3, 64, 64), dtype=torch.uint8), ... "b": torch.zeros(1, dtype=torch.int64), ... }, batch_size=[]).expand(1_000_000) # expand does not allocate new memory >>> buffer = td.memmap_like("/path/to/dataset")
- memmap_refresh_()¶
Refreshes the content of the memory-mapped tensordict if it has a
saved_path.This method will raise an exception if no path is associated with it.
- save(prefix: str | None = None, copy_existing: bool = False, *, num_threads: int = 0, return_early: bool = False, share_non_tensor: bool = False, robust_key: bool | None = None) Any¶
Saves the tensordict to disk.
This function is a proxy to
memmap().
- select(*keys, inplace: bool = False, strict: bool = True, as_tensordict: bool = False)¶
TensorClass-specific select that supports
as_tensordict.
- set(key: NestedKey, value: Any, inplace: bool = False, non_blocking: bool = False)¶
Sets a new key-value pair.
- Args:
- key (str, tuple of str): name of the key to be set.
If tuple of str it is equivalent to chained calls of getattr followed by a final setattr.
value (Any): value to be stored in the tensorclass inplace (bool, optional): if
True, set will tentatively try toupdate the value in-place. If
Falseor if the key isn’t present, the value will be simply written at its destination.- Returns:
self
- state_dict(destination=None, prefix='', keep_vars=False, flatten=True) dict[str, Any]¶
Returns a state_dict with logical keys, matching TensorDictBase conventions.
Tensor fields appear as data keys. Non-tensor fields (strings, ints, etc.) and the tensorclass type are stored in
_metadata. This replaces the legacy_tensordict/_non_tensordictwrapper format.
- to_tensordict(*, retain_none: bool | None = None) TensorDict¶
Convert the tensorclass into a regular TensorDict.
Makes a copy of all entries. Memmap and shared memory tensors are converted to regular tensors.
- Args:
- retain_none (bool): if
True, theNonevalues will be written in the tensordict. Otherwise they will be discrarded. Default:
True.
- retain_none (bool): if
- Returns:
A new TensorDict object containing the same values as the tensorclass.
- class agilerl.components.data.ReplayDataset(buffer: ReplayBuffer, batch_size: int = 256)¶
Iterable Dataset containing the ReplayBuffer which will be updated with new experiences during training.
- Parameters:
buffer (agilerl.components.replay_buffer.ReplayBuffer()) – Experience replay buffer
batch_size (int, optional) – Number of experiences to sample at a time, defaults to 256