Skip to content

Commit

Permalink
[FEATURE] Add support for Mistral via the mistral Python SDK (#374)
Browse files Browse the repository at this point in the history
* add mistral support

* linting

* fix typo

* add tests

* add examples notebook

* linting

* fix langchain typo in pyproject.toml (updated to 0.2.14)

* fix mistralai import and `undo_override` function

* add mistral to readme

* fix typo

* modified self.llm_event to llm_event

* refactoring

* black

* rename examples directory

* fix merge

* init merge

* updated model name so that tokencost will recognize this as a mistral model

* black lint

---------

Co-authored-by: reibs <[email protected]>
  • Loading branch information
the-praxs and areibman authored Nov 6, 2024
1 parent 1b8f473 commit c0fa6bb
Show file tree
Hide file tree
Showing 6 changed files with 692 additions and 1 deletion.
143 changes: 143 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -370,6 +370,149 @@ async def main() -> None:
print(message.content)


await main()
```
</details>

### Mistral 〽️

Track agents built with the Anthropic Python SDK (>=0.32.0).

- [AgentOps integration example](./examples/mistral//mistral_example.ipynb)
- [Official Mistral documentation](https://docs.mistral.ai)

<details>
<summary>Installation</summary>

```bash
pip install mistralai
```

Sync

```python python
from mistralai import Mistral
import agentops

# Beginning of program's code (i.e. main.py, __init__.py)
agentops.init(<INSERT YOUR API KEY HERE>)

client = Mistral(
# This is the default and can be omitted
api_key=os.environ.get("MISTRAL_API_KEY"),
)

message = client.chat.complete(
messages=[
{
"role": "user",
"content": "Tell me a cool fact about AgentOps",
}
],
model="open-mistral-nemo",
)
print(message.choices[0].message.content)

agentops.end_session('Success')
```

Streaming

```python python
from mistralai import Mistral
import agentops

# Beginning of program's code (i.e. main.py, __init__.py)
agentops.init(<INSERT YOUR API KEY HERE>)

client = Mistral(
# This is the default and can be omitted
api_key=os.environ.get("MISTRAL_API_KEY"),
)

message = client.chat.stream(
messages=[
{
"role": "user",
"content": "Tell me something cool about streaming agents",
}
],
model="open-mistral-nemo",
)

response = ""
for event in message:
if event.data.choices[0].finish_reason == "stop":
print("\n")
print(response)
print("\n")
else:
response += event.text

agentops.end_session('Success')
```

Async

```python python
import asyncio
from mistralai import Mistral

client = Mistral(
# This is the default and can be omitted
api_key=os.environ.get("MISTRAL_API_KEY"),
)


async def main() -> None:
message = await client.chat.complete_async(
messages=[
{
"role": "user",
"content": "Tell me something interesting about async agents",
}
],
model="open-mistral-nemo",
)
print(message.choices[0].message.content)


await main()
```

Async Streaming

```python python
import asyncio
from mistralai import Mistral

client = Mistral(
# This is the default and can be omitted
api_key=os.environ.get("MISTRAL_API_KEY"),
)


async def main() -> None:
message = await client.chat.stream_async(
messages=[
{
"role": "user",
"content": "Tell me something interesting about async streaming agents",
}
],
model="open-mistral-nemo",
)

response = ""
async for event in message:
if event.data.choices[0].finish_reason == "stop":
print("\n")
print(response)
print("\n")
else:
response += event.text


await main()
```
</details>
Expand Down
16 changes: 16 additions & 0 deletions agentops/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .ollama import OllamaProvider
from .openai import OpenAiProvider
from .anthropic import AnthropicProvider
from .mistral import MistralProvider
from .ai21 import AI21Provider

original_func = {}
Expand Down Expand Up @@ -40,6 +41,9 @@ class LlmTracker:
"anthropic": {
"0.32.0": ("completions.create",),
},
"mistralai": {
"1.0.1": ("chat.complete", "chat.stream"),
},
"ai21": {
"2.0.0": (
"chat.completions.create",
Expand Down Expand Up @@ -142,6 +146,17 @@ def override_api(self):
f"Only Anthropic>=0.32.0 supported. v{module_version} found."
)

if api == "mistralai":
module_version = version(api)

if Version(module_version) >= parse("1.0.1"):
provider = MistralProvider(self.client)
provider.override()
else:
logger.warning(
f"Only MistralAI>=1.0.1 supported. v{module_version} found."
)

if api == "ai21":
module_version = version(api)

Expand All @@ -165,4 +180,5 @@ def stop_instrumenting(self):
LiteLLMProvider(self.client).undo_override()
OllamaProvider(self.client).undo_override()
AnthropicProvider(self.client).undo_override()
MistralProvider(self.client).undo_override()
AI21Provider(self.client).undo_override()
Loading

0 comments on commit c0fa6bb

Please sign in to comment.