Probe Environments

For more information on how to use probe environments, see Debugging RL.

Single-agent probe environments

class agilerl.utils.probe_envs.ConstantRewardEnv
class ConstantRewardEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Discrete(1)
        self.action_space = spaces.Discrete(1)
        self.sample_obs = [np.zeros((1, 1))]
        self.q_values = [[1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[1.0]]  # Correct V values to learn, s table

    def step(self, action):
        observation = 0
        reward = 1  # Constant reward of 1
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = 0
        info = {}
        return observation, info
class agilerl.utils.probe_envs.ConstantRewardImageEnv
class ConstantRewardImageEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Box(0.0, 0.0, (3, 32, 32))
        self.action_space = spaces.Discrete(1)
        self.sample_obs = [np.zeros((1, 3, 32, 32))]
        self.q_values = [[1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[1.0]]  # Correct V values to learn, s table

    def step(self, action):
        observation = np.zeros((3, 32, 32))
        reward = 1  # Constant reward of 1
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = np.zeros((3, 32, 32))
        info = {}
        return observation, info
class agilerl.utils.probe_envs.ConstantRewardContActionsEnv
class ConstantRewardContActionsEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Discrete(1)
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.sample_obs = [np.zeros((1, 1))]
        self.sample_actions = [[[1.0]]]
        self.q_values = [[1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[1.0]]  # Correct V values to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(self, action):
        observation = 0
        reward = 1  # Constant reward of 1
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = 0
        info = {}
        return observation, info
class agilerl.utils.probe_envs.ConstantRewardContActionsImageEnv
class ConstantRewardContActionsImageEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Box(0.0, 0.0, (3, 32, 32))
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.sample_obs = [np.zeros((1, 3, 32, 32))]
        self.sample_actions = [[[1.0]]]
        self.q_values = [[1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[1.0]]  # Correct V value to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(self, action):
        observation = np.zeros((3, 32, 32))
        reward = 1  # Constant reward of 1
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = np.zeros((3, 32, 32))
        info = {}
        return observation, info
class agilerl.utils.probe_envs.ObsDependentRewardEnv
class ObsDependentRewardEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Discrete(1)
        self.last_obs = 1
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.q_values = [[-1.0], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[-1.0], [1.0]]  # Correct V values to learn, s table

    def step(self, action):
        observation = self.last_obs
        reward = -1 if self.last_obs == 0 else 1  # Reward depends on observation
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice([0, 1])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.ObsDependentRewardImageEnv
class ObsDependentRewardImageEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Box(0.0, 1.0, (3, 32, 32))
        self.action_space = spaces.Discrete(1)
        self.last_obs = np.ones((3, 32, 32))
        self.sample_obs = [np.zeros((1, 3, 32, 32)), np.ones((1, 3, 32, 32))]
        self.q_values = [[-1.0], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[-1.0], [1.0]]  # Correct V values to learn, s table

    def step(self, action):
        observation = self.last_obs
        reward = (
            -1 if np.mean(self.last_obs) == 0.0 else 1
        )  # Reward depends on observation
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice([np.zeros((3, 32, 32)), np.ones((3, 32, 32))])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.ObsDependentRewardContActionsEnv
class ObsDependentRewardContActionsEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.last_obs = 1
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.sample_actions = [[[1.0]], [[1.0]]]
        self.q_values = [[-1.0], [1.0]]  # Correct Q values to learn, s x a table
        self.policy_values = [None]  # Correct policy to learn
        self.v_values = [[-1.0], [1.0]]  # Correct V values to learn, s table

    def step(self, action):
        observation = self.last_obs
        reward = -1 if self.last_obs == 0 else 1  # Reward depends on observation
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice([0, 1])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.ObsDependentRewardContActionsImageEnv
class ObsDependentRewardContActionsImageEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Box(0.0, 1.0, (3, 32, 32))
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.last_obs = np.ones((3, 32, 32))
        self.sample_obs = [np.zeros((1, 3, 32, 32)), np.ones((1, 3, 32, 32))]
        self.sample_actions = [[[1.0]], [[1.0]]]
        self.q_values = [[-1.0], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[-1.0], [1.0]]  # Correct V values to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(self, action):
        observation = self.last_obs
        reward = (
            -1 if np.mean(self.last_obs) == 0.0 else 1
        )  # Reward depends on observationspaces.Box(0.0, 1.0
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice([np.zeros((3, 32, 32)), np.ones((3, 32, 32))])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.DiscountedRewardEnv
class DiscountedRewardEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Discrete(1)
        self.last_obs = 0
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.q_values = [[0.99], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[0.99], [1.0]]  # Correct V values to learn, s table

    def step(self, action):
        observation = 1
        reward = self.last_obs  # Reward depends on observation
        terminated = self.last_obs  # Terminate after second step
        truncated = False
        info = {}
        self.last_obs = 1
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = 0
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.DiscountedRewardImageEnv
class DiscountedRewardImageEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Box(0.0, 1.0, (3, 32, 32))
        self.action_space = spaces.Discrete(1)
        self.last_obs = np.zeros((3, 32, 32))
        self.sample_obs = [np.zeros((1, 3, 32, 32)), np.ones((1, 3, 32, 32))]
        self.q_values = [[0.99], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[0.99], [1.0]]  # Correct V values to learn, s table

    def step(self, action):
        observation = np.ones((3, 32, 32))
        reward = np.mean(self.last_obs)  # Reward depends on observation
        terminated = int(np.mean(self.last_obs))  # Terminate after second step
        truncated = False
        info = {}
        self.last_obs = np.ones((3, 32, 32))
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = np.zeros((3, 32, 32))
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.DiscountedRewardContActionsEnv
class DiscountedRewardContActionsEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.last_obs = 0
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.sample_actions = [[[1.0]], [[1.0]]]
        self.q_values = [[0.99], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[0.99], [1.0]]  # Correct V values to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(self, action):
        observation = 1
        reward = self.last_obs  # Reward depends on observation
        terminated = self.last_obs  # Terminate after second step
        truncated = False
        info = {}
        self.last_obs = 1
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = 0
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.DiscountedRewardContActionsImageEnv
class DiscountedRewardContActionsImageEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Box(0.0, 1.0, (3, 32, 32))
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.last_obs = np.zeros((3, 32, 32))
        self.sample_obs = [np.zeros((1, 3, 32, 32)), np.ones((1, 3, 32, 32))]
        self.sample_actions = [[[1.0]], [[1.0]]]
        self.q_values = [[0.99], [1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [[0.99], [1.0]]  # Correct V values to learn, s table
        self.policy_values = [None]  # Correct policy to learn

    def step(self, action):
        observation = np.ones((3, 32, 32))
        reward = np.mean(self.last_obs)  # Reward depends on observation
        terminated = int(np.mean(self.last_obs))  # Terminate after second step
        truncated = False
        info = {}
        self.last_obs = np.ones((3, 32, 32))
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = np.zeros((3, 32, 32))
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.FixedObsPolicyEnv
class FixedObsPolicyEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Discrete(1)
        self.action_space = spaces.Discrete(2)
        self.sample_obs = [np.array([[0]])]
        self.q_values = [[-1.0, 1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table

    def step(self, action):
        if isinstance(action, (np.ndarray, list)):
            action = action[0]
        observation = 0
        reward = [-1, 1][action]  # Reward depends on action
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = 0
        info = {}
        return observation, info
class agilerl.utils.probe_envs.FixedObsPolicyImageEnv
class FixedObsPolicyImageEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Box(0.0, 0.0, (3, 32, 32))
        self.action_space = spaces.Discrete(2)
        self.sample_obs = [np.zeros((1, 3, 32, 32))]
        self.q_values = [[-1.0, 1.0]]  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table

    def step(self, action):
        observation = np.zeros((3, 32, 32))
        if isinstance(action, (np.ndarray, list)):
            action = action[0]
        reward = [-1, 1][action]  # Reward depends on action
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = np.zeros((3, 32, 32))
        info = {}
        return observation, info
class agilerl.utils.probe_envs.FixedObsPolicyContActionsEnv
class FixedObsPolicyContActionsEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Discrete(1)
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.sample_obs = [np.array([[0]])]
        self.sample_actions = [np.array([[1.0]])]
        self.q_values = np.array([[0.0]])  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table
        self.policy_values = [[1.0]]  # Correct policy to learn

    def step(self, action):
        observation = 0
        reward = -((1 - action[0]) ** 2)  # Reward depends on action
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = 0
        info = {}
        return observation, info
class agilerl.utils.probe_envs.FixedObsPolicyContActionsImageEnv
class FixedObsPolicyContActionsImageEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Box(0.0, 1.0, (3, 32, 32))
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.sample_obs = [np.zeros((1, 3, 32, 32))]
        self.sample_actions = [np.array([[1.0]])]
        self.q_values = np.array([[0.0]])  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table
        self.policy_values = [[1.0]]  # Correct policy to learn

    def step(self, action):
        observation = np.zeros((3, 32, 32))
        reward = -((1 - action[0]) ** 2)  # Reward depends on action
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = np.zeros((3, 32, 32))
        info = {}
        return observation, info
class agilerl.utils.probe_envs.PolicyEnv
class PolicyEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Discrete(2)
        self.last_obs = 0
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.q_values = [
            [1.0, -1.0],
            [-1.0, 1.0],
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table

    def step(self, action):
        observation = self.last_obs
        reward = (
            1 if action == self.last_obs else -1
        )  # Reward depends on action in observation
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice([0, 1])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.PolicyImageEnv
class PolicyImageEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Box(0.0, 1.0, (3, 32, 32))
        self.action_space = spaces.Discrete(2)
        self.last_obs = np.ones((3, 32, 32))
        self.sample_obs = [np.zeros((1, 3, 32, 32)), np.ones((1, 3, 32, 32))]
        self.q_values = [
            [1.0, -1.0],
            [-1.0, 1.0],
        ]  # Correct Q values to learn, s x a table
        self.v_values = [None]  # Correct V values to learn, s table

    def step(self, action):
        observation = self.last_obs
        reward = (
            1 if action == int(np.mean(self.last_obs)) else -1
        )  # Reward depends on action in observation
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice([np.zeros((3, 32, 32)), np.ones((3, 32, 32))])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.PolicyContActionsEnv
class PolicyContActionsEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Discrete(2)
        self.action_space = spaces.Box(0.0, 1.0, (2,))
        self.last_obs = 0
        self.sample_obs = [np.array([[1, 0]]), np.array([[0, 1]])]
        self.sample_actions = [np.array([[1.0, 0.0]]), np.array([[0.0, 1.0]])]
        self.q_values = [[0.0], [0.0]]  # Correct Q values to learn
        self.v_values = [None]  # Correct V values to learn, s table
        self.policy_values = [[1.0, 0.0], [0.0, 1.0]]  # Correct policy to learn

    def step(self, action):
        observation = self.last_obs
        if self.last_obs:  # last obs = 1, policy should be [0, 1]
            reward = -((0 - action[0]) ** 2) - (1 - action[1]) ** 2
        else:  # last obs = 0, policy should be [1, 0]
            reward = -((1 - action[0]) ** 2) - (0 - action[1]) ** 2
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice([0, 1])
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.PolicyContActionsImageEnvSimple
class PolicyContActionsImageEnvSimple(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Box(0.0, 1.0, (3, 32, 32))
        self.action_space = spaces.Box(0.0, 1.0, (1,))
        self.last_obs = np.zeros((3, 32, 32))
        self.sample_obs = [
            np.zeros((1, 3, 32, 32)),
            np.zeros((1, 3, 32, 32)),
            np.ones((1, 3, 32, 32)),
            np.ones((1, 3, 32, 32)),
        ]
        self.sample_actions = [
            np.array([[0.0]]),
            np.array([[1.0]]),
            np.array([[0.0]]),
            np.array([[1.0]]),
        ]
        self.q_values = [[0.0], [-1.0], [-1.0], [0.0]]  # Correct Q values to learn
        self.policy_values = [[0.0], [0.0], [1.0], [1.0]]  # Correct policy to learn
        self.v_values = [None]  # Correct V values to learn, s table

    def step(self, action):
        observation = self.last_obs
        if int(np.mean(self.last_obs)):  # last obs = 1, policy should be [1]
            reward = -((1 - action[0]) ** 2)
        else:  # last obs = 0, policy should be [0]
            reward = -((0 - action[0]) ** 2)
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        # self.last_obs = random.choice([np.zeros((3, 32, 32)), np.ones((3, 32, 32))])
        if int(np.mean(self.last_obs)):
            self.last_obs = np.zeros((3, 32, 32))
        else:
            self.last_obs = np.ones((3, 32, 32))

        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs.PolicyContActionsImageEnv
class PolicyContActionsImageEnv(gym.Env):
    def __init__(self):
        self.observation_space = spaces.Box(0.0, 1.0, (3, 32, 32))
        self.action_space = spaces.Box(0.0, 1.0, (2,))
        self.last_obs = np.zeros((3, 32, 32))
        self.sample_obs = [np.zeros((1, 3, 32, 32)), np.ones((1, 3, 32, 32))]
        self.sample_actions = [np.array([[1.0, 0.0]]), np.array([[0.0, 1.0]])]
        self.q_values = [[0.0], [0.0]]  # Correct Q values to learn
        self.policy_values = [[1.0, 0.0], [0.0, 1.0]]  # Correct policy to learn
        self.v_values = [None]  # Correct V values to learn, s table

    def step(self, action):
        observation = self.last_obs
        if int(np.mean(self.last_obs)):  # last obs = 1, policy should be [0, 1]
            reward = -((0 - action[0]) ** 2) - (1 - action[1]) ** 2
        else:  # last obs = 0, policy should be [1, 0]
            reward = -((1 - action[0]) ** 2) - (0 - action[1]) ** 2
        terminated = True
        truncated = False
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice([np.zeros((3, 32, 32)), np.ones((3, 32, 32))])
        info = {}
        return self.last_obs, info
agilerl.utils.probe_envs.check_q_learning_with_probe_env(env, algo_class, algo_args, memory, learn_steps=1000, device='cpu')
def check_q_learning_with_probe_env(
    env, algo_class, algo_args, memory, learn_steps=1000, device="cpu"
):
    print(f"Probe environment: {type(env).__name__}")

    agent = algo_class(**algo_args, device=device)

    state, _ = env.reset()
    for _ in range(500):
        action = agent.getAction(np.expand_dims(state, 0), epsilon=1)
        next_state, reward, done, _, _ = env.step(action)
        memory.save2memory(state, action, reward, next_state, done)
        state = next_state
        if done:
            state, _ = env.reset()

    # Learn from experiences
    for _ in trange(learn_steps):
        experiences = memory.sample(agent.batch_size)
        # Learn according to agent's RL algorithm
        agent.learn(experiences)

    for sample_obs, q_values in zip(env.sample_obs, env.q_values):
        predicted_q_values = agent.actor(sample_obs).detach().cpu().numpy()[0]
        assert np.allclose(q_values, predicted_q_values, atol=0.1)
agilerl.utils.probe_envs.check_policy_q_learning_with_probe_env(env, algo_class, algo_args, memory, learn_steps=1000, device='cpu')
def check_policy_q_learning_with_probe_env(
    env, algo_class, algo_args, memory, learn_steps=1000, device="cpu"
):
    print(f"Probe environment: {type(env).__name__}")

    agent = algo_class(**algo_args, device=device)

    state, _ = env.reset()
    for _ in range(5000):
        action = agent.getAction(np.expand_dims(state, 0), epsilon=1)[0]
        next_state, reward, done, _, _ = env.step(action)
        memory.save2memory(state, action, reward, next_state, done)
        state = next_state
        if done:
            state, _ = env.reset()

    # Learn from experiences
    for _ in trange(learn_steps):
        experiences = memory.sample(agent.batch_size)
        # Learn according to agent's RL algorithm
        agent.learn(experiences)

    for sample_obs, sample_action, q_values, policy_values in zip(
        env.sample_obs, env.sample_actions, env.q_values, env.policy_values
    ):
        state = torch.tensor(sample_obs).float().to(device)
        action = torch.tensor(sample_action).float().to(device)
        if agent.arch == "mlp":
            input_combined = torch.cat([state, action], 1)
            predicted_q_values = agent.critic(input_combined).detach().cpu().numpy()[0]
        else:
            predicted_q_values = agent.critic(state, action).detach().cpu().numpy()[0]
        # print("---")
        # print("q", q_values, predicted_q_values)
        assert np.allclose(q_values, predicted_q_values, atol=0.1)

        if policy_values is not None:
            predicted_policy_values = agent.actor(sample_obs).detach().cpu().numpy()[0]

            # print("pol", policy_values, predicted_policy_values)
            assert np.allclose(policy_values, predicted_policy_values, atol=0.1)
agilerl.utils.probe_envs.check_policy_on_policy_with_probe_env(env, algo_class, algo_args, learn_steps=5000, device='cpu')
def check_policy_on_policy_with_probe_env(
    env, algo_class, algo_args, learn_steps=5000, device="cpu"
):
    print(f"Probe environment: {type(env).__name__}")

    agent = algo_class(**algo_args, device=device)

    for _ in trange(learn_steps):
        state, _ = env.reset()
        states = []
        actions = []
        log_probs = []
        rewards = []
        dones = []
        values = []
        truncs = []

        for _ in range(100):
            action, log_prob, _, value = agent.getAction(np.expand_dims(state, 0))
            action = action[0]
            log_prob = log_prob[0]
            value = value[0]
            next_state, reward, done, trunc, _ = env.step(action)

            states.append(state)
            actions.append(action)
            log_probs.append(log_prob)
            rewards.append(reward)
            dones.append(done)
            values.append(value)
            truncs.append(trunc)

            state = next_state
            if done:
                state, _ = env.reset()

        experiences = (
            states,
            actions,
            log_probs,
            rewards,
            dones,
            values,
            next_state,
        )
        agent.learn(experiences)

    for sample_obs, v_values in zip(env.sample_obs, env.v_values):
        state = torch.tensor(sample_obs).float().to(device)
        if v_values is not None:
            predicted_v_values = agent.critic(state).detach().cpu().numpy()[0]
            # print("---")
            # print("v", v_values, predicted_v_values)
            assert np.allclose(v_values, predicted_v_values, atol=0.1)

    if hasattr(env, "sample_actions"):
        for sample_action, policy_values in zip(env.sample_actions, env.policy_values):
            action = torch.tensor(sample_action).float().to(device)
            if policy_values is not None:
                predicted_policy_values = (
                    agent.actor(sample_obs).detach().cpu().numpy()[0]
                )
                # print("pol", policy_values, predicted_policy_values)
                assert np.allclose(policy_values, predicted_policy_values, atol=0.1)

Multi-agent probe environments

class agilerl.utils.probe_envs_ma.ConstantRewardEnv
class ConstantRewardEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.observation_space = {
            "agent_0": spaces.Discrete(1),
            "agent_1": spaces.Discrete(1),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }

        self.sample_obs = [{"agent_0": np.array([[0]]), "agent_1": np.array([[0]])}]
        self.sample_actions = [
            {"agent_0": np.array([[0.2, 0.8]]), "agent_1": np.array([[0.8, 0.2]])}
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 0.0}
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None]

    def step(self, action):
        observation = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        reward = {"agent_0": 1, "agent_1": 0}  # Constant reward of 1
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.ConstantRewardImageEnv
class ConstantRewardImageEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (3, 32, 32)),
            "agent_1": spaces.Box(0.0, 1.0, (3, 32, 32)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))}
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2, 0.8]]), "agent_1": np.array([[0.8, 0.2]])}
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 0.0}
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None]

    def step(self, action):
        observation = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        reward = {"agent_0": 1, "agent_1": 0}  # Constant reward of 1
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.ConstantRewardContActionsEnv
class ConstantRewardContActionsEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.observation_space = {
            "agent_0": spaces.Discrete(1),
            "agent_1": spaces.Discrete(1),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "agent_1": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [{"agent_0": np.array([[0]]), "agent_1": np.array([[0]])}]
        self.sample_actions = [
            {"agent_0": np.array([[0.0]]), "agent_1": np.array([[1.0]])}
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 0.0}
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None]

    def step(self, action):
        observation = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        reward = {"agent_0": 1, "agent_1": 0}  # Constant reward
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.ConstantRewardContActionsImageEnv
class ConstantRewardContActionsImageEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (3, 32, 32)),
            "agent_1": spaces.Box(0.0, 1.0, (3, 32, 32)),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "agent_1": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))}
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.0]]), "agent_1": np.array([[1.0]])}
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 0.0}
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None]

    def step(self, action):
        observation = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        reward = {"agent_0": 1, "agent_1": 0}  # Constant reward
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.ObsDependentRewardEnv
class ObsDependentRewardEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "agent_1": np.array([[0]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[1]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2, 0.8]]), "agent_1": np.array([[0.8, 0.2]])},
            {"agent_0": np.array([[0.8, 0.2]]), "agent_1": np.array([[0.2, 0.8]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 0.0},
            {"agent_0": 0.0, "agent_1": 1.0},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None, None]

    def step(self, action):
        observation = self.last_obs
        reward = (
            {"agent_0": 1, "agent_1": 0}
            if self.last_obs["agent_0"] == 0
            else {"agent_0": 0, "agent_1": 1}
        )  # Reward depends on observation
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice(
            [
                {"agent_0": np.array([0]), "agent_1": np.array([0])},
                {"agent_0": np.array([1]), "agent_1": np.array([1])},
            ]
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.ObsDependentRewardImageEnv
class ObsDependentRewardImageEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (3, 32, 32)),
            "agent_1": spaces.Box(0.0, 1.0, (3, 32, 32)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2, 0.8]]), "agent_1": np.array([[0.8, 0.2]])},
            {"agent_0": np.array([[0.8, 0.2]]), "agent_1": np.array([[0.2, 0.8]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 0.0},
            {"agent_0": 0.0, "agent_1": 1.0},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None, None]

    def step(self, action):
        observation = self.last_obs
        reward = (
            {"agent_0": 1, "agent_1": 0}
            if np.mean(self.last_obs["agent_0"]) == 0
            else {"agent_0": 0, "agent_1": 1}
        )  # Reward depends on observation
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice(
            [
                {"agent_0": np.zeros((3, 32, 32)), "agent_1": np.zeros((3, 32, 32))},
                {"agent_0": np.ones((3, 32, 32)), "agent_1": np.ones((3, 32, 32))},
            ]
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.ObsDependentRewardContActionsEnv
class ObsDependentRewardContActionsEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "agent_1": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "agent_1": np.array([[0]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[1]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2]]), "agent_1": np.array([[0.0]])},
            {"agent_0": np.array([[0.8]]), "agent_1": np.array([[0.6]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 0.0},
            {"agent_0": 0.0, "agent_1": 1.0},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None, None]

    def step(self, action):
        observation = self.last_obs
        reward = (
            {"agent_0": 1, "agent_1": 0}
            if self.last_obs["agent_0"] == 0
            else {"agent_0": 0, "agent_1": 1}
        )  # Reward depends on observation
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice(
            [
                {"agent_0": np.array([0]), "agent_1": np.array([0])},
                {"agent_0": np.array([1]), "agent_1": np.array([1])},
            ]
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.ObsDependentRewardContActionsImageEnv
class ObsDependentRewardContActionsImageEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (3, 32, 32)),
            "agent_1": spaces.Box(0.0, 1.0, (3, 32, 32)),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "agent_1": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2]]), "agent_1": np.array([[0.0]])},
            {"agent_0": np.array([[0.8]]), "agent_1": np.array([[0.6]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 0.0},
            {"agent_0": 0.0, "agent_1": 1.0},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None, None]

    def step(self, action):
        observation = self.last_obs
        reward = (
            {"agent_0": 1, "agent_1": 0}
            if np.mean(self.last_obs["agent_0"]) == 0
            else {"agent_0": 0, "agent_1": 1}
        )  # Reward depends on observation
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice(
            [
                {"agent_0": np.zeros((3, 32, 32)), "agent_1": np.zeros((3, 32, 32))},
                {"agent_0": np.ones((3, 32, 32)), "agent_1": np.ones((3, 32, 32))},
            ]
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.DiscountedRewardEnv
class DiscountedRewardEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "agent_1": np.array([[0]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[1]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2, 0.8]]), "agent_1": np.array([[0.8, 0.2]])},
            {"agent_0": np.array([[0.8, 0.2]]), "agent_1": np.array([[0.2, 0.8]])},
        ]
        self.q_values = [
            {"agent_0": 0.99, "agent_1": 0.495},
            {"agent_0": 1.0, "agent_1": 0.5},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None, None]

    def step(self, action):
        observation = {"agent_0": np.array([1]), "agent_1": np.array([1])}
        reward = (
            {"agent_0": 1, "agent_1": 0.5}
            if self.last_obs["agent_0"] == 1
            else {"agent_0": 0, "agent_1": 0}
        )  # Reward depends on observation  # Reward depends on observation
        terminated = self.last_obs  # Terminate after second step
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        self.last_obs = {"agent_0": np.array([1]), "agent_1": np.array([1])}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.DiscountedRewardImageEnv
class DiscountedRewardImageEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (3, 32, 32)),
            "agent_1": spaces.Box(0.0, 1.0, (3, 32, 32)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2, 0.8]]), "agent_1": np.array([[0.8, 0.2]])},
            {"agent_0": np.array([[0.8, 0.2]]), "agent_1": np.array([[0.2, 0.8]])},
        ]
        self.q_values = [
            {"agent_0": 0.99, "agent_1": 0.495},
            {"agent_0": 1.0, "agent_1": 0.5},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None, None]

    def step(self, action):
        observation = {"agent_0": np.ones((3, 32, 32)), "agent_1": np.ones((3, 32, 32))}
        reward = (
            {"agent_0": 1, "agent_1": 0.5}
            if np.mean(self.last_obs["agent_0"]) == 1
            else {"agent_0": 0, "agent_1": 0}
        )  # Reward depends on observation  # Reward depends on observation
        terminated = {
            "agent_0": int(np.mean(self.last_obs["agent_0"])),
            "agent_1": int(np.mean(self.last_obs["agent_0"])),
        }  # Terminate after second step
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        self.last_obs = {
            "agent_0": np.ones((3, 32, 32)),
            "agent_1": np.ones((3, 32, 32)),
        }
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.DiscountedRewardContActionsEnv
class DiscountedRewardContActionsEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "agent_1": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "agent_1": np.array([[0]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[1]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2]]), "agent_1": np.array([[0.4]])},
            {"agent_0": np.array([[0.8]]), "agent_1": np.array([[0.1]])},
        ]
        self.q_values = [
            {"agent_0": 0.99, "agent_1": 0.495},
            {"agent_0": 1.0, "agent_1": 0.5},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None, None]

    def step(self, action):
        observation = {"agent_0": np.array([1]), "agent_1": np.array([1])}
        reward = (
            {"agent_0": 1, "agent_1": 0.5}
            if self.last_obs["agent_0"] == 1
            else {"agent_0": 0, "agent_1": 0}
        )  # Reward depends on observation  # Reward depends on observation
        terminated = self.last_obs  # Terminate after second step
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        self.last_obs = {"agent_0": np.array([1]), "agent_1": np.array([1])}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.DiscountedRewardContActionsImageEnv
class DiscountedRewardContActionsImageEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (3, 32, 32)),
            "agent_1": spaces.Box(0.0, 1.0, (3, 32, 32)),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "agent_1": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[0.2]]), "agent_1": np.array([[0.4]])},
            {"agent_0": np.array([[0.8]]), "agent_1": np.array([[0.1]])},
        ]
        self.q_values = [
            {"agent_0": 0.99, "agent_1": 0.495},
            {"agent_0": 1.0, "agent_1": 0.5},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [None, None]

    def step(self, action):
        observation = {"agent_0": np.ones((3, 32, 32)), "agent_1": np.ones((3, 32, 32))}
        reward = (
            {"agent_0": 1, "agent_1": 0.5}
            if np.mean(self.last_obs["agent_0"]) == 1
            else {"agent_0": 0, "agent_1": 0}
        )  # Reward depends on observation  # Reward depends on observation
        terminated = {
            "agent_0": int(np.mean(self.last_obs["agent_0"])),
            "agent_1": int(np.mean(self.last_obs["agent_0"])),
        }  # Terminate after second step
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        self.last_obs = {
            "agent_0": np.ones((3, 32, 32)),
            "agent_1": np.ones((3, 32, 32)),
        }
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.FixedObsPolicyEnv
class FixedObsPolicyEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(1),
            "agent_1": spaces.Discrete(1),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "agent_1": np.array([[0]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 1.0}
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])}
        ]

    def step(self, action):
        observation = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        reward = {
            "agent_0": [1, -1][action["agent_0"]],
            "agent_1": [-1, 1][action["agent_1"]],
        }  # Reward depends on action
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.FixedObsPolicyImageEnv
class FixedObsPolicyImageEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 0.0, (3, 32, 32)),
            "agent_1": spaces.Box(0.0, 0.0, (3, 32, 32)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 1.0}
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])}
        ]

    def step(self, action):
        observation = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        reward = {
            "agent_0": [1, -1][action["agent_0"]],
            "agent_1": [-1, 1][action["agent_1"]],
        }  # Reward depends on action
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.FixedObsPolicyContActionsEnv
class FixedObsPolicyContActionsEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(1),
            "agent_1": spaces.Discrete(1),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "agent_1": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "agent_1": np.array([[0]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0]]), "agent_1": np.array([[0.0]])},
        ]
        self.q_values = [
            {"agent_0": 0.0, "agent_1": 0.0}
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [{"agent_0": np.array([1.0]), "agent_1": np.array([0.0])}]

    def step(self, action):
        observation = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        reward = {
            "agent_0": -((1 - action["agent_0"]) ** 2),
            "agent_1": -((0 - action["agent_1"]) ** 2),
        }  # Reward depends on action
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.FixedObsPolicyContActionsImageEnv
class FixedObsPolicyContActionsImageEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 0.0, (3, 32, 32)),
            "agent_1": spaces.Box(0.0, 0.0, (3, 32, 32)),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (1,)),
            "agent_1": spaces.Box(0.0, 1.0, (1,)),
        }

        self.sample_obs = [
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0]]), "agent_1": np.array([[0.0]])},
        ]
        self.q_values = [
            {"agent_0": 0.0, "agent_1": 0.0}
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [{"agent_0": np.array([1.0]), "agent_1": np.array([0.0])}]

    def step(self, action):
        observation = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        reward = {
            "agent_0": -((1 - action["agent_0"]) ** 2),
            "agent_1": -((0 - action["agent_1"]) ** 2),
        }  # Reward depends on action
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        observation = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        info = {}
        return observation, info
class agilerl.utils.probe_envs_ma.PolicyEnv
class PolicyEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }
        self.sample_obs = [
            {"agent_0": np.array([[0]]), "agent_1": np.array([[1]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[0]])},
            {"agent_0": np.array([[0]]), "agent_1": np.array([[1]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[0]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 1.0},
            {"agent_0": 1.0, "agent_1": 1.0},
            {"agent_0": 0.0, "agent_1": 0.0},
            {"agent_0": 0.0, "agent_1": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
        ]

    def step(self, action):
        observation = self.last_obs
        reward = {
            "agent_0": action["agent_0"] == self.last_obs["agent_0"],
            "agent_1": action["agent_1"] != self.last_obs["agent_1"],
        }  # Reward depends on action
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice(
            [
                {"agent_0": np.array([0]), "agent_1": np.array([0])},
                {"agent_0": np.array([1]), "agent_1": np.array([1])},
                {"agent_0": np.array([0]), "agent_1": np.array([1])},
                {"agent_0": np.array([1]), "agent_1": np.array([0])},
            ]
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.PolicyImageEnv
class PolicyImageEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (3, 32, 32)),
            "agent_1": spaces.Box(0.0, 1.0, (3, 32, 32)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }
        self.sample_obs = [
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
        ]
        self.q_values = [
            {"agent_0": 1.0, "agent_1": 1.0},
            {"agent_0": 1.0, "agent_1": 1.0},
            {"agent_0": 0.0, "agent_1": 0.0},
            {"agent_0": 0.0, "agent_1": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
        ]

    def step(self, action):
        observation = self.last_obs
        reward = {
            "agent_0": action["agent_0"] == np.mean(self.last_obs["agent_0"]),
            "agent_1": action["agent_1"] != np.mean(self.last_obs["agent_1"]),
        }  # Reward depends on action
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice(
            [
                {"agent_0": np.zeros((3, 32, 32)), "agent_1": np.zeros((3, 32, 32))},
                {"agent_0": np.ones((3, 32, 32)), "agent_1": np.ones((3, 32, 32))},
                {"agent_0": np.zeros((3, 32, 32)), "agent_1": np.ones((3, 32, 32))},
                {"agent_0": np.ones((3, 32, 32)), "agent_1": np.zeros((3, 32, 32))},
            ]
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.PolicyContActionsEnv
class PolicyContActionsEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (2,)),
            "agent_1": spaces.Box(0.0, 1.0, (2,)),
        }
        self.sample_obs = [
            {"agent_0": np.array([[0]]), "agent_1": np.array([[0]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[1]])},
            {"agent_0": np.array([[0]]), "agent_1": np.array([[1]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[0]])},
            {"agent_0": np.array([[0]]), "agent_1": np.array([[1]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[0]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 0.0]]), "agent_1": np.array([[0.0, 0.0]])},
            {"agent_0": np.array([[1.0, 1.0]]), "agent_1": np.array([[1.0, 1.0]])},
        ]
        self.q_values = [
            {"agent_0": 0.0, "agent_1": 0.0},
            {"agent_0": 0.0, "agent_1": 0.0},
            {"agent_0": -2.0, "agent_1": -2.0},
            {"agent_0": -2.0, "agent_1": -2.0},
            {"agent_0": -1.0, "agent_1": -1.0},
            {"agent_0": -1.0, "agent_1": -1.0},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
        ]

    def step(self, action):
        observation = self.last_obs
        reward = {}
        if self.last_obs["agent_0"]:  # last obs = 1, policy should be [0, 1]
            reward["agent_0"] = -((0 - action["agent_0"][0]) ** 2) - (
                (1 - action["agent_0"][1]) ** 2
            )
        else:  # last obs = 0, policy should be [1, 0]
            reward["agent_0"] = -((1 - action["agent_0"][0]) ** 2) - (
                (0 - action["agent_0"][1]) ** 2
            )
        if self.last_obs["agent_1"]:  # last obs = 1, policy should be [1, 0]
            reward["agent_1"] = -((1 - action["agent_1"][0]) ** 2) - (
                (0 - action["agent_1"][1]) ** 2
            )
        else:  # last obs = 0, policy should be [0, 1]
            reward["agent_1"] = -((0 - action["agent_1"][0]) ** 2) - (
                (1 - action["agent_1"][1]) ** 2
            )
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice(
            [
                {"agent_0": np.array([0]), "agent_1": np.array([0])},
                {"agent_0": np.array([1]), "agent_1": np.array([1])},
                {"agent_0": np.array([0]), "agent_1": np.array([1])},
                {"agent_0": np.array([1]), "agent_1": np.array([0])},
            ]
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.PolicyContActionsImageEnv
class PolicyContActionsImageEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (3, 32, 32)),
            "agent_1": spaces.Box(0.0, 1.0, (3, 32, 32)),
        }
        self.action_space = {
            "agent_0": spaces.Box(0.0, 1.0, (2,)),
            "agent_1": spaces.Box(0.0, 1.0, (2,)),
        }
        self.sample_obs = [
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 0.0]]), "agent_1": np.array([[0.0, 0.0]])},
            {"agent_0": np.array([[1.0, 1.0]]), "agent_1": np.array([[1.0, 1.0]])},
        ]
        self.q_values = [
            {"agent_0": 0.0, "agent_1": 0.0},
            {"agent_0": 0.0, "agent_1": 0.0},
            {"agent_0": -2.0, "agent_1": -2.0},
            {"agent_0": -2.0, "agent_1": -2.0},
            {"agent_0": -1.0, "agent_1": -1.0},
            {"agent_0": -1.0, "agent_1": -1.0},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
        ]

    def step(self, action):
        observation = self.last_obs
        reward = {}
        if np.mean(self.last_obs["agent_0"]):  # last obs = 1, policy should be [0, 1]
            reward["agent_0"] = -((0 - action["agent_0"][0]) ** 2) - (
                (1 - action["agent_0"][1]) ** 2
            )
        else:  # last obs = 0, policy should be [1, 0]
            reward["agent_0"] = -((1 - action["agent_0"][0]) ** 2) - (
                (0 - action["agent_0"][1]) ** 2
            )
        if np.mean(self.last_obs["agent_1"]):  # last obs = 1, policy should be [1, 0]
            reward["agent_1"] = -((1 - action["agent_1"][0]) ** 2) - (
                (0 - action["agent_1"][1]) ** 2
            )
        else:  # last obs = 0, policy should be [0, 1]
            reward["agent_1"] = -((0 - action["agent_1"][0]) ** 2) - (
                (1 - action["agent_1"][1]) ** 2
            )
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice(
            [
                {"agent_0": np.zeros((3, 32, 32)), "agent_1": np.zeros((3, 32, 32))},
                {"agent_0": np.ones((3, 32, 32)), "agent_1": np.ones((3, 32, 32))},
                {"agent_0": np.zeros((3, 32, 32)), "agent_1": np.ones((3, 32, 32))},
                {"agent_0": np.ones((3, 32, 32)), "agent_1": np.zeros((3, 32, 32))},
            ]
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.MultiPolicyEnv
class MultiPolicyEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {"agent_0": np.array([0]), "agent_1": np.array([0])}
        self.observation_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.array([[0]]), "agent_1": np.array([[1]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[0]])},
            {"agent_0": np.array([[0]]), "agent_1": np.array([[1]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[0]])},
            {"agent_0": np.array([[0]]), "agent_1": np.array([[1]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[0]])},
            {"agent_0": np.array([[0]]), "agent_1": np.array([[1]])},
            {"agent_0": np.array([[1]]), "agent_1": np.array([[0]])},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[1.0, 0.0]])},
        ]
        self.q_values = [
            {"agent_0": 2.0, "agent_1": 2.0},
            {"agent_0": 2.0, "agent_1": 2.0},
            {"agent_0": 1.0, "agent_1": 1.0},
            {"agent_0": 1.0, "agent_1": 1.0},
            {"agent_0": 0.0, "agent_1": 3.0},
            {"agent_0": 0.0, "agent_1": 3.0},
            {"agent_0": 3.0, "agent_1": 0.0},
            {"agent_0": 3.0, "agent_1": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
        ]

    def step(self, action):
        observation = self.last_obs
        reward = {
            "agent_0": 2 * (action["agent_0"] == self.last_obs["agent_0"])
            + (action["agent_1"] == self.last_obs["agent_1"]),
            "agent_1": 2 * (action["agent_1"] != self.last_obs["agent_1"])
            + (action["agent_0"] != self.last_obs["agent_0"]),
        }  # Reward depends on action
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice(
            [
                {"agent_0": np.array([0]), "agent_1": np.array([0])},
                {"agent_0": np.array([1]), "agent_1": np.array([1])},
                {"agent_0": np.array([0]), "agent_1": np.array([1])},
                {"agent_0": np.array([1]), "agent_1": np.array([0])},
            ]
        )
        info = {}
        return self.last_obs, info
class agilerl.utils.probe_envs_ma.MultiPolicyImageEnv
class MultiPolicyImageEnv:
    def __init__(self):
        self.possible_agents = ["agent_0", "agent_1"]
        self.agents = self.possible_agents
        self.max_num_agents = len(self.possible_agents)
        self.num_agents = len(self.agents)

        self.last_obs = {
            "agent_0": np.zeros((3, 32, 32)),
            "agent_1": np.zeros((3, 32, 32)),
        }
        self.observation_space = {
            "agent_0": spaces.Box(0.0, 1.0, (3, 32, 32)),
            "agent_1": spaces.Box(0.0, 1.0, (3, 32, 32)),
        }
        self.action_space = {
            "agent_0": spaces.Discrete(2),
            "agent_1": spaces.Discrete(2),
        }

        self.sample_obs = [
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
            {"agent_0": np.zeros((1, 3, 32, 32)), "agent_1": np.ones((1, 3, 32, 32))},
            {"agent_0": np.ones((1, 3, 32, 32)), "agent_1": np.zeros((1, 3, 32, 32))},
        ]
        self.sample_actions = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[1.0, 0.0]])},
        ]
        self.q_values = [
            {"agent_0": 2.0, "agent_1": 2.0},
            {"agent_0": 2.0, "agent_1": 2.0},
            {"agent_0": 1.0, "agent_1": 1.0},
            {"agent_0": 1.0, "agent_1": 1.0},
            {"agent_0": 0.0, "agent_1": 3.0},
            {"agent_0": 0.0, "agent_1": 3.0},
            {"agent_0": 3.0, "agent_1": 0.0},
            {"agent_0": 3.0, "agent_1": 0.0},
        ]  # Correct Q values to learn, s x a table
        self.policy_values = [
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
            {"agent_0": np.array([[1.0, 0.0]]), "agent_1": np.array([[1.0, 0.0]])},
            {"agent_0": np.array([[0.0, 1.0]]), "agent_1": np.array([[0.0, 1.0]])},
        ]

    def step(self, action):
        observation = self.last_obs
        reward = {
            "agent_0": 2
            * (np.mean(action["agent_0"]) == np.mean(self.last_obs["agent_0"]))
            + (np.mean(action["agent_1"]) == np.mean(self.last_obs["agent_1"])),
            "agent_1": 2
            * (np.mean(action["agent_1"]) != np.mean(self.last_obs["agent_1"]))
            + (np.mean(action["agent_0"]) != np.mean(self.last_obs["agent_0"])),
        }  # Reward depends on action
        terminated = {"agent_0": True, "agent_1": True}
        truncated = {"agent_0": False, "agent_1": False}
        info = {}
        return observation, reward, terminated, truncated, info

    def reset(self):
        self.last_obs = random.choice(
            [
                {"agent_0": np.zeros((3, 32, 32)), "agent_1": np.zeros((3, 32, 32))},
                {"agent_0": np.ones((3, 32, 32)), "agent_1": np.ones((3, 32, 32))},
                {"agent_0": np.zeros((3, 32, 32)), "agent_1": np.ones((3, 32, 32))},
                {"agent_0": np.ones((3, 32, 32)), "agent_1": np.zeros((3, 32, 32))},
            ]
        )
        info = {}
        return self.last_obs, info
