Skip to content

Commit 8b70019

Browse files
add supported_video_datasets function for quick start
1 parent 7a177d2 commit 8b70019

File tree

5 files changed

+78
-37
lines changed

5 files changed

+78
-37
lines changed

run.py

+11-31
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import torch.distributed as dist
33

44
from vlmeval.config import supported_VLM
5+
from vlmeval.dataset.video_dataset_config import supported_video_datasets
56
from vlmeval.dataset import build_dataset
67
from vlmeval.inference import infer_data_job
78
from vlmeval.inference_video import infer_data_job_video
@@ -26,16 +27,22 @@ def build_model_from_config(cfg, model_name):
2627
raise ValueError(f'Class {cls_name} is not supported in `vlmeval.api` or `vlmeval.vlm`')
2728

2829

29-
def build_dataset_from_config(cfg):
30+
def build_dataset_from_config(cfg, dataset_name):
3031
import vlmeval.dataset
3132
import inspect
32-
config = cp.deepcopy(cfg)
33+
config = cp.deepcopy(cfg[dataset_name])
34+
if config == {}:
35+
return supported_video_datasets[dataset_name]()
3336
assert 'class' in config
3437
cls_name = config.pop('class')
3538
if hasattr(vlmeval.dataset, cls_name):
3639
cls = getattr(vlmeval.dataset, cls_name)
3740
sig = inspect.signature(cls.__init__)
3841
valid_params = {k: v for k, v in config.items() if k in sig.parameters}
42+
if valid_params.get('fps', 0) > 0 and valid_params.get('nframe', 0) > 0:
43+
raise ValueError('fps and nframe should not be set at the same time')
44+
if valid_params.get('fps', 0) <= 0 and valid_params.get('nframe', 0) <= 0:
45+
raise ValueError('fps and nframe should be set at least one valid value')
3946
return cls(**valid_params)
4047
else:
4148
raise ValueError(f'Class {cls_name} is not supported in `vlmeval.dataset`')
@@ -190,20 +197,16 @@ def main():
190197
if use_config:
191198
if world_size > 1:
192199
if rank == 0:
193-
dataset = build_dataset_from_config(cfg['data'][dataset_name])
200+
dataset = build_dataset_from_config(cfg['data'], dataset_name)
194201
dist.barrier()
195-
dataset = build_dataset_from_config(cfg['data'][dataset_name])
202+
dataset = build_dataset_from_config(cfg['data'], dataset_name)
196203
if dataset is None:
197204
logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ')
198205
continue
199206
else:
200207
dataset_kwargs = {}
201208
if dataset_name in ['MMLongBench_DOC', 'DUDE', 'DUDE_MINI', 'SLIDEVQA', 'SLIDEVQA_MINI']:
202209
dataset_kwargs['model'] = model_name
203-
if dataset_name == 'MMBench-Video':
204-
dataset_kwargs['pack'] = args.pack
205-
if dataset_name == 'Video-MME':
206-
dataset_kwargs['use_subtitle'] = args.use_subtitle
207210

208211
# If distributed, first build the dataset on the main process for doing preparation works
209212
if world_size > 1:
@@ -215,29 +218,6 @@ def main():
215218
if dataset is None:
216219
logger.error(f'Dataset {dataset_name} is not valid, will be skipped. ')
217220
continue
218-
# Handling Video Datasets. For Video Dataset, set the fps for priority
219-
if args.fps > 0:
220-
if dataset_name == 'MVBench':
221-
raise ValueError('MVBench does not support fps setting, please transfer to MVBench_MP4!')
222-
args.nframe = 0
223-
if dataset_name in ['MMBench-Video']:
224-
packstr = 'pack' if args.pack else 'nopack'
225-
if args.nframe > 0:
226-
result_file_base = f'{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx'
227-
else:
228-
result_file_base = f'{model_name}_{dataset_name}_{args.fps}fps_{packstr}.xlsx'
229-
elif dataset.MODALITY == 'VIDEO':
230-
if args.pack:
231-
logger.info(f'{dataset_name} not support Pack Mode, directly change to unpack')
232-
args.pack = False
233-
packstr = 'pack' if args.pack else 'nopack'
234-
if args.nframe > 0:
235-
result_file_base = f'{model_name}_{dataset_name}_{args.nframe}frame_{packstr}.xlsx'
236-
else:
237-
result_file_base = f'{model_name}_{dataset_name}_{args.fps}fps_{packstr}.xlsx'
238-
if dataset_name in ['Video-MME', 'LongVideoBench']:
239-
subtitlestr = 'subs' if args.use_subtitle else 'nosubs'
240-
result_file_base = result_file_base.replace('.xlsx', f'_{subtitlestr}.xlsx')
241221

