Skip to content

Commit

Permalink
support official claude
Browse files Browse the repository at this point in the history
  • Loading branch information
kennymckormick committed Dec 22, 2024
1 parent dbda46a commit 2089e3c
Showing 1 changed file with 29 additions and 10 deletions.
39 changes: 29 additions & 10 deletions vlmeval/api/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,25 @@
import mimetypes
from PIL import Image

url = 'https://openxlab.org.cn/gw/alles-apin-hub/v1/claude/v1/text/chat'
headers = {
alles_url = 'https://openxlab.org.cn/gw/alles-apin-hub/v1/claude/v1/text/chat'
alles_headers = {
'alles-apin-token': '',
'Content-Type': 'application/json'
}
official_url = 'https://api.anthropic.com/v1/messages'
official_headers = {
'x-api-key': '',
'anthropic-version': '2023-06-01',
'content-type': 'application/json'
}


class Claude_Wrapper(BaseAPI):

is_api: bool = True

def __init__(self,
backend: str = 'alles',
model: str = 'claude-3-opus-20240229',
key: str = None,
retry: int = 10,
Expand All @@ -27,15 +34,26 @@ def __init__(self,
max_tokens: int = 1024,
**kwargs):

if os.environ.get('ANTHROPIC_BACKEND', '') == 'official':
backend = 'official'

assert backend in ['alles', 'official'], f'Invalid backend: {backend}'
self.backend = backend
self.url = alles_url if backend == 'alles' else official_url
self.model = model
self.headers = headers
self.temperature = temperature
self.max_tokens = max_tokens
self.headers = alles_headers if backend == 'alles' else official_headers

if key is not None:
self.key = key
else:
self.key = os.environ.get('ALLES', '')
self.headers['alles-apin-token'] = self.key
self.key = os.environ.get('ALLES', '') if self.backend == 'alles' else os.environ.get('ANTHROPIC_API_KEY', '') # noqa: E501

if self.backend == 'alles':
self.headers['alles-apin-token'] = self.key
else:
self.headers['x-api-key'] = self.key

super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs)

Expand Down Expand Up @@ -81,15 +99,16 @@ def prepare_inputs(self, inputs):
return input_msgs

def generate_inner(self, inputs, **kwargs) -> str:

payload = json.dumps({
payload = {
'model': self.model,
'max_tokens': self.max_tokens,
'messages': self.prepare_inputs(inputs),
'system': self.system_prompt,
**kwargs
})
response = requests.request('POST', url, headers=headers, data=payload)
}
if self.system_prompt is not None:
payload['system'] = self.system_prompt

response = requests.request('POST', self.url, headers=self.headers, data=json.dumps(payload))
ret_code = response.status_code
ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code
answer = self.fail_msg
Expand Down

0 comments on commit 2089e3c

Please sign in to comment.