main
hkr04 2024-12-19 09:56:57 +08:00
parent c9b6ec74b8
commit 93558d3cae
7 changed files with 100 additions and 66 deletions

3
.gitignore vendored 100644
View File

@ -0,0 +1,3 @@
/video
*/__pycache__*
/__pycache__

View File

@ -455,6 +455,7 @@ class TileNode:
self.agent_id = agent_id
self.is_city = is_city
self.team_id = -1
self.faction_id = -1
def reset(self):
self.chaotic_value = 0
@ -462,6 +463,7 @@ class TileNode:
self.occupied_by = None
self.agent_id = None
self.team_id = -1
self.faction_id = -1
def __lt__(self, other):
return self.pos < other.pos

View File

@ -25,18 +25,18 @@ team_colors = [
icons_path = "./tianqiong/envs/icons"
unit_icons: Dict[str, Dict[int, pygame.Surface]] = {
AgentType.AntiAir: {0: pygame.image.load(f"{icons_path}/AntiAir_b.png"), 1: pygame.image.load(f"{icons_path}/AntiAir_r.png")},
AgentType.Airport: {0: pygame.image.load(f"{icons_path}/Airport_b.png"), 1: pygame.image.load(f"{icons_path}/Airport_r.png")},
AgentType.Artillery: {0: pygame.image.load(f"{icons_path}/Artillery_b.png"), 1: pygame.image.load(f"{icons_path}/Artillery_r.png")},
AgentType.Bomber: {0: pygame.image.load(f"{icons_path}/Bomber_b.png"), 1: pygame.image.load(f"{icons_path}/Bomber_r.png")},
AgentType.Carrier: {0: pygame.image.load(f"{icons_path}/Carrier_b.png"), 1: pygame.image.load(f"{icons_path}/Carrier_r.png")},
AgentType.Fighter: {0: pygame.image.load(f"{icons_path}/Fighter_b.png"), 1: pygame.image.load(f"{icons_path}/Fighter_r.png")},
AgentType.Helicopter: {0: pygame.image.load(f"{icons_path}/Helicopter_b.png"), 1: pygame.image.load(f"{icons_path}/Helicopter_r.png")},
AgentType.Infantry: {0: pygame.image.load(f"{icons_path}/Infantry_b.png"), 1: pygame.image.load(f"{icons_path}/Infantry_r.png")},
AgentType.Tank: {0: pygame.image.load(f"{icons_path}/Tank_b.png"), 1: pygame.image.load(f"{icons_path}/Tank_r.png")},
AgentType.TransportHelicopter: {0: pygame.image.load(f"{icons_path}/TransportHelicopter_b.png"), 1: pygame.image.load(f"{icons_path}/TransportHelicopter_r.png")},
AgentType.TransportShip: {0: pygame.image.load(f"{icons_path}/TransportShip_b.png"), 1: pygame.image.load(f"{icons_path}/TransportShip_r.png")},
AgentType.CombatShip: {0: pygame.image.load(f"{icons_path}/CombatShip_b.png"), 1: pygame.image.load(f"{icons_path}/CombatShip_r.png")},
# AgentType.AntiAir: {0: pygame.image.load(f"{icons_path}/AntiAir_b.png"), 1: pygame.image.load(f"{icons_path}/AntiAir_r.png")},
# AgentType.Airport: {0: pygame.image.load(f"{icons_path}/Airport_b.png"), 1: pygame.image.load(f"{icons_path}/Airport_r.png")},
# AgentType.Artillery: {0: pygame.image.load(f"{icons_path}/Artillery_b.png"), 1: pygame.image.load(f"{icons_path}/Artillery_r.png")},
# AgentType.Bomber: {0: pygame.image.load(f"{icons_path}/Bomber_b.png"), 1: pygame.image.load(f"{icons_path}/Bomber_r.png")},
# AgentType.Carrier: {0: pygame.image.load(f"{icons_path}/Carrier_b.png"), 1: pygame.image.load(f"{icons_path}/Carrier_r.png")},
# AgentType.Fighter: {0: pygame.image.load(f"{icons_path}/Fighter_b.png"), 1: pygame.image.load(f"{icons_path}/Fighter_r.png")},
# AgentType.Helicopter: {0: pygame.image.load(f"{icons_path}/Helicopter_b.png"), 1: pygame.image.load(f"{icons_path}/Helicopter_r.png")},
# AgentType.Infantry: {0: pygame.image.load(f"{icons_path}/Infantry_b.png"), 1: pygame.image.load(f"{icons_path}/Infantry_r.png")},
# AgentType.Tank: {0: pygame.image.load(f"{icons_path}/Tank_b.png"), 1: pygame.image.load(f"{icons_path}/Tank_r.png")},
# AgentType.TransportHelicopter: {0: pygame.image.load(f"{icons_path}/TransportHelicopter_b.png"), 1: pygame.image.load(f"{icons_path}/TransportHelicopter_r.png")},
# AgentType.TransportShip: {0: pygame.image.load(f"{icons_path}/TransportShip_b.png"), 1: pygame.image.load(f"{icons_path}/TransportShip_r.png")},
# AgentType.CombatShip: {0: pygame.image.load(f"{icons_path}/CombatShip_b.png"), 1: pygame.image.load(f"{icons_path}/CombatShip_r.png")},
}
class Renderer:

View File

@ -319,6 +319,7 @@
"type": 16,
"team_id": 2,
"faction_id": 1,
"move_type": 1,
"defense": 5.0,
"mobility": 0.0,
"info_level": 6,

5
requirements.txt 100644
View File

@ -0,0 +1,5 @@
numpy
pygame
gymnasium
imageio
imageio_ffmpeg

View File

@ -7,29 +7,32 @@ from yes_cmdr.common.agent import Action, Command, ActionType
def test_action_id_transform(env):
print("Testing action id transform")
env.reset()
for idx in range(env.action_space_size):
assert env.action_to_id(env.id_to_action(idx)) == idx, f"Expected {env.id_to_action(idx)} as {idx} but got {env.action_to_id(env.id_to_action(idx))}"
for team_id in range(env.max_team):
for idx in range(env.action_space_size):
assert env.action_to_id(env.id_to_action(idx)) == idx, f"Expected {env.id_to_action(idx)} as {idx} but got {env.action_to_id(env.id_to_action(idx))}"
env.step(0)
for idx in range(env.action_space_size):
assert env.action_to_id(env.id_to_action(idx)) == idx, f"Expected {env.id_to_action(idx)} as {idx} but got {env.action_to_id(env.id_to_action(idx))}"
env.step(0)
def test_random_action(env):
print("Testing random action")
env.reset()
for i in range(100):
action = env.random_action()
assert env.check_validity(action)[0], f"Action {action} is not valid"
validity, reason = env.check_validity(action)
assert validity, f"Action {action} is not valid: {reason}"
env.step(action)
env.save_replay()
def test_bot_action(env):
print("Testing bot action")
env.reset()
for i in range(100):
action = env.bot_action()
assert env.check_validity(action)[0], f"Action {action} is not valid"
validity, reason = env.check_validity(action)
assert validity, f"Action {action} is not valid: {reason}"
env.step(action)
env.save_replay()
def test_command(env):
env.reset()
@ -79,7 +82,8 @@ if __name__ == "__main__":
data_path="./data/test",
max_team=5,
max_faction=4,
war_fog=False
war_fog=False,
replay_path="./video"
)
test_action_id_transform(env)

