Skip to content

Commit

Permalink
tokenizer init + torch import
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Jun 28, 2024
1 parent 66eb194 commit a381079
Show file tree
Hide file tree
Showing 8 changed files with 53 additions and 46 deletions.
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="wtpsplit",
version="2.0.3",
version="2.0.4",
packages=find_packages(),
description="Universal Robust, Efficient and Adaptable Sentence Segmentation",
author="Markus Frohmann, Igor Sterner, Benjamin Minixhofer",
Expand Down
21 changes: 13 additions & 8 deletions wtpsplit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,23 +8,20 @@
with contextlib.redirect_stderr(open(os.devnull, "w")):
import transformers # noqa

import adapters # noqa
import numpy as np
import skops.io as sio
from adapters.models import MODEL_MIXIN_MAPPING
from adapters.models.bert.mixin_bert import BertModelAdaptersMixin

from huggingface_hub import hf_hub_download
from transformers import AutoConfig, AutoModelForTokenClassification
from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer
from transformers.utils.hub import cached_file

from wtpsplit.evaluation import token_to_char_probs
from wtpsplit.extract import BertCharORTWrapper, PyTorchWrapper, extract
from wtpsplit.utils import Constants, indices_to_sentences, sigmoid
from wtpsplit.utils import Constants, indices_to_sentences, sigmoid, token_to_char_probs

__version__ = "2.0.3"
__version__ = "2.0.4"

warnings.simplefilter("default", DeprecationWarning) # show by default

warnings.simplefilter("ignore", category=FutureWarning) # for tranformers

