Skip to content

Commit 7664447

Browse files
[Refactor] merge the video dataset related args into config json and each dataset inside
1 parent 547b36f commit 7664447

10 files changed

+199
-217
lines changed

run.py

+6-11
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,15 @@ def build_model_from_config(cfg):
2626

2727
def build_dataset_from_config(cfg):
2828
import vlmeval.dataset
29+
import inspect
2930
config = cp.deepcopy(cfg)
3031
assert 'class' in config
3132
cls_name = config.pop('class')
3233
if hasattr(vlmeval.dataset, cls_name):
33-
return getattr(vlmeval.dataset, cls_name)(**config)
34+
cls = getattr(vlmeval.dataset, cls_name)
35+
sig = inspect.signature(cls.__init__)
36+
valid_params = {k: v for k, v in config.items() if k in sig.parameters}
37+
return cls(**valid_params)
3438
else:
3539
raise ValueError(f'Class {cls_name} is not supported in `vlmeval.dataset`')
3640

@@ -101,11 +105,6 @@ def parse_args():
101105
parser.add_argument('--data', type=str, nargs='+', help='Names of Datasets')
102106
parser.add_argument('--model', type=str, nargs='+', help='Names of Models')
103107
parser.add_argument('--config', type=str, help='Path to the Config Json File')
104-
# Args that only apply to Video Dataset
105-
parser.add_argument('--nframe', type=int, default=8)
106-
parser.add_argument('--pack', action='store_true')
107-
parser.add_argument('--use-subtitle', action='store_true')
108-
parser.add_argument('--fps', type=float, default=-1)
109108
# Work Dir
110109
parser.add_argument('--work-dir', type=str, default='./outputs', help='select the output directory')
111110
# Infer + Eval or Infer Only
@@ -287,12 +286,8 @@ def main():
287286
work_dir=pred_root,
288287
model_name=model_name,
289288
dataset=dataset,
290-
nframe=args.nframe,
291-
pack=args.pack,
292289
verbose=args.verbose,
293-
subtitle=args.use_subtitle,
294-
api_nproc=args.nproc,
295-
fps=args.fps)
290+
api_nproc=args.nproc)
296291
elif dataset.TYPE == 'MT':
297292
model = infer_data_job_mt(
298293
model,

vlmeval/dataset/longvideobench.py

+16-16
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ class LongVideoBench(VideoBaseDataset):
9494

9595
TYPE = 'Video-MCQ'
9696

97-
def __init__(self, dataset='LongVideoBench', use_subtitle=False):
98-
super().__init__(dataset=dataset)
97+
def __init__(self, dataset='LongVideoBench', use_subtitle=False, nframe=8, fps=-1):
98+
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
9999
self.use_subtitle = use_subtitle
100100
self.dataset_name = dataset
101101

@@ -195,25 +195,25 @@ def concat_tar_parts(tar_parts, output_tar):
195195

196196
return dict(data_file=data_file, root=dataset_path)
197197

198-
def save_video_frames(self, video_path, num_frames=8, fps=-1, video_llm=False):
198+
def save_video_frames(self, video_path, video_llm=False):
199199

200200
vid_path = osp.join(self.data_root, video_path)
201201
vid = decord.VideoReader(vid_path)
202202
video_info = {
203203
'fps': vid.get_avg_fps(),
204204
'n_frames': len(vid),
205205
}
206-
if num_frames > 0 and fps < 0:
207-
step_size = len(vid) / (num_frames + 1)
208-
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
209-
frame_paths = self.frame_paths(video_path[:-4], num_frames)
210-
elif fps > 0:
206+
if self.nframe > 0 and self.fps < 0:
207+
step_size = len(vid) / (self.nframe + 1)
208+
indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
209+
frame_paths = self.frame_paths(video_path[:-4])
210+
elif self.fps > 0:
211211
# not constrained by num_frames, get frames by fps
212212
total_duration = video_info['n_frames'] / video_info['fps']
213-
required_frames = int(total_duration * fps)
214-
step_size = video_info['fps'] / fps
213+
required_frames = int(total_duration * self.fps)
214+
step_size = video_info['fps'] / self.fps
215215
indices = [int(i * step_size) for i in range(required_frames)]
216-
frame_paths = self.frame_paths_fps(video_path[:-4], len(indices), fps)
216+
frame_paths = self.frame_paths_fps(video_path[:-4], len(indices))
217217

218218
flag = np.all([osp.exists(p) for p in frame_paths])
219219

@@ -226,16 +226,16 @@ def save_video_frames(self, video_path, num_frames=8, fps=-1, video_llm=False):
226226

227227
return frame_paths, indices, video_info
228228

229-
def save_video_into_images(self, line, num_frames=8):
230-
frame_paths, indices, video_info = self.save_video_frames(line['video_path'], num_frames)
231-
return frame_paths
229+
# def save_video_into_images(self, line, num_frames=8):
230+
# frame_paths, indices, video_info = self.save_video_frames(line['video_path'], num_frames)
231+
# return frame_paths
232232

233-
def build_prompt(self, line, num_frames, video_llm, fps):
233+
def build_prompt(self, line, video_llm):
234234
if isinstance(line, int):
235235
assert line < len(self)
236236
line = self.data.iloc[line]
237237

238-
frames, indices, video_info = self.save_video_frames(line['video_path'], num_frames, fps, video_llm)
238+
frames, indices, video_info = self.save_video_frames(line['video_path'], video_llm)
239239
fps = video_info["fps"]
240240

241241
message = [dict(type='text', value=self.SYS)]

vlmeval/dataset/mlvu.py

+26-26
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def qa_template(self, data):
152152
answer = f"({chr(ord('A') + answer_idx)}) {answer}"
153153
return question, answer
154154

155-
def save_video_frames(self, line, num_frames=8, fps=-1):
155+
def save_video_frames(self, line):
156156
suffix = line['video'].split('.')[-1]
157157
video = line['video'].replace(f'.{suffix}','')
158158
vid_path = osp.join(self.data_root, line['prefix'], line['video'])
@@ -161,17 +161,17 @@ def save_video_frames(self, line, num_frames=8, fps=-1):
161161
'fps': vid.get_avg_fps(),
162162
'n_frames': len(vid),
163163
}
164-
if num_frames > 0 and fps < 0:
165-
step_size = len(vid) / (num_frames + 1)
166-
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
167-
frame_paths = self.frame_paths(video, num_frames)
168-
elif fps > 0:
164+
if self.nframe > 0 and self.fps < 0:
165+
step_size = len(vid) / (self.nframe + 1)
166+
indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
167+
frame_paths = self.frame_paths(video)
168+
elif self.fps > 0:
169169
# not constrained by num_frames, get frames by fps
170170
total_duration = video_info['n_frames'] / video_info['fps']
171-
required_frames = int(total_duration * fps)
172-
step_size = video_info['fps'] / fps
171+
required_frames = int(total_duration * self.fps)
172+
step_size = video_info['fps'] / self.fps
173173
indices = [int(i * step_size) for i in range(required_frames)]
174-
frame_paths = self.frame_paths_fps(video, len(indices), fps)
174+
frame_paths = self.frame_paths_fps(video, len(indices))
175175

