Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: optional dependency issue #190

Merged
merged 25 commits into from
Mar 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6b901f8
optional dependency issue fix started
zahid-syed Mar 9, 2024
80ba002
started fix on llama cpp optional dependency
zahid-syed Mar 11, 2024
363ae33
fixing pytest for PR
zahid-syed Mar 11, 2024
3b0de2d
fixing pytest issues for optional dependency issues
zahid-syed Mar 11, 2024
354307e
should fix issues relating to optional dependencies and mistral
zahid-syed Mar 12, 2024
7ac33fa
reverted poetry lock and pyrproject.toml
zahid-syed Mar 12, 2024
43e662d
Merge branch 'main' into zahid/issue_fixes
zahid-syed Mar 12, 2024
ffc5f88
updated dependencies
zahid-syed Mar 12, 2024
fcf7077
merge
zahid-syed Mar 12, 2024
61d22a7
attempt to fix pytest coverage
zahid-syed Mar 12, 2024
97007ca
unable to fix coverage for both MistralEncoder and MistralLLM
zahid-syed Mar 13, 2024
eb343f6
Update semantic_router/llms/llamacpp.py
zahid-syed Mar 14, 2024
07b7c6a
Update semantic_router/llms/llamacpp.py
zahid-syed Mar 14, 2024
704ab01
Update semantic_router/encoders/mistral.py
zahid-syed Mar 14, 2024
e32c5e3
Update semantic_router/encoders/mistral.py
zahid-syed Mar 14, 2024
a4ae571
Update semantic_router/encoders/mistral.py
zahid-syed Mar 14, 2024
df56086
Update semantic_router/encoders/mistral.py
zahid-syed Mar 14, 2024
bf517a9
Update semantic_router/encoders/mistral.py
zahid-syed Mar 14, 2024
2b69033
Update semantic_router/encoders/mistral.py
zahid-syed Mar 14, 2024
09f0335
Update semantic_router/encoders/mistral.py
zahid-syed Mar 14, 2024
4a12b4a
fix mistral ai llm issue
zahid-syed Mar 14, 2024
bbf717a
fix llama cpp test
zahid-syed Mar 14, 2024
19d54ee
update mistral tests
ashraq1455 Mar 14, 2024
5e6fd2f
Merge branch 'main' into zahid/issue_fixes
ashraq1455 Mar 14, 2024
314005d
add LlamaCppLLM to init
ashraq1455 Mar 14, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
191 changes: 96 additions & 95 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ python = ">=3.9,<3.13"
pydantic = "^2.5.3"
openai = "^1.10.0"
cohere = "^4.32"
mistralai= "^0.0.12"
mistralai= {version = "^0.0.12", optional = true}
numpy = "^1.25.2"
colorlog = "^6.8.0"
pyyaml = "^6.0.1"
Expand All @@ -44,6 +44,7 @@ local = ["torch", "transformers", "llama-cpp-python"]
pinecone = ["pinecone-client"]
vision = ["torch", "torchvision", "transformers", "pillow"]
processing = ["matplotlib"]
mistralai = ["mistralai"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand All @@ -67,4 +68,4 @@ build-backend = "poetry.core.masonry.api"
line-length = 88

[tool.mypy]
ignore_missing_imports = true
ignore_missing_imports = true
41 changes: 30 additions & 11 deletions semantic_router/encoders/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@

import os
from time import sleep
from typing import List, Optional
from typing import List, Optional, Any

from mistralai.client import MistralClient
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse

from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault
from pydantic.v1 import PrivateAttr


class MistralEncoder(BaseEncoder):
"""Class to encode text using MistralAI"""

client: Optional[MistralClient]
_client: Any = PrivateAttr()
_mistralai: Any = PrivateAttr()
type: str = "mistral"

def __init__(
Expand All @@ -27,33 +26,53 @@ def __init__(
if name is None:
name = EncoderDefault.MISTRAL.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold)
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
self._client, self._mistralai = self._initialize_client(mistralai_api_key)

def _initialize_client(self, api_key):
try:
import mistralai
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)

api_key = api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("Mistral API key not provided")
try:
self.client = MistralClient(api_key=api_key)
client = MistralClient(api_key=api_key)
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
return client, mistralai

def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None:
if self._client is None:
raise ValueError("Mistral client not initialized")
embeds = None
error_message = ""

# Exponential backoff
for _ in range(3):
try:
embeds = self.client.embeddings(model=self.name, input=docs)
embeds = self._client.embeddings(model=self.name, input=docs)
if embeds.data:
break
except MistralException as e:
except self._mistralai.exceptions.MistralException as e:
sleep(2**_)
error_message = str(e)
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e

if not embeds or not isinstance(embeds, EmbeddingResponse) or not embeds.data:
if (
not embeds
or not isinstance(
embeds, self._mistralai.models.embeddings.EmbeddingResponse
)
or not embeds.data
):
raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
2 changes: 2 additions & 0 deletions semantic_router/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from semantic_router.llms.base import BaseLLM
from semantic_router.llms.cohere import CohereLLM
from semantic_router.llms.llamacpp import LlamaCppLLM
from semantic_router.llms.mistral import MistralAILLM
from semantic_router.llms.openai import OpenAILLM
from semantic_router.llms.openrouter import OpenRouterLLM
Expand All @@ -8,6 +9,7 @@
__all__ = [
"BaseLLM",
"OpenAILLM",
"LlamaCppLLM",
"OpenRouterLLM",
"CohereLLM",
"AzureOpenAILLM",
Expand Down
27 changes: 20 additions & 7 deletions semantic_router/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,27 @@
from pathlib import Path
from typing import Any, Optional

from llama_cpp import Llama, LlamaGrammar

from semantic_router.llms.base import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.logger import logger

from pydantic.v1 import PrivateAttr


class LlamaCppLLM(BaseLLM):
llm: Llama
llm: Any
temperature: float
max_tokens: Optional[int] = 200
grammar: Optional[LlamaGrammar] = None
grammar: Optional[Any] = None
_llama_cpp: Any = PrivateAttr()

def __init__(
self,
llm: Llama,
llm: Any,
name: str = "llama.cpp",
temperature: float = 0.2,
max_tokens: Optional[int] = 200,
grammar: Optional[LlamaGrammar] = None,
grammar: Optional[Any] = None,
):
super().__init__(
name=name,
Expand All @@ -30,6 +31,18 @@ def __init__(
max_tokens=max_tokens,
grammar=grammar,
)

try:
import llama_cpp
except ImportError:
raise ImportError(
"Please install LlamaCPP to use Llama CPP llm. "
"You can install it with: "
"`pip install 'semantic-router[local]'`"
)
self._llama_cpp = llama_cpp
llm = self._llama_cpp.Llama
grammar = self._llama_cpp.LlamaGrammar
self.llm = llm
self.temperature = temperature
self.max_tokens = max_tokens
Expand Down Expand Up @@ -62,7 +75,7 @@ def _grammar(self):
grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf")
assert grammar_path.exists(), f"{grammar_path}\ndoes not exist"
try:
self.grammar = LlamaGrammar.from_file(grammar_path)
self.grammar = self._llama_cpp.LlamaGrammar.from_file(grammar_path)
yield
finally:
self.grammar = None
Expand Down
42 changes: 32 additions & 10 deletions semantic_router/llms/mistral.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import os
from typing import List, Optional
from typing import List, Optional, Any

from mistralai.client import MistralClient

from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger

from pydantic.v1 import PrivateAttr


class MistralAILLM(BaseLLM):
client: Optional[MistralClient]
_client: Any = PrivateAttr()
temperature: Optional[float]
max_tokens: Optional[int]
_mistralai: Any = PrivateAttr()

def __init__(
self,
Expand All @@ -24,25 +26,45 @@ def __init__(
if name is None:
name = EncoderDefault.MISTRAL.value["language_model"]
super().__init__(name=name)
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
self._client, self._mistralai = self._initialize_client(mistralai_api_key)
self.temperature = temperature
self.max_tokens = max_tokens

def _initialize_client(self, api_key):
try:
import mistralai
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralAI LLM. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)
api_key = api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("MistralAI API key cannot be 'None'.")
try:
self.client = MistralClient(api_key=api_key)
client = MistralClient(api_key=api_key)
except Exception as e:
raise ValueError(
f"MistralAI API client failed to initialize. Error: {e}"
) from e
self.temperature = temperature
self.max_tokens = max_tokens
return client, mistralai

def __call__(self, messages: List[Message]) -> str:
if self.client is None:
if self._client is None:
raise ValueError("MistralAI client is not initialized.")

chat_messages = [
self._mistralai.models.chat_completion.ChatMessage(
role=m.role, content=m.content
)
for m in messages
]
try:
completion = self.client.chat(
completion = self._client.chat(
model=self.name,
messages=[m.to_mistral() for m in messages],
messages=chat_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
Expand Down
26 changes: 20 additions & 6 deletions tests/unit/encoders/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from semantic_router.encoders import MistralEncoder

from unittest.mock import patch


@pytest.fixture
def mistralai_encoder(mocker):
Expand All @@ -12,9 +14,21 @@ def mistralai_encoder(mocker):


class TestMistralEncoder:
def test_mistral_encoder_import_errors(self):
with patch.dict("sys.modules", {"mistralai": None}):
with pytest.raises(ImportError) as error:
MistralEncoder()

assert (
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`" in str(error.value)
)

def test_mistralai_encoder_init_success(self, mocker):
encoder = MistralEncoder(mistralai_api_key="test_api_key")
assert encoder.client is not None
assert encoder._client is not None
assert encoder._mistralai is not None

def test_mistralai_encoder_init_no_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
Expand All @@ -23,7 +37,7 @@ def test_mistralai_encoder_init_no_api_key(self, mocker):

def test_mistralai_encoder_call_uninitialized_client(self, mistralai_encoder):
# Set the client to None to simulate an uninitialized client
mistralai_encoder.client = None
mistralai_encoder._client = None
with pytest.raises(ValueError) as e:
mistralai_encoder(["test document"])
assert "Mistral client not initialized" in str(e.value)
Expand Down Expand Up @@ -60,7 +74,7 @@ def test_mistralai_encoder_call_success(self, mistralai_encoder, mocker):

responses = [MistralException("mistralai error"), mock_response]
mocker.patch.object(
mistralai_encoder.client, "embeddings", side_effect=responses
mistralai_encoder._client, "embeddings", side_effect=responses
)
embeddings = mistralai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
Expand All @@ -69,7 +83,7 @@ def test_mistralai_encoder_call_with_retries(self, mistralai_encoder, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mocker.patch.object(
mistralai_encoder.client,
mistralai_encoder._client,
"embeddings",
side_effect=MistralException("Test error"),
)
Expand All @@ -83,7 +97,7 @@ def test_mistralai_encoder_call_failure_non_mistralai_error(
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mocker.patch.object(
mistralai_encoder.client,
mistralai_encoder._client,
"embeddings",
side_effect=Exception("Non-MistralException"),
)
Expand Down Expand Up @@ -118,7 +132,7 @@ def test_mistralai_encoder_call_successful_retry(self, mistralai_encoder, mocker

responses = [MistralException("mistralai error"), mock_response]
mocker.patch.object(
mistralai_encoder.client, "embeddings", side_effect=responses
mistralai_encoder._client, "embeddings", side_effect=responses
)
embeddings = mistralai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
13 changes: 13 additions & 0 deletions tests/unit/llms/test_llm_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from semantic_router.llms.llamacpp import LlamaCppLLM
from semantic_router.schema import Message

from unittest.mock import patch


@pytest.fixture
def llamacpp_llm(mocker):
Expand All @@ -13,6 +15,17 @@ def llamacpp_llm(mocker):


class TestLlamaCppLLM:
def test_llama_cpp_import_errors(self, llamacpp_llm):
with patch.dict("sys.modules", {"llama_cpp": None}):
with pytest.raises(ImportError) as error:
LlamaCppLLM(llamacpp_llm.llm)

assert (
"Please install LlamaCPP to use Llama CPP llm. "
"You can install it with: "
"`pip install 'semantic-router[local]'`" in str(error.value)
)

def test_llamacpp_llm_init_success(self, llamacpp_llm):
assert llamacpp_llm.name == "llama.cpp"
assert llamacpp_llm.temperature == 0.2
Expand Down
Loading
Loading