更新至多阵营多方协作版本
parent
93558d3cae
commit
1af8749a8c
Binary file not shown.
Binary file not shown.
Binary file not shown.
|
@ -15,6 +15,7 @@ class YesCmdrEnv(gym.Env):
|
||||||
replay_path: str = None,
|
replay_path: str = None,
|
||||||
campaign_id: str = "NOID",
|
campaign_id: str = "NOID",
|
||||||
team_id: int = 0,
|
team_id: int = 0,
|
||||||
|
faction_id: int = 0,
|
||||||
max_team: int = 2,
|
max_team: int = 2,
|
||||||
max_faction: int = 2,
|
max_faction: int = 2,
|
||||||
data_path: str = "../data",
|
data_path: str = "../data",
|
||||||
|
@ -27,6 +28,7 @@ class YesCmdrEnv(gym.Env):
|
||||||
self.replay_path = replay_path
|
self.replay_path = replay_path
|
||||||
self.campaign_id = campaign_id
|
self.campaign_id = campaign_id
|
||||||
self.init_team_id = team_id
|
self.init_team_id = team_id
|
||||||
|
self.init_faction_id = faction_id
|
||||||
self.max_team = max_team
|
self.max_team = max_team
|
||||||
self.max_faction = max_faction
|
self.max_faction = max_faction
|
||||||
self.data_path = data_path
|
self.data_path = data_path
|
||||||
|
@ -127,7 +129,7 @@ class YesCmdrEnv(gym.Env):
|
||||||
|
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
# The eval_episode_return is calculated from Player 1's perspective
|
# The eval_episode_return is calculated from Player 1's perspective
|
||||||
info["eval_episode_return"] = reward if self.faction_id == self.teams[self.init_team_id].faction_id else -reward
|
info["eval_episode_return"] = reward if self.faction_id == self.init_faction_id else -reward
|
||||||
info["done_reason"] = "Terminated" if terminated else "Truncated"
|
info["done_reason"] = "Terminated" if terminated else "Truncated"
|
||||||
|
|
||||||
if self.replay_path is not None:
|
if self.replay_path is not None:
|
||||||
|
@ -222,12 +224,12 @@ class YesCmdrEnv(gym.Env):
|
||||||
return self._legal_actions
|
return self._legal_actions
|
||||||
|
|
||||||
map_data = self.map.nodes
|
map_data = self.map.nodes
|
||||||
player_data = self.current_agents_dict
|
team_data = self.current_agents_dict
|
||||||
spotted_enemy_ids = self.spotted_enemy_ids
|
spotted_enemy_ids = self.spotted_enemy_ids
|
||||||
|
|
||||||
_legal_actions = [Action(ActionType.END_OF_TURN)] # Default action
|
_legal_actions = [Action(ActionType.END_OF_TURN)] # Default action
|
||||||
|
|
||||||
for agent_id, agent in player_data.items():
|
for agent_id, agent in team_data.items():
|
||||||
if agent.commenced_action or not agent.alive or agent.pos == (-1, -1) or agent.is_carried:
|
if agent.commenced_action or not agent.alive or agent.pos == (-1, -1) or agent.is_carried:
|
||||||
continue
|
continue
|
||||||
pos = agent.pos
|
pos = agent.pos
|
||||||
|
@ -411,7 +413,7 @@ class YesCmdrEnv(gym.Env):
|
||||||
_teams = [
|
_teams = [
|
||||||
Team(
|
Team(
|
||||||
team_id=_team_id,
|
team_id=_team_id,
|
||||||
faction_id=agents_dict_lists[_team_id][0]["faction_id"],
|
faction_id=agents_dict_lists[_team_id][0]["faction_id"] if agents_dict_lists[_team_id] else -1, # 可能为空
|
||||||
agents=_agents[_team_id]
|
agents=_agents[_team_id]
|
||||||
)
|
)
|
||||||
for _team_id in range(self.max_team)
|
for _team_id in range(self.max_team)
|
||||||
|
@ -558,6 +560,8 @@ class YesCmdrEnv(gym.Env):
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def faction_id(self):
|
def faction_id(self):
|
||||||
|
if self._team_id >= self.max_team:
|
||||||
|
raise ValueError(f"Team id {self._team_id} exceeds max id {self.max_team - 1}")
|
||||||
return self.teams[self._team_id].faction_id
|
return self.teams[self._team_id].faction_id
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
@ -714,6 +718,8 @@ class YesCmdrEnv(gym.Env):
|
||||||
agent.weapon.ammo -= 1
|
agent.weapon.ammo -= 1
|
||||||
agent.commenced_action = True
|
agent.commenced_action = True
|
||||||
|
|
||||||
|
terminated = False
|
||||||
|
|
||||||
if target_agent.endurance <= 0:
|
if target_agent.endurance <= 0:
|
||||||
# 删除与地图关联
|
# 删除与地图关联
|
||||||
node = self.get_node(target_agent.pos)
|
node = self.get_node(target_agent.pos)
|
||||||
|
@ -721,10 +727,9 @@ class YesCmdrEnv(gym.Env):
|
||||||
node.team_id = -1
|
node.team_id = -1
|
||||||
node.faction_id = -1
|
node.faction_id = -1
|
||||||
|
|
||||||
if sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0:
|
terminated = sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0
|
||||||
return damage, True
|
|
||||||
|
|
||||||
return damage, False
|
return damage, terminated
|
||||||
|
|
||||||
def _interact(self, agent_id, target_id):
|
def _interact(self, agent_id, target_id):
|
||||||
agent = self.get_agent(agent_id)
|
agent = self.get_agent(agent_id)
|
||||||
|
@ -1261,19 +1266,19 @@ class YesCmdrEnv(gym.Env):
|
||||||
|
|
||||||
obs = self.observe()
|
obs = self.observe()
|
||||||
truncated = self.episode_steps >= self.max_episode_steps
|
truncated = self.episode_steps >= self.max_episode_steps
|
||||||
terminated = self.teams[self.enemy_id].alive_count == 0
|
terminated = sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0
|
||||||
self._cumulative_rewards[self.player_id] += reward
|
self._cumulative_rewards[self.team_id] += reward
|
||||||
|
|
||||||
info = {}
|
info = {}
|
||||||
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
|
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
|
||||||
info["next_player"] = self.player_id + 1 if not terminated else -1
|
info["next_team"] = self.team_id + 1 if not terminated else -1
|
||||||
|
|
||||||
if self.replay_path is not None:
|
if self.replay_path is not None:
|
||||||
self._frames[self.player_id].append(self.render(mode="rgb_array"))
|
self._frames[self.team_id].append(self.render(mode="rgb_array"))
|
||||||
|
|
||||||
if terminated or truncated:
|
if terminated or truncated:
|
||||||
# The eval_episode_return is calculated from Player 1's perspective
|
# The eval_episode_return is calculated from Player 1's perspective
|
||||||
info["eval_episode_return"] = reward if self.player_id == 0 else -reward
|
info["eval_episode_return"] = reward if self.faction_id == self.init_faction_id else -reward
|
||||||
info["done_reason"] = "Terminated" if terminated else "Truncated"
|
info["done_reason"] = "Terminated" if terminated else "Truncated"
|
||||||
|
|
||||||
if self.replay_path is not None:
|
if self.replay_path is not None:
|
||||||
|
|
Loading…
Reference in New Issue