Skip to content

Commit 4870d18

Browse files
kashifpcuenca
andauthored
Logit warpers (#273)
* inital * formatting * added generation integration tests * use MLTensor * fix top-p and do-sample * fix CI issue * undo changes to LanguageModel.swift * use MLTensor * Update Tests/GenerationTests/GenerationIntegrationTests.swift Co-authored-by: Pedro Cuenca <[email protected]> * Throws: If penalty is not strictly positive * remove unused selectNextTokenUsingTopKSampling * make sure ordering of warpers is that of transformers * throw * add Min-P --------- Co-authored-by: Pedro Cuenca <[email protected]>
1 parent cda4025 commit 4870d18

13 files changed

+1336
-31
lines changed

Examples/transformers-cli/Sources/transformers-cli/Transformers.swift

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,28 @@ struct TransformersCLI: AsyncParsableCommand {
2727
@Option(
2828
help: """
2929
When enabled, two generation passes are ran, one to 'warm up' and another to collect \
30-
benchmark metrics.
30+
benchmark metrics.
3131
""")
3232
var warmup: Bool = false
3333

34+
@Option(help: "Enable sampling mode (true) or use greedy decoding (false)")
35+
var doSample: Bool = false
36+
37+
@Option(help: "Temperature for sampling (lower = more deterministic, typical: 0.1-2.0)")
38+
var temperature: Float?
39+
40+
@Option(help: "Top-k filtering - only consider k most likely tokens (typical: 5-50)")
41+
var topK: Int?
42+
43+
@Option(help: "Top-p (nucleus) sampling - cumulative probability threshold (typical: 0.9-0.95)")
44+
var topP: Float?
45+
46+
@Option(help: "Min-p sampling - minimum probability threshold scaled by top token (typical: 0.01-0.2)")
47+
var minP: Float?
48+
49+
@Option(help: "Repetition penalty to discourage repeating tokens (typical: 1.0-2.0, 1.0 = no penalty)")
50+
var repetitionPenalty: Float?
51+
3452
func generate(
3553
model: LanguageModel,
3654
config: GenerationConfig,
@@ -88,11 +106,26 @@ struct TransformersCLI: AsyncParsableCommand {
88106
print("Loading model \(compiledURL)")
89107
let model = try LanguageModel.loadCompiled(url: compiledURL, computeUnits: computeUnits.asMLComputeUnits)
90108

91-
// Using greedy generation for now
92109
var config = model.defaultGenerationConfig
93-
config.doSample = false
110+
config.doSample = doSample
94111
config.maxNewTokens = maxLength
95112

113+
if let temperature = temperature {
114+
config.temperature = temperature
115+
}
116+
if let topK = topK {
117+
config.topK = topK
118+
}
119+
if let topP = topP {
120+
config.topP = topP
121+
}
122+
if let minP = minP {
123+
config.minP = minP
124+
}
125+
if let repetitionPenalty = repetitionPenalty {
126+
config.repetitionPenalty = repetitionPenalty
127+
}
128+
96129
// Given the size of the out-of-model computation, dispatch all
97130
// tensor operations to the CPU.
98131

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ let package = Package(
2424
.target(name: "Hub", dependencies: [.product(name: "Jinja", package: "swift-jinja")], resources: [.process("Resources")], swiftSettings: swiftSettings),
2525
.target(name: "Models", dependencies: ["Tokenizers", "Generation"]),
2626
.target(name: "Tokenizers", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")]),
27+
.testTarget(name: "GenerationTests", dependencies: ["Generation"]),
2728
.testTarget(name: "HubTests", dependencies: ["Hub", .product(name: "Jinja", package: "swift-jinja")], swiftSettings: swiftSettings),
2829
.testTarget(name: "ModelsTests", dependencies: ["Models", "Hub"], resources: [.process("Resources")]),
2930
.testTarget(name: "TokenizersTests", dependencies: ["Tokenizers", "Models", "Hub"], resources: [.process("Resources")]),

Sources/Generation/Decoders.swift

Lines changed: 42 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,51 @@ import CoreML
55

66
@available(macOS 15.0, iOS 18.0, *)
77
func selectNextTokenUsingGreedyDecoding(from scores: MLTensor) -> MLTensor {
8-
scores.argmax(alongAxis: -1).reshaped(to: [1, 1])
8+
let indices = scores.argmax(alongAxis: -1).reshaped(to: [1, 1])
9+
// Ensure indices are Int32 for concatenation with input tokens
10+
return indices.scalarType == Int32.self ? indices : indices.cast(to: Int32.self)
911
}
1012

11-
// MARK: Top-K Sampling
13+
// MARK: Sampling
1214

15+
/// Performs multinomial sampling from processed logits.
16+
///
17+
/// Assumes logits have already been processed by LogitsProcessorList
18+
/// (temperature, top-k, top-p, etc. already applied).
19+
///
20+
/// - Parameter scores: Processed logits tensor [batch_size, vocab_size]
21+
/// - Returns: Sampled token ID tensor [batch_size, 1]
1322
@available(macOS 15.0, iOS 18.0, *)
14-
func selectNextTokenUsingTopKSampling(from scores: MLTensor, temperature: Float, topK: Int) -> MLTensor {
15-
let temperatureAdjustedScores = scores / temperature
16-
let (topKScores, topKIndices) = temperatureAdjustedScores.topK(topK)
17-
let topKProbs = topKScores.softmax(alongAxis: -1)
18-
let rnd = topKProbs.sum() * Float.random(in: 0..<1)
19-
var accumTopKProbs = topKProbs.cumulativeSum(alongAxis: -1)
20-
accumTopKProbs += (accumTopKProbs .< rnd) * 100.0
21-
let topKIndex = accumTopKProbs.argsort()[..., 0]
22-
let nextTokenTensor = topKIndices.gathering(
23-
atIndices: topKIndex,
24-
alongAxis: topKIndices.rank - 1
25-
)
26-
return nextTokenTensor.reshaped(to: [1, 1])
23+
func selectNextTokenUsingSampling(from scores: MLTensor) -> MLTensor {
24+
// Convert logits to probabilities
25+
let probs = scores.softmax(alongAxis: -1)
26+
27+
// Multinomial sampling using cumulative sum method:
28+
// 1. Generate random number in [0, 1)
29+
// 2. Compute cumulative sum of probabilities
30+
// 3. Find first index where cumsum >= random_number
31+
//
32+
// This is equivalent to torch.multinomial() but using available MLTensor ops
33+
34+
let batchSize = scores.shape[0]
35+
let rndTensor = MLTensor(randomUniform: [batchSize, 1], in: 0..<1, scalarType: Float.self)
36+
let cumulativeProbs = probs.cumulativeSum(alongAxis: -1)
37+
38+
// Ensure random tensor matches the type of cumulativeProbs
39+
let rnd = cumulativeProbs.scalarType == Float.self ? rndTensor : rndTensor.cast(to: cumulativeProbs.scalarType)
40+
41+
// Create mask where cumsum >= rnd (these are candidates)
42+
// We want the FIRST position where this is true
43+
// Strategy: Set all positions where cumsum < rnd to a large value (1000.0)
44+
// Set all positions where cumsum >= rnd to their index value
45+
// Then argmin will give us the first qualifying index
46+
47+
let mask = cumulativeProbs .< rnd
48+
let penalized = mask * 1000.0 // Large value for positions to skip
49+
let indexed = penalized + cumulativeProbs // Positions >= rnd will have small values
50+
51+
let sampledIndex = indexed.argmin(alongAxis: -1).reshaped(to: [1, 1])
52+
// Ensure indices are Int32 for concatenation with input tokens
53+
return sampledIndex.scalarType == Int32.self ? sampledIndex : sampledIndex.cast(to: Int32.self)
2754
}
2855
#endif // canImport(CoreML)

Sources/Generation/Generation.swift

Lines changed: 63 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,18 +72,27 @@ extension Generation {
7272
) async -> GenerationOutput {
7373
let tokens = tokens.map { Int32($0) }
7474
var outputTokens = MLTensor(tokens).expandingShape(at: 0)
75-
while outputTokens.shape[1] < config.maxLength {
75+
76+
// Create logits processor list based on config
77+
let logitsProcessorList = createLogitsProcessorList(config: config)
78+
79+
let inputLength = outputTokens.shape[1]
80+
let maxTotalLength = min(config.maxLength, inputLength + config.maxNewTokens)
81+
82+
while outputTokens.shape[1] < maxTotalLength {
83+
// Get raw logits from model
7684
let nextTokenScores = await model(outputTokens, config)
85+
86+
// Apply logits processors
87+
let processedScores = await logitsProcessorList(outputTokens, nextTokenScores)
88+
89+
// Select next token based on generation mode
7790
let nextToken =
7891
switch config.generationMode {
7992
case .greedy:
80-
selectNextTokenUsingGreedyDecoding(from: nextTokenScores)
93+
selectNextTokenUsingGreedyDecoding(from: processedScores)
8194
case .sample:
82-
selectNextTokenUsingTopKSampling(
83-
from: nextTokenScores,
84-
temperature: config.temperature,
85-
topK: config.topK
86-
)
95+
selectNextTokenUsingSampling(from: processedScores)
8796
default:
8897
fatalError("Generation mode \(config.generationMode) not implemented yet")
8998
}
@@ -101,6 +110,53 @@ extension Generation {
101110
return await tensorToGenerationOutput(outputTokens)
102111
}
103112

113+
/// Creates a list of logits processors based on generation configuration.
114+
///
115+
/// - Parameter config: Generation configuration specifying which processors to apply
116+
/// - Returns: List of logits processors to apply during generation
117+
private func createLogitsProcessorList(config: GenerationConfig) -> LogitsProcessorList {
118+
var processors: [any LogitsProcessor] = []
119+
120+
// Repetition penalty (applied before sampling warpers)
121+
if config.repetitionPenalty != 1.0 {
122+
if let processor = try? RepetitionPenaltyLogitsProcessor(penalty: Float(config.repetitionPenalty)) {
123+
processors.append(processor)
124+
}
125+
}
126+
127+
// Temperature scaling (if not default)
128+
if config.temperature > 0 && config.temperature != 1.0 {
129+
if let processor = try? TemperatureLogitsWarper(temperature: config.temperature) {
130+
processors.append(processor)
131+
}
132+
}
133+
134+
// Top-K filtering (only apply if topK is meaningful)
135+
// Note: We can't determine vocab size here, so TopKLogitsWarper handles the case
136+
// where topK >= vocabSize internally
137+
if config.topK > 0 && config.topK < Int.max {
138+
if let processor = try? TopKLogitsWarper(topK: config.topK) {
139+
processors.append(processor)
140+
}
141+
}
142+
143+
// Top-P (nucleus) sampling
144+
if config.topP < 1.0 {
145+
if let processor = try? TopPLogitsWarper(topP: Float(config.topP)) {
146+
processors.append(processor)
147+
}
148+
}
149+
150+
// Min-P sampling (applied after temperature scaling)
151+
if let minP = config.minP {
152+
if let processor = try? MinPLogitsWarper(minP: Float(minP)) {
153+
processors.append(processor)
154+
}
155+
}
156+
157+
return LogitsProcessorList(processors: processors)
158+
}
159+
104160
private func tensorToGenerationOutput(_ tensor: MLTensor) async -> GenerationOutput {
105161
await tensor.shapedArray(of: Int32.self).scalars.map { Int($0) }
106162
}

Sources/Generation/GenerationConfig.swift

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,10 +39,13 @@ public struct GenerationConfig {
3939
public var topK = 50
4040

4141
/// Cumulative probability threshold for top-p sampling.
42-
public var topP = 1.0
42+
public var topP: Float = 1.0
43+
44+
/// Minimum token probability threshold, scaled by the most likely token's probability.
45+
public var minP: Float?
4346

4447
/// Penalty for token repetition (1.0 means no penalty).
45-
public var repetitionPenalty = 1.0
48+
public var repetitionPenalty: Float = 1.0
4649

4750
/// Token ID used for padding sequences.
4851
public var padTokenId: Int?
@@ -65,6 +68,7 @@ public struct GenerationConfig {
6568
/// - temperature: Sampling temperature
6669
/// - topK: Top-k sampling parameter
6770
/// - topP: Top-p sampling parameter
71+
/// - minP: Min-p sampling parameter
6872
/// - repetitionPenalty: Repetition penalty factor
6973
public init(
7074
maxLength: Int = 20,
@@ -73,20 +77,22 @@ public struct GenerationConfig {
7377
numBeams: Int = 1,
7478
numBeamGroups: Int = 1,
7579
penaltyAlpha: Double? = nil,
76-
temperature: Double = 1.0,
80+
temperature: Float = 1.0,
7781
topK: Int = 50,
78-
topP: Double = 1.0,
79-
repetitionPenalty: Double = 1.0
82+
topP: Float = 1.0,
83+
minP: Float? = nil,
84+
repetitionPenalty: Float = 1.0
8085
) {
8186
self.maxLength = maxLength
8287
self.maxNewTokens = maxNewTokens
8388
self.doSample = doSample
8489
self.numBeams = numBeams
8590
self.numBeamGroups = numBeamGroups
8691
self.penaltyAlpha = penaltyAlpha
87-
self.temperature = Float(temperature)
92+
self.temperature = temperature
8893
self.topK = topK
8994
self.topP = topP
95+
self.minP = minP
9096
self.repetitionPenalty = repetitionPenalty
9197
}
9298
}
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
#if canImport(CoreML)
2+
import CoreML
3+
4+
/// Abstract base class for all logits processors that can be applied during generation.
5+
///
6+
/// Logits processors modify the probability distribution over vocabulary tokens by transforming
7+
/// the raw logit scores produced by language models. This enables various sampling strategies
8+
/// such as temperature scaling, top-k/top-p filtering, and repetition penalties.
9+
///
10+
/// Based on: https://github.com/huggingface/transformers/blob/main/src/transformers/generation/logits_process.py
11+
@available(macOS 15.0, iOS 18.0, *)
12+
public protocol LogitsProcessor {
13+
/// Processes logits for next token prediction.
14+
///
15+
/// - Parameters:
16+
/// - inputIds: Tensor of input token IDs with shape `[batch_size, sequence_length]`
17+
/// - scores: Tensor of raw logit scores with shape `[batch_size, vocab_size]`
18+
/// - Returns: Processed logits tensor with shape `[batch_size, vocab_size]`
19+
///
20+
/// - Note: The `inputIds` parameter provides context for processors that need to examine
21+
/// the generated sequence (e.g., repetition penalty). Processors that don't need this
22+
/// context (e.g., temperature) can ignore it.
23+
func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor
24+
}
25+
26+
/// A list of logits processors that applies each processor sequentially.
27+
///
28+
/// This class provides a convenient way to chain multiple logits processors together.
29+
/// Each processor is applied in order to the logits tensor, with the output of one
30+
/// processor becoming the input to the next.
31+
@available(macOS 15.0, iOS 18.0, *)
32+
public struct LogitsProcessorList {
33+
public var processors: [any LogitsProcessor]
34+
35+
public init(processors: [any LogitsProcessor]) {
36+
self.processors = processors
37+
}
38+
39+
/// Applies all logits processors sequentially to the input scores.
40+
///
41+
/// - Parameters:
42+
/// - inputIds: Tensor of input token IDs with shape `[batch_size, sequence_length]`
43+
/// - scores: Tensor of raw logit scores with shape `[batch_size, vocab_size]`
44+
/// - Returns: Processed logits tensor with shape `[batch_size, vocab_size]`
45+
public func callAsFunction(_ inputIds: MLTensor, _ scores: MLTensor) async -> MLTensor {
46+
// Following transformers convention: all logits processing happens in Float32
47+
// Cast to Float32 once at the start, process, then cast back to original type at the end
48+
let originalScalarType = scores.scalarType
49+
var processedScores = scores.scalarType == Float.self ? scores : scores.cast(to: Float.self)
50+
51+
for processor in processors {
52+
processedScores = await processor(inputIds, processedScores)
53+
}
54+
55+
// Cast back to original type if needed
56+
return originalScalarType == Float.self ? processedScores : processedScores.cast(to: originalScalarType)
57+
}
58+
}
59+
#endif

0 commit comments

Comments
 (0)