AgentWrapper

Parameters

class agilerl.wrappers.agent.AgentWrapper(agent: RLAlgorithm | MultiAgentRLAlgorithm)

Base class for all agent wrappers. Agent wrappers are used to apply an additional functionality to the get_action() and learn() methods of an EvolvableAlgorithm instance.

Parameters:

agent (EvolvableAlgorithm) – Agent to be wrapped

clone(index: int | None = None, wrap: bool = True) SelfAgentWrapper

Clones the wrapper with the underlying agent.

Parameters:
  • index (Optional[int], optional) – Index of the agent in a population, defaults to None

  • wrap (bool, optional) – If True, wrap the models in the clone with the accelerator, defaults to False

Returns:

Cloned agent wrapper

Return type:

SelfAgentWrapper

property device: str | device

Returns the device of the agent.

Returns:

Device of the agent

Return type:

DeviceType

abstract get_action(obs: ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...] | Dict[str, ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...]], *args: Any, **kwargs: Any) Any

Returns the action from the agent.

Parameters:
  • obs (Union[ObservationType, MARLObservationType]) – Observation from the environment

  • args (Any) – Additional positional arguments

  • kwargs (Any) – Additional keyword arguments

Returns:

Action from the agent

Return type:

Any

abstract learn(experiences: Tuple[ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...], ...], *args: Any, **kwargs: Any) Any

Learns from the experiences.

Parameters:
  • experiences (ExperiencesType) – Experiences from the environment

  • args (Any) – Additional positional arguments

  • kwargs (Any) – Additional keyword arguments

Returns:

Learning information

Return type:

Any

load_checkpoint(path: str) None

Loads a checkpoint of agent properties and network weights from path.

Parameters:

path (string) – Location to load checkpoint from

save_checkpoint(path: str) None

Saves a checkpoint of agent properties and network weights to path.

Parameters:

path (string) – Location to save checkpoint at

property training: bool

Returns the training status of the agent.

Returns:

Training status of the agent

Return type:

bool

RSNorm

Parameters

class agilerl.wrappers.agent.RSNorm(agent: RLAlgorithm | MultiAgentRLAlgorithm, epsilon: float = 0.0001, norm_obs_keys: List[str] | None = None)

Wrapper to normalize observations such that each coordinate is centered with unit variance. Handles both single and multi-agent settings, as well as Dict and Tuple observation spaces.

The normalization statistics are only updated when the agent is in training mode. This can be disabled during inference through agent.set_training_mode(False).

Warning

This wrapper is currently only supported for off-policy algorithms since it relies on passed experiences to be formatted as a tuple of PyTorch tensors. Currently AgileRL does not use a Buffer class to store experiences for on-policy algorithms, albeit this will be released in a soon-to-come update!

Parameters:
  • agent (RLAlgorithm, MultiAgentRLAlgorithm) – Agent to be wrapped

  • epsilon (float, optional) – Small value to avoid division by zero, defaults to 1e-4

  • norm_obs_keys (Optional[List]) – List of observation keys to normalize, defaults to None

static build_rms(observation_space: Space, epsilon: float = 0.0001, norm_obs_keys: List[str] | None = None, device: str | device = 'cpu') RunningMeanStd | Dict[str, RunningMeanStd] | Tuple[RunningMeanStd, ...]

Builds the RunningMeanStd object(s) based on the observation space.

Parameters:

observation_space (spaces.Space) – Observation space of the agent

Returns:

RunningMeanStd object(s)

Return type:

Union[RunningMeanStd, Dict[str, RunningMeanStd], Tuple[RunningMeanStd, …]]

get_action(obs: ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...], *args: Any, **kwargs: Any) Any

Returns the action from the agent after normalizing the observation.

Parameters:

obs (ObservationType) – Observation from the environment

Returns:

Action from the agent

Return type:

Any

learn(experiences: Tuple[ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...], ...], *args: Any, **kwargs: Any) Any

Learns from the experiences after normalizing the observations.

Parameters:
  • experiences (ExperiencesType) – Experiences from the environment

  • args (Any) – Additional positional arguments

  • kwargs (Any) – Additional keyword arguments

Returns:

Learning information

Return type:

Any

normalize_observation(observation: ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...]) ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...]

Normalizes the observation using the RunningMeanStd object(s).

Parameters:

observation (ObservationType) – Observation from the environment

Returns:

Normalized observation

Return type:

ObservationType

update_statistics(observation: ndarray | Dict[str, ndarray] | Tuple[_SupportsArray[dtype[Any]] | _NestedSequence[_SupportsArray[dtype[Any]]] | bool | int | float | complex | str | bytes | _NestedSequence[bool | int | float | complex | str | bytes], ...] | Tensor | Dict[str, Tensor] | Tuple[Tensor, ...]) None

Updates the running statistics using the observation.

Parameters:

observation (ObservationType) – Observation from the environment