1
1
"""Integration with OpenAI's API."""
2
+ import copy
2
3
import functools
3
- import warnings
4
4
from dataclasses import asdict , dataclass , field , replace
5
- from itertools import zip_longest
6
- from typing import Callable , Dict , List , Optional , Set , Tuple , Union
5
+ from typing import Callable , Dict , List , Optional , Tuple , Union
7
6
8
7
import numpy as np
9
8
@@ -74,7 +73,6 @@ def __init__(
74
73
self ,
75
74
client ,
76
75
config ,
77
- tokenizer = None ,
78
76
system_prompt : Optional [str ] = None ,
79
77
):
80
78
"""Create an `OpenAI` instance.
@@ -89,13 +87,9 @@ def __init__(
89
87
config
90
88
An instance of `OpenAIConfig`. Can be useful to specify some
91
89
parameters that cannot be set by calling this class' methods.
92
- tokenizer
93
- The tokenizer associated with the model the client connects to.
94
-
95
90
"""
96
91
97
92
self .client = client
98
- self .tokenizer = tokenizer
99
93
self .config = config
100
94
101
95
# We count the total number of prompt and generated tokens as returned
@@ -104,6 +98,8 @@ def __init__(
104
98
self .prompt_tokens = 0
105
99
self .completion_tokens = 0
106
100
101
+ self .format_sequence = lambda seq : seq
102
+
107
103
def __call__ (
108
104
self ,
109
105
prompt : Union [str , List [str ]],
@@ -152,107 +148,17 @@ def __call__(
152
148
self .prompt_tokens += prompt_tokens
153
149
self .completion_tokens += completion_tokens
154
150
155
- return response
151
+ return self . format_sequence ( response )
156
152
157
153
def stream (self , * args , ** kwargs ):
158
154
raise NotImplementedError (
159
155
"Streaming is currently not supported for the OpenAI API"
160
156
)
161
157
162
- def generate_choice (
163
- self ,
164
- prompt : str ,
165
- choices : List [str ],
166
- max_tokens : Optional [int ] = None ,
167
- system_prompt : Optional [str ] = None ,
168
- ) -> str :
169
- """Call the OpenAI API to generate one of several choices.
170
-
171
- Parameters
172
- ----------
173
- prompt
174
- A string or list of strings that will be used to prompt the model
175
- choices
176
- The list of strings between which we ask the model to choose
177
- max_tokens
178
- The maximum number of tokens to generate
179
- system_prompt
180
- The content of the system message that precedes the user's prompt.
181
-
182
- """
183
- if self .tokenizer is None :
184
- raise ValueError (
185
- "You must initialize the `OpenAI` class with a tokenizer to use `outlines.generate.choice`"
186
- )
187
-
188
- config = replace (self .config , max_tokens = max_tokens )
189
-
190
- greedy = False
191
- decoded : List [str ] = []
192
- encoded_choices_left : List [List [int ]] = [
193
- self .tokenizer .encode (word ) for word in choices
194
- ]
195
-
196
- while len (encoded_choices_left ) > 0 :
197
- max_tokens_left = max ([len (tokens ) for tokens in encoded_choices_left ])
198
- transposed_choices_left : List [Set ] = [
199
- {item for item in subset if item is not None }
200
- for subset in zip_longest (* encoded_choices_left )
201
- ]
202
-
203
- if not greedy :
204
- mask = build_optimistic_mask (transposed_choices_left )
205
- else :
206
- mask = {}
207
- for token in transposed_choices_left [0 ]: # build greedy mask
208
- mask [token ] = 100
209
-
210
- if len (mask ) == 0 :
211
- break
212
-
213
- config = replace (config , logit_bias = mask , max_tokens = max_tokens_left )
214
-
215
- response , prompt_tokens , completion_tokens = generate_chat (
216
- prompt , system_prompt , self .client , config
217
- )
218
- self .prompt_tokens += prompt_tokens
219
- self .completion_tokens += completion_tokens
220
-
221
- encoded_response = self .tokenizer .encode (response )
222
-
223
- if encoded_response in encoded_choices_left :
224
- decoded .append (response )
225
- break
226
- else :
227
- (
228
- encoded_response ,
229
- encoded_choices_left ,
230
- ) = find_response_choices_intersection (
231
- encoded_response , encoded_choices_left
232
- )
233
-
234
- if len (encoded_response ) == 0 :
235
- greedy = True # next iteration will be "greedy"
236
- continue
237
- else :
238
- decoded .append ("" .join (self .tokenizer .decode (encoded_response )))
239
-
240
- if len (encoded_choices_left ) == 1 : # only one choice left
241
- choice_left = self .tokenizer .decode (encoded_choices_left [0 ])
242
- decoded .append (choice_left )
243
- break
244
-
245
- greedy = False # after each success, stay with (or switch to) "optimistic" approach
246
-
247
- prompt = prompt + "" .join (decoded )
248
-
249
- choice = "" .join (decoded )
250
-
251
- return choice
252
-
253
- def generate_json (self ):
254
- """Call the OpenAI API to generate a JSON object."""
255
- raise NotImplementedError
158
+ def new_with_replacements (self , ** kwargs ):
159
+ new_instance = copy .copy (self )
160
+ new_instance .config = replace (new_instance .config , ** kwargs )
161
+ return new_instance
256
162
257
163
def __str__ (self ):
258
164
return self .__class__ .__name__ + " API"
@@ -313,81 +219,6 @@ async def call_api(prompt, system_prompt, config):
313
219
return results , usage ["prompt_tokens" ], usage ["completion_tokens" ]
314
220
315
221
316
- def find_longest_intersection (response : List [int ], choice : List [int ]) -> List [int ]:
317
- """Find the longest intersection between the response and the choice."""
318
- for i , (token_r , token_c ) in enumerate (zip_longest (response , choice )):
319
- if token_r != token_c :
320
- return response [:i ]
321
-
322
- return response
323
-
324
-
325
- def find_response_choices_intersection (
326
- response : List [int ], choices : List [List [int ]]
327
- ) -> Tuple [List [int ], List [List [int ]]]:
328
- """Find the longest intersection between the response and the different
329
- choices.
330
-
331
- Say the response is of the form `[1, 2, 3, 4, 5]` and we have the choices
332
- `[[1, 2], [1, 2, 3], [6, 7, 8]` then the function will return `[1, 2, 3]` as the
333
- intersection, and `[[]]` as the list of choices left.
334
-
335
- Parameters
336
- ----------
337
- response
338
- The model's response
339
- choices
340
- The remaining possible choices
341
-
342
- Returns
343
- -------
344
- A tuple that contains the longest intersection between the response and the
345
- different choices, and the choices which start with this intersection, with the
346
- intersection removed.
347
-
348
- """
349
- max_len_prefix = 0
350
- choices_left = []
351
- longest_prefix = []
352
- for i , choice in enumerate (choices ):
353
- # Find the longest intersection between the response and the choice.
354
- prefix = find_longest_intersection (response , choice )
355
-
356
- if len (prefix ) > max_len_prefix :
357
- max_len_prefix = len (prefix )
358
- choices_left = [choice [len (prefix ) :]]
359
- longest_prefix = prefix
360
-
361
- elif len (prefix ) == max_len_prefix :
362
- choices_left .append (choice [len (prefix ) :])
363
-
364
- return longest_prefix , choices_left
365
-
366
-
367
- def build_optimistic_mask (
368
- transposed : List [Set [int ]], max_mask_size : int = 300
369
- ) -> Dict [int , int ]:
370
- """We build the largest mask possible.
371
-
372
- Tokens are added from left to right, so if the encoded choices are e.g.
373
- `[[1,2], [3,4]]`, `1` and `3` will be added before `2` and `4`.
374
-
375
- Parameters
376
- ----------
377
- transposed
378
- A list of lists that contain the nth token of each choice.
379
-
380
- """
381
- mask : Dict [int , int ] = {}
382
- for tokens in transposed :
383
- for token in tokens :
384
- if len (mask ) == max_mask_size :
385
- return mask
386
- mask [token ] = 100
387
-
388
- return mask
389
-
390
-
391
222
def error_handler (api_call_fn : Callable ) -> Callable :
392
223
"""Handle OpenAI API errors and missing API key."""
393
224
@@ -427,11 +258,10 @@ def openai_model(
427
258
** openai_client_params ,
428
259
):
429
260
try :
430
- import tiktoken
431
261
from openai import AsyncOpenAI
432
262
except ImportError :
433
263
raise ImportError (
434
- "The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' OpenAI integration."
264
+ "The `openai` library needs to be installed in order to use Outlines' OpenAI integration."
435
265
)
436
266
437
267
if config is not None :
@@ -441,15 +271,7 @@ def openai_model(
441
271
442
272
client = AsyncOpenAI (** openai_client_params )
443
273
444
- try :
445
- tokenizer = tiktoken .encoding_for_model (model_name )
446
- except KeyError :
447
- warnings .warn (
448
- f"Could not find a tokenizer for model { model_name } . Using default cl100k_base."
449
- )
450
- tokenizer = tiktoken .get_encoding ("cl100k_base" )
451
-
452
- return OpenAI (client , config , tokenizer )
274
+ return OpenAI (client , config )
453
275
454
276
455
277
def azure_openai (
@@ -459,11 +281,10 @@ def azure_openai(
459
281
** azure_openai_client_params ,
460
282
):
461
283
try :
462
- import tiktoken
463
284
from openai import AsyncAzureOpenAI
464
285
except ImportError :
465
286
raise ImportError (
466
- "The `openai` and `tiktoken` libraries needs to be installed in order to use Outlines' Azure OpenAI integration."
287
+ "The `openai` library needs to be installed in order to use Outlines' Azure OpenAI integration."
467
288
)
468
289
469
290
if config is not None :
@@ -473,12 +294,4 @@ def azure_openai(
473
294
474
295
client = AsyncAzureOpenAI (** azure_openai_client_params )
475
296
476
- try :
477
- tokenizer = tiktoken .encoding_for_model (model_name or deployment_name )
478
- except KeyError :
479
- warnings .warn (
480
- f"Could not find a tokenizer for model { model_name or deployment_name } . Using default cl100k_base."
481
- )
482
- tokenizer = tiktoken .get_encoding ("cl100k_base" )
483
-
484
- return OpenAI (client , config , tokenizer )
297
+ return OpenAI (client , config )
0 commit comments