1267 lines
57 KiB
Python
1267 lines
57 KiB
Python
|
import numpy as np
|
||
|
import json
|
||
|
import os
|
||
|
from datetime import datetime
|
||
|
from typing import List, Optional
|
||
|
import gymnasium as gym
|
||
|
from .common.utils import *
|
||
|
from .common.agent import *
|
||
|
|
||
|
class YesCmdrEnv(gym.Env):
|
||
|
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
|
||
|
|
||
|
def __init__(self,
|
||
|
use_real_engine: bool = False,
|
||
|
replay_path: str = None,
|
||
|
campaign_id: str = "NOID",
|
||
|
team_id: int = 0,
|
||
|
max_team: int = 2,
|
||
|
max_faction: int = 2,
|
||
|
data_path: str = "../data",
|
||
|
max_episode_steps: int = 1000,
|
||
|
bot_version: str = "v0",
|
||
|
war_fog: bool = True) -> None:
|
||
|
self._init_flag = False
|
||
|
# set other properties...
|
||
|
self.use_real_engine = use_real_engine
|
||
|
self.replay_path = replay_path
|
||
|
self.campaign_id = campaign_id
|
||
|
self.init_team_id = team_id
|
||
|
self.max_team = max_team
|
||
|
self.max_faction = max_faction
|
||
|
self.data_path = data_path
|
||
|
self.max_episode_steps = max_episode_steps
|
||
|
self.bot_version = bot_version
|
||
|
self.war_fog = war_fog
|
||
|
|
||
|
def reset(self):
|
||
|
# reset the environment...
|
||
|
# 实际的环境初始化是在第一次调用 reset 方法时进行的
|
||
|
if not self._init_flag:
|
||
|
self._init_flag = True
|
||
|
self._make_env()
|
||
|
if self.replay_path is not None:
|
||
|
from .common.renderer import Renderer
|
||
|
self._renderer = Renderer(self.map.nodes.values())
|
||
|
else:
|
||
|
for node in self.map.nodes.values():
|
||
|
node.reset()
|
||
|
for team in self.teams:
|
||
|
for agent_id, agent in team.agents.items():
|
||
|
agent.reset()
|
||
|
if agent.pos == (-1, -1): # Illegal position for padding
|
||
|
continue
|
||
|
assert agent.pos in self.map.nodes, f"Invalid agent position: {agent.pos}"
|
||
|
self.map.nodes[agent.pos].agent_id = agent_id
|
||
|
self.map.nodes[agent.pos].team_id = team.team_id
|
||
|
if self.replay_path is not None:
|
||
|
self._frames = [[], []]
|
||
|
self._legal_actions = None
|
||
|
self._legal_action_ids = None
|
||
|
self._spotted_enemy_ids = None
|
||
|
self._team_id = self.init_team_id
|
||
|
self._spotted_agents = None
|
||
|
self.episode_steps = 0
|
||
|
self._cumulative_rewards = np.zeros(2)
|
||
|
|
||
|
obs = self.observe()
|
||
|
info = {
|
||
|
"next_team": self.team_id + 1
|
||
|
}
|
||
|
|
||
|
return obs, info
|
||
|
|
||
|
def step(self, action: Union[int, Action]):
|
||
|
if isinstance(action, int):
|
||
|
if action not in self.legal_action_ids:
|
||
|
print(f"Invalid action will be ignored: {action}")
|
||
|
info = {}
|
||
|
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
|
||
|
info["next_team"] = self.team_id + 1
|
||
|
info["exception"] = f"Invalid action {action}"
|
||
|
return self.observe(), 0, False, False, info
|
||
|
action = self.id_to_action(action)
|
||
|
elif isinstance(action, Action):
|
||
|
valid, reason = self.check_validity(action)
|
||
|
if not valid:
|
||
|
if action.action_type == ActionType.MOVE and reason.find("occupied") > 0:
|
||
|
print("Trying to move to an adjacent position that is empty.")
|
||
|
dis = float('inf')
|
||
|
for _action in self.legal_actions:
|
||
|
if _action.agent_id == action.agent_id and _action.action_type == ActionType.MOVE and get_axial_dis(_action.des, action.des) < dis:
|
||
|
dis = get_axial_dis(_action.des, action.des)
|
||
|
action = _action
|
||
|
else:
|
||
|
print(f"Invalid action will be ignored: {action}")
|
||
|
print(f"Reason: {reason}")
|
||
|
info = {}
|
||
|
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
|
||
|
info["next_team"] = self.team_id + 1
|
||
|
info["exception"] = f"Invalid action: {action}. Reason: {reason}"
|
||
|
return self.observe(), 0, False, False, info
|
||
|
|
||
|
assert(isinstance(action, Action))
|
||
|
|
||
|
reward, terminated = self._player_step(action)
|
||
|
|
||
|
"""
|
||
|
NOTE: Clear the states before observation
|
||
|
"""
|
||
|
self._legal_actions = None
|
||
|
self._legal_action_ids = None
|
||
|
self._spotted_enemy_ids = None
|
||
|
|
||
|
obs = self.observe()
|
||
|
truncated = self.episode_steps >= self.max_episode_steps
|
||
|
self._cumulative_rewards[self.team_id] += reward
|
||
|
|
||
|
info = {}
|
||
|
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
|
||
|
info["next_team"] = self.team_id + 1 if not terminated else -1
|
||
|
|
||
|
if self.replay_path is not None:
|
||
|
self._frames[self.team_id].append(self.render(mode="rgb_array"))
|
||
|
|
||
|
if terminated or truncated:
|
||
|
# The eval_episode_return is calculated from Player 1's perspective
|
||
|
info["eval_episode_return"] = reward if self.team_id == 0 else -reward
|
||
|
info["done_reason"] = "Terminated" if terminated else "Truncated"
|
||
|
|
||
|
if self.replay_path is not None:
|
||
|
self.save_replay()
|
||
|
|
||
|
return obs, reward, terminated, truncated, info
|
||
|
|
||
|
def render(self, mode="human", attack_agent=None, defend_agent=None):
|
||
|
if mode == "rgb_array":
|
||
|
player_agents = self.teams[self.team_id].agents.values()
|
||
|
return self._renderer.render(self.current_agents, self.spotted_enemy_agents, attack_agent, defend_agent)
|
||
|
elif mode == "human":
|
||
|
return self.observe()
|
||
|
else:
|
||
|
raise ValueError("Invalid render mode: {}".format(mode))
|
||
|
|
||
|
def action_to_id(self, action: Action) -> int:
|
||
|
if action.action_type == ActionType.END_OF_TURN:
|
||
|
return 0
|
||
|
action_space = self.action_space
|
||
|
if action.agent_id not in action_space.keys():
|
||
|
raise ValueError(f"Invalid agent_id: {action.agent_id[:8]}")
|
||
|
agent_action_space = action_space[action.agent_id]
|
||
|
if action.action_type not in agent_action_space.keys():
|
||
|
agent = self.get_agent(action.agent_id)
|
||
|
print(agent.available_types, agent.parked_agents)
|
||
|
print(action)
|
||
|
raise ValueError(f"Invalid action_type: {action.action_type.name} for agent {action.agent_id[:8]}. Available action_types: {agent_action_space.keys()}")
|
||
|
action_id = 1
|
||
|
for agent_id, agent_action_space in action_space.items():
|
||
|
for action_type, action_space in agent_action_space.items():
|
||
|
if agent_id != action.agent_id or action_type != action.action_type:
|
||
|
action_id += action_space.n
|
||
|
else:
|
||
|
if action.action_type == ActionType.SWITCH_WEAPON: # 以下标编码,不以位置编码
|
||
|
assert action.weapon_idx < action_space.n, f"Invalid weapon_idx: {action.weapon_idx}, action_space.n: {action_space.n}"
|
||
|
action_id += action.weapon_idx
|
||
|
elif action.action_type == ActionType.RELEASE:
|
||
|
assert action.target_id < action_space.n, f"Invalid target_idx: {action.target_id}, action_space.n: {action_space.n}"
|
||
|
action_id += action.target_id # 此处为下标
|
||
|
else:
|
||
|
pos, des = self.get_agent(action.agent_id).pos, action.des
|
||
|
encode_id = encode_axial((des[0] - pos[0], des[1] - pos[1]))
|
||
|
action_id += encode_id - 1
|
||
|
if action.action_type == ActionType.MOVE:
|
||
|
cur_range = self.get_agent(action.agent_id).mobility
|
||
|
elif action.action_type == ActionType.ATTACK:
|
||
|
cur_range = self.get_agent(action.agent_id).attack_range
|
||
|
elif action.action_type == ActionType.INTERACT:
|
||
|
cur_range = 1
|
||
|
assert encode_id - 1 < action_space.n, f"Action {action.action_type.name} Distance: {get_axial_dis(pos, des)}, cur_range: {cur_range} encode_id: {encode_id} action_space.n: {action_space.n}\n{action}"
|
||
|
return int(action_id)
|
||
|
assert False, "Should not reach here"
|
||
|
|
||
|
def id_to_action(self, action_id: int) -> Action:
|
||
|
if action_id == 0:
|
||
|
return Action(ActionType.END_OF_TURN)
|
||
|
action_id -= 1
|
||
|
action_space = self.action_space
|
||
|
for agent_id, agent_action_space in action_space.items():
|
||
|
for action_type, action_space in agent_action_space.items():
|
||
|
if action_id < action_space.n:
|
||
|
if action_type == ActionType.SWITCH_WEAPON:
|
||
|
return Action(agent_id=agent_id, action_type=action_type, weapon_idx=action_id, weapon_name=self.get_agent(agent_id).weapon.name)
|
||
|
elif action_type == ActionType.RELEASE:
|
||
|
return Action(agent_id=agent_id, target_id=action_id, action_type=action_type)
|
||
|
else:
|
||
|
pos = self.get_agent(agent_id).pos
|
||
|
delta = decode_axial(action_id + 1) # 编码 0 是原位置
|
||
|
des = (pos[0] + delta[0], pos[1] + delta[1])
|
||
|
return Action(agent_id=agent_id, des=des, action_type=action_type)
|
||
|
else:
|
||
|
action_id -= action_space.n
|
||
|
raise ValueError("Invalid action_id: {}".format(action_id))
|
||
|
|
||
|
def get_agent(self, agent_id: str) -> Agent:
|
||
|
for team in self.teams:
|
||
|
if agent_id in team.agents.keys():
|
||
|
agent = team.agents[agent_id]
|
||
|
# if not agent.alive:
|
||
|
# raise ValueError(f"Agent {agent_id} is not alive")
|
||
|
return agent
|
||
|
raise ValueError(f"Invalid agent id: {agent_id}")
|
||
|
|
||
|
def get_node(self, pos: Tuple[int, int]) -> TileNode:
|
||
|
if pos not in self.map.nodes:
|
||
|
raise ValueError(f"Invalid position: {pos}")
|
||
|
return self.map.nodes[pos]
|
||
|
|
||
|
@property
|
||
|
def legal_actions(self) -> List[Action]:
|
||
|
if self._legal_actions is not None:
|
||
|
return self._legal_actions
|
||
|
|
||
|
map_data = self.map.nodes
|
||
|
player_data = self.current_agents_dict
|
||
|
spotted_enemy_ids = self.spotted_enemy_ids
|
||
|
|
||
|
_legal_actions = [Action(ActionType.END_OF_TURN)] # Default action
|
||
|
|
||
|
for agent_id, agent in player_data.items():
|
||
|
if agent.commenced_action or not agent.alive or agent.pos == (-1, -1) or agent.is_carried:
|
||
|
continue
|
||
|
pos = agent.pos
|
||
|
# 移动
|
||
|
mobility = min(agent.fuel, agent.mobility)
|
||
|
if mobility > 0:
|
||
|
adj_pos = get_adj_pos(pos, int(mobility))
|
||
|
aval_data = {
|
||
|
pos: node
|
||
|
for pos in adj_pos if pos in map_data and
|
||
|
(
|
||
|
(node := map_data[pos]).team_id != self.enemy_id
|
||
|
or
|
||
|
node.agent_id not in spotted_enemy_ids
|
||
|
)
|
||
|
}
|
||
|
aval_nodes = astar_search(aval_data, agent.move_type, start=pos, limit=mobility)
|
||
|
for des in aval_nodes:
|
||
|
if des == pos:
|
||
|
continue
|
||
|
node = map_data[des]
|
||
|
if node.team_id != self.team_id and node.agent_id not in spotted_enemy_ids:
|
||
|
# 一个格子只能有一个单位
|
||
|
_legal_actions.append(Action(
|
||
|
agent_id=agent_id,
|
||
|
des=des,
|
||
|
action_type=ActionType.MOVE
|
||
|
))
|
||
|
elif node.team_id == self.team_id and agent.agent_type in (target := self.get_agent(node.agent_id)).available_types \
|
||
|
and len(target.parked_agents) < target.capacity:
|
||
|
# 除非终点有可以停靠的单位
|
||
|
_legal_actions.append(Action(
|
||
|
agent_id=agent_id,
|
||
|
target_id=node.agent_id,
|
||
|
des=des,
|
||
|
action_type=ActionType.MOVE
|
||
|
))
|
||
|
# 进攻
|
||
|
if agent.attack_range > 0 and agent.ammo > 0:
|
||
|
for des in get_adj_pos(pos, agent.attack_range):
|
||
|
if des not in map_data:
|
||
|
continue
|
||
|
node = map_data[des]
|
||
|
if node.agent_id in spotted_enemy_ids and self.get_agent(node.agent_id).move_type in agent.strike_types:
|
||
|
# 已发现敌军
|
||
|
_legal_actions.append(Action(
|
||
|
agent_id=agent_id,
|
||
|
target_id=node.agent_id,
|
||
|
action_type=ActionType.ATTACK,
|
||
|
des=des
|
||
|
))
|
||
|
# 交互
|
||
|
for des in get_adj_pos(pos, 1):
|
||
|
if des not in map_data:
|
||
|
continue
|
||
|
node = map_data[des]
|
||
|
if node.team_id == self.team_id and agent.agent_type in (target :=self.get_agent(node.agent_id)).available_types \
|
||
|
and len(target.parked_agents) < target.capacity:
|
||
|
_legal_actions.append(Action(
|
||
|
agent_id=agent_id,
|
||
|
target_id=node.agent_id,
|
||
|
des=des,
|
||
|
action_type=ActionType.INTERACT
|
||
|
))
|
||
|
# 释放
|
||
|
if agent.parked_agents:
|
||
|
for des in get_adj_pos(pos, 1):
|
||
|
if des not in map_data:
|
||
|
continue
|
||
|
node = map_data[des]
|
||
|
if node.team_id != -1:
|
||
|
continue
|
||
|
for idx, agent_to_release in enumerate(agent.parked_agents):
|
||
|
if get_cost(agent_to_release.move_type, node.terrain_type) != 0:
|
||
|
_legal_actions.append(Action(
|
||
|
agent_id=agent_id,
|
||
|
target_id=idx,
|
||
|
des=des,
|
||
|
action_type=ActionType.RELEASE
|
||
|
))
|
||
|
break
|
||
|
# 切换武器
|
||
|
if agent.switchable_weapons:
|
||
|
for des in get_adj_pos(pos, 1):
|
||
|
if des not in map_data:
|
||
|
continue
|
||
|
node = map_data[des]
|
||
|
if node.team_id == self.team_id and self.get_agent(node.agent_id).has_supply:
|
||
|
for weapon_idx, weapon in enumerate(agent.switchable_weapons):
|
||
|
if not agent.weapon or weapon.name != agent.weapon.name:
|
||
|
_legal_actions.append(Action(
|
||
|
agent_id=agent_id,
|
||
|
target_id=weapon.name,
|
||
|
weapon_idx=weapon_idx,
|
||
|
action_type=ActionType.SWITCH_WEAPON
|
||
|
))
|
||
|
|
||
|
|
||
|
self._legal_actions = _legal_actions
|
||
|
return self._legal_actions
|
||
|
|
||
|
@property
|
||
|
def legal_action_ids(self) -> List[int]:
|
||
|
if self._legal_action_ids is not None:
|
||
|
return self._legal_action_ids
|
||
|
self._legal_action_ids = [self.action_to_id(action) for action in self.legal_actions]
|
||
|
return self._legal_action_ids
|
||
|
|
||
|
@property
|
||
|
def spotted_enemy_ids(self) -> List[str]:
|
||
|
if not self.war_fog:
|
||
|
return self.teams[self.enemy_id].alive_ids
|
||
|
|
||
|
if self._spotted_enemy_ids is not None:
|
||
|
return self._spotted_enemy_ids
|
||
|
|
||
|
player_agents = self.current_agents
|
||
|
enemy_data = self.enemy_agents_dict
|
||
|
map_data = self.map.nodes
|
||
|
_spotted_enemy_ids = []
|
||
|
|
||
|
for agent in player_agents:
|
||
|
info_level = agent.info_level
|
||
|
for des in get_adj_pos(agent.pos, agent.info_level):
|
||
|
if des not in map_data:
|
||
|
continue
|
||
|
node = map_data[des]
|
||
|
dis = get_axial_dis(agent.pos, des)
|
||
|
if dis == 1: # Adjacent agents are always spotted
|
||
|
dis = -1e9
|
||
|
if node.team_id == self.enemy_id \
|
||
|
and enemy_data[node.agent_id].stealth_level < info_level - dis + 1:
|
||
|
_spotted_enemy_ids.append(node.agent_id)
|
||
|
|
||
|
self._spotted_enemy_ids = _spotted_enemy_ids
|
||
|
return self._spotted_enemy_ids
|
||
|
|
||
|
def save_replay(self):
|
||
|
if not os.path.exists(self.replay_path):
|
||
|
os.makedirs(self.replay_path)
|
||
|
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
||
|
for team_id in range(self.max_team):
|
||
|
path = os.path.join(
|
||
|
self.replay_path,
|
||
|
f"tianqiong_{timestamp}_Team{team_id + 1}.mp4"
|
||
|
)
|
||
|
self.display_frames_as_mp4(self._frames[team_id], path)
|
||
|
print(f'replay {path} saved!')
|
||
|
|
||
|
@staticmethod
|
||
|
def display_frames_as_mp4(frames: list, path: str, fps=4) -> None:
|
||
|
assert path.endswith('.mp4'), f'path must end with .mp4, but got {path}'
|
||
|
import imageio
|
||
|
imageio.mimwrite(path, frames, fps=fps)
|
||
|
|
||
|
def _make_env(self):
|
||
|
# Configuration for file paths
|
||
|
map_path = f"{self.data_path}/MapInfo.json"
|
||
|
agents_path = f"{self.data_path}/AgentsInfo.json"
|
||
|
|
||
|
# Load map and agents information
|
||
|
with open(map_path, 'r') as f:
|
||
|
origin_map = json.load(f)
|
||
|
with open(agents_path, 'r') as f:
|
||
|
origin_agents = json.load(f)
|
||
|
|
||
|
_nodes = {
|
||
|
(node := dict_to_node(node_dict)).pos: node for node_dict in origin_map
|
||
|
}
|
||
|
agents_dict_lists = [
|
||
|
[agent_dict for agent_dict in origin_agents if agent_dict["team_id"] == team_id]
|
||
|
for team_id in range(self.max_team)
|
||
|
]
|
||
|
_agents = [
|
||
|
{
|
||
|
(agent := dict_to_agent(agent_dict)).agent_id: agent
|
||
|
for agent_dict in agents_dict_lists[team_id]
|
||
|
}
|
||
|
for team_id in range(self.max_team)
|
||
|
]
|
||
|
_teams = [
|
||
|
Team(
|
||
|
team_id=_team_id,
|
||
|
faction_id=_agents[_team_id][0].faction_id,
|
||
|
agents=_agents[_team_id]
|
||
|
)
|
||
|
for _team_id in range(self.max_team)
|
||
|
]
|
||
|
|
||
|
for team in _teams:
|
||
|
for agent_id, agent in team.agents.items():
|
||
|
if agent.pos == (-1, -1): # Illegal position for padding
|
||
|
continue
|
||
|
assert agent.pos in _nodes, f"Invalid agent position: {agent.pos}"
|
||
|
_nodes[agent.pos].agent_id = agent_id
|
||
|
_nodes[agent.pos].team_id = team.team_id
|
||
|
|
||
|
_map = Map(_nodes)
|
||
|
self.map = _map
|
||
|
self.teams = _teams
|
||
|
self.obs_shape = _map.width, _map.height
|
||
|
|
||
|
# Get action space
|
||
|
self._action_spaces = [
|
||
|
gym.spaces.Dict(
|
||
|
{
|
||
|
f"team_{team.team_id + 1}": gym.spaces.Dict(
|
||
|
{
|
||
|
ActionType.END_OF_TURN: gym.spaces.Discrete(1) # Default action
|
||
|
}
|
||
|
),
|
||
|
**{
|
||
|
agent_id: agent.action_space for agent_id, agent in team.agents.items()
|
||
|
}
|
||
|
}
|
||
|
)
|
||
|
for team in self.teams
|
||
|
]
|
||
|
|
||
|
self._action_space_sizes = [
|
||
|
sum(action_space.n for agent_action_space in self._action_spaces[team_id].values() for action_space in agent_action_space.values())
|
||
|
for team_id in range(self.max_team)
|
||
|
]
|
||
|
|
||
|
self._team_id = 0
|
||
|
for action_space_size in self._action_space_sizes:
|
||
|
for action_id in range(action_space_size):
|
||
|
assert self.action_to_id(self.id_to_action(action_id)) == action_id
|
||
|
self._team_id += 1
|
||
|
|
||
|
# Get observation space
|
||
|
self._observation_spaces = [
|
||
|
gym.spaces.Dict(
|
||
|
{
|
||
|
"action_mask": gym.spaces.Box(0, 1, (self._action_space_sizes[team_id], ), dtype=np.int8), # Different size for each team
|
||
|
"union_endurance": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
|
||
|
"union_info_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
|
||
|
"union_stealth_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
|
||
|
"union_mobility": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
|
||
|
"union_defense": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
|
||
|
"union_damage": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
|
||
|
"union_fuel": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
|
||
|
"union_ammo": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
|
||
|
"enemy_endurance": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
|
||
|
"enemy_info_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
|
||
|
"enemy_stealth_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
|
||
|
"enemy_mobility": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
|
||
|
"enemy_defense": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
|
||
|
"enemy_damage": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
|
||
|
"enemy_fuel": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
|
||
|
"enemy_ammo": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16)
|
||
|
}
|
||
|
)
|
||
|
for team_id in range(self.max_team)
|
||
|
]
|
||
|
|
||
|
self._reward_space = gym.spaces.Box(low=-1000, high=1000, shape=(1, ), dtype=np.float32)
|
||
|
|
||
|
# 预处理己方联盟和所有敌方的单位信息
|
||
|
self._union_agents = [[] for _ in range(self.max_faction)]
|
||
|
self._union_agents_dict = [{} for _ in range(self.max_faction)]
|
||
|
self._enemy_agents = [[] for _ in range(self.max_faction)]
|
||
|
self._enemy_agents_dict = [{} for _ in range(self.max_faction)]
|
||
|
for team_id in range(self.max_team):
|
||
|
faction_id = self.teams[team_id].faction_id
|
||
|
agents_list = list(self.teams[team_id].agents.values())
|
||
|
agents_dict = self.teams[team_id].agents
|
||
|
self._union_agents[faction_id].extend(agents_list)
|
||
|
for other_faction_id in range(self.max_faction):
|
||
|
if other_faction_id != faction_id:
|
||
|
self._enemy_agents[other_faction_id].extend(agents_list)
|
||
|
self._enemy_agents_dict[other_faction_id] |= agents_dict
|
||
|
|
||
|
|
||
|
@property
|
||
|
def team_id(self):
|
||
|
return self._team_id
|
||
|
|
||
|
@property
|
||
|
def current_agents(self):
|
||
|
return self.teams[self._team_id].agents.values()
|
||
|
|
||
|
@property
|
||
|
def current_agents_dict(self):
|
||
|
return self.teams[self._team_id].agents
|
||
|
|
||
|
@property
|
||
|
def enemy_agents(self):
|
||
|
return self._enemy_agents[self.faction_id]
|
||
|
# if self._enemy_agents is not None:
|
||
|
# return self._enemy_agents
|
||
|
# _enemy_agents = []
|
||
|
# for team_id in range(self.max_team):
|
||
|
# if self.teams[team_id].faction_id != self.faction_id:
|
||
|
# _enemy_agents.extend(list(self.teams[team_id].agents.values()))
|
||
|
# self._enemy_agents = _enemy_agents
|
||
|
# return self._enemy_agents
|
||
|
|
||
|
@property
|
||
|
def enemy_agents_dict(self):
|
||
|
return self._enemy_agents_dict[self.faction_id]
|
||
|
# if self._enemy_agents_dict is not None:
|
||
|
# return self._enemy_agents_dict
|
||
|
# _enemy_agents_dict = {}
|
||
|
# for team_id in range(self.max_team):
|
||
|
# if self.teams[team_id].faction_id != self.faction_id:
|
||
|
# for agent_id, agent in self.teams[team_id].agents.items():
|
||
|
# _enemy_agents_dict[agent_id] = agent
|
||
|
# self._enemy_agents_dict = _enemy_agents_dicct
|
||
|
# return self._enemy_agents_dict
|
||
|
|
||
|
@property
|
||
|
def spotted_enemy_agents(self):
|
||
|
if self._spotted_enemy_agents is not None:
|
||
|
return self._spotted_enemy_agents
|
||
|
_spotted_enemy_agents = [self.enemy_agents_dict[enemy_id] for enemy_id in self.spotted_enemy_ids]
|
||
|
self._spotted_enemy_agents = _spotted_enemy_agents
|
||
|
|
||
|
@property
|
||
|
def faction_id(self):
|
||
|
return self.teams[self._team_id].faction_id
|
||
|
|
||
|
@property
|
||
|
def observation_space(self):
|
||
|
return self._observation_spaces[self.team_id]
|
||
|
|
||
|
@property
|
||
|
def action_space(self):
|
||
|
return self._action_spaces[self.team_id]
|
||
|
|
||
|
@property
|
||
|
def action_space_size(self):
|
||
|
return self._action_space_sizes[self.team_id]
|
||
|
|
||
|
@property
|
||
|
def reward_space(self):
|
||
|
return self._reward_space
|
||
|
|
||
|
def random_action(self) -> Action:
|
||
|
return np.random.choice(self.legal_actions)
|
||
|
|
||
|
def next_turn(self):
|
||
|
self._team_id = (self._team_id + 1) % self.max_team
|
||
|
|
||
|
def bot_action(self) -> Action:
|
||
|
if self.bot_version == "v0":
|
||
|
attack_actions = [action for action in self.legal_actions if action.action_type == ActionType.ATTACK]
|
||
|
if attack_actions:
|
||
|
return np.random.choice(attack_actions)
|
||
|
else:
|
||
|
return self.random_action()
|
||
|
else:
|
||
|
raise NotImplementedError(f"Invalid bot version: {self.bot_version}")
|
||
|
|
||
|
def _player_step(self, action: Action):
|
||
|
if action.action_type == ActionType.END_OF_TURN:
|
||
|
"""
|
||
|
NOTE: here exchange the player
|
||
|
"""
|
||
|
# 清除动作标记
|
||
|
for agent in self.current_agents:
|
||
|
agent.commenced_action = False
|
||
|
# 切换玩家
|
||
|
self.episode_steps += 1
|
||
|
if not self.use_real_engine:
|
||
|
self.next_turn()
|
||
|
reward, terminated = 0, False # 这里可以加上惩罚项
|
||
|
elif action.action_type == ActionType.MOVE:
|
||
|
reward = self._move(action.agent_id, action.des, action.target_id)
|
||
|
terminated = False
|
||
|
elif action.action_type == ActionType.ATTACK:
|
||
|
reward, terminated = self._attack(action.agent_id, action.target_id)
|
||
|
elif action.action_type == ActionType.INTERACT:
|
||
|
reward = self._interact(action.agent_id, action.target_id)
|
||
|
terminated = False
|
||
|
elif action.action_type == ActionType.RELEASE:
|
||
|
reward = self._release(action.agent_id, action.target_id, action.des)
|
||
|
terminated = False
|
||
|
elif action.action_type == ActionType.SWITCH_WEAPON:
|
||
|
reward = self._switch_weapon(action.agent_id, action.weapon_idx)
|
||
|
terminated = False
|
||
|
else:
|
||
|
raise ValueError(f"Invalid action type: {action.action_type}")
|
||
|
|
||
|
for agent in self.current_agents:
|
||
|
for pos in get_adj_pos(agent.pos, 1):
|
||
|
if pos in self.map.nodes and self.map.nodes[pos].team_id == self.team_id:
|
||
|
adj_agent = self.get_agent(self.map.nodes[pos].agent_id)
|
||
|
if not adj_agent.has_supply:
|
||
|
continue
|
||
|
module = adj_agent.supply
|
||
|
if module.add_endurance > 0:
|
||
|
agent.endurance += min(module.add_endurance, agent.max_endurance - agent.endurance)
|
||
|
else:
|
||
|
agent.endurance = agent.max_endurance
|
||
|
if agent.weapon is not None:
|
||
|
if module.add_ammo > 0:
|
||
|
agent.weapon.ammo += min(module.add_ammo, agent.weapon.max_ammo - agent.weapon.ammo)
|
||
|
else:
|
||
|
agent.weapon.ammo = agent.weapon.max_ammo
|
||
|
if module.add_fuel > 0:
|
||
|
agent.fuel += min(module.add_fuel, agent.max_fuel - agent.fuel)
|
||
|
else:
|
||
|
agent.fuel = agent.max_fuel
|
||
|
|
||
|
return reward, terminated
|
||
|
|
||
|
def _move(self, agent_id, des, target_id):
|
||
|
map_data = self.map.nodes
|
||
|
agent = self.current_agents_dict[agent_id]
|
||
|
spotted_enemy_ids = self.spotted_enemy_ids
|
||
|
origin_pos = agent.pos
|
||
|
# 清除原位置关联
|
||
|
node = map_data[agent.pos]
|
||
|
node.agent_id = None
|
||
|
node.team_id = -1
|
||
|
|
||
|
mobility = min(agent.mobility, agent.fuel)
|
||
|
adj_pos = get_adj_pos(agent.pos, int(mobility))
|
||
|
aval_data = {
|
||
|
pos: node
|
||
|
for pos in adj_pos if pos in map_data and
|
||
|
(
|
||
|
(node := map_data[pos]).team_id != self.enemy_id
|
||
|
or
|
||
|
node.agent_id not in spotted_enemy_ids
|
||
|
)
|
||
|
}
|
||
|
path = astar_search(aval_data, move_type=agent.move_type, start=agent.pos, goal=des, limit=mobility)
|
||
|
path = [agent.pos] + path
|
||
|
|
||
|
exception = False
|
||
|
|
||
|
for cur, nxt in zip(path[:-1], path[1:]):
|
||
|
if self.war_fog and map_data[nxt].team_id == self.enemy_id: # 前进方向遇到敌军,停下
|
||
|
exception = True
|
||
|
break
|
||
|
if self.replay_path is not None:
|
||
|
self._frames[self.team_id].append(self.render(mode="rgb_array"))
|
||
|
agent.fuel -= get_cost(agent.move_type, map_data[nxt].terrain_type)
|
||
|
agent.pos = nxt
|
||
|
|
||
|
# 更新位置以及与地图的关联
|
||
|
cur = agent.pos
|
||
|
agent.commenced_action = True
|
||
|
|
||
|
if map_data[cur].team_id == -1:
|
||
|
map_data[cur].agent_id = agent_id
|
||
|
map_data[cur].team_id = self.team_id
|
||
|
elif map_data[cur].agent_id == target_id:
|
||
|
carry_agent = self.get_agent(target_id)
|
||
|
assert agent.team_id == carry_agent.team_id
|
||
|
assert agent.agent_type in carry_agent.available_types, f"{agent} not in {carry_agent.available_types} (id: {carry_agent.agent_id[:8]}). "
|
||
|
assert len(carry_agent.parked_agents) < carry_agent.capacity, f"Agent {carry_agent[:8]} is available but full"
|
||
|
carry_agent.parked_agents.append(agent)
|
||
|
agent.is_carried = True
|
||
|
else:
|
||
|
assert exception
|
||
|
|
||
|
return 0
|
||
|
|
||
|
def _attack(self, attack_id, target_id):
|
||
|
agent = self.current_agents[attack_id]
|
||
|
target_agent = self.teams[self.enemy_id].agents[target_id]
|
||
|
|
||
|
if self.replay_path is not None:
|
||
|
self._frames[self.team_id].extend([self.render(mode="rgb_array", attack_agent=agent, defend_agent=target_agent)] * 4)
|
||
|
|
||
|
damage = max(agent.damage - target_agent.defense, 0)
|
||
|
damage = min(damage, target_agent.endurance)
|
||
|
target_agent.endurance -= damage
|
||
|
agent.weapon.ammo -= 1
|
||
|
agent.commenced_action = True
|
||
|
|
||
|
if target_agent.endurance <= 0:
|
||
|
# 删除与地图关联
|
||
|
node = self.get_node(target_agent.pos)
|
||
|
node.agent_id = None
|
||
|
node.team_id = -1
|
||
|
|
||
|
if self.teams[self.enemy_id].alive_count == 0:
|
||
|
return damage, True
|
||
|
|
||
|
return damage, False
|
||
|
|
||
|
def _interact(self, agent_id, target_id):
|
||
|
agent = self.get_agent(agent_id)
|
||
|
target = self.get_agent(target_id)
|
||
|
target.parked_agents.append(agent)
|
||
|
agent.is_carried = True
|
||
|
return 0 # 后续可以考虑修改奖励为补给的线性组合
|
||
|
|
||
|
def _release(self, agent_id: str, target_idx: int, des: Tuple[int, int]):
|
||
|
agent = self.get_agent(agent_id)
|
||
|
agent_to_release = agent.parked_agents[target_idx]
|
||
|
agent_to_release.pos = des
|
||
|
node = self.get_node(des)
|
||
|
node.team_id = agent_to_release.team_id
|
||
|
node.agent_id = agent_to_release.agent_id
|
||
|
|
||
|
for module in agent.modules:
|
||
|
if module.add_endurance > 0:
|
||
|
agent_to_release.endurance += min(module.add_endurance, agent_to_release.max_endurance - agent_to_release.endurance)
|
||
|
else:
|
||
|
agent_to_release.endurance = agent_to_release.max_endurance
|
||
|
if agent_to_release.weapon is not None:
|
||
|
if module.add_ammo > 0:
|
||
|
agent_to_release.weapon.ammo += min(module.add_ammo, agent_to_release.weapon.max_ammo - agent_to_release.weapon.ammo)
|
||
|
else:
|
||
|
agent_to_release.weapon.ammo = agent_to_release.weapon.max_ammo
|
||
|
if module.add_fuel > 0:
|
||
|
agent_to_release.fuel += min(module.add_fuel, agent_to_release.max_fuel - agent_to_release.fuel)
|
||
|
else:
|
||
|
agent_to_release.fuel = agent_to_release.max_fuel
|
||
|
agent_to_release.is_carried = False
|
||
|
|
||
|
return 0 # 后续可以考虑修改奖励为补给的线性组合
|
||
|
|
||
|
def _switch_weapon(self, agent_id: str, weapon_idx: int):
|
||
|
agent = self.get_agent(agent_id)
|
||
|
agent.weapon.reset()
|
||
|
agent.weapon = agent.switchable_weapons[weapon_idx]
|
||
|
return 0
|
||
|
|
||
|
def observe(self):
|
||
|
player_agents = self.current_agents
|
||
|
spotted_enemy_agents = self.spotted_enemy_agents
|
||
|
|
||
|
obs = {}
|
||
|
|
||
|
for desc, space in self.observation_space.items():
|
||
|
obs[desc] = np.zeros(space.shape, dtype=space.dtype)
|
||
|
|
||
|
legal_actions_ids = self.legal_action_ids
|
||
|
|
||
|
for action_id in legal_actions_ids:
|
||
|
obs["action_mask"][action_id] = 1
|
||
|
|
||
|
for agent in player_agents:
|
||
|
pos = agent.pos
|
||
|
obs["player_info_level"][pos[0], pos[1]] = agent.info_level
|
||
|
obs["player_stealth_level"][pos[0], pos[1]] = agent.stealth_level
|
||
|
obs["player_mobility"][pos[0], pos[1]] = agent.mobility
|
||
|
obs["player_defense"][pos[0], pos[1]] = agent.defense
|
||
|
obs["player_damage"][pos[0], pos[1]] = agent.damage
|
||
|
obs["player_fuel"][pos[0], pos[1]] = agent.fuel
|
||
|
obs["player_ammo"][pos[0], pos[1]] = agent.ammo
|
||
|
obs["player_endurance"][pos[0], pos[1]] = agent.endurance
|
||
|
|
||
|
for agent in spotted_enemy_agents:
|
||
|
pos = agent.pos
|
||
|
obs["enemy_info_level"][pos[0], pos[1]] = agent.info_level
|
||
|
obs["enemy_stealth_level"][pos[0], pos[1]] = agent.stealth_level
|
||
|
obs["enemy_mobility"][pos[0], pos[1]] = agent.mobility
|
||
|
obs["enemy_defense"][pos[0], pos[1]] = agent.defense
|
||
|
obs["enemy_damage"][pos[0], pos[1]] = agent.damage
|
||
|
obs["enemy_fuel"][pos[0], pos[1]] = agent.fuel
|
||
|
obs["enemy_ammo"][pos[0], pos[1]] = agent.ammo
|
||
|
obs["enemy_endurance"][pos[0], pos[1]] = agent.endurance
|
||
|
|
||
|
return obs
|
||
|
|
||
|
def command_move(self, agent_id, des: Union[Tuple[int, int], List[Tuple[int, int]]], start_time=0, subtask=False) -> Command:
|
||
|
agent = self.get_agent(agent_id)
|
||
|
state = agent.todo[-1].state if agent.todo else agent.state
|
||
|
if state.pos == des or state.pos in des:
|
||
|
return
|
||
|
|
||
|
map_data = self.map.nodes
|
||
|
origin_todo_length = len(agent.todo) # 记录原 todo 的长度
|
||
|
spotted_enemy_ids = self.spotted_enemy_ids
|
||
|
adj_pos = get_adj_pos(state.pos, int(state.fuel))
|
||
|
adj_pos.append(state.pos)
|
||
|
aval_data = {
|
||
|
pos: node
|
||
|
for pos in adj_pos if pos in map_data and
|
||
|
(
|
||
|
(node := map_data[pos]).team_id != self.enemy_id
|
||
|
or
|
||
|
node.agent_id not in spotted_enemy_ids
|
||
|
)
|
||
|
}
|
||
|
path = get_path(agent=agent, map_data=aval_data, start=state.pos, des=des, limit=state.fuel) # 可能没有可行路径
|
||
|
|
||
|
for pos in path:
|
||
|
state.fuel -= get_cost(agent.move_type, map_data[pos].terrain_type)
|
||
|
state.pos = pos
|
||
|
agent.todo.append(Action(agent_id=agent_id, action_type=ActionType.MOVE, des=pos, start_time=start_time, state=state))
|
||
|
|
||
|
command = Command(agent_id=agent_id,
|
||
|
action_type=ActionType.MOVE,
|
||
|
des=des,
|
||
|
start_time=start_time,
|
||
|
end_time=start_time+len(agent.todo)-origin_todo_length-1,
|
||
|
state=state)
|
||
|
|
||
|
if not subtask:
|
||
|
end_time = command.end_time
|
||
|
for i in range(origin_todo_length, len(agent.todo)):
|
||
|
agent.todo[i].start_time = start_time + i - origin_todo_length
|
||
|
agent.todo[i].end_time = end_time
|
||
|
agent.todo[-1].end_type = ActionType.MOVE
|
||
|
agent.cmd_todo.append(command)
|
||
|
|
||
|
return command
|
||
|
|
||
|
def command_attack(self, attack_id, target_id, attack_count=1, start_time=0, subtask=False) -> Command:
|
||
|
agent = self.get_agent(attack_id)
|
||
|
target_agent = self.get_agent(target_id)
|
||
|
|
||
|
origin_todo_length = len(agent.todo) # 记录原 todo 的长度
|
||
|
|
||
|
state = agent.todo[-1].state if agent.todo else agent.state
|
||
|
|
||
|
if target_agent.move_type not in state.weapon.strike_types:
|
||
|
flag = False
|
||
|
no_enough_ammo = False
|
||
|
for weapon_idx, weapon in enumerate(agent.switchable_weapons):
|
||
|
if weapon.strike_types and target_agent.move_type in weapon.strike_types:
|
||
|
if weapon.max_ammo < attack_count:
|
||
|
no_enough_ammo = True
|
||
|
try:
|
||
|
self.command_switch(attack_id, weapon_idx, start_time=start_time, subtask=True)
|
||
|
flag = True
|
||
|
break
|
||
|
except Exception as e:
|
||
|
agent.todo = agent.todo[:origin_todo_length]
|
||
|
raise ValueError("Failed switching to the proper weapon")
|
||
|
if not flag:
|
||
|
if no_enough_ammo:
|
||
|
agent.todo = agent.todo[:origin_todo_length]
|
||
|
raise ValueError("Agent has the striking weapon but no one with enough ammo")
|
||
|
else:
|
||
|
agent.todo = agent.todo[:origin_todo_length]
|
||
|
raise ValueError("No proper weapon found")
|
||
|
|
||
|
state = agent.todo[-1].state if agent.todo else agent.state
|
||
|
|
||
|
if state.weapon.ammo < attack_count:
|
||
|
if state.weapon.max_ammo >= attack_count: # 不需要换武器,尝试补给
|
||
|
try:
|
||
|
self.command_supply(attack_id, start_time=start_time, subtask=True)
|
||
|
except Exception as e:
|
||
|
agent.todo = agent.todo[:origin_todo_length]
|
||
|
raise ValueError("No enough ammo and failed to get supplies")
|
||
|
else: # 需要换武器
|
||
|
flag = False
|
||
|
for weapon_idx, weapon in enumerate(agent.switchable_weapons):
|
||
|
if weapon.strike_types and target_agent.move_type in weapon.strike_types and weapon.max_ammo >= attack_count:
|
||
|
try:
|
||
|
self.command_switch(attack_id, weapon_idx, start_time=start_time, subtask=True)
|
||
|
flag = True
|
||
|
break
|
||
|
except Exception as e:
|
||
|
agent.todo = agent.todo[:origin_todo_length]
|
||
|
raise ValueError("Failed switching to the proper weapon")
|
||
|
if not flag:
|
||
|
agent.todo = agent.todo[:origin_todo_length]
|
||
|
raise ValueError("No proper weapon has enough ammo")
|
||
|
|
||
|
map_data = self.map.nodes
|
||
|
spotted_enemy_ids = self.spotted_enemy_ids
|
||
|
possible_des = [
|
||
|
pos for pos in get_adj_pos(target_agent.pos, int(agent.attack_range))
|
||
|
if pos in map_data and
|
||
|
(
|
||
|
(node := map_data[pos]).team_id != self.enemy_id
|
||
|
or
|
||
|
node.agent_id not in spotted_enemy_ids
|
||
|
)
|
||
|
]
|
||
|
|
||
|
try:
|
||
|
self.command_move(attack_id, possible_des, start_time=start_time, subtask=True)
|
||
|
except Exception as e:
|
||
|
agent.todo = agent.todo[:origin_todo_length]
|
||
|
raise ValueError("Failed to move to the target")
|
||
|
|
||
|
state = agent.todo[-1].state if agent.todo else agent.state
|
||
|
|
||
|
for _ in range(attack_count):
|
||
|
state.weapon.ammo -= 1
|
||
|
agent.todo.append(Action(agent_id=attack_id,
|
||
|
action_type=ActionType.ATTACK,
|
||
|
target_id=target_id,
|
||
|
des=target_agent.pos,
|
||
|
start_time=start_time,
|
||
|
state=state))
|
||
|
|
||
|
command = Command(agent_id=attack_id,
|
||
|
action_type=ActionType.ATTACK,
|
||
|
target_id=target_id,
|
||
|
des=target_agent.pos,
|
||
|
attack_count=attack_count,
|
||
|
start_time=start_time,
|
||
|
end_time=start_time+len(agent.todo)-origin_todo_length-1,
|
||
|
state=state)
|
||
|
|
||
|
if not subtask:
|
||
|
end_time = command.end_time
|
||
|
for i in range(origin_todo_length, len(agent.todo)):
|
||
|
agent.todo[i].start_time = start_time + i - origin_todo_length
|
||
|
agent.todo[i].end_time = end_time
|
||
|
agent.todo[-1].end_type = ActionType.ATTACK
|
||
|
agent.cmd_todo.append(command)
|
||
|
|
||
|
return command
|
||
|
|
||
|
def command_supply(self, agent_id, start_time=0, subtask=False) -> Command:
|
||
|
"""
|
||
|
需要补给的单位移动至有补给模块的单位附近
|
||
|
"""
|
||
|
agent = self.get_agent(agent_id)
|
||
|
player_agents = self.teams[agent.team_id].agents.values()
|
||
|
map_data = self.map.nodes
|
||
|
origin_todo_length = len(agent.todo) # 记录原 todo 的长度
|
||
|
|
||
|
possible_des = []
|
||
|
|
||
|
for _agent in player_agents:
|
||
|
if _agent.has_supply:
|
||
|
for pos in get_adj_pos(_agent.pos, 1):
|
||
|
if pos in map_data:
|
||
|
possible_des.append(pos)
|
||
|
|
||
|
# print(possible_des)
|
||
|
|
||
|
self.command_move(agent_id, possible_des, start_time, subtask=True)
|
||
|
|
||
|
des = agent.todo[-1].des if agent.todo else agent.pos
|
||
|
for pos in get_adj_pos(des, 1):
|
||
|
if pos in map_data and map_data[pos].team_id == agent.team_id and self.get_agent(map_data[pos].agent_id).has_supply:
|
||
|
supply_id = map_data[pos].agent_id
|
||
|
supply_module = self.get_agent(supply_id).supply
|
||
|
|
||
|
state = agent.todo[-1].state if agent.todo else agent.state
|
||
|
|
||
|
if supply_module.add_endurance > 0:
|
||
|
state.endurance = min(agent.max_endurance, state.endurance + supply_module.add_endurance)
|
||
|
else:
|
||
|
state.endurance = agent.max_endurance
|
||
|
if supply_module.add_ammo > 0:
|
||
|
state.weapon.ammo = min(state.weapon.max_ammo, state.weapon.ammo + supply_module.add_ammo)
|
||
|
else:
|
||
|
state.weapon.ammo = state.weapon.max_ammo
|
||
|
if supply_module.add_fuel > 0:
|
||
|
state.fuel = min(agent.max_fuel, state.fuel + supply_module.add_fuel)
|
||
|
else:
|
||
|
state.fuel = agent.max_fuel
|
||
|
|
||
|
command = Command(agent_id=agent_id,
|
||
|
action_type=ActionType.SUPPLY,
|
||
|
des=des,
|
||
|
start_time=start_time,
|
||
|
end_time=start_time+len(agent.todo)-origin_todo_length-1,
|
||
|
state=state)
|
||
|
|
||
|
if not subtask:
|
||
|
end_time = command.end_time
|
||
|
for i in range(origin_todo_length, len(agent.todo)):
|
||
|
agent.todo[i].start_time = start_time + i - origin_todo_length
|
||
|
agent.todo[i].end_time = end_time
|
||
|
agent.todo[-1].end_type = ActionType.SUPPLY
|
||
|
agent.cmd_todo.append(command)
|
||
|
|
||
|
return command
|
||
|
|
||
|
def command_switch(self, agent_id, weapon_idx, start_time=0, subtask=False) -> Command:
|
||
|
agent = self.get_agent(agent_id)
|
||
|
origin_todo_length = len(agent.todo) # 记录原 todo 的长度
|
||
|
self.command_supply(agent_id, subtask=True)
|
||
|
state = agent.todo[-1].state if agent.todo else agent.state
|
||
|
agent.todo.append(Action(agent_id=agent_id,
|
||
|
action_type=ActionType.SWITCH_WEAPON,
|
||
|
weapon_idx=weapon_idx,
|
||
|
weapon_name=agent.switchable_weapons[weapon_idx].name,
|
||
|
des=state.pos,
|
||
|
start_time=start_time,
|
||
|
state=state))
|
||
|
|
||
|
command = Command(agent_id=agent_id,
|
||
|
action_type=ActionType.SWITCH_WEAPON,
|
||
|
weapon_idx=weapon_idx,
|
||
|
weapon_name=agent.switchable_weapons[weapon_idx].name,
|
||
|
des=state.pos,
|
||
|
start_time=start_time,
|
||
|
end_time=start_time+len(agent.todo)-origin_todo_length-1,
|
||
|
state=state)
|
||
|
|
||
|
if not subtask:
|
||
|
end_time = command.end_time
|
||
|
for i in range(origin_todo_length, len(agent.todo)):
|
||
|
agent.todo[i].start_time = start_time + i - origin_todo_length
|
||
|
agent.todo[i].end_time = end_time
|
||
|
agent.todo[-1].end_type = ActionType.SWITCH_WEAPON
|
||
|
agent.cmd_todo.append(command)
|
||
|
|
||
|
return command
|
||
|
|
||
|
def make_plan(self, command: Command, subtask=False):
|
||
|
if command.action_type == ActionType.MOVE:
|
||
|
return self.command_move(command.agent_id, command.des, start_time=command.start_time, subtask=subtask)
|
||
|
elif command.action_type == ActionType.ATTACK:
|
||
|
return self.command_attack(command.agent_id, command.target_id, attack_count=command.attack_count, start_time=command.start_time, subtask=subtask)
|
||
|
elif command.action_type == ActionType.SUPPLY:
|
||
|
return self.command_supply(command.agent_id, start_time=command.start_time, subtask=subtask)
|
||
|
elif command.action_type == ActionType.SWITCH_WEAPON:
|
||
|
return self.command_switch(command.agent_id, command.weapon_idx, start_time=command.start_time, subtask=subtask)
|
||
|
else:
|
||
|
raise ValueError(f"Invalid command type: {command.action_type}")
|
||
|
|
||
|
def todo_action(self, agent_id: str, retry=0) -> Optional[Action]:
|
||
|
agent = self.get_agent(agent_id)
|
||
|
|
||
|
if not agent.todo:
|
||
|
return None
|
||
|
|
||
|
action = agent.todo.pop(0)
|
||
|
valid, msg = self.check_validity(action)
|
||
|
|
||
|
if valid or retry > 2:
|
||
|
if agent.cmd_todo[0].action_type == action.end_type:
|
||
|
agent.cmd_todo.pop(0)
|
||
|
return action
|
||
|
|
||
|
print(f"Retrying to plan actions for {agent_id}: {msg}")
|
||
|
|
||
|
if action.action_type == ActionType.RELEASE:
|
||
|
for pos in get_adj_pos(action.des, 1):
|
||
|
if pos in self.map.nodes and not self.map.nodes[pos].agent_id:
|
||
|
action.des = pos
|
||
|
return action
|
||
|
return None # Wait
|
||
|
|
||
|
command = agent.cmd_todo.pop(0)
|
||
|
|
||
|
# 处理异常
|
||
|
while agent.todo and agent.todo[0].end_type != command.action_type:
|
||
|
agent.todo.pop(0)
|
||
|
if agent.todo:
|
||
|
agent.todo.pop(0)
|
||
|
|
||
|
current_actions = agent.todo.copy()
|
||
|
current_commands = agent.cmd_todo.copy()
|
||
|
|
||
|
self.make_plan(command)
|
||
|
agent.todo.extend(current_actions)
|
||
|
agent.cmd_todo.extend(current_commands)
|
||
|
return self.todo_action(agent_id, retry=retry+1)
|
||
|
|
||
|
def check_validity(self, action: Action) -> Tuple[bool, str]:
|
||
|
"""
|
||
|
检查动作是否合法
|
||
|
"""
|
||
|
if action.action_type == ActionType.END_OF_TURN:
|
||
|
return True, "OK"
|
||
|
|
||
|
map_data = self.map.nodes
|
||
|
spotted_enemy_ids = self.spotted_enemy_ids
|
||
|
if action.agent_id not in self.teams[self.player_id].agents:
|
||
|
return False, f"Invalid agent id: {action.agent_id}"
|
||
|
|
||
|
agent = self.get_agent(action.agent_id)
|
||
|
|
||
|
if action.action_type not in agent.action_space.keys():
|
||
|
return False, f"Invalid action {action.action_type.name} for agent {agent.agent_id[:8]}"
|
||
|
|
||
|
if action.action_type == ActionType.MOVE:
|
||
|
des = action.des
|
||
|
|
||
|
node = map_data[des]
|
||
|
|
||
|
if node.team_id == agent.team_id and (not agent.agent_type in (target := self.get_agent(node.agent_id)).available_types or len(target.parked_agents) >= target.capacity):
|
||
|
return False, f"{node.agent_id[:8]} at {des} is an ally but not available"
|
||
|
|
||
|
if node.agent_id in spotted_enemy_ids:
|
||
|
return False, f"{des} has been occupied by enemy {node.agent_id[:8]}"
|
||
|
|
||
|
if get_axial_dis(agent.pos, des) > agent.mobility:
|
||
|
return False, f"Destination is out of range. Mobility: {agent.mobility} Euclidean distance: {get_axial_dis(agent.pos, des)}"
|
||
|
|
||
|
try:
|
||
|
mobility = min(agent.fuel, agent.mobility)
|
||
|
adj_pos = get_adj_pos(agent.pos, int(mobility))
|
||
|
aval_data = {
|
||
|
pos: node
|
||
|
for pos in adj_pos if pos in map_data and
|
||
|
(
|
||
|
(node := map_data[pos]).team_id != self.enemy_id
|
||
|
or
|
||
|
node.agent_id not in spotted_enemy_ids
|
||
|
)
|
||
|
}
|
||
|
aval_nodes = astar_search(aval_data, agent.move_type, start=agent.pos, limit=mobility)
|
||
|
if action.des in aval_nodes:
|
||
|
return True, "OK"
|
||
|
else:
|
||
|
return False, f"{des} could not be reached in one step."
|
||
|
except Exception as e:
|
||
|
return False, f"Failed to move to {des}: {e}"
|
||
|
|
||
|
elif action.action_type == ActionType.ATTACK:
|
||
|
try:
|
||
|
target_agent = self.get_agent(action.target_id)
|
||
|
except Exception as e:
|
||
|
return False, f"Invalid target agent id: {action.target_id}, ignored"
|
||
|
|
||
|
des = action.des
|
||
|
|
||
|
if agent.team_id == target_agent.team_id:
|
||
|
return False, f"Cannot attack teammate. Attack: {action.agent_id}. Defend: {action.target_id}"
|
||
|
if agent.weapon is None:
|
||
|
return False, "No weapon equipped"
|
||
|
if agent.ammo <= 0:
|
||
|
return False, "No ammo left"
|
||
|
if target_agent.move_type not in agent.strike_types:
|
||
|
return False, "Target agent cannot be attacked with this weapon"
|
||
|
if get_axial_dis(agent.pos, des) > agent.attack_range:
|
||
|
return False, f"Target agent is out of range. Attack range: {agent.attack_range} Distance: {get_axial_dis(agent.pos, des)}"
|
||
|
|
||
|
elif action.action_type == ActionType.SWITCH_WEAPON:
|
||
|
weapon_idx = action.weapon_idx
|
||
|
around_supply = False
|
||
|
for pos in get_adj_pos(agent.pos, 1):
|
||
|
if pos in map_data and map_data[pos].agent_id and (target_agent := self.get_agent(map_data[pos].agent_id)).team_id == agent.team_id and target_agent.has_supply:
|
||
|
around_supply = True
|
||
|
if not around_supply:
|
||
|
return False, "Agent is not around a supply module"
|
||
|
if weapon_idx >= len(agent.switchable_weapons):
|
||
|
return False, f"Invalid weapon id {weapon_idx}. Max id: {len(agent.switchable_weapons) - 1}"
|
||
|
|
||
|
elif action.action_type == ActionType.RELEASE:
|
||
|
if action.target_id >= len(agent.parked_agents):
|
||
|
return False, f"Trying to release an agent that is not in the queue. Target id: {action.target_id}, Current queue: {agent.parked_agents}"
|
||
|
target_agent = agent.parked_agents[action.target_id]
|
||
|
des = action.des
|
||
|
if get_axial_dis(agent.pos, des) > 1:
|
||
|
return False, "Only adjacent positions can be released"
|
||
|
if map_data[des].team_id != -1:
|
||
|
return False, "Cannot release agent to occupied position"
|
||
|
|
||
|
elif action.action_type == ActionType.INTERACT:
|
||
|
agent = self.get_agent(action.agent_id)
|
||
|
target_agent = self.get_agent(action.target_id)
|
||
|
if get_axial_dis(agent.pos, target_agent.pos) > 1:
|
||
|
return False, "Target agent is out of range"
|
||
|
if agent.team_id != target_agent.team_id:
|
||
|
return False, "Cannot interact with enemy agent"
|
||
|
if agent.agent_type not in target_agent.available_types:
|
||
|
return False, f"Target agent {target_agent.agent_id[:8]} is not interactable with agent {agent}"
|
||
|
if len(target_agent.parked_agents) >= target_agent.capacity:
|
||
|
return False, f"Target agent {target_agent.agent_id[:8]} full"
|
||
|
|
||
|
return True, "OK"
|
||
|
|
||
|
def update(self, sync_info):
|
||
|
"""
|
||
|
使用真实引擎
|
||
|
"""
|
||
|
sync_agents = sync_info['units']
|
||
|
spotted_enemies = sync_info['spottedHostiles']
|
||
|
map_data = self.map.nodes
|
||
|
|
||
|
reward = 0
|
||
|
|
||
|
for sync_agent in [*sync_agents, *spotted_enemies]:
|
||
|
agent_id = sync_agent['agent_id']
|
||
|
agent = self.get_agent(agent_id)
|
||
|
# 清除与地图关联
|
||
|
if agent.pos != (-1, -1):
|
||
|
node = map_data[agent.pos]
|
||
|
if node.agent_id == agent_id:
|
||
|
node.agent_id = None
|
||
|
node.team_id = -1
|
||
|
agent.pos = (sync_agent['pos']['q'], sync_agent['pos']['r'])
|
||
|
agent.fuel = sync_agent['fuel']
|
||
|
if agent.team_id == self.enemy_id:
|
||
|
reward += agent.endurance - sync_agent['endurance']
|
||
|
agent.endurance = sync_agent['endurance']
|
||
|
agent.commenced_action = sync_agent['commenced_action']
|
||
|
current_weapon = sync_agent['current_weapon']
|
||
|
if current_weapon:
|
||
|
for weapon in agent.switchable_weapons:
|
||
|
if weapon.name == current_weapon:
|
||
|
agent.weapon = weapon
|
||
|
break
|
||
|
agent.weapon.ammo = sync_agent['ammo']
|
||
|
else:
|
||
|
agent.weapon = None
|
||
|
# 更新与地图关联
|
||
|
if agent.alive:
|
||
|
node = map_data[agent.pos]
|
||
|
node.agent_id = agent_id
|
||
|
node.team_id = agent.team_id
|
||
|
|
||
|
for agent_id in self._spotted_enemy_ids:
|
||
|
agent = self.get_agent(agent_id)
|
||
|
# 清除与地图关联
|
||
|
if agent.pos != (-1, -1):
|
||
|
node = map_data[agent.pos]
|
||
|
if node.agent_id == agent_id:
|
||
|
node.agent_id = None
|
||
|
node.team_id = -1
|
||
|
|
||
|
# 更新敌人感知
|
||
|
self._spotted_enemy_ids = [enemy['agent_id'] for enemy in spotted_enemies]
|
||
|
|
||
|
"""
|
||
|
NOTE: Clear the states before observation
|
||
|
"""
|
||
|
self._legal_actions = None
|
||
|
self._legal_action_ids = None
|
||
|
self._spotted_agents = None
|
||
|
|
||
|
obs = self.observe()
|
||
|
truncated = self.episode_steps >= self.max_episode_steps
|
||
|
terminated = self.teams[self.enemy_id].alive_count == 0
|
||
|
self._cumulative_rewards[self.player_id] += reward
|
||
|
|
||
|
info = {}
|
||
|
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
|
||
|
info["next_player"] = self.player_id + 1 if not terminated else -1
|
||
|
|
||
|
if self.replay_path is not None:
|
||
|
self._frames[self.player_id].append(self.render(mode="rgb_array"))
|
||
|
|
||
|
if terminated or truncated:
|
||
|
# The eval_episode_return is calculated from Player 1's perspective
|
||
|
info["eval_episode_return"] = reward if self.player_id == 0 else -reward
|
||
|
info["done_reason"] = "Terminated" if terminated else "Truncated"
|
||
|
|
||
|
if self.replay_path is not None:
|
||
|
self.save_replay()
|
||
|
|
||
|
return obs, reward, terminated, truncated, info
|
||
|
|
||
|
|
||
|
|
||
|
|