94 lines
3.0 KiB
Python
94 lines
3.0 KiB
Python
import sys
|
|
sys.path.append("../")
|
|
|
|
from yes_cmdr.yes_cmdr_env import YesCmdrEnv
|
|
from yes_cmdr.common.agent import Action, Command, ActionType
|
|
|
|
def test_action_id_transform(env):
|
|
print("Testing action id transform")
|
|
env.reset()
|
|
|
|
for team_id in range(env.max_team):
|
|
for idx in range(env.action_space_size):
|
|
assert env.action_to_id(env.id_to_action(idx)) == idx, f"Expected {env.id_to_action(idx)} as {idx} but got {env.action_to_id(env.id_to_action(idx))}"
|
|
|
|
env.step(0)
|
|
|
|
def test_random_action(env):
|
|
print("Testing random action")
|
|
env.reset()
|
|
for i in range(100):
|
|
action = env.random_action()
|
|
validity, reason = env.check_validity(action)
|
|
assert validity, f"Action {action} is not valid: {reason}"
|
|
env.step(action)
|
|
env.save_replay()
|
|
|
|
def test_bot_action(env):
|
|
print("Testing bot action")
|
|
env.reset()
|
|
for i in range(100):
|
|
action = env.bot_action()
|
|
validity, reason = env.check_validity(action)
|
|
assert validity, f"Action {action} is not valid: {reason}"
|
|
env.step(action)
|
|
env.save_replay()
|
|
|
|
def test_command(env):
|
|
env.reset()
|
|
|
|
done = False
|
|
step_count = 1
|
|
|
|
for idx in range(env.action_space_size):
|
|
assert env.action_to_id(env.id_to_action(idx)) == idx, f"Expected {env.id_to_action(idx)} as {idx} but got {env.action_to_id(env.id_to_action(idx))}"
|
|
|
|
agents_data = env.teams[0].agents
|
|
agents = agents_data.values()
|
|
|
|
for agent in agents:
|
|
print(agent.to_dict())
|
|
|
|
env.make_plan(Command(agent_id="9f19", target_id="95f3", attack_count=1, action_type=ActionType.ATTACK, start_time=1))
|
|
env.make_plan(Command(agent_id="9f19", target_id="2477", attack_count=1, action_type=ActionType.ATTACK, start_time=5))
|
|
for idx, action in enumerate(agents_data["9f19"].todo):
|
|
print(action)
|
|
|
|
env.make_plan(Command(agent_id="cc4c", target_id="c53f", attack_count=4, action_type=ActionType.ATTACK))
|
|
for idx, action in enumerate(agents_data["cc4c"].todo):
|
|
print(action)
|
|
|
|
while not done and step_count <= 20:
|
|
for agent in agents:
|
|
if agent.todo and agent.todo[0].start_time <= step_count:
|
|
action = env.todo_action(agent.agent_id)
|
|
print(action)
|
|
obs, reward, terminated, truncated, info = env.step(action)
|
|
print(f"Step {step_count}: Action {action.action_type.name} -> Reward: {reward}, Done: {done}")
|
|
print(f"Info: {info}")
|
|
if "exception" in info:
|
|
agent.todo.clear()
|
|
|
|
env.step(0)
|
|
env.step(0)
|
|
|
|
step_count += 1
|
|
|
|
env.close()
|
|
|
|
if __name__ == "__main__":
|
|
# 配置环境参数
|
|
env = YesCmdrEnv(
|
|
data_path="./data/test",
|
|
max_team=5,
|
|
max_faction=4,
|
|
war_fog=False,
|
|
replay_path="./video"
|
|
)
|
|
|
|
test_action_id_transform(env)
|
|
test_random_action(env)
|
|
test_bot_action(env)
|
|
# test_command(env)
|
|
|