From fd0c31960e2e781e07c65012226a6e728260a29d Mon Sep 17 00:00:00 2001 From: semio Date: Wed, 16 Aug 2023 17:27:25 +0800 Subject: [PATCH] separate API client and Langchain wrapper. --- automation-api/lib/llms/iflytek/__init__.py | 2 +- automation-api/lib/llms/iflytek/spark_api.py | 84 +------------------- automation-api/lib/llms/utils.py | 2 +- 3 files changed, 3 insertions(+), 85 deletions(-) diff --git a/automation-api/lib/llms/iflytek/__init__.py b/automation-api/lib/llms/iflytek/__init__.py index 3822c82..31e9c65 100644 --- a/automation-api/lib/llms/iflytek/__init__.py +++ b/automation-api/lib/llms/iflytek/__init__.py @@ -1 +1 @@ -from .spark_api import Spark # noqa: F401 +from .spark_api import SparkClient # noqa: F401 diff --git a/automation-api/lib/llms/iflytek/spark_api.py b/automation-api/lib/llms/iflytek/spark_api.py index a8b3ec5..7d506aa 100644 --- a/automation-api/lib/llms/iflytek/spark_api.py +++ b/automation-api/lib/llms/iflytek/spark_api.py @@ -8,16 +8,11 @@ import json from datetime import datetime from time import mktime -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, Optional from urllib.parse import urlencode, urlparse from wsgiref.handlers import format_date_time import websocket -from langchain.callbacks.manager import CallbackManagerForLLMRun -from langchain.llms.base import LLM -from pydantic import root_validator - -from lib.config import read_config class Ws_Param(object): @@ -144,80 +139,3 @@ def generate_text(self, content, **kwargs) -> Dict[str, Any]: def chat(self): # TODO: add chat function, which accepts some message history and generate new reply. raise NotImplementedError() - - -def get_from_dict_or_env(data, key, env_key): - if key in data and data[key]: - return data[key] - else: - config = read_config() - if env_key in config and config[env_key]: - return config[env_key] - raise ValueError( - f"Did not found {key} in provided dict and {env_key} in environment variables" - ) - - -class Spark(LLM): - # TODO: maybe rewrite based on BaseLLM. Need to implement the more complex _generate method. - client: Any - iflytek_appid: str - iflytek_api_key: str - iflytek_api_secret: str - temperature: Optional[float] = 0.5 - max_tokens: Optional[int] = 2048 - top_k: Optional[int] = 4 - - @property - def _llm_type(self) -> str: - return "iflytek_spark" - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: # noqa: N805 - """Validate api key, python package exists.""" - iflytek_appid = get_from_dict_or_env(values, "iflytek_appid", "IFLYTEK_APPID") - iflytek_api_key = get_from_dict_or_env( - values, "iflytek_api_key", "IFLYTEK_API_KEY" - ) - iflytek_api_secret = get_from_dict_or_env( - values, "iflytek_api_secret", "IFLYTEK_API_SECRET" - ) - - values["client"] = SparkClient( - iflytek_appid, iflytek_api_key, iflytek_api_secret - ) - - if values["temperature"] is not None and not 0 <= values["temperature"] <= 1: - raise ValueError("temperature must be in the range [0.0, 1.0]") - - if values["top_k"] is not None and not 1 <= values["top_k"] <= 6: - raise ValueError("top_k must be between 1 and 6") - - if values["max_tokens"] is not None and not 1 <= values["max_tokens"] <= 4096: - raise ValueError("max_output_tokens must be between 1 and 4096") - - return values - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - ) -> str: - if stop is not None: - raise ValueError("stop kwargs are not permitted.") - return self.client.generate_text( - prompt, - temperature=self.temperature, - max_tokens=self.max_tokens, - top_k=self.top_k, - )["text"] - - @property - def _identifying_params(self) -> Mapping[str, Any]: - """Get the identifying parameters.""" - return { - "temperature": self.temperature, - "max_tokens": self.max_tokens, - "top_k": self.top_k, - } diff --git a/automation-api/lib/llms/utils.py b/automation-api/lib/llms/utils.py index 5a99227..796f5d6 100644 --- a/automation-api/lib/llms/utils.py +++ b/automation-api/lib/llms/utils.py @@ -12,7 +12,7 @@ from lib.config import read_config from .fake import RandomAnswerLLM -from .iflytek import Spark +from .spark import Spark def get_openai_model(model_name: str, **kwargs: Any) -> Union[ChatOpenAI, OpenAI]: