|
8 | 8 | from ._chat import Chat |
9 | 9 | from ._content import ( |
10 | 10 | Content, |
| 11 | + ContentImage, |
11 | 12 | ContentImageInline, |
12 | 13 | ContentImageRemote, |
13 | 14 | ContentJson, |
|
20 | 21 | from ._provider import Provider |
21 | 22 | from ._tokens import tokens_log |
22 | 23 | from ._tools import Tool, basemodel_to_param_schema |
23 | | -from ._turn import Turn, normalize_turns |
| 24 | +from ._turn import Turn, normalize_turns, user_turn |
24 | 25 | from ._utils import MISSING, MISSING_TYPE, is_testing |
25 | 26 |
|
26 | 27 | if TYPE_CHECKING: |
@@ -351,6 +352,57 @@ async def stream_turn_async(self, completion, has_data_model, stream): |
351 | 352 | def value_turn(self, completion, has_data_model) -> Turn: |
352 | 353 | return self._as_turn(completion, has_data_model) |
353 | 354 |
|
| 355 | + def token_count( |
| 356 | + self, |
| 357 | + *args: Content | str, |
| 358 | + tools: dict[str, Tool], |
| 359 | + data_model: Optional[type[BaseModel]], |
| 360 | + ) -> int: |
| 361 | + try: |
| 362 | + import tiktoken |
| 363 | + except ImportError: |
| 364 | + raise ImportError( |
| 365 | + "The tiktoken package is required for token counting. " |
| 366 | + "Please install it with `pip install tiktoken`." |
| 367 | + ) |
| 368 | + |
| 369 | + encoding = tiktoken.encoding_for_model(self._model) |
| 370 | + |
| 371 | + turn = user_turn(*args) |
| 372 | + |
| 373 | + # Count the tokens in image contents |
| 374 | + image_tokens = sum( |
| 375 | + self._image_token_count(x) |
| 376 | + for x in turn.contents |
| 377 | + if isinstance(x, ContentImage) |
| 378 | + ) |
| 379 | + |
| 380 | + # For other contents, get the token count from the actual message param |
| 381 | + other_contents = [x for x in turn.contents if not isinstance(x, ContentImage)] |
| 382 | + other_full = self._as_message_param([Turn("user", other_contents)]) |
| 383 | + other_tokens = len(encoding.encode(str(other_full))) |
| 384 | + |
| 385 | + return other_tokens + image_tokens |
| 386 | + |
| 387 | + async def token_count_async( |
| 388 | + self, |
| 389 | + *args: Content | str, |
| 390 | + tools: dict[str, Tool], |
| 391 | + data_model: Optional[type[BaseModel]], |
| 392 | + ) -> int: |
| 393 | + return self.token_count(*args, tools=tools, data_model=data_model) |
| 394 | + |
| 395 | + @staticmethod |
| 396 | + def _image_token_count(image: ContentImage) -> int: |
| 397 | + if isinstance(image, ContentImageRemote) and image.detail == "low": |
| 398 | + return 85 |
| 399 | + else: |
| 400 | + # This is just the max token count for an image The highest possible |
| 401 | + # resolution is 768 x 2048, and 8 tiles of size 512px can fit inside |
| 402 | + # TODO: this is obviously a very conservative estimate and could be improved |
| 403 | + # https://platform.openai.com/docs/guides/vision/calculating-costs |
| 404 | + return 170 * 8 + 85 |
| 405 | + |
354 | 406 | @staticmethod |
355 | 407 | def _as_message_param(turns: list[Turn]) -> list["ChatCompletionMessageParam"]: |
356 | 408 | from openai.types.chat import ( |
|
0 commit comments