diff --git a/__pycache__/__init__.cpython-310.pyc b/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 90ac4db..0000000 Binary files a/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/yes_cmdr_env.cpython-310.pyc b/__pycache__/yes_cmdr_env.cpython-310.pyc deleted file mode 100644 index 2fe8ef9..0000000 Binary files a/__pycache__/yes_cmdr_env.cpython-310.pyc and /dev/null differ diff --git a/__pycache__/yes_cmdr_utils.cpython-310.pyc b/__pycache__/yes_cmdr_utils.cpython-310.pyc deleted file mode 100644 index a713446..0000000 Binary files a/__pycache__/yes_cmdr_utils.cpython-310.pyc and /dev/null differ diff --git a/yes_cmdr_env.py b/yes_cmdr_env.py index 9b99fad..249bbec 100644 --- a/yes_cmdr_env.py +++ b/yes_cmdr_env.py @@ -15,6 +15,7 @@ class YesCmdrEnv(gym.Env): replay_path: str = None, campaign_id: str = "NOID", team_id: int = 0, + faction_id: int = 0, max_team: int = 2, max_faction: int = 2, data_path: str = "../data", @@ -27,6 +28,7 @@ class YesCmdrEnv(gym.Env): self.replay_path = replay_path self.campaign_id = campaign_id self.init_team_id = team_id + self.init_faction_id = faction_id self.max_team = max_team self.max_faction = max_faction self.data_path = data_path @@ -127,7 +129,7 @@ class YesCmdrEnv(gym.Env): if terminated or truncated: # The eval_episode_return is calculated from Player 1's perspective - info["eval_episode_return"] = reward if self.faction_id == self.teams[self.init_team_id].faction_id else -reward + info["eval_episode_return"] = reward if self.faction_id == self.init_faction_id else -reward info["done_reason"] = "Terminated" if terminated else "Truncated" if self.replay_path is not None: @@ -222,12 +224,12 @@ class YesCmdrEnv(gym.Env): return self._legal_actions map_data = self.map.nodes - player_data = self.current_agents_dict + team_data = self.current_agents_dict spotted_enemy_ids = self.spotted_enemy_ids _legal_actions = [Action(ActionType.END_OF_TURN)] # Default action - for agent_id, agent in player_data.items(): + for agent_id, agent in team_data.items(): if agent.commenced_action or not agent.alive or agent.pos == (-1, -1) or agent.is_carried: continue pos = agent.pos @@ -411,7 +413,7 @@ class YesCmdrEnv(gym.Env): _teams = [ Team( team_id=_team_id, - faction_id=agents_dict_lists[_team_id][0]["faction_id"], + faction_id=agents_dict_lists[_team_id][0]["faction_id"] if agents_dict_lists[_team_id] else -1, # 可能为空 agents=_agents[_team_id] ) for _team_id in range(self.max_team) @@ -558,6 +560,8 @@ class YesCmdrEnv(gym.Env): @property def faction_id(self): + if self._team_id >= self.max_team: + raise ValueError(f"Team id {self._team_id} exceeds max id {self.max_team - 1}") return self.teams[self._team_id].faction_id @property @@ -713,6 +717,8 @@ class YesCmdrEnv(gym.Env): target_agent.endurance -= damage agent.weapon.ammo -= 1 agent.commenced_action = True + + terminated = False if target_agent.endurance <= 0: # 删除与地图关联 @@ -721,10 +727,9 @@ class YesCmdrEnv(gym.Env): node.team_id = -1 node.faction_id = -1 - if sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0: - return damage, True + terminated = sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0 - return damage, False + return damage, terminated def _interact(self, agent_id, target_id): agent = self.get_agent(agent_id) @@ -1261,19 +1266,19 @@ class YesCmdrEnv(gym.Env): obs = self.observe() truncated = self.episode_steps >= self.max_episode_steps - terminated = self.teams[self.enemy_id].alive_count == 0 - self._cumulative_rewards[self.player_id] += reward + terminated = sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0 + self._cumulative_rewards[self.team_id] += reward info = {} info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1]) - info["next_player"] = self.player_id + 1 if not terminated else -1 + info["next_team"] = self.team_id + 1 if not terminated else -1 if self.replay_path is not None: - self._frames[self.player_id].append(self.render(mode="rgb_array")) + self._frames[self.team_id].append(self.render(mode="rgb_array")) if terminated or truncated: # The eval_episode_return is calculated from Player 1's perspective - info["eval_episode_return"] = reward if self.player_id == 0 else -reward + info["eval_episode_return"] = reward if self.faction_id == self.init_faction_id else -reward info["done_reason"] = "Terminated" if terminated else "Truncated" if self.replay_path is not None: