Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions llm4ad/tools/llm/llm_api_https.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,11 @@ class HttpsApi(LLM):
def __init__(self, host, key, model, timeout=60, **kwargs):
"""Https API
Args:
host : host name. please note that the host name does not include 'https://'
host : host name. please note that the host name does not include 'https://'.
Note: use "localhost:11434" for ollama's default settings
key : API key.
model : LLM model name.
model : LLM model name.
Note: pass local ollama model in the form like "ollama:deepseek-r1:14b"
timeout: API timeout.
"""
super().__init__(**kwargs)
Expand All @@ -45,6 +47,47 @@ def __init__(self, host, key, model, timeout=60, **kwargs):
self._cumulative_error = 0

def draw_sample(self, prompt: str | Any, *args, **kwargs) -> str:
if self._model.startswith('ollama:'):
return self.local_ollama_draw_sample(prompt)
else:
return self.remote_draw_sample(prompt)

def local_ollama_draw_sample(self, prompt: str | Any, *args, **kwargs) -> str:
if isinstance(prompt, str):
prompt = prompt.strip()
ollama_model = self._model.replace('ollama:','')
while True:
try:
url = f"http://{self._host}/api/generate"
headers = {"Content-Type": "application/json"}
data = {
"model": ollama_model,
"prompt": prompt,
"stream": False
}
res = requests.post(url, headers=headers, json=data)
resjson = res.json()
if 'response' in resjson:
response = resjson['response']
return response
else:
print(res.status_code)
print(f"[OLLAMA ERROR] Requesting {url} with data: {data}\n")
print(f"[OLLAMA RESPONSE] {resjson}\n")
return {"ollama error": resjson.get("error", "Unknown error")}
except Exception as e:
self._cumulative_error += 1
if self.debug_mode:
if self._cumulative_error == 10:
raise RuntimeError(f'{self.__class__.__name__} error: {traceback.format_exc()}.'
f'You may check the Ollama ({ollama_model}) service.')
else:
print(f'{self.__class__.__name__} error: {traceback.format_exc()}.'
f'You may check the Ollama ({ollama_model}) service.')
time.sleep(2)
continue

def remote_draw_sample(self, prompt: str | Any, *args, **kwargs) -> str:
if isinstance(prompt, str):
prompt = [{'role': 'user', 'content': prompt.strip()}]

Expand Down