Skip to content

Commit

Permalink
Chinese bots (#22)
Browse files Browse the repository at this point in the history
* basic support for iflytek Spark model

* add env variables

* bug fix

* separate API client and Langchain wrapper.

* add dependency

* add env variable

* add Alibaba models wrapper

* add utils for alibaba

* update to v2 api

* add retry to api call

* add variables to .env example

* update helper

* sort deps
  • Loading branch information
semio authored Aug 22, 2023
1 parent 49932e6 commit 0bd63e3
Show file tree
Hide file tree
Showing 10 changed files with 423 additions and 1 deletion.
8 changes: 8 additions & 0 deletions automation-api/.env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,16 @@ GCP_REGION="europe-west3"
# For running in production
OPENAI_API_KEY=""
OPENAI_ORG_ID=""
## for Huggingface Hub
HUGGINGFACEHUB_API_TOKEN=""
## for PALM
GOOGLE_API_KEY=""
## for iFlytek
IFLYTEK_API_KEY=""
IFLYTEK_API_SECRET=""
IFLYTEK_APPID=""
## for Alibaba
DASHSCOPE_API_KEY=""

# For local development
SERVICE_ACCOUNT_CREDENTIALS=""
Expand Down
4 changes: 4 additions & 0 deletions automation-api/lib/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ def read_config() -> dict[str, str]:
"AI_EVAL_DEV_SPREADSHEET_ID",
"HUGGINGFACEHUB_API_TOKEN",
"GOOGLE_API_KEY",
"IFLYTEK_APPID",
"IFLYTEK_API_KEY",
"IFLYTEK_API_SECRET",
"DASHSCOPE_API_KEY",
]:
config[key] = os.getenv(key=key, default="")
return config
118 changes: 118 additions & 0 deletions automation-api/lib/llms/alibaba.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import random
from http import HTTPStatus
from typing import Any, Dict, List, Mapping, Optional

import dashscope
from dashscope import Generation
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from pydantic import root_validator
from tenacity import (
retry,
retry_if_exception_type,
retry_if_not_result,
stop_after_attempt,
)

from lib.config import read_config


def response_is_ok(response):
if response.status_code == HTTPStatus.OK:
return True
return False


@retry(
retry=(retry_if_exception_type() | retry_if_not_result(response_is_ok)),
stop=stop_after_attempt(3),
)
def get_reply(**kwargs):
return Generation.call(**kwargs)


def get_from_dict_or_env(data, key, env_key):
if key in data and data[key]:
return data[key]
else:
config = read_config()
if env_key in config and config[env_key]:
return config[env_key]
raise ValueError(
f"Did not found {key} in provided dict and {env_key} in environment variables"
)


class Alibaba(LLM):
# TODO: maybe rewrite based on BaseLLM. Need to implement the more complex _generate method.
model_name: Optional[str] = "qwen-v1"
top_p: Optional[float] = 0.8
top_k: Optional[int] = 100
enable_search: Optional[bool] = False
seed: Optional[int] = None

@property
def _llm_type(self) -> str:
return "alibaba"

@root_validator()
def validate_environment(cls, values: Dict) -> Dict: # noqa: N805
"""Validate api key, python package exists."""
dashscope_api_key = get_from_dict_or_env(
values, "dashscope_api_key", "DASHSCOPE_API_KEY"
)
dashscope.api_key = dashscope_api_key

if values["top_p"] is not None and not 0.0 <= values["top_p"] <= 1.0:
raise ValueError("max_output_tokens must be between 0 and 1")

if values["top_k"] is not None and not 1 <= values["top_k"] <= 100:
raise ValueError("top_k must be between 1 and 100")

return values

def _call(
self,
prompt: str,
history: Optional[List[Dict]] = None,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")

if history is None:
history = []

if self.seed is None:
# FIXME: Alibaba's API support uint64
# but I am not sure what's the max number I can generate with randint()
seed = random.randint(0, 2**63)
# seed = np.random.randint(2**64, dtype=np.uint64) # this result in TypeError
else:
seed = self.seed

result = get_reply.retry_with(
stop=stop_after_attempt(
3
) # TODO: set how many times to try as the class vars.
)(
model=self.model_name,
prompt=prompt,
history=history,
top_p=self.top_p,
top_k=self.top_k,
seed=seed,
enable_search=self.enable_search,
)

return result["output"]["text"]

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"top_p": self.top_p,
"top_k": self.top_k,
"enable_search": self.enable_search,
}
1 change: 1 addition & 0 deletions automation-api/lib/llms/iflytek/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .spark_api import SparkClient # noqa: F401
143 changes: 143 additions & 0 deletions automation-api/lib/llms/iflytek/spark_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
"""
API for IFlyTek's Spark, and langchain class for it.
"""

import base64
import hashlib
import hmac
import json
from datetime import datetime
from time import mktime
from typing import Any, Dict, Optional
from urllib.parse import urlencode, urlparse
from wsgiref.handlers import format_date_time

import websocket


class Ws_Param(object):
# this class was taken from IFlyTek's doc
def __init__(self, APPID, APIKey, APISecret, gpt_url):
self.APPID = APPID
self.APIKey = APIKey
self.APISecret = APISecret
self.host = urlparse(gpt_url).netloc
self.path = urlparse(gpt_url).path
self.gpt_url = gpt_url

def create_url(self):
# RFC1123 timestamp
now = datetime.now()
date = format_date_time(mktime(now.timetuple()))

# origin
signature_origin = "host: " + self.host + "\n"
signature_origin += "date: " + date + "\n"
signature_origin += "GET " + self.path + " HTTP/1.1"

