更新至多阵营多方协作版本
parent
93558d3cae
commit
1af8749a8c
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -15,6 +15,7 @@ class YesCmdrEnv(gym.Env):
|
|||
replay_path: str = None,
|
||||
campaign_id: str = "NOID",
|
||||
team_id: int = 0,
|
||||
faction_id: int = 0,
|
||||
max_team: int = 2,
|
||||
max_faction: int = 2,
|
||||
data_path: str = "../data",
|
||||
|
@ -27,6 +28,7 @@ class YesCmdrEnv(gym.Env):
|
|||
self.replay_path = replay_path
|
||||
self.campaign_id = campaign_id
|
||||
self.init_team_id = team_id
|
||||
self.init_faction_id = faction_id
|
||||
self.max_team = max_team
|
||||
self.max_faction = max_faction
|
||||
self.data_path = data_path
|
||||
|
@ -127,7 +129,7 @@ class YesCmdrEnv(gym.Env):
|
|||
|
||||
if terminated or truncated:
|
||||
# The eval_episode_return is calculated from Player 1's perspective
|
||||
info["eval_episode_return"] = reward if self.faction_id == self.teams[self.init_team_id].faction_id else -reward
|
||||
info["eval_episode_return"] = reward if self.faction_id == self.init_faction_id else -reward
|
||||
info["done_reason"] = "Terminated" if terminated else "Truncated"
|
||||
|
||||
if self.replay_path is not None:
|
||||
|
@ -222,12 +224,12 @@ class YesCmdrEnv(gym.Env):
|
|||
return self._legal_actions
|
||||
|
||||
map_data = self.map.nodes
|
||||
player_data = self.current_agents_dict
|
||||
team_data = self.current_agents_dict
|
||||
spotted_enemy_ids = self.spotted_enemy_ids
|
||||
|
||||
_legal_actions = [Action(ActionType.END_OF_TURN)] # Default action
|
||||
|
||||
for agent_id, agent in player_data.items():
|
||||
for agent_id, agent in team_data.items():
|
||||
if agent.commenced_action or not agent.alive or agent.pos == (-1, -1) or agent.is_carried:
|
||||
continue
|
||||
pos = agent.pos
|
||||
|
@ -411,7 +413,7 @@ class YesCmdrEnv(gym.Env):
|
|||
_teams = [
|
||||
Team(
|
||||
team_id=_team_id,
|
||||
faction_id=agents_dict_lists[_team_id][0]["faction_id"],
|
||||
faction_id=agents_dict_lists[_team_id][0]["faction_id"] if agents_dict_lists[_team_id] else -1, # 可能为空
|
||||
agents=_agents[_team_id]
|
||||
)
|
||||
for _team_id in range(self.max_team)
|
||||
|
@ -558,6 +560,8 @@ class YesCmdrEnv(gym.Env):
|
|||
|
||||
@property
|
||||
def faction_id(self):
|
||||
if self._team_id >= self.max_team:
|
||||
raise ValueError(f"Team id {self._team_id} exceeds max id {self.max_team - 1}")
|
||||
return self.teams[self._team_id].faction_id
|
||||
|
||||
@property
|
||||
|
@ -713,6 +717,8 @@ class YesCmdrEnv(gym.Env):
|
|||
target_agent.endurance -= damage
|
||||
agent.weapon.ammo -= 1
|
||||
agent.commenced_action = True
|
||||
|
||||
terminated = False
|
||||
|
||||
if target_agent.endurance <= 0:
|
||||
# 删除与地图关联
|
||||
|
@ -721,10 +727,9 @@ class YesCmdrEnv(gym.Env):
|
|||
node.team_id = -1
|
||||
node.faction_id = -1
|
||||
|
||||
if sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0:
|
||||
return damage, True
|
||||
terminated = sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0
|
||||
|
||||
return damage, False
|
||||
return damage, terminated
|
||||
|
||||
def _interact(self, agent_id, target_id):
|
||||
agent = self.get_agent(agent_id)
|
||||
|
@ -1261,19 +1266,19 @@ class YesCmdrEnv(gym.Env):
|
|||
|
||||
obs = self.observe()
|
||||
truncated = self.episode_steps >= self.max_episode_steps
|
||||
terminated = self.teams[self.enemy_id].alive_count == 0
|
||||
self._cumulative_rewards[self.player_id] += reward
|
||||
terminated = sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0
|
||||
self._cumulative_rewards[self.team_id] += reward
|
||||
|
||||
info = {}
|
||||
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
|
||||
info["next_player"] = self.player_id + 1 if not terminated else -1
|
||||
info["next_team"] = self.team_id + 1 if not terminated else -1
|
||||
|
||||
if self.replay_path is not None:
|
||||
self._frames[self.player_id].append(self.render(mode="rgb_array"))
|
||||
self._frames[self.team_id].append(self.render(mode="rgb_array"))
|
||||
|
||||
if terminated or truncated:
|
||||
# The eval_episode_return is calculated from Player 1's perspective
|
||||
info["eval_episode_return"] = reward if self.player_id == 0 else -reward
|
||||
info["eval_episode_return"] = reward if self.faction_id == self.init_faction_id else -reward
|
||||
info["done_reason"] = "Terminated" if terminated else "Truncated"
|
||||
|
||||
if self.replay_path is not None:
|
||||
|
|
Loading…
Reference in New Issue