Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 92 additions & 0 deletions llm4ad/tools/llm/llm_api_local.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
# This file is part of the LLM4AD project (https://github.com/Optima-CityU/llm4ad).
# Last Revision: 2025/2/16
#
# ------------------------------- Copyright --------------------------------
# Copyright (c) 2025 Optima Group.
#
# Permission is granted to use the LLM4AD platform for research purposes.
# All publications, software, or other works that utilize this platform
# or any part of its codebase must acknowledge the use of "LLM4AD" and
# cite the following reference:
#
# Fei Liu, Rui Zhang, Zhuoliang Xie, Rui Sun, Kai Li, Xi Lin, Zhenkun Wang,
# Zhichao Lu, and Qingfu Zhang, "LLM4AD: A Platform for Algorithm Design
# with Large Language Model," arXiv preprint arXiv:2412.17287 (2024).
#
# For inquiries regarding commercial use or licensing, please contact
# http://www.llm4ad.com/contact.html
# --------------------------------------------------------------------------

from __future__ import annotations

import http.client
import json
import time
from typing import Any
import traceback
from ...base import LLM


class LocalApi(LLM):
def __init__(self, model, timeout=20, **kwargs):
# Both 'host' and 'base url' parameters are supported
host = kwargs.pop('host', None)
if host is None:
host = kwargs.pop('base_url', None)
if host is None:
raise ValueError("The host or base url parameter must be provided")

# Both 'key' and 'api key' parameters are supported
key = kwargs.pop('key', None)
if key is None:
key = kwargs.pop('api_key', None)
if key is None:
raise ValueError("The key or api key parameter must be provided")

super().__init__(**kwargs)
self._host = host
self._key = key
self._model = model
self._timeout = timeout
self._kwargs = kwargs
self._cumulative_error = 0

def draw_sample(self, prompt: str | Any, *args, **kwargs) -> str:
if isinstance(prompt, str):
prompt = [{'role': 'user', 'content': prompt.strip()}]

while True:
try:
conn = http.client.HTTPConnection(self._host, timeout=self._timeout)
payload = json.dumps({
'max_tokens': self._kwargs.get('max_tokens', 4096),
'top_p': self._kwargs.get('top_p', None),
'temperature': self._kwargs.get('temperature', 1.0),
'model': self._model,
'messages': prompt
})
headers = {
'Authorization': f'Bearer {self._key}',
'User-Agent': 'Apifox/1.0.0 (https://apifox.com)',
'Content-Type': 'application/json'
}
conn.request('POST', '/v1/chat/completions', payload, headers)
res = conn.getresponse()
data = res.read().decode('utf-8')
data = json.loads(data)
# print(data)
response = data['choices'][0]['message']['content']
if self.debug_mode:
self._cumulative_error = 0
return response
except Exception as e:
self._cumulative_error += 1
if self.debug_mode:
if self._cumulative_error == 10:
raise RuntimeError(f'{self.__class__.__name__} error: {traceback.format_exc()}.'
f'You may check your API host and API key.')
else:
print(f'{self.__class__.__name__} error: {traceback.format_exc()}.'
f'You may check your API host and API key.')
time.sleep(2)
continue