Skip to content

Commit

Permalink
add ONNX support
Browse files Browse the repository at this point in the history
  • Loading branch information
markus583 committed Sep 10, 2024
1 parent 46f3d19 commit 88bdc91
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 47 deletions.
30 changes: 30 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,36 @@ sat_adapted.split("This is a test This is another test.")
# returns ['This is a test ', 'This is another test']
```

## ONNX Support
🚀 You can now enable even faster ONNX inference for `sat` and `sat-sm` models! 🚀

```python
sat = SaT("sat-3l-sm", onnx_providers=["CUDAExecutionProvider"])
```

```python
>>> from wtpsplit import SaT
>>> texts = ["This is a sentence. This is another sentence."] * 1000

# PyTorch GPU
>>> model = SaT("sat-3l-sm")
>>> model.half().to("cuda")
>>> %timeit list(model.split(texts))
138 ms ± 8.41 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

# onnxruntime GPU
>>> model = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider"])
>>> %timeit list(model.split(texts))
198 ms ± 1.36 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
```

If you wish to use LoRA in combination with an ONNX model:
- Run `scripts/export_to_onnx_sat.py` with `use_lora: True` and an appropriate `output_dir: <OUTPUT_DIR>`.
- If you have a local LoRA module, use `lora_path`.
- If you wish to load a LoRA module from the HuggingFace hub, use `style_or_domain` and `language`.
- Load the ONNX model with merged LoRA weights:
`sat = SaT(<OUTPUT_DIR>, onnx_providers=["CUDAExecutionProvider"])`


## Available Models

Expand Down
74 changes: 65 additions & 9 deletions scripts/export_to_onnx_sat.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,32 @@
from dataclasses import dataclass
from pathlib import Path

import adapters # noqa
import onnx
import torch
from adapters.models import MODEL_MIXIN_MAPPING # noqa
from adapters.models.bert.mixin_bert import BertModelAdaptersMixin # noqa
from huggingface_hub import hf_hub_download
from onnxruntime.transformers.optimizer import optimize_model # noqa
from transformers import AutoModelForTokenClassification, HfArgumentParser

import wtpsplit # noqa
import wtpsplit.models # noqa
from wtpsplit.utils import Constants

MODEL_MIXIN_MAPPING["SubwordXLMRobertaModel"] = BertModelAdaptersMixin


@dataclass
class Args:
model_name_or_path: str = "segment-any-text/sat-1l-sm"
output_dir: str = "sat-1l-sm"
model_name_or_path: str = "segment-any-text/sat-1l-sm" # model from HF Hub: https://huggingface.co/segment-any-text
output_dir: str = "sat-1l-sm" # output directory, saves to current directory
device: str = "cuda"
# TODO: lora merging here
use_lora: bool = False
lora_path: str = None # local path to lora weights
# otherwise, fetch from HF Hub:
style_or_domain: str = "ud"
language: str = "en"


if __name__ == "__main__":
Expand All @@ -25,25 +36,70 @@ class Args:
output_dir.mkdir(exist_ok=True, parents=True)

model = AutoModelForTokenClassification.from_pretrained(args.model_name_or_path, force_download=False)
model = model.half() # CUDA ONLY!

model = model.to(args.device)

# fetch config.json from huggingface hub
hf_hub_download(
repo_id=args.model_name_or_path,
filename="config.json",
local_dir=output_dir,
)

# LoRA SETUP
if args.use_lora:
# adapters need xlm-roberta as model type.
model_type = model.config.model_type
model.config.model_type = "xlm-roberta"
adapters.init(model)
# reset model type (used later)
model.config.model_type = model_type
if not args.lora_path:
for file in [
"adapter_config.json",
"head_config.json",
"pytorch_adapter.bin",
"pytorch_model_head.bin",
]:
hf_hub_download(
repo_id=args.model_name_or_path,
subfolder=f"loras/{args.style_or_domain}/{args.language}",
filename=file,
local_dir=Constants.CACHE_DIR,
)
lora_load_path = str(Constants.CACHE_DIR / "loras" / args.style_or_domain / args.language)
else:
lora_load_path = args.lora_path

print(f"Using LoRA weights from {lora_load_path}.")
model.load_adapter(
lora_load_path,
set_active=True,
with_head=True,
load_as="sat-lora",
)
# merge lora weights into transformer for 0 efficiency overhead
model.merge_adapter("sat-lora")
print("LoRA setup done.")
# LoRA setup done, model is now ready for export.

model = model.half()

torch.onnx.export(
model,
{
"attention_mask": torch.zeros((1, 1), dtype=torch.float16, device=args.device),
"input_ids": torch.zeros((1, 1), dtype=torch.int64, device=args.device),
"attention_mask": torch.randn((1, 1), dtype=torch.float16, device=args.device),
"input_ids": torch.randint(0, 250002, (1, 1), dtype=torch.int64, device=args.device),
},
output_dir / "model.onnx",
verbose=True,
input_names=["attention_mask", "input_ids"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
"attention_mask": {0: "batch", 1: "sequence"},
"input_ids": {0: "batch", 1: "sequence"},
"logits": {0: "batch", 1: "sequence"},
},
# opset_version=11
)

m = optimize_model(
Expand All @@ -60,4 +116,4 @@ class Args:
onnx.save_model(m.model, optimized_model_path)

onnx_model = onnx.load(output_dir / "model.onnx")
onnx.checker.check_model(onnx_model, full_check=True)
print(onnx.checker.check_model(onnx_model, full_check=True))
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="wtpsplit",
version="2.0.8",
version="2.1.0",
packages=find_packages(),
description="Universal Robust, Efficient and Adaptable Sentence Segmentation",
author="Markus Frohmann, Igor Sterner, Benjamin Minixhofer",
Expand Down
8 changes: 4 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from wtpsplit import WtP, SaT


# def test_split_ort():
# sat = SaT("segment-any-text/sat-3l", ort_providers=["CPUExecutionProvider"])
def test_split_ort():
sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"])

# splits = sat.split("This is a test sentence This is another test sentence.", threshold=0.005)
# assert splits == ["This is a test sentence ", "This is another test sentence."]
splits = sat.split("This is a test sentence This is another test sentence.", threshold=0.25)
assert splits == ["This is a test sentence ", "This is another test sentence."]


def test_split_torch():
Expand Down
63 changes: 35 additions & 28 deletions wtpsplit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer
from transformers.utils.hub import cached_file

from wtpsplit.extract import BertCharORTWrapper, PyTorchWrapper, extract
from wtpsplit.extract import BertCharORTWrapper, SaTORTWrapper, PyTorchWrapper, extract
from wtpsplit.utils import Constants, indices_to_sentences, sigmoid, token_to_char_probs

__version__ = "2.0.8"
__version__ = "2.1.0"

warnings.simplefilter("default", DeprecationWarning) # show by default
warnings.simplefilter("ignore", category=FutureWarning) # for tranformers
Expand Down Expand Up @@ -88,8 +88,6 @@ def __init__(

try:
import onnxruntime as ort # noqa

ort.set_default_logger_severity(0)
except ModuleNotFoundError:
raise ValueError("Please install `onnxruntime` to use WtP with an ONNX model.")

Expand Down Expand Up @@ -449,38 +447,39 @@ def __init__(

if is_local:
model_path = Path(model_name)
onnx_path = model_path / "model.onnx"
onnx_path = model_path / "model_optimized.onnx"
if not onnx_path.exists():
onnx_path = None
else:
# no need to load if no ort_providers set
if ort_providers is not None:
onnx_path = cached_file(model_name_to_fetch, "model.onnx", **(from_pretrained_kwargs or {}))
onnx_path = cached_file(model_name_to_fetch, "model_optimized.onnx", **(from_pretrained_kwargs or {}))
else:
onnx_path = None

if ort_providers is not None:
raise NotImplementedError("ONNX is not supported for SaT *yet*.")
# if onnx_path is None:
# raise ValueError(
# "Could not find an ONNX model in the model directory. Try `use_ort=False` to run with PyTorch."
# )

# try:
# import onnxruntime as ort # noqa

# ort.set_default_logger_severity(0)
# except ModuleNotFoundError:
# raise ValueError("Please install `onnxruntime` to use WtP with an ONNX model.")

# # to register models for AutoConfig
# import wtpsplit.configs # noqa

# # TODO: ONNX integration
# self.model = SaTORTWrapper(
# AutoConfig.from_pretrained(model_name_to_fetch, **(from_pretrained_kwargs or {})),
# ort.InferenceSession(str(onnx_path), providers=ort_providers, **(ort_kwargs or {})),
# )
if onnx_path is None:
raise ValueError(
"Could not find an ONNX model in the model directory. Try `use_ort=False` to run with PyTorch."
)

try:
import onnxruntime as ort # noqa
except ModuleNotFoundError:
raise ValueError("Please install `onnxruntime` to use SaT with an ONNX model.")

# to register models for AutoConfig
import wtpsplit.configs # noqa

self.model = SaTORTWrapper(
AutoConfig.from_pretrained(model_name_to_fetch, **(from_pretrained_kwargs or {})),
ort.InferenceSession(str(onnx_path), providers=ort_providers, **(ort_kwargs or {})),
)
if lora_path:
raise ValueError(
"If using ONNX with LoRA, execute `scripts/export_to_onnx_sat.py` with `use_lora=True`."
"Reference the chosen `output_dir` here for `model_name_or_model`. and set `lora_path=None`."
)
else:
# to register models for AutoConfig
try:
Expand All @@ -496,7 +495,6 @@ def __init__(
)
)
# LoRA LOADING
# TODO: LoRA + ONNX ?
if not lora_path:
if (style_or_domain and not language) or (language and not style_or_domain):
raise ValueError("Please specify both language and style_or_domain!")
Expand Down Expand Up @@ -792,3 +790,12 @@ def get_default_threshold(model_str: str):
text, np.where(probs > sentence_threshold)[0], strip_whitespace=strip_whitespace
)
yield sentences


if __name__ == "__main__":
sat = SaT("sat-3l-lora", ort_providers=["CPUExecutionProvider"])
print(sat.split("Hello, World! Next."))

wtp = WtP("wtp-bert-tiny", ort_providers=["CPUExecutionProvider"])
print(wtp.split("Hello, World! Next."))
print("DONE!")
10 changes: 5 additions & 5 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ def __getattr__(self, name):

def __call__(self, input_ids, attention_mask):
logits = self.ort_session.run(
output_names=["logits"],
input_feed={
"attention_mask": attention_mask.astype(np.int64),
"input_ids": input_ids.astype(np.float16),
}, # .astype(np.int64)},
["logits"],
{
self.ort_session.get_inputs()[0].name: input_ids.astype(np.int64),
self.ort_session.get_inputs()[1].name: attention_mask.astype(np.float16),
},
)[0]

return {"logits": logits}
Expand Down

0 comments on commit 88bdc91

Please sign in to comment.