Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 108 additions & 7 deletions garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import logging
import re
import tiktoken
from typing import List, Union

import openai
Expand Down Expand Up @@ -114,12 +115,24 @@
"gpt-4o": 128000,
"gpt-4o-2024-05-13": 128000,
"gpt-4o-2024-08-06": 128000,
"gpt-4o-mini": 16384,
"gpt-4o-mini": 128000,
"gpt-4o-mini-2024-07-18": 16384,
"o1-mini": 65536,
"o1": 200000,
"o1-mini": 128000,
"o1-mini-2024-09-12": 65536,
"o1-preview": 32768,
"o1-preview-2024-09-12": 32768,
"o3-mini": 200000,
}

output_max = {
"gpt-3.5-turbo": 4096,
"gpt-4": 8192,
"gpt-4o": 16384,
"o3-mini": 100000,
"o1": 100000,
"o1-mini": 65536,
"gpt-4o-mini": 16384,
}


Expand Down Expand Up @@ -173,6 +186,87 @@ def _clear_client(self):
def _validate_config(self):
pass

def _validate_token_args(self, create_args: dict, prompt: Conversation) -> dict:
"""Ensure maximum token limit compatibility with OpenAI create request"""
token_generation_limit_key = "max_tokens"
fixed_cost = 0
if (
self.generator == self.client.chat.completions
and self.max_tokens is not None
):
token_generation_limit_key = "max_completion_tokens"
if not hasattr(self, "max_completion_tokens"):
create_args["max_completion_tokens"] = self.max_tokens

create_args.pop(
"max_tokens", None
) # remove deprecated value, utilize `max_completion_tokens`
# every reply is primed with <|start|>assistant<|message|> (3 toks) plus 1 for name change
# see https://github.com/openai/openai-cookbook/blob/main/examples/How_to_count_tokens_with_tiktoken.ipynb
# section 6 "Counting tokens for chat completions API calls"
# TODO: adjust fixed cost to account for each `turn` already in the prompt
fixed_cost = 7

# basic token boundary validation to ensure requests are not rejected for exceeding target context length
token_generation_limit = create_args.pop(token_generation_limit_key, None)
if token_generation_limit is not None:
# Suppress max_tokens if greater than context_len
if (
hasattr(self, "context_len")
and self.context_len is not None
and token_generation_limit > self.context_len
):
logging.warning(
f"Requested garak maximum tokens {token_generation_limit} exceeds context length {self.context_len}, no limit will be applied to the request"
)
token_generation_limit = None

if (
self.name in output_max
and token_generation_limit is not None
and token_generation_limit > output_max[self.name]
):
logging.warning(
f"Requested maximum tokens {token_generation_limit} exceeds max output {output_max[self.name]}, no limit will be applied to the request"
)
token_generation_limit = None

if self.context_len is not None and token_generation_limit is not None:
# count tokens in prompt and ensure token_generation_limit requested is <= context_len or output_max allowed
prompt_tokens = 0 # this should apply to messages object
try:
encoding = tiktoken.encoding_for_model(self.name)
prompt_tokens = 0
for turn in prompt.turns:
prompt_tokens = len(encoding.encode(turn.content.text))
except KeyError as e:
prompt_tokens = int(
len(prompt.split()) * 4 / 3
) # extra naive fallback 1 token ~= 3/4 of a word

if (
prompt_tokens + fixed_cost + token_generation_limit
> self.context_len
) and (prompt_tokens + fixed_cost < self.context_len):
token_generation_limit = (
self.context_len - prompt_tokens - fixed_cost
)
elif token_generation_limit > prompt_tokens + fixed_cost:
token_generation_limit = (
token_generation_limit - prompt_tokens - fixed_cost
)
else:
raise garak.exception.GarakException(
"A response of %s toks plus prompt %s toks cannot be generated; API capped at context length %s toks"
% (
self.max_tokens,
prompt_tokens + fixed_cost,
self.context_len,
)
)
create_args[token_generation_limit_key] = token_generation_limit
return create_args

