2024-12-18 21:30:53 +08:00
import numpy as np
import json
import os
from datetime import datetime
from typing import List , Optional
import gymnasium as gym
from . common . utils import *
from . common . agent import *
class YesCmdrEnv ( gym . Env ) :
metadata = { " render_modes " : [ " human " , " rgb_array " ] , " render_fps " : 4 }
def __init__ ( self ,
use_real_engine : bool = False ,
replay_path : str = None ,
campaign_id : str = " NOID " ,
team_id : int = 0 ,
2024-12-19 16:56:48 +08:00
faction_id : int = 0 ,
2024-12-18 21:30:53 +08:00
max_team : int = 2 ,
max_faction : int = 2 ,
data_path : str = " ../data " ,
max_episode_steps : int = 1000 ,
bot_version : str = " v0 " ,
war_fog : bool = True ) - > None :
self . _init_flag = False
# set other properties...
self . use_real_engine = use_real_engine
self . replay_path = replay_path
self . campaign_id = campaign_id
self . init_team_id = team_id
2024-12-19 16:56:48 +08:00
self . init_faction_id = faction_id
2024-12-18 21:30:53 +08:00
self . max_team = max_team
self . max_faction = max_faction
self . data_path = data_path
self . max_episode_steps = max_episode_steps
self . bot_version = bot_version
self . war_fog = war_fog
def reset ( self ) :
# reset the environment...
# 实际的环境初始化是在第一次调用 reset 方法时进行的
if not self . _init_flag :
self . _init_flag = True
self . _make_env ( )
if self . replay_path is not None :
from . common . renderer import Renderer
self . _renderer = Renderer ( self . map . nodes . values ( ) )
else :
for node in self . map . nodes . values ( ) :
node . reset ( )
for team in self . teams :
for agent_id , agent in team . agents . items ( ) :
agent . reset ( )
if agent . pos == ( - 1 , - 1 ) : # Illegal position for padding
continue
assert agent . pos in self . map . nodes , f " Invalid agent position: { agent . pos } "
self . map . nodes [ agent . pos ] . agent_id = agent_id
self . map . nodes [ agent . pos ] . team_id = team . team_id
2024-12-19 09:56:57 +08:00
self . map . nodes [ agent . pos ] . faction_id = team . faction_id
2024-12-18 21:30:53 +08:00
if self . replay_path is not None :
2024-12-19 09:56:57 +08:00
self . _frames = [ [ ] for _ in range ( self . max_team ) ]
2024-12-18 21:30:53 +08:00
self . _legal_actions = None
self . _legal_action_ids = None
self . _spotted_enemy_ids = None
2024-12-19 09:56:57 +08:00
self . _spotted_enemy_agents = None
2024-12-18 21:30:53 +08:00
self . _team_id = self . init_team_id
self . _spotted_agents = None
self . episode_steps = 0
2024-12-19 09:56:57 +08:00
self . _cumulative_rewards = np . zeros ( self . max_team )
2024-12-18 21:30:53 +08:00
obs = self . observe ( )
info = {
" next_team " : self . team_id + 1
}
return obs , info
def step ( self , action : Union [ int , Action ] ) :
if isinstance ( action , int ) :
if action not in self . legal_action_ids :
print ( f " Invalid action will be ignored: { action } " )
info = { }
info [ " accum_reward " ] = ( self . _cumulative_rewards [ 0 ] , self . _cumulative_rewards [ 1 ] )
info [ " next_team " ] = self . team_id + 1
info [ " exception " ] = f " Invalid action { action } "
return self . observe ( ) , 0 , False , False , info
action = self . id_to_action ( action )
elif isinstance ( action , Action ) :
valid , reason = self . check_validity ( action )
if not valid :
if action . action_type == ActionType . MOVE and reason . find ( " occupied " ) > 0 :
print ( " Trying to move to an adjacent position that is empty. " )
dis = float ( ' inf ' )
for _action in self . legal_actions :
if _action . agent_id == action . agent_id and _action . action_type == ActionType . MOVE and get_axial_dis ( _action . des , action . des ) < dis :
dis = get_axial_dis ( _action . des , action . des )
action = _action
else :
print ( f " Invalid action will be ignored: { action } " )
print ( f " Reason: { reason } " )
info = { }
info [ " accum_reward " ] = ( self . _cumulative_rewards [ 0 ] , self . _cumulative_rewards [ 1 ] )
info [ " next_team " ] = self . team_id + 1
info [ " exception " ] = f " Invalid action: { action } . Reason: { reason } "
return self . observe ( ) , 0 , False , False , info
assert ( isinstance ( action , Action ) )
reward , terminated = self . _player_step ( action )
"""
NOTE : Clear the states before observation
"""
self . _legal_actions = None
self . _legal_action_ids = None
self . _spotted_enemy_ids = None
2024-12-19 09:56:57 +08:00
self . _spotted_enemy_agents = None
2024-12-18 21:30:53 +08:00
obs = self . observe ( )
truncated = self . episode_steps > = self . max_episode_steps
self . _cumulative_rewards [ self . team_id ] + = reward
info = { }
info [ " accum_reward " ] = ( self . _cumulative_rewards [ 0 ] , self . _cumulative_rewards [ 1 ] )
info [ " next_team " ] = self . team_id + 1 if not terminated else - 1
if self . replay_path is not None :
self . _frames [ self . team_id ] . append ( self . render ( mode = " rgb_array " ) )
if terminated or truncated :
# The eval_episode_return is calculated from Player 1's perspective
2024-12-19 16:56:48 +08:00
info [ " eval_episode_return " ] = reward if self . faction_id == self . init_faction_id else - reward
2024-12-18 21:30:53 +08:00
info [ " done_reason " ] = " Terminated " if terminated else " Truncated "
if self . replay_path is not None :
self . save_replay ( )
return obs , reward , terminated , truncated , info
def render ( self , mode = " human " , attack_agent = None , defend_agent = None ) :
if mode == " rgb_array " :
2024-12-19 09:56:57 +08:00
return self . _renderer . render ( self . union_agents , self . spotted_enemy_agents , attack_agent , defend_agent )
2024-12-18 21:30:53 +08:00
elif mode == " human " :
return self . observe ( )
else :
raise ValueError ( " Invalid render mode: {} " . format ( mode ) )
def action_to_id ( self , action : Action ) - > int :
if action . action_type == ActionType . END_OF_TURN :
return 0
action_space = self . action_space
if action . agent_id not in action_space . keys ( ) :
raise ValueError ( f " Invalid agent_id: { action . agent_id [ : 8 ] } " )
agent_action_space = action_space [ action . agent_id ]
if action . action_type not in agent_action_space . keys ( ) :
agent = self . get_agent ( action . agent_id )
print ( agent . available_types , agent . parked_agents )
print ( action )
raise ValueError ( f " Invalid action_type: { action . action_type . name } for agent { action . agent_id [ : 8 ] } . Available action_types: { agent_action_space . keys ( ) } " )
action_id = 1
for agent_id , agent_action_space in action_space . items ( ) :
for action_type , action_space in agent_action_space . items ( ) :
if agent_id != action . agent_id or action_type != action . action_type :
action_id + = action_space . n
else :
if action . action_type == ActionType . SWITCH_WEAPON : # 以下标编码,不以位置编码
assert action . weapon_idx < action_space . n , f " Invalid weapon_idx: { action . weapon_idx } , action_space.n: { action_space . n } "
action_id + = action . weapon_idx
elif action . action_type == ActionType . RELEASE :
assert action . target_id < action_space . n , f " Invalid target_idx: { action . target_id } , action_space.n: { action_space . n } "
action_id + = action . target_id # 此处为下标
else :
pos , des = self . get_agent ( action . agent_id ) . pos , action . des
encode_id = encode_axial ( ( des [ 0 ] - pos [ 0 ] , des [ 1 ] - pos [ 1 ] ) )
action_id + = encode_id - 1
if action . action_type == ActionType . MOVE :
cur_range = self . get_agent ( action . agent_id ) . mobility
elif action . action_type == ActionType . ATTACK :
cur_range = self . get_agent ( action . agent_id ) . attack_range
elif action . action_type == ActionType . INTERACT :
cur_range = 1
assert encode_id - 1 < action_space . n , f " Action { action . action_type . name } Distance: { get_axial_dis ( pos , des ) } , cur_range: { cur_range } encode_id: { encode_id } action_space.n: { action_space . n } \n { action } "
return int ( action_id )
assert False , " Should not reach here "
def id_to_action ( self , action_id : int ) - > Action :
if action_id == 0 :
return Action ( ActionType . END_OF_TURN )
action_id - = 1
action_space = self . action_space
for agent_id , agent_action_space in action_space . items ( ) :
for action_type , action_space in agent_action_space . items ( ) :
if action_id < action_space . n :
if action_type == ActionType . SWITCH_WEAPON :
return Action ( agent_id = agent_id , action_type = action_type , weapon_idx = action_id , weapon_name = self . get_agent ( agent_id ) . weapon . name )
elif action_type == ActionType . RELEASE :
return Action ( agent_id = agent_id , target_id = action_id , action_type = action_type )
else :
pos = self . get_agent ( agent_id ) . pos
delta = decode_axial ( action_id + 1 ) # 编码 0 是原位置
des = ( pos [ 0 ] + delta [ 0 ] , pos [ 1 ] + delta [ 1 ] )
return Action ( agent_id = agent_id , des = des , action_type = action_type )
else :
action_id - = action_space . n
raise ValueError ( " Invalid action_id: {} " . format ( action_id ) )
def get_agent ( self , agent_id : str ) - > Agent :
for team in self . teams :
if agent_id in team . agents . keys ( ) :
agent = team . agents [ agent_id ]
# if not agent.alive:
# raise ValueError(f"Agent {agent_id} is not alive")
return agent
raise ValueError ( f " Invalid agent id: { agent_id } " )
2024-12-20 19:13:37 +08:00
def get_agents ( self , agent_ids : List [ str ] ) - > List [ Agent ] :
agents = [ ]
for agent_id in agent_ids :
agent = self . get_agent ( agent_id )
# if not agent.alive:
# raise ValueError(f"Agent {agent_id} is not alive")
agents . append ( agent )
return agents
def get_agents_dict ( self , agent_ids : List [ str ] ) - > Dict [ str , dict ] :
agents_dict = { }
for agent_id in agent_ids :
agent = self . get_agent ( agent_id )
agents_dict [ agent_id ] = agent . to_dict ( )
return agents_dict
2024-12-18 21:30:53 +08:00
def get_node ( self , pos : Tuple [ int , int ] ) - > TileNode :
if pos not in self . map . nodes :
raise ValueError ( f " Invalid position: { pos } " )
return self . map . nodes [ pos ]
@property
def legal_actions ( self ) - > List [ Action ] :
if self . _legal_actions is not None :
return self . _legal_actions
map_data = self . map . nodes
2024-12-19 16:56:48 +08:00
team_data = self . current_agents_dict
2024-12-18 21:30:53 +08:00
spotted_enemy_ids = self . spotted_enemy_ids
_legal_actions = [ Action ( ActionType . END_OF_TURN ) ] # Default action
2024-12-19 16:56:48 +08:00
for agent_id , agent in team_data . items ( ) :
2024-12-18 21:30:53 +08:00
if agent . commenced_action or not agent . alive or agent . pos == ( - 1 , - 1 ) or agent . is_carried :
continue
pos = agent . pos
# 移动
mobility = min ( agent . fuel , agent . mobility )
if mobility > 0 :
adj_pos = get_adj_pos ( pos , int ( mobility ) )
aval_data = {
pos : node
for pos in adj_pos if pos in map_data and
(
2024-12-19 09:56:57 +08:00
( node := map_data [ pos ] ) . faction_id == self . faction_id
2024-12-18 21:30:53 +08:00
or
node . agent_id not in spotted_enemy_ids
)
}
aval_nodes = astar_search ( aval_data , agent . move_type , start = pos , limit = mobility )
for des in aval_nodes :
if des == pos :
continue
node = map_data [ des ]
2024-12-19 09:56:57 +08:00
if node . faction_id != self . faction_id and node . agent_id not in spotted_enemy_ids :
2024-12-18 21:30:53 +08:00
# 一个格子只能有一个单位
_legal_actions . append ( Action (
agent_id = agent_id ,
des = des ,
action_type = ActionType . MOVE
) )
2024-12-19 09:56:57 +08:00
elif node . faction_id == self . faction_id and agent . agent_type in ( target := self . get_agent ( node . agent_id ) ) . available_types \
2024-12-18 21:30:53 +08:00
and len ( target . parked_agents ) < target . capacity :
# 除非终点有可以停靠的单位
_legal_actions . append ( Action (
agent_id = agent_id ,
target_id = node . agent_id ,
des = des ,
action_type = ActionType . MOVE
) )
# 进攻
if agent . attack_range > 0 and agent . ammo > 0 :
for des in get_adj_pos ( pos , agent . attack_range ) :
if des not in map_data :
continue
node = map_data [ des ]
if node . agent_id in spotted_enemy_ids and self . get_agent ( node . agent_id ) . move_type in agent . strike_types :
# 已发现敌军
_legal_actions . append ( Action (
agent_id = agent_id ,
target_id = node . agent_id ,
action_type = ActionType . ATTACK ,
des = des
) )
# 交互
for des in get_adj_pos ( pos , 1 ) :
if des not in map_data :
continue
node = map_data [ des ]
2024-12-19 09:56:57 +08:00
if node . faction_id == self . faction_id and agent . agent_type in ( target := self . get_agent ( node . agent_id ) ) . available_types \
2024-12-18 21:30:53 +08:00
and len ( target . parked_agents ) < target . capacity :
_legal_actions . append ( Action (
agent_id = agent_id ,
target_id = node . agent_id ,
des = des ,
action_type = ActionType . INTERACT
) )
# 释放
if agent . parked_agents :
for des in get_adj_pos ( pos , 1 ) :
if des not in map_data :
continue
node = map_data [ des ]
if node . team_id != - 1 :
continue
for idx , agent_to_release in enumerate ( agent . parked_agents ) :
if get_cost ( agent_to_release . move_type , node . terrain_type ) != 0 :
_legal_actions . append ( Action (
agent_id = agent_id ,
target_id = idx ,
des = des ,
action_type = ActionType . RELEASE
) )
break
# 切换武器
if agent . switchable_weapons :
for des in get_adj_pos ( pos , 1 ) :
if des not in map_data :
continue
node = map_data [ des ]
2024-12-19 09:56:57 +08:00
if node . faction_id == self . faction_id and self . get_agent ( node . agent_id ) . has_supply :
2024-12-18 21:30:53 +08:00
for weapon_idx , weapon in enumerate ( agent . switchable_weapons ) :
if not agent . weapon or weapon . name != agent . weapon . name :
_legal_actions . append ( Action (
agent_id = agent_id ,
target_id = weapon . name ,
weapon_idx = weapon_idx ,
action_type = ActionType . SWITCH_WEAPON
) )
self . _legal_actions = _legal_actions
return self . _legal_actions
@property
def legal_action_ids ( self ) - > List [ int ] :
if self . _legal_action_ids is not None :
return self . _legal_action_ids
self . _legal_action_ids = [ self . action_to_id ( action ) for action in self . legal_actions ]
return self . _legal_action_ids
@property
def spotted_enemy_ids ( self ) - > List [ str ] :
if not self . war_fog :
2024-12-19 09:56:57 +08:00
return [ enemy_id for enemy_id , enemy_agent in self . enemy_agents_dict . items ( ) if enemy_agent . alive ]
2024-12-18 21:30:53 +08:00
if self . _spotted_enemy_ids is not None :
return self . _spotted_enemy_ids
2024-12-19 09:56:57 +08:00
current_agents = self . current_agents
2024-12-18 21:30:53 +08:00
enemy_data = self . enemy_agents_dict
map_data = self . map . nodes
_spotted_enemy_ids = [ ]
2024-12-19 09:56:57 +08:00
for agent in current_agents :
2024-12-18 21:30:53 +08:00
info_level = agent . info_level
for des in get_adj_pos ( agent . pos , agent . info_level ) :
if des not in map_data :
continue
node = map_data [ des ]
dis = get_axial_dis ( agent . pos , des )
if dis == 1 : # Adjacent agents are always spotted
dis = - 1e9
2024-12-19 09:56:57 +08:00
if node . faction_id != self . faction_id \
2024-12-18 21:30:53 +08:00
and enemy_data [ node . agent_id ] . stealth_level < info_level - dis + 1 :
_spotted_enemy_ids . append ( node . agent_id )
self . _spotted_enemy_ids = _spotted_enemy_ids
return self . _spotted_enemy_ids
def save_replay ( self ) :
if not os . path . exists ( self . replay_path ) :
os . makedirs ( self . replay_path )
timestamp = datetime . now ( ) . strftime ( " % Y % m %d % H % M % S " )
for team_id in range ( self . max_team ) :
path = os . path . join (
self . replay_path ,
f " tianqiong_ { timestamp } _Team { team_id + 1 } .mp4 "
)
self . display_frames_as_mp4 ( self . _frames [ team_id ] , path )
print ( f ' replay { path } saved! ' )
@staticmethod
def display_frames_as_mp4 ( frames : list , path : str , fps = 4 ) - > None :
assert path . endswith ( ' .mp4 ' ) , f ' path must end with .mp4, but got { path } '
import imageio
imageio . mimwrite ( path , frames , fps = fps )
def _make_env ( self ) :
# Configuration for file paths
map_path = f " { self . data_path } /MapInfo.json "
agents_path = f " { self . data_path } /AgentsInfo.json "
# Load map and agents information
with open ( map_path , ' r ' ) as f :
origin_map = json . load ( f )
with open ( agents_path , ' r ' ) as f :
origin_agents = json . load ( f )
_nodes = {
( node := dict_to_node ( node_dict ) ) . pos : node for node_dict in origin_map
}
agents_dict_lists = [
[ agent_dict for agent_dict in origin_agents if agent_dict [ " team_id " ] == team_id ]
for team_id in range ( self . max_team )
]
_agents = [
{
( agent := dict_to_agent ( agent_dict ) ) . agent_id : agent
for agent_dict in agents_dict_lists [ team_id ]
}
for team_id in range ( self . max_team )
]
_teams = [
Team (
team_id = _team_id ,
2024-12-19 16:56:48 +08:00
faction_id = agents_dict_lists [ _team_id ] [ 0 ] [ " faction_id " ] if agents_dict_lists [ _team_id ] else - 1 , # 可能为空
2024-12-18 21:30:53 +08:00
agents = _agents [ _team_id ]
)
for _team_id in range ( self . max_team )
]
for team in _teams :
for agent_id , agent in team . agents . items ( ) :
if agent . pos == ( - 1 , - 1 ) : # Illegal position for padding
continue
assert agent . pos in _nodes , f " Invalid agent position: { agent . pos } "
_nodes [ agent . pos ] . agent_id = agent_id
_nodes [ agent . pos ] . team_id = team . team_id
2024-12-19 09:56:57 +08:00
_nodes [ agent . pos ] . faction_id = team . faction_id
2024-12-18 21:30:53 +08:00
_map = Map ( _nodes )
self . map = _map
self . teams = _teams
self . obs_shape = _map . width , _map . height
# Get action space
self . _action_spaces = [
gym . spaces . Dict (
{
f " team_ { team . team_id + 1 } " : gym . spaces . Dict (
{
ActionType . END_OF_TURN : gym . spaces . Discrete ( 1 ) # Default action
}
) ,
* * {
agent_id : agent . action_space for agent_id , agent in team . agents . items ( )
}
}
)
for team in self . teams
]
self . _action_space_sizes = [
sum ( action_space . n for agent_action_space in self . _action_spaces [ team_id ] . values ( ) for action_space in agent_action_space . values ( ) )
for team_id in range ( self . max_team )
]
self . _team_id = 0
for action_space_size in self . _action_space_sizes :
for action_id in range ( action_space_size ) :
assert self . action_to_id ( self . id_to_action ( action_id ) ) == action_id
self . _team_id + = 1
# Get observation space
self . _observation_spaces = [
gym . spaces . Dict (
{
" action_mask " : gym . spaces . Box ( 0 , 1 , ( self . _action_space_sizes [ team_id ] , ) , dtype = np . int8 ) , # Different size for each team
" union_endurance " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . float32 ) ,
" union_info_level " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . int16 ) ,
" union_stealth_level " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . int16 ) ,
" union_mobility " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . float32 ) ,
" union_defense " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . int16 ) ,
" union_damage " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . float32 ) ,
" union_fuel " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . float32 ) ,
" union_ammo " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . int16 ) ,
" enemy_endurance " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . float32 ) ,
" enemy_info_level " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . int16 ) ,
" enemy_stealth_level " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . int16 ) ,
" enemy_mobility " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . float32 ) ,
" enemy_defense " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . int16 ) ,
" enemy_damage " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . float32 ) ,
" enemy_fuel " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . float32 ) ,
" enemy_ammo " : gym . spaces . Box ( 0 , 1000 , self . obs_shape , dtype = np . int16 )
}
)
for team_id in range ( self . max_team )
]
self . _reward_space = gym . spaces . Box ( low = - 1000 , high = 1000 , shape = ( 1 , ) , dtype = np . float32 )
# 预处理己方联盟和所有敌方的单位信息
self . _union_agents = [ [ ] for _ in range ( self . max_faction ) ]
self . _union_agents_dict = [ { } for _ in range ( self . max_faction ) ]
self . _enemy_agents = [ [ ] for _ in range ( self . max_faction ) ]
self . _enemy_agents_dict = [ { } for _ in range ( self . max_faction ) ]
for team_id in range ( self . max_team ) :
faction_id = self . teams [ team_id ] . faction_id
agents_list = list ( self . teams [ team_id ] . agents . values ( ) )
agents_dict = self . teams [ team_id ] . agents
self . _union_agents [ faction_id ] . extend ( agents_list )
for other_faction_id in range ( self . max_faction ) :
if other_faction_id != faction_id :
self . _enemy_agents [ other_faction_id ] . extend ( agents_list )
self . _enemy_agents_dict [ other_faction_id ] | = agents_dict
@property
def team_id ( self ) :
return self . _team_id
@property
def current_agents ( self ) :
return self . teams [ self . _team_id ] . agents . values ( )
@property
def current_agents_dict ( self ) :
return self . teams [ self . _team_id ] . agents
2024-12-19 09:56:57 +08:00
@property
def union_agents ( self ) :
return self . _union_agents [ self . faction_id ]
@property
def union_agents_dict ( self ) :
return self . _union_agents_dict [ self . faction_id ]
2024-12-18 21:30:53 +08:00
@property
def enemy_agents ( self ) :
return self . _enemy_agents [ self . faction_id ]
# if self._enemy_agents is not None:
# return self._enemy_agents
# _enemy_agents = []
# for team_id in range(self.max_team):
# if self.teams[team_id].faction_id != self.faction_id:
# _enemy_agents.extend(list(self.teams[team_id].agents.values()))
# self._enemy_agents = _enemy_agents
# return self._enemy_agents
@property
def enemy_agents_dict ( self ) :
return self . _enemy_agents_dict [ self . faction_id ]
# if self._enemy_agents_dict is not None:
# return self._enemy_agents_dict
# _enemy_agents_dict = {}
# for team_id in range(self.max_team):
# if self.teams[team_id].faction_id != self.faction_id:
# for agent_id, agent in self.teams[team_id].agents.items():
# _enemy_agents_dict[agent_id] = agent
# self._enemy_agents_dict = _enemy_agents_dicct
# return self._enemy_agents_dict
@property
def spotted_enemy_agents ( self ) :
if self . _spotted_enemy_agents is not None :
return self . _spotted_enemy_agents
_spotted_enemy_agents = [ self . enemy_agents_dict [ enemy_id ] for enemy_id in self . spotted_enemy_ids ]
self . _spotted_enemy_agents = _spotted_enemy_agents
2024-12-19 09:56:57 +08:00
return self . _spotted_enemy_agents
2024-12-18 21:30:53 +08:00
@property
def faction_id ( self ) :
2024-12-19 16:56:48 +08:00
if self . _team_id > = self . max_team :
raise ValueError ( f " Team id { self . _team_id } exceeds max id { self . max_team - 1 } " )
2024-12-18 21:30:53 +08:00
return self . teams [ self . _team_id ] . faction_id
@property
def observation_space ( self ) :
return self . _observation_spaces [ self . team_id ]
@property
def action_space ( self ) :
return self . _action_spaces [ self . team_id ]
@property
def action_space_size ( self ) :
return self . _action_space_sizes [ self . team_id ]
@property
def reward_space ( self ) :
return self . _reward_space
def random_action ( self ) - > Action :
return np . random . choice ( self . legal_actions )
def next_turn ( self ) :
self . _team_id = ( self . _team_id + 1 ) % self . max_team
def bot_action ( self ) - > Action :
if self . bot_version == " v0 " :
attack_actions = [ action for action in self . legal_actions if action . action_type == ActionType . ATTACK ]
if attack_actions :
return np . random . choice ( attack_actions )
else :
return self . random_action ( )
else :
raise NotImplementedError ( f " Invalid bot version: { self . bot_version } " )
def _player_step ( self , action : Action ) :
if action . action_type == ActionType . END_OF_TURN :
"""
NOTE : here exchange the player
"""
# 清除动作标记
for agent in self . current_agents :
agent . commenced_action = False
# 切换玩家
self . episode_steps + = 1
if not self . use_real_engine :
self . next_turn ( )
reward , terminated = 0 , False # 这里可以加上惩罚项
elif action . action_type == ActionType . MOVE :
reward = self . _move ( action . agent_id , action . des , action . target_id )
terminated = False
elif action . action_type == ActionType . ATTACK :
reward , terminated = self . _attack ( action . agent_id , action . target_id )
elif action . action_type == ActionType . INTERACT :
reward = self . _interact ( action . agent_id , action . target_id )
terminated = False
elif action . action_type == ActionType . RELEASE :
reward = self . _release ( action . agent_id , action . target_id , action . des )
terminated = False
elif action . action_type == ActionType . SWITCH_WEAPON :
reward = self . _switch_weapon ( action . agent_id , action . weapon_idx )
terminated = False
else :
raise ValueError ( f " Invalid action type: { action . action_type } " )
for agent in self . current_agents :
for pos in get_adj_pos ( agent . pos , 1 ) :
2024-12-19 09:56:57 +08:00
if pos in self . map . nodes and self . map . nodes [ pos ] . faction_id == self . faction_id :
2024-12-18 21:30:53 +08:00
adj_agent = self . get_agent ( self . map . nodes [ pos ] . agent_id )
if not adj_agent . has_supply :
continue
module = adj_agent . supply
if module . add_endurance > 0 :
agent . endurance + = min ( module . add_endurance , agent . max_endurance - agent . endurance )
else :
agent . endurance = agent . max_endurance
if agent . weapon is not None :
if module . add_ammo > 0 :
agent . weapon . ammo + = min ( module . add_ammo , agent . weapon . max_ammo - agent . weapon . ammo )
else :
agent . weapon . ammo = agent . weapon . max_ammo
if module . add_fuel > 0 :
agent . fuel + = min ( module . add_fuel , agent . max_fuel - agent . fuel )
else :
agent . fuel = agent . max_fuel
return reward , terminated
def _move ( self , agent_id , des , target_id ) :
map_data = self . map . nodes
agent = self . current_agents_dict [ agent_id ]
spotted_enemy_ids = self . spotted_enemy_ids
origin_pos = agent . pos
# 清除原位置关联
node = map_data [ agent . pos ]
node . agent_id = None
node . team_id = - 1
2024-12-19 09:56:57 +08:00
node . faction_id = - 1
2024-12-18 21:30:53 +08:00
mobility = min ( agent . mobility , agent . fuel )
adj_pos = get_adj_pos ( agent . pos , int ( mobility ) )
aval_data = {
pos : node
for pos in adj_pos if pos in map_data and
(
2024-12-19 09:56:57 +08:00
( node := map_data [ pos ] ) . faction_id == self . faction_id
2024-12-18 21:30:53 +08:00
or
node . agent_id not in spotted_enemy_ids
)
}
path = astar_search ( aval_data , move_type = agent . move_type , start = agent . pos , goal = des , limit = mobility )
path = [ agent . pos ] + path
exception = False
for cur , nxt in zip ( path [ : - 1 ] , path [ 1 : ] ) :
2024-12-19 09:56:57 +08:00
if self . war_fog and map_data [ nxt ] . faction_id != self . faction_id : # 前进方向遇到敌军,停下
2024-12-18 21:30:53 +08:00
exception = True
break
if self . replay_path is not None :
self . _frames [ self . team_id ] . append ( self . render ( mode = " rgb_array " ) )
agent . fuel - = get_cost ( agent . move_type , map_data [ nxt ] . terrain_type )
agent . pos = nxt
# 更新位置以及与地图的关联
cur = agent . pos
agent . commenced_action = True
if map_data [ cur ] . team_id == - 1 :
map_data [ cur ] . agent_id = agent_id
map_data [ cur ] . team_id = self . team_id
2024-12-19 09:56:57 +08:00
map_data [ cur ] . faction_id = self . faction_id
2024-12-18 21:30:53 +08:00
elif map_data [ cur ] . agent_id == target_id :
carry_agent = self . get_agent ( target_id )
2024-12-19 09:56:57 +08:00
assert agent . faction_id == carry_agent . faction_id
2024-12-18 21:30:53 +08:00
assert agent . agent_type in carry_agent . available_types , f " { agent } not in { carry_agent . available_types } (id: { carry_agent . agent_id [ : 8 ] } ). "
assert len ( carry_agent . parked_agents ) < carry_agent . capacity , f " Agent { carry_agent [ : 8 ] } is available but full "
carry_agent . parked_agents . append ( agent )
agent . is_carried = True
else :
assert exception
return 0
def _attack ( self , attack_id , target_id ) :
2024-12-19 09:56:57 +08:00
agent = self . get_agent ( attack_id )
target_agent = self . get_agent ( target_id )
2024-12-18 21:30:53 +08:00
if self . replay_path is not None :
self . _frames [ self . team_id ] . extend ( [ self . render ( mode = " rgb_array " , attack_agent = agent , defend_agent = target_agent ) ] * 4 )
damage = max ( agent . damage - target_agent . defense , 0 )
damage = min ( damage , target_agent . endurance )
target_agent . endurance - = damage
agent . weapon . ammo - = 1
agent . commenced_action = True
2024-12-19 16:56:48 +08:00
terminated = False
2024-12-18 21:30:53 +08:00
if target_agent . endurance < = 0 :
# 删除与地图关联
node = self . get_node ( target_agent . pos )
node . agent_id = None
node . team_id = - 1
2024-12-19 09:56:57 +08:00
node . faction_id = - 1
2024-12-18 21:30:53 +08:00
2024-12-19 16:56:48 +08:00
terminated = sum ( self . teams [ team_id ] . alive_count for team_id in range ( self . max_team ) if self . teams [ team_id ] . faction_id != self . faction_id ) == 0
2024-12-18 21:30:53 +08:00
2024-12-19 16:56:48 +08:00
return damage , terminated
2024-12-18 21:30:53 +08:00
def _interact ( self , agent_id , target_id ) :
agent = self . get_agent ( agent_id )
target = self . get_agent ( target_id )
target . parked_agents . append ( agent )
agent . is_carried = True
return 0 # 后续可以考虑修改奖励为补给的线性组合
def _release ( self , agent_id : str , target_idx : int , des : Tuple [ int , int ] ) :
agent = self . get_agent ( agent_id )
agent_to_release = agent . parked_agents [ target_idx ]
agent_to_release . pos = des
node = self . get_node ( des )
node . team_id = agent_to_release . team_id
2024-12-19 09:56:57 +08:00
node . faction_id = agent_to_release . faction_id
2024-12-18 21:30:53 +08:00
node . agent_id = agent_to_release . agent_id
for module in agent . modules :
if module . add_endurance > 0 :
agent_to_release . endurance + = min ( module . add_endurance , agent_to_release . max_endurance - agent_to_release . endurance )
else :
agent_to_release . endurance = agent_to_release . max_endurance
if agent_to_release . weapon is not None :
if module . add_ammo > 0 :
agent_to_release . weapon . ammo + = min ( module . add_ammo , agent_to_release . weapon . max_ammo - agent_to_release . weapon . ammo )
else :
agent_to_release . weapon . ammo = agent_to_release . weapon . max_ammo
if module . add_fuel > 0 :
agent_to_release . fuel + = min ( module . add_fuel , agent_to_release . max_fuel - agent_to_release . fuel )
else :
agent_to_release . fuel = agent_to_release . max_fuel
agent_to_release . is_carried = False
return 0 # 后续可以考虑修改奖励为补给的线性组合
def _switch_weapon ( self , agent_id : str , weapon_idx : int ) :
agent = self . get_agent ( agent_id )
agent . weapon . reset ( )
agent . weapon = agent . switchable_weapons [ weapon_idx ]
return 0
def observe ( self ) :
2024-12-19 09:56:57 +08:00
union_agents = self . union_agents
2024-12-18 21:30:53 +08:00
spotted_enemy_agents = self . spotted_enemy_agents
obs = { }
for desc , space in self . observation_space . items ( ) :
obs [ desc ] = np . zeros ( space . shape , dtype = space . dtype )
legal_actions_ids = self . legal_action_ids
for action_id in legal_actions_ids :
obs [ " action_mask " ] [ action_id ] = 1
2024-12-19 09:56:57 +08:00
for agent in union_agents :
2024-12-18 21:30:53 +08:00
pos = agent . pos
2024-12-19 09:56:57 +08:00
obs [ " union_info_level " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . info_level
obs [ " union_stealth_level " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . stealth_level
obs [ " union_mobility " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . mobility
obs [ " union_defense " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . defense
obs [ " union_damage " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . damage
obs [ " union_fuel " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . fuel
obs [ " union_ammo " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . ammo
obs [ " union_endurance " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . endurance
2024-12-18 21:30:53 +08:00
for agent in spotted_enemy_agents :
pos = agent . pos
obs [ " enemy_info_level " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . info_level
obs [ " enemy_stealth_level " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . stealth_level
obs [ " enemy_mobility " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . mobility
obs [ " enemy_defense " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . defense
obs [ " enemy_damage " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . damage
obs [ " enemy_fuel " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . fuel
obs [ " enemy_ammo " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . ammo
obs [ " enemy_endurance " ] [ pos [ 0 ] , pos [ 1 ] ] = agent . endurance
return obs
def command_move ( self , agent_id , des : Union [ Tuple [ int , int ] , List [ Tuple [ int , int ] ] ] , start_time = 0 , subtask = False ) - > Command :
agent = self . get_agent ( agent_id )
state = agent . todo [ - 1 ] . state if agent . todo else agent . state
if state . pos == des or state . pos in des :
return
map_data = self . map . nodes
origin_todo_length = len ( agent . todo ) # 记录原 todo 的长度
spotted_enemy_ids = self . spotted_enemy_ids
adj_pos = get_adj_pos ( state . pos , int ( state . fuel ) )
adj_pos . append ( state . pos )
aval_data = {
pos : node
for pos in adj_pos if pos in map_data and
(
2024-12-19 09:56:57 +08:00
( node := map_data [ pos ] ) . faction_id == self . faction_id
2024-12-18 21:30:53 +08:00
or
node . agent_id not in spotted_enemy_ids
)
}
path = get_path ( agent = agent , map_data = aval_data , start = state . pos , des = des , limit = state . fuel ) # 可能没有可行路径
for pos in path :
state . fuel - = get_cost ( agent . move_type , map_data [ pos ] . terrain_type )
state . pos = pos
agent . todo . append ( Action ( agent_id = agent_id , action_type = ActionType . MOVE , des = pos , start_time = start_time , state = state ) )
command = Command ( agent_id = agent_id ,
action_type = ActionType . MOVE ,
des = des ,
start_time = start_time ,
end_time = start_time + len ( agent . todo ) - origin_todo_length - 1 ,
state = state )
if not subtask :
end_time = command . end_time
for i in range ( origin_todo_length , len ( agent . todo ) ) :
agent . todo [ i ] . start_time = start_time + i - origin_todo_length
agent . todo [ i ] . end_time = end_time
agent . todo [ - 1 ] . end_type = ActionType . MOVE
agent . cmd_todo . append ( command )
return command
def command_attack ( self , attack_id , target_id , attack_count = 1 , start_time = 0 , subtask = False ) - > Command :
agent = self . get_agent ( attack_id )
target_agent = self . get_agent ( target_id )
origin_todo_length = len ( agent . todo ) # 记录原 todo 的长度
state = agent . todo [ - 1 ] . state if agent . todo else agent . state
if target_agent . move_type not in state . weapon . strike_types :
flag = False
no_enough_ammo = False
for weapon_idx , weapon in enumerate ( agent . switchable_weapons ) :
if weapon . strike_types and target_agent . move_type in weapon . strike_types :
if weapon . max_ammo < attack_count :
no_enough_ammo = True
try :
self . command_switch ( attack_id , weapon_idx , start_time = start_time , subtask = True )
flag = True
break
except Exception as e :
agent . todo = agent . todo [ : origin_todo_length ]
raise ValueError ( " Failed switching to the proper weapon " )
if not flag :
if no_enough_ammo :
agent . todo = agent . todo [ : origin_todo_length ]
raise ValueError ( " Agent has the striking weapon but no one with enough ammo " )
else :
agent . todo = agent . todo [ : origin_todo_length ]
raise ValueError ( " No proper weapon found " )
state = agent . todo [ - 1 ] . state if agent . todo else agent . state
if state . weapon . ammo < attack_count :
if state . weapon . max_ammo > = attack_count : # 不需要换武器,尝试补给
try :
self . command_supply ( attack_id , start_time = start_time , subtask = True )
except Exception as e :
agent . todo = agent . todo [ : origin_todo_length ]
raise ValueError ( " No enough ammo and failed to get supplies " )
else : # 需要换武器
flag = False
for weapon_idx , weapon in enumerate ( agent . switchable_weapons ) :
if weapon . strike_types and target_agent . move_type in weapon . strike_types and weapon . max_ammo > = attack_count :
try :
self . command_switch ( attack_id , weapon_idx , start_time = start_time , subtask = True )
flag = True
break
except Exception as e :
agent . todo = agent . todo [ : origin_todo_length ]
raise ValueError ( " Failed switching to the proper weapon " )
if not flag :
agent . todo = agent . todo [ : origin_todo_length ]
raise ValueError ( " No proper weapon has enough ammo " )
map_data = self . map . nodes
spotted_enemy_ids = self . spotted_enemy_ids
possible_des = [
pos for pos in get_adj_pos ( target_agent . pos , int ( agent . attack_range ) )
if pos in map_data and
(
2024-12-19 09:56:57 +08:00
( node := map_data [ pos ] ) . faction_id == self . faction_id
2024-12-18 21:30:53 +08:00
or
node . agent_id not in spotted_enemy_ids
)
]
try :
self . command_move ( attack_id , possible_des , start_time = start_time , subtask = True )
except Exception as e :
agent . todo = agent . todo [ : origin_todo_length ]
raise ValueError ( " Failed to move to the target " )
state = agent . todo [ - 1 ] . state if agent . todo else agent . state
for _ in range ( attack_count ) :
state . weapon . ammo - = 1
agent . todo . append ( Action ( agent_id = attack_id ,
action_type = ActionType . ATTACK ,
target_id = target_id ,
des = target_agent . pos ,
start_time = start_time ,
state = state ) )
command = Command ( agent_id = attack_id ,
action_type = ActionType . ATTACK ,
target_id = target_id ,
des = target_agent . pos ,
attack_count = attack_count ,
start_time = start_time ,
end_time = start_time + len ( agent . todo ) - origin_todo_length - 1 ,
state = state )
if not subtask :
end_time = command . end_time
for i in range ( origin_todo_length , len ( agent . todo ) ) :
agent . todo [ i ] . start_time = start_time + i - origin_todo_length
agent . todo [ i ] . end_time = end_time
agent . todo [ - 1 ] . end_type = ActionType . ATTACK
agent . cmd_todo . append ( command )
return command
def command_supply ( self , agent_id , start_time = 0 , subtask = False ) - > Command :
"""
需要补给的单位移动至有补给模块的单位附近
"""
agent = self . get_agent ( agent_id )
2024-12-19 09:56:57 +08:00
current_agents = self . current_agents
2024-12-18 21:30:53 +08:00
map_data = self . map . nodes
origin_todo_length = len ( agent . todo ) # 记录原 todo 的长度
possible_des = [ ]
2024-12-19 09:56:57 +08:00
for _agent in current_agents :
2024-12-18 21:30:53 +08:00
if _agent . has_supply :
for pos in get_adj_pos ( _agent . pos , 1 ) :
if pos in map_data :
possible_des . append ( pos )
# print(possible_des)
self . command_move ( agent_id , possible_des , start_time , subtask = True )
des = agent . todo [ - 1 ] . des if agent . todo else agent . pos
for pos in get_adj_pos ( des , 1 ) :
2024-12-19 09:56:57 +08:00
if pos in map_data and map_data [ pos ] . faction_id == agent . faction_id and self . get_agent ( map_data [ pos ] . agent_id ) . has_supply :
2024-12-18 21:30:53 +08:00
supply_id = map_data [ pos ] . agent_id
supply_module = self . get_agent ( supply_id ) . supply
state = agent . todo [ - 1 ] . state if agent . todo else agent . state
if supply_module . add_endurance > 0 :
state . endurance = min ( agent . max_endurance , state . endurance + supply_module . add_endurance )
else :
state . endurance = agent . max_endurance
if supply_module . add_ammo > 0 :
state . weapon . ammo = min ( state . weapon . max_ammo , state . weapon . ammo + supply_module . add_ammo )
else :
state . weapon . ammo = state . weapon . max_ammo
if supply_module . add_fuel > 0 :
state . fuel = min ( agent . max_fuel , state . fuel + supply_module . add_fuel )
else :
state . fuel = agent . max_fuel
command = Command ( agent_id = agent_id ,
action_type = ActionType . SUPPLY ,
des = des ,
start_time = start_time ,
end_time = start_time + len ( agent . todo ) - origin_todo_length - 1 ,
state = state )
if not subtask :
end_time = command . end_time
for i in range ( origin_todo_length , len ( agent . todo ) ) :
agent . todo [ i ] . start_time = start_time + i - origin_todo_length
agent . todo [ i ] . end_time = end_time
agent . todo [ - 1 ] . end_type = ActionType . SUPPLY
agent . cmd_todo . append ( command )
return command
def command_switch ( self , agent_id , weapon_idx , start_time = 0 , subtask = False ) - > Command :
agent = self . get_agent ( agent_id )
origin_todo_length = len ( agent . todo ) # 记录原 todo 的长度
self . command_supply ( agent_id , subtask = True )
state = agent . todo [ - 1 ] . state if agent . todo else agent . state
agent . todo . append ( Action ( agent_id = agent_id ,
action_type = ActionType . SWITCH_WEAPON ,
weapon_idx = weapon_idx ,
weapon_name = agent . switchable_weapons [ weapon_idx ] . name ,
des = state . pos ,
start_time = start_time ,
state = state ) )
command = Command ( agent_id = agent_id ,
action_type = ActionType . SWITCH_WEAPON ,
weapon_idx = weapon_idx ,
weapon_name = agent . switchable_weapons [ weapon_idx ] . name ,
des = state . pos ,
start_time = start_time ,
end_time = start_time + len ( agent . todo ) - origin_todo_length - 1 ,
state = state )
if not subtask :
end_time = command . end_time
for i in range ( origin_todo_length , len ( agent . todo ) ) :
agent . todo [ i ] . start_time = start_time + i - origin_todo_length
agent . todo [ i ] . end_time = end_time
agent . todo [ - 1 ] . end_type = ActionType . SWITCH_WEAPON
agent . cmd_todo . append ( command )
return command
def make_plan ( self , command : Command , subtask = False ) :
if command . action_type == ActionType . MOVE :
return self . command_move ( command . agent_id , command . des , start_time = command . start_time , subtask = subtask )
elif command . action_type == ActionType . ATTACK :
return self . command_attack ( command . agent_id , command . target_id , attack_count = command . attack_count , start_time = command . start_time , subtask = subtask )
elif command . action_type == ActionType . SUPPLY :
return self . command_supply ( command . agent_id , start_time = command . start_time , subtask = subtask )
elif command . action_type == ActionType . SWITCH_WEAPON :
return self . command_switch ( command . agent_id , command . weapon_idx , start_time = command . start_time , subtask = subtask )
else :
raise ValueError ( f " Invalid command type: { command . action_type } " )
def todo_action ( self , agent_id : str , retry = 0 ) - > Optional [ Action ] :
agent = self . get_agent ( agent_id )
if not agent . todo :
return None
action = agent . todo . pop ( 0 )
valid , msg = self . check_validity ( action )
if valid or retry > 2 :
if agent . cmd_todo [ 0 ] . action_type == action . end_type :
agent . cmd_todo . pop ( 0 )
return action
print ( f " Retrying to plan actions for { agent_id } : { msg } " )
if action . action_type == ActionType . RELEASE :
for pos in get_adj_pos ( action . des , 1 ) :
if pos in self . map . nodes and not self . map . nodes [ pos ] . agent_id :
action . des = pos
return action
return None # Wait
command = agent . cmd_todo . pop ( 0 )
# 处理异常
while agent . todo and agent . todo [ 0 ] . end_type != command . action_type :
agent . todo . pop ( 0 )
if agent . todo :
agent . todo . pop ( 0 )
current_actions = agent . todo . copy ( )
current_commands = agent . cmd_todo . copy ( )
self . make_plan ( command )
agent . todo . extend ( current_actions )
agent . cmd_todo . extend ( current_commands )
return self . todo_action ( agent_id , retry = retry + 1 )
def check_validity ( self , action : Action ) - > Tuple [ bool , str ] :
"""
检查动作是否合法
"""
if action . action_type == ActionType . END_OF_TURN :
return True , " OK "
map_data = self . map . nodes
spotted_enemy_ids = self . spotted_enemy_ids
2024-12-19 09:56:57 +08:00
if action . agent_id not in self . current_agents_dict :
2024-12-18 21:30:53 +08:00
return False , f " Invalid agent id: { action . agent_id } "
agent = self . get_agent ( action . agent_id )
if action . action_type not in agent . action_space . keys ( ) :
return False , f " Invalid action { action . action_type . name } for agent { agent . agent_id [ : 8 ] } "
if action . action_type == ActionType . MOVE :
des = action . des
node = map_data [ des ]
2024-12-19 09:56:57 +08:00
if node . faction_id == agent . faction_id and ( not agent . agent_type in ( target := self . get_agent ( node . agent_id ) ) . available_types or len ( target . parked_agents ) > = target . capacity ) :
2024-12-18 21:30:53 +08:00
return False , f " { node . agent_id [ : 8 ] } at { des } is an ally but not available "
if node . agent_id in spotted_enemy_ids :
return False , f " { des } has been occupied by enemy { node . agent_id [ : 8 ] } "
if get_axial_dis ( agent . pos , des ) > agent . mobility :
return False , f " Destination is out of range. Mobility: { agent . mobility } Euclidean distance: { get_axial_dis ( agent . pos , des ) } "
try :
mobility = min ( agent . fuel , agent . mobility )
adj_pos = get_adj_pos ( agent . pos , int ( mobility ) )
aval_data = {
pos : node
for pos in adj_pos if pos in map_data and
(
2024-12-19 09:56:57 +08:00
( node := map_data [ pos ] ) . faction_id == self . faction_id
2024-12-18 21:30:53 +08:00
or
node . agent_id not in spotted_enemy_ids
)
}
aval_nodes = astar_search ( aval_data , agent . move_type , start = agent . pos , limit = mobility )
if action . des in aval_nodes :
return True , " OK "
else :
return False , f " { des } could not be reached in one step. "
except Exception as e :
return False , f " Failed to move to { des } : { e } "
elif action . action_type == ActionType . ATTACK :
try :
target_agent = self . get_agent ( action . target_id )
except Exception as e :
return False , f " Invalid target agent id: { action . target_id } , ignored "
des = action . des
2024-12-19 09:56:57 +08:00
if agent . faction_id == target_agent . faction_id :
return False , f " Cannot attack allies. Attack: { action . agent_id } . Defend: { action . target_id } "
2024-12-18 21:30:53 +08:00
if agent . weapon is None :
return False , " No weapon equipped "
if agent . ammo < = 0 :
return False , " No ammo left "
if target_agent . move_type not in agent . strike_types :
return False , " Target agent cannot be attacked with this weapon "
if get_axial_dis ( agent . pos , des ) > agent . attack_range :
return False , f " Target agent is out of range. Attack range: { agent . attack_range } Distance: { get_axial_dis ( agent . pos , des ) } "
elif action . action_type == ActionType . SWITCH_WEAPON :
weapon_idx = action . weapon_idx
around_supply = False
for pos in get_adj_pos ( agent . pos , 1 ) :
2024-12-19 09:56:57 +08:00
if pos in map_data and map_data [ pos ] . agent_id and ( target_agent := self . get_agent ( map_data [ pos ] . agent_id ) ) . faction_id == agent . faction_id and target_agent . has_supply :
2024-12-18 21:30:53 +08:00
around_supply = True
if not around_supply :
return False , " Agent is not around a supply module "
if weapon_idx > = len ( agent . switchable_weapons ) :
return False , f " Invalid weapon id { weapon_idx } . Max id: { len ( agent . switchable_weapons ) - 1 } "
elif action . action_type == ActionType . RELEASE :
if action . target_id > = len ( agent . parked_agents ) :
return False , f " Trying to release an agent that is not in the queue. Target id: { action . target_id } , Current queue: { agent . parked_agents } "
target_agent = agent . parked_agents [ action . target_id ]
des = action . des
if get_axial_dis ( agent . pos , des ) > 1 :
return False , " Only adjacent positions can be released "
if map_data [ des ] . team_id != - 1 :
return False , " Cannot release agent to occupied position "
elif action . action_type == ActionType . INTERACT :
agent = self . get_agent ( action . agent_id )
target_agent = self . get_agent ( action . target_id )
if get_axial_dis ( agent . pos , target_agent . pos ) > 1 :
return False , " Target agent is out of range "
2024-12-19 09:56:57 +08:00
if agent . faction_id != target_agent . faction_id :
2024-12-18 21:30:53 +08:00
return False , " Cannot interact with enemy agent "
if agent . agent_type not in target_agent . available_types :
return False , f " Target agent { target_agent . agent_id [ : 8 ] } is not interactable with agent { agent } "
if len ( target_agent . parked_agents ) > = target_agent . capacity :
return False , f " Target agent { target_agent . agent_id [ : 8 ] } full "
return True , " OK "
def update ( self , sync_info ) :
"""
使用真实引擎
"""
sync_agents = sync_info [ ' units ' ]
spotted_enemies = sync_info [ ' spottedHostiles ' ]
map_data = self . map . nodes
reward = 0
for sync_agent in [ * sync_agents , * spotted_enemies ] :
agent_id = sync_agent [ ' agent_id ' ]
agent = self . get_agent ( agent_id )
# 清除与地图关联
if agent . pos != ( - 1 , - 1 ) :
node = map_data [ agent . pos ]
if node . agent_id == agent_id :
node . agent_id = None
node . team_id = - 1
2024-12-19 09:56:57 +08:00
node . faction_id = - 1
2024-12-18 21:30:53 +08:00
agent . pos = ( sync_agent [ ' pos ' ] [ ' q ' ] , sync_agent [ ' pos ' ] [ ' r ' ] )
agent . fuel = sync_agent [ ' fuel ' ]
2024-12-19 09:56:57 +08:00
if agent . faction_id != self . faction_id :
2024-12-18 21:30:53 +08:00
reward + = agent . endurance - sync_agent [ ' endurance ' ]
agent . endurance = sync_agent [ ' endurance ' ]
agent . commenced_action = sync_agent [ ' commenced_action ' ]
current_weapon = sync_agent [ ' current_weapon ' ]
if current_weapon :
for weapon in agent . switchable_weapons :
if weapon . name == current_weapon :
agent . weapon = weapon
break
agent . weapon . ammo = sync_agent [ ' ammo ' ]
else :
agent . weapon = None
# 更新与地图关联
if agent . alive :
node = map_data [ agent . pos ]
node . agent_id = agent_id
node . team_id = agent . team_id
2024-12-19 09:56:57 +08:00
node . faction_id = agent . faction_id
2024-12-18 21:30:53 +08:00
for agent_id in self . _spotted_enemy_ids :
agent = self . get_agent ( agent_id )
# 清除与地图关联
if agent . pos != ( - 1 , - 1 ) :
node = map_data [ agent . pos ]
if node . agent_id == agent_id :
node . agent_id = None
node . team_id = - 1
2024-12-19 09:56:57 +08:00
node . faction_id = - 1
2024-12-18 21:30:53 +08:00
# 更新敌人感知
self . _spotted_enemy_ids = [ enemy [ ' agent_id ' ] for enemy in spotted_enemies ]
"""
NOTE : Clear the states before observation
"""
self . _legal_actions = None
self . _legal_action_ids = None
self . _spotted_agents = None
obs = self . observe ( )
truncated = self . episode_steps > = self . max_episode_steps
2024-12-19 16:56:48 +08:00
terminated = sum ( self . teams [ team_id ] . alive_count for team_id in range ( self . max_team ) if self . teams [ team_id ] . faction_id != self . faction_id ) == 0
self . _cumulative_rewards [ self . team_id ] + = reward
2024-12-18 21:30:53 +08:00
info = { }
info [ " accum_reward " ] = ( self . _cumulative_rewards [ 0 ] , self . _cumulative_rewards [ 1 ] )
2024-12-19 16:56:48 +08:00
info [ " next_team " ] = self . team_id + 1 if not terminated else - 1
2024-12-18 21:30:53 +08:00
if self . replay_path is not None :
2024-12-19 16:56:48 +08:00
self . _frames [ self . team_id ] . append ( self . render ( mode = " rgb_array " ) )
2024-12-18 21:30:53 +08:00
if terminated or truncated :
# The eval_episode_return is calculated from Player 1's perspective
2024-12-19 16:56:48 +08:00
info [ " eval_episode_return " ] = reward if self . faction_id == self . init_faction_id else - reward
2024-12-18 21:30:53 +08:00
info [ " done_reason " ] = " Terminated " if terminated else " Truncated "
if self . replay_path is not None :
self . save_replay ( )
return obs , reward , terminated , truncated , info
2024-12-20 19:13:37 +08:00
def format_agents ( self , agent_ids = None ) :
if agent_ids is None :
agent_ids = self . current_agents_dict . keys ( )
return json_to_markdown ( self . get_agents_dict ( agent_ids ) )
2024-12-18 21:30:53 +08:00
2024-12-20 19:13:37 +08:00
def format_battlefeild ( self , battlefield_info ) :
dict_info = { }
dict_info [ " battlefield_name " ] = battlefield_info [ " sectorName " ]
dict_info [ " my_agents " ] = self . get_agents_dict ( battlefield_info [ " myUnitsID " ] )
dict_info [ " ally_agents " ] = self . get_agents_dict ( battlefield_info [ " allyUnitsID " ] )
dict_info [ " enemy_agents " ] = self . get_agents_dict ( battlefield_info [ " hostileUnitsID " ] )
return json_to_markdown ( dict_info )
2024-12-18 21:30:53 +08:00