Skip to content

Commit 6b39c19

Browse files
[Fix] Remove get_token_len for GPT API
1 parent cf4d61b commit 6b39c19

File tree

2 files changed

+102
-73
lines changed

2 files changed

+102
-73
lines changed

vlmeval/api/gpt.py

+10-10
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,16 @@ def generate_inner(self, inputs, **kwargs) -> str:
177177
temperature = kwargs.pop('temperature', self.temperature)
178178
max_tokens = kwargs.pop('max_tokens', self.max_tokens)
179179

180-
context_window = GPT_context_window(self.model)
181-
new_max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
182-
if 0 < new_max_tokens <= 100 and new_max_tokens < max_tokens:
183-
self.logger.warning(
184-
'Less than 100 tokens left, '
185-
'may exceed the context window with some additional meta symbols. '
186-
)
187-
if new_max_tokens <= 0:
188-
return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
189-
max_tokens = new_max_tokens
180+
# context_window = GPT_context_window(self.model)
181+
# new_max_tokens = min(max_tokens, context_window - self.get_token_len(inputs))
182+
# if 0 < new_max_tokens <= 100 and new_max_tokens < max_tokens:
183+
# self.logger.warning(
184+
# 'Less than 100 tokens left, '
185+
# 'may exceed the context window with some additional meta symbols. '
186+
# )
187+
# if new_max_tokens <= 0:
188+
# return 0, self.fail_msg + 'Input string longer than context window. ', 'Length Exceeded. '
189+
# max_tokens = new_max_tokens
190190

191191
# Will send request if use Azure, dk how to use openai client for it
192192
if self.use_azure:

vlmeval/tools.py

+92-63
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from vlmeval.smp import *
55

66
# Define valid modes
7-
MODES = ('dlist', 'mlist', 'missing', 'circular', 'localize', 'check', 'run', 'eval')
7+
MODES = ('dlist', 'mlist', 'missing', 'circular', 'localize', 'check', 'run', 'eval', 'merge_pkl')
88

99
CLI_HELP_MSG = \
1010
f"""
@@ -33,6 +33,8 @@
3333
vlmutil run l2 hf
3434
8. Evaluate data file:
3535
vlmutil eval [dataset_name] [prediction_file]
36+
9. Merge pkl files:
37+
vlmutil merge_pkl [pkl_dir] [world_size]
3638
3739
GitHub: https://github.com/open-compass/VLMEvalKit
3840
""" # noqa: E501
@@ -393,75 +395,102 @@ def parse_args_eval():
393395
return args
394396

395397

