Skip to content

Commit

Permalink
Merge pull request #190 from aurelio-labs/zahid/issue_fixes
Browse files Browse the repository at this point in the history
fix: optional dependency issue
  • Loading branch information
jamescalam committed Mar 15, 2024
2 parents 19cf5d2 + 314005d commit e02a4a7
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 135 deletions.
191 changes: 96 additions & 95 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ python = ">=3.9,<3.13"
pydantic = "^2.5.3"
openai = "^1.10.0"
cohere = "^4.32"
mistralai= "^0.0.12"
mistralai= {version = "^0.0.12", optional = true}
numpy = "^1.25.2"
colorlog = "^6.8.0"
pyyaml = "^6.0.1"
Expand All @@ -44,6 +44,7 @@ local = ["torch", "transformers", "llama-cpp-python"]
pinecone = ["pinecone-client"]
vision = ["torch", "torchvision", "transformers", "pillow"]
processing = ["matplotlib"]
mistralai = ["mistralai"]

[tool.poetry.group.dev.dependencies]
ipykernel = "^6.25.0"
Expand All @@ -67,4 +68,4 @@ build-backend = "poetry.core.masonry.api"
line-length = 88

[tool.mypy]
ignore_missing_imports = true
ignore_missing_imports = true
41 changes: 30 additions & 11 deletions semantic_router/encoders/mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@

import os
from time import sleep
from typing import List, Optional
from typing import List, Optional, Any

from mistralai.client import MistralClient
from mistralai.exceptions import MistralException
from mistralai.models.embeddings import EmbeddingResponse

from semantic_router.encoders import BaseEncoder
from semantic_router.utils.defaults import EncoderDefault
from pydantic.v1 import PrivateAttr


class MistralEncoder(BaseEncoder):
"""Class to encode text using MistralAI"""

client: Optional[MistralClient]
_client: Any = PrivateAttr()
_mistralai: Any = PrivateAttr()
type: str = "mistral"

