diff --git a/setup.py b/setup.py index 0e47c112..5f74338c 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="wtpsplit", - version="2.0.3", + version="2.0.4", packages=find_packages(), description="Universal Robust, Efficient and Adaptable Sentence Segmentation", author="Markus Frohmann, Igor Sterner, Benjamin Minixhofer", diff --git a/wtpsplit/__init__.py b/wtpsplit/__init__.py index 8cbcd34a..09f9465d 100644 --- a/wtpsplit/__init__.py +++ b/wtpsplit/__init__.py @@ -8,23 +8,20 @@ with contextlib.redirect_stderr(open(os.devnull, "w")): import transformers # noqa -import adapters # noqa import numpy as np import skops.io as sio -from adapters.models import MODEL_MIXIN_MAPPING -from adapters.models.bert.mixin_bert import BertModelAdaptersMixin + from huggingface_hub import hf_hub_download -from transformers import AutoConfig, AutoModelForTokenClassification +from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer from transformers.utils.hub import cached_file -from wtpsplit.evaluation import token_to_char_probs from wtpsplit.extract import BertCharORTWrapper, PyTorchWrapper, extract -from wtpsplit.utils import Constants, indices_to_sentences, sigmoid +from wtpsplit.utils import Constants, indices_to_sentences, sigmoid, token_to_char_probs -__version__ = "2.0.3" +__version__ = "2.0.4" warnings.simplefilter("default", DeprecationWarning) # show by default - +warnings.simplefilter("ignore", category=FutureWarning) # for tranformers class WtP: def __init__( @@ -438,6 +435,10 @@ def __init__( self.use_lora = False + self.tokenizer = AutoTokenizer.from_pretrained( + "facebookAI/xlm-roberta-base" + ) + if isinstance(model_name_or_model, (str, Path)): model_name = str(model_name_or_model) is_local = os.path.isdir(model_name) @@ -500,6 +501,9 @@ def __init__( if (style_or_domain and not language) or (language and not style_or_domain): raise ValueError("Please specify both language and style_or_domain!") if style_or_domain and language: + import adapters # noqa + from adapters.models import MODEL_MIXIN_MAPPING # noqa + from adapters.models.bert.mixin_bert import BertModelAdaptersMixin # noqa # monkey patch mixin to avoid forking whole adapters library MODEL_MIXIN_MAPPING["SubwordXLMRobertaModel"] = BertModelAdaptersMixin model_type = self.model.model.config.model_type @@ -638,6 +642,7 @@ def newline_probability_fn(logits): batch_size=batch_size, pad_last_batch=pad_last_batch, verbose=verbose, + tokenizer=self.tokenizer ) # convert token probabilities to character probabilities for the entire array diff --git a/wtpsplit/evaluation/__init__.py b/wtpsplit/evaluation/__init__.py index 8d3ed68c..0147ca90 100644 --- a/wtpsplit/evaluation/__init__.py +++ b/wtpsplit/evaluation/__init__.py @@ -427,31 +427,3 @@ def punkt_sentencize(lang_code, text): return reconstruct_sentences(text, sent_tokenize(text, language=lang_code_to_lang(lang_code).lower())) except LookupError: raise LanguageError(f"punkt does not support {lang_code}") - - -def get_token_spans(tokenizer, offsets_mapping, tokens): - # Filter out special tokens and get their character start and end positions - valid_indices = np.array( - [ - idx - for idx, token in enumerate(tokens) - if token not in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token] - and idx < len(offsets_mapping) - ] - ) - valid_offsets = np.array(offsets_mapping)[valid_indices] - return valid_indices, valid_offsets - - -def token_to_char_probs(text, tokens, token_logits, tokenizer, offsets_mapping): - """Map from token probabalities to character probabilities""" - char_probs = np.full((len(text), token_logits.shape[1]), np.min(token_logits)) # Initialize with very low numbers - - valid_indices, valid_offsets = get_token_spans(tokenizer, offsets_mapping, tokens) - - # Assign the token's probability to the last character of the token - for i in range(valid_offsets.shape[0]): - start, end = valid_offsets[i] - char_probs[end - 1] = token_logits[valid_indices[i]] - - return char_probs diff --git a/wtpsplit/evaluation/adapt.py b/wtpsplit/evaluation/adapt.py index cbc5f7f2..1e5da05e 100644 --- a/wtpsplit/evaluation/adapt.py +++ b/wtpsplit/evaluation/adapt.py @@ -17,11 +17,11 @@ import adapters import wtpsplit.models # noqa: F401 -from wtpsplit.evaluation import evaluate_mixture, get_labels, token_to_char_probs, train_mixture +from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture from wtpsplit.evaluation.intrinsic_baselines import split_language_data from wtpsplit.extract import PyTorchWrapper, extract from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification -from wtpsplit.utils import Constants +from wtpsplit.utils import Constants, token_to_char_probs logger = logging.getLogger() logger.setLevel(logging.WARNING) diff --git a/wtpsplit/evaluation/intrinsic_pairwise.py b/wtpsplit/evaluation/intrinsic_pairwise.py index 37aac490..2da00a29 100644 --- a/wtpsplit/evaluation/intrinsic_pairwise.py +++ b/wtpsplit/evaluation/intrinsic_pairwise.py @@ -17,11 +17,11 @@ import adapters import wtpsplit.models # noqa: F401 -from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs +from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture from wtpsplit.evaluation.intrinsic_baselines import split_language_data from wtpsplit.extract import PyTorchWrapper from wtpsplit.extract_batched import extract_batched -from wtpsplit.utils import Constants +from wtpsplit.utils import Constants, token_to_char_probs from wtpsplit.evaluation.intrinsic import compute_statistics logger = logging.getLogger() diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 4847abfb..97236a2e 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -91,6 +91,7 @@ def extract( lang_code=None, pad_last_batch=False, verbose=False, + tokenizer=None, ): """ Computes logits for the given batch of texts by: @@ -102,9 +103,10 @@ def extract( """ if "xlm" in model.config.model_type: use_subwords = True - tokenizer = AutoTokenizer.from_pretrained( - "facebookAI/xlm-roberta-base", - ) + if not tokenizer: + tokenizer = AutoTokenizer.from_pretrained( + "facebookAI/xlm-roberta-base", + ) # tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]}) tokens = tokenizer(batch_of_texts, return_offsets_mapping=True, verbose=False, add_special_tokens=False) # remove CLS and SEP tokens, they are added later anyhow diff --git a/wtpsplit/train/evaluate.py b/wtpsplit/train/evaluate.py index f2a627a7..31fea227 100644 --- a/wtpsplit/train/evaluate.py +++ b/wtpsplit/train/evaluate.py @@ -5,10 +5,10 @@ import pysbd import sklearn.metrics -from wtpsplit.evaluation import token_to_char_probs +from wtpsplit.evaluation import from wtpsplit.evaluation.intrinsic_pairwise import generate_pairs, generate_k_mers, process_logits_k_mers from wtpsplit.extract import PyTorchWrapper, extract -from wtpsplit.utils import Constants, sigmoid, corrupt +from wtpsplit.utils import Constants, sigmoid, corrupt, token_to_char_probs logger = logging.getLogger(__name__) diff --git a/wtpsplit/utils/__init__.py b/wtpsplit/utils/__init__.py index e8ed4821..509e2dfc 100644 --- a/wtpsplit/utils/__init__.py +++ b/wtpsplit/utils/__init__.py @@ -433,6 +433,34 @@ def reconstruct_sentences(text, partial_sentences): return fixed_sentences +def get_token_spans(tokenizer, offsets_mapping, tokens): + # Filter out special tokens and get their character start and end positions + valid_indices = np.array( + [ + idx + for idx, token in enumerate(tokens) + if token not in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token] + and idx < len(offsets_mapping) + ] + ) + valid_offsets = np.array(offsets_mapping)[valid_indices] + return valid_indices, valid_offsets + + +def token_to_char_probs(text, tokens, token_logits, tokenizer, offsets_mapping): + """Map from token probabalities to character probabilities""" + char_probs = np.full((len(text), token_logits.shape[1]), np.min(token_logits)) # Initialize with very low numbers + + valid_indices, valid_offsets = get_token_spans(tokenizer, offsets_mapping, tokens) + + # Assign the token's probability to the last character of the token + for i in range(valid_offsets.shape[0]): + start, end = valid_offsets[i] + char_probs[end - 1] = token_logits[valid_indices[i]] + + return char_probs + + if __name__ == "__main__": # test corrupt function from tokenizers import AddedToken