diff --git a/vlmeval/api/claude.py b/vlmeval/api/claude.py index 8bc79916..c129cf4c 100644 --- a/vlmeval/api/claude.py +++ b/vlmeval/api/claude.py @@ -5,11 +5,17 @@ import mimetypes from PIL import Image -url = 'https://openxlab.org.cn/gw/alles-apin-hub/v1/claude/v1/text/chat' -headers = { +alles_url = 'https://openxlab.org.cn/gw/alles-apin-hub/v1/claude/v1/text/chat' +alles_headers = { 'alles-apin-token': '', 'Content-Type': 'application/json' } +official_url = 'https://api.anthropic.com/v1/messages' +official_headers = { + 'x-api-key': '', + 'anthropic-version': '2023-06-01', + 'content-type': 'application/json' +} class Claude_Wrapper(BaseAPI): @@ -17,6 +23,7 @@ class Claude_Wrapper(BaseAPI): is_api: bool = True def __init__(self, + backend: str = 'alles', model: str = 'claude-3-opus-20240229', key: str = None, retry: int = 10, @@ -27,15 +34,26 @@ def __init__(self, max_tokens: int = 1024, **kwargs): + if os.environ.get('ANTHROPIC_BACKEND', '') == 'official': + backend = 'official' + + assert backend in ['alles', 'official'], f'Invalid backend: {backend}' + self.backend = backend + self.url = alles_url if backend == 'alles' else official_url self.model = model - self.headers = headers self.temperature = temperature self.max_tokens = max_tokens + self.headers = alles_headers if backend == 'alles' else official_headers + if key is not None: self.key = key else: - self.key = os.environ.get('ALLES', '') - self.headers['alles-apin-token'] = self.key + self.key = os.environ.get('ALLES', '') if self.backend == 'alles' else os.environ.get('ANTHROPIC_API_KEY', '') # noqa: E501 + + if self.backend == 'alles': + self.headers['alles-apin-token'] = self.key + else: + self.headers['x-api-key'] = self.key super().__init__(retry=retry, wait=wait, verbose=verbose, system_prompt=system_prompt, **kwargs) @@ -81,15 +99,16 @@ def prepare_inputs(self, inputs): return input_msgs def generate_inner(self, inputs, **kwargs) -> str: - - payload = json.dumps({ + payload = { 'model': self.model, 'max_tokens': self.max_tokens, 'messages': self.prepare_inputs(inputs), - 'system': self.system_prompt, **kwargs - }) - response = requests.request('POST', url, headers=headers, data=payload) + } + if self.system_prompt is not None: + payload['system'] = self.system_prompt + + response = requests.request('POST', self.url, headers=self.headers, data=json.dumps(payload)) ret_code = response.status_code ret_code = 0 if (200 <= int(ret_code) < 300) else ret_code answer = self.fail_msg