def __init__(
Expand All @@ -27,33 +26,53 @@ def __init__(
if name is None:
name = EncoderDefault.MISTRAL.value["embedding_model"]
super().__init__(name=name, score_threshold=score_threshold)
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
self._client, self._mistralai = self._initialize_client(mistralai_api_key)

def _initialize_client(self, api_key):
try:
import mistralai
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)

api_key = api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("Mistral API key not provided")
try:
self.client = MistralClient(api_key=api_key)
client = MistralClient(api_key=api_key)
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e
return client, mistralai

def __call__(self, docs: List[str]) -> List[List[float]]:
if self.client is None:
if self._client is None:
raise ValueError("Mistral client not initialized")
embeds = None
error_message = ""

# Exponential backoff
for _ in range(3):
try:
embeds = self.client.embeddings(model=self.name, input=docs)
embeds = self._client.embeddings(model=self.name, input=docs)
if embeds.data:
break
except MistralException as e:
except self._mistralai.exceptions.MistralException as e:
sleep(2**_)
error_message = str(e)
except Exception as e:
raise ValueError(f"Unable to connect to MistralAI {e.args}: {e}") from e

if not embeds or not isinstance(embeds, EmbeddingResponse) or not embeds.data:
if (
not embeds
or not isinstance(
embeds, self._mistralai.models.embeddings.EmbeddingResponse
)
or not embeds.data
):
raise ValueError(f"No embeddings returned from MistralAI: {error_message}")
embeddings = [embeds_obj.embedding for embeds_obj in embeds.data]
return embeddings
2 changes: 2 additions & 0 deletions semantic_router/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from semantic_router.llms.base import BaseLLM
from semantic_router.llms.cohere import CohereLLM
from semantic_router.llms.llamacpp import LlamaCppLLM
from semantic_router.llms.mistral import MistralAILLM
from semantic_router.llms.openai import OpenAILLM
from semantic_router.llms.openrouter import OpenRouterLLM
Expand All @@ -8,6 +9,7 @@
__all__ = [
"BaseLLM",
"OpenAILLM",
"LlamaCppLLM",
"OpenRouterLLM",
"CohereLLM",
"AzureOpenAILLM",
Expand Down
27 changes: 20 additions & 7 deletions semantic_router/llms/llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,26 +2,27 @@
from pathlib import Path
from typing import Any, Optional

from llama_cpp import Llama, LlamaGrammar

from semantic_router.llms.base import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.logger import logger

from pydantic.v1 import PrivateAttr


class LlamaCppLLM(BaseLLM):
llm: Llama
llm: Any
temperature: float
max_tokens: Optional[int] = 200
grammar: Optional[LlamaGrammar] = None
grammar: Optional[Any] = None
_llama_cpp: Any = PrivateAttr()

def __init__(
self,
llm: Llama,
llm: Any,
name: str = "llama.cpp",
temperature: float = 0.2,
max_tokens: Optional[int] = 200,
grammar: Optional[LlamaGrammar] = None,
grammar: Optional[Any] = None,
):
super().__init__(
name=name,
Expand All @@ -30,6 +31,18 @@ def __init__(
max_tokens=max_tokens,
grammar=grammar,
)

try:
import llama_cpp
except ImportError:
raise ImportError(
"Please install LlamaCPP to use Llama CPP llm. "
"You can install it with: "
"`pip install 'semantic-router[local]'`"
)
self._llama_cpp = llama_cpp
llm = self._llama_cpp.Llama
grammar = self._llama_cpp.LlamaGrammar
self.llm = llm
self.temperature = temperature
self.max_tokens = max_tokens
Expand Down Expand Up @@ -62,7 +75,7 @@ def _grammar(self):
grammar_path = Path(__file__).parent.joinpath("grammars", "json.gbnf")
assert grammar_path.exists(), f"{grammar_path}\ndoes not exist"
try:
self.grammar = LlamaGrammar.from_file(grammar_path)
self.grammar = self._llama_cpp.LlamaGrammar.from_file(grammar_path)
yield
finally:
self.grammar = None
Expand Down
42 changes: 32 additions & 10 deletions semantic_router/llms/mistral.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
import os
from typing import List, Optional
from typing import List, Optional, Any

from mistralai.client import MistralClient

from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger

from pydantic.v1 import PrivateAttr


class MistralAILLM(BaseLLM):
client: Optional[MistralClient]
_client: Any = PrivateAttr()
temperature: Optional[float]
max_tokens: Optional[int]
_mistralai: Any = PrivateAttr()

def __init__(
self,
Expand All @@ -24,25 +26,45 @@ def __init__(
if name is None:
name = EncoderDefault.MISTRAL.value["language_model"]
super().__init__(name=name)
api_key = mistralai_api_key or os.getenv("MISTRALAI_API_KEY")
self._client, self._mistralai = self._initialize_client(mistralai_api_key)
self.temperature = temperature
self.max_tokens = max_tokens

def _initialize_client(self, api_key):
try:
import mistralai
from mistralai.client import MistralClient
except ImportError:
raise ImportError(
"Please install MistralAI to use MistralAI LLM. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`"
)
api_key = api_key or os.getenv("MISTRALAI_API_KEY")
if api_key is None:
raise ValueError("MistralAI API key cannot be 'None'.")
try:
self.client = MistralClient(api_key=api_key)
client = MistralClient(api_key=api_key)
except Exception as e:
raise ValueError(
f"MistralAI API client failed to initialize. Error: {e}"
) from e
self.temperature = temperature
self.max_tokens = max_tokens
return client, mistralai

def __call__(self, messages: List[Message]) -> str:
if self.client is None:
if self._client is None:
raise ValueError("MistralAI client is not initialized.")

chat_messages = [
self._mistralai.models.chat_completion.ChatMessage(
role=m.role, content=m.content
)
for m in messages
]
try:
completion = self.client.chat(
completion = self._client.chat(
model=self.name,
messages=[m.to_mistral() for m in messages],
messages=chat_messages,
temperature=self.temperature,
max_tokens=self.max_tokens,
)
Expand Down
26 changes: 20 additions & 6 deletions tests/unit/encoders/test_mistral.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from semantic_router.encoders import MistralEncoder

from unittest.mock import patch


@pytest.fixture
def mistralai_encoder(mocker):
Expand All @@ -12,9 +14,21 @@ def mistralai_encoder(mocker):


class TestMistralEncoder:
def test_mistral_encoder_import_errors(self):
with patch.dict("sys.modules", {"mistralai": None}):
with pytest.raises(ImportError) as error:
MistralEncoder()

assert (
"Please install MistralAI to use MistralEncoder. "
"You can install it with: "
"`pip install 'semantic-router[mistralai]'`" in str(error.value)
)

def test_mistralai_encoder_init_success(self, mocker):
encoder = MistralEncoder(mistralai_api_key="test_api_key")
assert encoder.client is not None
assert encoder._client is not None
assert encoder._mistralai is not None

def test_mistralai_encoder_init_no_api_key(self, mocker):
mocker.patch("os.getenv", return_value=None)
Expand All @@ -23,7 +37,7 @@ def test_mistralai_encoder_init_no_api_key(self, mocker):

def test_mistralai_encoder_call_uninitialized_client(self, mistralai_encoder):
# Set the client to None to simulate an uninitialized client
mistralai_encoder.client = None
mistralai_encoder._client = None
with pytest.raises(ValueError) as e:
mistralai_encoder(["test document"])
assert "Mistral client not initialized" in str(e.value)
Expand Down Expand Up @@ -60,7 +74,7 @@ def test_mistralai_encoder_call_success(self, mistralai_encoder, mocker):

responses = [MistralException("mistralai error"), mock_response]
mocker.patch.object(
mistralai_encoder.client, "embeddings", side_effect=responses
mistralai_encoder._client, "embeddings", side_effect=responses
)
embeddings = mistralai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
Expand All @@ -69,7 +83,7 @@ def test_mistralai_encoder_call_with_retries(self, mistralai_encoder, mocker):
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mocker.patch.object(
mistralai_encoder.client,
mistralai_encoder._client,
"embeddings",
side_effect=MistralException("Test error"),
)
Expand All @@ -83,7 +97,7 @@ def test_mistralai_encoder_call_failure_non_mistralai_error(
mocker.patch("os.getenv", return_value="fake-api-key")
mocker.patch("time.sleep", return_value=None) # To speed up the test
mocker.patch.object(
mistralai_encoder.client,
mistralai_encoder._client,
"embeddings",
side_effect=Exception("Non-MistralException"),
)
Expand Down Expand Up @@ -118,7 +132,7 @@ def test_mistralai_encoder_call_successful_retry(self, mistralai_encoder, mocker

responses = [MistralException("mistralai error"), mock_response]
mocker.patch.object(
mistralai_encoder.client, "embeddings", side_effect=responses
mistralai_encoder._client, "embeddings", side_effect=responses
)
embeddings = mistralai_encoder(["test document"])
assert embeddings == [[0.1, 0.2]]
13 changes: 13 additions & 0 deletions tests/unit/llms/test_llm_llamacpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
from semantic_router.llms.llamacpp import LlamaCppLLM
from semantic_router.schema import Message

from unittest.mock import patch


@pytest.fixture
def llamacpp_llm(mocker):
Expand All @@ -13,6 +15,17 @@ def llamacpp_llm(mocker):


class TestLlamaCppLLM:
def test_llama_cpp_import_errors(self, llamacpp_llm):
with patch.dict("sys.modules", {"llama_cpp": None}):
with pytest.raises(ImportError) as error:
LlamaCppLLM(llamacpp_llm.llm)

assert (
"Please install LlamaCPP to use Llama CPP llm. "
"You can install it with: "
"`pip install 'semantic-router[local]'`" in str(error.value)
)

def test_llamacpp_llm_init_success(self, llamacpp_llm):
assert llamacpp_llm.name == "llama.cpp"
assert llamacpp_llm.temperature == 0.2
Expand Down
Loading

0 comments on commit e02a4a7

Please sign in to comment.