diff --git a/requirements.txt b/requirements.txt index 33ac92b5..ccbf1fc2 100644 --- a/requirements.txt +++ b/requirements.txt @@ -23,4 +23,5 @@ mosestokenizer cached_property tqdm skops -pandas \ No newline at end of file +pandas +protobuf==3.20 \ No newline at end of file diff --git a/wtpsplit/train/train_lora.py b/wtpsplit/train/train_lora.py index b0e9154c..c70e8bed 100644 --- a/wtpsplit/train/train_lora.py +++ b/wtpsplit/train/train_lora.py @@ -10,16 +10,18 @@ from glob import glob from typing import List, Optional, Union +import adapters import datasets import numpy as np import torch +from adapters import AdapterArguments +from adapters.models import MODEL_MIXIN_MAPPING +from adapters.models.bert.mixin_bert import BertModelAdaptersMixin from tokenizers import AddedToken from tqdm import tqdm from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed -import adapters import wandb -from adapters import AdapterArguments from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification from wtpsplit.train.adaptertrainer import AdapterTrainer from wtpsplit.train.evaluate import evaluate_sentence @@ -32,6 +34,8 @@ os.environ["TOKENIZERS_PARALLELISM"] = "false" +MODEL_MIXIN_MAPPING["SubwordXLMRobertaModel"] = BertModelAdaptersMixin + @dataclass class Args: @@ -387,29 +391,6 @@ def maybe_pad(text): for lang in tqdm(data.keys(), desc="Language"): if lang in args.include_languages: for dataset_name in data[lang]["sentence"].keys(): - if "corrupted-asr" in dataset_name and ( - "lyrics" not in dataset_name - and "short" not in dataset_name - and "code" not in dataset_name - and "ted2020" not in dataset_name - and "legal" not in dataset_name - ): - print("SKIP: ", lang, dataset_name) - continue - if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name): - print("SKIP: ", lang, dataset_name) - continue - if "social-media" in dataset_name: - continue - if "nllb" in dataset_name: - continue - if lang == "en" and "legal-all-laws" in dataset_name: - # not available. - print("SKIP: ", lang, dataset_name) - continue - if not any(x in dataset_name for x in ["ersatz", "opus", "ud"]): - print("SKIP: ", lang, dataset_name) - continue print("RUNNING:", dataset_name, lang) # do model stuff here; otherwise, head params would be overwritten every time backbone = SubwordXLMForTokenClassification.from_pretrained(