From 4a255326498c622fc09def8fccad7c2f0adf2b6d Mon Sep 17 00:00:00 2001 From: Marek Zaremba-Pike Date: Fri, 13 Sep 2024 19:56:59 +0100 Subject: [PATCH] Add ollama llm_type is a totally irrelvant name name, will need to change at some point --- llm-debate/llm_clients.py | 29 ++++++++++++++++++++++++++++- llm-debate/main.py | 2 +- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/llm-debate/llm_clients.py b/llm-debate/llm_clients.py index 9eeef9d..40a694b 100644 --- a/llm-debate/llm_clients.py +++ b/llm-debate/llm_clients.py @@ -1,7 +1,9 @@ +import json import os from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Dict, Optional +import requests from anthropic import Anthropic from openai import OpenAI @@ -43,10 +45,35 @@ def get_response(self, prompt: str, model: str) -> str: return response.content[0].text +class OllamaClient(LLMClient): + def __init__(self) -> None: + self.base_url: str = os.getenv("OLLAMA_BASE_URL", "http://localhost:11434") + self.session: requests.Session = requests.Session() + + def get_response(self, prompt: str, model: str) -> str: + url = f"{self.base_url}/api/generate" + data: Dict[str, Any] = {"model": model, "prompt": prompt} + response = self.session.post(url, json=data, stream=True) + response.raise_for_status() + + full_response = "" + for line in response.iter_lines(): + if line: + json_response = json.loads(line) + if "response" in json_response: + full_response += json_response["response"] + if json_response.get("done", False): + break + + return full_response.strip() + + def get_llm_client(llm_type: str) -> LLMClient: if llm_type == "openai": return OpenAIClient() elif llm_type == "anthropic": return AnthropicClient() + elif llm_type == "ollama": + return OllamaClient() else: raise ValueError(f"Unsupported LLM type: {llm_type}") diff --git a/llm-debate/main.py b/llm-debate/main.py index fdc5187..e109464 100644 --- a/llm-debate/main.py +++ b/llm-debate/main.py @@ -11,7 +11,7 @@ logger = logging.getLogger(__name__) -LLM_CHOICES: List[str] = ["openai", "anthropic"] +LLM_CHOICES: List[str] = ["openai", "anthropic", "ollama"] LOG_LEVELS: Dict[str, int] = { "DEBUG": logging.DEBUG, "INFO": logging.INFO,