176176
flag = np.all([osp.exists(p) for p in frame_paths])
177177

@@ -184,11 +184,11 @@ def save_video_frames(self, line, num_frames=8, fps=-1):
184184

185185
return frame_paths
186186

187-
def save_video_into_images(self, line, num_frames, fps):
188-
frame_paths = self.save_video_frames(line, num_frames, fps)
187+
def save_video_into_images(self, line):
188+
frame_paths = self.save_video_frames(line)
189189
return frame_paths
190190

191-
def build_prompt(self, line, num_frames, video_llm, fps=-1):
191+
def build_prompt(self, line, video_llm):
192192
if isinstance(line, int):
193193
assert line < len(self)
194194
line = self.data.iloc[line]
@@ -200,7 +200,7 @@ def build_prompt(self, line, num_frames, video_llm, fps=-1):
200200
if video_llm:
201201
message.append(dict(type='video', value=video_path))
202202
else:
203-
img_frame_paths = self.save_video_into_images(line, num_frames, fps)
203+
img_frame_paths = self.save_video_into_images(line)
204204
for im in img_frame_paths:
205205
message.append(dict(type='image', value=im))
206206
message.append(dict(type='text', value='\nOnly give the best option.'))
@@ -355,7 +355,7 @@ def qa_template(self, data):
355355
answer = data['answer']
356356
return question, answer
357357

358-
def save_video_frames(self, line, num_frames=8, fps=-1):
358+
def save_video_frames(self, line):
359359
suffix = line['video'].split('.')[-1]
360360
video = line['video'].replace(f'.{suffix}','')
361361
vid_path = osp.join(self.data_root, line['prefix'], line['video'])
@@ -364,17 +364,17 @@ def save_video_frames(self, line, num_frames=8, fps=-1):
364364
'fps': vid.get_avg_fps(),
365365
'n_frames': len(vid),
366366
}
367-
if num_frames > 0 and fps < 0:
368-
step_size = len(vid) / (num_frames + 1)
369-
indices = [int(i * step_size) for i in range(1, num_frames + 1)]
370-
frame_paths = self.frame_paths(video, num_frames)
371-
elif fps > 0:
367+
if self.nframe > 0 and self.fps < 0:
368+
step_size = len(vid) / (self.nframe + 1)
369+
indices = [int(i * step_size) for i in range(1, self.nframe + 1)]
370+
frame_paths = self.frame_paths(video)
371+
elif self.fps > 0:
372372
# not constrained by num_frames, get frames by fps
373373
total_duration = video_info['n_frames'] / video_info['fps']
374-
required_frames = int(total_duration * fps)
375-
step_size = video_info['fps'] / fps
374+
required_frames = int(total_duration * self.fps)
375+
step_size = video_info['fps'] / self.fps
376376
indices = [int(i * step_size) for i in range(required_frames)]
377-
frame_paths = self.frame_paths_fps(video, len(indices), fps)
377+
frame_paths = self.frame_paths_fps(video, len(indices))
378378

