@@ -5,24 +5,51 @@ import CoreML
5
5
6
6
@available ( macOS 15 . 0 , iOS 18 . 0 , * )
7
7
func selectNextTokenUsingGreedyDecoding( from scores: MLTensor ) -> MLTensor {
8
- scores. argmax ( alongAxis: - 1 ) . reshaped ( to: [ 1 , 1 ] )
8
+ let indices = scores. argmax ( alongAxis: - 1 ) . reshaped ( to: [ 1 , 1 ] )
9
+ // Ensure indices are Int32 for concatenation with input tokens
10
+ return indices. scalarType == Int32 . self ? indices : indices. cast ( to: Int32 . self)
9
11
}
10
12
11
- // MARK: Top-K Sampling
13
+ // MARK: Sampling
12
14
15
+ /// Performs multinomial sampling from processed logits.
16
+ ///
17
+ /// Assumes logits have already been processed by LogitsProcessorList
18
+ /// (temperature, top-k, top-p, etc. already applied).
19
+ ///
20
+ /// - Parameter scores: Processed logits tensor [batch_size, vocab_size]
21
+ /// - Returns: Sampled token ID tensor [batch_size, 1]
13
22
@available ( macOS 15 . 0 , iOS 18 . 0 , * )
14
- func selectNextTokenUsingTopKSampling( from scores: MLTensor , temperature: Float , topK: Int ) -> MLTensor {
15
- let temperatureAdjustedScores = scores / temperature
16
- let ( topKScores, topKIndices) = temperatureAdjustedScores. topK ( topK)
17
- let topKProbs = topKScores. softmax ( alongAxis: - 1 )
18
- let rnd = topKProbs. sum ( ) * Float. random ( in: 0 ..< 1 )
19
- var accumTopKProbs = topKProbs. cumulativeSum ( alongAxis: - 1 )
20
- accumTopKProbs += ( accumTopKProbs .< rnd) * 100.0
21
- let topKIndex = accumTopKProbs. argsort ( ) [ ... , 0 ]
22
- let nextTokenTensor = topKIndices. gathering (
23
- atIndices: topKIndex,
24
- alongAxis: topKIndices. rank - 1
25
- )
26
- return nextTokenTensor. reshaped ( to: [ 1 , 1 ] )
23
+ func selectNextTokenUsingSampling( from scores: MLTensor ) -> MLTensor {
24
+ // Convert logits to probabilities
25
+ let probs = scores. softmax ( alongAxis: - 1 )
26
+
27
+ // Multinomial sampling using cumulative sum method:
28
+ // 1. Generate random number in [0, 1)
29
+ // 2. Compute cumulative sum of probabilities
30
+ // 3. Find first index where cumsum >= random_number
31
+ //
32
+ // This is equivalent to torch.multinomial() but using available MLTensor ops
33
+
34
+ let batchSize = scores. shape [ 0 ]
35
+ let rndTensor = MLTensor ( randomUniform: [ batchSize, 1 ] , in: 0 ..< 1 , scalarType: Float . self)
36
+ let cumulativeProbs = probs. cumulativeSum ( alongAxis: - 1 )
37
+
38
+ // Ensure random tensor matches the type of cumulativeProbs
39
+ let rnd = cumulativeProbs. scalarType == Float . self ? rndTensor : rndTensor. cast ( to: cumulativeProbs. scalarType)
40
+
41
+ // Create mask where cumsum >= rnd (these are candidates)
42
+ // We want the FIRST position where this is true
43
+ // Strategy: Set all positions where cumsum < rnd to a large value (1000.0)
44
+ // Set all positions where cumsum >= rnd to their index value
45
+ // Then argmin will give us the first qualifying index
46
+
47
+ let mask = cumulativeProbs .< rnd
48
+ let penalized = mask * 1000.0 // Large value for positions to skip
49
+ let indexed = penalized + cumulativeProbs // Positions >= rnd will have small values
50
+
51
+ let sampledIndex = indexed. argmin ( alongAxis: - 1 ) . reshaped ( to: [ 1 , 1 ] )
52
+ // Ensure indices are Int32 for concatenation with input tokens
53
+ return sampledIndex. scalarType == Int32 . self ? sampledIndex : sampledIndex. cast ( to: Int32 . self)
27
54
}
28
55
#endif // canImport(CoreML)
0 commit comments