Skip to content

Commit

Permalink
Merge pull request #72 from jkrukowski/tokenizer-download-path
Browse files Browse the repository at this point in the history
  • Loading branch information
ZachNagengast committed Mar 15, 2024
2 parents 0b78c52 + 4b83f89 commit 8588a38
Show file tree
Hide file tree
Showing 11 changed files with 100 additions and 37 deletions.
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ build:

build-cli:
@echo "Building WhisperKit CLI..."
@swift build -c release --product transcribe
@swift build -c release --product whisperkit-cli


test:
Expand Down
4 changes: 2 additions & 2 deletions Package.resolved
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@
"kind" : "remoteSourceControl",
"location" : "https://github.com/huggingface/swift-transformers.git",
"state" : {
"revision" : "564442fba36b0b694d730a62d0593e5f54043b55",
"version" : "0.1.2"
"revision" : "24605a8c0cc974bec5b94a6752eb687bae77db31",
"version" : "0.1.3"
}
}
],
Expand Down
4 changes: 2 additions & 2 deletions Package.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ let package = Package(
targets: ["WhisperKit"]
),
.executable(
name: "transcribe",
name: "whisperkit-cli",
targets: ["WhisperKitCLI"]
),
],
dependencies: [
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.2"),
.package(url: "https://github.com/huggingface/swift-transformers.git", exact: "0.1.3"),
.package(url: "https://github.com/apple/swift-argument-parser.git", exact: "1.3.0"),
],
targets: [
Expand Down
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -121,13 +121,13 @@ make download-models
You can then run them via the CLI with:

```bash
swift run transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --audio-path "path/to/your/audio.{wav,mp3,m4a,flac}"
swift run whisperkit-cli transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --audio-path "path/to/your/audio.{wav,mp3,m4a,flac}"
```

Which should print a transcription of the audio file. If you would like to stream the audio directly from a microphone, use:

```bash
swift run transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --stream
swift run whisperkit-cli transcribe --model-path "Models/whisperkit-coreml/openai_whisper-large-v3" --stream
```

## Contributing & Roadmap
Expand Down
10 changes: 7 additions & 3 deletions Sources/WhisperKit/Core/Utils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import AVFoundation
import CoreML
import Foundation
import Tokenizers
import Hub
#if canImport(UIKit)
import UIKit
#elseif canImport(AppKit)
Expand Down Expand Up @@ -269,10 +270,13 @@ public func resolveAbsolutePath(_ inputPath: String) -> String {
return inputPath
}

func loadTokenizer(for pretrained: ModelVariant) async throws -> Tokenizer {
// TODO: Cache tokenizer config to avoid repeated network requests
func loadTokenizer(
for pretrained: ModelVariant,
tokenizerFolder: URL? = nil
) async throws -> Tokenizer {
let tokenizerName = tokenizerNameForVariant(pretrained)
return try await AutoTokenizer.from(pretrained: tokenizerName)
let hubApi = HubApi(downloadBase: tokenizerFolder)
return try await AutoTokenizer.from(pretrained: tokenizerName, hubApi: hubApi)
}

func formatTimestamp(_ timestamp: Float) -> String {
Expand Down
13 changes: 11 additions & 2 deletions Sources/WhisperKit/Core/WhisperKit.swift
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public class WhisperKit: Transcriber {
public var modelState: ModelState = .unloaded
public var modelCompute: ModelComputeOptions
public var modelFolder: URL?
public var tokenizerFolder: URL?
public var tokenizer: Tokenizer?

/// Protocols
Expand Down Expand Up @@ -54,6 +55,7 @@ public class WhisperKit: Transcriber {
downloadBase: URL? = nil,
modelRepo: String? = nil,
modelFolder: String? = nil,
tokenizerFolder: URL? = nil,
computeOptions: ModelComputeOptions? = nil,
audioProcessor: (any AudioProcessing)? = nil,
featureExtractor: (any FeatureExtracting)? = nil,
Expand All @@ -74,10 +76,17 @@ public class WhisperKit: Transcriber {
self.textDecoder = textDecoder ?? TextDecoder()
self.logitsFilters = logitsFilters ?? []
self.segmentSeeker = segmentSeeker ?? SegmentSeeker()
self.tokenizerFolder = tokenizerFolder
Logging.shared.logLevel = verbose ? logLevel : .none
currentTimings = TranscriptionTimings()

try await setupModels(model: model, downloadBase: downloadBase, modelRepo: modelRepo, modelFolder: modelFolder, download: download)
try await setupModels(
model: model,
downloadBase: downloadBase,
modelRepo: modelRepo,
modelFolder: modelFolder,
download: download
)

if let prewarm = prewarm, prewarm {
Logging.info("Prewarming models...")
Expand Down Expand Up @@ -283,7 +292,7 @@ public class WhisperKit: Transcriber {
{
modelVariant = detectVariant(logitsDim: logitsDim, encoderDim: encoderDim)
Logging.debug("Loading tokenizer for \(modelVariant)")
tokenizer = try await loadTokenizer(for: modelVariant)
tokenizer = try await loadTokenizer(for: modelVariant, tokenizerFolder: tokenizerFolder)
textDecoder.tokenizer = tokenizer
Logging.debug("Loaded tokenizer")
} else {
Expand Down
11 changes: 10 additions & 1 deletion Sources/WhisperKitCLI/CLIArguments.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,16 @@ struct CLIArguments: ParsableArguments {
var audioPath: String = "Tests/WhisperKitTests/Resources/jfk.wav"

@Option(help: "Path of model files")
var modelPath: String = "Models/whisperkit-coreml/openai_whisper-tiny"
var modelPath: String?

@Option(help: "Model to download if no modelPath is provided")
var model: String?

@Option(help: "Path to save the downloaded model")
var downloadModelPath: String?

@Option(help: "Path to save the downloaded tokenizer files")
var downloadTokenizerPath: String?

@Option(help: "Compute units for audio encoder model with {all,cpuOnly,cpuAndGPU,cpuAndNeuralEngine,random}")
var audioEncoderComputeUnits: ComputeUnits = .cpuAndNeuralEngine
Expand Down
1 change: 0 additions & 1 deletion Sources/WhisperKitCLI/CLIUtils.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import ArgumentParser
import CoreML
import Foundation
import WhisperKit

enum ComputeUnits: String, ExpressibleByArgument, CaseIterable {
case all, cpuAndGPU, cpuOnly, cpuAndNeuralEngine, random
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,48 +7,56 @@ import Foundation
import WhisperKit

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
@main
struct WhisperKitCLI: AsyncParsableCommand {
struct Transcribe: AsyncParsableCommand {
static let configuration = CommandConfiguration(
commandName: "transcribe",
abstract: "WhisperKit Transcribe CLI",
discussion: "Swift native speech recognition with Whisper for Apple Silicon"
abstract: "Transcribe audio to text using WhisperKit"
)

@OptionGroup
var cliArguments: CLIArguments

mutating func run() async throws {
if cliArguments.stream {
try await transcribeStream(modelPath: cliArguments.modelPath)
try await transcribeStream()
} else {
let audioURL = URL(fileURLWithPath: cliArguments.audioPath)
if cliArguments.verbose {
print("Transcribing audio at \(audioURL)")
}
try await transcribe(audioPath: cliArguments.audioPath, modelPath: cliArguments.modelPath)
try await transcribe()
}
}

private func transcribe(audioPath: String, modelPath: String) async throws {
let resolvedModelPath = resolveAbsolutePath(modelPath)
guard FileManager.default.fileExists(atPath: resolvedModelPath) else {
fatalError("Model path does not exist \(resolvedModelPath)")
}

let resolvedAudioPath = resolveAbsolutePath(audioPath)
private func transcribe() async throws {
let resolvedAudioPath = resolveAbsolutePath(cliArguments.audioPath)
guard FileManager.default.fileExists(atPath: resolvedAudioPath) else {
fatalError("Resource path does not exist \(resolvedAudioPath)")
throw CocoaError.error(.fileNoSuchFile)
}
if cliArguments.verbose {
print("Transcribing audio at \(cliArguments.audioPath)")
}

let computeOptions = ModelComputeOptions(
audioEncoderCompute: cliArguments.audioEncoderComputeUnits.asMLComputeUnits,
textDecoderCompute: cliArguments.textDecoderComputeUnits.asMLComputeUnits
)

let downloadTokenizerFolder: URL? =
if let filePath = cliArguments.downloadTokenizerPath {
URL(filePath: filePath)
} else {
nil
}

let downloadModelFolder: URL? =
if let filePath = cliArguments.downloadModelPath {
URL(filePath: filePath)
} else {
nil
}

print("Initializing models...")
let whisperKit = try await WhisperKit(
modelFolder: modelPath,
model: cliArguments.model,
downloadBase: downloadModelFolder,
modelFolder: cliArguments.modelPath,
tokenizerFolder: downloadTokenizerFolder,
computeOptions: computeOptions,
verbose: cliArguments.verbose,
logLevel: .debug
Expand Down Expand Up @@ -82,7 +90,7 @@ struct WhisperKitCLI: AsyncParsableCommand {
let transcription = transcribeResult?.text ?? "Transcription failed"

if cliArguments.report, let result = transcribeResult {
let audioFileName = URL(fileURLWithPath: audioPath).lastPathComponent.components(separatedBy: ".").first!
let audioFileName = URL(fileURLWithPath: cliArguments.audioPath).lastPathComponent.components(separatedBy: ".").first!

// Write SRT (SubRip Subtitle Format) for the transcription
let srtReportWriter = WriteSRT(outputDir: cliArguments.reportPath)
Expand Down Expand Up @@ -116,15 +124,32 @@ struct WhisperKitCLI: AsyncParsableCommand {
}
}

private func transcribeStream(modelPath: String) async throws {
private func transcribeStream() async throws {
let computeOptions = ModelComputeOptions(
audioEncoderCompute: cliArguments.audioEncoderComputeUnits.asMLComputeUnits,
textDecoderCompute: cliArguments.textDecoderComputeUnits.asMLComputeUnits
)

let downloadTokenizerFolder: URL? =
if let filePath = cliArguments.downloadTokenizerPath {
URL(filePath: filePath)
} else {
nil
}

let downloadModelFolder: URL? =
if let filePath = cliArguments.downloadModelPath {
URL(filePath: filePath)
} else {
nil
}

print("Initializing models...")
let whisperKit = try await WhisperKit(
modelFolder: modelPath,
model: cliArguments.model,
downloadBase: downloadModelFolder,
modelFolder: cliArguments.modelPath,
tokenizerFolder: downloadTokenizerFolder,
computeOptions: computeOptions,
verbose: cliArguments.verbose,
logLevel: .debug
Expand Down
16 changes: 16 additions & 0 deletions Sources/WhisperKitCLI/WhisperKitCLI.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
// For licensing see accompanying LICENSE.md file.
// Copyright © 2024 Argmax, Inc. All rights reserved.

import ArgumentParser
import Foundation

@available(macOS 13, iOS 16, watchOS 10, visionOS 1, *)
@main
struct WhisperKitCLI: AsyncParsableCommand {
static let configuration = CommandConfiguration(
commandName: "whisperkit-cli",
abstract: "WhisperKit CLI",
discussion: "Swift native speech recognition with Whisper for Apple Silicon",
subcommands: [Transcribe.self]
)
}
1 change: 1 addition & 0 deletions Tests/WhisperKitTests/UnitTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import AVFoundation
import CoreML
import Tokenizers
import Hub
@testable import WhisperKit
import XCTest

Expand Down

0 comments on commit 8588a38

Please sign in to comment.