Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WMT19 dataset configuration and inference code #1503

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 122 additions & 0 deletions configs/datasets/wmt19/wmt19_gen.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
from opencompass.openicl.icl_prompt_template import PromptTemplate
from opencompass.openicl.icl_retriever import ZeroRetriever, BM25Retriever
from opencompass.openicl.icl_inferencer import GenInferencer
from opencompass.openicl.icl_evaluator import BleuEvaluator
from opencompass.datasets.wmt19 import WMT19TranslationDataset

LANG_CODE_TO_NAME = {
'cs': 'Czech',
'de': 'German',
'en': 'English',
'fi': 'Finnish',
'fr': 'French',
'gu': 'Gujarati',
'kk': 'Kazakh',
'lt': 'Lithuanian',
'ru': 'Russian',
'zh': 'Chinese'
}

wmt19_reader_cfg = dict(
input_columns=['input'],
output_column='target',
train_split='validation',
test_split='validation')

wmt19_infer_cfg_0shot = dict(
prompt_template=dict(
type=PromptTemplate,
template=dict(
round=[
dict(role='HUMAN', prompt='Translate the following {src_lang_name} text to {tgt_lang_name}:\n{{input}}\n'),
dict(role='BOT', prompt='Translation:\n')
]
)
),
retriever=dict(type=ZeroRetriever),
inferencer=dict(type=GenInferencer)
)

wmt19_infer_cfg_5shot = dict(
ice_template=dict(
type=PromptTemplate,
template='Example:\n{src_lang_name}: {{input}}\n{tgt_lang_name}: {{target}}'
),
prompt_template=dict(
type=PromptTemplate,
template='</E>\nTranslate the following {src_lang_name} text to {tgt_lang_name}:\n{{input}}\nTranslation:\n',
ice_token='</E>',
),
retriever=dict(type=BM25Retriever, ice_num=5),
inferencer=dict(type=GenInferencer),
)

wmt19_eval_cfg = dict(
evaluator=dict(
type=BleuEvaluator
),
pred_role='BOT',
)

language_pairs = [
('cs', 'en'), ('de', 'en'), ('fi', 'en'), ('fr', 'de'),
('gu', 'en'), ('kk', 'en'), ('lt', 'en'), ('ru', 'en'), ('zh', 'en')
]

wmt19_datasets = []

for src_lang, tgt_lang in language_pairs:
src_lang_name = LANG_CODE_TO_NAME[src_lang]
tgt_lang_name = LANG_CODE_TO_NAME[tgt_lang]

wmt19_datasets.extend([
dict(
abbr=f'wmt19_{src_lang}-{tgt_lang}_0shot',
type=WMT19TranslationDataset,
path='/path/to/wmt19',
src_lang=src_lang,
tgt_lang=tgt_lang,
reader_cfg=wmt19_reader_cfg,
infer_cfg={
**wmt19_infer_cfg_0shot,
'prompt_template': {
**wmt19_infer_cfg_0shot['prompt_template'],
'template': {
**wmt19_infer_cfg_0shot['prompt_template']['template'],
'round': [
{
**wmt19_infer_cfg_0shot['prompt_template']['template']['round'][0],
'prompt': wmt19_infer_cfg_0shot['prompt_template']['template']['round'][0]['prompt'].format(
src_lang_name=src_lang_name, tgt_lang_name=tgt_lang_name
)
},
wmt19_infer_cfg_0shot['prompt_template']['template']['round'][1]
]
}
}
},
eval_cfg=wmt19_eval_cfg),
dict(
abbr=f'wmt19_{src_lang}-{tgt_lang}_5shot',
type=WMT19TranslationDataset,
path='/path/to/wmt19',
src_lang=src_lang,
tgt_lang=tgt_lang,
reader_cfg=wmt19_reader_cfg,
infer_cfg={
**wmt19_infer_cfg_5shot,
'ice_template': {
**wmt19_infer_cfg_5shot['ice_template'],
'template': wmt19_infer_cfg_5shot['ice_template']['template'].format(
src_lang_name=src_lang_name, tgt_lang_name=tgt_lang_name
)
},
'prompt_template': {
**wmt19_infer_cfg_5shot['prompt_template'],
'template': wmt19_infer_cfg_5shot['prompt_template']['template'].format(
src_lang_name=src_lang_name, tgt_lang_name=tgt_lang_name
)
}
},
eval_cfg=wmt19_eval_cfg),
])
1 change: 1 addition & 0 deletions opencompass/datasets/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@
from .wikibench import * # noqa: F401, F403
from .winograd import * # noqa: F401, F403
from .winogrande import * # noqa: F401, F403
from .wmt19 import * # noqa: F401, F403
from .wnli import wnliDataset # noqa: F401, F403
from .wsc import * # noqa: F401, F403
from .xcopa import * # noqa: F401, F403
Expand Down
38 changes: 38 additions & 0 deletions opencompass/datasets/wmt19.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import os
import pandas as pd
from datasets import Dataset, DatasetDict
from opencompass.registry import LOAD_DATASET
from opencompass.datasets.base import BaseDataset

@LOAD_DATASET.register_module()
class WMT19TranslationDataset(BaseDataset):
@staticmethod
def load(path: str, src_lang: str, tgt_lang: str):
print(f"Attempting to load data from path: {path}")
print(f"Source language: {src_lang}, Target language: {tgt_lang}")

lang_pair_dir = os.path.join(path, f"{src_lang}-{tgt_lang}")
if not os.path.exists(lang_pair_dir):
lang_pair_dir = os.path.join(path, f"{tgt_lang}-{src_lang}")
if not os.path.exists(lang_pair_dir):
raise ValueError(f"Cannot find directory for language pair {src_lang}-{tgt_lang} or {tgt_lang}-{src_lang}")

print(f"Loading data from directory: {lang_pair_dir}")

val_file = os.path.join(lang_pair_dir, "validation-00000-of-00001.parquet")
val_df = pd.read_parquet(val_file)

def process_split(df):
return Dataset.from_dict({
'input': df['translation'].apply(lambda x: x[src_lang]).tolist(),
'target': df['translation'].apply(lambda x: x[tgt_lang]).tolist()
})

return DatasetDict({
'validation': process_split(val_df)
})

@classmethod
def get_dataset(cls, path, src_lang, tgt_lang):
return cls.load(path, src_lang, tgt_lang)

Loading