update
parent
c9b6ec74b8
commit
93558d3cae
|
@ -0,0 +1,3 @@
|
|||
/video
|
||||
*/__pycache__*
|
||||
/__pycache__
|
|
@ -455,6 +455,7 @@ class TileNode:
|
|||
self.agent_id = agent_id
|
||||
self.is_city = is_city
|
||||
self.team_id = -1
|
||||
self.faction_id = -1
|
||||
|
||||
def reset(self):
|
||||
self.chaotic_value = 0
|
||||
|
@ -462,6 +463,7 @@ class TileNode:
|
|||
self.occupied_by = None
|
||||
self.agent_id = None
|
||||
self.team_id = -1
|
||||
self.faction_id = -1
|
||||
|
||||
def __lt__(self, other):
|
||||
return self.pos < other.pos
|
||||
|
|
|
@ -25,18 +25,18 @@ team_colors = [
|
|||
icons_path = "./tianqiong/envs/icons"
|
||||
|
||||
unit_icons: Dict[str, Dict[int, pygame.Surface]] = {
|
||||
AgentType.AntiAir: {0: pygame.image.load(f"{icons_path}/AntiAir_b.png"), 1: pygame.image.load(f"{icons_path}/AntiAir_r.png")},
|
||||
AgentType.Airport: {0: pygame.image.load(f"{icons_path}/Airport_b.png"), 1: pygame.image.load(f"{icons_path}/Airport_r.png")},
|
||||
AgentType.Artillery: {0: pygame.image.load(f"{icons_path}/Artillery_b.png"), 1: pygame.image.load(f"{icons_path}/Artillery_r.png")},
|
||||
AgentType.Bomber: {0: pygame.image.load(f"{icons_path}/Bomber_b.png"), 1: pygame.image.load(f"{icons_path}/Bomber_r.png")},
|
||||
AgentType.Carrier: {0: pygame.image.load(f"{icons_path}/Carrier_b.png"), 1: pygame.image.load(f"{icons_path}/Carrier_r.png")},
|
||||
AgentType.Fighter: {0: pygame.image.load(f"{icons_path}/Fighter_b.png"), 1: pygame.image.load(f"{icons_path}/Fighter_r.png")},
|
||||
AgentType.Helicopter: {0: pygame.image.load(f"{icons_path}/Helicopter_b.png"), 1: pygame.image.load(f"{icons_path}/Helicopter_r.png")},
|
||||
AgentType.Infantry: {0: pygame.image.load(f"{icons_path}/Infantry_b.png"), 1: pygame.image.load(f"{icons_path}/Infantry_r.png")},
|
||||
AgentType.Tank: {0: pygame.image.load(f"{icons_path}/Tank_b.png"), 1: pygame.image.load(f"{icons_path}/Tank_r.png")},
|
||||
AgentType.TransportHelicopter: {0: pygame.image.load(f"{icons_path}/TransportHelicopter_b.png"), 1: pygame.image.load(f"{icons_path}/TransportHelicopter_r.png")},
|
||||
AgentType.TransportShip: {0: pygame.image.load(f"{icons_path}/TransportShip_b.png"), 1: pygame.image.load(f"{icons_path}/TransportShip_r.png")},
|
||||
AgentType.CombatShip: {0: pygame.image.load(f"{icons_path}/CombatShip_b.png"), 1: pygame.image.load(f"{icons_path}/CombatShip_r.png")},
|
||||
# AgentType.AntiAir: {0: pygame.image.load(f"{icons_path}/AntiAir_b.png"), 1: pygame.image.load(f"{icons_path}/AntiAir_r.png")},
|
||||
# AgentType.Airport: {0: pygame.image.load(f"{icons_path}/Airport_b.png"), 1: pygame.image.load(f"{icons_path}/Airport_r.png")},
|
||||
# AgentType.Artillery: {0: pygame.image.load(f"{icons_path}/Artillery_b.png"), 1: pygame.image.load(f"{icons_path}/Artillery_r.png")},
|
||||
# AgentType.Bomber: {0: pygame.image.load(f"{icons_path}/Bomber_b.png"), 1: pygame.image.load(f"{icons_path}/Bomber_r.png")},
|
||||
# AgentType.Carrier: {0: pygame.image.load(f"{icons_path}/Carrier_b.png"), 1: pygame.image.load(f"{icons_path}/Carrier_r.png")},
|
||||
# AgentType.Fighter: {0: pygame.image.load(f"{icons_path}/Fighter_b.png"), 1: pygame.image.load(f"{icons_path}/Fighter_r.png")},
|
||||
# AgentType.Helicopter: {0: pygame.image.load(f"{icons_path}/Helicopter_b.png"), 1: pygame.image.load(f"{icons_path}/Helicopter_r.png")},
|
||||
# AgentType.Infantry: {0: pygame.image.load(f"{icons_path}/Infantry_b.png"), 1: pygame.image.load(f"{icons_path}/Infantry_r.png")},
|
||||
# AgentType.Tank: {0: pygame.image.load(f"{icons_path}/Tank_b.png"), 1: pygame.image.load(f"{icons_path}/Tank_r.png")},
|
||||
# AgentType.TransportHelicopter: {0: pygame.image.load(f"{icons_path}/TransportHelicopter_b.png"), 1: pygame.image.load(f"{icons_path}/TransportHelicopter_r.png")},
|
||||
# AgentType.TransportShip: {0: pygame.image.load(f"{icons_path}/TransportShip_b.png"), 1: pygame.image.load(f"{icons_path}/TransportShip_r.png")},
|
||||
# AgentType.CombatShip: {0: pygame.image.load(f"{icons_path}/CombatShip_b.png"), 1: pygame.image.load(f"{icons_path}/CombatShip_r.png")},
|
||||
}
|
||||
|
||||
class Renderer:
|
||||
|
|
|
@ -319,6 +319,7 @@
|
|||
"type": 16,
|
||||
"team_id": 2,
|
||||
"faction_id": 1,
|
||||
"move_type": 1,
|
||||
"defense": 5.0,
|
||||
"mobility": 0.0,
|
||||
"info_level": 6,
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
numpy
|
||||
pygame
|
||||
gymnasium
|
||||
imageio
|
||||
imageio_ffmpeg
|
|
@ -7,29 +7,32 @@ from yes_cmdr.common.agent import Action, Command, ActionType
|
|||
def test_action_id_transform(env):
|
||||
print("Testing action id transform")
|
||||
env.reset()
|
||||
|
||||
for team_id in range(env.max_team):
|
||||
for idx in range(env.action_space_size):
|
||||
assert env.action_to_id(env.id_to_action(idx)) == idx, f"Expected {env.id_to_action(idx)} as {idx} but got {env.action_to_id(env.id_to_action(idx))}"
|
||||
|
||||
env.step(0)
|
||||
|
||||
for idx in range(env.action_space_size):
|
||||
assert env.action_to_id(env.id_to_action(idx)) == idx, f"Expected {env.id_to_action(idx)} as {idx} but got {env.action_to_id(env.id_to_action(idx))}"
|
||||
|
||||
def test_random_action(env):
|
||||
print("Testing random action")
|
||||
env.reset()
|
||||
for i in range(100):
|
||||
action = env.random_action()
|
||||
assert env.check_validity(action)[0], f"Action {action} is not valid"
|
||||
validity, reason = env.check_validity(action)
|
||||
assert validity, f"Action {action} is not valid: {reason}"
|
||||
env.step(action)
|
||||
env.save_replay()
|
||||
|
||||
def test_bot_action(env):
|
||||
print("Testing bot action")
|
||||
env.reset()
|
||||
for i in range(100):
|
||||
action = env.bot_action()
|
||||
assert env.check_validity(action)[0], f"Action {action} is not valid"
|
||||
validity, reason = env.check_validity(action)
|
||||
assert validity, f"Action {action} is not valid: {reason}"
|
||||
env.step(action)
|
||||
env.save_replay()
|
||||
|
||||
def test_command(env):
|
||||
env.reset()
|
||||
|
@ -79,7 +82,8 @@ if __name__ == "__main__":
|
|||
data_path="./data/test",
|
||||
max_team=5,
|
||||
max_faction=4,
|
||||
war_fog=False
|
||||
war_fog=False,
|
||||
replay_path="./video"
|
||||
)
|
||||
|
||||
test_action_id_transform(env)
|
||||
|
|
109
yes_cmdr_env.py
109
yes_cmdr_env.py
|
@ -54,15 +54,17 @@ class YesCmdrEnv(gym.Env):
|
|||
assert agent.pos in self.map.nodes, f"Invalid agent position: {agent.pos}"
|
||||
self.map.nodes[agent.pos].agent_id = agent_id
|
||||
self.map.nodes[agent.pos].team_id = team.team_id
|
||||
self.map.nodes[agent.pos].faction_id = team.faction_id
|
||||
if self.replay_path is not None:
|
||||
self._frames = [[], []]
|
||||
self._frames = [[] for _ in range(self.max_team)]
|
||||
self._legal_actions = None
|
||||
self._legal_action_ids = None
|
||||
self._spotted_enemy_ids = None
|
||||
self._spotted_enemy_agents = None
|
||||
self._team_id = self.init_team_id
|
||||
self._spotted_agents = None
|
||||
self.episode_steps = 0
|
||||
self._cumulative_rewards = np.zeros(2)
|
||||
self._cumulative_rewards = np.zeros(self.max_team)
|
||||
|
||||
obs = self.observe()
|
||||
info = {
|
||||
|
@ -110,6 +112,7 @@ class YesCmdrEnv(gym.Env):
|
|||
self._legal_actions = None
|
||||
self._legal_action_ids = None
|
||||
self._spotted_enemy_ids = None
|
||||
self._spotted_enemy_agents = None
|
||||
|
||||
obs = self.observe()
|
||||
truncated = self.episode_steps >= self.max_episode_steps
|
||||
|
@ -124,7 +127,7 @@ class YesCmdrEnv(gym.Env):
|
|||
|
||||
if terminated or truncated:
|
||||
# The eval_episode_return is calculated from Player 1's perspective
|
||||
info["eval_episode_return"] = reward if self.team_id == 0 else -reward
|
||||
info["eval_episode_return"] = reward if self.faction_id == self.teams[self.init_team_id].faction_id else -reward
|
||||
info["done_reason"] = "Terminated" if terminated else "Truncated"
|
||||
|
||||
if self.replay_path is not None:
|
||||
|
@ -134,8 +137,7 @@ class YesCmdrEnv(gym.Env):
|
|||
|
||||
def render(self, mode="human", attack_agent=None, defend_agent=None):
|
||||
if mode == "rgb_array":
|
||||
player_agents = self.teams[self.team_id].agents.values()
|
||||
return self._renderer.render(self.current_agents, self.spotted_enemy_agents, attack_agent, defend_agent)
|
||||
return self._renderer.render(self.union_agents, self.spotted_enemy_agents, attack_agent, defend_agent)
|
||||
elif mode == "human":
|
||||
return self.observe()
|
||||
else:
|
||||
|
@ -237,7 +239,7 @@ class YesCmdrEnv(gym.Env):
|
|||
pos: node
|
||||
for pos in adj_pos if pos in map_data and
|
||||
(
|
||||
(node := map_data[pos]).team_id != self.enemy_id
|
||||
(node := map_data[pos]).faction_id == self.faction_id
|
||||
or
|
||||
node.agent_id not in spotted_enemy_ids
|
||||
)
|
||||
|
@ -247,14 +249,14 @@ class YesCmdrEnv(gym.Env):
|
|||
if des == pos:
|
||||
continue
|
||||
node = map_data[des]
|
||||
if node.team_id != self.team_id and node.agent_id not in spotted_enemy_ids:
|
||||
if node.faction_id != self.faction_id and node.agent_id not in spotted_enemy_ids:
|
||||
# 一个格子只能有一个单位
|
||||
_legal_actions.append(Action(
|
||||
agent_id=agent_id,
|
||||
des=des,
|
||||
action_type=ActionType.MOVE
|
||||
))
|
||||
elif node.team_id == self.team_id and agent.agent_type in (target := self.get_agent(node.agent_id)).available_types \
|
||||
elif node.faction_id == self.faction_id and agent.agent_type in (target := self.get_agent(node.agent_id)).available_types \
|
||||
and len(target.parked_agents) < target.capacity:
|
||||
# 除非终点有可以停靠的单位
|
||||
_legal_actions.append(Action(
|
||||
|
@ -282,7 +284,7 @@ class YesCmdrEnv(gym.Env):
|
|||
if des not in map_data:
|
||||
continue
|
||||
node = map_data[des]
|
||||
if node.team_id == self.team_id and agent.agent_type in (target :=self.get_agent(node.agent_id)).available_types \
|
||||
if node.faction_id == self.faction_id and agent.agent_type in (target :=self.get_agent(node.agent_id)).available_types \
|
||||
and len(target.parked_agents) < target.capacity:
|
||||
_legal_actions.append(Action(
|
||||
agent_id=agent_id,
|
||||
|
@ -313,7 +315,7 @@ class YesCmdrEnv(gym.Env):
|
|||
if des not in map_data:
|
||||
continue
|
||||
node = map_data[des]
|
||||
if node.team_id == self.team_id and self.get_agent(node.agent_id).has_supply:
|
||||
if node.faction_id == self.faction_id and self.get_agent(node.agent_id).has_supply:
|
||||
for weapon_idx, weapon in enumerate(agent.switchable_weapons):
|
||||
if not agent.weapon or weapon.name != agent.weapon.name:
|
||||
_legal_actions.append(Action(
|
||||
|
@ -337,17 +339,17 @@ class YesCmdrEnv(gym.Env):
|
|||
@property
|
||||
def spotted_enemy_ids(self) -> List[str]:
|
||||
if not self.war_fog:
|
||||
return self.teams[self.enemy_id].alive_ids
|
||||
return [enemy_id for enemy_id, enemy_agent in self.enemy_agents_dict.items() if enemy_agent.alive]
|
||||
|
||||
if self._spotted_enemy_ids is not None:
|
||||
return self._spotted_enemy_ids
|
||||
|
||||
player_agents = self.current_agents
|
||||
current_agents = self.current_agents
|
||||
enemy_data = self.enemy_agents_dict
|
||||
map_data = self.map.nodes
|
||||
_spotted_enemy_ids = []
|
||||
|
||||
for agent in player_agents:
|
||||
for agent in current_agents:
|
||||
info_level = agent.info_level
|
||||
for des in get_adj_pos(agent.pos, agent.info_level):
|
||||
if des not in map_data:
|
||||
|
@ -356,7 +358,7 @@ class YesCmdrEnv(gym.Env):
|
|||
dis = get_axial_dis(agent.pos, des)
|
||||
if dis == 1: # Adjacent agents are always spotted
|
||||
dis = -1e9
|
||||
if node.team_id == self.enemy_id \
|
||||
if node.faction_id != self.faction_id \
|
||||
and enemy_data[node.agent_id].stealth_level < info_level - dis + 1:
|
||||
_spotted_enemy_ids.append(node.agent_id)
|
||||
|
||||
|
@ -409,7 +411,7 @@ class YesCmdrEnv(gym.Env):
|
|||
_teams = [
|
||||
Team(
|
||||
team_id=_team_id,
|
||||
faction_id=_agents[_team_id][0].faction_id,
|
||||
faction_id=agents_dict_lists[_team_id][0]["faction_id"],
|
||||
agents=_agents[_team_id]
|
||||
)
|
||||
for _team_id in range(self.max_team)
|
||||
|
@ -422,6 +424,7 @@ class YesCmdrEnv(gym.Env):
|
|||
assert agent.pos in _nodes, f"Invalid agent position: {agent.pos}"
|
||||
_nodes[agent.pos].agent_id = agent_id
|
||||
_nodes[agent.pos].team_id = team.team_id
|
||||
_nodes[agent.pos].faction_id = team.faction_id
|
||||
|
||||
_map = Map(_nodes)
|
||||
self.map = _map
|
||||
|
@ -512,6 +515,14 @@ class YesCmdrEnv(gym.Env):
|
|||
def current_agents_dict(self):
|
||||
return self.teams[self._team_id].agents
|
||||
|
||||
@property
|
||||
def union_agents(self):
|
||||
return self._union_agents[self.faction_id]
|
||||
|
||||
@property
|
||||
def union_agents_dict(self):
|
||||
return self._union_agents_dict[self.faction_id]
|
||||
|
||||
@property
|
||||
def enemy_agents(self):
|
||||
return self._enemy_agents[self.faction_id]
|
||||
|
@ -543,6 +554,7 @@ class YesCmdrEnv(gym.Env):
|
|||
return self._spotted_enemy_agents
|
||||
_spotted_enemy_agents = [self.enemy_agents_dict[enemy_id] for enemy_id in self.spotted_enemy_ids]
|
||||
self._spotted_enemy_agents = _spotted_enemy_agents
|
||||
return self._spotted_enemy_agents
|
||||
|
||||
@property
|
||||
def faction_id(self):
|
||||
|
@ -612,7 +624,7 @@ class YesCmdrEnv(gym.Env):
|
|||
|
||||
for agent in self.current_agents:
|
||||
for pos in get_adj_pos(agent.pos, 1):
|
||||
if pos in self.map.nodes and self.map.nodes[pos].team_id == self.team_id:
|
||||
if pos in self.map.nodes and self.map.nodes[pos].faction_id == self.faction_id:
|
||||
adj_agent = self.get_agent(self.map.nodes[pos].agent_id)
|
||||
if not adj_agent.has_supply:
|
||||
continue
|
||||
|
@ -642,6 +654,7 @@ class YesCmdrEnv(gym.Env):
|
|||
node = map_data[agent.pos]
|
||||
node.agent_id = None
|
||||
node.team_id = -1
|
||||
node.faction_id = -1
|
||||
|
||||
mobility = min(agent.mobility, agent.fuel)
|
||||
adj_pos = get_adj_pos(agent.pos, int(mobility))
|
||||
|
@ -649,7 +662,7 @@ class YesCmdrEnv(gym.Env):
|
|||
pos: node
|
||||
for pos in adj_pos if pos in map_data and
|
||||
(
|
||||
(node := map_data[pos]).team_id != self.enemy_id
|
||||
(node := map_data[pos]).faction_id == self.faction_id
|
||||
or
|
||||
node.agent_id not in spotted_enemy_ids
|
||||
)
|
||||
|
@ -660,7 +673,7 @@ class YesCmdrEnv(gym.Env):
|
|||
exception = False
|
||||
|
||||
for cur, nxt in zip(path[:-1], path[1:]):
|
||||
if self.war_fog and map_data[nxt].team_id == self.enemy_id: # 前进方向遇到敌军,停下
|
||||
if self.war_fog and map_data[nxt].faction_id != self.faction_id: # 前进方向遇到敌军,停下
|
||||
exception = True
|
||||
break
|
||||
if self.replay_path is not None:
|
||||
|
@ -675,9 +688,10 @@ class YesCmdrEnv(gym.Env):
|
|||
if map_data[cur].team_id == -1:
|
||||
map_data[cur].agent_id = agent_id
|
||||
map_data[cur].team_id = self.team_id
|
||||
map_data[cur].faction_id = self.faction_id
|
||||
elif map_data[cur].agent_id == target_id:
|
||||
carry_agent = self.get_agent(target_id)
|
||||
assert agent.team_id == carry_agent.team_id
|
||||
assert agent.faction_id == carry_agent.faction_id
|
||||
assert agent.agent_type in carry_agent.available_types, f"{agent} not in {carry_agent.available_types} (id: {carry_agent.agent_id[:8]}). "
|
||||
assert len(carry_agent.parked_agents) < carry_agent.capacity, f"Agent {carry_agent[:8]} is available but full"
|
||||
carry_agent.parked_agents.append(agent)
|
||||
|
@ -688,8 +702,8 @@ class YesCmdrEnv(gym.Env):
|
|||
return 0
|
||||
|
||||
def _attack(self, attack_id, target_id):
|
||||
agent = self.current_agents[attack_id]
|
||||
target_agent = self.teams[self.enemy_id].agents[target_id]
|
||||
agent = self.get_agent(attack_id)
|
||||
target_agent = self.get_agent(target_id)
|
||||
|
||||
if self.replay_path is not None:
|
||||
self._frames[self.team_id].extend([self.render(mode="rgb_array", attack_agent=agent, defend_agent=target_agent)] * 4)
|
||||
|
@ -705,8 +719,9 @@ class YesCmdrEnv(gym.Env):
|
|||
node = self.get_node(target_agent.pos)
|
||||
node.agent_id = None
|
||||
node.team_id = -1
|
||||
node.faction_id = -1
|
||||
|
||||
if self.teams[self.enemy_id].alive_count == 0:
|
||||
if sum(self.teams[team_id].alive_count for team_id in range(self.max_team) if self.teams[team_id].faction_id != self.faction_id) == 0:
|
||||
return damage, True
|
||||
|
||||
return damage, False
|
||||
|
@ -724,6 +739,7 @@ class YesCmdrEnv(gym.Env):
|
|||
agent_to_release.pos = des
|
||||
node = self.get_node(des)
|
||||
node.team_id = agent_to_release.team_id
|
||||
node.faction_id = agent_to_release.faction_id
|
||||
node.agent_id = agent_to_release.agent_id
|
||||
|
||||
for module in agent.modules:
|
||||
|
@ -751,7 +767,7 @@ class YesCmdrEnv(gym.Env):
|
|||
return 0
|
||||
|
||||
def observe(self):
|
||||
player_agents = self.current_agents
|
||||
union_agents = self.union_agents
|
||||
spotted_enemy_agents = self.spotted_enemy_agents
|
||||
|
||||
obs = {}
|
||||
|
@ -764,16 +780,16 @@ class YesCmdrEnv(gym.Env):
|
|||
for action_id in legal_actions_ids:
|
||||
obs["action_mask"][action_id] = 1
|
||||
|
||||
for agent in player_agents:
|
||||
for agent in union_agents:
|
||||
pos = agent.pos
|
||||
obs["player_info_level"][pos[0], pos[1]] = agent.info_level
|
||||
obs["player_stealth_level"][pos[0], pos[1]] = agent.stealth_level
|
||||
obs["player_mobility"][pos[0], pos[1]] = agent.mobility
|
||||
obs["player_defense"][pos[0], pos[1]] = agent.defense
|
||||
obs["player_damage"][pos[0], pos[1]] = agent.damage
|
||||
obs["player_fuel"][pos[0], pos[1]] = agent.fuel
|
||||
obs["player_ammo"][pos[0], pos[1]] = agent.ammo
|
||||
obs["player_endurance"][pos[0], pos[1]] = agent.endurance
|
||||
obs["union_info_level"][pos[0], pos[1]] = agent.info_level
|
||||
obs["union_stealth_level"][pos[0], pos[1]] = agent.stealth_level
|
||||
obs["union_mobility"][pos[0], pos[1]] = agent.mobility
|
||||
obs["union_defense"][pos[0], pos[1]] = agent.defense
|
||||
obs["union_damage"][pos[0], pos[1]] = agent.damage
|
||||
obs["union_fuel"][pos[0], pos[1]] = agent.fuel
|
||||
obs["union_ammo"][pos[0], pos[1]] = agent.ammo
|
||||
obs["union_endurance"][pos[0], pos[1]] = agent.endurance
|
||||
|
||||
for agent in spotted_enemy_agents:
|
||||
pos = agent.pos
|
||||
|
@ -803,7 +819,7 @@ class YesCmdrEnv(gym.Env):
|
|||
pos: node
|
||||
for pos in adj_pos if pos in map_data and
|
||||
(
|
||||
(node := map_data[pos]).team_id != self.enemy_id
|
||||
(node := map_data[pos]).faction_id == self.faction_id
|
||||
or
|
||||
node.agent_id not in spotted_enemy_ids
|
||||
)
|
||||
|
@ -892,7 +908,7 @@ class YesCmdrEnv(gym.Env):
|
|||
pos for pos in get_adj_pos(target_agent.pos, int(agent.attack_range))
|
||||
if pos in map_data and
|
||||
(
|
||||
(node := map_data[pos]).team_id != self.enemy_id
|
||||
(node := map_data[pos]).faction_id == self.faction_id
|
||||
or
|
||||
node.agent_id not in spotted_enemy_ids
|
||||
)
|
||||
|
@ -939,13 +955,13 @@ class YesCmdrEnv(gym.Env):
|
|||
需要补给的单位移动至有补给模块的单位附近
|
||||
"""
|
||||
agent = self.get_agent(agent_id)
|
||||
player_agents = self.teams[agent.team_id].agents.values()
|
||||
current_agents = self.current_agents
|
||||
map_data = self.map.nodes
|
||||
origin_todo_length = len(agent.todo) # 记录原 todo 的长度
|
||||
|
||||
possible_des = []
|
||||
|
||||
for _agent in player_agents:
|
||||
for _agent in current_agents:
|
||||
if _agent.has_supply:
|
||||
for pos in get_adj_pos(_agent.pos, 1):
|
||||
if pos in map_data:
|
||||
|
@ -957,7 +973,7 @@ class YesCmdrEnv(gym.Env):
|
|||
|
||||
des = agent.todo[-1].des if agent.todo else agent.pos
|
||||
for pos in get_adj_pos(des, 1):
|
||||
if pos in map_data and map_data[pos].team_id == agent.team_id and self.get_agent(map_data[pos].agent_id).has_supply:
|
||||
if pos in map_data and map_data[pos].faction_id == agent.faction_id and self.get_agent(map_data[pos].agent_id).has_supply:
|
||||
supply_id = map_data[pos].agent_id
|
||||
supply_module = self.get_agent(supply_id).supply
|
||||
|
||||
|
@ -1085,7 +1101,7 @@ class YesCmdrEnv(gym.Env):
|
|||
|
||||
map_data = self.map.nodes
|
||||
spotted_enemy_ids = self.spotted_enemy_ids
|
||||
if action.agent_id not in self.teams[self.player_id].agents:
|
||||
if action.agent_id not in self.current_agents_dict:
|
||||
return False, f"Invalid agent id: {action.agent_id}"
|
||||
|
||||
agent = self.get_agent(action.agent_id)
|
||||
|
@ -1098,7 +1114,7 @@ class YesCmdrEnv(gym.Env):
|
|||
|
||||
node = map_data[des]
|
||||
|
||||
if node.team_id == agent.team_id and (not agent.agent_type in (target := self.get_agent(node.agent_id)).available_types or len(target.parked_agents) >= target.capacity):
|
||||
if node.faction_id == agent.faction_id and (not agent.agent_type in (target := self.get_agent(node.agent_id)).available_types or len(target.parked_agents) >= target.capacity):
|
||||
return False, f"{node.agent_id[:8]} at {des} is an ally but not available"
|
||||
|
||||
if node.agent_id in spotted_enemy_ids:
|
||||
|
@ -1114,7 +1130,7 @@ class YesCmdrEnv(gym.Env):
|
|||
pos: node
|
||||
for pos in adj_pos if pos in map_data and
|
||||
(
|
||||
(node := map_data[pos]).team_id != self.enemy_id
|
||||
(node := map_data[pos]).faction_id == self.faction_id
|
||||
or
|
||||
node.agent_id not in spotted_enemy_ids
|
||||
)
|
||||
|
@ -1135,8 +1151,8 @@ class YesCmdrEnv(gym.Env):
|
|||
|
||||
des = action.des
|
||||
|
||||
if agent.team_id == target_agent.team_id:
|
||||
return False, f"Cannot attack teammate. Attack: {action.agent_id}. Defend: {action.target_id}"
|
||||
if agent.faction_id == target_agent.faction_id:
|
||||
return False, f"Cannot attack allies. Attack: {action.agent_id}. Defend: {action.target_id}"
|
||||
if agent.weapon is None:
|
||||
return False, "No weapon equipped"
|
||||
if agent.ammo <= 0:
|
||||
|
@ -1150,7 +1166,7 @@ class YesCmdrEnv(gym.Env):
|
|||
weapon_idx = action.weapon_idx
|
||||
around_supply = False
|
||||
for pos in get_adj_pos(agent.pos, 1):
|
||||
if pos in map_data and map_data[pos].agent_id and (target_agent := self.get_agent(map_data[pos].agent_id)).team_id == agent.team_id and target_agent.has_supply:
|
||||
if pos in map_data and map_data[pos].agent_id and (target_agent := self.get_agent(map_data[pos].agent_id)).faction_id == agent.faction_id and target_agent.has_supply:
|
||||
around_supply = True
|
||||
if not around_supply:
|
||||
return False, "Agent is not around a supply module"
|
||||
|
@ -1172,7 +1188,7 @@ class YesCmdrEnv(gym.Env):
|
|||
target_agent = self.get_agent(action.target_id)
|
||||
if get_axial_dis(agent.pos, target_agent.pos) > 1:
|
||||
return False, "Target agent is out of range"
|
||||
if agent.team_id != target_agent.team_id:
|
||||
if agent.faction_id != target_agent.faction_id:
|
||||
return False, "Cannot interact with enemy agent"
|
||||
if agent.agent_type not in target_agent.available_types:
|
||||
return False, f"Target agent {target_agent.agent_id[:8]} is not interactable with agent {agent}"
|
||||
|
@ -1200,9 +1216,10 @@ class YesCmdrEnv(gym.Env):
|
|||
if node.agent_id == agent_id:
|
||||
node.agent_id = None
|
||||
node.team_id = -1
|
||||
node.faction_id = -1
|
||||
agent.pos = (sync_agent['pos']['q'], sync_agent['pos']['r'])
|
||||
agent.fuel = sync_agent['fuel']
|
||||
if agent.team_id == self.enemy_id:
|
||||
if agent.faction_id != self.faction_id:
|
||||
reward += agent.endurance - sync_agent['endurance']
|
||||
agent.endurance = sync_agent['endurance']
|
||||
agent.commenced_action = sync_agent['commenced_action']
|
||||
|
@ -1220,6 +1237,7 @@ class YesCmdrEnv(gym.Env):
|
|||
node = map_data[agent.pos]
|
||||
node.agent_id = agent_id
|
||||
node.team_id = agent.team_id
|
||||
node.faction_id = agent.faction_id
|
||||
|
||||
for agent_id in self._spotted_enemy_ids:
|
||||
agent = self.get_agent(agent_id)
|
||||
|
@ -1229,6 +1247,7 @@ class YesCmdrEnv(gym.Env):
|
|||
if node.agent_id == agent_id:
|
||||
node.agent_id = None
|
||||
node.team_id = -1
|
||||
node.faction_id = -1
|
||||
|
||||
# 更新敌人感知
|
||||
self._spotted_enemy_ids = [enemy['agent_id'] for enemy in spotted_enemies]
|
||||
|
|
Loading…
Reference in New Issue