yes_cmdr/yes_cmdr_env.py

1286 lines
58 KiB
Python

import numpy as np
import json
import os
from datetime import datetime
from typing import List, Optional
import gymnasium as gym
from .common.utils import *
from .common.agent import *
class YesCmdrEnv(gym.Env):
metadata = {"render_modes": ["human", "rgb_array"], "render_fps": 4}
def __init__(self,
use_real_engine: bool = False,
replay_path: str = None,
campaign_id: str = "NOID",
team_id: int = 0,
max_team: int = 2,
max_faction: int = 2,
data_path: str = "../data",
max_episode_steps: int = 1000,
bot_version: str = "v0",
war_fog: bool = True) -> None:
self._init_flag = False
# set other properties...
self.use_real_engine = use_real_engine
self.replay_path = replay_path
self.campaign_id = campaign_id
self.init_team_id = team_id
self.max_team = max_team
self.max_faction = max_faction
self.data_path = data_path
self.max_episode_steps = max_episode_steps
self.bot_version = bot_version
self.war_fog = war_fog
def reset(self):
# reset the environment...
# 实际的环境初始化是在第一次调用 reset 方法时进行的
if not self._init_flag:
self._init_flag = True
self._make_env()
if self.replay_path is not None:
from .common.renderer import Renderer
self._renderer = Renderer(self.map.nodes.values())
else:
for node in self.map.nodes.values():
node.reset()
for team in self.teams:
for agent_id, agent in team.agents.items():
agent.reset()
if agent.pos == (-1, -1): # Illegal position for padding
continue
assert agent.pos in self.map.nodes, f"Invalid agent position: {agent.pos}"
self.map.nodes[agent.pos].agent_id = agent_id
self.map.nodes[agent.pos].team_id = team.team_id
self.map.nodes[agent.pos].faction_id = team.faction_id
if self.replay_path is not None:
self._frames = [[] for _ in range(self.max_team)]
self._legal_actions = None
self._legal_action_ids = None
self._spotted_enemy_ids = None
self._spotted_enemy_agents = None
self._team_id = self.init_team_id
self._spotted_agents = None
self.episode_steps = 0
self._cumulative_rewards = np.zeros(self.max_team)
obs = self.observe()
info = {
"next_team": self.team_id + 1
}
return obs, info
def step(self, action: Union[int, Action]):
if isinstance(action, int):
if action not in self.legal_action_ids:
print(f"Invalid action will be ignored: {action}")
info = {}
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
info["next_team"] = self.team_id + 1
info["exception"] = f"Invalid action {action}"
return self.observe(), 0, False, False, info
action = self.id_to_action(action)
elif isinstance(action, Action):
valid, reason = self.check_validity(action)
if not valid:
if action.action_type == ActionType.MOVE and reason.find("occupied") > 0:
print("Trying to move to an adjacent position that is empty.")
dis = float('inf')
for _action in self.legal_actions:
if _action.agent_id == action.agent_id and _action.action_type == ActionType.MOVE and get_axial_dis(_action.des, action.des) < dis:
dis = get_axial_dis(_action.des, action.des)
action = _action
else:
print(f"Invalid action will be ignored: {action}")
print(f"Reason: {reason}")
info = {}
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
info["next_team"] = self.team_id + 1
info["exception"] = f"Invalid action: {action}. Reason: {reason}"
return self.observe(), 0, False, False, info
assert(isinstance(action, Action))
reward, terminated = self._player_step(action)
"""
NOTE: Clear the states before observation
"""
self._legal_actions = None
self._legal_action_ids = None
self._spotted_enemy_ids = None
self._spotted_enemy_agents = None
obs = self.observe()
truncated = self.episode_steps >= self.max_episode_steps
self._cumulative_rewards[self.team_id] += reward
info = {}
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
info["next_team"] = self.team_id + 1 if not terminated else -1
if self.replay_path is not None:
self._frames[self.team_id].append(self.render(mode="rgb_array"))
if terminated or truncated:
# The eval_episode_return is calculated from Player 1's perspective
info["eval_episode_return"] = reward if self.faction_id == self.teams[self.init_team_id].faction_id else -reward
info["done_reason"] = "Terminated" if terminated else "Truncated"
if self.replay_path is not None:
self.save_replay()
return obs, reward, terminated, truncated, info
def render(self, mode="human", attack_agent=None, defend_agent=None):
if mode == "rgb_array":
return self._renderer.render(self.union_agents, self.spotted_enemy_agents, attack_agent, defend_agent)
elif mode == "human":
return self.observe()
else:
raise ValueError("Invalid render mode: {}".format(mode))
def action_to_id(self, action: Action) -> int:
if action.action_type == ActionType.END_OF_TURN:
return 0
action_space = self.action_space
if action.agent_id not in action_space.keys():
raise ValueError(f"Invalid agent_id: {action.agent_id[:8]}")
agent_action_space = action_space[action.agent_id]
if action.action_type not in agent_action_space.keys():
agent = self.get_agent(action.agent_id)
print(agent.available_types, agent.parked_agents)
print(action)
raise ValueError(f"Invalid action_type: {action.action_type.name} for agent {action.agent_id[:8]}. Available action_types: {agent_action_space.keys()}")
action_id = 1
for agent_id, agent_action_space in action_space.items():
for action_type, action_space in agent_action_space.items():
if agent_id != action.agent_id or action_type != action.action_type:
action_id += action_space.n
else:
if action.action_type == ActionType.SWITCH_WEAPON: # 以下标编码,不以位置编码
assert action.weapon_idx < action_space.n, f"Invalid weapon_idx: {action.weapon_idx}, action_space.n: {action_space.n}"
action_id += action.weapon_idx
elif action.action_type == ActionType.RELEASE:
assert action.target_id < action_space.n, f"Invalid target_idx: {action.target_id}, action_space.n: {action_space.n}"
action_id += action.target_id # 此处为下标
else:
pos, des = self.get_agent(action.agent_id).pos, action.des
encode_id = encode_axial((des[0] - pos[0], des[1] - pos[1]))
action_id += encode_id - 1
if action.action_type == ActionType.MOVE:
cur_range = self.get_agent(action.agent_id).mobility
elif action.action_type == ActionType.ATTACK:
cur_range = self.get_agent(action.agent_id).attack_range
elif action.action_type == ActionType.INTERACT:
cur_range = 1
assert encode_id - 1 < action_space.n, f"Action {action.action_type.name} Distance: {get_axial_dis(pos, des)}, cur_range: {cur_range} encode_id: {encode_id} action_space.n: {action_space.n}\n{action}"
return int(action_id)
assert False, "Should not reach here"
def id_to_action(self, action_id: int) -> Action:
if action_id == 0:
return Action(ActionType.END_OF_TURN)
action_id -= 1
action_space = self.action_space
for agent_id, agent_action_space in action_space.items():
for action_type, action_space in agent_action_space.items():
if action_id < action_space.n:
if action_type == ActionType.SWITCH_WEAPON:
return Action(agent_id=agent_id, action_type=action_type, weapon_idx=action_id, weapon_name=self.get_agent(agent_id).weapon.name)
elif action_type == ActionType.RELEASE:
return Action(agent_id=agent_id, target_id=action_id, action_type=action_type)
else:
pos = self.get_agent(agent_id).pos
delta = decode_axial(action_id + 1) # 编码 0 是原位置
des = (pos[0] + delta[0], pos[1] + delta[1])
return Action(agent_id=agent_id, des=des, action_type=action_type)
else:
action_id -= action_space.n
raise ValueError("Invalid action_id: {}".format(action_id))
def get_agent(self, agent_id: str) -> Agent:
for team in self.teams:
if agent_id in team.agents.keys():
agent = team.agents[agent_id]
# if not agent.alive:
# raise ValueError(f"Agent {agent_id} is not alive")
return agent
raise ValueError(f"Invalid agent id: {agent_id}")
def get_node(self, pos: Tuple[int, int]) -> TileNode:
if pos not in self.map.nodes:
raise ValueError(f"Invalid position: {pos}")
return self.map.nodes[pos]
@property
def legal_actions(self) -> List[Action]:
if self._legal_actions is not None:
return self._legal_actions
map_data = self.map.nodes
player_data = self.current_agents_dict
spotted_enemy_ids = self.spotted_enemy_ids
_legal_actions = [Action(ActionType.END_OF_TURN)] # Default action
for agent_id, agent in player_data.items():
if agent.commenced_action or not agent.alive or agent.pos == (-1, -1) or agent.is_carried:
continue
pos = agent.pos
# 移动
mobility = min(agent.fuel, agent.mobility)
if mobility > 0:
adj_pos = get_adj_pos(pos, int(mobility))
aval_data = {
pos: node
for pos in adj_pos if pos in map_data and
(
(node := map_data[pos]).faction_id == self.faction_id
or
node.agent_id not in spotted_enemy_ids
)
}
aval_nodes = astar_search(aval_data, agent.move_type, start=pos, limit=mobility)
for des in aval_nodes:
if des == pos:
continue
node = map_data[des]
if node.faction_id != self.faction_id and node.agent_id not in spotted_enemy_ids:
# 一个格子只能有一个单位
_legal_actions.append(Action(
agent_id=agent_id,
des=des,
action_type=ActionType.MOVE
))
elif node.faction_id == self.faction_id and agent.agent_type in (target := self.get_agent(node.agent_id)).available_types \
and len(target.parked_agents) < target.capacity:
# 除非终点有可以停靠的单位
_legal_actions.append(Action(
agent_id=agent_id,
target_id=node.agent_id,
des=des,
action_type=ActionType.MOVE
))
# 进攻
if agent.attack_range > 0 and agent.ammo > 0:
for des in get_adj_pos(pos, agent.attack_range):
if des not in map_data:
continue
node = map_data[des]
if node.agent_id in spotted_enemy_ids and self.get_agent(node.agent_id).move_type in agent.strike_types:
# 已发现敌军
_legal_actions.append(Action(
agent_id=agent_id,
target_id=node.agent_id,
action_type=ActionType.ATTACK,
des=des
))
# 交互
for des in get_adj_pos(pos, 1):
if des not in map_data:
continue
node = map_data[des]
if node.faction_id == self.faction_id and agent.agent_type in (target :=self.get_agent(node.agent_id)).available_types \
and len(target.parked_agents) < target.capacity:
_legal_actions.append(Action(
agent_id=agent_id,
target_id=node.agent_id,
des=des,
action_type=ActionType.INTERACT
))
# 释放
if agent.parked_agents:
for des in get_adj_pos(pos, 1):
if des not in map_data:
continue
node = map_data[des]
if node.team_id != -1:
continue
for idx, agent_to_release in enumerate(agent.parked_agents):
if get_cost(agent_to_release.move_type, node.terrain_type) != 0:
_legal_actions.append(Action(
agent_id=agent_id,
target_id=idx,
des=des,
action_type=ActionType.RELEASE
))
break
# 切换武器
if agent.switchable_weapons:
for des in get_adj_pos(pos, 1):
if des not in map_data:
continue
node = map_data[des]
if node.faction_id == self.faction_id and self.get_agent(node.agent_id).has_supply:
for weapon_idx, weapon in enumerate(agent.switchable_weapons):
if not agent.weapon or weapon.name != agent.weapon.name:
_legal_actions.append(Action(
agent_id=agent_id,
target_id=weapon.name,
weapon_idx=weapon_idx,
action_type=ActionType.SWITCH_WEAPON
))
self._legal_actions = _legal_actions
return self._legal_actions
@property
def legal_action_ids(self) -> List[int]:
if self._legal_action_ids is not None:
return self._legal_action_ids
self._legal_action_ids = [self.action_to_id(action) for action in self.legal_actions]
return self._legal_action_ids
@property
def spotted_enemy_ids(self) -> List[str]:
if not self.war_fog:
return [enemy_id for enemy_id, enemy_agent in self.enemy_agents_dict.items() if enemy_agent.alive]
if self._spotted_enemy_ids is not None:
return self._spotted_enemy_ids
current_agents = self.current_agents
enemy_data = self.enemy_agents_dict
map_data = self.map.nodes
_spotted_enemy_ids = []
for agent in current_agents:
info_level = agent.info_level
for des in get_adj_pos(agent.pos, agent.info_level):
if des not in map_data:
continue
node = map_data[des]
dis = get_axial_dis(agent.pos, des)
if dis == 1: # Adjacent agents are always spotted
dis = -1e9
if node.faction_id != self.faction_id \
and enemy_data[node.agent_id].stealth_level < info_level - dis + 1:
_spotted_enemy_ids.append(node.agent_id)
self._spotted_enemy_ids = _spotted_enemy_ids
return self._spotted_enemy_ids
def save_replay(self):
if not os.path.exists(self.replay_path):
os.makedirs(self.replay_path)
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
for team_id in range(self.max_team):
path = os.path.join(
self.replay_path,
f"tianqiong_{timestamp}_Team{team_id + 1}.mp4"
)
self.display_frames_as_mp4(self._frames[team_id], path)
print(f'replay {path} saved!')
@staticmethod
def display_frames_as_mp4(frames: list, path: str, fps=4) -> None:
assert path.endswith('.mp4'), f'path must end with .mp4, but got {path}'
import imageio
imageio.mimwrite(path, frames, fps=fps)
def _make_env(self):
# Configuration for file paths
map_path = f"{self.data_path}/MapInfo.json"
agents_path = f"{self.data_path}/AgentsInfo.json"
# Load map and agents information
with open(map_path, 'r') as f:
origin_map = json.load(f)
with open(agents_path, 'r') as f:
origin_agents = json.load(f)
_nodes = {
(node := dict_to_node(node_dict)).pos: node for node_dict in origin_map
}
agents_dict_lists = [
[agent_dict for agent_dict in origin_agents if agent_dict["team_id"] == team_id]
for team_id in range(self.max_team)
]
_agents = [
{
(agent := dict_to_agent(agent_dict)).agent_id: agent
for agent_dict in agents_dict_lists[team_id]
}
for team_id in range(self.max_team)
]
_teams = [
Team(
team_id=_team_id,
faction_id=agents_dict_lists[_team_id][0]["faction_id"],
agents=_agents[_team_id]
)
for _team_id in range(self.max_team)
]
for team in _teams:
for agent_id, agent in team.agents.items():
if agent.pos == (-1, -1): # Illegal position for padding
continue
assert agent.pos in _nodes, f"Invalid agent position: {agent.pos}"
_nodes[agent.pos].agent_id = agent_id
_nodes[agent.pos].team_id = team.team_id
_nodes[agent.pos].faction_id = team.faction_id
_map = Map(_nodes)
self.map = _map
self.teams = _teams
self.obs_shape = _map.width, _map.height
# Get action space
self._action_spaces = [
gym.spaces.Dict(
{
f"team_{team.team_id + 1}": gym.spaces.Dict(
{
ActionType.END_OF_TURN: gym.spaces.Discrete(1) # Default action
}
),
**{
agent_id: agent.action_space for agent_id, agent in team.agents.items()
}
}
)
for team in self.teams
]
self._action_space_sizes = [
sum(action_space.n for agent_action_space in self._action_spaces[team_id].values() for action_space in agent_action_space.values())
for team_id in range(self.max_team)
]
self._team_id = 0
for action_space_size in self._action_space_sizes:
for action_id in range(action_space_size):
assert self.action_to_id(self.id_to_action(action_id)) == action_id
self._team_id += 1
# Get observation space
self._observation_spaces = [
gym.spaces.Dict(
{
"action_mask": gym.spaces.Box(0, 1, (self._action_space_sizes[team_id], ), dtype=np.int8), # Different size for each team
"union_endurance": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
"union_info_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
"union_stealth_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
"union_mobility": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
"union_defense": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
"union_damage": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
"union_fuel": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
"union_ammo": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
"enemy_endurance": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
"enemy_info_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
"enemy_stealth_level": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
"enemy_mobility": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
"enemy_defense": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16),
"enemy_damage": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
"enemy_fuel": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.float32),
"enemy_ammo": gym.spaces.Box(0, 1000, self.obs_shape, dtype=np.int16)
}
)
for team_id in range(self.max_team)
]
self._reward_space = gym.spaces.Box(low=-1000, high=1000, shape=(1, ), dtype=np.float32)
# 预处理己方联盟和所有敌方的单位信息
self._union_agents = [[] for _ in range(self.max_faction)]
self._union_agents_dict = [{} for _ in range(self.max_faction)]
self._enemy_agents = [[] for _ in range(self.max_faction)]
self._enemy_agents_dict = [{} for _ in range(self.max_faction)]
for team_id in range(self.max_team):
faction_id = self.teams[team_id].faction_id
agents_list = list(self.teams[team_id].agents.values())
agents_dict = self.teams[team_id].agents
self._union_agents[faction_id].extend(agents_list)
for other_faction_id in range(self.max_faction):
if other_faction_id != faction_id:
self._enemy_agents[other_faction_id].extend(agents_list)
self._enemy_agents_dict[other_faction_id] |= agents_dict
@property
def team_id(self):
return self._team_id
@property
def current_agents(self):
return self.teams[self._team_id].agents.values()
@property
def current_agents_dict(self):
return self.teams[self._team_id].agents
@property
def union_agents(self):
return self._union_agents[self.faction_id]
@property
def union_agents_dict(self):
return self._union_agents_dict[self.faction_id]
@property
def enemy_agents(self):
return self._enemy_agents[self.faction_id]
# if self._enemy_agents is not None:
# return self._enemy_agents
# _enemy_agents = []
# for team_id in range(self.max_team):
# if self.teams[team_id].faction_id != self.faction_id:
# _enemy_agents.extend(list(self.teams[team_id].agents.values()))
# self._enemy_agents = _enemy_agents
# return self._enemy_agents
@property
def enemy_agents_dict(self):
return self._enemy_agents_dict[self.faction_id]
# if self._enemy_agents_dict is not None:
# return self._enemy_agents_dict
# _enemy_agents_dict = {}
# for team_id in range(self.max_team):
# if self.teams[team_id].faction_id != self.faction_id:
# for agent_id, agent in self.teams[team_id].agents.items():
# _enemy_agents_dict[agent_id] = agent
# self._enemy_agents_dict = _enemy_agents_dicct
# return self._enemy_agents_dict
@property
def spotted_enemy_agents(self):
if self._spotted_enemy_agents is not None:
return self._spotted_enemy_agents
_spotted_enemy_agents = [self.enemy_agents_dict[enemy_id] for enemy_id in self.spotted_enemy_ids]
self._spotted_enemy_agents = _spotted_enemy_agents
return self._spotted_enemy_agents
@property
def faction_id(self):
return self.teams[self._team_id].faction_id
@property
def observation_space(self):
return self._observation_spaces[self.team_id]
@property
def action_space(self):
return self._action_spaces[self.team_id]
@property
def action_space_size(self):
return self._action_space_sizes[self.team_id]
@property
def reward_space(self):
return self._reward_space
def random_action(self) -> Action:
return np.random.choice(self.legal_actions)
def next_turn(self):
self._team_id = (self._team_id + 1) % self.max_team
def bot_action(self) -> Action:
if self.bot_version == "v0":
attack_actions = [action for action in self.legal_actions if action.action_type == ActionType.ATTACK]
if attack_actions:
return np.random.choice(attack_actions)
else:
return self.random_action()
else:
raise NotImplementedError(f"Invalid bot version: {self.bot_version}")
def _player_step(self, action: Action):
if action.action_type == ActionType.END_OF_TURN:
"""
NOTE: here exchange the player
"""
# 清除动作标记
for agent in self.current_agents:
agent.commenced_action = False
# 切换玩家
self.episode_steps += 1
if not self.use_real_engine:
self.next_turn()
reward, terminated = 0, False # 这里可以加上惩罚项
elif action.action_type == ActionType.MOVE:
reward = self._move(action.agent_id, action.des, action.target_id)
terminated = False
elif action.action_type == ActionType.ATTACK:
reward, terminated = self._attack(action.agent_id, action.target_id)
elif action.action_type == ActionType.INTERACT:
reward = self._interact(action.agent_id, action.target_id)
terminated = False
elif action.action_type == ActionType.RELEASE:
reward = self._release(action.agent_id, action.target_id, action.des)
terminated = False
elif action.action_type == ActionType.SWITCH_WEAPON:
reward = self._switch_weapon(action.agent_id, action.weapon_idx)
terminated = False
else:
raise ValueError(f"Invalid action type: {action.action_type}")
for agent in self.current_agents:
for pos in get_adj_pos(agent.pos, 1):
if pos in self.map.nodes and self.map.nodes[pos].faction_id == self.faction_id:
adj_agent = self.get_agent(self.map.nodes[pos].agent_id)
if not adj_agent.has_supply:
continue
module = adj_agent.supply
if module.add_endurance > 0:
agent.endurance += min(module.add_endurance, agent.max_endurance - agent.endurance)
else:
agent.endurance = agent.max_endurance
if agent.weapon is not None:
if module.add_ammo > 0:
agent.weapon.ammo += min(module.add_ammo, agent.weapon.max_ammo - agent.weapon.ammo)
else:
agent.weapon.ammo = agent.weapon.max_ammo
if module.add_fuel > 0:
agent.fuel += min(module.add_fuel, agent.max_fuel - agent.fuel)
else:
agent.fuel = agent.max_fuel
return reward, terminated
def _move(self, agent_id, des, target_id):
map_data = self.map.nodes
agent = self.current_agents_dict[agent_id]
spotted_enemy_ids = self.spotted_enemy_ids
origin_pos = agent.pos
# 清除原位置关联
node = map_data[agent.pos]
node.agent_id = None
node.team_id = -1
node.faction_id = -1
mobility = min(agent.mobility, agent.fuel)
adj_pos = get_adj_pos(agent.pos, int(mobility))
aval_data = {
pos: node
for pos in adj_pos if pos in map_data and
(
(node := map_data[pos]).faction_id == self.faction_id
or
node.agent_id not in spotted_enemy_ids
)
}
path = astar_search(aval_data, move_type=agent.move_type, start=agent.pos, goal=des, limit=mobility)
path = [agent.pos] + path
exception = False
for cur, nxt in zip(path[:-1], path[1:]):
if self.war_fog and map_data[nxt].faction_id != self.faction_id: # 前进方向遇到敌军,停下
exception = True
break
if self.replay_path is not None:
self._frames[self.team_id].append(self.render(mode="rgb_array"))
agent.fuel -= get_cost(agent.move_type, map_data[nxt].terrain_type)
agent.pos = nxt
# 更新位置以及与地图的关联
cur = agent.pos
agent.commenced_action = True
if map_data[cur].team_id == -1:
map_data[cur].agent_id = agent_id
map_data[cur].team_id = self.team_id
map_data[cur].faction_id = self.faction_id
elif map_data[cur].agent_id == target_id:
carry_agent = self.get_agent(target_id)
assert agent.faction_id == carry_agent.faction_id
assert agent.agent_type in carry_agent.available_types, f"{agent} not in {carry_agent.available_types} (id: {carry_agent.agent_id[:8]}). "
assert len(carry_agent.parked_agents) < carry_agent.capacity, f"Agent {carry_agent[:8]} is available but full"
carry_agent.parked_agents.append(agent)
agent.is_carried = True
else:
assert exception
return 0
def _attack(self, attack_id, target_id):
agent = self.get_agent(attack_id)
target_agent = self.get_agent(target_id)
if self.replay_path is not None:
self._frames[self.team_id].extend([self.render(mode="rgb_array", attack_agent=agent, defend_agent=target_agent)] * 4)
damage = max(agent.damage - target_agent.defense, 0)
damage = min(damage, target_agent.endurance)
target_agent.endurance -= damage
agent.weapon.ammo -= 1
agent.commenced_action = True
if target_agent.endurance <= 0:
# 删除与地图关联
node = self.get_node(target_agent.pos)
node.agent_id = None
node.team_id = -1
node.faction_id = -1
if sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0:
return damage, True
return damage, False
def _interact(self, agent_id, target_id):
agent = self.get_agent(agent_id)
target = self.get_agent(target_id)
target.parked_agents.append(agent)
agent.is_carried = True
return 0 # 后续可以考虑修改奖励为补给的线性组合
def _release(self, agent_id: str, target_idx: int, des: Tuple[int, int]):
agent = self.get_agent(agent_id)
agent_to_release = agent.parked_agents[target_idx]
agent_to_release.pos = des
node = self.get_node(des)
node.team_id = agent_to_release.team_id
node.faction_id = agent_to_release.faction_id
node.agent_id = agent_to_release.agent_id
for module in agent.modules:
if module.add_endurance > 0:
agent_to_release.endurance += min(module.add_endurance, agent_to_release.max_endurance - agent_to_release.endurance)
else:
agent_to_release.endurance = agent_to_release.max_endurance
if agent_to_release.weapon is not None:
if module.add_ammo > 0:
agent_to_release.weapon.ammo += min(module.add_ammo, agent_to_release.weapon.max_ammo - agent_to_release.weapon.ammo)
else:
agent_to_release.weapon.ammo = agent_to_release.weapon.max_ammo
if module.add_fuel > 0:
agent_to_release.fuel += min(module.add_fuel, agent_to_release.max_fuel - agent_to_release.fuel)
else:
agent_to_release.fuel = agent_to_release.max_fuel
agent_to_release.is_carried = False
return 0 # 后续可以考虑修改奖励为补给的线性组合
def _switch_weapon(self, agent_id: str, weapon_idx: int):
agent = self.get_agent(agent_id)
agent.weapon.reset()
agent.weapon = agent.switchable_weapons[weapon_idx]
return 0
def observe(self):
union_agents = self.union_agents
spotted_enemy_agents = self.spotted_enemy_agents
obs = {}
for desc, space in self.observation_space.items():
obs[desc] = np.zeros(space.shape, dtype=space.dtype)
legal_actions_ids = self.legal_action_ids
for action_id in legal_actions_ids:
obs["action_mask"][action_id] = 1
for agent in union_agents:
pos = agent.pos
obs["union_info_level"][pos[0], pos[1]] = agent.info_level
obs["union_stealth_level"][pos[0], pos[1]] = agent.stealth_level
obs["union_mobility"][pos[0], pos[1]] = agent.mobility
obs["union_defense"][pos[0], pos[1]] = agent.defense
obs["union_damage"][pos[0], pos[1]] = agent.damage
obs["union_fuel"][pos[0], pos[1]] = agent.fuel
obs["union_ammo"][pos[0], pos[1]] = agent.ammo
obs["union_endurance"][pos[0], pos[1]] = agent.endurance
for agent in spotted_enemy_agents:
pos = agent.pos
obs["enemy_info_level"][pos[0], pos[1]] = agent.info_level
obs["enemy_stealth_level"][pos[0], pos[1]] = agent.stealth_level
obs["enemy_mobility"][pos[0], pos[1]] = agent.mobility
obs["enemy_defense"][pos[0], pos[1]] = agent.defense
obs["enemy_damage"][pos[0], pos[1]] = agent.damage
obs["enemy_fuel"][pos[0], pos[1]] = agent.fuel
obs["enemy_ammo"][pos[0], pos[1]] = agent.ammo
obs["enemy_endurance"][pos[0], pos[1]] = agent.endurance
return obs
def command_move(self, agent_id, des: Union[Tuple[int, int], List[Tuple[int, int]]], start_time=0, subtask=False) -> Command:
agent = self.get_agent(agent_id)
state = agent.todo[-1].state if agent.todo else agent.state
if state.pos == des or state.pos in des:
return
map_data = self.map.nodes
origin_todo_length = len(agent.todo) # 记录原 todo 的长度
spotted_enemy_ids = self.spotted_enemy_ids
adj_pos = get_adj_pos(state.pos, int(state.fuel))
adj_pos.append(state.pos)
aval_data = {
pos: node
for pos in adj_pos if pos in map_data and
(
(node := map_data[pos]).faction_id == self.faction_id
or
node.agent_id not in spotted_enemy_ids
)
}
path = get_path(agent=agent, map_data=aval_data, start=state.pos, des=des, limit=state.fuel) # 可能没有可行路径
for pos in path:
state.fuel -= get_cost(agent.move_type, map_data[pos].terrain_type)
state.pos = pos
agent.todo.append(Action(agent_id=agent_id, action_type=ActionType.MOVE, des=pos, start_time=start_time, state=state))
command = Command(agent_id=agent_id,
action_type=ActionType.MOVE,
des=des,
start_time=start_time,
end_time=start_time+len(agent.todo)-origin_todo_length-1,
state=state)
if not subtask:
end_time = command.end_time
for i in range(origin_todo_length, len(agent.todo)):
agent.todo[i].start_time = start_time + i - origin_todo_length
agent.todo[i].end_time = end_time
agent.todo[-1].end_type = ActionType.MOVE
agent.cmd_todo.append(command)
return command
def command_attack(self, attack_id, target_id, attack_count=1, start_time=0, subtask=False) -> Command:
agent = self.get_agent(attack_id)
target_agent = self.get_agent(target_id)
origin_todo_length = len(agent.todo) # 记录原 todo 的长度
state = agent.todo[-1].state if agent.todo else agent.state
if target_agent.move_type not in state.weapon.strike_types:
flag = False
no_enough_ammo = False
for weapon_idx, weapon in enumerate(agent.switchable_weapons):
if weapon.strike_types and target_agent.move_type in weapon.strike_types:
if weapon.max_ammo < attack_count:
no_enough_ammo = True
try:
self.command_switch(attack_id, weapon_idx, start_time=start_time, subtask=True)
flag = True
break
except Exception as e:
agent.todo = agent.todo[:origin_todo_length]
raise ValueError("Failed switching to the proper weapon")
if not flag:
if no_enough_ammo:
agent.todo = agent.todo[:origin_todo_length]
raise ValueError("Agent has the striking weapon but no one with enough ammo")
else:
agent.todo = agent.todo[:origin_todo_length]
raise ValueError("No proper weapon found")
state = agent.todo[-1].state if agent.todo else agent.state
if state.weapon.ammo < attack_count:
if state.weapon.max_ammo >= attack_count: # 不需要换武器,尝试补给
try:
self.command_supply(attack_id, start_time=start_time, subtask=True)
except Exception as e:
agent.todo = agent.todo[:origin_todo_length]
raise ValueError("No enough ammo and failed to get supplies")
else: # 需要换武器
flag = False
for weapon_idx, weapon in enumerate(agent.switchable_weapons):
if weapon.strike_types and target_agent.move_type in weapon.strike_types and weapon.max_ammo >= attack_count:
try:
self.command_switch(attack_id, weapon_idx, start_time=start_time, subtask=True)
flag = True
break
except Exception as e:
agent.todo = agent.todo[:origin_todo_length]
raise ValueError("Failed switching to the proper weapon")
if not flag:
agent.todo = agent.todo[:origin_todo_length]
raise ValueError("No proper weapon has enough ammo")
map_data = self.map.nodes
spotted_enemy_ids = self.spotted_enemy_ids
possible_des = [
pos for pos in get_adj_pos(target_agent.pos, int(agent.attack_range))
if pos in map_data and
(
(node := map_data[pos]).faction_id == self.faction_id
or
node.agent_id not in spotted_enemy_ids
)
]
try:
self.command_move(attack_id, possible_des, start_time=start_time, subtask=True)
except Exception as e:
agent.todo = agent.todo[:origin_todo_length]
raise ValueError("Failed to move to the target")
state = agent.todo[-1].state if agent.todo else agent.state
for _ in range(attack_count):
state.weapon.ammo -= 1
agent.todo.append(Action(agent_id=attack_id,
action_type=ActionType.ATTACK,
target_id=target_id,
des=target_agent.pos,
start_time=start_time,
state=state))
command = Command(agent_id=attack_id,
action_type=ActionType.ATTACK,
target_id=target_id,
des=target_agent.pos,
attack_count=attack_count,
start_time=start_time,
end_time=start_time+len(agent.todo)-origin_todo_length-1,
state=state)
if not subtask:
end_time = command.end_time
for i in range(origin_todo_length, len(agent.todo)):
agent.todo[i].start_time = start_time + i - origin_todo_length
agent.todo[i].end_time = end_time
agent.todo[-1].end_type = ActionType.ATTACK
agent.cmd_todo.append(command)
return command
def command_supply(self, agent_id, start_time=0, subtask=False) -> Command:
"""
需要补给的单位移动至有补给模块的单位附近
"""
agent = self.get_agent(agent_id)
current_agents = self.current_agents
map_data = self.map.nodes
origin_todo_length = len(agent.todo) # 记录原 todo 的长度
possible_des = []
for _agent in current_agents:
if _agent.has_supply:
for pos in get_adj_pos(_agent.pos, 1):
if pos in map_data:
possible_des.append(pos)
# print(possible_des)
self.command_move(agent_id, possible_des, start_time, subtask=True)
des = agent.todo[-1].des if agent.todo else agent.pos
for pos in get_adj_pos(des, 1):
if pos in map_data and map_data[pos].faction_id == agent.faction_id and self.get_agent(map_data[pos].agent_id).has_supply:
supply_id = map_data[pos].agent_id
supply_module = self.get_agent(supply_id).supply
state = agent.todo[-1].state if agent.todo else agent.state
if supply_module.add_endurance > 0:
state.endurance = min(agent.max_endurance, state.endurance + supply_module.add_endurance)
else:
state.endurance = agent.max_endurance
if supply_module.add_ammo > 0:
state.weapon.ammo = min(state.weapon.max_ammo, state.weapon.ammo + supply_module.add_ammo)
else:
state.weapon.ammo = state.weapon.max_ammo
if supply_module.add_fuel > 0:
state.fuel = min(agent.max_fuel, state.fuel + supply_module.add_fuel)
else:
state.fuel = agent.max_fuel
command = Command(agent_id=agent_id,
action_type=ActionType.SUPPLY,
des=des,
start_time=start_time,
end_time=start_time+len(agent.todo)-origin_todo_length-1,
state=state)
if not subtask:
end_time = command.end_time
for i in range(origin_todo_length, len(agent.todo)):
agent.todo[i].start_time = start_time + i - origin_todo_length
agent.todo[i].end_time = end_time
agent.todo[-1].end_type = ActionType.SUPPLY
agent.cmd_todo.append(command)
return command
def command_switch(self, agent_id, weapon_idx, start_time=0, subtask=False) -> Command:
agent = self.get_agent(agent_id)
origin_todo_length = len(agent.todo) # 记录原 todo 的长度
self.command_supply(agent_id, subtask=True)
state = agent.todo[-1].state if agent.todo else agent.state
agent.todo.append(Action(agent_id=agent_id,
action_type=ActionType.SWITCH_WEAPON,
weapon_idx=weapon_idx,
weapon_name=agent.switchable_weapons[weapon_idx].name,
des=state.pos,
start_time=start_time,
state=state))
command = Command(agent_id=agent_id,
action_type=ActionType.SWITCH_WEAPON,
weapon_idx=weapon_idx,
weapon_name=agent.switchable_weapons[weapon_idx].name,
des=state.pos,
start_time=start_time,
end_time=start_time+len(agent.todo)-origin_todo_length-1,
state=state)
if not subtask:
end_time = command.end_time
for i in range(origin_todo_length, len(agent.todo)):
agent.todo[i].start_time = start_time + i - origin_todo_length
agent.todo[i].end_time = end_time
agent.todo[-1].end_type = ActionType.SWITCH_WEAPON
agent.cmd_todo.append(command)
return command
def make_plan(self, command: Command, subtask=False):
if command.action_type == ActionType.MOVE:
return self.command_move(command.agent_id, command.des, start_time=command.start_time, subtask=subtask)
elif command.action_type == ActionType.ATTACK:
return self.command_attack(command.agent_id, command.target_id, attack_count=command.attack_count, start_time=command.start_time, subtask=subtask)
elif command.action_type == ActionType.SUPPLY:
return self.command_supply(command.agent_id, start_time=command.start_time, subtask=subtask)
elif command.action_type == ActionType.SWITCH_WEAPON:
return self.command_switch(command.agent_id, command.weapon_idx, start_time=command.start_time, subtask=subtask)
else:
raise ValueError(f"Invalid command type: {command.action_type}")
def todo_action(self, agent_id: str, retry=0) -> Optional[Action]:
agent = self.get_agent(agent_id)
if not agent.todo:
return None
action = agent.todo.pop(0)
valid, msg = self.check_validity(action)
if valid or retry > 2:
if agent.cmd_todo[0].action_type == action.end_type:
agent.cmd_todo.pop(0)
return action
print(f"Retrying to plan actions for {agent_id}: {msg}")
if action.action_type == ActionType.RELEASE:
for pos in get_adj_pos(action.des, 1):
if pos in self.map.nodes and not self.map.nodes[pos].agent_id:
action.des = pos
return action
return None # Wait
command = agent.cmd_todo.pop(0)
# 处理异常
while agent.todo and agent.todo[0].end_type != command.action_type:
agent.todo.pop(0)
if agent.todo:
agent.todo.pop(0)
current_actions = agent.todo.copy()
current_commands = agent.cmd_todo.copy()
self.make_plan(command)
agent.todo.extend(current_actions)
agent.cmd_todo.extend(current_commands)
return self.todo_action(agent_id, retry=retry+1)
def check_validity(self, action: Action) -> Tuple[bool, str]:
"""
检查动作是否合法
"""
if action.action_type == ActionType.END_OF_TURN:
return True, "OK"
map_data = self.map.nodes
spotted_enemy_ids = self.spotted_enemy_ids
if action.agent_id not in self.current_agents_dict:
return False, f"Invalid agent id: {action.agent_id}"
agent = self.get_agent(action.agent_id)
if action.action_type not in agent.action_space.keys():
return False, f"Invalid action {action.action_type.name} for agent {agent.agent_id[:8]}"
if action.action_type == ActionType.MOVE:
des = action.des
node = map_data[des]
if node.faction_id == agent.faction_id and (not agent.agent_type in (target := self.get_agent(node.agent_id)).available_types or len(target.parked_agents) >= target.capacity):
return False, f"{node.agent_id[:8]} at {des} is an ally but not available"
if node.agent_id in spotted_enemy_ids:
return False, f"{des} has been occupied by enemy {node.agent_id[:8]}"
if get_axial_dis(agent.pos, des) > agent.mobility:
return False, f"Destination is out of range. Mobility: {agent.mobility} Euclidean distance: {get_axial_dis(agent.pos, des)}"
try:
mobility = min(agent.fuel, agent.mobility)
adj_pos = get_adj_pos(agent.pos, int(mobility))
aval_data = {
pos: node
for pos in adj_pos if pos in map_data and
(
(node := map_data[pos]).faction_id == self.faction_id
or
node.agent_id not in spotted_enemy_ids
)
}
aval_nodes = astar_search(aval_data, agent.move_type, start=agent.pos, limit=mobility)
if action.des in aval_nodes:
return True, "OK"
else:
return False, f"{des} could not be reached in one step."
except Exception as e:
return False, f"Failed to move to {des}: {e}"
elif action.action_type == ActionType.ATTACK:
try:
target_agent = self.get_agent(action.target_id)
except Exception as e:
return False, f"Invalid target agent id: {action.target_id}, ignored"
des = action.des
if agent.faction_id == target_agent.faction_id:
return False, f"Cannot attack allies. Attack: {action.agent_id}. Defend: {action.target_id}"
if agent.weapon is None:
return False, "No weapon equipped"
if agent.ammo <= 0:
return False, "No ammo left"
if target_agent.move_type not in agent.strike_types:
return False, "Target agent cannot be attacked with this weapon"
if get_axial_dis(agent.pos, des) > agent.attack_range:
return False, f"Target agent is out of range. Attack range: {agent.attack_range} Distance: {get_axial_dis(agent.pos, des)}"
elif action.action_type == ActionType.SWITCH_WEAPON:
weapon_idx = action.weapon_idx
around_supply = False
for pos in get_adj_pos(agent.pos, 1):
if pos in map_data and map_data[pos].agent_id and (target_agent := self.get_agent(map_data[pos].agent_id)).faction_id == agent.faction_id and target_agent.has_supply:
around_supply = True
if not around_supply:
return False, "Agent is not around a supply module"
if weapon_idx >= len(agent.switchable_weapons):
return False, f"Invalid weapon id {weapon_idx}. Max id: {len(agent.switchable_weapons) - 1}"
elif action.action_type == ActionType.RELEASE:
if action.target_id >= len(agent.parked_agents):
return False, f"Trying to release an agent that is not in the queue. Target id: {action.target_id}, Current queue: {agent.parked_agents}"
target_agent = agent.parked_agents[action.target_id]
des = action.des
if get_axial_dis(agent.pos, des) > 1:
return False, "Only adjacent positions can be released"
if map_data[des].team_id != -1:
return False, "Cannot release agent to occupied position"
elif action.action_type == ActionType.INTERACT:
agent = self.get_agent(action.agent_id)
target_agent = self.get_agent(action.target_id)
if get_axial_dis(agent.pos, target_agent.pos) > 1:
return False, "Target agent is out of range"
if agent.faction_id != target_agent.faction_id:
return False, "Cannot interact with enemy agent"
if agent.agent_type not in target_agent.available_types:
return False, f"Target agent {target_agent.agent_id[:8]} is not interactable with agent {agent}"
if len(target_agent.parked_agents) >= target_agent.capacity:
return False, f"Target agent {target_agent.agent_id[:8]} full"
return True, "OK"
def update(self, sync_info):
"""
使用真实引擎
"""
sync_agents = sync_info['units']
spotted_enemies = sync_info['spottedHostiles']
map_data = self.map.nodes
reward = 0
for sync_agent in [*sync_agents, *spotted_enemies]:
agent_id = sync_agent['agent_id']
agent = self.get_agent(agent_id)
# 清除与地图关联
if agent.pos != (-1, -1):
node = map_data[agent.pos]
if node.agent_id == agent_id:
node.agent_id = None
node.team_id = -1
node.faction_id = -1
agent.pos = (sync_agent['pos']['q'], sync_agent['pos']['r'])
agent.fuel = sync_agent['fuel']
if agent.faction_id != self.faction_id:
reward += agent.endurance - sync_agent['endurance']
agent.endurance = sync_agent['endurance']
agent.commenced_action = sync_agent['commenced_action']
current_weapon = sync_agent['current_weapon']
if current_weapon:
for weapon in agent.switchable_weapons:
if weapon.name == current_weapon:
agent.weapon = weapon
break
agent.weapon.ammo = sync_agent['ammo']
else:
agent.weapon = None
# 更新与地图关联
if agent.alive:
node = map_data[agent.pos]
node.agent_id = agent_id
node.team_id = agent.team_id
node.faction_id = agent.faction_id
for agent_id in self._spotted_enemy_ids:
agent = self.get_agent(agent_id)
# 清除与地图关联
if agent.pos != (-1, -1):
node = map_data[agent.pos]
if node.agent_id == agent_id:
node.agent_id = None
node.team_id = -1
node.faction_id = -1
# 更新敌人感知
self._spotted_enemy_ids = [enemy['agent_id'] for enemy in spotted_enemies]
"""
NOTE: Clear the states before observation
"""
self._legal_actions = None
self._legal_action_ids = None
self._spotted_agents = None
obs = self.observe()
truncated = self.episode_steps >= self.max_episode_steps
terminated = self.teams[self.enemy_id].alive_count == 0
self._cumulative_rewards[self.player_id] += reward
info = {}
info["accum_reward"] = (self._cumulative_rewards[0], self._cumulative_rewards[1])
info["next_player"] = self.player_id + 1 if not terminated else -1
if self.replay_path is not None:
self._frames[self.player_id].append(self.render(mode="rgb_array"))
if terminated or truncated:
# The eval_episode_return is calculated from Player 1's perspective
info["eval_episode_return"] = reward if self.player_id == 0 else -reward
info["done_reason"] = "Terminated" if terminated else "Truncated"
if self.replay_path is not None:
self.save_replay()
return obs, reward, terminated, truncated, info