|
4 | 4 | from vlmeval.smp import *
|
5 | 5 |
|
6 | 6 | # Define valid modes
|
7 |
| -MODES = ('dlist', 'mlist', 'missing', 'circular', 'localize', 'check', 'run', 'eval') |
| 7 | +MODES = ('dlist', 'mlist', 'missing', 'circular', 'localize', 'check', 'run', 'eval', 'merge_pkl') |
8 | 8 |
|
9 | 9 | CLI_HELP_MSG = \
|
10 | 10 | f"""
|
|
33 | 33 | vlmutil run l2 hf
|
34 | 34 | 8. Evaluate data file:
|
35 | 35 | vlmutil eval [dataset_name] [prediction_file]
|
| 36 | + 9. Merge pkl files: |
| 37 | + vlmutil merge_pkl [pkl_dir] [world_size] |
36 | 38 |
|
37 | 39 | GitHub: https://github.com/open-compass/VLMEvalKit
|
38 | 40 | """ # noqa: E501
|
@@ -393,75 +395,102 @@ def parse_args_eval():
|
393 | 395 | return args
|
394 | 396 |
|
395 | 397 |
|
| 398 | +def MERGE_PKL(pkl_dir, world_size=1): |
| 399 | + prefs = [] |
| 400 | + for ws in list(range(1, 9)): |
| 401 | + prefs.extend([f'{i}{ws}_' for i in range(ws)]) |
| 402 | + prefs = set(prefs) |
| 403 | + files = os.listdir(pkl_dir) |
| 404 | + files = [x for x in files if x[:3] in prefs] |
| 405 | + # Merge the files |
| 406 | + res_all = defaultdict(dict) |
| 407 | + for f in files: |
| 408 | + full_path = osp.join(pkl_dir, f) |
| 409 | + key = f[3:] |
| 410 | + res_all[key].update(load(full_path)) |
| 411 | + os.remove(full_path) |
| 412 | + |
| 413 | + dump_prefs = [f'{i}{world_size}_' for i in range(world_size)] |
| 414 | + for k in res_all: |
| 415 | + for pf in dump_prefs: |
| 416 | + dump(res_all[k], f'{pkl_dir}/{pf}{k}') |
| 417 | + print(f'Merged {len(res_all[k])} records into {pkl_dir}/{dump_prefs[0]}{k}') |
| 418 | + |
| 419 | + |
396 | 420 | def cli():
|
397 | 421 | logger = get_logger('VLMEvalKit Tools')
|
398 | 422 | args = sys.argv[1:]
|
399 | 423 | if not args: # no arguments passed
|
400 | 424 | logger.info(CLI_HELP_MSG)
|
401 | 425 | return
|
402 |
| - if args[0].lower() in MODES: |
403 |
| - if args[0].lower() == 'dlist': |
404 |
| - assert len(args) >= 2 |
405 |
| - lst = DLIST(args[1]) |
406 |
| - print(' '.join(lst)) |
407 |
| - elif args[0].lower() == 'mlist': |
408 |
| - assert len(args) >= 2 |
409 |
| - size = 'all' |
410 |
| - if len(args) > 2: |
411 |
| - size = args[2].lower() |
412 |
| - lst = MLIST(args[1], size) |
413 |
| - print('\n'.join(lst)) |
414 |
| - elif args[0].lower() == 'missing': |
415 |
| - assert len(args) >= 2 |
416 |
| - missing_list = MISSING(args[1]) |
417 |
| - logger = get_logger('Find Missing') |
418 |
| - logger.info(colored(f'Level {args[1]} Missing Results: ', 'red')) |
419 |
| - lines = [] |
420 |
| - for m, D in missing_list: |
421 |
| - line = f'Model {m}, Dataset {D}' |
422 |
| - logger.info(colored(line, 'red')) |
423 |
| - lines.append(line) |
424 |
| - mwlines(lines, f'{args[1]}_missing.txt') |
425 |
| - elif args[0].lower() == 'circular': |
426 |
| - assert len(args) >= 2 |
427 |
| - CIRCULAR(args[1]) |
428 |
| - elif args[0].lower() == 'localize': |
429 |
| - assert len(args) >= 2 |
430 |
| - LOCALIZE(args[1]) |
431 |
| - elif args[0].lower() == 'check': |
432 |
| - assert len(args) >= 2 |
433 |
| - model_list = args[1:] |
434 |
| - for m in model_list: |
435 |
| - CHECK(m) |
436 |
| - elif args[0].lower() == 'run': |
437 |
| - assert len(args) >= 2 |
438 |
| - lvl = args[1] |
439 |
| - if len(args) == 2: |
440 |
| - model = 'all' |
| 426 | + |
| 427 | + if args[0].lower() == 'dlist': |
| 428 | + assert len(args) >= 2 |
| 429 | + lst = DLIST(args[1]) |
| 430 | + print(' '.join(lst)) |
| 431 | + elif args[0].lower() == 'mlist': |
| 432 | + assert len(args) >= 2 |
| 433 | + size = 'all' |
| 434 | + if len(args) > 2: |
| 435 | + size = args[2].lower() |
| 436 | + lst = MLIST(args[1], size) |
| 437 | + print('\n'.join(lst)) |
| 438 | + elif args[0].lower() == 'missing': |
| 439 | + assert len(args) >= 2 |
| 440 | + missing_list = MISSING(args[1]) |
| 441 | + logger = get_logger('Find Missing') |
| 442 | + logger.info(colored(f'Level {args[1]} Missing Results: ', 'red')) |
| 443 | + lines = [] |
| 444 | + for m, D in missing_list: |
| 445 | + line = f'Model {m}, Dataset {D}' |
| 446 | + logger.info(colored(line, 'red')) |
| 447 | + lines.append(line) |
| 448 | + mwlines(lines, f'{args[1]}_missing.txt') |
| 449 | + elif args[0].lower() == 'circular': |
| 450 | + assert len(args) >= 2 |
| 451 | + CIRCULAR(args[1]) |
| 452 | + elif args[0].lower() == 'localize': |
| 453 | + assert len(args) >= 2 |
| 454 | + LOCALIZE(args[1]) |
| 455 | + elif args[0].lower() == 'check': |
| 456 | + assert len(args) >= 2 |
| 457 | + model_list = args[1:] |
| 458 | + for m in model_list: |
| 459 | + CHECK(m) |
| 460 | + elif args[0].lower() == 'run': |
| 461 | + assert len(args) >= 2 |
| 462 | + lvl = args[1] |
| 463 | + if len(args) == 2: |
| 464 | + model = 'all' |
| 465 | + RUN(lvl, model) |
| 466 | + else: |
| 467 | + for model in args[2:]: |
441 | 468 | RUN(lvl, model)
|
442 |
| - else: |
443 |
| - for model in args[2:]: |
444 |
| - RUN(lvl, model) |
445 |
| - elif args[0].lower() == 'eval': |
446 |
| - args = parse_args_eval() |
447 |
| - data_file = args.data_file |
448 |
| - |
449 |
| - def extract_dataset(file_name): |
450 |
| - fname = osp.splitext(file_name)[0].split('/')[-1] |
451 |
| - parts = fname.split('_') |
452 |
| - for i in range(len(parts)): |
453 |
| - if '_'.join(parts[i:]) in SUPPORTED_DATASETS: |
454 |
| - return '_'.join(parts[i:]) |
455 |
| - return None |
456 |
| - |
457 |
| - dataset = extract_dataset(data_file) |
458 |
| - assert dataset is not None, f'Cannot infer dataset name from {data_file}' |
459 |
| - kwargs = {'nproc': args.nproc} |
460 |
| - if args.judge is not None: |
461 |
| - kwargs['model'] = args.judge |
462 |
| - if args.retry is not None: |
463 |
| - kwargs['retry'] = args.retry |
464 |
| - EVAL(dataset_name=dataset, data_file=data_file, **kwargs) |
| 469 | + elif args[0].lower() == 'eval': |
| 470 | + args = parse_args_eval() |
| 471 | + data_file = args.data_file |
| 472 | + |
| 473 | + def extract_dataset(file_name): |
| 474 | + fname = osp.splitext(file_name)[0].split('/')[-1] |
| 475 | + parts = fname.split('_') |
| 476 | + for i in range(len(parts)): |
| 477 | + if '_'.join(parts[i:]) in SUPPORTED_DATASETS: |
| 478 | + return '_'.join(parts[i:]) |
| 479 | + return None |
| 480 | + |
| 481 | + dataset = extract_dataset(data_file) |
| 482 | + assert dataset is not None, f'Cannot infer dataset name from {data_file}' |
| 483 | + kwargs = {'nproc': args.nproc} |
| 484 | + if args.judge is not None: |
| 485 | + kwargs['model'] = args.judge |
| 486 | + if args.retry is not None: |
| 487 | + kwargs['retry'] = args.retry |
| 488 | + EVAL(dataset_name=dataset, data_file=data_file, **kwargs) |
| 489 | + elif args[0].lower() == 'merge_pkl': |
| 490 | + assert len(args) == 3 |
| 491 | + args[2] = int(args[2]) |
| 492 | + assert args[2] in [1, 2, 4, 8] |
| 493 | + MERGE_PKL(args[1], args[2]) |
465 | 494 | else:
|
466 | 495 | logger.error('WARNING: command error!')
|
467 | 496 | logger.info(CLI_HELP_MSG)
|
|
0 commit comments