Skip to content

Commit

Permalink
Merge pull request #1 from norsulabs/dev/test
Browse files Browse the repository at this point in the history
added quantization support
  • Loading branch information
ameen-91 authored Nov 26, 2024
2 parents 143d0c8 + 93db8af commit 58d63cc
Show file tree
Hide file tree
Showing 8 changed files with 305 additions and 156 deletions.
19 changes: 19 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ Infero allows you to easily download, convert, and host your models using the ON
- Automatic downloads.
- Automatic ONNX conversions.
- Automatic server setup.
- 8-bit quantization support.

## Installation

Expand All @@ -32,6 +33,24 @@ Here is a simple example of how to use Infero:
infero run [hf_model_name]
```

To run a model with 8-bit quantization:

```bash
infero run [hf_model_name] --quantize
```

To list all available models:

```bash
infero list
```

To remove a model:

```bash
infero remove [hf_model_name]
```

Infero is licensed under the MIT License. See the [LICENSE](LICENSE) file for more details.

## Contact
Expand Down
37 changes: 36 additions & 1 deletion infero/convert/onnx.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,14 @@
import os
from infero.utils import sanitize_model_name, print_success
from infero.utils import sanitize_model_name, print_success, print_neutral
from transformers import AutoModelForSequenceClassification
import torch
from onnxruntime.quantization import quantize_dynamic, QuantType
import warnings
import logging

logging.getLogger("root").setLevel(
logging.ERROR
) ## temporary fix for warning message from onnxruntime


def convert_to_onnx(model_name):
Expand All @@ -11,6 +18,9 @@ def convert_to_onnx(model_name):
if os.path.exists(output_path):
print_success(f"ONNX model for {model_name} already exists")
return

print_neutral(f"Creating ONNX model for {model_name}")

model = AutoModelForSequenceClassification.from_pretrained(
f"infero/data/models/{sanitize_model_name(model_name)}/"
)
Expand Down Expand Up @@ -39,3 +49,28 @@ def convert_to_onnx(model_name):
"attention_mask": symbolic_names,
},
)

print_success(f"ONNX model for {model_name} created")


def convert_to_onnx_q8(model_name):

onnx_model_path = f"infero/data/models/{sanitize_model_name(model_name)}/model.onnx"
quantized_model_path = (
f"infero/data/models/{sanitize_model_name(model_name)}/model_quantized.onnx"
)

if os.path.exists(quantized_model_path):
print_success(f"Quantized model for {model_name} already exists")
return

print_neutral(f"Creating quantized model for {model_name}")
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="Please consider to run pre-processing before quantization",
)
quantize_dynamic(
onnx_model_path, quantized_model_path, weight_type=QuantType.QInt8
)
print_success(f"Quantized model for {model_name} created")
14 changes: 9 additions & 5 deletions infero/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,25 @@
import subprocess
import typer
from infero.pull.download import check_model
from infero.convert.onnx import convert_to_onnx
from infero.convert.onnx import convert_to_onnx, convert_to_onnx_q8
from infero.utils import sanitize_model_name
from infero.pull.models import remove_model

app = typer.Typer(name="infero")


@app.command("run")
def pull(model: str):
def pull(model: str, quantize: bool = False):
if check_model(model):
convert_to_onnx(model)
if quantize:
convert_to_onnx_q8(model)
model_path = f"infero/data/models/{sanitize_model_name(model)}"
script_dir = os.path.dirname(__file__)
server_script_path = os.path.join(script_dir, "serve", "server", "server.py")
subprocess.run(["python", server_script_path, model_path])
script_dir = os.path.dirname(os.path.abspath(__file__))
server_script_path = os.path.join(script_dir, "serve", "server.py")
subprocess.run(
["python", server_script_path, model_path, str(quantize).lower()]
)
else:
typer.echo("Failed to run model")

Expand Down
8 changes: 4 additions & 4 deletions infero/pull/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import yaml
from tqdm import tqdm

from infero.utils import print_error, print_success, sanitize_model_name
from infero.utils import print_error, print_success, sanitize_model_name, print_neutral


def is_supported(model: str):
Expand Down Expand Up @@ -32,18 +32,18 @@ def check_model_integrity(model: str):
config_path = f"infero/data/models/{sanitize_model_name(model)}/config.json"

if not os.path.exists(model_path):
print_error(f"Model {model} not found, downloading...")
print_neutral(f"Model {model} not found, downloading...")
download_model(model)

if not os.path.exists(vocab_path) and not os.path.exists(vocab_path_2):
print_error(f"Vocab file for {model} not found, downloading...")
print_neutral(f"Vocab file for {model} not found, downloading...")
vocab_url = f"https://huggingface.co/{model}/resolve/main/vocab.json"
if not download_file(vocab_url, vocab_path):
vocab_url = f"https://huggingface.co/{model}/resolve/main/vocab.txt"
download_file(vocab_url, vocab_path_2)

if not os.path.exists(config_path):
print_error(f"Config file for {model} not found, downloading...")
print_neutral(f"Config file for {model} not found, downloading...")
config_url = f"https://huggingface.co/{model}/raw/main/config.json"
download_file(config_url, config_path)

Expand Down
19 changes: 14 additions & 5 deletions infero/serve/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,26 +3,35 @@
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer
from infero.utils import print_success_bold
import onnxruntime


class TextRequest(BaseModel):
text: str


def load_model(model_path):
def load_model(model_path, quantize=False):
tokenizer = AutoTokenizer.from_pretrained(model_path)
sess_options = onnxruntime.SessionOptions()
session = onnxruntime.InferenceSession(
os.path.join(model_path, "model.onnx"), sess_options
)

if not quantize:
model_path_onnx = os.path.join(model_path, "model.onnx")
session = onnxruntime.InferenceSession(model_path_onnx, sess_options)
print_success_bold("Running: " + model_path_onnx)
else:
model_path_onnx = os.path.join(model_path, "model_quantized.onnx")
session = onnxruntime.InferenceSession(model_path_onnx, sess_options)
print_success_bold("Running: " + model_path_onnx)
return tokenizer, session


api_server = FastAPI()

model_path = sys.argv[1] if len(sys.argv) > 1 else ValueError("Model path not provided")
tokenizer, session = load_model(model_path)
quantize = sys.argv[2].lower() == "true"

tokenizer, session = load_model(model_path, quantize)


@api_server.post("/inference")
Expand Down
12 changes: 12 additions & 0 deletions infero/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,17 @@ def print_error(message: str):
typer.echo(typer.style(message, fg=typer.colors.RED))


def print_success_bold(message: str):
typer.echo(typer.style(message, fg=typer.colors.GREEN, bold=True))


def print_neutral(message: str):
typer.echo(typer.style(message, fg=typer.colors.BLUE))


def sanitize_model_name(model: str):
return model.replace("/", "_")


def unsanitize_model_name(model: str):
return model.replace("_", "/")
Loading

0 comments on commit 58d63cc

Please sign in to comment.