def __init__(self, name="", config_root=_config):
self.name = name
self._load_config(config_root)
Expand Down Expand Up @@ -218,7 +312,8 @@ def _call_model(
create_args = {}
if "n" not in self.suppressed_params:
create_args["n"] = generations_this_call
for arg in inspect.signature(self.generator.create).parameters:
create_params = inspect.signature(self.generator.create).parameters
for arg in create_params:
if arg == "model":
create_args[arg] = self.name
continue
Expand All @@ -232,14 +327,20 @@ def _call_model(
for k, v in self.extra_params.items():
create_args[k] = v

try:
create_args = self._validate_token_args(create_args, prompt)
except garak.exception.GarakException as e:
logging.exception(e)
return [None] * generations_this_call

if self.generator == self.client.completions:
if not isinstance(prompt, Conversation) or len(prompt.turns) > 1:
msg = (
f"Expected a Conversation with one Turn for {self.generator_family_name} completions model {self.name}, but got {type(prompt)}. "
f"Returning nothing!"
)
logging.error(msg)
return list()
return [None] * generations_this_call

create_args["prompt"] = prompt.last_message().text

Expand All @@ -255,7 +356,7 @@ def _call_model(
f"Returning nothing!"
)
logging.error(msg)
return list()
return [None] * generations_this_call

create_args["messages"] = messages

Expand All @@ -265,7 +366,7 @@ def _call_model(
msg = "Bad request: " + str(repr(prompt))
logging.exception(e)
logging.error(msg)
return [None]
return [None] * generations_this_call
except json.decoder.JSONDecodeError as e:
logging.exception(e)
if self.retry_json:
Expand All @@ -282,7 +383,7 @@ def _call_model(
if self.retry_json:
raise garak.exception.GarakBackoffTrigger(msg)
else:
return [None]
return [None] * generations_this_call

if self.generator == self.client.completions:
return [Message(c.text) for c in response.choices]
Expand Down
136 changes: 135 additions & 1 deletion tests/generators/test_openai_compatible.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,14 @@

import os
import httpx
import lorem
import respx
import pytest
import tiktoken
import importlib
import inspect


from collections.abc import Iterable

from garak.attempt import Message, Turn, Conversation
Expand All @@ -25,7 +28,7 @@
"generators.groq.GroqChat",
]

MODEL_NAME = "gpt-3.5-turbo-instruct"
MODEL_NAME = "gpt-3.5-turbo"
ENV_VAR = os.path.abspath(
__file__
) # use test path as hint encase env changes are missed
Expand Down Expand Up @@ -124,3 +127,134 @@ def test_openai_multiprocessing(openai_compat_mocks, classname):
assert isinstance(
result[0], Message
), "generator should return list of Turns or Nones"


def create_prompt(prompt_length: int):
test_large_context = ""
encoding = tiktoken.encoding_for_model(MODEL_NAME)
while len(encoding.encode(test_large_context)) < prompt_length:
test_large_context += "\n" + lorem.paragraph()
return Conversation([Turn(role="user", content=Message(test_large_context))])


TOKEN_LIMIT_EXPECTATIONS = {
"use_max_completion_tokens": (
100, # prompt_length
2048, # max_tokens
4096, # context_len
False, # makes_request
lambda a: a["max_completion_tokens"] <= 2048, # check_lambda
"request max_completion_tokens must account for prompt tokens", # err_msg
),
"use_max_tokens": (
100,
2048,
4096,
False,
lambda a: a["max_tokens"] <= 2048,
"request max_must account for prompt tokens",
),
"suppress_tokens": (
100,
4096,
2048,
False,
lambda a: a.get("max_completion_tokens", None) is None
and a.get("max_tokens", None) is None,
"request max_tokens is suppressed when larger than context length",
),
"skip_request_above_user_limit": (
4096,
2048,
4096,
True,
None,
"a prompt larger than max_tokens must skip request",
),
"skip_request_based_on_model_context": (
4096,
4096,
4096,
True,
None,
"a prompt larger than context_len must skip request",
),
}


@pytest.mark.parametrize(
"test_conditions",
[key for key in TOKEN_LIMIT_EXPECTATIONS.keys() if key != "use_max_tokens"],
)
def test_validate_call_model_chat_token_restrictions(
openai_compat_mocks, test_conditions
):
import json

prompt_length, max_tokens, context_len, makes_request, check_lambda, err_msg = (
TOKEN_LIMIT_EXPECTATIONS[test_conditions]
)

generator = build_test_instance(OpenAICompatible)
generator._load_client()
generator.max_tokens = max_tokens
generator.context_len = context_len
generator.generator = generator.client.chat.completions
mock_url = getattr(generator, "uri", "https://api.openai.com/v1")
with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock:
mock_response = openai_compat_mocks["chat"]
respx_mock.post("chat/completions").mock(
return_value=httpx.Response(
mock_response["code"], json=mock_response["json"]
)
)
prompt_text = create_prompt(prompt_length)
if makes_request:
resp = generator._call_model(prompt_text)
assert not respx_mock.routes[0].called
assert resp == [None]
else:
generator._call_model(prompt_text)
req_body = json.loads(respx_mock.routes[0].calls[0].request.content)
assert check_lambda(req_body), err_msg


@pytest.mark.parametrize(
"test_conditions",
[
key
for key in TOKEN_LIMIT_EXPECTATIONS.keys()
if key != "use_max_completion_tokens"
],
)
def test_validate_call_model_completion_token_restrictions(
openai_compat_mocks, test_conditions
):
import json

prompt_length, max_tokens, context_len, makes_request, check_lambda, err_msg = (
TOKEN_LIMIT_EXPECTATIONS[test_conditions]
)

generator = build_test_instance(OpenAICompatible)
generator._load_client()
generator.max_tokens = max_tokens
generator.context_len = context_len
generator.generator = generator.client.completions
mock_url = getattr(generator, "uri", "https://api.openai.com/v1")
with respx.mock(base_url=mock_url, assert_all_called=False) as respx_mock:
mock_response = openai_compat_mocks["completion"]
respx_mock.post("/completions").mock(
return_value=httpx.Response(
mock_response["code"], json=mock_response["json"]
)
)
prompt_text = create_prompt(prompt_length)
if makes_request:
resp = generator._call_model(prompt_text)
assert not respx_mock.routes[0].called
assert resp == [None]
else:
generator._call_model(prompt_text)
req_body = json.loads(respx_mock.routes[0].calls[0].request.content)
assert check_lambda(req_body), err_msg
Loading