Skip to content

Commit

Permalink
Added annotations
Browse files Browse the repository at this point in the history
  • Loading branch information
oandreeva-nv committed Aug 27, 2024
1 parent 2a6dab5 commit 7964b3f
Show file tree
Hide file tree
Showing 2 changed files with 75 additions and 6 deletions.
1 change: 1 addition & 0 deletions AI_Agents_Guide/Constrained_Decoding/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ prompt += "Here's the json schema you must adhere to:\n<schema>\n{schema}\n</sch
schema=AnswerFormat.model_json_schema())

```
Let's try it out:

```bash
python3 /tutorials/AI_Agents_Guide/Constrained_Decoding/artifacts/client.py --prompt "Give me information about Harry Potter and the Order of Phoenix" -o 200 --use-system-prompt --use-schema
Expand Down
80 changes: 74 additions & 6 deletions AI_Agents_Guide/Constrained_Decoding/artifacts/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,29 @@
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions
# are met:
# * Redistributions of source code must retain the above copyright
# notice, this list of conditions and the following disclaimer.
# * Redistributions in binary form must reproduce the above copyright
# notice, this list of conditions and the following disclaimer in the
# documentation and/or other materials provided with the distribution.
# * Neither the name of NVIDIA CORPORATION nor the names of its
# contributors may be used to endorse or promote products derived
# from this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.

import json
import typing

Check notice

Code scanning / CodeQL

Module is imported with 'import' and 'import from' Note

Module 'typing' is imported with both 'import' and 'import from'.
from collections import defaultdict
Expand All @@ -14,12 +40,31 @@


class WandFormat(BaseModel):
"""Represents the format of a wand description.
Attributes:
wood (str): The type of wood used in the wand.
core (str): The core material of the wand.
length (float): The length of the wand.
"""

wood: str
core: str
length: float


class AnswerFormat(BaseModel):
"""Represents the output format, which LLM should follow.
Attributes:
name (str): The name of the person.
house (str): The house affiliation of the person (e.g., Gryffindor).
blood_status (str): The blood status (e.g., pure-blood).
occupation (str): The occupation of the person.
alive (str): Whether the person is alive.
wand (WandFormat): The wand information.
"""

name: str
house: str
blood_status: str
Expand All @@ -29,6 +74,10 @@ class AnswerFormat(BaseModel):


class LMFELogitsProcessor:
"""
The class implementing logits post-processor via LM Format Enforcer.
"""

PROCESSOR_NAME = "lmfe"

def __init__(self, tokenizer_dir, schema):
Expand All @@ -37,6 +86,9 @@ def __init__(self, tokenizer_dir, schema):
)
self.eos_token = tokenizer.eos_token_id
tokenizer_data = build_trtlmm_tokenizer_data(tokenizer)
# TokenEnforcer provides a token filtering mechanism,
# given a tokenizer and a CharacterLevelParser.
# ref: https://github.com/noamgat/lm-format-enforcer/blob/fe6cbf107218839624e3ab39b47115bf7f64dd6e/lmformatenforcer/tokenenforcer.py#L32
self.token_enforcer = TokenEnforcer(tokenizer_data, JsonSchemaParser(schema))

def get_allowed_tokens(self, ids):
Expand All @@ -53,14 +105,21 @@ def __call__(
ids: typing.List[typing.List[int]],
stream_ptr: int,
):
# Create a mask with negative infinity to block all tokens initially.
mask = torch.full_like(logits, fill_value=float("-inf"), device=logits.device)
allowed = self.get_allowed_tokens(ids)
# Update the mask to zero for allowed tokens,
# allowing them to be selected.
mask[:, :, allowed] = 0
with torch.cuda.stream(torch.cuda.ExternalStream(stream_ptr)):
logits += mask


class OutlinesLogitsProcessor:
"""
The class implementing logits post-processor via Outlines.
"""

PROCESSOR_NAME = "outlines"

def __init__(self, tokenizer_dir, schema):
Expand All @@ -72,6 +131,9 @@ def __init__(self, tokenizer_dir, schema):
self.fsm = RegexGuide(regex_string, tokenizer)
self._fsm_state: DefaultDict[int, int] = defaultdict(int)
self.mask_cache: Dict[int, torch.Tensor] = {}
# By default, TensorRT-LLM includes request query into the output.
# Outlines should only look at generated outputs, thus we'll keep
# track of the request's input prefix.
self._prefix = [-1]

def __call__(
Expand All @@ -81,22 +143,24 @@ def __call__(
ids: typing.List[typing.List[int]],
stream_ptr: int,
):
# Initialize the FSM state dictionary if the input_ids are empty, as this means
# that the input_ids are the first tokens of the sequence.
# Initialize the FSM state dictionary if the input_ids are empty,
# as this means that the input_ids are the first tokens of the sequence.
seq_id = hash(tuple(ids[0]))

Check warning

Code scanning / CodeQL

Variable defined multiple times Warning

This assignment to 'seq_id' is unnecessary as it is
redefined
before this value is used.
This assignment to 'seq_id' is unnecessary as it is
redefined
before this value is used.
# If the prefix token IDs have changed we assume that we are dealing with a new
# sample and reset the FSM state
# If the prefix token IDs have changed we assume that we are dealing
# with a new sample and reset the FSM state
if (
ids[0][: len(self._prefix)] != self._prefix
# handling edge case, when the new request is identical to already
# processed
or len(ids[0][len(self._prefix) :]) == 0
):
self._fsm_state = defaultdict(int)
self._prefix = ids[0]
seq_id = hash(tuple([]))

else:
# Remove the prefix token IDs from the input token IDs, as the FSM should
# only be applied to the generated tokens
# Remove the prefix token IDs from the input token IDs,
# because the FSM should only be applied to the generated tokens
ids = ids[0][len(self._prefix) :]
last_token = ids[-1]
last_seq_id = hash(tuple(ids[:-1]))
Expand All @@ -110,9 +174,13 @@ def __call__(
allowed_tokens = self.fsm.get_next_instruction(
state=self._fsm_state[seq_id]
).tokens
# Create a mask with negative infinity to block all
# tokens initially.
mask = torch.full_like(
logits, fill_value=float("-inf"), device=logits.device
)
# Update the mask to zero for allowed tokens,
# allowing them to be selected.
mask[:, :, allowed_tokens] = 0
self.mask_cache[state_id] = mask
else:
Expand Down

0 comments on commit 7964b3f

Please sign in to comment.