Skip to content

Commit

Permalink
Merge pull request #74 from jkrukowski/background-url-session
Browse files Browse the repository at this point in the history
Updated swift-transformers, do not use background url session in CLI
  • Loading branch information
ZachNagengast committed Mar 15, 2024
2 parents 8588a38 + 7d809a2 commit 17b472c
Show file tree
Hide file tree
Showing 5 changed files with 68 additions and 25 deletions.
4 changes: 2 additions & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "24605a8c0cc974bec5b94a6752eb687bae77db31",
"version" : "0.1.3"
"revision" : "3bd02269b7797ade67c15679a575cd5c6f203ce6",
"version" : "0.1.5"
}
}
],
Expand Down
2 changes: 1 addition & 1 deletion Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ let package = Package(
),
],
dependencies: [
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.3"),
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.5"),
.package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"),
],
targets: [
Expand Down
5 changes: 3 additions & 2 deletions Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -272,10 +272,11 @@ public func resolveAbsolutePath(_ inputPath: String) -> String {

func loadTokenizer(
for pretrained: ModelVariant,
tokenizerFolder: URL? = nil
tokenizerFolder: URL? = nil,
useBackgroundSession: Bool = false
) async throws -> Tokenizer {
let tokenizerName = tokenizerNameForVariant(pretrained)
let hubApi = HubApi(downloadBase: tokenizerFolder)
let hubApi = HubApi(downloadBase: tokenizerFolder, useBackgroundSession: useBackgroundSession)
return try await AutoTokenizer.from(pretrained: tokenizerName, hubApi: hubApi)
}

Expand Down
47 changes: 38 additions & 9 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,6 @@ public class WhisperKit: Transcriber {
public var modelVariant: ModelVariant = .tiny
public var modelState: ModelState = .unloaded
public var modelCompute: ModelComputeOptions
public var modelFolder: URL?
public var tokenizerFolder: URL?
public var tokenizer: Tokenizer?

/// Protocols
Expand All @@ -48,7 +46,13 @@ public class WhisperKit: Transcriber {
public var decoderInputs: DecodingInputs?
public var currentTimings: TranscriptionTimings?

/// State
public let progress = Progress()

/// Configuration
public var modelFolder: URL?
public var tokenizerFolder: URL?
private let useBackgroundDownloadSession: Bool

public init(
model: String? = nil,
Expand All @@ -67,7 +71,8 @@ public class WhisperKit: Transcriber {
logLevel: Logging.LogLevel = .info,
prewarm: Bool? = nil,
load: Bool? = nil,
download: Bool = true
download: Bool = true,
useBackgroundDownloadSession: Bool = false
) async throws {
self.modelCompute = computeOptions ?? ModelComputeOptions()
self.audioProcessor = audioProcessor ?? AudioProcessor()
Expand All @@ -77,6 +82,7 @@ public class WhisperKit: Transcriber {
self.logitsFilters = logitsFilters ?? []
self.segmentSeeker = segmentSeeker ?? SegmentSeeker()
self.tokenizerFolder = tokenizerFolder
self.useBackgroundDownloadSession = useBackgroundDownloadSession
Logging.shared.logLevel = verbose ? logLevel : .none
currentTimings = TranscriptionTimings()

Expand Down Expand Up @@ -170,8 +176,14 @@ public class WhisperKit: Transcriber {
return sortedModels
}

public static func download(variant: String, downloadBase: URL? = nil, from repo: String = "argmaxinc/whisperkit-coreml", progressCallback: ((Progress) -> Void)? = nil) async throws -> URL? {
let hubApi = HubApi(downloadBase: downloadBase)
public static func download(
variant: String,
downloadBase: URL? = nil,
useBackgroundSession: Bool = false,
from repo: String = "argmaxinc/whisperkit-coreml",
progressCallback: ((Progress) -> Void)? = nil
) async throws -> URL? {
let hubApi = HubApi(downloadBase: downloadBase, useBackgroundSession: useBackgroundSession)
let repo = Hub.Repo(id: repo, type: .models)
do {
let modelFolder = try await hubApi.snapshot(from: repo, matching: ["*\(variant.description)/*"]) { progress in
Expand All @@ -191,7 +203,13 @@ public class WhisperKit: Transcriber {
}

/// Sets up the model folder either from a local path or by downloading from a repository.
public func setupModels(model: String?, downloadBase: URL? = nil, modelRepo: String?, modelFolder: String?, download: Bool) async throws {
public func setupModels(
model: String?,
downloadBase: URL? = nil,
modelRepo: String?,
modelFolder: String?,
download: Bool
) async throws {
// Determine the model variant to use
let modelVariant = model ?? WhisperKit.recommendedModels().default

Expand All @@ -201,7 +219,12 @@ public class WhisperKit: Transcriber {
} else if download {
let repo = modelRepo ?? "argmaxinc/whisperkit-coreml"
do {
let hubModelFolder = try await Self.download(variant: modelVariant, downloadBase: downloadBase, from: repo)
let hubModelFolder = try await Self.download(
variant: modelVariant,
downloadBase: downloadBase,
useBackgroundSession: useBackgroundDownloadSession,
from: repo
)
self.modelFolder = hubModelFolder!
} catch {
// Handle errors related to model downloading
Expand All @@ -217,7 +240,9 @@ public class WhisperKit: Transcriber {
try await loadModels(prewarmMode: true)
}

public func loadModels(prewarmMode: Bool = false) async throws {
public func loadModels(
prewarmMode: Bool = false
) async throws {
modelState = prewarmMode ? .prewarming : .loading

let modelLoadStart = CFAbsoluteTimeGetCurrent()
Expand Down Expand Up @@ -292,7 +317,11 @@ public class WhisperKit: Transcriber {
{
modelVariant = detectVariant(logitsDim: logitsDim, encoderDim: encoderDim)
Logging.debug("Loading tokenizer for \(modelVariant)")
tokenizer = try await loadTokenizer(for: modelVariant, tokenizerFolder: tokenizerFolder)
tokenizer = try await loadTokenizer(
for: modelVariant,
tokenizerFolder: tokenizerFolder,
useBackgroundSession: useBackgroundDownloadSession
)
textDecoder.tokenizer = tokenizer
Logging.debug("Loaded tokenizer")
} else {
Expand Down
35 changes: 24 additions & 11 deletions Sources/WhisperKitCLI/Transcribe.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ struct Transcribe: AsyncParsableCommand {
abstract: "Transcribe audio to text using WhisperKit"
)

@OptionGroup
@OptionGroup
var cliArguments: CLIArguments

mutating func run() async throws {
Expand All @@ -36,32 +36,39 @@ struct Transcribe: AsyncParsableCommand {
audioEncoderCompute: cliArguments.audioEncoderComputeUnits.asMLComputeUnits,
textDecoderCompute: cliArguments.textDecoderComputeUnits.asMLComputeUnits
)

let downloadTokenizerFolder: URL? =
if let filePath = cliArguments.downloadTokenizerPath {
URL(filePath: filePath)
} else {
nil
}

let downloadModelFolder: URL? =
if let filePath = cliArguments.downloadModelPath {
URL(filePath: filePath)
} else {
nil
}

print("Initializing models...")
if cliArguments.verbose {
print("Initializing models...")
}

let whisperKit = try await WhisperKit(
model: cliArguments.model,
downloadBase: downloadModelFolder,
modelFolder: cliArguments.modelPath,
tokenizerFolder: downloadTokenizerFolder,
computeOptions: computeOptions,
verbose: cliArguments.verbose,
logLevel: .debug
logLevel: .debug,
useBackgroundDownloadSession: false
)
print("Models initialized")

if cliArguments.verbose {
print("Models initialized")
}

let options = DecodingOptions(
verbose: cliArguments.verbose,
Expand All @@ -83,7 +90,7 @@ struct Transcribe: AsyncParsableCommand {
)

let transcribeResult = try await whisperKit.transcribe(
audioPath: resolvedAudioPath,
audioPath: resolvedAudioPath,
decodeOptions: options
)

Expand Down Expand Up @@ -136,26 +143,32 @@ struct Transcribe: AsyncParsableCommand {
} else {
nil
}

let downloadModelFolder: URL? =
if let filePath = cliArguments.downloadModelPath {
URL(filePath: filePath)
} else {
nil
}

print("Initializing models...")
if cliArguments.verbose {
print("Initializing models...")
}

let whisperKit = try await WhisperKit(
model: cliArguments.model,
downloadBase: downloadModelFolder,
modelFolder: cliArguments.modelPath,
tokenizerFolder: downloadTokenizerFolder,
computeOptions: computeOptions,
verbose: cliArguments.verbose,
logLevel: .debug
logLevel: .debug,
useBackgroundDownloadSession: false
)
print("Models initialized")

if cliArguments.verbose {
print("Models initialized")
}
let decodingOptions = DecodingOptions(
verbose: cliArguments.verbose,
task: .transcribe,
Expand Down

0 comments on commit 17b472c

Please sign in to comment.