541 lines
19 KiB
Python
541 lines
19 KiB
Python
from abc import ABC
|
|
import copy
|
|
from enum import Enum
|
|
from typing import List, Tuple
|
|
import gymnasium as gym
|
|
|
|
class MyEnum(Enum):
|
|
@classmethod
|
|
def from_value(cls, value):
|
|
"""根据数值获取相应的枚举成员"""
|
|
for member in cls:
|
|
if member.value == value:
|
|
return member
|
|
raise ValueError(f"{value} is not a valid value for {cls.__name__}")
|
|
|
|
@classmethod
|
|
def from_string(cls, name):
|
|
"""根据字符串获取相应的枚举成员"""
|
|
try:
|
|
return cls[name]
|
|
except KeyError:
|
|
raise ValueError(f"{name} is not a valid {cls.__name__}")
|
|
|
|
@classmethod
|
|
def from_input(cls, input):
|
|
"""根据输入(字符串或数值)获取相应的枚举成员"""
|
|
if isinstance(input, str):
|
|
return cls.from_string(input)
|
|
elif isinstance(input, int):
|
|
return cls.from_value(input)
|
|
else:
|
|
raise ValueError("Input must be either a string or an integer.")
|
|
|
|
class MoveType(MyEnum):
|
|
AIR = 0
|
|
GROUND = 1
|
|
SEA = 2
|
|
SUBWATER = 3
|
|
|
|
class TerrainType(MyEnum):
|
|
PLAIN = 0
|
|
ROAD = 1
|
|
SEA = 2
|
|
SUBWATER = 3
|
|
|
|
class ActionType(MyEnum):
|
|
ILLEGAL = -1
|
|
END_OF_TURN = 0
|
|
MOVE = 1
|
|
ATTACK = 2
|
|
INTERACT = 3
|
|
RELEASE = 4
|
|
SWITCH_WEAPON = 5
|
|
SUPPLY = 6
|
|
|
|
class ActionStatus(MyEnum):
|
|
VALID = 0
|
|
OCCUPIED_DES = 1
|
|
OUT_OF_RANGE = 2
|
|
|
|
class AgentType(MyEnum):
|
|
Infantry = 0 # 步兵
|
|
Tank = 1 # 装甲单位
|
|
AntiAir = 2 # 自行防空炮
|
|
MobilizedInfantry = 3 # 机械化步兵
|
|
Helicopter = 4 # 直升机
|
|
Fighter = 5 # 战斗机
|
|
UAV = 6 # 无人机
|
|
CombatShip = 7 # 战舰
|
|
Carrier = 8 # 航母
|
|
Artillery = 9 # 火炮
|
|
Construction = 10 # 建筑
|
|
MissileLauncher = 11 # 导弹发射车
|
|
TransportHelicopter = 12 # 运输直升机
|
|
SupplyTruck = 13 # 移动补给车
|
|
Submarine = 14 # 潜艇
|
|
AdvancedTank = 15 # 先进坦克
|
|
Airport = 16 # 机场
|
|
SupplyStation = 17 # 补给站
|
|
RadarStation = 18 # 雷达站
|
|
Bomber = 19 # 轰炸机
|
|
AWACS = 20 # 预警机
|
|
TransportShip = 21 # 运输船
|
|
CommandPost = 22 # 指挥部
|
|
|
|
class ModuleType(MyEnum):
|
|
HANGER = 0
|
|
SUPPLY = 1
|
|
TRANSPORT = 2
|
|
|
|
class Weapon:
|
|
def __init__(self, name: str, attack_range: int, damage: float, max_ammo: int, ammo: int = None, strike_types: List[MoveType] = []):
|
|
self.name = name
|
|
self.attack_range = attack_range
|
|
self.damage = damage
|
|
self.max_ammo = max_ammo
|
|
self.ammo = ammo if ammo is not None else max_ammo
|
|
self.strike_types = strike_types
|
|
|
|
def reset(self):
|
|
self.ammo = self.max_ammo
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"name": self.name,
|
|
"attack_range": self.attack_range,
|
|
"damage": self.damage,
|
|
"max_ammo": self.max_ammo,
|
|
"ammo": self.ammo,
|
|
"strike_types": [strike_type.name for strike_type in self.strike_types]
|
|
}
|
|
|
|
def __deepcopy__(self, memo):
|
|
copied_strike_types = copy.deepcopy(self.strike_types, memo)
|
|
return Weapon(self.name, self.attack_range, self.damage, self.max_ammo, self.ammo, copied_strike_types)
|
|
|
|
class Action:
|
|
def __init__(self,
|
|
action_type: ActionType=ActionType.ILLEGAL,
|
|
agent_id: str = "",
|
|
des: Tuple[int, int] = (-1, -1),
|
|
target_id: str = "",
|
|
weapon_id: int = 0,
|
|
weapon_name: str = "",
|
|
start_time: int = 0,
|
|
end_time: int = -1,
|
|
state: 'AgentState' = None,
|
|
**kwargs):
|
|
self.agent_id = agent_id
|
|
self.target_id = target_id
|
|
self.des = des
|
|
self.action_type = action_type
|
|
self.weapon_id = weapon_id
|
|
self.weapon_name = weapon_name
|
|
self.start_time = start_time
|
|
self.end_time = start_time if end_time == -1 else end_time
|
|
self.end_type = None
|
|
self.state = state
|
|
for key, value in kwargs.items():
|
|
setattr(self, key, value)
|
|
|
|
def json(self) -> dict:
|
|
from .utils import axial_to_cube_dict
|
|
return {
|
|
"curOrderedUnitID": self.agent_id,
|
|
"targetUnitID": self.target_id,
|
|
"destination": axial_to_cube_dict((self.des[0]-1000, self.des[1]-1000)),
|
|
"commandType": self.action_type.value
|
|
}
|
|
|
|
def __str__(self):
|
|
if self.action_type == ActionType.MOVE:
|
|
return f"{self.agent_id} moves to {self.des}"
|
|
elif self.action_type == ActionType.ATTACK:
|
|
return f"{self.agent_id} attacks {self.target_id}"
|
|
elif self.action_type == ActionType.INTERACT:
|
|
return f"Interact {self.agent_id} with {self.target_id}"
|
|
elif self.action_type == ActionType.RELEASE:
|
|
return f"Release {self.target_id} from {self.agent_id} to {self.des}"
|
|
elif self.action_type == ActionType.END_OF_TURN:
|
|
return f"End of turn"
|
|
elif self.action_type == ActionType.SWITCH_WEAPON:
|
|
return f"{self.agent_id} switches to {self.weapon_name}"
|
|
else:
|
|
return "Unknown action"
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"agent_id": self.agent_id,
|
|
"target_id": self.target_id,
|
|
"des": self.des,
|
|
"action_type": self.action_type.name,
|
|
"weapon_id": self.weapon_id,
|
|
"weapon_name": self.weapon_name,
|
|
"start_time": self.start_time,
|
|
"end_type": self.end_type
|
|
}
|
|
|
|
class Command(Action):
|
|
def __init__(self, **kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
def description(self) -> List[str]:
|
|
if self.command.action_type == ActionType.MOVE:
|
|
return [
|
|
f"Agent {self.agent_id} started to move to {self.command.des} at time {self.start_time}.",
|
|
f"Agent {self.agent_id} arrived at {self.command.des} at time {self.end_time}."
|
|
]
|
|
elif self.command.action_type == ActionType.ATTACK:
|
|
return [
|
|
f"Agent {self.agent_id} started to attack {self.command.des} at time {self.start_time}.",
|
|
f"Agent {self.agent_id} attacked {self.command.des} at time {self.end_time}."
|
|
]
|
|
elif self.command.action_type == ActionType.SUPPLY:
|
|
return [
|
|
f"Agent {self.agent_id} started to get supply from {self.command.des} at time {self.start_time}.",
|
|
f"Agent {self.agent_id} got supply from {self.command.des} at time {self.end_time}."
|
|
]
|
|
elif self.command.action_type == ActionType.SWITCH_WEAPON:
|
|
return [
|
|
f"Agent {self.agent_id} started to switch weapon to {self.command.weapon_name} at time {self.start_time}.",
|
|
f"Agent {self.agent_id} switched weapon to {self.command.weapon_name} at time {self.end_time}."
|
|
]
|
|
else:
|
|
raise NotImplementedError(f"Unsupported command type: {self.command.action_type}")
|
|
|
|
class Agent:
|
|
def __init__(self,
|
|
agent_id: str,
|
|
agent_type: AgentType,
|
|
team_id: int,
|
|
faction_id: int,
|
|
move_type: MoveType,
|
|
pos: Tuple[int, int],
|
|
info_level: int, stealth_level: int,
|
|
max_endurance: int, defense: int,
|
|
max_fuel: float, mobility: float,
|
|
switchable_weapons = [],
|
|
modules = [],
|
|
is_key: bool = False):
|
|
self.agent_id = agent_id
|
|
self.agent_type = agent_type
|
|
self.team_id = team_id
|
|
self.faction_id = faction_id
|
|
self.move_type = MoveType(move_type)
|
|
self.init_pos = pos
|
|
self.info_level = info_level
|
|
self.stealth_level = stealth_level
|
|
self.max_endurance = max_endurance
|
|
self.defense = defense
|
|
self.endurance = max_endurance
|
|
self.max_fuel = max_fuel
|
|
self.mobility = mobility
|
|
self.switchable_weapons = switchable_weapons
|
|
self.modules = modules
|
|
self.is_key = is_key
|
|
|
|
# Get action space
|
|
from .utils import range_to_count
|
|
max_attack_range = max(weapon.attack_range for weapon in self.switchable_weapons) if self.switchable_weapons else 0
|
|
max_capacity = max(module.capacity if hasattr(module, "capacity") else 0 for module in self.modules) if self.modules else 0
|
|
# max_capacity = 6 if self.modules and any(hasattr(module, "parked_agents") for module in self.modules) else 0
|
|
self._action_space = gym.spaces.Dict(
|
|
{
|
|
action_type: space
|
|
for action_type, space in {
|
|
ActionType.MOVE: gym.spaces.Discrete(range_to_count(int(self.mobility))) \
|
|
if self.mobility > 0 else None,
|
|
ActionType.ATTACK: gym.spaces.Discrete(range_to_count(max_attack_range)) \
|
|
if max_attack_range > 0 else None,
|
|
ActionType.SWITCH_WEAPON: gym.spaces.Discrete(len(self.switchable_weapons)) \
|
|
if len(self.switchable_weapons) >= 1 else None,
|
|
ActionType.INTERACT: gym.spaces.Discrete(range_to_count(1)),
|
|
ActionType.RELEASE: gym.spaces.Discrete(max_capacity) \
|
|
if max_capacity > 0 else None,
|
|
}.items()
|
|
if space is not None # 过滤掉值为 None 的键值对
|
|
}
|
|
)
|
|
|
|
self._has_supply = False
|
|
self._available_types = []
|
|
self._parked_agents = []
|
|
self._capacity = max_capacity # 实际上应该最多只有一个运输单元
|
|
|
|
for module in self.modules:
|
|
if module.module_type == ModuleType.SUPPLY:
|
|
self._has_supply = True
|
|
self.supply = module
|
|
elif module.module_type == ModuleType.HANGER:
|
|
self._parked_agents = module.parked_agents
|
|
elif module.module_type == ModuleType.TRANSPORT:
|
|
self._parked_agents = module.parked_agents
|
|
self._available_types.extend(module.available_types)
|
|
|
|
self.reset()
|
|
|
|
def reset(self):
|
|
self.pos = self.init_pos
|
|
self.endurance = self.max_endurance
|
|
self.fuel = self.max_fuel
|
|
if self.switchable_weapons:
|
|
for weapon in self.switchable_weapons:
|
|
weapon.reset()
|
|
self.weapon = self.switchable_weapons[0]
|
|
else:
|
|
self.weapon = None
|
|
self.commenced_action = False
|
|
self.cmd_todo = []
|
|
self.todo = []
|
|
self.is_carried = False
|
|
|
|
def __lt__(self, other):
|
|
return self.agent_id < other.agent_id
|
|
|
|
def __eq__(self, other):
|
|
return self.agent_id == other.agent_id
|
|
|
|
def __hash__(self):
|
|
return hash(self.agent_id)
|
|
|
|
def __str__(self):
|
|
return f"{self.agent_id} ({self.move_type.name}) at {self.pos}"
|
|
|
|
def __repr__(self):
|
|
return f"{self.agent_id} ({self.move_type.name}) at {self.pos}"
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"agent_id": self.agent_id,
|
|
"agent_type": self.agent_type.name,
|
|
"team_id": self.team_id,
|
|
"faction_id": self.faction_id,
|
|
"move_type": self.move_type.name,
|
|
# "info_level": self.info_level,
|
|
# "stealth_level": self.stealth_level,
|
|
"defense": self.defense,
|
|
"max_endurance": self.max_endurance,
|
|
"max_fuel": self.max_fuel,
|
|
"switchable_weapons": [weapon.to_dict() for weapon in self.switchable_weapons],
|
|
"modules": [module.to_dict() for module in self.modules],
|
|
# 以下为可变属性
|
|
"pos": list(self.pos),
|
|
"endurance": self.endurance,
|
|
"fuel": self.fuel,
|
|
"mobility": self.mobility,
|
|
"weapon": self.weapon.to_dict() if self.weapon is not None else None
|
|
}
|
|
|
|
def plan_to_dict(self):
|
|
state = self.todo[-1].state if self.todo else self.state
|
|
return {
|
|
"agent_id": self.agent_id,
|
|
"agent_type": self.agent_type.name,
|
|
"move_type": self.move_type.name,
|
|
"defense": self.defense,
|
|
"max_endurance": self.max_endurance,
|
|
"max_fuel": self.max_fuel,
|
|
"switchable_weapons": [weapon.to_dict() for weapon in self.switchable_weapons],
|
|
"modules": [module.to_dict() for module in self.modules],
|
|
"pos": list(state.pos),
|
|
"endurance": self.endurance,
|
|
"fuel": self.fuel,
|
|
"mobility": self.mobility,
|
|
"weapon": self.weapon.to_dict() if self.weapon is not None else None,
|
|
}
|
|
|
|
@property
|
|
def alive(self):
|
|
return self.endurance > 0
|
|
|
|
@property
|
|
def attack_range(self):
|
|
return self.weapon.attack_range if self.weapon is not None else 0
|
|
|
|
@property
|
|
def strike_types(self):
|
|
return self.weapon.strike_types if self.weapon is not None else []
|
|
|
|
@property
|
|
def damage(self):
|
|
return self.weapon.damage if self.weapon is not None else 0
|
|
|
|
@property
|
|
def ammo(self):
|
|
return self.weapon.ammo if self.weapon is not None else 0
|
|
|
|
@property
|
|
def action_space(self):
|
|
return self._action_space
|
|
|
|
@property
|
|
def has_supply(self):
|
|
return self._has_supply
|
|
|
|
@property
|
|
def available_types(self):
|
|
return self._available_types
|
|
|
|
@property
|
|
def parked_agents(self):
|
|
return self._parked_agents
|
|
|
|
@property
|
|
def capacity(self):
|
|
return self._capacity
|
|
|
|
@property
|
|
def state(self):
|
|
return AgentState(self.pos, self.endurance, self.fuel, self.weapon, self.commenced_action, self.cmd_todo, self.todo)
|
|
|
|
def update(self, state: 'AgentState'):
|
|
self.pos = state.pos
|
|
self.endurance = state.endurance
|
|
self.fuel = state.fuel
|
|
self.commenced_action = state.commenced_action
|
|
self.cmd_todo = copy.deepcopy(state.cmd_todo)
|
|
self.todo = copy.deepcopy(state.todo)
|
|
self.weapon = copy.deepcopy(state.weapon)
|
|
|
|
class AgentState:
|
|
def __init__(self, pos: Tuple[int, int], endurance: int, fuel: float, weapon: Weapon = None, commenced_action: bool = False, cmd_todo: List[Action] = [], todo: List[Action] = []):
|
|
self.pos = pos
|
|
self.endurance = endurance
|
|
self.fuel = fuel
|
|
self.commenced_action = commenced_action
|
|
self.weapon = copy.deepcopy(weapon) if weapon is not None else None
|
|
self.cmd_todo = copy.deepcopy(cmd_todo)
|
|
self.todo = copy.deepcopy(todo)
|
|
|
|
class Team:
|
|
def __init__(self, team_id: int, faction_id: int, agents: dict[str, Agent] = {}):
|
|
self.team_id = team_id
|
|
self.faction_id = faction_id
|
|
self.agents = agents
|
|
self.reset()
|
|
|
|
def __lt__(self, other):
|
|
return self.team_id < other.team_id
|
|
|
|
def __eq__(self, other):
|
|
return self.team_id == other.team_id
|
|
|
|
def __hash__(self):
|
|
return hash(self.team_id)
|
|
|
|
def __str__(self):
|
|
return f"Team {self.team_id}"
|
|
|
|
def __repr__(self):
|
|
return f"Team {self.team_id}"
|
|
|
|
def add_agent(self, agent: Agent):
|
|
self.agents.append(agent)
|
|
|
|
def remove_agent(self, agent_id: str):
|
|
self.agents.pop(agent_id)
|
|
|
|
def reset(self):
|
|
for agent in self.agents.values():
|
|
agent.reset()
|
|
|
|
@property
|
|
def alive_count(self):
|
|
return sum(1 for agent in self.agents.values() if agent.alive)
|
|
|
|
@property
|
|
def alive_ids(self):
|
|
return [agent.agent_id for agent in self.agents.values() if agent.alive]
|
|
|
|
class TileNode:
|
|
def __init__(self, pos: Tuple[int, int], terrain_type: TerrainType, agent_id: str = None, is_city: bool = False):
|
|
self.pos = pos
|
|
self.terrain_type = TerrainType(terrain_type)
|
|
self.agent_id = agent_id
|
|
self.is_city = is_city
|
|
self.team_id = -1
|
|
|
|
def reset(self):
|
|
self.chaotic_value = 0
|
|
self.occupy_value = 0
|
|
self.occupied_by = None
|
|
self.agent_id = None
|
|
self.team_id = -1
|
|
|
|
def __lt__(self, other):
|
|
return self.pos < other.pos
|
|
|
|
def __eq__(self, other):
|
|
return self.pos == other.pos
|
|
|
|
def __hash__(self):
|
|
return hash(self.pos)
|
|
|
|
def __str__(self):
|
|
return f"{self.pos} ({self.terrain_type.name})"
|
|
|
|
def __repr__(self):
|
|
return f"{self.pos} ({self.terrain_type.name})"
|
|
|
|
|
|
class Map:
|
|
def __init__(self, nodes: dict[Tuple[int, int], TileNode]):
|
|
self.width = max(pos[0] for pos in nodes.keys()) + 1
|
|
self.height = max(pos[1] for pos in nodes.keys()) + 1
|
|
self.nodes = nodes
|
|
|
|
class Module(ABC):
|
|
def __init__(self, module_type: ModuleType, add_endurance: int, add_ammo: int, add_fuel: float, available_types: List[AgentType] = None, capacity: int = 0):
|
|
self.module_type = module_type
|
|
self.add_endurance = add_endurance
|
|
self.add_ammo = add_ammo
|
|
self.add_fuel = add_fuel
|
|
self.available_types = available_types if available_types is not None else []
|
|
self.capacity = capacity
|
|
|
|
def to_dict(self):
|
|
return {
|
|
"module_type": self.module_type.name,
|
|
"add_endurance": self.add_endurance,
|
|
"add_ammo": self.add_ammo,
|
|
"add_fuel": self.add_fuel,
|
|
"available_types": [agent_type.name for agent_type in self.available_types]
|
|
}
|
|
|
|
class Hanger(Module):
|
|
def __init__(self, available_types: List[AgentType] = None, capacity: int = 6):
|
|
super().__init__(module_type=ModuleType.HANGER,
|
|
add_endurance=4,
|
|
add_ammo=-1,
|
|
add_fuel=-1,
|
|
available_types=available_types,
|
|
capacity=capacity)
|
|
self.parked_agents = []
|
|
|
|
def reset(self):
|
|
self.parked_agents = []
|
|
|
|
class Supply(Module):
|
|
def __init__(self, available_types: List[AgentType] = None, capacity: int = 0):
|
|
super().__init__(module_type=ModuleType.SUPPLY,
|
|
add_endurance=2,
|
|
add_ammo=2,
|
|
add_fuel=-1,
|
|
available_types=available_types,
|
|
capacity=capacity)
|
|
|
|
class Transport(Module):
|
|
def __init__(self, available_types: List[AgentType] = None, capacity: int = 6):
|
|
super().__init__(module_type=ModuleType.TRANSPORT,
|
|
add_endurance=0,
|
|
add_ammo=0,
|
|
add_fuel=0,
|
|
available_types=available_types,
|
|
capacity=capacity)
|
|
self.parked_agents = []
|
|
|
|
def reset(self):
|
|
self.parked_agents = []
|
|
|