更新至多阵营多方协作版本

main
zuchaoli 2024-12-19 16:56:48 +08:00
parent 93558d3cae
commit 1af8749a8c
4 changed files with 17 additions and 12 deletions

Binary file not shown.

View File

@ -15,6 +15,7 @@ class YesCmdrEnv(gym.Env):
replay_path: str = None,
campaign_id: str = "NOID",
team_id: int = 0,
faction_id: int = 0,
max_team: int = 2,
max_faction: int = 2,
data_path: str = "../data",
@ -27,6 +28,7 @@ class YesCmdrEnv(gym.Env):
self.replay_path = replay_path
self.campaign_id = campaign_id
self.init_team_id = team_id
self.init_faction_id = faction_id
self.max_team = max_team
self.max_faction = max_faction
self.data_path = data_path
@ -127,7 +129,7 @@ class YesCmdrEnv(gym.Env):
if terminated or truncated:
# The eval_episode_return is calculated from Player 1's perspective
info["eval_episode_return"] = reward if self.faction_id == self.teams[self.init_team_id].faction_id else -reward
info["eval_episode_return"] = reward if self.faction_id == self.init_faction_id else -reward
info["done_reason"] = "Terminated" if terminated else "Truncated"
if self.replay_path is not None:
@ -222,12 +224,12 @@ class YesCmdrEnv(gym.Env):
return self._legal_actions
map_data = self.map.nodes
player_data = self.current_agents_dict
team_data = self.current_agents_dict
spotted_enemy_ids = self.spotted_enemy_ids
_legal_actions = [Action(ActionType.END_OF_TURN)] # Default action
for agent_id, agent in player_data.items():
for agent_id, agent in team_data.items():
if agent.commenced_action or not agent.alive or agent.pos == (-1, -1) or agent.is_carried:
continue
pos = agent.pos
@ -411,7 +413,7 @@ class YesCmdrEnv(gym.Env):
_teams = [
Team(
team_id=_team_id,
faction_id=agents_dict_lists[_team_id][0]["faction_id"],
faction_id=agents_dict_lists[_team_id][0]["faction_id"] if agents_dict_lists[_team_id] else -1, # 可能为空
agents=_agents[_team_id]
)
for _team_id in range(self.max_team)
@ -558,6 +560,8 @@ class YesCmdrEnv(gym.Env):
@property
def faction_id(self):
if self._team_id >= self.max_team:
raise ValueError(f"Team id {self._team_id} exceeds max id {self.max_team - 1}")
return self.teams[self._team_id].faction_id
@property
@ -713,6 +717,8 @@ class YesCmdrEnv(gym.Env):
target_agent.endurance -= damage
agent.weapon.ammo -= 1
agent.commenced_action = True
terminated = False
if target_agent.endurance <= 0:
# 删除与地图关联
@ -721,10 +727,9 @@ class YesCmdrEnv(gym.Env):
node.team_id = -1
node.faction_id = -1
if sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0:
return damage, True
terminated = sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0
return damage, False
return damage, terminated
def _interact(self, agent_id, target_id):
agent = self.get_agent(agent_id)
@ -1261,19 +1266,19 @@ class YesCmdrEnv(gym.Env):
obs = self.observe()
truncated = self.episode_steps >= self.max_episode_steps
terminated = self.teams[self.enemy_id].alive_count == 0
self._cumulative_rewards[self.player_id] += reward
terminated = sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0
self._cumulative_rewards[self.team_id] += reward
info = {}
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
info["next_player"] = self.player_id + 1 if not terminated else -1
info["next_team"] = self.team_id + 1 if not terminated else -1
if self.replay_path is not None:
self._frames[self.player_id].append(self.render(mode="rgb_array"))
self._frames[self.team_id].append(self.render(mode="rgb_array"))
if terminated or truncated:
# The eval_episode_return is calculated from Player 1's perspective
info["eval_episode_return"] = reward if self.player_id == 0 else -reward
info["eval_episode_return"] = reward if self.faction_id == self.init_faction_id else -reward
info["done_reason"] = "Terminated" if terminated else "Truncated"
if self.replay_path is not None: