Skip to content

Commit c94a6a7

Browse files
feat: Added Ollama engine via OpenAI api (#51)
* feat: Added Ollama engine via OpenAI api * fix: Added PR remarks and test
1 parent a15e7b6 commit c94a6a7

File tree

4 files changed

+61
-5
lines changed

4 files changed

+61
-5
lines changed

Diff for: README.md

+14
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,20 @@ We are grateful for all the help we got from our contributors!
384384
<br />
385385
<sub><b>tboen1</b></sub>
386386
</a>
387+
</td>
388+
<td align="center">
389+
<a href="https://github.com/nihalnayak">
390+
<img src="https://avatars.githubusercontent.com/u/5679782?v=4" width="100;" alt="nihalnayak"/>
391+
<br />
392+
<sub><b>Nihal Nayak</b></sub>
393+
</a>
394+
</td>
395+
<td align="center">
396+
<a href="https://github.com/AtakanTekparmak">
397+
<img src="https://avatars.githubusercontent.com/u/59488384?v=4" width="100;" alt="AtakanTekparmak"/>
398+
<br />
399+
<sub><b>Atakan Tekparmak</b></sub>
400+
</a>
387401
</td>
388402
</tr>
389403
<tbody>

Diff for: tests/test_engines.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
import pytest
2+
3+
from textgrad.engine import get_engine
4+
5+
def test_ollama_engine():
6+
# Declare test constants
7+
OLLAMA_BASE_URL = 'http://localhost:11434/v1'
8+
MODEL_STRING = "test-model-string"
9+
10+
# Initialise the engine
11+
engine = get_engine("ollama-" + MODEL_STRING)
12+
13+
assert engine
14+
assert engine.model_string == MODEL_STRING
15+
assert engine.base_url == OLLAMA_BASE_URL

Diff for: textgrad/engine/__init__.py

+8
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,14 @@ def get_engine(engine_name: str, **kwargs) -> EngineLM:
5555
elif engine_name in ["command-r-plus", "command-r", "command", "command-light"]:
5656
from .cohere import ChatCohere
5757
return ChatCohere(model_string=engine_name, **kwargs)
58+
elif engine_name.startswith("ollama"):
59+
from .openai import ChatOpenAI, OLLAMA_BASE_URL
60+
model_string = engine_name.replace("ollama-", "")
61+
return ChatOpenAI(
62+
model_string=model_string,
63+
base_url=OLLAMA_BASE_URL,
64+
**kwargs
65+
)
5866
elif "vllm" in engine_name:
5967
from .vllm import ChatVLLM
6068
engine_name = engine_name.replace("vllm-", "")

Diff for: textgrad/engine/openai.py

+24-5
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,12 @@
1717
from .base import EngineLM, CachedEngine
1818
from .engine_utils import get_image_type_from_bytes
1919

20+
# Default base URL for OLLAMA
21+
OLLAMA_BASE_URL = 'http://localhost:11434/v1'
22+
23+
# Check if the user set the OLLAMA_BASE_URL environment variable
24+
if os.getenv("OLLAMA_BASE_URL"):
25+
OLLAMA_BASE_URL = os.getenv("OLLAMA_BASE_URL")
2026

2127
class ChatOpenAI(EngineLM, CachedEngine):
2228
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."
@@ -26,23 +32,36 @@ def __init__(
2632
model_string: str="gpt-3.5-turbo-0613",
2733
system_prompt: str=DEFAULT_SYSTEM_PROMPT,
2834
is_multimodal: bool=False,
35+
base_url: str=None,
2936
**kwargs):
3037
"""
3138
:param model_string:
3239
:param system_prompt:
40+
:param base_url: Used to support Ollama
3341
"""
3442
root = platformdirs.user_cache_dir("textgrad")
3543
cache_path = os.path.join(root, f"cache_openai_{model_string}.db")
3644

3745
super().__init__(cache_path=cache_path)
3846

3947
self.system_prompt = system_prompt
40-
if os.getenv("OPENAI_API_KEY") is None:
41-
raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")
48+
self.base_url = base_url
4249

43-
self.client = OpenAI(
44-
api_key=os.getenv("OPENAI_API_KEY"),
45-
)
50+
if not base_url:
51+
if os.getenv("OPENAI_API_KEY") is None:
52+
raise ValueError("Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")
53+
54+
self.client = OpenAI(
55+
api_key=os.getenv("OPENAI_API_KEY")
56+
)
57+
elif base_url and base_url == OLLAMA_BASE_URL:
58+
self.client = OpenAI(
59+
base_url=base_url,
60+
api_key="ollama"
61+
)
62+
else:
63+
raise ValueError("Invalid base URL provided. Please use the default OLLAMA base URL or None.")
64+
4665
self.model_string = model_string
4766
self.is_multimodal = is_multimodal
4867

0 commit comments

Comments
 (0)