diff --git a/llm4ad/tools/llm/llm_api_https.py b/llm4ad/tools/llm/llm_api_https.py index 9d1f8db9..98bd3f13 100644 --- a/llm4ad/tools/llm/llm_api_https.py +++ b/llm4ad/tools/llm/llm_api_https.py @@ -31,9 +31,11 @@ class HttpsApi(LLM): def __init__(self, host, key, model, timeout=60, **kwargs): """Https API Args: - host : host name. please note that the host name does not include 'https://' + host : host name. please note that the host name does not include 'https://'. + Note: use "localhost:11434" for ollama's default settings key : API key. - model : LLM model name. + model : LLM model name. + Note: pass local ollama model in the form like "ollama:deepseek-r1:14b" timeout: API timeout. """ super().__init__(**kwargs) @@ -45,6 +47,47 @@ def __init__(self, host, key, model, timeout=60, **kwargs): self._cumulative_error = 0 def draw_sample(self, prompt: str | Any, *args, **kwargs) -> str: + if self._model.startswith('ollama:'): + return self.local_ollama_draw_sample(prompt) + else: + return self.remote_draw_sample(prompt) + + def local_ollama_draw_sample(self, prompt: str | Any, *args, **kwargs) -> str: + if isinstance(prompt, str): + prompt = prompt.strip() + ollama_model = self._model.replace('ollama:','') + while True: + try: + url = f"http://{self._host}/api/generate" + headers = {"Content-Type": "application/json"} + data = { + "model": ollama_model, + "prompt": prompt, + "stream": False + } + res = requests.post(url, headers=headers, json=data) + resjson = res.json() + if 'response' in resjson: + response = resjson['response'] + return response + else: + print(res.status_code) + print(f"[OLLAMA ERROR] Requesting {url} with data: {data}\n") + print(f"[OLLAMA RESPONSE] {resjson}\n") + return {"ollama error": resjson.get("error", "Unknown error")} + except Exception as e: + self._cumulative_error += 1 + if self.debug_mode: + if self._cumulative_error == 10: + raise RuntimeError(f'{self.__class__.__name__} error: {traceback.format_exc()}.' + f'You may check the Ollama ({ollama_model}) service.') + else: + print(f'{self.__class__.__name__} error: {traceback.format_exc()}.' + f'You may check the Ollama ({ollama_model}) service.') + time.sleep(2) + continue + + def remote_draw_sample(self, prompt: str | Any, *args, **kwargs) -> str: if isinstance(prompt, str): prompt = [{'role': 'user', 'content': prompt.strip()}]