diff --git a/garak/generators/openai.py b/garak/generators/openai.py index 536393c70..1c8065399 100644 --- a/garak/generators/openai.py +++ b/garak/generators/openai.py @@ -14,6 +14,7 @@ import json import logging import re +import tiktoken from typing import List, Union import openai @@ -114,12 +115,24 @@ "gpt-4o": 128000, "gpt-4o-2024-05-13": 128000, "gpt-4o-2024-08-06": 128000, - "gpt-4o-mini": 16384, + "gpt-4o-mini": 128000, "gpt-4o-mini-2024-07-18": 16384, - "o1-mini": 65536, + "o1": 200000, + "o1-mini": 128000, "o1-mini-2024-09-12": 65536, "o1-preview": 32768, "o1-preview-2024-09-12": 32768, + "o3-mini": 200000, +} + +output_max = { + "gpt-3.5-turbo": 4096, + "gpt-4": 8192, + "gpt-4o": 16384, + "o3-mini": 100000, + "o1": 100000, + "o1-mini": 65536, + "gpt-4o-mini": 16384, } @@ -173,6 +186,87 @@ def _clear_client(self): def _validate_config(self): pass + def _validate_token_args(self, create_args: dict, prompt: Conversation) -> dict: + """Ensure maximum token limit compatibility with OpenAI create request""" + token_generation_limit_key = "max_tokens" + fixed_cost = 0 + if ( + self.generator == self.client.chat.completions + and self.max_tokens is not None + ): + token_generation_limit_key = "max_completion_tokens" + if not hasattr(self, "max_completion_tokens"): + create_args["max_completion_tokens"] = self.max_tokens + + create_args.pop( + "max_tokens", None + ) # remove deprecated value, utilize `max_completion_tokens` + # every reply is primed with <|start|>assistant<|message|> (3 toks) plus 1 for name change + # see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb + # section 6 "Counting tokens for chat completions API calls" + # TODO: adjust fixed cost to account for each `turn` already in the prompt + fixed_cost = 7 + + # basic token boundary validation to ensure requests are not rejected for exceeding target context length + token_generation_limit = create_args.pop(token_generation_limit_key, None) + if token_generation_limit is not None: + # Suppress max_tokens if greater than context_len + if ( + hasattr(self, "context_len") + and self.context_len is not None + and token_generation_limit > self.context_len + ): + logging.warning( + f"Requested garak maximum tokens {token_generation_limit} exceeds context length {self.context_len}, no limit will be applied to the request" + ) + token_generation_limit = None + + if ( + self.name in output_max + and token_generation_limit is not None + and token_generation_limit > output_max[self.name] + ): + logging.warning( + f"Requested maximum tokens {token_generation_limit} exceeds max output {output_max[self.name]}, no limit will be applied to the request" + ) + token_generation_limit = None + + if self.context_len is not None and token_generation_limit is not None: + # count tokens in prompt and ensure token_generation_limit requested is <= context_len or output_max allowed + prompt_tokens = 0 # this should apply to messages object + try: + encoding = tiktoken.encoding_for_model(self.name) + prompt_tokens = 0 + for turn in prompt.turns: + prompt_tokens = len(encoding.encode(turn.content.text)) + except KeyError as e: + prompt_tokens = int( + len(prompt.split()) * 4 / 3 + ) # extra naive fallback 1 token ~= 3/4 of a word + + if ( + prompt_tokens + fixed_cost + token_generation_limit + > self.context_len + ) and (prompt_tokens + fixed_cost < self.context_len): + token_generation_limit = ( + self.context_len - prompt_tokens - fixed_cost + ) + elif token_generation_limit > prompt_tokens + fixed_cost: + token_generation_limit = ( + token_generation_limit - prompt_tokens - fixed_cost + ) + else: + raise garak.exception.GarakException( + "A response of %s toks plus prompt %s toks cannot be generated; API capped at context length %s toks" + % ( + self.max_tokens, + prompt_tokens + fixed_cost, + self.context_len, + ) + ) + create_args[token_generation_limit_key] = token_generation_limit + return create_args + def __init__(self, name="", config_root=_config): self.name = name self._load_config(config_root) @@ -218,7 +312,8 @@ def _call_model( create_args = {} if "n" not in self.suppressed_params: create_args["n"] = generations_this_call - for arg in inspect.signature(self.generator.create).parameters: + create_params = inspect.signature(self.generator.create).parameters + for arg in create_params: if arg == "model": create_args[arg] = self.name continue @@ -232,6 +327,12 @@ def _call_model( for k, v in self.extra_params.items(): create_args[k] = v + try: + create_args = self._validate_token_args(create_args, prompt) + except garak.exception.GarakException as e: + logging.exception(e) + return [None] * generations_this_call + if self.generator == self.client.completions: if not isinstance(prompt, Conversation) or len(prompt.turns) > 1: msg = ( @@ -239,7 +340,7 @@ def _call_model( f"Returning nothing!" ) logging.error(msg) - return list() + return [None] * generations_this_call create_args["prompt"] = prompt.last_message().text @@ -255,7 +356,7 @@ def _call_model( f"Returning nothing!" ) logging.error(msg) - return list() + return [None] * generations_this_call create_args["messages"] = messages @@ -265,7 +366,7 @@ def _call_model( msg = "Bad request: " + str(repr(prompt)) logging.exception(e) logging.error(msg) - return [None] + return [None] * generations_this_call except json.decoder.JSONDecodeError as e: logging.exception(e) if self.retry_json: @@ -282,7 +383,7 @@ def _call_model( if self.retry_json: raise garak.exception.GarakBackoffTrigger(msg) else: - return [None] + return [None] * generations_this_call if self.generator == self.client.completions: return [Message(c.text) for c in response.choices] diff --git a/tests/generators/test_openai_compatible.py b/tests/generators/test_openai_compatible.py index af2e7461e..1311ba104 100644 --- a/tests/generators/test_openai_compatible.py +++ b/tests/generators/test_openai_compatible.py @@ -3,11 +3,14 @@ import os import httpx +import lorem import respx import pytest +import tiktoken import importlib import inspect + from collections.abc import Iterable from garak.attempt import Message, Turn, Conversation @@ -25,7 +28,7 @@ "generators.groq.GroqChat", ] -MODEL_NAME = "gpt-3.5-turbo-instruct" +MODEL_NAME = "gpt-3.5-turbo" ENV_VAR = os.path.abspath( __file__ ) # use test path as hint encase env changes are missed @@ -124,3 +127,134 @@ def test_openai_multiprocessing(openai_compat_mocks, classname): assert isinstance( result[0], Message ), "generator should return list of Turns or Nones" + + +def create_prompt(prompt_length: int): + test_large_context = "" + encoding = tiktoken.encoding_for_model(MODEL_NAME) + while len(encoding.encode(test_large_context)) < prompt_length: + test_large_context += "\n" + lorem.paragraph() + return Conversation([Turn(role="user", content=Message(test_large_context))]) + + +TOKEN_LIMIT_EXPECTATIONS = { + "use_max_completion_tokens": ( + 100, # prompt_length + 2048, # max_tokens + 4096, # context_len + False, # makes_request + lambda a: a["max_completion_tokens"] <= 2048, # check_lambda + "request max_completion_tokens must account for prompt tokens", # err_msg + ), + "use_max_tokens": ( + 100, + 2048, + 4096, + False, + lambda a: a["max_tokens"] <= 2048, + "request max_must account for prompt tokens", + ), + "suppress_tokens": ( + 100, + 4096, + 2048, + False, + lambda a: a.get("max_completion_tokens", None) is None + and a.get("max_tokens", None) is None, + "request max_tokens is suppressed when larger than context length", + ), + "skip_request_above_user_limit": ( + 4096, + 2048, + 4096, + True, + None, + "a prompt larger than max_tokens must skip request", + ), + "skip_request_based_on_model_context": ( + 4096, + 4096, + 4096, + True, + None, + "a prompt larger than context_len must skip request", + ), +} + + +@pytest.mark.parametrize( + "test_conditions", + [key for key in TOKEN_LIMIT_EXPECTATIONS.keys() if key != "use_max_tokens"], +) +def test_validate_call_model_chat_token_restrictions( + openai_compat_mocks, test_conditions +): + import json + + prompt_length, max_tokens, context_len, makes_request, check_lambda, err_msg = ( + TOKEN_LIMIT_EXPECTATIONS[test_conditions] + ) + + generator = build_test_instance(OpenAICompatible) + generator._load_client() + generator.max_tokens = max_tokens + generator.context_len = context_len + generator.generator = generator.client.chat.completions + mock_url = getattr(generator, "uri", "https://api.openai.com/v1") + with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock: + mock_response = openai_compat_mocks["chat"] + respx_mock.post("chat/completions").mock( + return_value=httpx.Response( + mock_response["code"], json=mock_response["json"] + ) + ) + prompt_text = create_prompt(prompt_length) + if makes_request: + resp = generator._call_model(prompt_text) + assert not respx_mock.routes[0].called + assert resp == [None] + else: + generator._call_model(prompt_text) + req_body = json.loads(respx_mock.routes[0].calls[0].request.content) + assert check_lambda(req_body), err_msg + + +@pytest.mark.parametrize( + "test_conditions", + [ + key + for key in TOKEN_LIMIT_EXPECTATIONS.keys() + if key != "use_max_completion_tokens" + ], +) +def test_validate_call_model_completion_token_restrictions( + openai_compat_mocks, test_conditions +): + import json + + prompt_length, max_tokens, context_len, makes_request, check_lambda, err_msg = ( + TOKEN_LIMIT_EXPECTATIONS[test_conditions] + ) + + generator = build_test_instance(OpenAICompatible) + generator._load_client() + generator.max_tokens = max_tokens + generator.context_len = context_len + generator.generator = generator.client.completions + mock_url = getattr(generator, "uri", "https://api.openai.com/v1") + with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock: + mock_response = openai_compat_mocks["completion"] + respx_mock.post("/completions").mock( + return_value=httpx.Response( + mock_response["code"], json=mock_response["json"] + ) + ) + prompt_text = create_prompt(prompt_length) + if makes_request: + resp = generator._call_model(prompt_text) + assert not respx_mock.routes[0].called + assert resp == [None] + else: + generator._call_model(prompt_text) + req_body = json.loads(respx_mock.routes[0].calls[0].request.content) + assert check_lambda(req_body), err_msg