import numpy as np import json import os from datetime import datetime from typing import List, Optional import gymnasium as gym from .common.utils import * from .common.agent import * class YesCmdrEnv(gym.Env): metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4} def __init__(self, use_real_engine: bool = False, replay_path: str = None, campaign_id: str = "NOID", team_id: int = 0, faction_id: int = 0, max_team: int = 2, max_faction: int = 2, data_path: str = "../data", max_episode_steps: int = 1000, bot_version: str = "v0", war_fog: bool = True) -> None: self._init_flag = False # set other properties... self.use_real_engine = use_real_engine self.replay_path = replay_path self.campaign_id = campaign_id self.init_team_id = team_id self.init_faction_id = faction_id self.max_team = max_team self.max_faction = max_faction self.data_path = data_path self.max_episode_steps = max_episode_steps self.bot_version = bot_version self.war_fog = war_fog def reset(self): # reset the environment... # 实际的环境初始化是在第一次调用 reset 方法时进行的 if not self._init_flag: self._init_flag = True self._make_env() if self.replay_path is not None: from .common.renderer import Renderer self._renderer = Renderer(self.map.nodes.values()) else: for node in self.map.nodes.values(): node.reset() for team in self.teams: for agent_id, agent in team.agents.items(): agent.reset() if agent.pos == (-1, -1): # Illegal position for padding continue assert agent.pos in self.map.nodes, f"Invalid agent position: {agent.pos}" self.map.nodes[agent.pos].agent_id = agent_id self.map.nodes[agent.pos].team_id = team.team_id self.map.nodes[agent.pos].faction_id = team.faction_id if self.replay_path is not None: self._frames = [[] for _ in range(self.max_team)] self._legal_actions = None self._legal_action_ids = None self._spotted_enemy_ids = None self._spotted_enemy_agents = None self._team_id = self.init_team_id self._spotted_agents = None self.episode_steps = 0 self._cumulative_rewards = np.zeros(self.max_team) obs = self.observe() info = { "next_team": self.team_id + 1 } return obs, info def step(self, action: Union[int, Action]): if isinstance(action, int): if action not in self.legal_action_ids: print(f"Invalid action will be ignored: {action}") info = {} info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1]) info["next_team"] = self.team_id + 1 info["exception"] = f"Invalid action {action}" return self.observe(), 0, False, False, info action = self.id_to_action(action) elif isinstance(action, Action): valid, reason = self.check_validity(action) if not valid: if action.action_type == ActionType.MOVE and reason.find("occupied") > 0: print("Trying to move to an adjacent position that is empty.") dis = float('inf') for _action in self.legal_actions: if _action.agent_id == action.agent_id and _action.action_type == ActionType.MOVE and get_axial_dis(_action.des, action.des) < dis: dis = get_axial_dis(_action.des, action.des) action = _action else: print(f"Invalid action will be ignored: {action}") print(f"Reason: {reason}") info = {} info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1]) info["next_team"] = self.team_id + 1 info["exception"] = f"Invalid action: {action}. Reason: {reason}" return self.observe(), 0, False, False, info assert(isinstance(action, Action)) reward, terminated = self._player_step(action) """ NOTE: Clear the states before observation """ self._legal_actions = None self._legal_action_ids = None self._spotted_enemy_ids = None self._spotted_enemy_agents = None obs = self.observe() truncated = self.episode_steps >= self.max_episode_steps self._cumulative_rewards[self.team_id] += reward info = {} info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1]) info["next_team"] = self.team_id + 1 if not terminated else -1 if self.replay_path is not None: self._frames[self.team_id].append(self.render(mode="rgb_array")) if terminated or truncated: # The eval_episode_return is calculated from Player 1's perspective info["eval_episode_return"] = reward if self.faction_id == self.init_faction_id else -reward info["done_reason"] = "Terminated" if terminated else "Truncated" if self.replay_path is not None: self.save_replay() return obs, reward, terminated, truncated, info def render(self, mode="human", attack_agent=None, defend_agent=None): if mode == "rgb_array": return self._renderer.render(self.union_agents, self.spotted_enemy_agents, attack_agent, defend_agent) elif mode == "human": return self.observe() else: raise ValueError("Invalid render mode: {}".format(mode)) def action_to_id(self, action: Action) -> int: if action.action_type == ActionType.END_OF_TURN: return 0 action_space = self.action_space if action.agent_id not in action_space.keys(): raise ValueError(f"Invalid agent_id: {action.agent_id[:8]}") agent_action_space = action_space[action.agent_id] if action.action_type not in agent_action_space.keys(): agent = self.get_agent(action.agent_id) print(agent.available_types, agent.parked_agents) print(action) raise ValueError(f"Invalid action_type: {action.action_type.name} for agent {action.agent_id[:8]}. Available action_types: {agent_action_space.keys()}") action_id = 1 for agent_id, agent_action_space in action_space.items(): for action_type, action_space in agent_action_space.items(): if agent_id != action.agent_id or action_type != action.action_type: action_id += action_space.n else: if action.action_type == ActionType.SWITCH_WEAPON: # 以下标编码,不以位置编码 assert action.weapon_idx < action_space.n, f"Invalid weapon_idx: {action.weapon_idx}, action_space.n: {action_space.n}" action_id += action.weapon_idx elif action.action_type == ActionType.RELEASE: assert action.target_id < action_space.n, f"Invalid target_idx: {action.target_id}, action_space.n: {action_space.n}" action_id += action.target_id # 此处为下标 else: pos, des = self.get_agent(action.agent_id).pos, action.des encode_id = encode_axial((des[0] - pos[0], des[1] - pos[1])) action_id += encode_id - 1 if action.action_type == ActionType.MOVE: cur_range = self.get_agent(action.agent_id).mobility elif action.action_type == ActionType.ATTACK: cur_range = self.get_agent(action.agent_id).attack_range elif action.action_type == ActionType.INTERACT: cur_range = 1 assert encode_id - 1 < action_space.n, f"Action {action.action_type.name} Distance: {get_axial_dis(pos, des)}, cur_range: {cur_range} encode_id: {encode_id} action_space.n: {action_space.n}\n{action}" return int(action_id) assert False, "Should not reach here" def id_to_action(self, action_id: int) -> Action: if action_id == 0: return Action(ActionType.END_OF_TURN) action_id -= 1 action_space = self.action_space for agent_id, agent_action_space in action_space.items(): for action_type, action_space in agent_action_space.items(): if action_id < action_space.n: if action_type == ActionType.SWITCH_WEAPON: return Action(agent_id=agent_id, action_type=action_type, weapon_idx=action_id, weapon_name=self.get_agent(agent_id).weapon.name) elif action_type == ActionType.RELEASE: return Action(agent_id=agent_id, target_id=action_id, action_type=action_type) else: pos = self.get_agent(agent_id).pos delta = decode_axial(action_id + 1) # 编码 0 是原位置 des = (pos[0] + delta[0], pos[1] + delta[1]) return Action(agent_id=agent_id, des=des, action_type=action_type) else: action_id -= action_space.n raise ValueError("Invalid action_id: {}".format(action_id)) def get_agent(self, agent_id: str) -> Agent: for team in self.teams: if agent_id in team.agents.keys(): agent = team.agents[agent_id] # if not agent.alive: # raise ValueError(f"Agent {agent_id} is not alive") return agent raise ValueError(f"Invalid agent id: {agent_id}") def get_agents(self, agent_ids: List[str]) -> List[Agent]: agents = [] for agent_id in agent_ids: agent = self.get_agent(agent_id) # if not agent.alive: # raise ValueError(f"Agent {agent_id} is not alive") agents.append(agent) return agents def get_agents_dict(self, agent_ids: List[str]) -> Dict[str, dict]: agents_dict = {} for agent_id in agent_ids: agent = self.get_agent(agent_id) agents_dict[agent_id] = agent.to_dict() return agents_dict def get_node(self, pos: Tuple[int, int]) -> TileNode: if pos not in self.map.nodes: raise ValueError(f"Invalid position: {pos}") return self.map.nodes[pos] @property def legal_actions(self) -> List[Action]: if self._legal_actions is not None: return self._legal_actions map_data = self.map.nodes team_data = self.current_agents_dict spotted_enemy_ids = self.spotted_enemy_ids _legal_actions = [Action(ActionType.END_OF_TURN)] # Default action for agent_id, agent in team_data.items(): if agent.commenced_action or not agent.alive or agent.pos == (-1, -1) or agent.is_carried: continue pos = agent.pos # 移动 mobility = min(agent.fuel, agent.mobility) if mobility > 0: adj_pos = get_adj_pos(pos, int(mobility)) aval_data = { pos: node for pos in adj_pos if pos in map_data and ( (node := map_data[pos]).faction_id == self.faction_id or node.agent_id not in spotted_enemy_ids ) } aval_nodes = astar_search(aval_data, agent.move_type, start=pos, limit=mobility) for des in aval_nodes: if des == pos: continue node = map_data[des] if node.faction_id != self.faction_id and node.agent_id not in spotted_enemy_ids: # 一个格子只能有一个单位 _legal_actions.append(Action( agent_id=agent_id, des=des, action_type=ActionType.MOVE )) elif node.faction_id == self.faction_id and agent.agent_type in (target := self.get_agent(node.agent_id)).available_types \ and len(target.parked_agents) < target.capacity: # 除非终点有可以停靠的单位 _legal_actions.append(Action( agent_id=agent_id, target_id=node.agent_id, des=des, action_type=ActionType.MOVE )) # 进攻 if agent.attack_range > 0 and agent.ammo > 0: for des in get_adj_pos(pos, agent.attack_range): if des not in map_data: continue node = map_data[des] if node.agent_id in spotted_enemy_ids and self.get_agent(node.agent_id).move_type in agent.strike_types: # 已发现敌军 _legal_actions.append(Action( agent_id=agent_id, target_id=node.agent_id, action_type=ActionType.ATTACK, des=des )) # 交互 for des in get_adj_pos(pos, 1): if des not in map_data: continue node = map_data[des] if node.faction_id == self.faction_id and agent.agent_type in (target :=self.get_agent(node.agent_id)).available_types \ and len(target.parked_agents) < target.capacity: _legal_actions.append(Action( agent_id=agent_id, target_id=node.agent_id, des=des, action_type=ActionType.INTERACT )) # 释放 if agent.parked_agents: for des in get_adj_pos(pos, 1): if des not in map_data: continue node = map_data[des] if node.team_id != -1: continue for idx, agent_to_release in enumerate(agent.parked_agents): if get_cost(agent_to_release.move_type, node.terrain_type) != 0: _legal_actions.append(Action( agent_id=agent_id, target_id=idx, des=des, action_type=ActionType.RELEASE )) break # 切换武器 if agent.switchable_weapons: for des in get_adj_pos(pos, 1): if des not in map_data: continue node = map_data[des] if node.faction_id == self.faction_id and self.get_agent(node.agent_id).has_supply: for weapon_idx, weapon in enumerate(agent.switchable_weapons): if not agent.weapon or weapon.name != agent.weapon.name: _legal_actions.append(Action( agent_id=agent_id, target_id=weapon.name, weapon_idx=weapon_idx, action_type=ActionType.SWITCH_WEAPON )) self._legal_actions = _legal_actions return self._legal_actions @property def legal_action_ids(self) -> List[int]: if self._legal_action_ids is not None: return self._legal_action_ids self._legal_action_ids = [self.action_to_id(action) for action in self.legal_actions] return self._legal_action_ids @property def spotted_enemy_ids(self) -> List[str]: if not self.war_fog: return [enemy_id for enemy_id, enemy_agent in self.enemy_agents_dict.items() if enemy_agent.alive] if self._spotted_enemy_ids is not None: return self._spotted_enemy_ids current_agents = self.current_agents enemy_data = self.enemy_agents_dict map_data = self.map.nodes _spotted_enemy_ids = [] for agent in current_agents: info_level = agent.info_level for des in get_adj_pos(agent.pos, agent.info_level): if des not in map_data: continue node = map_data[des] dis = get_axial_dis(agent.pos, des) if dis == 1: # Adjacent agents are always spotted dis = -1e9 if node.faction_id != self.faction_id \ and enemy_data[node.agent_id].stealth_level < info_level - dis + 1: _spotted_enemy_ids.append(node.agent_id) self._spotted_enemy_ids = _spotted_enemy_ids return self._spotted_enemy_ids def save_replay(self): if not os.path.exists(self.replay_path): os.makedirs(self.replay_path) timestamp = datetime.now().strftime("%Y%m%d%H%M%S") for team_id in range(self.max_team): path = os.path.join( self.replay_path, f"tianqiong_{timestamp}_Team{team_id + 1}.mp4" ) self.display_frames_as_mp4(self._frames[team_id], path) print(f'replay {path} saved!') @staticmethod def display_frames_as_mp4(frames: list, path: str, fps=4) -> None: assert path.endswith('.mp4'), f'path must end with .mp4, but got {path}' import imageio imageio.mimwrite(path, frames, fps=fps) def _make_env(self): # Configuration for file paths map_path = f"{self.data_path}/MapInfo.json" agents_path = f"{self.data_path}/AgentsInfo.json" # Load map and agents information with open(map_path, 'r') as f: origin_map = json.load(f) with open(agents_path, 'r') as f: origin_agents = json.load(f) _nodes = { (node := dict_to_node(node_dict)).pos: node for node_dict in origin_map } agents_dict_lists = [ [agent_dict for agent_dict in origin_agents if agent_dict["team_id"] == team_id] for team_id in range(self.max_team) ] _agents = [ { (agent := dict_to_agent(agent_dict)).agent_id: agent for agent_dict in agents_dict_lists[team_id] } for team_id in range(self.max_team) ] _teams = [ Team( team_id=_team_id, faction_id=agents_dict_lists[_team_id][0]["faction_id"] if agents_dict_lists[_team_id] else -1, # 可能为空 agents=_agents[_team_id] ) for _team_id in range(self.max_team) ] for team in _teams: for agent_id, agent in team.agents.items(): if agent.pos == (-1, -1): # Illegal position for padding continue assert agent.pos in _nodes, f"Invalid agent position: {agent.pos}" _nodes[agent.pos].agent_id = agent_id _nodes[agent.pos].team_id = team.team_id _nodes[agent.pos].faction_id = team.faction_id _map = Map(_nodes) self.map = _map self.teams = _teams self.obs_shape = _map.width, _map.height # Get action space self._action_spaces = [ gym.spaces.Dict( { f"team_{team.team_id + 1}": gym.spaces.Dict( { ActionType.END_OF_TURN: gym.spaces.Discrete(1) # Default action } ), **{ agent_id: agent.action_space for agent_id, agent in team.agents.items() } } ) for team in self.teams ] self._action_space_sizes = [ sum(action_space.n for agent_action_space in self._action_spaces[team_id].values() for action_space in agent_action_space.values()) for team_id in range(self.max_team) ] self._team_id = 0 for action_space_size in self._action_space_sizes: for action_id in range(action_space_size): assert self.action_to_id(self.id_to_action(action_id)) == action_id self._team_id += 1 # Get observation space self._observation_spaces = [ gym.spaces.Dict( { "action_mask": gym.spaces.Box(0, 1, (self._action_space_sizes[team_id], ), dtype=np.int8), # Different size for each team "union_endurance": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32), "union_info_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16), "union_stealth_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16), "union_mobility": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32), "union_defense": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16), "union_damage": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32), "union_fuel": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32), "union_ammo": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16), "enemy_endurance": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32), "enemy_info_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16), "enemy_stealth_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16), "enemy_mobility": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32), "enemy_defense": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16), "enemy_damage": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32), "enemy_fuel": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32), "enemy_ammo": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16) } ) for team_id in range(self.max_team) ] self._reward_space = gym.spaces.Box(low=-1000, high=1000, shape=(1, ), dtype=np.float32) # 预处理己方联盟和所有敌方的单位信息 self._union_agents = [[] for _ in range(self.max_faction)] self._union_agents_dict = [{} for _ in range(self.max_faction)] self._enemy_agents = [[] for _ in range(self.max_faction)] self._enemy_agents_dict = [{} for _ in range(self.max_faction)] for team_id in range(self.max_team): faction_id = self.teams[team_id].faction_id agents_list = list(self.teams[team_id].agents.values()) agents_dict = self.teams[team_id].agents self._union_agents[faction_id].extend(agents_list) for other_faction_id in range(self.max_faction): if other_faction_id != faction_id: self._enemy_agents[other_faction_id].extend(agents_list) self._enemy_agents_dict[other_faction_id] |= agents_dict @property def team_id(self): return self._team_id @property def current_agents(self): return self.teams[self._team_id].agents.values() @property def current_agents_dict(self): return self.teams[self._team_id].agents @property def union_agents(self): return self._union_agents[self.faction_id] @property def union_agents_dict(self): return self._union_agents_dict[self.faction_id] @property def enemy_agents(self): return self._enemy_agents[self.faction_id] # if self._enemy_agents is not None: # return self._enemy_agents # _enemy_agents = [] # for team_id in range(self.max_team): # if self.teams[team_id].faction_id != self.faction_id: # _enemy_agents.extend(list(self.teams[team_id].agents.values())) # self._enemy_agents = _enemy_agents # return self._enemy_agents @property def enemy_agents_dict(self): return self._enemy_agents_dict[self.faction_id] # if self._enemy_agents_dict is not None: # return self._enemy_agents_dict # _enemy_agents_dict = {} # for team_id in range(self.max_team): # if self.teams[team_id].faction_id != self.faction_id: # for agent_id, agent in self.teams[team_id].agents.items(): # _enemy_agents_dict[agent_id] = agent # self._enemy_agents_dict = _enemy_agents_dicct # return self._enemy_agents_dict @property def spotted_enemy_agents(self): if self._spotted_enemy_agents is not None: return self._spotted_enemy_agents _spotted_enemy_agents = [self.enemy_agents_dict[enemy_id] for enemy_id in self.spotted_enemy_ids] self._spotted_enemy_agents = _spotted_enemy_agents return self._spotted_enemy_agents @property def faction_id(self): if self._team_id >= self.max_team: raise ValueError(f"Team id {self._team_id} exceeds max id {self.max_team - 1}") return self.teams[self._team_id].faction_id @property def observation_space(self): return self._observation_spaces[self.team_id] @property def action_space(self): return self._action_spaces[self.team_id] @property def action_space_size(self): return self._action_space_sizes[self.team_id] @property def reward_space(self): return self._reward_space def random_action(self) -> Action: return np.random.choice(self.legal_actions) def next_turn(self): self._team_id = (self._team_id + 1) % self.max_team def bot_action(self) -> Action: if self.bot_version == "v0": attack_actions = [action for action in self.legal_actions if action.action_type == ActionType.ATTACK] if attack_actions: return np.random.choice(attack_actions) else: return self.random_action() else: raise NotImplementedError(f"Invalid bot version: {self.bot_version}") def _player_step(self, action: Action): if action.action_type == ActionType.END_OF_TURN: """ NOTE: here exchange the player """ # 清除动作标记 for agent in self.current_agents: agent.commenced_action = False # 切换玩家 self.episode_steps += 1 if not self.use_real_engine: self.next_turn() reward, terminated = 0, False # 这里可以加上惩罚项 elif action.action_type == ActionType.MOVE: reward = self._move(action.agent_id, action.des, action.target_id) terminated = False elif action.action_type == ActionType.ATTACK: reward, terminated = self._attack(action.agent_id, action.target_id) elif action.action_type == ActionType.INTERACT: reward = self._interact(action.agent_id, action.target_id) terminated = False elif action.action_type == ActionType.RELEASE: reward = self._release(action.agent_id, action.target_id, action.des) terminated = False elif action.action_type == ActionType.SWITCH_WEAPON: reward = self._switch_weapon(action.agent_id, action.weapon_idx) terminated = False else: raise ValueError(f"Invalid action type: {action.action_type}") for agent in self.current_agents: for pos in get_adj_pos(agent.pos, 1): if pos in self.map.nodes and self.map.nodes[pos].faction_id == self.faction_id: adj_agent = self.get_agent(self.map.nodes[pos].agent_id) if not adj_agent.has_supply: continue module = adj_agent.supply if module.add_endurance > 0: agent.endurance += min(module.add_endurance, agent.max_endurance - agent.endurance) else: agent.endurance = agent.max_endurance if agent.weapon is not None: if module.add_ammo > 0: agent.weapon.ammo += min(module.add_ammo, agent.weapon.max_ammo - agent.weapon.ammo) else: agent.weapon.ammo = agent.weapon.max_ammo if module.add_fuel > 0: agent.fuel += min(module.add_fuel, agent.max_fuel - agent.fuel) else: agent.fuel = agent.max_fuel return reward, terminated def _move(self, agent_id, des, target_id): map_data = self.map.nodes agent = self.current_agents_dict[agent_id] spotted_enemy_ids = self.spotted_enemy_ids origin_pos = agent.pos # 清除原位置关联 node = map_data[agent.pos] node.agent_id = None node.team_id = -1 node.faction_id = -1 mobility = min(agent.mobility, agent.fuel) adj_pos = get_adj_pos(agent.pos, int(mobility)) aval_data = { pos: node for pos in adj_pos if pos in map_data and ( (node := map_data[pos]).faction_id == self.faction_id or node.agent_id not in spotted_enemy_ids ) } path = astar_search(aval_data, move_type=agent.move_type, start=agent.pos, goal=des, limit=mobility) path = [agent.pos] + path exception = False for cur, nxt in zip(path[:-1], path[1:]): if self.war_fog and map_data[nxt].faction_id != self.faction_id: # 前进方向遇到敌军,停下 exception = True break if self.replay_path is not None: self._frames[self.team_id].append(self.render(mode="rgb_array")) agent.fuel -= get_cost(agent.move_type, map_data[nxt].terrain_type) agent.pos = nxt # 更新位置以及与地图的关联 cur = agent.pos agent.commenced_action = True if map_data[cur].team_id == -1: map_data[cur].agent_id = agent_id map_data[cur].team_id = self.team_id map_data[cur].faction_id = self.faction_id elif map_data[cur].agent_id == target_id: carry_agent = self.get_agent(target_id) assert agent.faction_id == carry_agent.faction_id assert agent.agent_type in carry_agent.available_types, f"{agent} not in {carry_agent.available_types} (id: {carry_agent.agent_id[:8]}). " assert len(carry_agent.parked_agents) < carry_agent.capacity, f"Agent {carry_agent[:8]} is available but full" carry_agent.parked_agents.append(agent) agent.is_carried = True else: assert exception return 0 def _attack(self, attack_id, target_id): agent = self.get_agent(attack_id) target_agent = self.get_agent(target_id) if self.replay_path is not None: self._frames[self.team_id].extend([self.render(mode="rgb_array", attack_agent=agent, defend_agent=target_agent)] * 4) damage = max(agent.damage - target_agent.defense, 0) damage = min(damage, target_agent.endurance) target_agent.endurance -= damage agent.weapon.ammo -= 1 agent.commenced_action = True terminated = False if target_agent.endurance <= 0: # 删除与地图关联 node = self.get_node(target_agent.pos) node.agent_id = None node.team_id = -1 node.faction_id = -1 terminated = sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0 return damage, terminated def _interact(self, agent_id, target_id): agent = self.get_agent(agent_id) target = self.get_agent(target_id) target.parked_agents.append(agent) agent.is_carried = True return 0 # 后续可以考虑修改奖励为补给的线性组合 def _release(self, agent_id: str, target_idx: int, des: Tuple[int, int]): agent = self.get_agent(agent_id) agent_to_release = agent.parked_agents[target_idx] agent_to_release.pos = des node = self.get_node(des) node.team_id = agent_to_release.team_id node.faction_id = agent_to_release.faction_id node.agent_id = agent_to_release.agent_id for module in agent.modules: if module.add_endurance > 0: agent_to_release.endurance += min(module.add_endurance, agent_to_release.max_endurance - agent_to_release.endurance) else: agent_to_release.endurance = agent_to_release.max_endurance if agent_to_release.weapon is not None: if module.add_ammo > 0: agent_to_release.weapon.ammo += min(module.add_ammo, agent_to_release.weapon.max_ammo - agent_to_release.weapon.ammo) else: agent_to_release.weapon.ammo = agent_to_release.weapon.max_ammo if module.add_fuel > 0: agent_to_release.fuel += min(module.add_fuel, agent_to_release.max_fuel - agent_to_release.fuel) else: agent_to_release.fuel = agent_to_release.max_fuel agent_to_release.is_carried = False return 0 # 后续可以考虑修改奖励为补给的线性组合 def _switch_weapon(self, agent_id: str, weapon_idx: int): agent = self.get_agent(agent_id) agent.weapon.reset() agent.weapon = agent.switchable_weapons[weapon_idx] return 0 def observe(self): union_agents = self.union_agents spotted_enemy_agents = self.spotted_enemy_agents obs = {} for desc, space in self.observation_space.items(): obs[desc] = np.zeros(space.shape, dtype=space.dtype) legal_actions_ids = self.legal_action_ids for action_id in legal_actions_ids: obs["action_mask"][action_id] = 1 for agent in union_agents: pos = agent.pos obs["union_info_level"][pos[0], pos[1]] = agent.info_level obs["union_stealth_level"][pos[0], pos[1]] = agent.stealth_level obs["union_mobility"][pos[0], pos[1]] = agent.mobility obs["union_defense"][pos[0], pos[1]] = agent.defense obs["union_damage"][pos[0], pos[1]] = agent.damage obs["union_fuel"][pos[0], pos[1]] = agent.fuel obs["union_ammo"][pos[0], pos[1]] = agent.ammo obs["union_endurance"][pos[0], pos[1]] = agent.endurance for agent in spotted_enemy_agents: pos = agent.pos obs["enemy_info_level"][pos[0], pos[1]] = agent.info_level obs["enemy_stealth_level"][pos[0], pos[1]] = agent.stealth_level obs["enemy_mobility"][pos[0], pos[1]] = agent.mobility obs["enemy_defense"][pos[0], pos[1]] = agent.defense obs["enemy_damage"][pos[0], pos[1]] = agent.damage obs["enemy_fuel"][pos[0], pos[1]] = agent.fuel obs["enemy_ammo"][pos[0], pos[1]] = agent.ammo obs["enemy_endurance"][pos[0], pos[1]] = agent.endurance return obs def command_move(self, agent_id, des: Union[Tuple[int, int], List[Tuple[int, int]]], start_time=0, subtask=False) -> Command: agent = self.get_agent(agent_id) state = agent.todo[-1].state if agent.todo else agent.state if state.pos == des or state.pos in des: return map_data = self.map.nodes origin_todo_length = len(agent.todo) # 记录原 todo 的长度 spotted_enemy_ids = self.spotted_enemy_ids adj_pos = get_adj_pos(state.pos, int(state.fuel)) adj_pos.append(state.pos) aval_data = { pos: node for pos in adj_pos if pos in map_data and ( (node := map_data[pos]).faction_id == self.faction_id or node.agent_id not in spotted_enemy_ids ) } path = get_path(agent=agent, map_data=aval_data, start=state.pos, des=des, limit=state.fuel) # 可能没有可行路径 for pos in path: state.fuel -= get_cost(agent.move_type, map_data[pos].terrain_type) state.pos = pos agent.todo.append(Action(agent_id=agent_id, action_type=ActionType.MOVE, des=pos, start_time=start_time, state=state)) command = Command(agent_id=agent_id, action_type=ActionType.MOVE, des=des, start_time=start_time, end_time=start_time+len(agent.todo)-origin_todo_length-1, state=state) if not subtask: end_time = command.end_time for i in range(origin_todo_length, len(agent.todo)): agent.todo[i].start_time = start_time + i - origin_todo_length agent.todo[i].end_time = end_time agent.todo[-1].end_type = ActionType.MOVE agent.cmd_todo.append(command) return command def command_attack(self, attack_id, target_id, attack_count=1, start_time=0, subtask=False) -> Command: agent = self.get_agent(attack_id) target_agent = self.get_agent(target_id) origin_todo_length = len(agent.todo) # 记录原 todo 的长度 state = agent.todo[-1].state if agent.todo else agent.state if target_agent.move_type not in state.weapon.strike_types: flag = False no_enough_ammo = False for weapon_idx, weapon in enumerate(agent.switchable_weapons): if weapon.strike_types and target_agent.move_type in weapon.strike_types: if weapon.max_ammo < attack_count: no_enough_ammo = True try: self.command_switch(attack_id, weapon_idx, start_time=start_time, subtask=True) flag = True break except Exception as e: agent.todo = agent.todo[:origin_todo_length] raise ValueError("Failed switching to the proper weapon") if not flag: if no_enough_ammo: agent.todo = agent.todo[:origin_todo_length] raise ValueError("Agent has the striking weapon but no one with enough ammo") else: agent.todo = agent.todo[:origin_todo_length] raise ValueError("No proper weapon found") state = agent.todo[-1].state if agent.todo else agent.state if state.weapon.ammo < attack_count: if state.weapon.max_ammo >= attack_count: # 不需要换武器,尝试补给 try: self.command_supply(attack_id, start_time=start_time, subtask=True) except Exception as e: agent.todo = agent.todo[:origin_todo_length] raise ValueError("No enough ammo and failed to get supplies") else: # 需要换武器 flag = False for weapon_idx, weapon in enumerate(agent.switchable_weapons): if weapon.strike_types and target_agent.move_type in weapon.strike_types and weapon.max_ammo >= attack_count: try: self.command_switch(attack_id, weapon_idx, start_time=start_time, subtask=True) flag = True break except Exception as e: agent.todo = agent.todo[:origin_todo_length] raise ValueError("Failed switching to the proper weapon") if not flag: agent.todo = agent.todo[:origin_todo_length] raise ValueError("No proper weapon has enough ammo") map_data = self.map.nodes spotted_enemy_ids = self.spotted_enemy_ids possible_des = [ pos for pos in get_adj_pos(target_agent.pos, int(agent.attack_range)) if pos in map_data and ( (node := map_data[pos]).faction_id == self.faction_id or node.agent_id not in spotted_enemy_ids ) ] try: self.command_move(attack_id, possible_des, start_time=start_time, subtask=True) except Exception as e: agent.todo = agent.todo[:origin_todo_length] raise ValueError("Failed to move to the target") state = agent.todo[-1].state if agent.todo else agent.state for _ in range(attack_count): state.weapon.ammo -= 1 agent.todo.append(Action(agent_id=attack_id, action_type=ActionType.ATTACK, target_id=target_id, des=target_agent.pos, start_time=start_time, state=state)) command = Command(agent_id=attack_id, action_type=ActionType.ATTACK, target_id=target_id, des=target_agent.pos, attack_count=attack_count, start_time=start_time, end_time=start_time+len(agent.todo)-origin_todo_length-1, state=state) if not subtask: end_time = command.end_time for i in range(origin_todo_length, len(agent.todo)): agent.todo[i].start_time = start_time + i - origin_todo_length agent.todo[i].end_time = end_time agent.todo[-1].end_type = ActionType.ATTACK agent.cmd_todo.append(command) return command def command_supply(self, agent_id, start_time=0, subtask=False) -> Command: """ 需要补给的单位移动至有补给模块的单位附近 """ agent = self.get_agent(agent_id) current_agents = self.current_agents map_data = self.map.nodes origin_todo_length = len(agent.todo) # 记录原 todo 的长度 possible_des = [] for _agent in current_agents: if _agent.has_supply: for pos in get_adj_pos(_agent.pos, 1): if pos in map_data: possible_des.append(pos) # print(possible_des) self.command_move(agent_id, possible_des, start_time, subtask=True) des = agent.todo[-1].des if agent.todo else agent.pos for pos in get_adj_pos(des, 1): if pos in map_data and map_data[pos].faction_id == agent.faction_id and self.get_agent(map_data[pos].agent_id).has_supply: supply_id = map_data[pos].agent_id supply_module = self.get_agent(supply_id).supply state = agent.todo[-1].state if agent.todo else agent.state if supply_module.add_endurance > 0: state.endurance = min(agent.max_endurance, state.endurance + supply_module.add_endurance) else: state.endurance = agent.max_endurance if supply_module.add_ammo > 0: state.weapon.ammo = min(state.weapon.max_ammo, state.weapon.ammo + supply_module.add_ammo) else: state.weapon.ammo = state.weapon.max_ammo if supply_module.add_fuel > 0: state.fuel = min(agent.max_fuel, state.fuel + supply_module.add_fuel) else: state.fuel = agent.max_fuel command = Command(agent_id=agent_id, action_type=ActionType.SUPPLY, des=des, start_time=start_time, end_time=start_time+len(agent.todo)-origin_todo_length-1, state=state) if not subtask: end_time = command.end_time for i in range(origin_todo_length, len(agent.todo)): agent.todo[i].start_time = start_time + i - origin_todo_length agent.todo[i].end_time = end_time agent.todo[-1].end_type = ActionType.SUPPLY agent.cmd_todo.append(command) return command def command_switch(self, agent_id, weapon_idx, start_time=0, subtask=False) -> Command: agent = self.get_agent(agent_id) origin_todo_length = len(agent.todo) # 记录原 todo 的长度 self.command_supply(agent_id, subtask=True) state = agent.todo[-1].state if agent.todo else agent.state agent.todo.append(Action(agent_id=agent_id, action_type=ActionType.SWITCH_WEAPON, weapon_idx=weapon_idx, weapon_name=agent.switchable_weapons[weapon_idx].name, des=state.pos, start_time=start_time, state=state)) command = Command(agent_id=agent_id, action_type=ActionType.SWITCH_WEAPON, weapon_idx=weapon_idx, weapon_name=agent.switchable_weapons[weapon_idx].name, des=state.pos, start_time=start_time, end_time=start_time+len(agent.todo)-origin_todo_length-1, state=state) if not subtask: end_time = command.end_time for i in range(origin_todo_length, len(agent.todo)): agent.todo[i].start_time = start_time + i - origin_todo_length agent.todo[i].end_time = end_time agent.todo[-1].end_type = ActionType.SWITCH_WEAPON agent.cmd_todo.append(command) return command def make_plan(self, command: Command, subtask=False): if command.action_type == ActionType.MOVE: return self.command_move(command.agent_id, command.des, start_time=command.start_time, subtask=subtask) elif command.action_type == ActionType.ATTACK: return self.command_attack(command.agent_id, command.target_id, attack_count=command.attack_count, start_time=command.start_time, subtask=subtask) elif command.action_type == ActionType.SUPPLY: return self.command_supply(command.agent_id, start_time=command.start_time, subtask=subtask) elif command.action_type == ActionType.SWITCH_WEAPON: return self.command_switch(command.agent_id, command.weapon_idx, start_time=command.start_time, subtask=subtask) else: raise ValueError(f"Invalid command type: {command.action_type}") def todo_action(self, agent_id: str, retry=0) -> Optional[Action]: agent = self.get_agent(agent_id) if not agent.todo: return None action = agent.todo.pop(0) valid, msg = self.check_validity(action) if valid or retry > 2: if agent.cmd_todo[0].action_type == action.end_type: agent.cmd_todo.pop(0) return action print(f"Retrying to plan actions for {agent_id}: {msg}") if action.action_type == ActionType.RELEASE: for pos in get_adj_pos(action.des, 1): if pos in self.map.nodes and not self.map.nodes[pos].agent_id: action.des = pos return action return None # Wait command = agent.cmd_todo.pop(0) # 处理异常 while agent.todo and agent.todo[0].end_type != command.action_type: agent.todo.pop(0) if agent.todo: agent.todo.pop(0) current_actions = agent.todo.copy() current_commands = agent.cmd_todo.copy() self.make_plan(command) agent.todo.extend(current_actions) agent.cmd_todo.extend(current_commands) return self.todo_action(agent_id, retry=retry+1) def check_validity(self, action: Action) -> Tuple[bool, str]: """ 检查动作是否合法 """ if action.action_type == ActionType.END_OF_TURN: return True, "OK" map_data = self.map.nodes spotted_enemy_ids = self.spotted_enemy_ids if action.agent_id not in self.current_agents_dict: return False, f"Invalid agent id: {action.agent_id}" agent = self.get_agent(action.agent_id) if action.action_type not in agent.action_space.keys(): return False, f"Invalid action {action.action_type.name} for agent {agent.agent_id[:8]}" if action.action_type == ActionType.MOVE: des = action.des node = map_data[des] if node.faction_id == agent.faction_id and (not agent.agent_type in (target := self.get_agent(node.agent_id)).available_types or len(target.parked_agents) >= target.capacity): return False, f"{node.agent_id[:8]} at {des} is an ally but not available" if node.agent_id in spotted_enemy_ids: return False, f"{des} has been occupied by enemy {node.agent_id[:8]}" if get_axial_dis(agent.pos, des) > agent.mobility: return False, f"Destination is out of range. Mobility: {agent.mobility} Euclidean distance: {get_axial_dis(agent.pos, des)}" try: mobility = min(agent.fuel, agent.mobility) adj_pos = get_adj_pos(agent.pos, int(mobility)) aval_data = { pos: node for pos in adj_pos if pos in map_data and ( (node := map_data[pos]).faction_id == self.faction_id or node.agent_id not in spotted_enemy_ids ) } aval_nodes = astar_search(aval_data, agent.move_type, start=agent.pos, limit=mobility) if action.des in aval_nodes: return True, "OK" else: return False, f"{des} could not be reached in one step." except Exception as e: return False, f"Failed to move to {des}: {e}" elif action.action_type == ActionType.ATTACK: try: target_agent = self.get_agent(action.target_id) except Exception as e: return False, f"Invalid target agent id: {action.target_id}, ignored" des = action.des if agent.faction_id == target_agent.faction_id: return False, f"Cannot attack allies. Attack: {action.agent_id}. Defend: {action.target_id}" if agent.weapon is None: return False, "No weapon equipped" if agent.ammo <= 0: return False, "No ammo left" if target_agent.move_type not in agent.strike_types: return False, "Target agent cannot be attacked with this weapon" if get_axial_dis(agent.pos, des) > agent.attack_range: return False, f"Target agent is out of range. Attack range: {agent.attack_range} Distance: {get_axial_dis(agent.pos, des)}" elif action.action_type == ActionType.SWITCH_WEAPON: weapon_idx = action.weapon_idx around_supply = False for pos in get_adj_pos(agent.pos, 1): if pos in map_data and map_data[pos].agent_id and (target_agent := self.get_agent(map_data[pos].agent_id)).faction_id == agent.faction_id and target_agent.has_supply: around_supply = True if not around_supply: return False, "Agent is not around a supply module" if weapon_idx >= len(agent.switchable_weapons): return False, f"Invalid weapon id {weapon_idx}. Max id: {len(agent.switchable_weapons) - 1}" elif action.action_type == ActionType.RELEASE: if action.target_id >= len(agent.parked_agents): return False, f"Trying to release an agent that is not in the queue. Target id: {action.target_id}, Current queue: {agent.parked_agents}" target_agent = agent.parked_agents[action.target_id] des = action.des if get_axial_dis(agent.pos, des) > 1: return False, "Only adjacent positions can be released" if map_data[des].team_id != -1: return False, "Cannot release agent to occupied position" elif action.action_type == ActionType.INTERACT: agent = self.get_agent(action.agent_id) target_agent = self.get_agent(action.target_id) if get_axial_dis(agent.pos, target_agent.pos) > 1: return False, "Target agent is out of range" if agent.faction_id != target_agent.faction_id: return False, "Cannot interact with enemy agent" if agent.agent_type not in target_agent.available_types: return False, f"Target agent {target_agent.agent_id[:8]} is not interactable with agent {agent}" if len(target_agent.parked_agents) >= target_agent.capacity: return False, f"Target agent {target_agent.agent_id[:8]} full" return True, "OK" def update(self, sync_info): """ 使用真实引擎 """ sync_agents = sync_info['units'] spotted_enemies = sync_info['spottedHostiles'] map_data = self.map.nodes reward = 0 for sync_agent in [*sync_agents, *spotted_enemies]: agent_id = sync_agent['agent_id'] agent = self.get_agent(agent_id) # 清除与地图关联 if agent.pos != (-1, -1): node = map_data[agent.pos] if node.agent_id == agent_id: node.agent_id = None node.team_id = -1 node.faction_id = -1 agent.pos = (sync_agent['pos']['q'], sync_agent['pos']['r']) agent.fuel = sync_agent['fuel'] if agent.faction_id != self.faction_id: reward += agent.endurance - sync_agent['endurance'] agent.endurance = sync_agent['endurance'] agent.commenced_action = sync_agent['commenced_action'] current_weapon = sync_agent['current_weapon'] if current_weapon: for weapon in agent.switchable_weapons: if weapon.name == current_weapon: agent.weapon = weapon break agent.weapon.ammo = sync_agent['ammo'] else: agent.weapon = None # 更新与地图关联 if agent.alive: node = map_data[agent.pos] node.agent_id = agent_id node.team_id = agent.team_id node.faction_id = agent.faction_id for agent_id in self._spotted_enemy_ids: agent = self.get_agent(agent_id) # 清除与地图关联 if agent.pos != (-1, -1): node = map_data[agent.pos] if node.agent_id == agent_id: node.agent_id = None node.team_id = -1 node.faction_id = -1 # 更新敌人感知 self._spotted_enemy_ids = [enemy['agent_id'] for enemy in spotted_enemies] """ NOTE: Clear the states before observation """ self._legal_actions = None self._legal_action_ids = None self._spotted_agents = None obs = self.observe() truncated = self.episode_steps >= self.max_episode_steps terminated = sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0 self._cumulative_rewards[self.team_id] += reward info = {} info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1]) info["next_team"] = self.team_id + 1 if not terminated else -1 if self.replay_path is not None: self._frames[self.team_id].append(self.render(mode="rgb_array")) if terminated or truncated: # The eval_episode_return is calculated from Player 1's perspective info["eval_episode_return"] = reward if self.faction_id == self.init_faction_id else -reward info["done_reason"] = "Terminated" if terminated else "Truncated" if self.replay_path is not None: self.save_replay() return obs, reward, terminated, truncated, info def format_agents(self, agent_ids = None): if agent_ids is None: agent_ids = self.current_agents_dict.keys() return json_to_markdown(self.get_agents_dict(agent_ids)) def format_battlefields(self, battlefield_infos): dict_info = {} for battlefield_info in battlefield_infos: _dict_info = {} _dict_info["battlefield_name"] = battlefield_info["sectorName"] _dict_info["my_agents"] = self.get_agents_dict(battlefield_info["myUnitsID"]) _dict_info["ally_agents"] = self.get_agents_dict(battlefield_info["allyUnitsID"]) _dict_info["enemy_agents"] = self.get_agents_dict(battlefield_info["hostileUnitsID"]) dict_info[battlefield_info["sectorName"]] = _dict_info return json_to_markdown(dict_info)