Skip to content

Commit

Permalink
Merge pull request #2 from norsulabs/dev/test
Browse files Browse the repository at this point in the history
pathing refactor and list memory
  • Loading branch information
ameen-91 authored Nov 29, 2024
2 parents 174a6ca + 9ec6d3c commit 0a6e5b5
Show file tree
Hide file tree
Showing 7 changed files with 508 additions and 31 deletions.
423 changes: 423 additions & 0 deletions file.json

Large diffs are not rendered by default.

45 changes: 36 additions & 9 deletions infero/main.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,27 @@
import os
import subprocess
import typer
from infero.pull.download import check_model
from infero.pull.download import check_model, pull_model
from tabulate import tabulate
from infero.convert.onnx import convert_to_onnx, convert_to_onnx_q8
from infero.utils import (
sanitize_model_name,
get_models_dir,
get_package_dir,
print_neutral,
print_success_bold,
print_error,
get_memory_usage,
)
from infero.pull.models import remove_model

app = typer.Typer(name="infero")


@app.command("run")
def pull(model: str, quantize: bool = False):
print
def run(model: str, quantize: bool = False):
print_neutral(f"{get_memory_usage()/ 1024 / 1024} MB")
if check_model(model):
convert_to_onnx(model)
if quantize:
convert_to_onnx_q8(model)
model_path = os.path.join(get_models_dir(), sanitize_model_name(model))
package_dir = get_package_dir()
server_script_path = os.path.join(package_dir, "serve", "server.py")
Expand All @@ -31,14 +32,40 @@ def pull(model: str, quantize: bool = False):
typer.echo("Failed to run model")


@app.command("pull")
def pull(model: str, quantize: bool = False):
if pull_model(model):
convert_to_onnx(model)
if quantize:
convert_to_onnx_q8(model)
print_success_bold(f"Model {model} pulled successfully")
else:
print_error("Failed to get model")


@app.command("list")
def list_models():
if not os.path.exists(get_models_dir()):
print_neutral("No models found")
return
models = os.path.join(get_models_dir(), sanitize_model_name)
for model in os.listdir(models):
typer.echo(model)
models_dir = get_models_dir()
models = []
for model in os.listdir(models_dir):
quantized = (
f"{os.path.getsize(os.path.join(models_dir, model, 'model_quantized.onnx')) / 1024 / 1024:.2f}"
if os.path.exists(os.path.join(models_dir, model, "model_quantized.onnx"))
else ""
)
size = (
os.path.getsize(os.path.join(models_dir, model, "pytorch_model.bin"))
/ 1024
/ 1024
)
models.append([model, size, quantized])
table = tabulate(
models, headers=["Name", "Size (MB)", "Quantized (MB)"], tablefmt="grid"
)
print_neutral(table)


@app.command("remove")
Expand Down
22 changes: 21 additions & 1 deletion infero/pull/download.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def check_model_integrity(model: str):

if not os.path.exists(model_path):
print_neutral(f"Model {model} not found, downloading...")
download_model(model)
return False

if not os.path.exists(vocab_path) and not os.path.exists(vocab_path_2):
print_neutral(f"Vocab file for {model} not found, downloading...")
Expand Down Expand Up @@ -122,6 +122,26 @@ def download_model(model: str):


def check_model(model: str):

if is_supported(model):
print_success(f"Model {model} is supported")
else:
print_error("Model architecture not supported")

if os.path.exists(
os.path.join(get_package_dir(), f"data/models/{sanitize_model_name(model)}")
):
print_success(f"Model {model} already exists")
chk = check_model_integrity(model)
if chk is True:
return True
else:
print_error(f"Model {model} not found, please run 'infero pull {model}'")
return False


def pull_model(model: str):

if is_supported(model):
print_success(f"Model {model} is supported")
else:
Expand Down
6 changes: 2 additions & 4 deletions infero/pull/models.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import os
from infero.utils import sanitize_model_name, print_success, print_error
from infero.utils import sanitize_model_name, print_success, print_error, get_models_dir


def remove_model(model):
model_path = os.path.join(
os.getcwd(), "infero/data/models", sanitize_model_name(model)
)
model_path = os.path.join(get_models_dir, sanitize_model_name(model))
if os.path.exists(model_path):
os.rmdir(model_path)
print_success(f"Model {model} removed")
Expand Down
16 changes: 15 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ transformers = "^4.46.3"
fastapi = {extras = ["standard"], version = "^0.115.5"}
torch = "^2.5.1"
psutil = "^6.1.0"
tabulate = "^0.9.0"



Expand Down
26 changes: 10 additions & 16 deletions tests/test_download.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,15 @@
import os
import pytest
from infero.pull.download import check_model_integrity
from unittest.mock import patch
from infero.pull.download import download_file


@pytest.fixture
def setup_model_directory():
model_name = "test-model"
model_dir = f"infero/data/models/{model_name}"
os.makedirs(model_dir, exist_ok=True)
yield model_name
# Cleanup after test
if os.path.exists(model_dir):
for file in os.listdir(model_dir):
os.remove(os.path.join(model_dir, file))
os.rmdir(model_dir)
@patch("infero.pull.download.download_file")
def test_download_file_success(mock_download_file):
mock_download_file.return_value = True
url = "https://huggingface.co/cardiffnlp/twitter-roberta-base-sentiment/blob/main/vocab.json"
result = download_file(url, "file.json")
assert result


def test_check_model_integrity(setup_model_directory):
model_name = setup_model_directory
assert check_model_integrity(model_name) is True
if __name__ == "__main__":
pytest.main()

0 comments on commit 0a6e5b5

Please sign in to comment.