Skip to content

Commit 8a728b1

Browse files
cpfifferrlouf
authored andcommitted
Add LogitTrackingProcessor
1 parent 9c98de7 commit 8a728b1

19 files changed

+1468
-12
lines changed

Diff for: docs/reference/processors.md

+247
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,247 @@
1+
# Logit processors
2+
3+
Logit processors modify token probabilities during text generation to enforce constraints or analyze the generation process. While processors can be used directly, most users will interact with them through the high-level generation APIs (see [Generating JSON](generation/json.md), [Regex Generation](generation/regex.md), and [CFG Generation](generation/cfg.md)).
4+
5+
Users can track the token probabilities and logits at each step of the generation process using the `LogitTrackingProcessor`. This is useful for debugging and understanding the generation process.
6+
7+
## Available Processors
8+
9+
Outlines provides several specialized processors for different use cases:
10+
11+
- `JSONLogitsProcessor`: Ensures generation follows a JSON schema
12+
- `RegexLogitsProcessor`: Constrains generation to match a regex pattern
13+
- `CFGLogitsProcessor`: Enforces a context-free grammar
14+
- `LogitTrackingProcessor`: Tracks token probabilities and logits
15+
16+
### RegexLogitsProcessor
17+
18+
The `RegexLogitsProcessor` constrains generation to match a regular expression pattern:
19+
20+
```python
21+
from outlines.processors import RegexLogitsProcessor
22+
23+
# Create a processor that only allows 4-digit numbers
24+
processor = RegexLogitsProcessor(r"[0-9]{4}", tokenizer)
25+
26+
# Use with a generator
27+
generator = outlines.generate.regex(model, r"[0-9]{4}")
28+
generator.logits_processor = processor
29+
```
30+
31+
See [Regex Generation](generation/regex.md) for more details and examples.
32+
33+
### JSONLogitsProcessor
34+
35+
The `JSONLogitsProcessor` ensures generation follows a JSON schema defined using Pydantic:
36+
37+
```python
38+
from pydantic import BaseModel
39+
from outlines.processors import JSONLogitsProcessor
40+
41+
class Response(BaseModel):
42+
name: str
43+
age: int
44+
city: str
45+
46+
# Create processor from schema
47+
processor = JSONLogitsProcessor(Response, tokenizer)
48+
49+
# Use with a generator
50+
generator = outlines.generate.json(model, Response)
51+
generator.logits_processor = processor
52+
```
53+
54+
See [Generating JSON](generation/json.md) for more details and examples.
55+
56+
### CFGLogitsProcessor
57+
58+
The `CFGLogitsProcessor` constrains generation to follow a context-free grammar:
59+
60+
```python
61+
from outlines.processors import CFGLogitsProcessor
62+
63+
# Define a simple grammar
64+
grammar = """
65+
start: NUMBER "+" NUMBER "=" NUMBER
66+
NUMBER: /[0-9]+/
67+
"""
68+
69+
# Create processor from grammar
70+
processor = CFGLogitsProcessor(grammar, tokenizer)
71+
72+
# Use with a generator
73+
generator = outlines.generate.cfg(model, grammar)
74+
generator.logits_processor = processor
75+
```
76+
77+
See [CFG Generation](generation/cfg.md) for more details and examples.
78+
79+
## Tracking logit scores and token probabilities
80+
81+
The `LogitTrackingProcessor` wraps any processor to track logit scores and token probabilities before and after processing. This is useful for:
82+
83+
- Debugging logit processors by analyzing how they modify token probabilities
84+
- Visualizing the effects of logit biasing on token distributions
85+
- Understanding how constraints affect the generation process
86+
- Validating that processors are working as intended
87+
88+
### Adding tracking to a generator
89+
90+
The simplest way to add tracking is using the convenience function `track_logits`:
91+
92+
```python
93+
from outlines import generate, models
94+
from outlines.processors import track_logits
95+
from pydantic import BaseModel
96+
97+
# Define your schema
98+
class Person(BaseModel):
99+
name: str
100+
age: int
101+
102+
# Create generator with tracking
103+
model = models.transformers("HuggingFaceTB/SmolLM2-135M-Instruct")
104+
generator = generate.json(model, Person)
105+
generator = track_logits(generator) # Enable tracking
106+
107+
# Apply templating if needed
108+
prompt = model.tokenizer.tokenizer.apply_chat_template(
109+
[{"role": "system", "content": "You are a helpful assistant, responding in JSON."},
110+
{"role": "user", "content": "Make me a person with a name and age. Return the JSON only."}],
111+
tokenize=False,
112+
add_bos=True,
113+
add_generation_prompt=True,
114+
)
115+
116+
# Generate the response
117+
response = generator(prompt)
118+
```
119+
120+
**NOTE**: You __must__ use `generator.logits_processor.clear()` between generations, otherwise the processor will use the logits from the previous generation. You may also construct a new generator and call `track_logits` again to start tracking from scratch.
121+
122+
### Analyzing generation results
123+
124+
Once tracking is enabled, you can analyze the generation process in several ways:
125+
126+
1. Get the logits and probabilities at each position as a matrix:
127+
128+
```python
129+
# Raw logits as a dictionary with two keys: unstructured and structured
130+
logits = generator.logits_processor.get_logits()
131+
132+
# Get a vocab_size x n_positions matrix of logits for
133+
# structured and unstructured logits
134+
unstructured_logits = logits['unstructured']
135+
structured_logits = logits['structured']
136+
137+
probabilities = generator.logits_processor.get_probabilities()
138+
139+
# Get a vocab_size x n_positions matrix of probabilities
140+
# for structured and unstructured logits
141+
unstructured_probs = probabilities['unstructured']
142+
structured_probs = probabilities['structured']
143+
```
144+
145+
2. Get the top tokens at each position:
146+
147+
```python
148+
# Get top 5 tokens at each position
149+
top_k = generator.logits_processor.get_top_tokens(k=5)
150+
151+
# Analyze each position
152+
for position_dict in top_k:
153+
print(f"\nPosition {position_dict['position']}:")
154+
print(f"Text so far: {position_dict['text_so_far']}")
155+
156+
for token in position_dict['tokens']:
157+
print(f"\nToken: {token['token']}")
158+
print(f"Unstructured probability: {token['unstructured_prob']:.3f}")
159+
print(f"Structured probability: {token['structured_prob']:.3f}")
160+
print(f"Unstructured logit: {token['unstructured_logit']:.3f}")
161+
print(f"Structured logit: {token['structured_logit']:.3f}")
162+
print(f"Was chosen: {token['is_chosen']}")
163+
```
164+
165+
3. Convert to a pandas DataFrame for analysis:
166+
167+
```python
168+
import pandas as pd
169+
170+
# Get all tokens with probability > 1%
171+
df = generator.logits_processor.to_dataframe(show="probs", min_value=0.01)
172+
print(df)
173+
# position token natural constrained chosen
174+
# 0 0 You 0.021324 0.0 False
175+
# 1 0 The 0.021959 0.0 False
176+
# 2 0 Sure 0.025492 0.0 False
177+
# 3 0 JSON 0.031045 0.0 False
178+
# 4 0 To 0.031047 0.0 False
179+
```
180+
181+
4. Get the generated sequence up to a position:
182+
183+
```python
184+
# Get text generated up to position 5
185+
text = generator.logits_processor.sequence(5)
186+
```
187+
188+
### Memory management
189+
190+
The tracking processor stores logits in memory for analysis, and offloads logits to main memory if you use a GPU. For long sequences, you have several options:
191+
192+
1. Clear tracking data when no longer needed:
193+
```python
194+
generator.logits_processor.clear()
195+
```
196+
197+
2. Filter data when analyzing:
198+
```python
199+
# Only analyze specific positions
200+
results = generator.logits_processor.get_top_tokens(positions=[0, 1, 2])
201+
202+
# Only look at high probability tokens
203+
df = generator.logits_processor.to_dataframe(show="probs", min_value=0.01)
204+
```
205+
206+
### Important notes about logit tracking
207+
208+
- Tracking logits is a slow operation, so do not use it in production environments
209+
- The processor will accumulate logits if you call `generator(prompt)` multiple times, meaning that the tokens stored can be aggregated across generations. You can use `generator.logits_processor.clear()` to reset the processor, or construct a new generator and call `track_logits` again to start tracking from scratch.
210+
- Processed logits will contain `-inf` values when structured outputs are used
211+
- Token decoding requires the wrapped processor to have a tokenizer attribute
212+
- Memory usage grows linearly with sequence length
213+
- The tracking processor only supports single-batch processing
214+
- Tracking logits can incur significant overhead -- do not use it in production environments
215+
216+
## Using the tracking processor directly
217+
218+
The tracking processor can be used directly with transformers pipelines:
219+
220+
```python
221+
import outlines.models as models
222+
import transformers
223+
from outlines.processors import RegexLogitsProcessor
224+
from outlines.processors.tracking import LogitTrackingProcessor
225+
226+
model_uri = "HuggingFaceTB/SmolLM2-135M-Instruct"
227+
model = models.transformers(model_uri)
228+
229+
outlines_tokenizer = models.TransformerTokenizer(
230+
transformers.AutoTokenizer.from_pretrained(model_uri)
231+
)
232+
phone_number_logits_processor = LogitTrackingProcessor(RegexLogitsProcessor(
233+
"\\+?[1-9][0-9]{7,14}", # phone number pattern
234+
outlines_tokenizer,
235+
))
236+
237+
generator = transformers.pipeline('text-generation', model=model_uri)
238+
239+
# Perform inference
240+
output = generator(
241+
"Jenny gave me her number it's ",
242+
logits_processor=transformers.LogitsProcessorList([phone_number_logits_processor])
243+
)
244+
245+
# Retrieve the logits
246+
phone_number_logits_processor.get_logits()
247+
```

0 commit comments

Comments
 (0)