diff --git a/llm4ad/tools/llm/llm_api_local.py b/llm4ad/tools/llm/llm_api_local.py new file mode 100644 index 00000000..974fd518 --- /dev/null +++ b/llm4ad/tools/llm/llm_api_local.py @@ -0,0 +1,92 @@ +# This file is part of the LLM4AD project (https://github.com/Optima-CityU/llm4ad). +# Last Revision: 2025/2/16 +# +# ------------------------------- Copyright -------------------------------- +# Copyright (c) 2025 Optima Group. +# +# Permission is granted to use the LLM4AD platform for research purposes. +# All publications, software, or other works that utilize this platform +# or any part of its codebase must acknowledge the use of "LLM4AD" and +# cite the following reference: +# +# Fei Liu, Rui Zhang, Zhuoliang Xie, Rui Sun, Kai Li, Xi Lin, Zhenkun Wang, +# Zhichao Lu, and Qingfu Zhang, "LLM4AD: A Platform for Algorithm Design +# with Large Language Model," arXiv preprint arXiv:2412.17287 (2024). +# +# For inquiries regarding commercial use or licensing, please contact +# http://www.llm4ad.com/contact.html +# -------------------------------------------------------------------------- + +from __future__ import annotations + +import http.client +import json +import time +from typing import Any +import traceback +from ...base import LLM + + +class LocalApi(LLM): + def __init__(self, model, timeout=20, **kwargs): + # Both 'host' and 'base url' parameters are supported + host = kwargs.pop('host', None) + if host is None: + host = kwargs.pop('base_url', None) + if host is None: + raise ValueError("The host or base url parameter must be provided") + + # Both 'key' and 'api key' parameters are supported + key = kwargs.pop('key', None) + if key is None: + key = kwargs.pop('api_key', None) + if key is None: + raise ValueError("The key or api key parameter must be provided") + + super().__init__(**kwargs) + self._host = host + self._key = key + self._model = model + self._timeout = timeout + self._kwargs = kwargs + self._cumulative_error = 0 + + def draw_sample(self, prompt: str | Any, *args, **kwargs) -> str: + if isinstance(prompt, str): + prompt = [{'role': 'user', 'content': prompt.strip()}] + + while True: + try: + conn = http.client.HTTPConnection(self._host, timeout=self._timeout) + payload = json.dumps({ + 'max_tokens': self._kwargs.get('max_tokens', 4096), + 'top_p': self._kwargs.get('top_p', None), + 'temperature': self._kwargs.get('temperature', 1.0), + 'model': self._model, + 'messages': prompt + }) + headers = { + 'Authorization': f'Bearer {self._key}', + 'User-Agent': 'Apifox/1.0.0 (https://apifox.com)', + 'Content-Type': 'application/json' + } + conn.request('POST', '/v1/chat/completions', payload, headers) + res = conn.getresponse() + data = res.read().decode('utf-8') + data = json.loads(data) + # print(data) + response = data['choices'][0]['message']['content'] + if self.debug_mode: + self._cumulative_error = 0 + return response + except Exception as e: + self._cumulative_error += 1 + if self.debug_mode: + if self._cumulative_error == 10: + raise RuntimeError(f'{self.__class__.__name__} error: {traceback.format_exc()}.' + f'You may check your API host and API key.') + else: + print(f'{self.__class__.__name__} error: {traceback.format_exc()}.' + f'You may check your API host and API key.') + time.sleep(2) + continue