|
| 1 | +import os |
| 2 | +import sys |
| 3 | +import argparse |
| 4 | +import yaml |
| 5 | + |
| 6 | +parant_dir = os.path.dirname(__file__) |
| 7 | +project_home = os.path.dirname(os.path.dirname(parant_dir)) |
| 8 | +sys.path = [ |
| 9 | + os.path.join(project_home, 'third_party/Android-Lab') |
| 10 | +] + sys.path |
| 11 | + |
| 12 | +from agent import get_agent |
| 13 | +from evaluation.auto_test import * |
| 14 | +from evaluation.parallel import parallel_worker |
| 15 | +from generate_result import find_all_task_files |
| 16 | +from evaluation.configs import AppConfig, TaskConfig |
| 17 | +from mobile_use_auto_test import * |
| 18 | +from mobile_use_executor import * |
| 19 | + |
| 20 | + |
| 21 | +if __name__ == '__main__': |
| 22 | + android_lab_dir = os.path.join(project_home, 'third_party/Android-Lab') |
| 23 | + task_yamls = os.listdir(f'{android_lab_dir}/evaluation/config') |
| 24 | + task_yamls = [f"{android_lab_dir}/evaluation/config/" + i for i in task_yamls if i.endswith(".yaml")] |
| 25 | + |
| 26 | + arg_parser = argparse.ArgumentParser() |
| 27 | + arg_parser.add_argument("-n", "--name", default=None, type=str) |
| 28 | + arg_parser.add_argument("-c", "--config", default=f"{parant_dir}/config.yaml", type=str) |
| 29 | + arg_parser.add_argument("--task_config", nargs="+", default=task_yamls, help="All task config(s) to load") |
| 30 | + arg_parser.add_argument("--task_id", nargs="+", default=None) |
| 31 | + arg_parser.add_argument("--debug", action="store_true", default=False) |
| 32 | + arg_parser.add_argument("--app", nargs="+", default=None) |
| 33 | + arg_parser.add_argument("-p", "--parallel", default=1, type=int) |
| 34 | + |
| 35 | + args = arg_parser.parse_args() |
| 36 | + with open(args.config, "r") as file: |
| 37 | + yaml_data = yaml.safe_load(file) |
| 38 | + |
| 39 | + agent_config = yaml_data["agent"] |
| 40 | + task_config = yaml_data["task"] |
| 41 | + eval_config = yaml_data["eval"] |
| 42 | + |
| 43 | + if args.name is None: |
| 44 | + args.name = f"{yaml_data.get('name', agent_config['name'])}_{datetime.datetime.now().strftime('%Y%m%dT%H%M%S')}" |
| 45 | + |
| 46 | + autotask_class = task_config["class"] if "class" in task_config else "ScreenshotMobileTask_AutoTest" |
| 47 | + |
| 48 | + single_config = TaskConfig(**task_config["args"]) |
| 49 | + single_config = single_config.add_config(eval_config) |
| 50 | + if "True" == agent_config.get("relative_bbox"): |
| 51 | + single_config.is_relative_bbox = True |
| 52 | + agent_class = globals().get(agent_config["name"]) |
| 53 | + if agent_class is None: |
| 54 | + agent = get_agent(agent_config["name"], **agent_config["args"]) |
| 55 | + else: |
| 56 | + agent = agent_class(**agent_config["args"]) |
| 57 | + |
| 58 | + task_files = find_all_task_files(args.task_config) |
| 59 | + print(f"Evaluation saved name: {args.name}") |
| 60 | + if os.path.exists(os.path.join(single_config.save_dir, args.name)): |
| 61 | + already_run = os.listdir(os.path.join(single_config.save_dir, args.name)) |
| 62 | + already_run = [i.split("_")[0] + "_" + i.split("_")[1] for i in already_run] |
| 63 | + else: |
| 64 | + already_run = [] |
| 65 | + |
| 66 | + all_task_start_info = [] |
| 67 | + for app_task_config_path in task_files: |
| 68 | + app_config = AppConfig(app_task_config_path) |
| 69 | + if args.task_id is None: |
| 70 | + task_ids = list(app_config.task_name.keys()) |
| 71 | + else: |
| 72 | + task_ids = args.task_id |
| 73 | + for task_id in task_ids: |
| 74 | + if task_id in already_run: |
| 75 | + print(f"Task {task_id} already run, skipping") |
| 76 | + continue |
| 77 | + if task_id not in app_config.task_name: |
| 78 | + continue |
| 79 | + task_instruction = app_config.task_name[task_id].strip() |
| 80 | + app = app_config.APP |
| 81 | + if args.app is not None: |
| 82 | + print(app, args.app) |
| 83 | + if app not in args.app: |
| 84 | + continue |
| 85 | + package = app_config.package |
| 86 | + command_per_step = app_config.command_per_step.get(task_id, None) |
| 87 | + |
| 88 | + task_instruction = f"You should use {app} to complete the following task: {task_instruction}" |
| 89 | + all_task_start_info.append({ |
| 90 | + "agent": agent, |
| 91 | + "task_id": task_id, |
| 92 | + "task_instruction": task_instruction, |
| 93 | + "package": package, |
| 94 | + "command_per_step": command_per_step, |
| 95 | + "app": app |
| 96 | + }) |
| 97 | + |
| 98 | + class_ = globals().get(autotask_class) |
| 99 | + if class_ is None: |
| 100 | + raise AttributeError(f"Class {autotask_class} not found. Please check the class name in the config file.") |
| 101 | + |
| 102 | + if args.parallel == 1: |
| 103 | + Auto_Test = class_(single_config.subdir_config(args.name)) |
| 104 | + print("Auto_Test", Auto_Test) |
| 105 | + Auto_Test.run_serial(all_task_start_info) |
| 106 | + else: |
| 107 | + parallel_worker(class_, single_config.subdir_config(args.name), args.parallel, all_task_start_info) |
0 commit comments