Probe Environments

For more information on how to use probe environments, see Debugging RL.

Single-agent probe environments

class agilerl.utils.probe_envs.ConstantRewardEnv
class ConstantRewardEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Discrete(1)
        self.action_space = spaces.Discrete(1)
        self.sample_obs = [np.zeros((1, 1))]
        self.q_values = [[1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[1.0]]  # Correct V values to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[int, float, bool, bool, dict[str, Any]]:
        observation = 0
        reward = 1  # Constant reward of 1
        terminated = True
        truncated = False
        info: dict[str, Any] = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[int, dict[str, Any]]:
        observation = 0
        info: dict[str, Any] = {}
        return observation, info
class agilerl.utils.probe_envs.ConstantRewardImageEnv
class ConstantRewardImageEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Box(0.0, 0.0, (1, 3, 3))
        self.action_space = spaces.Discrete(1)
        self.sample_obs = [np.zeros((1, 1, 3, 3))]
        self.q_values = [[1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[1.0]]  # Correct V values to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = np.zeros((1, 3, 3))
        reward = 1  # Constant reward of 1
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = np.zeros((1, 3, 3))
        info = {}
        return observation, info
class agilerl.utils.probe_envs.ConstantRewardContActionsEnv
class ConstantRewardContActionsEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Discrete(1)
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.sample_obs = [np.zeros((1, 1))]
        self.sample_actions = [[[1.0]]]
        self.q_values = [[1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[1.0]]  # Correct V values to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = 0
        reward = 1  # Constant reward of 1
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = 0
        info = {}
        return observation, info
class agilerl.utils.probe_envs.ConstantRewardContActionsImageEnv
class ConstantRewardContActionsImageEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Box(0.0, 0.0, (1, 3, 3))
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.sample_obs = [np.zeros((1, 1, 3, 3))]
        self.sample_actions = [[[1.0]]]
        self.q_values = [[1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[1.0]]  # Correct V value to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = np.zeros((1, 3, 3))
        reward = 1  # Constant reward of 1
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = np.zeros((1, 3, 3))
        info = {}
        return observation, info
class agilerl.utils.probe_envs.ObsDependentRewardEnv
class ObsDependentRewardEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Discrete(1)
        self.last_obs = 1
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.q_values = [[-1.0], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[-1.0], [1.0]]  # Correct V values to learn, s table

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = self.last_obs
        reward = -1 if self.last_obs == 0 else 1  # Reward depends on observation
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice([0, 1])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.ObsDependentRewardImageEnv
class ObsDependentRewardImageEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Box(0.0, 1.0, (1, 3, 3))
        self.action_space = spaces.Discrete(1)
        self.last_obs = np.ones((1, 3, 3))
        self.sample_obs = [np.zeros((1, 1, 3, 3)), np.ones((1, 1, 3, 3))]
        self.q_values = [[-1.0], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[-1.0], [1.0]]  # Correct V values to learn, s table

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = self.last_obs
        reward = (
            -1 if np.mean(self.last_obs) == 0.0 else 1
        )  # Reward depends on observation
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice([np.zeros((1, 3, 3)), np.ones((1, 3, 3))])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.ObsDependentRewardContActionsEnv
class ObsDependentRewardContActionsEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.last_obs = 1
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.sample_actions = [[[1.0]], [[1.0]]]
        self.q_values = [[-1.0], [1.0]]  # Correct Q values to learn, s x a table
        self.policy_values = [None]  # Correct policy to learn
        self.v_values = [[-1.0], [1.0]]  # Correct V values to learn, s table

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = self.last_obs
        reward = -1 if self.last_obs == 0 else 1  # Reward depends on observation
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice([0, 1])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.ObsDependentRewardContActionsImageEnv
class ObsDependentRewardContActionsImageEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Box(0.0, 1.0, (1, 3, 3))
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.last_obs = np.ones((1, 3, 3))
        self.sample_obs = [np.zeros((1, 1, 3, 3)), np.ones((1, 1, 3, 3))]
        self.sample_actions = [[[1.0]], [[1.0]]]
        self.q_values = [[-1.0], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[-1.0], [1.0]]  # Correct V values to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = self.last_obs
        reward = (
            -1 if np.mean(self.last_obs) == 0.0 else 1
        )  # Reward depends on observationspaces.Box(0.0, 1.0
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice([np.zeros((1, 3, 3)), np.ones((1, 3, 3))])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.DiscountedRewardEnv
class DiscountedRewardEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Discrete(1)
        self.last_obs = 0
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.q_values = [[0.99], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[0.99], [1.0]]  # Correct V values to learn, s table

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = 1
        reward = self.last_obs  # Reward depends on observation
        terminated = self.last_obs  # Terminate after second step
        truncated = False
        info = {}
        self.last_obs = 1
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = 0
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.DiscountedRewardImageEnv
class DiscountedRewardImageEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Box(0.0, 1.0, (1, 3, 3))
        self.action_space = spaces.Discrete(1)
        self.last_obs = np.zeros((1, 3, 3))
        self.sample_obs = [np.zeros((1, 1, 3, 3)), np.ones((1, 1, 3, 3))]
        self.q_values = [[0.99], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[0.99], [1.0]]  # Correct V values to learn, s table

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = np.ones((1, 3, 3))
        reward = np.mean(self.last_obs)  # Reward depends on observation
        terminated = int(np.mean(self.last_obs))  # Terminate after second step
        truncated = False
        info = {}
        self.last_obs = np.ones((1, 3, 3))
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = np.zeros((1, 3, 3))
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.DiscountedRewardContActionsEnv
class DiscountedRewardContActionsEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.last_obs = 0
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.sample_actions = [[[1.0]], [[1.0]]]
        self.q_values = [[0.99], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[0.99], [1.0]]  # Correct V values to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = 1
        reward = self.last_obs  # Reward depends on observation
        terminated = self.last_obs  # Terminate after second step
        truncated = False
        info = {}
        self.last_obs = 1
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = 0
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.DiscountedRewardContActionsImageEnv
class DiscountedRewardContActionsImageEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Box(0.0, 1.0, (1, 3, 3))
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.last_obs = np.zeros((1, 3, 3))
        self.sample_obs = [np.zeros((1, 1, 3, 3)), np.ones((1, 1, 3, 3))]
        self.sample_actions = [[[1.0]], [[1.0]]]
        self.q_values = [[0.99], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[0.99], [1.0]]  # Correct V values to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = np.ones((1, 3, 3))
        reward = np.mean(self.last_obs)  # Reward depends on observation
        terminated = int(np.mean(self.last_obs))  # Terminate after second step
        truncated = False
        info = {}
        self.last_obs = np.ones((1, 3, 3))
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = np.zeros((1, 3, 3))
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.FixedObsPolicyEnv
class FixedObsPolicyEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Discrete(1)
        self.action_space = spaces.Discrete(2)
        self.sample_obs = [np.array([[0]])]
        self.q_values = [[-1.0, 1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table
        self.policy_values = [[0.0, 1.0]]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        if isinstance(action, (np.ndarray, list)):
            action = action[0]
        observation = 0
        reward = [-1, 1][action]  # Reward depends on action
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = 0
        info = {}
        return observation, info
class agilerl.utils.probe_envs.FixedObsPolicyImageEnv
class FixedObsPolicyImageEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Box(0.0, 0.0, (1, 3, 3))
        self.action_space = spaces.Discrete(2)
        self.sample_obs = [np.zeros((1, 1, 3, 3))]
        self.q_values = [[-1.0, 1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table
        self.policy_values = [[0.0, 1.0]]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = np.zeros((1, 3, 3))
        if isinstance(action, (np.ndarray, list)):
            action = action[0]
        reward = [-1, 1][action]  # Reward depends on action
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = np.zeros((1, 3, 3))
        info = {}
        return observation, info
class agilerl.utils.probe_envs.FixedObsPolicyContActionsEnv
class FixedObsPolicyContActionsEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Discrete(1)
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.sample_obs = [np.array([[0]])]
        self.sample_actions = [np.array([[1.0]])]
        self.q_values = np.array([[0.0]])  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table
        self.policy_values = [[1.0]]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = 0
        reward = -((1 - action[0]) ** 2)  # Reward depends on action
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = 0
        info = {}
        return observation, info
class agilerl.utils.probe_envs.FixedObsPolicyContActionsImageEnv
class FixedObsPolicyContActionsImageEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Box(0.0, 1.0, (1, 3, 3))
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.sample_obs = [np.zeros((1, 1, 3, 3))]
        self.sample_actions = [np.array([[1.0]])]
        self.q_values = np.array([[0.0]])  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table
        self.policy_values = [[1.0]]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = np.zeros((1, 3, 3))
        reward = -((1 - action[0]) ** 2)  # Reward depends on action
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = np.zeros((1, 3, 3))
        info = {}
        return observation, info
class agilerl.utils.probe_envs.PolicyEnv
class PolicyEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Discrete(2)
        self.last_obs = 0
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.q_values = [
            [1.0, -1.0],
            [-1.0, 1.0],
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = self.last_obs
        reward = (
            1 if action == self.last_obs else -1
        )  # Reward depends on action in observation
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice([0, 1])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.PolicyImageEnv
class PolicyImageEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Box(0.0, 1.0, (1, 3, 3))
        self.action_space = spaces.Discrete(2)
        self.last_obs = np.ones((1, 3, 3))
        self.sample_obs = [np.zeros((1, 1, 3, 3)), np.ones((1, 1, 3, 3))]
        self.q_values = [
            [1.0, -1.0],
            [-1.0, 1.0],
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = self.last_obs
        reward = (
            1 if action == int(np.mean(self.last_obs)) else -1
        )  # Reward depends on action in observation
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice([np.zeros((1, 3, 3)), np.ones((1, 3, 3))])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.PolicyContActionsEnv
class PolicyContActionsEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Box(0.0, 1.0, (2,))
        self.last_obs = 0
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.sample_actions = [np.array([[1.0, 0.0]]), np.array([[0.0, 1.0]])]
        self.q_values = [[0.0], [0.0]]  # Correct Q values to learn
        self.v_values = [None]  # Correct V values to learn, s table
        self.policy_values = [[1.0, 0.0], [0.0, 1.0]]  # Correct policy to learn

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = self.last_obs
        if self.last_obs:  # last obs = 1, policy should be [0, 1]
            reward = -((0 - action[0]) ** 2) - (1 - action[1]) ** 2
        else:  # last obs = 0, policy should be [1, 0]
            reward = -((1 - action[0]) ** 2) - (0 - action[1]) ** 2
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice([0, 1])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.PolicyContActionsImageEnvSimple
class PolicyContActionsImageEnvSimple(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Box(0.0, 1.0, (1, 3, 3))
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.last_obs = np.zeros((1, 3, 3))
        self.sample_obs = [
            np.zeros((1, 1, 3, 3)),
            np.zeros((1, 1, 3, 3)),
            np.ones((1, 1, 3, 3)),
            np.ones((1, 1, 3, 3)),
        ]
        self.sample_actions = [
            np.array([[0.0]]),
            np.array([[1.0]]),
            np.array([[0.0]]),
            np.array([[1.0]]),
        ]
        self.q_values = [[0.0], [-1.0], [-1.0], [0.0]]  # Correct Q values to learn
        self.policy_values = [[0.0], [0.0], [1.0], [1.0]]  # Correct policy to learn
        self.v_values = [None]  # Correct V values to learn, s table

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = self.last_obs
        if int(np.mean(self.last_obs)):  # last obs = 1, policy should be [1]
            reward = -((1 - action[0]) ** 2)
        else:  # last obs = 0, policy should be [0]
            reward = -((0 - action[0]) ** 2)
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        # self.last_obs = random.choice([np.zeros((1, 3, 3)), np.ones((1, 3, 3))])
        if int(np.mean(self.last_obs)):
            self.last_obs = np.zeros((1, 3, 3))
        else:
            self.last_obs = np.ones((1, 3, 3))

        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.PolicyContActionsImageEnv
class PolicyContActionsImageEnv(gym.Env):
    def __init__(self) -> None:
        self.observation_space = spaces.Box(0.0, 1.0, (1, 3, 3))
        self.action_space = spaces.Box(0.0, 1.0, (2,))
        self.last_obs = np.zeros((1, 3, 3))
        self.sample_obs = [np.zeros((1, 1, 3, 3)), np.ones((1, 1, 3, 3))]
        self.sample_actions = [np.array([[1.0, 0.0]]), np.array([[0.0, 1.0]])]
        self.q_values = [[0.0], [0.0]]  # Correct Q values to learn
        self.policy_values = [[1.0, 0.0], [0.0, 1.0]]  # Correct policy to learn
        self.v_values = [None]  # Correct V values to learn, s table

    def step(
        self,
        action: int | np.ndarray,
    ) -> tuple[Any, float, bool, bool, dict[str, Any]]:
        observation = self.last_obs
        if int(np.mean(self.last_obs)):  # last obs = 1, policy should be [0, 1]
            reward = -((0 - action[0]) ** 2) - (1 - action[1]) ** 2
        else:  # last obs = 0, policy should be [1, 0]
            reward = -((1 - action[0]) ** 2) - (0 - action[1]) ** 2
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice([np.zeros((1, 3, 3)), np.ones((1, 3, 3))])
        info = {}
        return self.last_obs, info
agilerl.utils.probe_envs.check_q_learning_with_probe_env(env: Env[Any, Any], algo_class: type[Any], algo_args: dict[str, Any], memory: Any, learn_steps: int = 10000, device: str = 'cpu') None
def check_q_learning_with_probe_env(
    env: gym.Env[Any, Any],
    algo_class: type[Any],
    algo_args: dict[str, Any],
    memory: Any,
    learn_steps: int = 10000,
    device: str = "cpu",
) -> None:

    agent = algo_class(**algo_args, device=device)

    state, _ = env.reset()
    for _ in range(1000):
        if isinstance(state, dict):
            state = {k: np.expand_dims(v, 0) for k, v in state.items()}
        else:
            state = np.expand_dims(state, 0)
        action = agent.get_action(state, epsilon=1)
        next_state, reward, done, _, _ = env.step(action)
        transition = Transition(
            obs=state,
            action=action,
            reward=reward,
            next_obs=next_state,
            done=done,
        ).to_tensordict()
        transition = transition.unsqueeze(0)
        transition.batch_size = [1]
        memory.add(transition)
        state = next_state
        if done:
            state, _ = env.reset()

    # Learn from experiences
    for i in trange(learn_steps):
        experiences = memory.sample(agent.batch_size)
        # Learn according to agent's RL algorithm
        agent.learn(experiences)
        if i < 20:
            pass

    for sample_obs, _q_values in zip(env.sample_obs, env.q_values, strict=False):
        agent.actor(sample_obs).detach().cpu().numpy()[0]
agilerl.utils.probe_envs.check_policy_q_learning_with_probe_env(env: Env[Any, Any], algo_class: type[Any], algo_args: dict[str, Any], memory: Any, learn_steps: int = 10000, device: str = 'cpu') None
def check_policy_q_learning_with_probe_env(
    env: gym.Env[Any, Any],
    algo_class: type[Any],
    algo_args: dict[str, Any],
    memory: Any,
    learn_steps: int = 10000,
    device: str = "cpu",
) -> None:

    agent = algo_class(**algo_args, device=device)

    state, _ = env.reset()
    for _ in range(5000):
        action = (
            (agent.action_space.high - agent.action_space.low)
            * np.random.rand(1, agent.action_dim).astype("float32")
        ) + agent.action_space.low
        action = action[0]
        next_state, reward, done, _, _ = env.step(action)
        transition = Transition(
            obs=state,
            action=action,
            reward=reward,
            next_obs=next_state,
            done=done,
        ).to_tensordict()
        transition = transition.unsqueeze(0)
        transition.batch_size = [1]
        memory.add(transition)
        state = next_state
        if done:
            state, _ = env.reset()

    # Learn from experiences
    for i in trange(learn_steps):
        experiences = memory.sample(agent.batch_size)
        # Learn according to agent's RL algorithm
        agent.learn(experiences)
        if i < 20:
            pass

    for sample_obs, sample_action, _q_values, policy_values in zip(
        env.sample_obs,
        env.sample_actions,
        env.q_values,
        env.policy_values,
        strict=False,
    ):
        if isinstance(sample_obs, dict):
            state = {
                k: torch.tensor(v).float().to(device) for k, v in sample_obs.items()
            }
        else:
            state = torch.tensor(sample_obs).float().to(device)

        agent.critic.eval()
        agent.actor.eval()
        action = torch.tensor(sample_action).float().to(device)
        agent.critic(state, action).detach().cpu().numpy()[0]
        # assert np.allclose(
        #     q_values, predicted_q_values, atol=0.15
        # ), f"{q_values} != {predicted_q_values}"

        if policy_values is not None:
            agent.actor(sample_obs).detach().cpu().numpy()[0]
agilerl.utils.probe_envs.check_policy_on_policy_with_probe_env(env: Env[Any, Any], algo_class: type[Any], algo_args: dict[str, Any], learn_steps: int = 5000, device: str = 'cpu', discrete: bool = True) None
def check_policy_on_policy_with_probe_env(
    env: gym.Env[Any, Any],
    algo_class: type[Any],
    algo_args: dict[str, Any],
    learn_steps: int = 5000,
    device: str = "cpu",
    discrete: bool = True,
) -> None:

    agent = algo_class(**algo_args, device=device)

    for i in trange(learn_steps):
        state, _ = env.reset()
        states = []
        actions = []
        log_probs = []
        rewards = []
        dones = []
        values = []

        done = 0

        for _j in range(200):
            if isinstance(state, dict):
                state = {k: np.expand_dims(v, 0) for k, v in state.items()}
            else:
                state = np.expand_dims(state, 0)

            action, log_prob, _, value = agent.get_action(state)

            action = action[0]
            log_prob = log_prob[0]
            value = value[0]
            next_state, reward, term, trunc, _ = env.step(action)
            next_done = np.logical_or(term, trunc).astype(np.int8)

            states.append(state)
            actions.append(action)
            log_probs.append(log_prob)
            rewards.append(reward)
            dones.append(done)
            values.append(value)

            state = next_state
            done = next_done

            if done:
                state, _ = env.reset()

        experiences = (
            states,
            actions,
            log_probs,
            rewards,
            dones,
            values,
            next_state,
            next_done,
        )
        agent.learn(experiences)
        if i < 20:
            pass

    for sample_obs, v_values, policy_values in zip(
        env.sample_obs,
        env.v_values,
        env.policy_values,
        strict=False,
    ):
        if isinstance(sample_obs, dict):
            state = {
                k: torch.tensor(v).float().to(device) for k, v in sample_obs.items()
            }
        else:
            state = torch.tensor(sample_obs).float().to(device)

        if v_values is not None:
            agent.critic(state).detach().cpu().numpy()[0]
            # assert np.allclose(
            #     v_values, predicted_v_values, atol=0.2
            # ), f"{v_values} != {predicted_v_values}"

        if policy_values is not None:
            # Assumes it is always a discrete action space
            _, _, _ = agent.actor(state)

Multi-agent probe environments

class agilerl.utils.probe_envs_ma.ConstantRewardEnv
class ConstantRewardEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.observation_space = {
            "agent_0": spaces.Discrete(1),
            "other_agent_0": spaces.Discrete(1),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[0]])},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[0.2, 0.8]]),
                "other_agent_0": np.array([[0.8, 0.2]]),
            },
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
        ]  # Correct V values to learn, s table
        self.policy_values = [None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        reward = {"agent_0": 1, "other_agent_0": 0}  # Constant reward of 1
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.ConstantRewardImageEnv
class ConstantRewardImageEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }

        self.sample_obs = [
            {
                "agent_0": np.zeros((1, 1, 3, 3)),
                "other_agent_0": np.zeros((1, 1, 3, 3)),
            },
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[0.2, 0.8]]),
                "other_agent_0": np.array([[0.8, 0.2]]),
            },
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
        ]  # Correct V values to learn, s table
        self.policy_values = [None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        reward = {"agent_0": 1, "other_agent_0": 0}  # Constant reward of 1
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.ConstantRewardContActionsEnv
class ConstantRewardContActionsEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.observation_space = {
            "agent_0": spaces.Discrete(1),
            "other_agent_0": spaces.Discrete(1),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[0]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.0]]), "other_agent_0": np.array([[1.0]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
        ]  # Correct V values to learn, s table
        self.policy_values = [None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        reward = {"agent_0": 1, "other_agent_0": 0}  # Constant reward
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.ConstantRewardContActionsImageEnv
class ConstantRewardContActionsImageEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {
                "agent_0": np.zeros((1, 1, 3, 3)),
                "other_agent_0": np.zeros((1, 1, 3, 3)),
            },
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.0]]), "other_agent_0": np.array([[1.0]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
        ]  # Correct V values to learn, s table
        self.policy_values = [None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        reward = {"agent_0": 1, "other_agent_0": 0}  # Constant reward
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.ObsDependentRewardEnv
class ObsDependentRewardEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[0]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[1]])},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[0.2, 0.8]]),
                "other_agent_0": np.array([[0.8, 0.2]]),
            },
            {
                "agent_0": np.array([[0.8, 0.2]]),
                "other_agent_0": np.array([[0.2, 0.8]]),
            },
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 1.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 1.0},
        ]  # Correct V values to learn, s table
        self.policy_values = [None, None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = self.last_obs
        reward = (
            {"agent_0": 1, "other_agent_0": 0}
            if self.last_obs["agent_0"] == 0
            else {"agent_0": 0, "other_agent_0": 1}
        )  # Reward depends on observation
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice(
            [
                {"agent_0": np.array([0]), "other_agent_0": np.array([0])},
                {"agent_0": np.array([1]), "other_agent_0": np.array([1])},
            ],
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.ObsDependentRewardImageEnv
class ObsDependentRewardImageEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }

        self.sample_obs = [
            {
                "agent_0": np.zeros((1, 1, 3, 3)),
                "other_agent_0": np.zeros((1, 1, 3, 3)),
            },
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[0.2, 0.8]]),
                "other_agent_0": np.array([[0.8, 0.2]]),
            },
            {
                "agent_0": np.array([[0.8, 0.2]]),
                "other_agent_0": np.array([[0.2, 0.8]]),
            },
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 1.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 1.0},
        ]  # Correct V values to learn, s table
        self.policy_values = [None, None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = self.last_obs
        reward = (
            {"agent_0": 1, "other_agent_0": 0}
            if np.mean(self.last_obs["agent_0"]) == 0
            else {"agent_0": 0, "other_agent_0": 1}
        )  # Reward depends on observation
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice(
            [
                {"agent_0": np.zeros((1, 3, 3)), "other_agent_0": np.zeros((1, 3, 3))},
                {"agent_0": np.ones((1, 3, 3)), "other_agent_0": np.ones((1, 3, 3))},
            ],
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.ObsDependentRewardContActionsEnv
class ObsDependentRewardContActionsEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[0]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[1]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2]]), "other_agent_0": np.array([[0.0]])},
            {"agent_0": np.array([[0.8]]), "other_agent_0": np.array([[0.6]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 1.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 1.0},
        ]  # Correct V values to learn, s table
        self.policy_values = [None, None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = self.last_obs
        reward = (
            {"agent_0": 1, "other_agent_0": 0}
            if self.last_obs["agent_0"] == 0
            else {"agent_0": 0, "other_agent_0": 1}
        )  # Reward depends on observation
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice(
            [
                {"agent_0": np.array([0]), "other_agent_0": np.array([0])},
                {"agent_0": np.array([1]), "other_agent_0": np.array([1])},
            ],
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.ObsDependentRewardContActionsImageEnv
class ObsDependentRewardContActionsImageEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {
                "agent_0": np.zeros((1, 1, 3, 3)),
                "other_agent_0": np.zeros((1, 1, 3, 3)),
            },
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2]]), "other_agent_0": np.array([[0.0]])},
            {"agent_0": np.array([[0.8]]), "other_agent_0": np.array([[0.6]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 1.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 1.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 1.0},
        ]  # Correct V values to learn, s table
        self.policy_values = [None, None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = self.last_obs
        reward = (
            {"agent_0": 1, "other_agent_0": 0}
            if np.mean(self.last_obs["agent_0"]) == 0
            else {"agent_0": 0, "other_agent_0": 1}
        )  # Reward depends on observation
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice(
            [
                {"agent_0": np.zeros((1, 3, 3)), "other_agent_0": np.zeros((1, 3, 3))},
                {"agent_0": np.ones((1, 3, 3)), "other_agent_0": np.ones((1, 3, 3))},
            ],
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.DiscountedRewardEnv
class DiscountedRewardEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[0]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[1]])},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[0.2, 0.8]]),
                "other_agent_0": np.array([[0.8, 0.2]]),
            },
            {
                "agent_0": np.array([[0.8, 0.2]]),
                "other_agent_0": np.array([[0.2, 0.8]]),
            },
        ]
        self.q_values = [
            {"agent_0": 0.99, "other_agent_0": 0.495},
            {"agent_0": 1.0, "other_agent_0": 0.5},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 0.99, "other_agent_0": 0.495},
            {"agent_0": 1.0, "other_agent_0": 0.5},
        ]  # Correct V values to learn, s table
        self.policy_values = [None, None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {"agent_0": np.array([1]), "other_agent_0": np.array([1])}
        reward = (
            {"agent_0": 1, "other_agent_0": 0.5}
            if self.last_obs["agent_0"] == 1
            else {"agent_0": 0, "other_agent_0": 0}
        )  # Reward depends on observation  # Reward depends on observation
        terminated = {
            agent: obs[0] for agent, obs in self.last_obs.items()
        }  # Terminate after second step
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        self.last_obs = {"agent_0": np.array([1]), "other_agent_0": np.array([1])}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.DiscountedRewardImageEnv
class DiscountedRewardImageEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }

        self.sample_obs = [
            {
                "agent_0": np.zeros((1, 1, 3, 3)),
                "other_agent_0": np.zeros((1, 1, 3, 3)),
            },
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[0.2, 0.8]]),
                "other_agent_0": np.array([[0.8, 0.2]]),
            },
            {
                "agent_0": np.array([[0.8, 0.2]]),
                "other_agent_0": np.array([[0.2, 0.8]]),
            },
        ]
        self.q_values = [
            {"agent_0": 0.99, "other_agent_0": 0.495},
            {"agent_0": 1.0, "other_agent_0": 0.5},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 0.99, "other_agent_0": 0.495},
            {"agent_0": 1.0, "other_agent_0": 0.5},
        ]  # Correct V values to learn, s table
        self.policy_values = [None, None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {
            "agent_0": np.ones((1, 3, 3)),
            "other_agent_0": np.ones((1, 3, 3)),
        }
        reward = (
            {"agent_0": 1, "other_agent_0": 0.5}
            if np.mean(self.last_obs["agent_0"]) == 1
            else {"agent_0": 0, "other_agent_0": 0}
        )  # Reward depends on observation  # Reward depends on observation
        terminated = {
            "agent_0": int(np.mean(self.last_obs["agent_0"])),
            "other_agent_0": int(np.mean(self.last_obs["agent_0"])),
        }  # Terminate after second step
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        self.last_obs = {
            "agent_0": np.ones((1, 3, 3)),
            "other_agent_0": np.ones((1, 3, 3)),
        }
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.DiscountedRewardContActionsEnv
class DiscountedRewardContActionsEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[0]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[1]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2]]), "other_agent_0": np.array([[0.4]])},
            {"agent_0": np.array([[0.8]]), "other_agent_0": np.array([[0.1]])},
        ]
        self.q_values = [
            {"agent_0": 0.99, "other_agent_0": 0.495},
            {"agent_0": 1.0, "other_agent_0": 0.5},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 0.99, "other_agent_0": 0.495},
            {"agent_0": 1.0, "other_agent_0": 0.5},
        ]  # Correct V values to learn, s table
        self.policy_values = [None, None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {"agent_0": np.array([1]), "other_agent_0": np.array([1])}
        reward = (
            {"agent_0": 1, "other_agent_0": 0.5}
            if self.last_obs["agent_0"] == 1
            else {"agent_0": 0, "other_agent_0": 0}
        )  # Reward depends on observation  # Reward depends on observation
        terminated = {
            agent: obs[0] for agent, obs in self.last_obs.items()
        }  # Terminate after second step
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        self.last_obs = {"agent_0": np.array([1]), "other_agent_0": np.array([1])}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.DiscountedRewardContActionsImageEnv
class DiscountedRewardContActionsImageEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {
                "agent_0": np.zeros((1, 1, 3, 3)),
                "other_agent_0": np.zeros((1, 1, 3, 3)),
            },
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2]]), "other_agent_0": np.array([[0.4]])},
            {"agent_0": np.array([[0.8]]), "other_agent_0": np.array([[0.1]])},
        ]
        self.q_values = [
            {"agent_0": 0.99, "other_agent_0": 0.495},
            {"agent_0": 1.0, "other_agent_0": 0.5},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [
            {"agent_0": 0.99, "other_agent_0": 0.495},
            {"agent_0": 1.0, "other_agent_0": 0.5},
        ]  # Correct V values to learn, s table
        self.policy_values = [None, None]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {
            "agent_0": np.ones((1, 3, 3)),
            "other_agent_0": np.ones((1, 3, 3)),
        }
        reward = (
            {"agent_0": 1, "other_agent_0": 0.5}
            if np.mean(self.last_obs["agent_0"]) == 1
            else {"agent_0": 0, "other_agent_0": 0}
        )  # Reward depends on observation  # Reward depends on observation
        terminated = {
            "agent_0": int(np.mean(self.last_obs["agent_0"])),
            "other_agent_0": int(np.mean(self.last_obs["agent_0"])),
        }  # Terminate after second step
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        self.last_obs = {
            "agent_0": np.ones((1, 3, 3)),
            "other_agent_0": np.ones((1, 3, 3)),
        }
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.FixedObsPolicyEnv
class FixedObsPolicyEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(1),
            "other_agent_0": spaces.Discrete(1),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[0]])},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 1.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]
        self.policy_values = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
        ]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        reward = {
            "agent_0": [1, -1][int(np.asarray(action["agent_0"]).flat[0])],
            "other_agent_0": [-1, 1][int(np.asarray(action["other_agent_0"]).flat[0])],
        }  # Reward depends on action
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.FixedObsPolicyImageEnv
class FixedObsPolicyImageEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }

        self.sample_obs = [
            {
                "agent_0": np.zeros((1, 1, 3, 3)),
                "other_agent_0": np.zeros((1, 1, 3, 3)),
            },
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 1.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]
        self.policy_values = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
        ]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        reward = {
            "agent_0": [1, -1][int(np.asarray(action["agent_0"]).flat[0])],
            "other_agent_0": [-1, 1][int(np.asarray(action["other_agent_0"]).flat[0])],
        }  # Reward depends on action
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.FixedObsPolicyContActionsEnv
class FixedObsPolicyContActionsEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(1),
            "other_agent_0": spaces.Discrete(1),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[0]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0]]), "other_agent_0": np.array([[0.0]])},
        ]
        self.q_values = [
            {"agent_0": 0.0, "other_agent_0": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]
        self.policy_values = [
            {"agent_0": np.array([1.0]), "other_agent_0": np.array([0.0])},
        ]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        reward = {
            "agent_0": -((1 - action["agent_0"]) ** 2),
            "other_agent_0": -((0 - action["other_agent_0"]) ** 2),
        }  # Reward depends on action
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.FixedObsPolicyContActionsImageEnv
class FixedObsPolicyContActionsImageEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {
                "agent_0": np.zeros((1, 1, 3, 3)),
                "other_agent_0": np.zeros((1, 1, 3, 3)),
            },
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0]]), "other_agent_0": np.array([[0.0]])},
        ]
        self.q_values = [
            {"agent_0": 0.0, "other_agent_0": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]
        self.policy_values = [
            {"agent_0": np.array([1.0]), "other_agent_0": np.array([0.0])},
        ]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        reward = {
            "agent_0": -((1 - action["agent_0"]) ** 2),
            "other_agent_0": -((0 - action["other_agent_0"]) ** 2),
        }  # Reward depends on action
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        observation = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.PolicyEnv
class PolicyEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }
        self.sample_obs = [
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[1]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[0]])},
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[1]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[0]])},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 1.0},
            {"agent_0": 1.0, "other_agent_0": 1.0},
            {"agent_0": 0.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]
        self.policy_values = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
        ]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = self.last_obs
        reward = {
            "agent_0": action["agent_0"] == self.last_obs["agent_0"],
            "other_agent_0": action["other_agent_0"] != self.last_obs["other_agent_0"],
        }  # Reward depends on action
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice(
            [
                {"agent_0": np.array([0]), "other_agent_0": np.array([0])},
                {"agent_0": np.array([1]), "other_agent_0": np.array([1])},
                {"agent_0": np.array([0]), "other_agent_0": np.array([1])},
                {"agent_0": np.array([1]), "other_agent_0": np.array([0])},
            ],
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.PolicyImageEnv
class PolicyImageEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }
        self.sample_obs = [
            {"agent_0": np.zeros((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.zeros((1, 1, 3, 3))},
            {"agent_0": np.zeros((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.zeros((1, 1, 3, 3))},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
        ]
        self.q_values = [
            {"agent_0": 1.0, "other_agent_0": 1.0},
            {"agent_0": 1.0, "other_agent_0": 1.0},
            {"agent_0": 0.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]
        self.policy_values = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
        ]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = self.last_obs
        reward = {
            "agent_0": action["agent_0"] == np.mean(self.last_obs["agent_0"]),
            "other_agent_0": action["other_agent_0"]
            != np.mean(self.last_obs["other_agent_0"]),
        }  # Reward depends on action
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice(
            [
                {"agent_0": np.zeros((1, 3, 3)), "other_agent_0": np.zeros((1, 3, 3))},
                {"agent_0": np.ones((1, 3, 3)), "other_agent_0": np.ones((1, 3, 3))},
                {"agent_0": np.zeros((1, 3, 3)), "other_agent_0": np.ones((1, 3, 3))},
                {"agent_0": np.ones((1, 3, 3)), "other_agent_0": np.zeros((1, 3, 3))},
            ],
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.PolicyContActionsEnv
class PolicyContActionsEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (2,)),
            "other_agent_0": spaces.Box(0.0, 1.0, (2,)),
        }
        self.sample_obs = [
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[0]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[1]])},
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[1]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[0]])},
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[1]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[0]])},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 0.0]]),
            },
            {
                "agent_0": np.array([[1.0, 1.0]]),
                "other_agent_0": np.array([[1.0, 1.0]]),
            },
        ]
        self.q_values = [
            {"agent_0": 0.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 0.0},
            {"agent_0": -2.0, "other_agent_0": -2.0},
            {"agent_0": -2.0, "other_agent_0": -2.0},
            {"agent_0": -1.0, "other_agent_0": -1.0},
            {"agent_0": -1.0, "other_agent_0": -1.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]
        self.policy_values = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
        ]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = self.last_obs
        reward = {}
        if self.last_obs["agent_0"]:  # last obs = 1, policy should be [0, 1]
            reward["agent_0"] = -((0 - action["agent_0"][0]) ** 2) - (
                (1 - action["agent_0"][1]) ** 2
            )
        else:  # last obs = 0, policy should be [1, 0]
            reward["agent_0"] = -((1 - action["agent_0"][0]) ** 2) - (
                (0 - action["agent_0"][1]) ** 2
            )
        if self.last_obs["other_agent_0"]:  # last obs = 1, policy should be [1, 0]
            reward["other_agent_0"] = -((1 - action["other_agent_0"][0]) ** 2) - (
                (0 - action["other_agent_0"][1]) ** 2
            )
        else:  # last obs = 0, policy should be [0, 1]
            reward["other_agent_0"] = -((0 - action["other_agent_0"][0]) ** 2) - (
                (1 - action["other_agent_0"][1]) ** 2
            )
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice(
            [
                {"agent_0": np.array([0]), "other_agent_0": np.array([0])},
                {"agent_0": np.array([1]), "other_agent_0": np.array([1])},
                {"agent_0": np.array([0]), "other_agent_0": np.array([1])},
                {"agent_0": np.array([1]), "other_agent_0": np.array([0])},
            ],
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.PolicyContActionsImageEnv
class PolicyContActionsImageEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (2,)),
            "other_agent_0": spaces.Box(0.0, 1.0, (2,)),
        }
        self.sample_obs = [
            {
                "agent_0": np.zeros((1, 1, 3, 3)),
                "other_agent_0": np.zeros((1, 1, 3, 3)),
            },
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
            {"agent_0": np.zeros((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.zeros((1, 1, 3, 3))},
            {"agent_0": np.zeros((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.zeros((1, 1, 3, 3))},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 0.0]]),
            },
            {
                "agent_0": np.array([[1.0, 1.0]]),
                "other_agent_0": np.array([[1.0, 1.0]]),
            },
        ]
        self.q_values = [
            {"agent_0": 0.0, "other_agent_0": 0.0},
            {"agent_0": 0.0, "other_agent_0": 0.0},
            {"agent_0": -2.0, "other_agent_0": -2.0},
            {"agent_0": -2.0, "other_agent_0": -2.0},
            {"agent_0": -1.0, "other_agent_0": -1.0},
            {"agent_0": -1.0, "other_agent_0": -1.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]
        self.policy_values = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
        ]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = self.last_obs
        reward = {}
        # First, deal with agent_0
        if np.mean(self.last_obs["agent_0"]):  # last obs = 1, policy should be [0, 1]
            reward["agent_0"] = -((0 - action["agent_0"][0]) ** 2) - (
                (1 - action["agent_0"][1]) ** 2
            )
        else:  # last obs = 0, policy should be [1, 0]
            reward["agent_0"] = -((1 - action["agent_0"][0]) ** 2) - (
                (0 - action["agent_0"][1]) ** 2
            )

        # other_agent_0 should learn the opposite behaviour
        if np.mean(
            self.last_obs["other_agent_0"],
        ):  # last obs = 1, policy should be [1, 0]
            reward["other_agent_0"] = -((1 - action["other_agent_0"][0]) ** 2) - (
                (0 - action["other_agent_0"][1]) ** 2
            )
        else:  # last obs = 0, policy should be [0, 1]
            reward["other_agent_0"] = -((0 - action["other_agent_0"][0]) ** 2) - (
                (1 - action["other_agent_0"][1]) ** 2
            )

        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice(
            [
                {"agent_0": np.zeros((1, 3, 3)), "other_agent_0": np.zeros((1, 3, 3))},
                {"agent_0": np.ones((1, 3, 3)), "other_agent_0": np.ones((1, 3, 3))},
                {"agent_0": np.zeros((1, 3, 3)), "other_agent_0": np.ones((1, 3, 3))},
                {"agent_0": np.ones((1, 3, 3)), "other_agent_0": np.zeros((1, 3, 3))},
            ],
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.MultiPolicyEnv
class MultiPolicyEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "other_agent_0": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[1]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[0]])},
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[1]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[0]])},
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[1]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[0]])},
            {"agent_0": np.array([[0]]), "other_agent_0": np.array([[1]])},
            {"agent_0": np.array([[1]]), "other_agent_0": np.array([[0]])},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
        ]
        self.q_values = [
            {"agent_0": 2.0, "other_agent_0": 2.0},
            {"agent_0": 2.0, "other_agent_0": 2.0},
            {"agent_0": 1.0, "other_agent_0": 1.0},
            {"agent_0": 1.0, "other_agent_0": 1.0},
            {"agent_0": 0.0, "other_agent_0": 3.0},
            {"agent_0": 0.0, "other_agent_0": 3.0},
            {"agent_0": 3.0, "other_agent_0": 0.0},
            {"agent_0": 3.0, "other_agent_0": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]
        self.policy_values = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
        ]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = self.last_obs
        reward = {
            "agent_0": 2 * (action["agent_0"] == self.last_obs["agent_0"])
            + (action["other_agent_0"] == self.last_obs["other_agent_0"]),
            "other_agent_0": 2
            * (action["other_agent_0"] != self.last_obs["other_agent_0"])
            + (action["agent_0"] != self.last_obs["agent_0"]),
        }  # Reward depends on action
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice(
            [
                {"agent_0": np.array([0]), "other_agent_0": np.array([0])},
                {"agent_0": np.array([1]), "other_agent_0": np.array([1])},
                {"agent_0": np.array([0]), "other_agent_0": np.array([1])},
                {"agent_0": np.array([1]), "other_agent_0": np.array([0])},
            ],
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.MultiPolicyImageEnv
class MultiPolicyImageEnv:
    def __init__(self) -> None:
        self.possible_agents = ["agent_0", "other_agent_0"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((1, 3, 3)),
            "other_agent_0": np.zeros((1, 3, 3)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
            "other_agent_0": spaces.Box(0.0, 1.0, (1, 3, 3)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "other_agent_0": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.zeros((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.zeros((1, 1, 3, 3))},
            {"agent_0": np.zeros((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.zeros((1, 1, 3, 3))},
            {"agent_0": np.zeros((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.zeros((1, 1, 3, 3))},
            {"agent_0": np.zeros((1, 1, 3, 3)), "other_agent_0": np.ones((1, 1, 3, 3))},
            {"agent_0": np.ones((1, 1, 3, 3)), "other_agent_0": np.zeros((1, 1, 3, 3))},
        ]
        self.sample_actions = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
        ]
        self.q_values = [
            {"agent_0": 2.0, "other_agent_0": 2.0},
            {"agent_0": 2.0, "other_agent_0": 2.0},
            {"agent_0": 1.0, "other_agent_0": 1.0},
            {"agent_0": 1.0, "other_agent_0": 1.0},
            {"agent_0": 0.0, "other_agent_0": 3.0},
            {"agent_0": 0.0, "other_agent_0": 3.0},
            {"agent_0": 3.0, "other_agent_0": 0.0},
            {"agent_0": 3.0, "other_agent_0": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]
        self.policy_values = [
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
            {
                "agent_0": np.array([[1.0, 0.0]]),
                "other_agent_0": np.array([[1.0, 0.0]]),
            },
            {
                "agent_0": np.array([[0.0, 1.0]]),
                "other_agent_0": np.array([[0.0, 1.0]]),
            },
        ]

    def step(
        self,
        action: dict[str, np.ndarray] | np.ndarray,
    ) -> tuple[Any, Any, Any, Any, dict[str, Any]]:
        observation = self.last_obs
        reward = {
            "agent_0": 2
            * (np.mean(action["agent_0"]) == np.mean(self.last_obs["agent_0"]))
            + (
                np.mean(action["other_agent_0"])
                == np.mean(self.last_obs["other_agent_0"])
            ),
            "other_agent_0": 2
            * (
                np.mean(action["other_agent_0"])
                != np.mean(self.last_obs["other_agent_0"])
            )
            + (np.mean(action["agent_0"]) != np.mean(self.last_obs["agent_0"])),
        }  # Reward depends on action
        terminated = {"agent_0": True, "other_agent_0": True}
        truncated = {"agent_0": False, "other_agent_0": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(
        self,
        seed: int | None = None,
        options: dict[str, Any] | None = None,
    ) -> tuple[Any, dict[str, Any]]:
        self.last_obs = random.choice(
            [
                {"agent_0": np.zeros((1, 3, 3)), "other_agent_0": np.zeros((1, 3, 3))},
                {"agent_0": np.ones((1, 3, 3)), "other_agent_0": np.ones((1, 3, 3))},
                {"agent_0": np.zeros((1, 3, 3)), "other_agent_0": np.ones((1, 3, 3))},
                {"agent_0": np.ones((1, 3, 3)), "other_agent_0": np.zeros((1, 3, 3))},
            ],
        )
        info = {}
        return self.last_obs, info
agilerl.utils.probe_envs_ma.check_policy_q_learning_with_probe_env(env: Any, algo_class: type[Any], algo_args: dict[str, Any], memory: Any, learn_steps: int = 1000, device: str = 'cpu') None
def check_policy_q_learning_with_probe_env(
    env: Any,
    algo_class: type[Any],
    algo_args: dict[str, Any],
    memory: Any,
    learn_steps: int = 1000,
    device: str = "cpu",
) -> None:

    agent = algo_class(**algo_args, vect_noise_dim=1, device=device)

    state, _ = env.reset()
    agent.set_training_mode(True)
    for _ in range(1000):
        # Make vectorized
        state = {agent_id: np.expand_dims(s, 0) for agent_id, s in state.items()}
        processed_action, raw_action = agent.get_action(state)
        next_state, reward, done, _, _ = env.step(processed_action)
        reward = {
            agent_id: np.expand_dims(np.array(r), 0) for agent_id, r in reward.items()
        }
        done = {
            agent_id: np.expand_dims(np.array(d), 0) for agent_id, d in done.items()
        }
        mem_next_state = {
            agent_id: np.expand_dims(ns, 0) for agent_id, ns in next_state.items()
        }
        memory.save_to_memory(
            state,
            raw_action,
            reward,
            mem_next_state,
            done,
            is_vectorised=True,
        )
        state = next_state
        if done[agent.agent_ids[0]]:
            state, _ = env.reset()

    # Learn from experiences
    for _ in trange(learn_steps):
        experiences = memory.sample(agent.batch_size)

        # Learn according to agent's RL algorithm
        agent.learn(experiences)

    agent.set_training_mode(False)
    with torch.no_grad():
        for agent_id in agent.agent_ids:
            actor = agent.actors[agent_id]
            critic = agent.critics[agent_id]
            for sample_obs, sample_action, q_values, policy_values in zip(
                env.sample_obs,
                env.sample_actions,
                env.q_values,
                env.policy_values,
                strict=False,
            ):
                state = prepare_ma_states(sample_obs, agent.observation_space, device)

                if q_values is not None:
                    action = prepare_ma_actions(sample_action, device)
                    stacked_actions = torch.cat(list(action.values()), dim=1)
                    predicted_q_values = (
                        critic(state, stacked_actions).detach().cpu().numpy()[0]
                    )
                    # print("---")
                    # print(agent_id, "q", q_values[agent_id], predicted_q_values)
                    # assert np.allclose(q_values[agent_id], predicted_q_values, atol=0.1):
                    if not np.allclose(
                        q_values[agent_id],
                        predicted_q_values,
                        atol=0.1,
                    ):
                        pass

                if policy_values is not None:
                    predicted_policy_values = (
                        actor(state[agent_id]).detach().cpu().numpy()[0]
                    )

                    # print(agent_id, "pol", policy_values[agent_id], predicted_policy_values)
                    # assert np.allclose(policy_values[agent_id], predicted_policy_values, atol=0.1)
                    if not np.allclose(
                        policy_values[agent_id],
                        predicted_policy_values,
                        atol=0.1,
                    ):
                        pass