Skip to content

Commit

Permalink
patch adaptation: model import
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Sep 2, 2024
1 parent ba565f1 commit 1220a3b
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 26 deletions.
3 changes: 2 additions & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,5 @@ mosestokenizer
cached_property
tqdm
skops
pandas
pandas
protobuf==3.20
31 changes: 6 additions & 25 deletions wtpsplit/train/train_lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,18 @@
from glob import glob
from typing import List, Optional, Union

import adapters
import datasets
import numpy as np
import torch
from adapters import AdapterArguments
from adapters.models import MODEL_MIXIN_MAPPING
from adapters.models.bert.mixin_bert import BertModelAdaptersMixin
from tokenizers import AddedToken
from tqdm import tqdm
from transformers import AutoTokenizer, HfArgumentParser, TrainingArguments, set_seed

import adapters
import wandb
from adapters import AdapterArguments
from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification
from wtpsplit.train.adaptertrainer import AdapterTrainer
from wtpsplit.train.evaluate import evaluate_sentence
Expand All @@ -32,6 +34,8 @@

os.environ["TOKENIZERS_PARALLELISM"] = "false"

MODEL_MIXIN_MAPPING["SubwordXLMRobertaModel"] = BertModelAdaptersMixin


@dataclass
class Args:
Expand Down Expand Up @@ -387,29 +391,6 @@ def maybe_pad(text):
for lang in tqdm(data.keys(), desc="Language"):
if lang in args.include_languages:
for dataset_name in data[lang]["sentence"].keys():
if "corrupted-asr" in dataset_name and (
"lyrics" not in dataset_name
and "short" not in dataset_name
and "code" not in dataset_name
and "ted2020" not in dataset_name
and "legal" not in dataset_name
):
print("SKIP: ", lang, dataset_name)
continue
if "legal" in dataset_name and not ("laws" in dataset_name or "judgements" in dataset_name):
print("SKIP: ", lang, dataset_name)
continue
if "social-media" in dataset_name:
continue
if "nllb" in dataset_name:
continue
if lang == "en" and "legal-all-laws" in dataset_name:
# not available.
print("SKIP: ", lang, dataset_name)
continue
if not any(x in dataset_name for x in ["ersatz", "opus", "ud"]):
print("SKIP: ", lang, dataset_name)
continue
print("RUNNING:", dataset_name, lang)
# do model stuff here; otherwise, head params would be overwritten every time
backbone = SubwordXLMForTokenClassification.from_pretrained(
Expand Down

0 comments on commit 1220a3b

Please sign in to comment.