View File

@ -54,15 +54,17 @@ class YesCmdrEnv(gym.Env):
assert agent.pos in self.map.nodes, f"Invalid agent position: {agent.pos}"
self.map.nodes[agent.pos].agent_id = agent_id
self.map.nodes[agent.pos].team_id = team.team_id
self.map.nodes[agent.pos].faction_id = team.faction_id
if self.replay_path is not None:
self._frames = [[], []]
self._frames = [[] for _ in range(self.max_team)]
self._legal_actions = None
self._legal_action_ids = None
self._spotted_enemy_ids = None
self._spotted_enemy_agents = None
self._team_id = self.init_team_id
self._spotted_agents = None
self.episode_steps = 0
self._cumulative_rewards = np.zeros(2)
self._cumulative_rewards = np.zeros(self.max_team)
obs = self.observe()
info = {
@ -110,6 +112,7 @@ class YesCmdrEnv(gym.Env):
self._legal_actions = None
self._legal_action_ids = None
self._spotted_enemy_ids = None
self._spotted_enemy_agents = None
obs = self.observe()
truncated = self.episode_steps >= self.max_episode_steps
@ -124,7 +127,7 @@ class YesCmdrEnv(gym.Env):
if terminated or truncated:
# The eval_episode_return is calculated from Player 1's perspective
info["eval_episode_return"] = reward if self.team_id == 0 else -reward
info["eval_episode_return"] = reward if self.faction_id == self.teams[self.init_team_id].faction_id else -reward
info["done_reason"] = "Terminated" if terminated else "Truncated"
if self.replay_path is not None:
@ -134,8 +137,7 @@ class YesCmdrEnv(gym.Env):
def render(self, mode="human", attack_agent=None, defend_agent=None):
if mode == "rgb_array":
player_agents = self.teams[self.team_id].agents.values()
return self._renderer.render(self.current_agents, self.spotted_enemy_agents, attack_agent, defend_agent)
return self._renderer.render(self.union_agents, self.spotted_enemy_agents, attack_agent, defend_agent)
elif mode == "human":
return self.observe()
else:
@ -237,7 +239,7 @@ class YesCmdrEnv(gym.Env):
pos: node
for pos in adj_pos if pos in map_data and
(
(node := map_data[pos]).team_id != self.enemy_id
(node := map_data[pos]).faction_id == self.faction_id
or
node.agent_id not in spotted_enemy_ids
)
@ -247,14 +249,14 @@ class YesCmdrEnv(gym.Env):
if des == pos:
continue
node = map_data[des]
if node.team_id != self.team_id and node.agent_id not in spotted_enemy_ids:
if node.faction_id != self.faction_id and node.agent_id not in spotted_enemy_ids:
# 一个格子只能有一个单位
_legal_actions.append(Action(
agent_id=agent_id,
des=des,
action_type=ActionType.MOVE
))
elif node.team_id == self.team_id and agent.agent_type in (target := self.get_agent(node.agent_id)).available_types \
elif node.faction_id == self.faction_id and agent.agent_type in (target := self.get_agent(node.agent_id)).available_types \
and len(target.parked_agents) < target.capacity:
# 除非终点有可以停靠的单位
_legal_actions.append(Action(
@ -282,7 +284,7 @@ class YesCmdrEnv(gym.Env):
if des not in map_data:
continue
node = map_data[des]
if node.team_id == self.team_id and agent.agent_type in (target :=self.get_agent(node.agent_id)).available_types \
if node.faction_id == self.faction_id and agent.agent_type in (target :=self.get_agent(node.agent_id)).available_types \
and len(target.parked_agents) < target.capacity:
_legal_actions.append(Action(
agent_id=agent_id,
@ -313,7 +315,7 @@ class YesCmdrEnv(gym.Env):
if des not in map_data:
continue
node = map_data[des]
if node.team_id == self.team_id and self.get_agent(node.agent_id).has_supply:
if node.faction_id == self.faction_id and self.get_agent(node.agent_id).has_supply:
for weapon_idx, weapon in enumerate(agent.switchable_weapons):
if not agent.weapon or weapon.name != agent.weapon.name:
_legal_actions.append(Action(
@ -337,17 +339,17 @@ class YesCmdrEnv(gym.Env):
@property
def spotted_enemy_ids(self) -> List[str]:
if not self.war_fog:
return self.teams[self.enemy_id].alive_ids
return [enemy_id for enemy_id, enemy_agent in self.enemy_agents_dict.items() if enemy_agent.alive]
if self._spotted_enemy_ids is not None:
return self._spotted_enemy_ids
player_agents = self.current_agents
current_agents = self.current_agents
enemy_data = self.enemy_agents_dict
map_data = self.map.nodes
_spotted_enemy_ids = []
for agent in player_agents:
for agent in current_agents:
info_level = agent.info_level
for des in get_adj_pos(agent.pos, agent.info_level):
if des not in map_data:
@ -356,7 +358,7 @@ class YesCmdrEnv(gym.Env):
dis = get_axial_dis(agent.pos, des)
if dis == 1: # Adjacent agents are always spotted
dis = -1e9
if node.team_id == self.enemy_id \
if node.faction_id != self.faction_id \
and enemy_data[node.agent_id].stealth_level < info_level - dis + 1:
_spotted_enemy_ids.append(node.agent_id)
@ -409,7 +411,7 @@ class YesCmdrEnv(gym.Env):
_teams = [
Team(
team_id=_team_id,
faction_id=_agents[_team_id][0].faction_id,
faction_id=agents_dict_lists[_team_id][0]["faction_id"],
agents=_agents[_team_id]
)
for _team_id in range(self.max_team)
@ -422,6 +424,7 @@ class YesCmdrEnv(gym.Env):
assert agent.pos in _nodes, f"Invalid agent position: {agent.pos}"
_nodes[agent.pos].agent_id = agent_id
_nodes[agent.pos].team_id = team.team_id
_nodes[agent.pos].faction_id = team.faction_id
_map = Map(_nodes)
self.map = _map
@ -512,6 +515,14 @@ class YesCmdrEnv(gym.Env):
def current_agents_dict(self):
return self.teams[self._team_id].agents
@property
def union_agents(self):
return self._union_agents[self.faction_id]
@property
def union_agents_dict(self):
return self._union_agents_dict[self.faction_id]
@property
def enemy_agents(self):
return self._enemy_agents[self.faction_id]
@ -543,6 +554,7 @@ class YesCmdrEnv(gym.Env):
return self._spotted_enemy_agents
_spotted_enemy_agents = [self.enemy_agents_dict[enemy_id] for enemy_id in self.spotted_enemy_ids]
self._spotted_enemy_agents = _spotted_enemy_agents
return self._spotted_enemy_agents
@property
def faction_id(self):
@ -612,7 +624,7 @@ class YesCmdrEnv(gym.Env):
for agent in self.current_agents:
for pos in get_adj_pos(agent.pos, 1):
if pos in self.map.nodes and self.map.nodes[pos].team_id == self.team_id:
if pos in self.map.nodes and self.map.nodes[pos].faction_id == self.faction_id:
adj_agent = self.get_agent(self.map.nodes[pos].agent_id)
if not adj_agent.has_supply:
continue
@ -642,6 +654,7 @@ class YesCmdrEnv(gym.Env):
node = map_data[agent.pos]
node.agent_id = None
node.team_id = -1
node.faction_id = -1
mobility = min(agent.mobility, agent.fuel)
adj_pos = get_adj_pos(agent.pos, int(mobility))
@ -649,7 +662,7 @@ class YesCmdrEnv(gym.Env):
pos: node
for pos in adj_pos if pos in map_data and
(
(node := map_data[pos]).team_id != self.enemy_id
(node := map_data[pos]).faction_id == self.faction_id
or
node.agent_id not in spotted_enemy_ids
)
@ -660,7 +673,7 @@ class YesCmdrEnv(gym.Env):
exception = False
for cur, nxt in zip(path[:-1], path[1:]):
if self.war_fog and map_data[nxt].team_id == self.enemy_id: # 前进方向遇到敌军,停下
if self.war_fog and map_data[nxt].faction_id != self.faction_id: # 前进方向遇到敌军,停下
exception = True
break
if self.replay_path is not None:
@ -675,9 +688,10 @@ class YesCmdrEnv(gym.Env):
if map_data[cur].team_id == -1:
map_data[cur].agent_id = agent_id
map_data[cur].team_id = self.team_id
map_data[cur].faction_id = self.faction_id
elif map_data[cur].agent_id == target_id:
carry_agent = self.get_agent(target_id)
assert agent.team_id == carry_agent.team_id
assert agent.faction_id == carry_agent.faction_id
assert agent.agent_type in carry_agent.available_types, f"{agent} not in {carry_agent.available_types} (id: {carry_agent.agent_id[:8]}). "
assert len(carry_agent.parked_agents) < carry_agent.capacity, f"Agent {carry_agent[:8]} is available but full"
carry_agent.parked_agents.append(agent)
@ -688,8 +702,8 @@ class YesCmdrEnv(gym.Env):
return 0
def _attack(self, attack_id, target_id):
agent = self.current_agents[attack_id]
target_agent = self.teams[self.enemy_id].agents[target_id]
agent = self.get_agent(attack_id)
target_agent = self.get_agent(target_id)
if self.replay_path is not None:
self._frames[self.team_id].extend([self.render(mode="rgb_array", attack_agent=agent, defend_agent=target_agent)] * 4)
@ -705,8 +719,9 @@ class YesCmdrEnv(gym.Env):
node = self.get_node(target_agent.pos)
node.agent_id = None
node.team_id = -1
node.faction_id = -1
if self.teams[self.enemy_id].alive_count == 0:
if sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0:
return damage, True
return damage, False
@ -724,6 +739,7 @@ class YesCmdrEnv(gym.Env):
agent_to_release.pos = des
node = self.get_node(des)
node.team_id = agent_to_release.team_id
node.faction_id = agent_to_release.faction_id
node.agent_id = agent_to_release.agent_id
for module in agent.modules:
@ -751,7 +767,7 @@ class YesCmdrEnv(gym.Env):
return 0
def observe(self):
player_agents = self.current_agents
union_agents = self.union_agents
spotted_enemy_agents = self.spotted_enemy_agents
obs = {}
@ -764,16 +780,16 @@ class YesCmdrEnv(gym.Env):
for action_id in legal_actions_ids:
obs["action_mask"][action_id] = 1
for agent in player_agents:
for agent in union_agents:
pos = agent.pos
obs["player_info_level"][pos[0], pos[1]] = agent.info_level
obs["player_stealth_level"][pos[0], pos[1]] = agent.stealth_level
obs["player_mobility"][pos[0], pos[1]] = agent.mobility
obs["player_defense"][pos[0], pos[1]] = agent.defense
obs["player_damage"][pos[0], pos[1]] = agent.damage
obs["player_fuel"][pos[0], pos[1]] = agent.fuel
obs["player_ammo"][pos[0], pos[1]] = agent.ammo
obs["player_endurance"][pos[0], pos[1]] = agent.endurance
obs["union_info_level"][pos[0], pos[1]] = agent.info_level
obs["union_stealth_level"][pos[0], pos[1]] = agent.stealth_level
obs["union_mobility"][pos[0], pos[1]] = agent.mobility
obs["union_defense"][pos[0], pos[1]] = agent.defense
obs["union_damage"][pos[0], pos[1]] = agent.damage
obs["union_fuel"][pos[0], pos[1]] = agent.fuel
obs["union_ammo"][pos[0], pos[1]] = agent.ammo
obs["union_endurance"][pos[0], pos[1]] = agent.endurance
for agent in spotted_enemy_agents:
pos = agent.pos
@ -803,7 +819,7 @@ class YesCmdrEnv(gym.Env):
pos: node
for pos in adj_pos if pos in map_data and
(
(node := map_data[pos]).team_id != self.enemy_id
(node := map_data[pos]).faction_id == self.faction_id
or
node.agent_id not in spotted_enemy_ids
)
@ -892,7 +908,7 @@ class YesCmdrEnv(gym.Env):
pos for pos in get_adj_pos(target_agent.pos, int(agent.attack_range))
if pos in map_data and
(
(node := map_data[pos]).team_id != self.enemy_id
(node := map_data[pos]).faction_id == self.faction_id
or
node.agent_id not in spotted_enemy_ids
)
@ -939,13 +955,13 @@ class YesCmdrEnv(gym.Env):
需要补给的单位移动至有补给模块的单位附近
"""
agent = self.get_agent(agent_id)
player_agents = self.teams[agent.team_id].agents.values()
current_agents = self.current_agents
map_data = self.map.nodes
origin_todo_length = len(agent.todo) # 记录原 todo 的长度
possible_des = []
for _agent in player_agents:
for _agent in current_agents:
if _agent.has_supply:
for pos in get_adj_pos(_agent.pos, 1):
if pos in map_data:
@ -957,7 +973,7 @@ class YesCmdrEnv(gym.Env):
des = agent.todo[-1].des if agent.todo else agent.pos
for pos in get_adj_pos(des, 1):
if pos in map_data and map_data[pos].team_id == agent.team_id and self.get_agent(map_data[pos].agent_id).has_supply:
if pos in map_data and map_data[pos].faction_id == agent.faction_id and self.get_agent(map_data[pos].agent_id).has_supply:
supply_id = map_data[pos].agent_id
supply_module = self.get_agent(supply_id).supply
@ -1085,7 +1101,7 @@ class YesCmdrEnv(gym.Env):
map_data = self.map.nodes
spotted_enemy_ids = self.spotted_enemy_ids
if action.agent_id not in self.teams[self.player_id].agents:
if action.agent_id not in self.current_agents_dict:
return False, f"Invalid agent id: {action.agent_id}"
agent = self.get_agent(action.agent_id)
@ -1098,7 +1114,7 @@ class YesCmdrEnv(gym.Env):
node = map_data[des]
if node.team_id == agent.team_id and (not agent.agent_type in (target := self.get_agent(node.agent_id)).available_types or len(target.parked_agents) >= target.capacity):
if node.faction_id == agent.faction_id and (not agent.agent_type in (target := self.get_agent(node.agent_id)).available_types or len(target.parked_agents) >= target.capacity):
return False, f"{node.agent_id[:8]} at {des} is an ally but not available"
if node.agent_id in spotted_enemy_ids:
@ -1114,7 +1130,7 @@ class YesCmdrEnv(gym.Env):
pos: node
for pos in adj_pos if pos in map_data and
(
(node := map_data[pos]).team_id != self.enemy_id
(node := map_data[pos]).faction_id == self.faction_id
or
node.agent_id not in spotted_enemy_ids
)
@ -1135,8 +1151,8 @@ class YesCmdrEnv(gym.Env):
des = action.des
if agent.team_id == target_agent.team_id:
return False, f"Cannot attack teammate. Attack: {action.agent_id}. Defend: {action.target_id}"
if agent.faction_id == target_agent.faction_id:
return False, f"Cannot attack allies. Attack: {action.agent_id}. Defend: {action.target_id}"
if agent.weapon is None:
return False, "No weapon equipped"
if agent.ammo <= 0:
@ -1150,7 +1166,7 @@ class YesCmdrEnv(gym.Env):
weapon_idx = action.weapon_idx
around_supply = False
for pos in get_adj_pos(agent.pos, 1):
if pos in map_data and map_data[pos].agent_id and (target_agent := self.get_agent(map_data[pos].agent_id)).team_id == agent.team_id and target_agent.has_supply:
if pos in map_data and map_data[pos].agent_id and (target_agent := self.get_agent(map_data[pos].agent_id)).faction_id == agent.faction_id and target_agent.has_supply:
around_supply = True
if not around_supply:
return False, "Agent is not around a supply module"
@ -1172,7 +1188,7 @@ class YesCmdrEnv(gym.Env):
target_agent = self.get_agent(action.target_id)
if get_axial_dis(agent.pos, target_agent.pos) > 1:
return False, "Target agent is out of range"
if agent.team_id != target_agent.team_id:
if agent.faction_id != target_agent.faction_id:
return False, "Cannot interact with enemy agent"
if agent.agent_type not in target_agent.available_types:
return False, f"Target agent {target_agent.agent_id[:8]} is not interactable with agent {agent}"
@ -1200,9 +1216,10 @@ class YesCmdrEnv(gym.Env):
if node.agent_id == agent_id:
node.agent_id = None
node.team_id = -1
node.faction_id = -1
agent.pos = (sync_agent['pos']['q'], sync_agent['pos']['r'])
agent.fuel = sync_agent['fuel']
if agent.team_id == self.enemy_id:
if agent.faction_id != self.faction_id:
reward += agent.endurance - sync_agent['endurance']
agent.endurance = sync_agent['endurance']
agent.commenced_action = sync_agent['commenced_action']
@ -1220,6 +1237,7 @@ class YesCmdrEnv(gym.Env):
node = map_data[agent.pos]
node.agent_id = agent_id
node.team_id = agent.team_id
node.faction_id = agent.faction_id
for agent_id in self._spotted_enemy_ids:
agent = self.get_agent(agent_id)
@ -1229,6 +1247,7 @@ class YesCmdrEnv(gym.Env):
if node.agent_id == agent_id:
node.agent_id = None
node.team_id = -1
node.faction_id = -1
# 更新敌人感知
self._spotted_enemy_ids = [enemy['agent_id'] for enemy in spotted_enemies]