Skip to content

Commit

Permalink
added MLX text decoder (#161)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkrukowski committed Jun 12, 2024
1 parent 470e227 commit b88079d
Show file tree
Hide file tree
Showing 11 changed files with 963 additions and 184 deletions.
109 changes: 94 additions & 15 deletions Sources/WhisperKit/Core/Models.swift
Original file line number Diff line number Diff line change
Expand Up @@ -195,18 +195,42 @@ public enum DecodingTask: CustomStringConvertible, CaseIterable {
}

public struct DecodingInputs {
var initialPrompt: [Int]
var inputIds: MLMultiArray
var cacheLength: MLMultiArray
var keyCache: MLMultiArray
var valueCache: MLMultiArray
var alignmentWeights: MLMultiArray
var kvCacheUpdateMask: MLMultiArray
var decoderKeyPaddingMask: MLMultiArray
var prefillKeyCache: MLMultiArray
var prefillValueCache: MLMultiArray

func reset(prefilledCacheSize: Int, maxTokenContext: Int) {
public var initialPrompt: [Int]
public var inputIds: MLMultiArray
public var cacheLength: MLMultiArray
public var keyCache: MLMultiArray?
public var valueCache: MLMultiArray?
public var alignmentWeights: MLMultiArray
public var kvCacheUpdateMask: MLMultiArray
public var decoderKeyPaddingMask: MLMultiArray
public var prefillKeyCache: MLMultiArray
public var prefillValueCache: MLMultiArray

public init(
initialPrompt: [Int],
inputIds: MLMultiArray,
cacheLength: MLMultiArray,
keyCache: MLMultiArray?,
valueCache: MLMultiArray?,
alignmentWeights: MLMultiArray,
kvCacheUpdateMask: MLMultiArray,
decoderKeyPaddingMask: MLMultiArray,
prefillKeyCache: MLMultiArray,
prefillValueCache: MLMultiArray
) {
self.initialPrompt = initialPrompt
self.inputIds = inputIds
self.cacheLength = cacheLength
self.keyCache = keyCache
self.valueCache = valueCache
self.alignmentWeights = alignmentWeights
self.kvCacheUpdateMask = kvCacheUpdateMask
self.decoderKeyPaddingMask = decoderKeyPaddingMask
self.prefillKeyCache = prefillKeyCache
self.prefillValueCache = prefillValueCache
}

public func reset(prefilledCacheSize: Int, maxTokenContext: Int) {
// NOTE: Because we have a mask on the kvcache,
// we can simply shift the masks without touching the data,
// it will be overwritten by the new data without impact on the output
Expand All @@ -230,9 +254,19 @@ public struct DecodingInputs {
}

public struct DecodingCache {
var keyCache: MLMultiArray?
var valueCache: MLMultiArray?
var alignmentWeights: MLMultiArray?
public var keyCache: MLMultiArray?
public var valueCache: MLMultiArray?
public var alignmentWeights: MLMultiArray?

public init(
keyCache: MLMultiArray?,
valueCache: MLMultiArray?,
alignmentWeights: MLMultiArray?
) {
self.keyCache = keyCache
self.valueCache = valueCache
self.alignmentWeights = alignmentWeights
}
}

public enum ChunkingStrategy: String, CaseIterable {
Expand Down Expand Up @@ -410,6 +444,34 @@ public struct DecodingResult {
public var timings: TranscriptionTimings?
public var fallback: DecodingFallback?

public init(
language: String,
languageProbs: [String : Float],
tokens: [Int],
tokenLogProbs: [[Int : Float]],
text: String,
avgLogProb: Float,
noSpeechProb: Float,
temperature: Float,
compressionRatio: Float,
cache: DecodingCache?,
timings: TranscriptionTimings?,
fallback: DecodingFallback?
) {
self.language = language
self.languageProbs = languageProbs
self.tokens = tokens
self.tokenLogProbs = tokenLogProbs
self.text = text
self.avgLogProb = avgLogProb
self.noSpeechProb = noSpeechProb
self.temperature = temperature
self.compressionRatio = compressionRatio
self.cache = cache
self.timings = timings
self.fallback = fallback
}

public static var emptyResults: DecodingResult {
return DecodingResult(language: "",
languageProbs: [:],
Expand Down Expand Up @@ -596,6 +658,23 @@ public struct TranscriptionProgress {
public var avgLogprob: Float?
public var compressionRatio: Float?
public var windowId: Int = 0

public init(
timings: TranscriptionTimings,
text: String,
tokens: [Int],
temperature: Float?,
avgLogprob: Float?,
compressionRatio: Float?,
windowId: Int = 0
) {
self.timings = timings
self.text = text
self.tokens = tokens
self.temperature = temperature
self.avgLogprob = avgLogprob
self.compressionRatio = compressionRatio
}
}

/// Callback to receive progress updates during transcription.
Expand Down
Loading

0 comments on commit b88079d

Please sign in to comment.