From cfd5e24411d8c658b3000af0c62b6602a6c955ca Mon Sep 17 00:00:00 2001 From: markus583 Date: Sun, 7 Jul 2024 13:38:17 +0200 Subject: [PATCH] fix tok length issue --- requirements.txt | 1 - scripts/export_to_onnx_charbert.py | 2 +- scripts/export_to_onnx_sat.py | 38 ++++++++++++++++-------------- setup.py | 2 +- wtpsplit/__init__.py | 2 +- wtpsplit/extract.py | 16 ++++++++----- 6 files changed, 33 insertions(+), 28 deletions(-) diff --git a/requirements.txt b/requirements.txt index dc004891..6d3536f7 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,7 +13,6 @@ numpy==1.23.5 pydantic torchinfo conllu -genalog pandarallel cohere replicate diff --git a/scripts/export_to_onnx_charbert.py b/scripts/export_to_onnx_charbert.py index ecd7c903..1a94dc75 100644 --- a/scripts/export_to_onnx_charbert.py +++ b/scripts/export_to_onnx_charbert.py @@ -7,7 +7,7 @@ from transformers import AutoModelForTokenClassification, HfArgumentParser import wtpsplit # noqa - +import wtpsplit.models # noqa @dataclass class Args: diff --git a/scripts/export_to_onnx_sat.py b/scripts/export_to_onnx_sat.py index f7674dd9..dd82baa5 100644 --- a/scripts/export_to_onnx_sat.py +++ b/scripts/export_to_onnx_sat.py @@ -12,9 +12,9 @@ @dataclass class Args: - model_name_or_path: str = "segment-any-text/sat-12l-no-limited-lookahead" - output_dir: str = "sat-12l-no-limited-lookahead" - device: str = "cpu" + model_name_or_path: str = "segment-any-text/sat-1l-sm" + output_dir: str = "sat-1l-sm" + device: str = "cuda" # TODO: lora merging here @@ -24,15 +24,15 @@ class Args: output_dir = Path(args.output_dir) output_dir.mkdir(exist_ok=True, parents=True) - model = AutoModelForTokenClassification.from_pretrained(args.model_name_or_path, force_download=True) - # model = model.half() # CUDA ONLY! + model = AutoModelForTokenClassification.from_pretrained(args.model_name_or_path, force_download=False) + model = model.half() # CUDA ONLY! model = model.to(args.device) torch.onnx.export( model, { - "attention_mask": torch.zeros((1, 14), dtype=torch.long, device=args.device), - "input_ids": torch.zeros((1, 14), dtype=torch.long, device=args.device), + "attention_mask": torch.zeros((1, 1), dtype=torch.float16, device=args.device), + "input_ids": torch.zeros((1, 1), dtype=torch.int64, device=args.device), }, output_dir / "model.onnx", verbose=True, @@ -41,21 +41,23 @@ class Args: dynamic_axes={ "input_ids": {0: "batch", 1: "sequence"}, "attention_mask": {0: "batch", 1: "sequence"}, - "logits": {0: "batch", 1: "sequence"} + "logits": {0: "batch", 1: "sequence"}, }, # opset_version=11 ) - # m = optimize_model( - # str(output_dir / "model.onnx"), - # model_type="bert", - # optimization_options=None, - # opt_level=0, - # use_gpu=False, - # ) + m = optimize_model( + str(output_dir / "model.onnx"), + model_type="bert", + num_heads=0, + hidden_size=0, + optimization_options=None, + opt_level=0, + use_gpu=False, + ) - # optimized_model_path = output_dir / "model_optimized.onnx" - # onnx.save_model(m.model, optimized_model_path) + optimized_model_path = output_dir / "model_optimized.onnx" + onnx.save_model(m.model, optimized_model_path) onnx_model = onnx.load(output_dir / "model.onnx") - onnx.checker.check_model(onnx_model, full_check=True) \ No newline at end of file + onnx.checker.check_model(onnx_model, full_check=True) diff --git a/setup.py b/setup.py index 5f74338c..24dd9002 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="wtpsplit", - version="2.0.4", + version="2.0.5", packages=find_packages(), description="Universal Robust, Efficient and Adaptable Sentence Segmentation", author="Markus Frohmann, Igor Sterner, Benjamin Minixhofer", diff --git a/wtpsplit/__init__.py b/wtpsplit/__init__.py index 09f9465d..2e291698 100644 --- a/wtpsplit/__init__.py +++ b/wtpsplit/__init__.py @@ -18,7 +18,7 @@ from wtpsplit.extract import BertCharORTWrapper, PyTorchWrapper, extract from wtpsplit.utils import Constants, indices_to_sentences, sigmoid, token_to_char_probs -__version__ = "2.0.4" +__version__ = "2.0.5" warnings.simplefilter("default", DeprecationWarning) # show by default warnings.simplefilter("ignore", category=FutureWarning) # for tranformers diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 97236a2e..6bdf1411 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -44,7 +44,10 @@ def __getattr__(self, name): def __call__(self, input_ids, attention_mask): logits = self.ort_session.run( output_names=["logits"], - input_feed={"attention_mask": attention_mask.astype(np.int64), "input_ids": input_ids.astype(np.int64)}, + input_feed={ + "attention_mask": attention_mask.astype(np.int64), + "input_ids": input_ids.astype(np.float16), + }, # .astype(np.int64)}, )[0] return {"logits": logits} @@ -71,9 +74,9 @@ def __call__(self, attention_mask, hashed_ids=None, language_ids=None, input_ids input_ids=torch.from_numpy(input_ids).to(self.model.device) if input_ids is not None else None, hashed_ids=torch.from_numpy(hashed_ids).to(self.model.device) if hashed_ids is not None else None, attention_mask=torch.from_numpy(attention_mask).to(self.model.device), - language_ids=torch.from_numpy(language_ids).to(self.model.device) - if language_ids is not None - else None, + language_ids=( + torch.from_numpy(language_ids).to(self.model.device) if language_ids is not None else None + ), )["logits"] .cpu() .numpy() @@ -124,8 +127,9 @@ def extract( text_lengths = [len(text) for text in batch_of_texts] # reduce block size if possible block_size = min(max_block_size, max(text_lengths)) - if use_subwords and block_size == 512: - block_size -= 2 # account for CLS and SEP tokens + if use_subwords and block_size > 510: + overflow_length = block_size - 510 + block_size -= overflow_length # account for CLS and SEP tokens # make sure block_size is a multiple of downsampling rate downsampling_rate = getattr(model.config, "downsampling_rate", 1)