242222
# Handling Multi-Turn Dataset
243223
if dataset.TYPE == 'MT':

vlmeval/dataset/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from .mmmath import MMMath
3535
from .dynamath import Dynamath
3636
from .utils import *
37+
from .video_dataset_config import *
3738
from ..smp import *
3839

3940

@@ -196,7 +197,9 @@ def DATASET_MODALITY(dataset, *, default: str = 'IMAGE') -> str:
196197

197198
def build_dataset(dataset_name, **kwargs):
198199
for cls in DATASET_CLASSES:
199-
if dataset_name in cls.supported_datasets():
200+
if dataset_name in supported_video_datasets:
201+
return supported_video_datasets[dataset_name](**kwargs)
202+
elif dataset_name in cls.supported_datasets():
200203
return cls(dataset=dataset_name, **kwargs)
201204

202205
warnings.warn(f'Dataset {dataset_name} is not officially supported. ')

vlmeval/dataset/mvbench.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ class MVBench(VideoBaseDataset):
2828

2929
TYPE = 'Video-MCQ'
3030

31-
def __init__(self, dataset='MVBench', pack=False, nframe=0, fps=-1):
31+
def __init__(self, dataset='MVBench', nframe=0, fps=-1):
3232
self.type_data_list = {
3333
'Action Sequence': ('action_sequence.json',
3434
'your_data_path/star/Charades_v1_480/', 'video', True), # has start & end
@@ -71,7 +71,7 @@ def __init__(self, dataset='MVBench', pack=False, nframe=0, fps=-1):
7171
'Counterfactual Inference': ('counterfactual_inference.json',
7272
'your_data_path/clevrer/video_validation/', 'video', False),
7373
}
74-
super().__init__(dataset=dataset, pack=pack, nframe=nframe, fps=fps)
74+
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
7575

