From 93558d3caed56fad85d6053e885a35b6376fe6a7 Mon Sep 17 00:00:00 2001 From: hkr04 Date: Thu, 19 Dec 2024 09:56:57 +0800 Subject: [PATCH] update --- .gitignore | 3 + common/agent.py | 2 + common/renderer.py | 24 ++++---- data/test/AgentsInfo.json | 1 + requirements.txt | 5 ++ tests/test_yes_cmdr_env.py | 22 +++++--- yes_cmdr_env.py | 109 ++++++++++++++++++++++--------------- 7 files changed, 100 insertions(+), 66 deletions(-) create mode 100644 .gitignore create mode 100644 requirements.txt diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..cff9d04 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +/video +*/__pycache__* +/__pycache__ \ No newline at end of file diff --git a/common/agent.py b/common/agent.py index 5f4dfcc..6c62714 100644 --- a/common/agent.py +++ b/common/agent.py @@ -455,6 +455,7 @@ class TileNode: self.agent_id = agent_id self.is_city = is_city self.team_id = -1 + self.faction_id = -1 def reset(self): self.chaotic_value = 0 @@ -462,6 +463,7 @@ class TileNode: self.occupied_by = None self.agent_id = None self.team_id = -1 + self.faction_id = -1 def __lt__(self, other): return self.pos < other.pos diff --git a/common/renderer.py b/common/renderer.py index 0d045e2..e3d79b6 100644 --- a/common/renderer.py +++ b/common/renderer.py @@ -25,18 +25,18 @@ team_colors = [ icons_path = "./tianqiong/envs/icons" unit_icons: Dict[str, Dict[int, pygame.Surface]] = { - AgentType.AntiAir: {0: pygame.image.load(f"{icons_path}/AntiAir_b.png"), 1: pygame.image.load(f"{icons_path}/AntiAir_r.png")}, - AgentType.Airport: {0: pygame.image.load(f"{icons_path}/Airport_b.png"), 1: pygame.image.load(f"{icons_path}/Airport_r.png")}, - AgentType.Artillery: {0: pygame.image.load(f"{icons_path}/Artillery_b.png"), 1: pygame.image.load(f"{icons_path}/Artillery_r.png")}, - AgentType.Bomber: {0: pygame.image.load(f"{icons_path}/Bomber_b.png"), 1: pygame.image.load(f"{icons_path}/Bomber_r.png")}, - AgentType.Carrier: {0: pygame.image.load(f"{icons_path}/Carrier_b.png"), 1: pygame.image.load(f"{icons_path}/Carrier_r.png")}, - AgentType.Fighter: {0: pygame.image.load(f"{icons_path}/Fighter_b.png"), 1: pygame.image.load(f"{icons_path}/Fighter_r.png")}, - AgentType.Helicopter: {0: pygame.image.load(f"{icons_path}/Helicopter_b.png"), 1: pygame.image.load(f"{icons_path}/Helicopter_r.png")}, - AgentType.Infantry: {0: pygame.image.load(f"{icons_path}/Infantry_b.png"), 1: pygame.image.load(f"{icons_path}/Infantry_r.png")}, - AgentType.Tank: {0: pygame.image.load(f"{icons_path}/Tank_b.png"), 1: pygame.image.load(f"{icons_path}/Tank_r.png")}, - AgentType.TransportHelicopter: {0: pygame.image.load(f"{icons_path}/TransportHelicopter_b.png"), 1: pygame.image.load(f"{icons_path}/TransportHelicopter_r.png")}, - AgentType.TransportShip: {0: pygame.image.load(f"{icons_path}/TransportShip_b.png"), 1: pygame.image.load(f"{icons_path}/TransportShip_r.png")}, - AgentType.CombatShip: {0: pygame.image.load(f"{icons_path}/CombatShip_b.png"), 1: pygame.image.load(f"{icons_path}/CombatShip_r.png")}, + # AgentType.AntiAir: {0: pygame.image.load(f"{icons_path}/AntiAir_b.png"), 1: pygame.image.load(f"{icons_path}/AntiAir_r.png")}, + # AgentType.Airport: {0: pygame.image.load(f"{icons_path}/Airport_b.png"), 1: pygame.image.load(f"{icons_path}/Airport_r.png")}, + # AgentType.Artillery: {0: pygame.image.load(f"{icons_path}/Artillery_b.png"), 1: pygame.image.load(f"{icons_path}/Artillery_r.png")}, + # AgentType.Bomber: {0: pygame.image.load(f"{icons_path}/Bomber_b.png"), 1: pygame.image.load(f"{icons_path}/Bomber_r.png")}, + # AgentType.Carrier: {0: pygame.image.load(f"{icons_path}/Carrier_b.png"), 1: pygame.image.load(f"{icons_path}/Carrier_r.png")}, + # AgentType.Fighter: {0: pygame.image.load(f"{icons_path}/Fighter_b.png"), 1: pygame.image.load(f"{icons_path}/Fighter_r.png")}, + # AgentType.Helicopter: {0: pygame.image.load(f"{icons_path}/Helicopter_b.png"), 1: pygame.image.load(f"{icons_path}/Helicopter_r.png")}, + # AgentType.Infantry: {0: pygame.image.load(f"{icons_path}/Infantry_b.png"), 1: pygame.image.load(f"{icons_path}/Infantry_r.png")}, + # AgentType.Tank: {0: pygame.image.load(f"{icons_path}/Tank_b.png"), 1: pygame.image.load(f"{icons_path}/Tank_r.png")}, + # AgentType.TransportHelicopter: {0: pygame.image.load(f"{icons_path}/TransportHelicopter_b.png"), 1: pygame.image.load(f"{icons_path}/TransportHelicopter_r.png")}, + # AgentType.TransportShip: {0: pygame.image.load(f"{icons_path}/TransportShip_b.png"), 1: pygame.image.load(f"{icons_path}/TransportShip_r.png")}, + # AgentType.CombatShip: {0: pygame.image.load(f"{icons_path}/CombatShip_b.png"), 1: pygame.image.load(f"{icons_path}/CombatShip_r.png")}, } class Renderer: diff --git a/data/test/AgentsInfo.json b/data/test/AgentsInfo.json index 2b5126b..a9e518f 100644 --- a/data/test/AgentsInfo.json +++ b/data/test/AgentsInfo.json @@ -319,6 +319,7 @@ "type": 16, "team_id": 2, "faction_id": 1, + "move_type": 1, "defense": 5.0, "mobility": 0.0, "info_level": 6, diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..fb3d15d --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +numpy +pygame +gymnasium +imageio +imageio_ffmpeg \ No newline at end of file diff --git a/tests/test_yes_cmdr_env.py b/tests/test_yes_cmdr_env.py index 542b879..89709ea 100644 --- a/tests/test_yes_cmdr_env.py +++ b/tests/test_yes_cmdr_env.py @@ -7,29 +7,32 @@ from yes_cmdr.common.agent import Action, Command, ActionType def test_action_id_transform(env): print("Testing action id transform") env.reset() - for idx in range(env.action_space_size): - assert env.action_to_id(env.id_to_action(idx)) == idx, f"Expected {env.id_to_action(idx)} as {idx} but got {env.action_to_id(env.id_to_action(idx))}" + + for team_id in range(env.max_team): + for idx in range(env.action_space_size): + assert env.action_to_id(env.id_to_action(idx)) == idx, f"Expected {env.id_to_action(idx)} as {idx} but got {env.action_to_id(env.id_to_action(idx))}" - env.step(0) - - for idx in range(env.action_space_size): - assert env.action_to_id(env.id_to_action(idx)) == idx, f"Expected {env.id_to_action(idx)} as {idx} but got {env.action_to_id(env.id_to_action(idx))}" + env.step(0) def test_random_action(env): print("Testing random action") env.reset() for i in range(100): action = env.random_action() - assert env.check_validity(action)[0], f"Action {action} is not valid" + validity, reason = env.check_validity(action) + assert validity, f"Action {action} is not valid: {reason}" env.step(action) + env.save_replay() def test_bot_action(env): print("Testing bot action") env.reset() for i in range(100): action = env.bot_action() - assert env.check_validity(action)[0], f"Action {action} is not valid" + validity, reason = env.check_validity(action) + assert validity, f"Action {action} is not valid: {reason}" env.step(action) + env.save_replay() def test_command(env): env.reset() @@ -79,7 +82,8 @@ if __name__ == "__main__": data_path="./data/test", max_team=5, max_faction=4, - war_fog=False + war_fog=False, + replay_path="./video" ) test_action_id_transform(env) diff --git a/yes_cmdr_env.py b/yes_cmdr_env.py index 3c8d1d4..9b99fad 100644 --- a/yes_cmdr_env.py +++ b/yes_cmdr_env.py @@ -54,15 +54,17 @@ class YesCmdrEnv(gym.Env): assert agent.pos in self.map.nodes, f"Invalid agent position: {agent.pos}" self.map.nodes[agent.pos].agent_id = agent_id self.map.nodes[agent.pos].team_id = team.team_id + self.map.nodes[agent.pos].faction_id = team.faction_id if self.replay_path is not None: - self._frames = [[], []] + self._frames = [[] for _ in range(self.max_team)] self._legal_actions = None self._legal_action_ids = None self._spotted_enemy_ids = None + self._spotted_enemy_agents = None self._team_id = self.init_team_id self._spotted_agents = None self.episode_steps = 0 - self._cumulative_rewards = np.zeros(2) + self._cumulative_rewards = np.zeros(self.max_team) obs = self.observe() info = { @@ -110,6 +112,7 @@ class YesCmdrEnv(gym.Env): self._legal_actions = None self._legal_action_ids = None self._spotted_enemy_ids = None + self._spotted_enemy_agents = None obs = self.observe() truncated = self.episode_steps >= self.max_episode_steps @@ -124,7 +127,7 @@ class YesCmdrEnv(gym.Env): if terminated or truncated: # The eval_episode_return is calculated from Player 1's perspective - info["eval_episode_return"] = reward if self.team_id == 0 else -reward + info["eval_episode_return"] = reward if self.faction_id == self.teams[self.init_team_id].faction_id else -reward info["done_reason"] = "Terminated" if terminated else "Truncated" if self.replay_path is not None: @@ -134,8 +137,7 @@ class YesCmdrEnv(gym.Env): def render(self, mode="human", attack_agent=None, defend_agent=None): if mode == "rgb_array": - player_agents = self.teams[self.team_id].agents.values() - return self._renderer.render(self.current_agents, self.spotted_enemy_agents, attack_agent, defend_agent) + return self._renderer.render(self.union_agents, self.spotted_enemy_agents, attack_agent, defend_agent) elif mode == "human": return self.observe() else: @@ -237,7 +239,7 @@ class YesCmdrEnv(gym.Env): pos: node for pos in adj_pos if pos in map_data and ( - (node := map_data[pos]).team_id != self.enemy_id + (node := map_data[pos]).faction_id == self.faction_id or node.agent_id not in spotted_enemy_ids ) @@ -247,14 +249,14 @@ class YesCmdrEnv(gym.Env): if des == pos: continue node = map_data[des] - if node.team_id != self.team_id and node.agent_id not in spotted_enemy_ids: + if node.faction_id != self.faction_id and node.agent_id not in spotted_enemy_ids: # 一个格子只能有一个单位 _legal_actions.append(Action( agent_id=agent_id, des=des, action_type=ActionType.MOVE )) - elif node.team_id == self.team_id and agent.agent_type in (target := self.get_agent(node.agent_id)).available_types \ + elif node.faction_id == self.faction_id and agent.agent_type in (target := self.get_agent(node.agent_id)).available_types \ and len(target.parked_agents) < target.capacity: # 除非终点有可以停靠的单位 _legal_actions.append(Action( @@ -282,7 +284,7 @@ class YesCmdrEnv(gym.Env): if des not in map_data: continue node = map_data[des] - if node.team_id == self.team_id and agent.agent_type in (target :=self.get_agent(node.agent_id)).available_types \ + if node.faction_id == self.faction_id and agent.agent_type in (target :=self.get_agent(node.agent_id)).available_types \ and len(target.parked_agents) < target.capacity: _legal_actions.append(Action( agent_id=agent_id, @@ -313,7 +315,7 @@ class YesCmdrEnv(gym.Env): if des not in map_data: continue node = map_data[des] - if node.team_id == self.team_id and self.get_agent(node.agent_id).has_supply: + if node.faction_id == self.faction_id and self.get_agent(node.agent_id).has_supply: for weapon_idx, weapon in enumerate(agent.switchable_weapons): if not agent.weapon or weapon.name != agent.weapon.name: _legal_actions.append(Action( @@ -337,17 +339,17 @@ class YesCmdrEnv(gym.Env): @property def spotted_enemy_ids(self) -> List[str]: if not self.war_fog: - return self.teams[self.enemy_id].alive_ids + return [enemy_id for enemy_id, enemy_agent in self.enemy_agents_dict.items() if enemy_agent.alive] if self._spotted_enemy_ids is not None: return self._spotted_enemy_ids - player_agents = self.current_agents + current_agents = self.current_agents enemy_data = self.enemy_agents_dict map_data = self.map.nodes _spotted_enemy_ids = [] - for agent in player_agents: + for agent in current_agents: info_level = agent.info_level for des in get_adj_pos(agent.pos, agent.info_level): if des not in map_data: @@ -356,7 +358,7 @@ class YesCmdrEnv(gym.Env): dis = get_axial_dis(agent.pos, des) if dis == 1: # Adjacent agents are always spotted dis = -1e9 - if node.team_id == self.enemy_id \ + if node.faction_id != self.faction_id \ and enemy_data[node.agent_id].stealth_level < info_level - dis + 1: _spotted_enemy_ids.append(node.agent_id) @@ -409,7 +411,7 @@ class YesCmdrEnv(gym.Env): _teams = [ Team( team_id=_team_id, - faction_id=_agents[_team_id][0].faction_id, + faction_id=agents_dict_lists[_team_id][0]["faction_id"], agents=_agents[_team_id] ) for _team_id in range(self.max_team) @@ -422,6 +424,7 @@ class YesCmdrEnv(gym.Env): assert agent.pos in _nodes, f"Invalid agent position: {agent.pos}" _nodes[agent.pos].agent_id = agent_id _nodes[agent.pos].team_id = team.team_id + _nodes[agent.pos].faction_id = team.faction_id _map = Map(_nodes) self.map = _map @@ -512,6 +515,14 @@ class YesCmdrEnv(gym.Env): def current_agents_dict(self): return self.teams[self._team_id].agents + @property + def union_agents(self): + return self._union_agents[self.faction_id] + + @property + def union_agents_dict(self): + return self._union_agents_dict[self.faction_id] + @property def enemy_agents(self): return self._enemy_agents[self.faction_id] @@ -543,6 +554,7 @@ class YesCmdrEnv(gym.Env): return self._spotted_enemy_agents _spotted_enemy_agents = [self.enemy_agents_dict[enemy_id] for enemy_id in self.spotted_enemy_ids] self._spotted_enemy_agents = _spotted_enemy_agents + return self._spotted_enemy_agents @property def faction_id(self): @@ -612,7 +624,7 @@ class YesCmdrEnv(gym.Env): for agent in self.current_agents: for pos in get_adj_pos(agent.pos, 1): - if pos in self.map.nodes and self.map.nodes[pos].team_id == self.team_id: + if pos in self.map.nodes and self.map.nodes[pos].faction_id == self.faction_id: adj_agent = self.get_agent(self.map.nodes[pos].agent_id) if not adj_agent.has_supply: continue @@ -642,6 +654,7 @@ class YesCmdrEnv(gym.Env): node = map_data[agent.pos] node.agent_id = None node.team_id = -1 + node.faction_id = -1 mobility = min(agent.mobility, agent.fuel) adj_pos = get_adj_pos(agent.pos, int(mobility)) @@ -649,7 +662,7 @@ class YesCmdrEnv(gym.Env): pos: node for pos in adj_pos if pos in map_data and ( - (node := map_data[pos]).team_id != self.enemy_id + (node := map_data[pos]).faction_id == self.faction_id or node.agent_id not in spotted_enemy_ids ) @@ -660,7 +673,7 @@ class YesCmdrEnv(gym.Env): exception = False for cur, nxt in zip(path[:-1], path[1:]): - if self.war_fog and map_data[nxt].team_id == self.enemy_id: # 前进方向遇到敌军,停下 + if self.war_fog and map_data[nxt].faction_id != self.faction_id: # 前进方向遇到敌军,停下 exception = True break if self.replay_path is not None: @@ -675,9 +688,10 @@ class YesCmdrEnv(gym.Env): if map_data[cur].team_id == -1: map_data[cur].agent_id = agent_id map_data[cur].team_id = self.team_id + map_data[cur].faction_id = self.faction_id elif map_data[cur].agent_id == target_id: carry_agent = self.get_agent(target_id) - assert agent.team_id == carry_agent.team_id + assert agent.faction_id == carry_agent.faction_id assert agent.agent_type in carry_agent.available_types, f"{agent} not in {carry_agent.available_types} (id: {carry_agent.agent_id[:8]}). " assert len(carry_agent.parked_agents) < carry_agent.capacity, f"Agent {carry_agent[:8]} is available but full" carry_agent.parked_agents.append(agent) @@ -688,8 +702,8 @@ class YesCmdrEnv(gym.Env): return 0 def _attack(self, attack_id, target_id): - agent = self.current_agents[attack_id] - target_agent = self.teams[self.enemy_id].agents[target_id] + agent = self.get_agent(attack_id) + target_agent = self.get_agent(target_id) if self.replay_path is not None: self._frames[self.team_id].extend([self.render(mode="rgb_array", attack_agent=agent, defend_agent=target_agent)] * 4) @@ -705,8 +719,9 @@ class YesCmdrEnv(gym.Env): node = self.get_node(target_agent.pos) node.agent_id = None node.team_id = -1 + node.faction_id = -1 - if self.teams[self.enemy_id].alive_count == 0: + if sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0: return damage, True return damage, False @@ -724,6 +739,7 @@ class YesCmdrEnv(gym.Env): agent_to_release.pos = des node = self.get_node(des) node.team_id = agent_to_release.team_id + node.faction_id = agent_to_release.faction_id node.agent_id = agent_to_release.agent_id for module in agent.modules: @@ -751,7 +767,7 @@ class YesCmdrEnv(gym.Env): return 0 def observe(self): - player_agents = self.current_agents + union_agents = self.union_agents spotted_enemy_agents = self.spotted_enemy_agents obs = {} @@ -764,16 +780,16 @@ class YesCmdrEnv(gym.Env): for action_id in legal_actions_ids: obs["action_mask"][action_id] = 1 - for agent in player_agents: + for agent in union_agents: pos = agent.pos - obs["player_info_level"][pos[0], pos[1]] = agent.info_level - obs["player_stealth_level"][pos[0], pos[1]] = agent.stealth_level - obs["player_mobility"][pos[0], pos[1]] = agent.mobility - obs["player_defense"][pos[0], pos[1]] = agent.defense - obs["player_damage"][pos[0], pos[1]] = agent.damage - obs["player_fuel"][pos[0], pos[1]] = agent.fuel - obs["player_ammo"][pos[0], pos[1]] = agent.ammo - obs["player_endurance"][pos[0], pos[1]] = agent.endurance + obs["union_info_level"][pos[0], pos[1]] = agent.info_level + obs["union_stealth_level"][pos[0], pos[1]] = agent.stealth_level + obs["union_mobility"][pos[0], pos[1]] = agent.mobility + obs["union_defense"][pos[0], pos[1]] = agent.defense + obs["union_damage"][pos[0], pos[1]] = agent.damage + obs["union_fuel"][pos[0], pos[1]] = agent.fuel + obs["union_ammo"][pos[0], pos[1]] = agent.ammo + obs["union_endurance"][pos[0], pos[1]] = agent.endurance for agent in spotted_enemy_agents: pos = agent.pos @@ -803,7 +819,7 @@ class YesCmdrEnv(gym.Env): pos: node for pos in adj_pos if pos in map_data and ( - (node := map_data[pos]).team_id != self.enemy_id + (node := map_data[pos]).faction_id == self.faction_id or node.agent_id not in spotted_enemy_ids ) @@ -892,7 +908,7 @@ class YesCmdrEnv(gym.Env): pos for pos in get_adj_pos(target_agent.pos, int(agent.attack_range)) if pos in map_data and ( - (node := map_data[pos]).team_id != self.enemy_id + (node := map_data[pos]).faction_id == self.faction_id or node.agent_id not in spotted_enemy_ids ) @@ -939,13 +955,13 @@ class YesCmdrEnv(gym.Env): 需要补给的单位移动至有补给模块的单位附近 """ agent = self.get_agent(agent_id) - player_agents = self.teams[agent.team_id].agents.values() + current_agents = self.current_agents map_data = self.map.nodes origin_todo_length = len(agent.todo) # 记录原 todo 的长度 possible_des = [] - for _agent in player_agents: + for _agent in current_agents: if _agent.has_supply: for pos in get_adj_pos(_agent.pos, 1): if pos in map_data: @@ -957,7 +973,7 @@ class YesCmdrEnv(gym.Env): des = agent.todo[-1].des if agent.todo else agent.pos for pos in get_adj_pos(des, 1): - if pos in map_data and map_data[pos].team_id == agent.team_id and self.get_agent(map_data[pos].agent_id).has_supply: + if pos in map_data and map_data[pos].faction_id == agent.faction_id and self.get_agent(map_data[pos].agent_id).has_supply: supply_id = map_data[pos].agent_id supply_module = self.get_agent(supply_id).supply @@ -1085,7 +1101,7 @@ class YesCmdrEnv(gym.Env): map_data = self.map.nodes spotted_enemy_ids = self.spotted_enemy_ids - if action.agent_id not in self.teams[self.player_id].agents: + if action.agent_id not in self.current_agents_dict: return False, f"Invalid agent id: {action.agent_id}" agent = self.get_agent(action.agent_id) @@ -1098,7 +1114,7 @@ class YesCmdrEnv(gym.Env): node = map_data[des] - if node.team_id == agent.team_id and (not agent.agent_type in (target := self.get_agent(node.agent_id)).available_types or len(target.parked_agents) >= target.capacity): + if node.faction_id == agent.faction_id and (not agent.agent_type in (target := self.get_agent(node.agent_id)).available_types or len(target.parked_agents) >= target.capacity): return False, f"{node.agent_id[:8]} at {des} is an ally but not available" if node.agent_id in spotted_enemy_ids: @@ -1114,7 +1130,7 @@ class YesCmdrEnv(gym.Env): pos: node for pos in adj_pos if pos in map_data and ( - (node := map_data[pos]).team_id != self.enemy_id + (node := map_data[pos]).faction_id == self.faction_id or node.agent_id not in spotted_enemy_ids ) @@ -1135,8 +1151,8 @@ class YesCmdrEnv(gym.Env): des = action.des - if agent.team_id == target_agent.team_id: - return False, f"Cannot attack teammate. Attack: {action.agent_id}. Defend: {action.target_id}" + if agent.faction_id == target_agent.faction_id: + return False, f"Cannot attack allies. Attack: {action.agent_id}. Defend: {action.target_id}" if agent.weapon is None: return False, "No weapon equipped" if agent.ammo <= 0: @@ -1150,7 +1166,7 @@ class YesCmdrEnv(gym.Env): weapon_idx = action.weapon_idx around_supply = False for pos in get_adj_pos(agent.pos, 1): - if pos in map_data and map_data[pos].agent_id and (target_agent := self.get_agent(map_data[pos].agent_id)).team_id == agent.team_id and target_agent.has_supply: + if pos in map_data and map_data[pos].agent_id and (target_agent := self.get_agent(map_data[pos].agent_id)).faction_id == agent.faction_id and target_agent.has_supply: around_supply = True if not around_supply: return False, "Agent is not around a supply module" @@ -1172,7 +1188,7 @@ class YesCmdrEnv(gym.Env): target_agent = self.get_agent(action.target_id) if get_axial_dis(agent.pos, target_agent.pos) > 1: return False, "Target agent is out of range" - if agent.team_id != target_agent.team_id: + if agent.faction_id != target_agent.faction_id: return False, "Cannot interact with enemy agent" if agent.agent_type not in target_agent.available_types: return False, f"Target agent {target_agent.agent_id[:8]} is not interactable with agent {agent}" @@ -1200,9 +1216,10 @@ class YesCmdrEnv(gym.Env): if node.agent_id == agent_id: node.agent_id = None node.team_id = -1 + node.faction_id = -1 agent.pos = (sync_agent['pos']['q'], sync_agent['pos']['r']) agent.fuel = sync_agent['fuel'] - if agent.team_id == self.enemy_id: + if agent.faction_id != self.faction_id: reward += agent.endurance - sync_agent['endurance'] agent.endurance = sync_agent['endurance'] agent.commenced_action = sync_agent['commenced_action'] @@ -1220,6 +1237,7 @@ class YesCmdrEnv(gym.Env): node = map_data[agent.pos] node.agent_id = agent_id node.team_id = agent.team_id + node.faction_id = agent.faction_id for agent_id in self._spotted_enemy_ids: agent = self.get_agent(agent_id) @@ -1229,6 +1247,7 @@ class YesCmdrEnv(gym.Env): if node.agent_id == agent_id: node.agent_id = None node.team_id = -1 + node.faction_id = -1 # 更新敌人感知 self._spotted_enemy_ids = [enemy['agent_id'] for enemy in spotted_enemies]