class WtP:
def __init__(
Expand Down Expand Up @@ -438,6 +435,10 @@ def __init__(

self.use_lora = False

self.tokenizer = AutoTokenizer.from_pretrained(
"facebookAI/xlm-roberta-base"
)

if isinstance(model_name_or_model, (str, Path)):
model_name = str(model_name_or_model)
is_local = os.path.isdir(model_name)
Expand Down Expand Up @@ -500,6 +501,9 @@ def __init__(
if (style_or_domain and not language) or (language and not style_or_domain):
raise ValueError("Please specify both language and style_or_domain!")
if style_or_domain and language:
import adapters # noqa
from adapters.models import MODEL_MIXIN_MAPPING # noqa
from adapters.models.bert.mixin_bert import BertModelAdaptersMixin # noqa
# monkey patch mixin to avoid forking whole adapters library
MODEL_MIXIN_MAPPING["SubwordXLMRobertaModel"] = BertModelAdaptersMixin
model_type = self.model.model.config.model_type
Expand Down Expand Up @@ -638,6 +642,7 @@ def newline_probability_fn(logits):
batch_size=batch_size,
pad_last_batch=pad_last_batch,
verbose=verbose,
tokenizer=self.tokenizer
)

# convert token probabilities to character probabilities for the entire array
Expand Down
28 changes: 0 additions & 28 deletions wtpsplit/evaluation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,31 +427,3 @@ def punkt_sentencize(lang_code, text):
return reconstruct_sentences(text, sent_tokenize(text, language=lang_code_to_lang(lang_code).lower()))
except LookupError:
raise LanguageError(f"punkt does not support {lang_code}")


def get_token_spans(tokenizer, offsets_mapping, tokens):
# Filter out special tokens and get their character start and end positions
valid_indices = np.array(
[
idx
for idx, token in enumerate(tokens)
if token not in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]
and idx < len(offsets_mapping)
]
)
valid_offsets = np.array(offsets_mapping)[valid_indices]
return valid_indices, valid_offsets


def token_to_char_probs(text, tokens, token_logits, tokenizer, offsets_mapping):
"""Map from token probabalities to character probabilities"""
char_probs = np.full((len(text), token_logits.shape[1]), np.min(token_logits)) # Initialize with very low numbers

valid_indices, valid_offsets = get_token_spans(tokenizer, offsets_mapping, tokens)

# Assign the token's probability to the last character of the token
for i in range(valid_offsets.shape[0]):
start, end = valid_offsets[i]
char_probs[end - 1] = token_logits[valid_indices[i]]

return char_probs
4 changes: 2 additions & 2 deletions wtpsplit/evaluation/adapt.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@

import adapters
import wtpsplit.models # noqa: F401
from wtpsplit.evaluation import evaluate_mixture, get_labels, token_to_char_probs, train_mixture
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture
from wtpsplit.evaluation.intrinsic_baselines import split_language_data
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.models import SubwordXLMConfig, SubwordXLMForTokenClassification
from wtpsplit.utils import Constants
from wtpsplit.utils import Constants, token_to_char_probs

logger = logging.getLogger()
logger.setLevel(logging.WARNING)
Expand Down
4 changes: 2 additions & 2 deletions wtpsplit/evaluation/intrinsic_pairwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@
import adapters

import wtpsplit.models # noqa: F401
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture, token_to_char_probs
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture
from wtpsplit.evaluation.intrinsic_baselines import split_language_data
from wtpsplit.extract import PyTorchWrapper
from wtpsplit.extract_batched import extract_batched
from wtpsplit.utils import Constants
from wtpsplit.utils import Constants, token_to_char_probs
from wtpsplit.evaluation.intrinsic import compute_statistics

logger = logging.getLogger()
Expand Down
8 changes: 5 additions & 3 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def extract(
lang_code=None,
pad_last_batch=False,
verbose=False,
tokenizer=None,
):
"""
Computes logits for the given batch of texts by:
Expand All @@ -102,9 +103,10 @@ def extract(
"""
if "xlm" in model.config.model_type:
use_subwords = True
tokenizer = AutoTokenizer.from_pretrained(
"facebookAI/xlm-roberta-base",
)
if not tokenizer:
tokenizer = AutoTokenizer.from_pretrained(
"facebookAI/xlm-roberta-base",
)
# tokenizer.add_special_tokens({"additional_special_tokens": [AddedToken("\n")]})
tokens = tokenizer(batch_of_texts, return_offsets_mapping=True, verbose=False, add_special_tokens=False)
# remove CLS and SEP tokens, they are added later anyhow
Expand Down
4 changes: 2 additions & 2 deletions wtpsplit/train/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import pysbd
import sklearn.metrics

from wtpsplit.evaluation import token_to_char_probs
from wtpsplit.evaluation import
from wtpsplit.evaluation.intrinsic_pairwise import generate_pairs, generate_k_mers, process_logits_k_mers

Check failure on line 9 in wtpsplit/train/evaluate.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (E999)

wtpsplit/train/evaluate.py:8:32: E999 SyntaxError: Unexpected token Newline

Check failure on line 9 in wtpsplit/train/evaluate.py

View workflow job for this annotation

GitHub Actions / build (3.11)

Ruff (E999)

wtpsplit/train/evaluate.py:8:32: E999 SyntaxError: Unexpected token Newline
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants, sigmoid, corrupt
from wtpsplit.utils import Constants, sigmoid, corrupt, token_to_char_probs

logger = logging.getLogger(__name__)

Expand Down
28 changes: 28 additions & 0 deletions wtpsplit/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,34 @@ def reconstruct_sentences(text, partial_sentences):
return fixed_sentences


def get_token_spans(tokenizer, offsets_mapping, tokens):
# Filter out special tokens and get their character start and end positions
valid_indices = np.array(
[
idx
for idx, token in enumerate(tokens)
if token not in [tokenizer.cls_token, tokenizer.sep_token, tokenizer.pad_token]
and idx < len(offsets_mapping)
]
)
valid_offsets = np.array(offsets_mapping)[valid_indices]
return valid_indices, valid_offsets


def token_to_char_probs(text, tokens, token_logits, tokenizer, offsets_mapping):
"""Map from token probabalities to character probabilities"""
char_probs = np.full((len(text), token_logits.shape[1]), np.min(token_logits)) # Initialize with very low numbers

valid_indices, valid_offsets = get_token_spans(tokenizer, offsets_mapping, tokens)

# Assign the token's probability to the last character of the token
for i in range(valid_offsets.shape[0]):
start, end = valid_offsets[i]
char_probs[end - 1] = token_logits[valid_indices[i]]

return char_probs


if __name__ == "__main__":
# test corrupt function
from tokenizers import AddedToken
Expand Down

0 comments on commit a381079

Please sign in to comment.