Skip to content

Commit

Permalink
Add ollama
Browse files Browse the repository at this point in the history
llm_type is a totally irrelvant name name, will need to change at some
point
  • Loading branch information
marekzp committed Sep 13, 2024
1 parent 8f1c72c commit 4a25532
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 2 deletions.
29 changes: 28 additions & 1 deletion llm-debate/llm_clients.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import json
import os
from abc import ABC, abstractmethod
from typing import Optional
from typing import Any, Dict, Optional

import requests
from anthropic import Anthropic
from openai import OpenAI

Expand Down Expand Up @@ -43,10 +45,35 @@ def get_response(self, prompt: str, model: str) -> str:
return response.content[0].text


class OllamaClient(LLMClient):
def __init__(self) -> None:
self.base_url: str = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434")
self.session: requests.Session = requests.Session()

def get_response(self, prompt: str, model: str) -> str:
url = f"{self.base_url}/api/generate"
data: Dict[str, Any] = {"model": model, "prompt": prompt}
response = self.session.post(url, json=data, stream=True)
response.raise_for_status()

full_response = ""
for line in response.iter_lines():
if line:
json_response = json.loads(line)
if "response" in json_response:
full_response += json_response["response"]
if json_response.get("done", False):
break

return full_response.strip()


def get_llm_client(llm_type: str) -> LLMClient:
if llm_type == "openai":
return OpenAIClient()
elif llm_type == "anthropic":
return AnthropicClient()
elif llm_type == "ollama":
return OllamaClient()
else:
raise ValueError(f"Unsupported LLM type: {llm_type}")
2 changes: 1 addition & 1 deletion llm-debate/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

logger = logging.getLogger(__name__)

LLM_CHOICES: List[str] = ["openai", "anthropic"]
LLM_CHOICES: List[str] = ["openai", "anthropic", "ollama"]
LOG_LEVELS: Dict[str, int] = {
"DEBUG": logging.DEBUG,
"INFO": logging.INFO,
Expand Down

0 comments on commit 4a25532

Please sign in to comment.