398+
def MERGE_PKL(pkl_dir, world_size=1):
399+
prefs = []
400+
for ws in list(range(1, 9)):
401+
prefs.extend([f'{i}{ws}_' for i in range(ws)])
402+
prefs = set(prefs)
403+
files = os.listdir(pkl_dir)
404+
files = [x for x in files if x[:3] in prefs]
405+
# Merge the files
406+
res_all = defaultdict(dict)
407+
for f in files:
408+
full_path = osp.join(pkl_dir, f)
409+
key = f[3:]
410+
res_all[key].update(load(full_path))
411+
os.remove(full_path)
412+
413+
dump_prefs = [f'{i}{world_size}_' for i in range(world_size)]
414+
for k in res_all:
415+
for pf in dump_prefs:
416+
dump(res_all[k], f'{pkl_dir}/{pf}{k}')
417+
print(f'Merged {len(res_all[k])} records into {pkl_dir}/{dump_prefs[0]}{k}')
418+
419+
396420
def cli():
397421
logger = get_logger('VLMEvalKit Tools')
398422
args = sys.argv[1:]
399423
if not args: # no arguments passed
400424
logger.info(CLI_HELP_MSG)
401425
return
402-
if args[0].lower() in MODES:
403-
if args[0].lower() == 'dlist':
404-
assert len(args) >= 2
405-
lst = DLIST(args[1])
406-
print(' '.join(lst))
407-
elif args[0].lower() == 'mlist':
408-
assert len(args) >= 2
409-
size = 'all'
410-
if len(args) > 2:
411-
size = args[2].lower()
412-
lst = MLIST(args[1], size)
413-
print('\n'.join(lst))
414-
elif args[0].lower() == 'missing':
415-
assert len(args) >= 2
416-
missing_list = MISSING(args[1])
417-
logger = get_logger('Find Missing')
418-
logger.info(colored(f'Level {args[1]} Missing Results: ', 'red'))
419-
lines = []
420-
for m, D in missing_list:
421-
line = f'Model {m}, Dataset {D}'
422-
logger.info(colored(line, 'red'))
423-
lines.append(line)
424-
mwlines(lines, f'{args[1]}_missing.txt')
425-
elif args[0].lower() == 'circular':
426-
assert len(args) >= 2
427-
CIRCULAR(args[1])
428-
elif args[0].lower() == 'localize':
429-
assert len(args) >= 2
430-
LOCALIZE(args[1])
431-
elif args[0].lower() == 'check':
432-
assert len(args) >= 2
433-
model_list = args[1:]
434-
for m in model_list:
435-
CHECK(m)
436-
elif args[0].lower() == 'run':
437-
assert len(args) >= 2
438-
lvl = args[1]
439-
if len(args) == 2:
440-
model = 'all'
426+
427+
if args[0].lower() == 'dlist':
428+
assert len(args) >= 2
429+
lst = DLIST(args[1])
430+
print(' '.join(lst))
431+
elif args[0].lower() == 'mlist':
432+
assert len(args) >= 2
433+
size = 'all'
434+
if len(args) > 2:
435+
size = args[2].lower()
436+
lst = MLIST(args[1], size)
437+
print('\n'.join(lst))
438+
elif args[0].lower() == 'missing':
439+
assert len(args) >= 2
440+
missing_list = MISSING(args[1])
441+
logger = get_logger('Find Missing')
442+
logger.info(colored(f'Level {args[1]} Missing Results: ', 'red'))
443+
lines = []
444+
for m, D in missing_list:
445+
line = f'Model {m}, Dataset {D}'
446+
logger.info(colored(line, 'red'))
447+
lines.append(line)
448+
mwlines(lines, f'{args[1]}_missing.txt')
449+
elif args[0].lower() == 'circular':
450+
assert len(args) >= 2
451+
CIRCULAR(args[1])
452+
elif args[0].lower() == 'localize':
453+
assert len(args) >= 2
454+
LOCALIZE(args[1])
455+
elif args[0].lower() == 'check':
456+
assert len(args) >= 2
457+
model_list = args[1:]
458+
for m in model_list:
459+
CHECK(m)
460+
elif args[0].lower() == 'run':
461+
assert len(args) >= 2
462+
lvl = args[1]
463+
if len(args) == 2:
464+
model = 'all'
465+
RUN(lvl, model)
466+
else:
467+
for model in args[2:]:
441468
RUN(lvl, model)
442-
else:
443-
for model in args[2:]:
444-
RUN(lvl, model)
445-
elif args[0].lower() == 'eval':
446-
args = parse_args_eval()
447-
data_file = args.data_file
448-
449-
def extract_dataset(file_name):
450-
fname = osp.splitext(file_name)[0].split('/')[-1]
451-
parts = fname.split('_')
452-
for i in range(len(parts)):
453-
if '_'.join(parts[i:]) in SUPPORTED_DATASETS:
454-
return '_'.join(parts[i:])
455-
return None
456-
457-
dataset = extract_dataset(data_file)
458-
assert dataset is not None, f'Cannot infer dataset name from {data_file}'
459-
kwargs = {'nproc': args.nproc}
460-
if args.judge is not None:
461-
kwargs['model'] = args.judge
462-
if args.retry is not None:
463-
kwargs['retry'] = args.retry
464-
EVAL(dataset_name=dataset, data_file=data_file, **kwargs)
469+
elif args[0].lower() == 'eval':
470+
args = parse_args_eval()
471+
data_file = args.data_file
472+
473+
def extract_dataset(file_name):
474+
fname = osp.splitext(file_name)[0].split('/')[-1]
475+
parts = fname.split('_')
476+
for i in range(len(parts)):
477+
if '_'.join(parts[i:]) in SUPPORTED_DATASETS:
478+
return '_'.join(parts[i:])
479+
return None
480+
481+
dataset = extract_dataset(data_file)
482+
assert dataset is not None, f'Cannot infer dataset name from {data_file}'
483+
kwargs = {'nproc': args.nproc}
484+
if args.judge is not None:
485+
kwargs['model'] = args.judge
486+
if args.retry is not None:
487+
kwargs['retry'] = args.retry
488+
EVAL(dataset_name=dataset, data_file=data_file, **kwargs)
489+
elif args[0].lower() == 'merge_pkl':
490+
assert len(args) == 3
491+
args[2] = int(args[2])
492+
assert args[2] in [1, 2, 4, 8]
493+
MERGE_PKL(args[1], args[2])
465494
else:
466495
logger.error('WARNING: command error!')
467496
logger.info(CLI_HELP_MSG)

0 commit comments

Comments
 (0)