Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 20 additions & 6 deletions Type4Me/Audio/AudioCaptureEngine.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,14 @@
@preconcurrency import AVFoundation

protocol AudioCapturing: AnyObject, Sendable {
var onAudioChunk: ((Data) -> Void)? { get set }
var onAudioLevel: ((Float) -> Void)? { get set }

func warmUp()
func start() throws
func stop()
}

enum AudioCaptureError: Error, LocalizedError {
case converterCreationFailed
case microphonePermissionDenied
Expand All @@ -17,7 +26,7 @@ enum AudioCaptureError: Error, LocalizedError {
}
}

final class AudioCaptureEngine: NSObject, @unchecked Sendable, AVCaptureAudioDataOutputSampleBufferDelegate {
final class AudioCaptureEngine: NSObject, @unchecked Sendable, AVCaptureAudioDataOutputSampleBufferDelegate, AudioCapturing {

// MARK: - Static properties

Expand Down Expand Up @@ -128,11 +137,16 @@ final class AudioCaptureEngine: NSObject, @unchecked Sendable, AVCaptureAudioDat
}

func stop() {
captureSession?.stopRunning()
captureSession = nil
converter = nil
levelCounter = 0
flushRemaining()
outputQueue.sync {
// Drain any pending AVCapture callbacks on the delegate queue before
// flushing the tail buffer, otherwise the last spoken frames can be
// stranded behind stop() and never reach the session pipeline.
captureSession?.stopRunning()
captureSession = nil
converter = nil
levelCounter = 0
flushRemaining()
}
NSLog("[Audio] Capture session stopped")
}

Expand Down
137 changes: 87 additions & 50 deletions Type4Me/Session/RecognitionSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ import AppKit
import os

actor RecognitionSession {
private static let stopCaptureGracePeriod: Duration = .milliseconds(350)
private static let stopSoundRestoreDelay: Duration = .milliseconds(50)

// MARK: - State

Expand Down Expand Up @@ -30,16 +32,26 @@ actor RecognitionSession {

// MARK: - Dependencies

private let audioEngine = AudioCaptureEngine()
private let injectionEngine = TextInjectionEngine()
let historyStore = HistoryStore()
private let audioEngine: any AudioCapturing
private let injectionEngine: TextInjectionEngine
let historyStore: HistoryStore
private var asrClient: (any SpeechRecognizer)?

private let logger = Logger(
subsystem: "com.type4me.session",
category: "RecognitionSession"
)

init(
audioEngine: any AudioCapturing = AudioCaptureEngine(),
injectionEngine: TextInjectionEngine = TextInjectionEngine(),
historyStore: HistoryStore = HistoryStore()
) {
self.audioEngine = audioEngine
self.injectionEngine = injectionEngine
self.historyStore = historyStore
}

/// Return the appropriate LLM client for the currently selected provider.
private func currentLLMClient() -> any LLMClient {
let provider = KeychainService.selectedLLMProvider
Expand Down Expand Up @@ -368,15 +380,23 @@ actor RecognitionSession {
}

let stopT0 = ContinuousClock.now
SystemVolumeManager.restore() // Restore before stop sound plays
try? await Task.sleep(for: .milliseconds(50)) // Let OS apply volume change
SoundFeedback.playStop()
state = .finishing

// Give CoreAudio a brief grace window to deliver frames that were spoken
// just before the stop hotkey event but have not yet surfaced as callbacks.
try? await Task.sleep(for: Self.stopCaptureGracePeriod)

// Stop capture first so flushRemaining() can emit the tail audio chunk.
audioEngine.stop()
audioEngine.onAudioChunk = nil
await finishAudioChunkPipeline()
audioEngine.onAudioChunk = nil
audioEngine.onAudioLevel = nil

// Stop feedback should play only after the microphone is fully stopped,
// otherwise the end sound itself can mask the user's trailing syllables.
SystemVolumeManager.restore()
try? await Task.sleep(for: Self.stopSoundRestoreDelay)
SoundFeedback.playStop()
DebugFileLogger.log("stop: audio stopped +\(ContinuousClock.now - stopT0)")
guard sessionGeneration == myGeneration else {
DebugFileLogger.log("stopRecording: zombie after audio pipeline, bailing")
Expand All @@ -391,12 +411,14 @@ actor RecognitionSession {
let provider = KeychainService.selectedASRProvider
let canEarlyLLM = ASRProviderRegistry.capabilities(for: provider).isStreaming
var earlyLLMTask: Task<String?, Never>?
var earlyLLMInputText: String?
if needsLLM && canEarlyLLM {
var earlyText = currentTranscript.composedText
.trimmingCharacters(in: .whitespacesAndNewlines)
earlyText = SnippetStorage.applyEffective(to: earlyText)
DebugFileLogger.log("stop: needsLLM=true mode=\(currentMode.name) text=\(earlyText.count)chars specMatch=\(earlyText == speculativeLLMText)")
if !earlyText.isEmpty {
earlyLLMInputText = earlyText
if earlyText == speculativeLLMText, let specTask = speculativeLLMTask {
// Speculative LLM matches — reuse (may already be done!)
earlyLLMTask = specTask
Expand Down Expand Up @@ -426,55 +448,50 @@ actor RecognitionSession {
}
}

// ASR teardown: streaming providers can skip endAudio in LLM modes since
// we already have text. Batch providers (e.g. OpenAI REST) MUST await endAudio
// because that's where the actual recognition happens.
// ASR teardown: correctness matters more than latency here. Always await
// final ASR teardown so trailing words that were captured before the stop
// hotkey still have a chance to become part of the final transcript.
// Batch providers (e.g. OpenAI REST) MUST await endAudio because that's
// where the actual recognition happens.
let providerIsStreaming = ASRProviderRegistry.capabilities(for: provider).isStreaming
if let client = asrClient {
if needsLLM && earlyLLMTask != nil && providerIsStreaming {
// Fast path (streaming only): just disconnect, skip the 2-3s finalization.
eventConsumptionTask?.cancel()
await client.disconnect()
DebugFileLogger.log("stop: ASR fast-disconnect +\(ContinuousClock.now - stopT0)")
} else {
// Full teardown: batch providers get a longer timeout for the HTTP round-trip.
let endAudioTimeout: Duration = providerIsStreaming ? .seconds(3) : .seconds(60)
do {
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask { try await client.endAudio() }
group.addTask {
try await Task.sleep(for: endAudioTimeout)
throw CancellationError()
}
try await group.next()
group.cancelAll()
// Full teardown: batch providers get a longer timeout for the HTTP round-trip.
let endAudioTimeout: Duration = providerIsStreaming ? .seconds(3) : .seconds(60)
do {
try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask { try await client.endAudio() }
group.addTask {
try await Task.sleep(for: endAudioTimeout)
throw CancellationError()
}
} catch {
NSLog("[Session] endAudio timed out or failed: %@", String(describing: error))
DebugFileLogger.log("endAudio timeout/error: \(error)")
try await group.next()
group.cancelAll()
}
let drainTimeout: Duration = providerIsStreaming ? .seconds(2) : .seconds(5)
if let task = eventConsumptionTask {
let streamDrained = await withTaskGroup(of: Bool.self) { group in
group.addTask {
await task.value
return true
}
group.addTask {
try? await Task.sleep(for: drainTimeout)
return false
}
let first = await group.next() ?? true
group.cancelAll()
return first
} catch {
NSLog("[Session] endAudio timed out or failed: %@", String(describing: error))
DebugFileLogger.log("endAudio timeout/error: \(error)")
}
let drainTimeout: Duration = providerIsStreaming ? .seconds(2) : .seconds(5)
if let task = eventConsumptionTask {
let streamDrained = await withTaskGroup(of: Bool.self) { group in
group.addTask {
await task.value
return true
}
if !streamDrained {
task.cancel()
DebugFileLogger.log("event stream drain timeout; eventConsumptionTask cancelled")
group.addTask {
try? await Task.sleep(for: drainTimeout)
return false
}
let first = await group.next() ?? true
group.cancelAll()
return first
}
if !streamDrained {
task.cancel()
DebugFileLogger.log("event stream drain timeout; eventConsumptionTask cancelled")
}
await client.disconnect()
}
await client.disconnect()
}
eventConsumptionTask = nil
asrClient = nil
Expand All @@ -500,7 +517,8 @@ actor RecognitionSession {
// LLM post-processing: prefer early result (fired at stop time),
// fall back to synchronous call for very short recordings where
// no streaming text was available yet.
if let earlyTask = earlyLLMTask {
if let earlyTask = earlyLLMTask,
earlyLLMInputText == finalText {
state = .postProcessing
DebugFileLogger.log("stop: awaiting early LLM result +\(ContinuousClock.now - stopT0)")
let earlyResult = await earlyTask.value
Expand All @@ -518,6 +536,9 @@ actor RecognitionSession {
}
} else if needsLLM {
state = .postProcessing
if let earlyLLMInputText, earlyLLMInputText != finalText {
DebugFileLogger.log("stop: final transcript changed after stop, discarding stale early LLM result")
}
if let llmConfig = KeychainService.loadLLMConfig() {
DebugFileLogger.log("stop: sync LLM firing mode=\(currentMode.name) model=\(llmConfig.model) with \(finalText.count) chars")
do {
Expand Down Expand Up @@ -716,6 +737,22 @@ actor RecognitionSession {
audioChunkSenderTask = nil
}

func prepareForStopTesting(
client: any SpeechRecognizer,
mode: ProcessingMode = .direct,
transcript: RecognitionTranscript = .empty
) {
asrClient = client
currentMode = mode
currentTranscript = transcript
state = .recording

let chunkContinuation = setupAudioChunkPipeline()
audioEngine.onAudioChunk = { data in
chunkContinuation.yield(data)
}
}

private func markReadyIfNeeded() {
guard !hasEmittedReadyForCurrentSession else { return }
hasEmittedReadyForCurrentSession = true
Expand Down Expand Up @@ -804,9 +841,9 @@ actor RecognitionSession {
resetSpeculativeLLM()

audioEngine.stop()
await finishAudioChunkPipeline(timeout: .milliseconds(100))
audioEngine.onAudioChunk = nil
audioEngine.onAudioLevel = nil
await finishAudioChunkPipeline(timeout: .milliseconds(100))

if let client = asrClient {
Task { await client.disconnect() } // fire-and-forget: don't block reset on WebSocket teardown
Expand Down
101 changes: 101 additions & 0 deletions Type4MeTests/RecognitionSessionTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -50,4 +50,105 @@ final class RecognitionSessionTests: XCTestCase {
let mode = await session.currentModeForTesting()
XCTAssertEqual(mode.id, ProcessingMode.directId)
}

func testStopRecordingWaitsForTailChunkBeforeEndAudio() async throws {
KeychainService.selectedASRProvider = .volcano

let tempDB = URL(fileURLWithPath: NSTemporaryDirectory())
.appendingPathComponent(UUID().uuidString)
.appendingPathExtension("sqlite")
let audioEngine = FakeAudioEngine(tailChunk: Data([1, 2, 3, 4]))
let recognizer = FakeSpeechRecognizer()
let session = RecognitionSession(
audioEngine: audioEngine,
historyStore: HistoryStore(path: tempDB.path)
)

await session.prepareForStopTesting(client: recognizer)

let stopTask = Task {
await session.stopRecording()
}

try await Task.sleep(for: .milliseconds(100))
let didCallEndAudio = await recognizer.didCallEndAudio
let sentAudioBeforeResume = await recognizer.sentAudio
XCTAssertFalse(didCallEndAudio)
XCTAssertEqual(sentAudioBeforeResume, [])

await recognizer.resumeSendAudio()
await stopTask.value

let sentAudioAfterResume = await recognizer.sentAudio
let callOrder = await recognizer.callOrder
XCTAssertEqual(sentAudioAfterResume, [Data([1, 2, 3, 4])])
XCTAssertEqual(callOrder, ["sendAudio", "endAudio", "disconnect"])
}
}

private final class FakeAudioEngine: AudioCapturing, @unchecked Sendable {
var onAudioChunk: ((Data) -> Void)?
var onAudioLevel: ((Float) -> Void)?

private let tailChunk: Data

init(tailChunk: Data) {
self.tailChunk = tailChunk
}

func warmUp() {}
func start() throws {}

func stop() {
onAudioChunk?(tailChunk)
}
}

private actor FakeSpeechRecognizer: SpeechRecognizer {
private var sendContinuation: CheckedContinuation<Void, Never>?
private var shouldResumeImmediately = false
private(set) var sentAudio: [Data] = []
private(set) var callOrder: [String] = []

var didCallEndAudio: Bool {
callOrder.contains("endAudio")
}

var events: AsyncStream<RecognitionEvent> {
let (stream, continuation) = AsyncStream<RecognitionEvent>.makeStream()
continuation.finish()
return stream
}

func connect(config: any ASRProviderConfig, options: ASRRequestOptions) async throws {}

func sendAudio(_ data: Data) async throws {
callOrder.append("sendAudio")
if shouldResumeImmediately {
shouldResumeImmediately = false
sentAudio.append(data)
return
}
await withCheckedContinuation { continuation in
sendContinuation = continuation
}
sentAudio.append(data)
}

func endAudio() async throws {
callOrder.append("endAudio")
}

func disconnect() async {
callOrder.append("disconnect")
}

func resumeSendAudio() {
if let sendContinuation {
sendContinuation.resume()
self.sendContinuation = nil
} else {
shouldResumeImmediately = true
}
}
}