379379
flag = np.all([osp.exists(p) for p in frame_paths])
380380

@@ -387,11 +387,11 @@ def save_video_frames(self, line, num_frames=8, fps=-1):
387387

388388
return frame_paths
389389

390-
def save_video_into_images(self, line, num_frames, fps):
391-
frame_paths = self.save_video_frames(line, num_frames, fps)
390+
def save_video_into_images(self, line):
391+
frame_paths = self.save_video_frames(line)
392392
return frame_paths
393393

394-
def build_prompt(self, line, num_frames, video_llm, fps=-1):
394+
def build_prompt(self, line, video_llm):
395395
if isinstance(line, int):
396396
assert line < len(self)
397397
line = self.data.iloc[line]
@@ -403,7 +403,7 @@ def build_prompt(self, line, num_frames, video_llm, fps=-1):
403403
if video_llm:
404404
message.append(dict(type='video', value=video_path))
405405
else:
406-
img_frame_paths = self.save_video_into_images(line, num_frames, fps)
406+
img_frame_paths = self.save_video_into_images(line)
407407
for im in img_frame_paths:
408408
message.append(dict(type='image', value=im))
409409
return message

vlmeval/dataset/mmbench_video.py

+9-9
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,8 @@ class MMBenchVideo(VideoBaseDataset):
5959

6060
TYPE = 'Video-VQA'
6161

62-
def __init__(self, dataset='MMBench-Video', pack=False):
63-
super().__init__(dataset=dataset, pack=pack)
62+
def __init__(self, dataset='MMBench-Video', pack=False, nframe=8, fps=-1):
63+
super().__init__(dataset=dataset, pack=pack, nframe=nframe, fps=fps)
6464

6565
@classmethod
6666
def supported_datasets(cls):
@@ -92,7 +92,7 @@ def check_integrity(pth):
9292

9393
return dict(data_file=data_file, root=osp.join(dataset_path, 'video'))
9494

95-
def build_prompt_pack(self, line, num_frames, fps=-1):
95+
def build_prompt_pack(self, line):
9696
if isinstance(line, int):
9797
assert line < len(self)
9898
video = self.videos[line]
@@ -101,7 +101,7 @@ def build_prompt_pack(self, line, num_frames, fps=-1):
101101
elif isinstance(line, str):
102102
video = line
103103

104-
frames = self.save_video_frames(video, num_frames, fps)
104+
frames = self.save_video_frames(video)
105105
sub = self.data[self.data['video'] == video]
106106
sys_prompt = self.SYS + self.FRAMES_TMPL_PACK.format(len(frames))
107107
message = [dict(type='text', value=sys_prompt)]
@@ -114,7 +114,7 @@ def build_prompt_pack(self, line, num_frames, fps=-1):
114114
message.append(dict(type='text', value=prompt))
115115
return message
116116

117-
def build_prompt_nopack(self, line, num_frames, video_llm, fps):
117+
def build_prompt_nopack(self, line, video_llm):
118118
if isinstance(line, int):
119119
assert line < len(self)
120120
line = self.data.iloc[line]
@@ -125,7 +125,7 @@ def build_prompt_nopack(self, line, num_frames, video_llm, fps):
125125
message.append(dict(type='video', value=os.path.join(self.video_path, video_idx_path)))
126126
return message
127127
else:
128-
frames = self.save_video_frames(line['video'], num_frames, fps)
128+
frames = self.save_video_frames(line['video'])
129129
sys_prompt = self.FRAMES_TMPL_NOPACK.format(len(frames))
130130
message = [dict(type='text', value=sys_prompt)]
131131
for im in frames:
@@ -134,11 +134,11 @@ def build_prompt_nopack(self, line, num_frames, video_llm, fps):
134134
message.append(dict(type='text', value=prompt))
135135
return message
136136

137-
def build_prompt(self, line, num_frames, video_llm, fps):
137+
def build_prompt(self, line, video_llm):
138138
if self.pack and not video_llm:
139-
return self.build_prompt_pack(line, num_frames, fps)
139+
return self.build_prompt_pack(line)
140140
else:
141-
return self.build_prompt_nopack(line, num_frames, video_llm, fps)
141+
return self.build_prompt_nopack(line, video_llm)
142142

143143
@staticmethod
144144
def remove_side_quote(s, syms=[',', '"', "'"]):

0 commit comments

Comments
 (0)