7676
@classmethod
7777
def supported_datasets(cls):
@@ -432,8 +432,8 @@ class MVBench_MP4(VideoBaseDataset):
432432
"""
433433
TYPE = 'Video-MCQ'
434434

435-
def __init__(self, dataset='MVBench_MP4', pack=False, nframe=0, fps=-1):
436-
super().__init__(dataset=dataset, pack=pack, nframe=nframe, fps=fps)
435+
def __init__(self, dataset='MVBench_MP4', nframe=0, fps=-1):
436+
super().__init__(dataset=dataset, nframe=nframe, fps=fps)
437437

438438
@classmethod
439439
def supported_datasets(cls):

vlmeval/dataset/video_base.py

+5-1
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,10 @@ def __init__(self,
3737
self.pack = pack
3838
self.nframe = nframe
3939
self.fps = fps
40+
if self.fps > 0 and self.nframe > 0:
41+
raise ValueError('fps and nframe should not be set at the same time')
42+
if self.fps <= 0 and self.nframe <= 0:
43+
raise ValueError('fps and nframe should be set at least one valid value')
4044

4145
def __len__(self):
4246
return len(self.videos) if self.pack else len(self.data)
@@ -81,7 +85,7 @@ def save_video_frames(self, video):
8185
indices = [int(i * step_size) for i in range(required_frames)]
8286

8387
# 提取帧并保存
84-
frame_paths = self.frame_paths_fps(video, len(indices), self.fps)
88+
frame_paths = self.frame_paths_fps(video, len(indices))
8589
flag = np.all([osp.exists(p) for p in frame_paths])
8690
if flag:
8791
return frame_paths
+54
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
from vlmeval.dataset import *
2+
from functools import partial
3+
4+
mmbench_video_dataset = {
5+
'MMBench_Video_8frame_nopack': partial(MMBenchVideo, dataset='MMBench-Video', nframe=8, pack=False),
6+
'MMBench_Video_8frame_pack': partial(MMBenchVideo, dataset='MMBench-Video', nframe=8, pack=True),
7+
'MMBench_Video_16frame_nopack': partial(MMBenchVideo, dataset='MMBench-Video', nframe=16, pack=False),
8+
'MMBench_Video_1fps_nopack': partial(MMBenchVideo, dataset='MMBench-Video', fps=1.0, pack=False),
9+
'MMBench_Video_1fps_pack': partial(MMBenchVideo, dataset='MMBench-Video', fps=1.0, pack=True)
10+
}
11+
12+
mvbench_dataset = {
13+
'MVBench_8frame': partial(MVBench, dataset='MVBench', nframe=8),
14+
# MVBench not support fps, but MVBench_MP4 does
15+
'MVBench_MP4_8frame': partial(MVBench_MP4, dataset='MVBench_MP4', nframe=8),
16+
'MVBench_MP4_1fps': partial(MVBench_MP4, dataset='MVBench_MP4', fps=1.0),
17+
}
18+
19+
videomme_dataset = {
20+
'Video-MME_8frame': partial(VideoMME, dataset='Video-MME', nframe=8),
21+
'Video-MME_8frame_subs': partial(VideoMME, dataset='Video-MME', nframe=8, use_subtitle=True),
22+
'Video-MME_1fps': partial(VideoMME, dataset='Video-MME', fps=1.0),
23+
'Video-MME_0.5fps': partial(VideoMME, dataset='Video-MME', fps=0.5),
24+
'Video-MME_0.5fps_subs': partial(VideoMME, dataset='Video-MME', fps=0.5, use_subtitle=True),
25+
}
26+
27+
longvideobench_dataset = {
28+
'LongVideoBench_8frame': partial(LongVideoBench, dataset='LongVideoBench', nframe=8),
29+
'LongVideoBench_8frame_subs': partial(LongVideoBench, dataset='LongVideoBench', nframe=8, use_subtitle=True),
30+
'LongVideoBench_1fps': partial(LongVideoBench, dataset='LongVideoBench', fps=1.0),
31+
'LongVideoBench_0.5fps': partial(LongVideoBench, dataset='LongVideoBench', fps=0.5),
32+
'LongVideoBench_0.5fps_subs': partial(LongVideoBench, dataset='LongVideoBench', fps=0.5, use_subtitle=True)
33+
}
34+
35+
mlvu_dataset = {
36+
'MLVU_8frame': partial(MLVU, dataset='MLVU', nframe=8),
37+
'MLVU_1fps': partial(MLVU, dataset='MLVU', fps=1.0)
38+
}
39+
40+
tempcompass_dataset = {
41+
'TempCompass_8frame': partial(TempCompass, dataset='TempCompass', nframe=8),
42+
'TempCompass_1fps': partial(TempCompass, dataset='TempCompass', fps=1.0),
43+
'TempCompass_0.5fps': partial(TempCompass, dataset='TempCompass', fps=0.5)
44+
}
45+
46+
supported_video_datasets = {}
47+
48+
dataset_groups = [
49+
mmbench_video_dataset, mvbench_dataset, videomme_dataset, longvideobench_dataset,
50+
mlvu_dataset, tempcompass_dataset
51+
]
52+
53+
for grp in dataset_groups:
54+
supported_video_datasets.update(grp)

0 commit comments

Comments
 (0)