Torch Utils¶
- agilerl.utils.torch_utils.map_pytree(f: Callable[[ndarray | Tensor], Any], item: Any) Any¶
Apply a function to all tensors/arrays in a nested data structure.
Recursively traverses nested dictionaries, lists, tuples, and sets, applying the given function to any numpy arrays or PyTorch tensors found.
- Parameters:
f (Callable[[np.ndarray | torch.Tensor], Any]) – Function to apply to arrays/tensors
item (Any) – Nested data structure to traverse
- Returns:
Data structure with function applied to all arrays/tensors
- Return type:
Any
- agilerl.utils.torch_utils.to(item: Any, device: device | str) Any¶
Move all tensors/arrays in a nested data structure to specified device.
- Parameters:
item (Any) – Nested data structure containing tensors/arrays
device (torch.device) – Target device to move tensors to
- Returns:
Data structure with tensors moved to device
- Return type:
Any
- agilerl.utils.torch_utils.to_decorator(f: Callable[[...], Any], device: device | str) Callable[[...], Any]¶
Move the output of a function to a specified device (decorator).
- Parameters:
f (Callable) – Function whose output should be moved to device
device (torch.device) – Target device
- Returns:
Decorated function
- Return type:
Callable
- agilerl.utils.torch_utils.parameter_norm(model: Module) float¶
Calculate the L2 norm of all parameters in a model.
- Parameters:
model (nn.Module) – PyTorch model
- Returns:
L2 norm of all model parameters
- Return type:
- agilerl.utils.torch_utils.get_transformer_logs(attentions: list[Tensor], model: Module, attn_mask: Tensor) dict[str, tuple[float, int]]¶
Extract logging information from transformer attention weights.
Computes attention entropy and parameter norm for transformer models, which can be useful for monitoring training dynamics.