AgentWrapper¶
Parameters¶
- class agilerl.wrappers.agent.AgentWrapper(agent: RLAlgorithm | MultiAgentRLAlgorithm)¶
Base class for all agent wrappers. Agent wrappers are used to apply an additional functionality to the
get_action()
andlearn()
methods of anEvolvableAlgorithm
instance.- Parameters:
agent (EvolvableAlgorithm) – Agent to be wrapped
- clone(index: int | None = None, wrap: bool = True) SelfAgentWrapper ¶
Clones the wrapper with the underlying agent.
- property device: str | device¶
Returns the device of the agent.
- Returns:
Device of the agent
- Return type:
DeviceType
- abstract get_action(obs: ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...] | Dict[str, ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...]], *args: Any, **kwargs: Any) Any ¶
Returns the action from the agent.
- Parameters:
obs (Union[ObservationType, MARLObservationType]) – Observation from the environment
args (Any) – Additional positional arguments
kwargs (Any) – Additional keyword arguments
- Returns:
Action from the agent
- Return type:
Any
- abstract learn(experiences: Tuple[ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...], ...], *args: Any, **kwargs: Any) Any ¶
Learns from the experiences.
- Parameters:
experiences (ExperiencesType) – Experiences from the environment
args (Any) – Additional positional arguments
kwargs (Any) – Additional keyword arguments
- Returns:
Learning information
- Return type:
Any
- load_checkpoint(path: str) None ¶
Loads a checkpoint of agent properties and network weights from path.
- Parameters:
path (string) – Location to load checkpoint from
RSNorm¶
Parameters¶
- class agilerl.wrappers.agent.RSNorm(agent: RLAlgorithm | MultiAgentRLAlgorithm, epsilon: float = 0.0001, norm_obs_keys: List[str] | None = None)¶
Wrapper to normalize observations such that each coordinate is centered with unit variance. Handles both single and multi-agent settings, as well as Dict and Tuple observation spaces.
The normalization statistics are only updated when the agent is in training mode. This can be disabled during inference through
agent.set_training_mode(False)
.Warning
This wrapper is currently only supported for off-policy algorithms since it relies on passed experiences to be formatted as a tuple of PyTorch tensors. Currently AgileRL does not use a Buffer class to store experiences for on-policy algorithms, albeit this will be released in a soon-to-come update!
- Parameters:
agent (RLAlgorithm, MultiAgentRLAlgorithm) – Agent to be wrapped
epsilon (float, optional) – Small value to avoid division by zero, defaults to 1e-4
norm_obs_keys (Optional[List]) – List of observation keys to normalize, defaults to None
- static build_rms(observation_space: Space, epsilon: float = 0.0001, norm_obs_keys: List[str] | None = None, device: str | device = 'cpu') RunningMeanStd | Dict[str, RunningMeanStd] | Tuple[RunningMeanStd, ...] ¶
Builds the RunningMeanStd object(s) based on the observation space.
- Parameters:
observation_space (spaces.Space) – Observation space of the agent
- Returns:
RunningMeanStd object(s)
- Return type:
Union[RunningMeanStd, Dict[str, RunningMeanStd], Tuple[RunningMeanStd, …]]
- get_action(obs: ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...], *args: Any, **kwargs: Any) Any ¶
Returns the action from the agent after normalizing the observation.
- Parameters:
obs (ObservationType) – Observation from the environment
- Returns:
Action from the agent
- Return type:
Any
- learn(experiences: Tuple[ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...], ...], *args: Any, **kwargs: Any) Any ¶
Learns from the experiences after normalizing the observations.
- Parameters:
experiences (ExperiencesType) – Experiences from the environment
args (Any) – Additional positional arguments
kwargs (Any) – Additional keyword arguments
- Returns:
Learning information
- Return type:
Any
- normalize_observation(observation: ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...]) ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...] ¶
Normalizes the observation using the RunningMeanStd object(s).
- Parameters:
observation (ObservationType) – Observation from the environment
- Returns:
Normalized observation
- Return type:
ObservationType
- update_statistics(observation: ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...]) None ¶
Updates the running statistics using the observation.
- Parameters:
observation (ObservationType) – Observation from the environment