Skip to content

Commit 5591950

Browse files
lapp0rlouf
authored andcommitted
Use OpenAI API For Structured Generation (json, choice)
1 parent 2b1aed0 commit 5591950

File tree

5 files changed

+134
-261
lines changed

5 files changed

+134
-261
lines changed

outlines/generate/choice.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
1+
import json as pyjson
12
from functools import singledispatch
23
from typing import Callable, List
34

45
from outlines.generate.api import SequenceGeneratorAdapter
56
from outlines.models import OpenAI
67
from outlines.samplers import Sampler, multinomial
78

9+
from .json import json
810
from .regex import regex
911

1012

@@ -24,13 +26,22 @@ def choice(
2426
def choice_openai(
2527
model: OpenAI, choices: List[str], sampler: Sampler = multinomial()
2628
) -> Callable:
27-
if not isinstance(sampler, multinomial):
28-
raise NotImplementedError(
29-
r"The OpenAI API does not support any other sampling algorithm "
30-
+ "that the multinomial sampler."
31-
)
32-
33-
def generate_choice(prompt: str, max_tokens: int = 1):
34-
return model.generate_choice(prompt, choices, max_tokens)
29+
"""
30+
Call OpenAI API with response_format of a dict:
31+
{"result": <one of choices>}
32+
"""
33+
34+
choices_schema = pyjson.dumps(
35+
{
36+
"type": "object",
37+
"properties": {"result": {"type": "string", "enum": choices}},
38+
"additionalProperties": False,
39+
"required": ["result"],
40+
}
41+
)
42+
generator = json(model, choices_schema, sampler)
43+
44+
def generate_choice(*args, **kwargs):
45+
return generator(*args, **kwargs)["result"]
3546

3647
return generate_choice

outlines/generate/json.py

Lines changed: 34 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,9 +70,39 @@ def json(
7070

7171
@json.register(OpenAI)
7272
def json_openai(
73-
model, schema_object: Union[str, object, Callable], sampler: Sampler = multinomial()
73+
model, schema_object: Union[str, object], sampler: Sampler = multinomial()
7474
):
75-
raise NotImplementedError(
76-
"Cannot use JSON Schema-structure generation with an OpenAI model "
77-
+ "due to the limitations of the OpenAI API"
75+
if not isinstance(sampler, multinomial):
76+
raise NotImplementedError(
77+
r"The OpenAI API does not support any other sampling algorithm "
78+
+ "than the multinomial sampler."
79+
)
80+
81+
if isinstance(schema_object, type(BaseModel)):
82+
schema = pyjson.dumps(schema_object.model_json_schema())
83+
format_sequence = lambda x: schema_object.parse_raw(x)
84+
elif isinstance(schema_object, str):
85+
schema = schema_object
86+
format_sequence = lambda x: pyjson.loads(x)
87+
else:
88+
raise ValueError(
89+
f"Cannot parse schema {schema_object}. The schema must be either "
90+
+ "a Pydantic object, a function or a string that contains the JSON "
91+
+ "Schema specification"
92+
)
93+
94+
# create copied, patched model with normalized json schema set
95+
generator = model.new_with_replacements(
96+
response_format={
97+
"type": "json_schema",
98+
"json_schema": {
99+
"name": "default",
100+
"strict": True,
101+
"schema": pyjson.loads(schema),
102+
},
103+
}
78104
)
105+
106+
generator.format_sequence = format_sequence
107+
108+
return generator

outlines/models/openai.py

Lines changed: 13 additions & 200 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,8 @@
11
"""Integration with OpenAI's API."""
2+
import copy
23
import functools
3-
import warnings
44
from dataclasses import asdict, dataclass, field, replace
5-
from itertools import zip_longest
6-
from typing import Callable, Dict, List, Optional, Set, Tuple, Union
5+
from typing import Callable, Dict, List, Optional, Tuple, Union
76

87
import numpy as np
98

@@ -74,7 +73,6 @@ def __init__(
7473
self,
7574
client,
7675
config,
77-
tokenizer=None,
7876
system_prompt: Optional[str] = None,
7977
):
8078
"""Create an `OpenAI` instance.
@@ -89,13 +87,9 @@ def __init__(
8987
config
9088
An instance of `OpenAIConfig`. Can be useful to specify some
9189
parameters that cannot be set by calling this class' methods.
92-
tokenizer
93-
The tokenizer associated with the model the client connects to.
94-
9590
"""
9691

9792
self.client = client
98-
self.tokenizer = tokenizer
9993
self.config = config
10094

10195
# We count the total number of prompt and generated tokens as returned
@@ -104,6 +98,8 @@ def __init__(
10498
self.prompt_tokens = 0
10599
self.completion_tokens = 0
106100

101+
self.format_sequence = lambda seq: seq
102+
107103
def __call__(
108104
self,
109105
prompt: Union[str, List[str]],
@@ -152,107 +148,17 @@ def __call__(
152148
self.prompt_tokens += prompt_tokens
153149
self.completion_tokens += completion_tokens
154150

155-
return response
151+
return self.format_sequence(response)
156152

157153
def stream(self, *args, **kwargs):
158154
raise NotImplementedError(
159155
"Streaming is currently not supported for the OpenAI API"
160156
)
161157

162-
def generate_choice(
163-
self,
164-
prompt: str,
165-
choices: List[str],
166-
max_tokens: Optional[int] = None,
167-
system_prompt: Optional[str] = None,
168-
) -> str:
169-
"""Call the OpenAI API to generate one of several choices.
170-
171-
Parameters
172-
----------
173-
prompt
174-
A string or list of strings that will be used to prompt the model
175-
choices
176-
The list of strings between which we ask the model to choose
177-
max_tokens
178-
The maximum number of tokens to generate
179-
system_prompt
180-
The content of the system message that precedes the user's prompt.
181-
182-
"""
183-
if self.tokenizer is None:
184-
raise ValueError(
185-
"You must initialize the `OpenAI` class with a tokenizer to use `outlines.generate.choice`"
186-
)
187-
188-
config = replace(self.config, max_tokens=max_tokens)
189-
190-
greedy = False
191-
decoded: List[str] = []
192-
encoded_choices_left: List[List[int]] = [
193-
self.tokenizer.encode(word) for word in choices
194-
]
195-
196-
while len(encoded_choices_left) > 0:
197-
max_tokens_left = max([len(tokens) for tokens in encoded_choices_left])
198-
transposed_choices_left: List[Set] = [
199-
{item for item in subset if item is not None}
200-
for subset in zip_longest(*encoded_choices_left)
201-
]
202-
203-
if not greedy:
204-
mask = build_optimistic_mask(transposed_choices_left)
205-
else:
206-
mask = {}
207-
for token in transposed_choices_left[0]: # build greedy mask
208-
mask[token] = 100
209-
210-
if len(mask) == 0:
211-
break
212-
213-
config = replace(config, logit_bias=mask, max_tokens=max_tokens_left)
214-
215-
response, prompt_tokens, completion_tokens = generate_chat(
216-
prompt, system_prompt, self.client, config
217-
)
218-
self.prompt_tokens += prompt_tokens
219-
self.completion_tokens += completion_tokens
220-
221-
encoded_response = self.tokenizer.encode(response)
222-
223-
if encoded_response in encoded_choices_left:
224-
decoded.append(response)
225-
break
226-
else:
227-
(
228-
encoded_response,
229-
encoded_choices_left,
230-
) = find_response_choices_intersection(
231-
encoded_response, encoded_choices_left
232-
)
233-
234-
if len(encoded_response) == 0:
235-
greedy = True # next iteration will be "greedy"
236-
continue
237-
else:
238-
decoded.append("".join(self.tokenizer.decode(encoded_response)))
239-
240-
if len(encoded_choices_left) == 1: # only one choice left
241-
choice_left = self.tokenizer.decode(encoded_choices_left[0])
242-
decoded.append(choice_left)
243-
break
244-
245-
greedy = False # after each success, stay with (or switch to) "optimistic" approach
246-
247-
prompt = prompt + "".join(decoded)
248-
249-
choice = "".join(decoded)
250-
251-
return choice
252-
253-
def generate_json(self):
254-
"""Call the OpenAI API to generate a JSON object."""
255-
raise NotImplementedError
158+
def new_with_replacements(self, **kwargs):
159+
new_instance = copy.copy(self)
160+
new_instance.config = replace(new_instance.config, **kwargs)
161+
return new_instance
256162

257163
def __str__(self):
258164
return self.__class__.__name__ + " API"
@@ -313,81 +219,6 @@ async def call_api(prompt, system_prompt, config):
313219
return results, usage["prompt_tokens"], usage["completion_tokens"]
314220

315221

316-
def find_longest_intersection(response: List[int], choice: List[int]) -> List[int]:
317-
"""Find the longest intersection between the response and the choice."""
318-
for i, (token_r, token_c) in enumerate(zip_longest(response, choice)):
319-
if token_r != token_c:
320-
return response[:i]
321-
322-
return response
323-
324-
325-
def find_response_choices_intersection(
326-
response: List[int], choices: List[List[int]]
327-
) -> Tuple[List[int], List[List[int]]]:
328-
"""Find the longest intersection between the response and the different
329-
choices.
330-
331-
Say the response is of the form `[1, 2, 3, 4, 5]` and we have the choices
332-
`[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2, 3]` as the
333-
intersection, and `[[]]` as the list of choices left.
334-
335-
Parameters
336-
----------
337-
response
338-
The model's response
339-
choices
340-
The remaining possible choices
341-
342-
Returns
343-
-------
344-
A tuple that contains the longest intersection between the response and the
345-
different choices, and the choices which start with this intersection, with the
346-
intersection removed.
347-
348-
"""
349-
max_len_prefix = 0
350-
choices_left = []
351-
longest_prefix = []
352-
for i, choice in enumerate(choices):
353-
# Find the longest intersection between the response and the choice.
354-
prefix = find_longest_intersection(response, choice)
355-
356-
if len(prefix) > max_len_prefix:
357-
max_len_prefix = len(prefix)
358-
choices_left = [choice[len(prefix) :]]
359-
longest_prefix = prefix
360-
361-
elif len(prefix) == max_len_prefix:
362-
choices_left.append(choice[len(prefix) :])
363-
364-
return longest_prefix, choices_left
365-
366-
367-
def build_optimistic_mask(
368-
transposed: List[Set[int]], max_mask_size: int = 300
369-
) -> Dict[int, int]:
370-
"""We build the largest mask possible.
371-
372-
Tokens are added from left to right, so if the encoded choices are e.g.
373-
`[[1,2], [3,4]]`, `1` and `3` will be added before `2` and `4`.
374-
375-
Parameters
376-
----------
377-
transposed
378-
A list of lists that contain the nth token of each choice.
379-
380-
"""
381-
mask: Dict[int, int] = {}
382-
for tokens in transposed:
383-
for token in tokens:
384-
if len(mask) == max_mask_size:
385-
return mask
386-
mask[token] = 100
387-
388-
return mask
389-
390-
391222
def error_handler(api_call_fn: Callable) -> Callable:
392223
"""Handle OpenAI API errors and missing API key."""
393224

@@ -427,11 +258,10 @@ def openai_model(
427258
**openai_client_params,
428259
):
429260
try:
430-
import tiktoken
431261
from openai import AsyncOpenAI
432262
except ImportError:
433263
raise ImportError(
434-
"The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' OpenAI integration."
264+
"The `openai` library needs to be installed in order to use Outlines' OpenAI integration."
435265
)
436266

437267
if config is not None:
@@ -441,15 +271,7 @@ def openai_model(
441271

442272
client = AsyncOpenAI(**openai_client_params)
443273

444-
try:
445-
tokenizer = tiktoken.encoding_for_model(model_name)
446-
except KeyError:
447-
warnings.warn(
448-
f"Could not find a tokenizer for model {model_name}. Using default cl100k_base."
449-
)
450-
tokenizer = tiktoken.get_encoding("cl100k_base")
451-
452-
return OpenAI(client, config, tokenizer)
274+
return OpenAI(client, config)
453275

454276

455277
def azure_openai(
@@ -459,11 +281,10 @@ def azure_openai(
459281
**azure_openai_client_params,
460282
):
461283
try:
462-
import tiktoken
463284
from openai import AsyncAzureOpenAI
464285
except ImportError:
465286
raise ImportError(
466-
"The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' Azure OpenAI integration."
287+
"The `openai` library needs to be installed in order to use Outlines' Azure OpenAI integration."
467288
)
468289

469290
if config is not None:
@@ -473,12 +294,4 @@ def azure_openai(
473294

474295
client = AsyncAzureOpenAI(**azure_openai_client_params)
475296

476-
try:
477-
tokenizer = tiktoken.encoding_for_model(model_name or deployment_name)
478-
except KeyError:
479-
warnings.warn(
480-
f"Could not find a tokenizer for model {model_name or deployment_name}. Using default cl100k_base."
481-
)
482-
tokenizer = tiktoken.get_encoding("cl100k_base")
483-
484-
return OpenAI(client, config, tokenizer)
297+
return OpenAI(client, config)

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,6 @@ module = [
122122
"pydantic.*",
123123
"pytest",
124124
"referencing.*",
125-
"tiktoken.*",
126125
"torch.*",
127126
"transformers.*",
128127
"llama_cpp",

0 commit comments

Comments
 (0)