agilerl.utils.probe_envs_ma.check_policy_q_learning_with_probe_env(env, algo_class, algo_args, memory, learn_steps=1000, device='cpu')
def check_policy_q_learning_with_probe_env(
    env, algo_class, algo_args, memory, learn_steps=1000, device="cpu"
):
    print(f"Probe environment: {type(env).__name__}")

    agent = algo_class(**algo_args, device=device)

    state, _ = env.reset()
    for _ in range(10000):
        if agent.net_config["arch"] == "cnn":
            state = {agent_id: np.expand_dims(s, 0) for agent_id, s in state.items()}
        cont_actions, discrete_action = agent.getAction(state, epsilon=1)
        action = discrete_action if agent.discrete_actions else cont_actions
        next_state, reward, done, _, _ = env.step(action)
        if agent.net_config["arch"] == "cnn":
            state = {agent_id: np.squeeze(s) for agent_id, s in state.items()}
        memory.save2memory(state, cont_actions, reward, next_state, done)
        state = next_state
        if done[agent.agent_ids[0]]:
            state, _ = env.reset()

    # Learn from experiences
    for _ in trange(learn_steps):
        experiences = memory.sample(agent.batch_size)
        # Learn according to agent's RL algorithm
        agent.learn(experiences)

    with torch.no_grad():
        for agent_id, actor, critic in zip(
            agent.agent_ids, agent.actors, agent.critics
        ):
            for sample_obs, sample_action, q_values, policy_values in zip(
                env.sample_obs, env.sample_actions, env.q_values, env.policy_values
            ):
                state = prepare_ma_states(
                    sample_obs, agent.one_hot, agent.state_dims, device
                )

                if q_values is not None:
                    action = prepare_ma_actions(sample_action, device)
                    if agent.arch == "mlp":
                        input_combined = torch.cat(
                            list(state.values()) + list(action.values()), 1
                        )
                        predicted_q_values = (
                            critic(input_combined).detach().cpu().numpy()[0]
                        )
                    else:
                        stacked_states = torch.stack(list(state.values()), dim=2)
                        stacked_actions = torch.cat(list(action.values()), dim=1)
                        predicted_q_values = (
                            critic(stacked_states, stacked_actions)
                            .detach()
                            .cpu()
                            .numpy()[0]
                        )
                    # print("---")
                    # print(agent_id, "q", q_values[agent_id], predicted_q_values)
                    # assert np.allclose(q_values[agent_id], predicted_q_values, atol=0.1):
                    if not np.allclose(
                        q_values[agent_id], predicted_q_values, atol=0.1
                    ):
                        print(agent_id, "q", q_values[agent_id], predicted_q_values)

                if policy_values is not None:
                    if agent.arch == "mlp":
                        predicted_policy_values = (
                            actor(state[agent_id]).detach().cpu().numpy()[0]
                        )
                    else:
                        predicted_policy_values = (
                            actor(state[agent_id].unsqueeze(2))
                            .detach()
                            .cpu()
                            .numpy()[0]
                        )
                    # print(agent_id, "pol", policy_values[agent_id], predicted_policy_values)
                    # assert np.allclose(policy_values[agent_id], predicted_policy_values, atol=0.1)
                    if not np.allclose(
                        policy_values[agent_id], predicted_policy_values, atol=0.1
                    ):
                        print(
                            agent_id,
                            "pol",
                            policy_values[agent_id],
                            predicted_policy_values,
                        )