# use hmac-sha256 to create auth info
signature_sha = hmac.new(
self.APISecret.encode("utf-8"),
signature_origin.encode("utf-8"),
digestmod=hashlib.sha256,
).digest()

signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8")

authorization_origin = f'api_key="{self.APIKey}", algorithm="hmac-sha256", headers="host date request-line", signature="{signature_sha_base64}"'

authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode(
encoding="utf-8"
)

# create dictionary
v = {"authorization": authorization, "date": date, "host": self.host}
# create URL
url = self.gpt_url + "?" + urlencode(v)
return url


def get_reply(url: str, data: Dict) -> Dict[str, Any]:
ws = websocket.WebSocket()
ws.connect(url)
ws.send(json.dumps(data))

res = []

while True:
reply = ws.recv()
data = json.loads(reply)
code = data["header"]["code"]
message = data["header"]["message"]
if code != 0:
print(f"WS error: {code}, {message}")
ws.close()

if code in [
10013,
10014,
10019,
]: # these codes mean the input/output were blocked by content filter.
return {"text": message}
else:
raise websocket.WebSocketException("Websocket Error.")
else:
choices = data["payload"]["choices"]
status = choices["status"]
content = choices["text"][0]["content"]
res.append(content)
if status == 2:
usage = data["payload"]["usage"]
ws.close()

return {"text": "".join(res), "usage": usage}


class SparkClient:
gpt_url: str = "wss://spark-api.xf-yun.com/v2.1/chat"
# TODO: add support for selecting v1 and v2?
# v1 url: "ws(s)://spark-api.xf-yun.com/v1.1/chat"

def __init__(self, appid: str, api_key: str, api_secret: str) -> None:
self.appid = appid
self.ws_url = Ws_Param(appid, api_key, api_secret, self.gpt_url).create_url()

def gen_parameters(
self,
uid: str = "0",
chat_id: Optional[str] = None,
temperature: float = 0.5,
max_tokens: int = 2048, # [1, 4096]
top_k: int = 4, # [1, 6]
) -> Dict:
data: Dict[str, Any] = {
"header": {"app_id": self.appid, "uid": uid},
"parameter": {
"chat": {
"domain": "generalv2", # v1 domain: "general"
"temperature": temperature,
"max_tokens": max_tokens,
"top_k": top_k,
}
},
}
if chat_id:
data["parameter"]["chat"]["chat_id"] = chat_id

return data

def gen_payload(self, content):
data = {
"payload": {"message": {"text": [{"role": "user", "content": content}]}}
}
return data

def generate_text(self, content, **kwargs) -> Dict[str, Any]:
data = self.gen_parameters(**kwargs)
data.update(self.gen_payload(content))
res = get_reply(self.ws_url, data)
return res

def chat(self):
# TODO: add chat function, which accepts some message history and generate new reply.
raise NotImplementedError()
100 changes: 100 additions & 0 deletions automation-api/lib/llms/spark.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Langchain Wrapper for iFlytek Spark
"""

from typing import Any, Dict, List, Mapping, Optional

from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from pydantic import root_validator
from tenacity import (
retry,
retry_if_exception_type,
stop_after_attempt,
)

from lib.config import read_config
from lib.llms.iflytek import SparkClient


def get_from_dict_or_env(data, key, env_key):
if key in data and data[key]:
return data[key]
else:
config = read_config()
if env_key in config and config[env_key]:
return config[env_key]
raise ValueError(
f"Did not found {key} in provided dict and {env_key} in environment variables"
)


class Spark(LLM):
# TODO: maybe rewrite based on BaseLLM. Need to implement the more complex _generate method.
client: Any
iflytek_appid: str
iflytek_api_key: str
iflytek_api_secret: str
temperature: Optional[float] = 0.5
max_tokens: Optional[int] = 2048
top_k: Optional[int] = 4

@property
def _llm_type(self) -> str:
return "iflytek_spark"

@root_validator()
def validate_environment(cls, values: Dict) -> Dict: # noqa: N805
"""Validate api key, python package exists."""
iflytek_appid = get_from_dict_or_env(values, "iflytek_appid", "IFLYTEK_APPID")
iflytek_api_key = get_from_dict_or_env(
values, "iflytek_api_key", "IFLYTEK_API_KEY"
)
iflytek_api_secret = get_from_dict_or_env(
values, "iflytek_api_secret", "IFLYTEK_API_SECRET"
)

values["client"] = SparkClient(
iflytek_appid, iflytek_api_key, iflytek_api_secret
)

if values["temperature"] is not None and not 0 <= values["temperature"] <= 1:
raise ValueError("temperature must be in the range [0.0, 1.0]")

if values["top_k"] is not None and not 1 <= values["top_k"] <= 6:
raise ValueError("top_k must be between 1 and 6")

if values["max_tokens"] is not None and not 1 <= values["max_tokens"] <= 4096:
raise ValueError("max_output_tokens must be between 1 and 4096")

return values

@retry(
retry=(retry_if_exception_type()),
stop=stop_after_attempt(3),
)
def generate_text_with_retry(self, prompt):
return self.client.generate_text(
prompt,
temperature=self.temperature,
max_tokens=self.max_tokens,
top_k=self.top_k,
)["text"]

def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
) -> str:
if stop is not None:
raise ValueError("stop kwargs are not permitted.")
return self.generate_text_with_retry(prompt)

@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_k": self.top_k,
}
Loading

0 comments on commit 0bd63e3

Please sign in to comment.