Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add ONNX support for SaT models #129

Merged
merged 4 commits into from
Sep 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,38 @@ sat_adapted.split("This is a test This is another test.")
# returns ['This is a test ', 'This is another test']
```

## ONNX Support
🚀 You can now enable even faster ONNX inference for `sat` and `sat-sm` models! 🚀

```python
sat = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
```

```python
>>> from wtpsplit import SaT
>>> texts = ["This is a sentence. This is another sentence."] * 1000

# PyTorch GPU
>>> model_pytorch = SaT("sat-3l-sm")
>>> model_pytorch.half().to("cuda");
>>> %timeit list(model_pytorch.split(texts))
# 144 ms ± 252 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)
# quite fast already, but...

# onnxruntime GPU
>>> model_ort = SaT("sat-3l-sm", ort_providers=["CUDAExecutionProvider", "CPUExecutionProvider"])
>>> %timeit list(model_ort.split(texts))
# 94.9 ms ± 165 μs per loop (mean ± std. dev. of 7 runs, 10 loops each
# ...this should be ~50% faster! (tested on RTX 3090)
```

If you wish to use LoRA in combination with an ONNX model:
- Run `scripts/export_to_onnx_sat.py` with `use_lora: True` and an appropriate `output_dir: <OUTPUT_DIR>`.
- If you have a local LoRA module, use `lora_path`.
- If you wish to load a LoRA module from the HuggingFace hub, use `style_or_domain` and `language`.
- Load the ONNX model with merged LoRA weights:
`sat = SaT(<OUTPUT_DIR>, onnx_providers=["CUDAExecutionProvider", "CPUExecutionProvider"])`


## Available Models

Expand Down
23 changes: 23 additions & 0 deletions scripts/export_all_to_onnx.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# all models in manually defined array of models
models=(
"sat-1l-sm"
"sat-3l-sm"
"sat-6l-sm"
"sat-9l-sm"
"sat-12l-sm"
"sat-1l"
"sat-3l"
"sat-6l"
"sat-9l"
"sat-12l"
"sat-1l-no-limited-lookahead"
"sat-3l-no-limited-lookahead"
"sat-6l-no-limited-lookahead"
"sat-9l-no-limited-lookahead"
"sat-12l-no-limited-lookahead"
)

for model in "${models[@]}"
do
python scripts/export_to_onnx_sat.py --model_name_or_path=segment-any-text/$model --output_dir=output_onnx_exports/$model --upload_to_hub=True
done
89 changes: 80 additions & 9 deletions scripts/export_to_onnx_sat.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,33 @@
from dataclasses import dataclass
from pathlib import Path

import adapters # noqa
import onnx
import torch
from adapters.models import MODEL_MIXIN_MAPPING # noqa
from adapters.models.bert.mixin_bert import BertModelAdaptersMixin # noqa
from huggingface_hub import hf_hub_download, HfApi
from onnxruntime.transformers.optimizer import optimize_model # noqa
from transformers import AutoModelForTokenClassification, HfArgumentParser

import wtpsplit # noqa
import wtpsplit.models # noqa
from wtpsplit.utils import Constants

MODEL_MIXIN_MAPPING["SubwordXLMRobertaModel"] = BertModelAdaptersMixin


@dataclass
class Args:
model_name_or_path: str = "segment-any-text/sat-1l-sm"
output_dir: str = "sat-1l-sm"
model_name_or_path: str = "segment-any-text/sat-1l-sm" # model from HF Hub: https://huggingface.co/segment-any-text
output_dir: str = "sat-1l-sm" # output directory, saves to current directory
device: str = "cuda"
# TODO: lora merging here
use_lora: bool = False
lora_path: str = None # local path to lora weights
# otherwise, fetch from HF Hub:
style_or_domain: str = "ud"
language: str = "en"
upload_to_hub: bool = False


if __name__ == "__main__":
Expand All @@ -25,25 +37,70 @@ class Args:
output_dir.mkdir(exist_ok=True, parents=True)

model = AutoModelForTokenClassification.from_pretrained(args.model_name_or_path, force_download=False)
model = model.half() # CUDA ONLY!

model = model.to(args.device)

# fetch config.json from huggingface hub
hf_hub_download(
repo_id=args.model_name_or_path,
filename="config.json",
local_dir=output_dir,
)

# LoRA SETUP
if args.use_lora:
# adapters need xlm-roberta as model type.
model_type = model.config.model_type
model.config.model_type = "xlm-roberta"
adapters.init(model)
# reset model type (used later)
model.config.model_type = model_type
if not args.lora_path:
for file in [
"adapter_config.json",
"head_config.json",
"pytorch_adapter.bin",
"pytorch_model_head.bin",
]:
hf_hub_download(
repo_id=args.model_name_or_path,
subfolder=f"loras/{args.style_or_domain}/{args.language}",
filename=file,
local_dir=Constants.CACHE_DIR,
)
lora_load_path = str(Constants.CACHE_DIR / "loras" / args.style_or_domain / args.language)
else:
lora_load_path = args.lora_path

print(f"Using LoRA weights from {lora_load_path}.")
model.load_adapter(
lora_load_path,
set_active=True,
with_head=True,
load_as="sat-lora",
)
# merge lora weights into transformer for 0 efficiency overhead
model.merge_adapter("sat-lora")
print("LoRA setup done.")
# LoRA setup done, model is now ready for export.

model = model.half()

torch.onnx.export(
model,
{
"attention_mask": torch.zeros((1, 1), dtype=torch.float16, device=args.device),
"input_ids": torch.zeros((1, 1), dtype=torch.int64, device=args.device),
"input_ids": torch.randint(0, model.config.vocab_size, (1, 1), dtype=torch.int64, device=args.device),
"attention_mask": torch.randn((1, 1), dtype=torch.float16, device=args.device),
},
output_dir / "model.onnx",
verbose=True,
input_names=["attention_mask", "input_ids"],
input_names=["input_ids", "attention_mask"],
output_names=["logits"],
dynamic_axes={
"input_ids": {0: "batch", 1: "sequence"},
"attention_mask": {0: "batch", 1: "sequence"},
"logits": {0: "batch", 1: "sequence"},
},
# opset_version=11
)

m = optimize_model(
Expand All @@ -60,4 +117,18 @@ class Args:
onnx.save_model(m.model, optimized_model_path)

onnx_model = onnx.load(output_dir / "model.onnx")
onnx.checker.check_model(onnx_model, full_check=True)
print(onnx.checker.check_model(onnx_model, full_check=True))

if args.upload_to_hub:
api = HfApi()

api.upload_file(
path_or_fileobj=output_dir / "model_optimized.onnx",
path_in_repo="model_optimized.onnx",
repo_id=args.model_name_or_path,
)
api.upload_file(
path_or_fileobj=output_dir / "model.onnx",
path_in_repo="model.onnx",
repo_id=args.model_name_or_path,
)
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="wtpsplit",
version="2.0.8",
version="2.1.0",
packages=find_packages(),
description="Universal Robust, Efficient and Adaptable Sentence Segmentation",
author="Markus Frohmann, Igor Sterner, Benjamin Minixhofer",
Expand Down
8 changes: 4 additions & 4 deletions test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@
from wtpsplit import WtP, SaT


# def test_split_ort():
# sat = SaT("segment-any-text/sat-3l", ort_providers=["CPUExecutionProvider"])
def test_split_ort():
sat = SaT("sat-3l-sm", ort_providers=["CPUExecutionProvider"])

# splits = sat.split("This is a test sentence This is another test sentence.", threshold=0.005)
# assert splits == ["This is a test sentence ", "This is another test sentence."]
splits = sat.split("This is a test sentence This is another test sentence.", threshold=0.25)
assert splits == ["This is a test sentence ", "This is another test sentence."]


def test_split_torch():
Expand Down
56 changes: 28 additions & 28 deletions wtpsplit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,10 @@
from transformers import AutoConfig, AutoModelForTokenClassification, AutoTokenizer
from transformers.utils.hub import cached_file

from wtpsplit.extract import BertCharORTWrapper, PyTorchWrapper, extract
from wtpsplit.extract import BertCharORTWrapper, SaTORTWrapper, PyTorchWrapper, extract
from wtpsplit.utils import Constants, indices_to_sentences, sigmoid, token_to_char_probs

__version__ = "2.0.8"
__version__ = "2.1.0"

warnings.simplefilter("default", DeprecationWarning) # show by default
warnings.simplefilter("ignore", category=FutureWarning) # for tranformers
Expand Down Expand Up @@ -88,8 +88,6 @@ def __init__(

try:
import onnxruntime as ort # noqa

ort.set_default_logger_severity(0)
except ModuleNotFoundError:
raise ValueError("Please install `onnxruntime` to use WtP with an ONNX model.")

Expand Down Expand Up @@ -449,38 +447,41 @@ def __init__(

if is_local:
model_path = Path(model_name)
onnx_path = model_path / "model.onnx"
onnx_path = model_path / "model_optimized.onnx"
if not onnx_path.exists():
onnx_path = None
else:
# no need to load if no ort_providers set
if ort_providers is not None:
onnx_path = cached_file(model_name_to_fetch, "model.onnx", **(from_pretrained_kwargs or {}))
onnx_path = cached_file(
model_name_to_fetch, "model_optimized.onnx", **(from_pretrained_kwargs or {})
)
else:
onnx_path = None

if ort_providers is not None:
raise NotImplementedError("ONNX is not supported for SaT *yet*.")
# if onnx_path is None:
# raise ValueError(
# "Could not find an ONNX model in the model directory. Try `use_ort=False` to run with PyTorch."
# )

# try:
# import onnxruntime as ort # noqa

# ort.set_default_logger_severity(0)
# except ModuleNotFoundError:
# raise ValueError("Please install `onnxruntime` to use WtP with an ONNX model.")

# # to register models for AutoConfig
# import wtpsplit.configs # noqa

# # TODO: ONNX integration
# self.model = SaTORTWrapper(
# AutoConfig.from_pretrained(model_name_to_fetch, **(from_pretrained_kwargs or {})),
# ort.InferenceSession(str(onnx_path), providers=ort_providers, **(ort_kwargs or {})),
# )
if onnx_path is None:
raise ValueError(
"Could not find an ONNX model in the model directory. Try `use_ort=False` to run with PyTorch."
)

try:
import onnxruntime as ort # noqa
except ModuleNotFoundError:
raise ValueError("Please install `onnxruntime` to use SaT with an ONNX model.")

# to register models for AutoConfig
import wtpsplit.configs # noqa

self.model = SaTORTWrapper(
AutoConfig.from_pretrained(model_name_to_fetch, **(from_pretrained_kwargs or {})),
ort.InferenceSession(str(onnx_path), providers=ort_providers, **(ort_kwargs or {})),
)
if lora_path:
raise ValueError(
"If using ONNX with LoRA, execute `scripts/export_to_onnx_sat.py` with `use_lora=True`."
"Reference the chosen `output_dir` here for `model_name_or_model`. and set `lora_path=None`."
)
else:
# to register models for AutoConfig
try:
Expand All @@ -496,7 +497,6 @@ def __init__(
)
)
# LoRA LOADING
# TODO: LoRA + ONNX ?
if not lora_path:
if (style_or_domain and not language) or (language and not style_or_domain):
raise ValueError("Please specify both language and style_or_domain!")
Expand Down
10 changes: 5 additions & 5 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,11 @@ def __getattr__(self, name):

def __call__(self, input_ids, attention_mask):
logits = self.ort_session.run(
output_names=["logits"],
input_feed={
"attention_mask": attention_mask.astype(np.int64),
"input_ids": input_ids.astype(np.float16),
}, # .astype(np.int64)},
["logits"],
{
"attention_mask": attention_mask.astype(np.float16),
"input_ids": input_ids.astype(np.int64),
},
)[0]

return {"logits": logits}
Expand Down
Loading