Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for mistralai and fireworks-ai #74

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions textgrad/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,5 +71,13 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM:
from .groq import ChatGroq
engine_name = engine_name.replace("groq-", "")
return ChatGroq(model_string=engine_name, **kwargs)
elif "fireworks" in engine_name:
from .fireworks import ChatFireworks
engine_name = engine_name.replace("fireworks-", "")
return ChatFireworks(model_string=engine_name, **kwargs)
elif "mistral" in engine_name:
from .mistral import ChatMistral
engine_name = engine_name.replace("mistral-", "")
return ChatMistral(model_string=engine_name, **kwargs)
else:
raise ValueError(f"Engine {engine_name} not supported")
69 changes: 69 additions & 0 deletions textgrad/engine/fireworks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
try:
from fireworks.client import Fireworks
except ImportError:
raise ImportError("If you'd like to use Fireworks models, please install the fireworks-ai package by running `pip install fireworks-ai`, and add 'FIREWORKS_API_KEY' to your environment variables.")

import os
import platformdirs
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)

from .base import EngineLM, CachedEngine

class ChatFireworks(EngineLM, CachedEngine):
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."

def __init__(
self,
model_string="accounts/fireworks/models/llama-v3-70b-instruct",
system_prompt=DEFAULT_SYSTEM_PROMPT):
"""
:param model_string:
:param system_prompt:
"""
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_fireworks_{model_string}.db")
super().__init__(cache_path=cache_path)

self.system_prompt = system_prompt
if os.getenv("FIREWORKS_API_KEY") is None:
raise ValueError("Please set the FIREWORKS_API_KEY environment variable if you'd like to use FirworksAI models.")

self.client = Fireworks(
api_key=os.getenv("FIREWORKS_API_KEY"),
)
self.model_string = model_string

def generate(
self, prompt, system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
):

sys_prompt_arg = system_prompt if system_prompt else self.system_prompt

cache_or_none = self._check_cache(sys_prompt_arg + prompt)
if cache_or_none is not None:
return cache_or_none

response = self.client.chat.completions.create(
model=self.model_string,
messages=[
{"role": "system", "content": sys_prompt_arg},
{"role": "user", "content": prompt},
],
stop=None,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)

response = response.choices[0].message.content
self._save_cache(sys_prompt_arg + prompt, response)
return response

@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5))
def __call__(self, prompt, **kwargs):
return self.generate(prompt, **kwargs)

68 changes: 68 additions & 0 deletions textgrad/engine/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
try:
from mistralai.client import MistralClient
except ImportError:
raise ImportError("If you'd like to use Mistral models, please install the mistralai package by running `pip install mistralai`, and add 'MISTRAL_API_KEY' to your environment variables.")

import os
import platformdirs
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)

from .base import EngineLM, CachedEngine

class ChatMistral(EngineLM, CachedEngine):
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."

def __init__(
self,
model_string="open-mixtral-8x7b",
system_prompt=DEFAULT_SYSTEM_PROMPT):
"""
:param model_string:
:param system_prompt:
"""
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_mistral_{model_string}.db")
super().__init__(cache_path=cache_path)

self.system_prompt = system_prompt
if os.getenv("MISTRAL_API_KEY") is None:
raise ValueError("Please set the MISTRAL_API_KEY environment variable if you'd like to use MistralAI models.")

self.client = MistralClient(
api_key=os.getenv("MISTRAL_API_KEY"),
)
self.model_string = model_string

def generate(
self, prompt, system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
):

sys_prompt_arg = system_prompt if system_prompt else self.system_prompt

cache_or_none = self._check_cache(sys_prompt_arg + prompt)
if cache_or_none is not None:
return cache_or_none

response = self.client.chat(
model=self.model_string,
messages=[
{"role": "system", "content": sys_prompt_arg},
{"role": "user", "content": prompt},
],
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)

response = response.choices[0].message.content
self._save_cache(sys_prompt_arg + prompt, response)
return response

@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5))
def __call__(self, prompt, **kwargs):
return self.generate(prompt, **kwargs)