Torch Utils

agilerl.utils.torch_utils.map_pytree(f: Callable[[ndarray | Tensor], Any], item: Any) Any

Apply a function to all tensors/arrays in a nested data structure.

Recursively traverses nested dictionaries, lists, tuples, and sets, applying the given function to any numpy arrays or PyTorch tensors found.

Parameters:
  • f (Callable[[np.ndarray | torch.Tensor], Any]) – Function to apply to arrays/tensors

  • item (Any) – Nested data structure to traverse

Returns:

Data structure with function applied to all arrays/tensors

Return type:

Any

agilerl.utils.torch_utils.to(item: Any, device: device | str) Any

Move all tensors/arrays in a nested data structure to specified device.

Parameters:
  • item (Any) – Nested data structure containing tensors/arrays

  • device (torch.device) – Target device to move tensors to

Returns:

Data structure with tensors moved to device

Return type:

Any

agilerl.utils.torch_utils.to_decorator(f: Callable[[...], Any], device: device | str) Callable[[...], Any]

Move the output of a function to a specified device (decorator).

Parameters:
  • f (Callable) – Function whose output should be moved to device

  • device (torch.device) – Target device

Returns:

Decorated function

Return type:

Callable

agilerl.utils.torch_utils.parameter_norm(model: Module) float

Calculate the L2 norm of all parameters in a model.

Parameters:

model (nn.Module) – PyTorch model

Returns:

L2 norm of all model parameters

Return type:

float

agilerl.utils.torch_utils.get_transformer_logs(attentions: list[Tensor], model: Module, attn_mask: Tensor) dict[str, tuple[float, int]]

Extract logging information from transformer attention weights.

Computes attention entropy and parameter norm for transformer models, which can be useful for monitoring training dynamics.

Parameters:
  • attentions (list[torch.Tensor]) – List of attention weight tensors from transformer layers

  • model (nn.Module) – Transformer model

  • attn_mask (torch.Tensor) – Attention mask tensor

Returns:

Dictionary containing attention entropy and parameter norm

Return type:

dict[str, tuple[float, int]]