diff --git a/Packages/OsaurusCore/Managers/RemoteProviderManager.swift b/Packages/OsaurusCore/Managers/RemoteProviderManager.swift index 0e0672c77..63aba2b5b 100644 --- a/Packages/OsaurusCore/Managers/RemoteProviderManager.swift +++ b/Packages/OsaurusCore/Managers/RemoteProviderManager.swift @@ -77,9 +77,10 @@ public final class RemoteProviderManager: ObservableObject { var didChange = false for i in configuration.providers.indices { let host = configuration.providers[i].host.lowercased() - if configuration.providers[i].providerType == .openaiLegacy + let shouldMigrate = + configuration.providers[i].providerType == .openaiLegacy && host.contains("openai.com") - { + if shouldMigrate { configuration.providers[i].providerType = .openResponses didChange = true } @@ -229,12 +230,11 @@ public final class RemoteProviderManager: ObservableObject { providerStates[providerId] = state do { - if provider.authType == .openAICodexOAuth, - let tokens = provider.getOAuthTokens(), - tokens.isExpired - { - let refreshed = try await OpenAICodexOAuthService.refresh(tokens) - RemoteProviderKeychain.saveOAuthTokens(refreshed, for: provider.id) + if provider.authType == .openAICodexOAuth { + if let tokens = provider.getOAuthTokens(), tokens.isExpired { + let refreshed = try await OpenAICodexOAuthService.refresh(tokens) + RemoteProviderKeychain.saveOAuthTokens(refreshed, for: provider.id) + } } // Fetch models from the provider and merge any manually configured deployment IDs. @@ -368,6 +368,9 @@ public final class RemoteProviderManager: ObservableObject { state.discoveredModels = merged providerStates[providerId] = state + if let service = services[providerId] { + await service.updateModels(merged) + } notifyModelsChanged() } @@ -383,9 +386,9 @@ public final class RemoteProviderManager: ObservableObject { let now = Date() let throttle = Self.modelRefetchThrottle let dueIds: [UUID] = self.configuration.enabledProviders.compactMap { provider in - if let last = self.lastModelRefetchAt[provider.id], - now.timeIntervalSince(last) < throttle - { + let lastRefetch = self.lastModelRefetchAt[provider.id] + let isThrottled = lastRefetch.map { now.timeIntervalSince($0) < throttle } ?? false + if isThrottled { return nil } return provider.id @@ -539,9 +542,12 @@ public final class RemoteProviderManager: ObservableObject { // Add headers for (key, value) in testHeaders { // Don't log the full auth header for security - if key.lowercased() == "authorization" || key.lowercased() == "x-api-key" - || key.lowercased() == "x-goog-api-key" - { + let lowercasedKey = key.lowercased() + let shouldRedact = + lowercasedKey == "authorization" + || lowercasedKey == "x-api-key" + || lowercasedKey == "x-goog-api-key" + if shouldRedact { print("[Osaurus] Test Connection: Adding header \(key)=***") } else { print("[Osaurus] Test Connection: Adding header \(key)=\(value)") @@ -624,9 +630,10 @@ public final class RemoteProviderManager: ObservableObject { } /// Test Anthropic connection by fetching models from the /models endpoint - private func testAnthropicConnection(tempProvider: RemoteProvider, testHeaders: [String: String]) async throws - -> [String] - { + private func testAnthropicConnection( + tempProvider: RemoteProvider, + testHeaders: [String: String] + ) async throws -> [String] { guard let baseURL = tempProvider.url(for: "/models") else { print("[Osaurus] Test Connection (Anthropic): Invalid URL") throw RemoteProviderError.invalidURL @@ -659,9 +666,14 @@ public final class RemoteProviderManager: ObservableObject { // MARK: - Test Helpers - /// Insert a fake connected provider directly into state, bypassing the - /// real `connect()` (no network, no service instance). Test-only. - func _testInstallConnectedProvider(_ provider: RemoteProvider, discoveredModels: [String]) { + /// Insert a fake connected provider directly into state, optionally with a + /// matching service instance for tests that assert routing state. Test-only. + @discardableResult + func _testInstallConnectedProvider( + _ provider: RemoteProvider, + discoveredModels: [String], + installService: Bool = false + ) -> RemoteProviderService? { configuration.add(provider) ephemeralProviderIds.insert(provider.id) var state = RemoteProviderState(providerId: provider.id) @@ -669,6 +681,16 @@ public final class RemoteProviderManager: ObservableObject { state.discoveredModels = discoveredModels state.lastConnectedAt = Date() providerStates[provider.id] = state + + guard installService else { return nil } + + let service = RemoteProviderService( + provider: provider, + models: discoveredModels, + resolvedHeaders: provider.resolvedHeaders() + ) + services[provider.id] = service + return service } /// Mutate a test-installed provider's state. Test-only. @@ -684,6 +706,9 @@ public final class RemoteProviderManager: ObservableObject { ephemeralProviderIds.remove(id) providerStates.removeValue(forKey: id) lastModelRefetchAt.removeValue(forKey: id) + if let service = services.removeValue(forKey: id) { + Task { await service.invalidateSession() } + } } refreshConnectedTask = nil testFetchModelsOverride = nil diff --git a/Packages/OsaurusCore/Tests/Helpers/RemoteProviderTestLock.swift b/Packages/OsaurusCore/Tests/Helpers/RemoteProviderTestLock.swift new file mode 100644 index 000000000..78980b7ec --- /dev/null +++ b/Packages/OsaurusCore/Tests/Helpers/RemoteProviderTestLock.swift @@ -0,0 +1,50 @@ +// +// RemoteProviderTestLock.swift +// OsaurusCoreTests +// +// Process-wide serialization for tests that mutate `RemoteProviderManager.shared` +// or assert on `ModelPickerItemCache.shared` while remote-provider notifications +// are in flight. `@Suite(.serialized)` only serializes tests inside one suite. +// + +import Foundation + +actor RemoteProviderTestLock { + static let shared = RemoteProviderTestLock() + + private var holder = false + private var waiters: [CheckedContinuation] = [] + + private func acquire() async { + if !holder { + holder = true + return + } + await withCheckedContinuation { (cont: CheckedContinuation) in + waiters.append(cont) + } + } + + private func release() { + if let next = waiters.first { + waiters.removeFirst() + next.resume() + } else { + holder = false + } + } + + func run( + _ body: @MainActor @Sendable () async throws -> T + ) async rethrows -> T { + await acquire() + do { + let value = try await body() + release() + return value + } catch { + release() + throw error + } + } +} diff --git a/Packages/OsaurusCore/Tests/Model/ModelPickerItemCacheTests.swift b/Packages/OsaurusCore/Tests/Model/ModelPickerItemCacheTests.swift index 958080ef2..7730d567a 100644 --- a/Packages/OsaurusCore/Tests/Model/ModelPickerItemCacheTests.swift +++ b/Packages/OsaurusCore/Tests/Model/ModelPickerItemCacheTests.swift @@ -21,6 +21,7 @@ import Testing @testable import OsaurusCore +@Suite(.serialized) @MainActor struct ModelPickerItemCacheTests { @@ -33,33 +34,35 @@ struct ModelPickerItemCacheTests { /// order, so callers could disagree about whether remote models were /// present. @Test func concurrentCallers_returnIdenticalResults() async throws { - // Establish a baseline so we know what to compare against, and so - // any work needed to populate the cache (e.g. local model - // discovery) doesn't perturb the concurrent run below. - let baselineItems = await ModelPickerItemCache.shared.buildModelPickerItems() - let baselineIds = baselineItems.map(\.id) + await RemoteProviderTestLock.shared.run { + // Establish a baseline so we know what to compare against, and so + // any work needed to populate the cache (e.g. local model + // discovery) doesn't perturb the concurrent run below. + let baselineItems = await ModelPickerItemCache.shared.buildModelPickerItems() + let baselineIds = baselineItems.map(\.id) - // Spawn many detached tasks that each call into the @MainActor - // cache. Detached tasks are deliberately used so the calls hop - // back into the MainActor at the await point and exercise the - // serialized rebuild path the way real callers do (notification - // observer Tasks, the AppDelegate prewarm Task, ChatView's - // refresh Task, and so on). - var tasks: [Task<[String], Never>] = [] - for _ in 0 ..< 32 { - tasks.append( - Task.detached { - let items = await ModelPickerItemCache.shared.buildModelPickerItems() - return items.map(\.id) - } - ) - } + // Spawn many detached tasks that each call into the @MainActor + // cache. Detached tasks are deliberately used so the calls hop + // back into the MainActor at the await point and exercise the + // serialized rebuild path the way real callers do (notification + // observer Tasks, the AppDelegate prewarm Task, ChatView's + // refresh Task, and so on). + var tasks: [Task<[String], Never>] = [] + for _ in 0 ..< 32 { + tasks.append( + Task.detached { + let items = await ModelPickerItemCache.shared.buildModelPickerItems() + return items.map(\.id) + } + ) + } - for task in tasks { - let ids = await task.value - #expect(ids == baselineIds) + for task in tasks { + let ids = await task.value + #expect(ids == baselineIds) + } + #expect(ModelPickerItemCache.shared.isLoaded) } - #expect(ModelPickerItemCache.shared.isLoaded) } /// Posting a burst of `.remoteProviderModelsChanged` notifications used @@ -70,46 +73,48 @@ struct ModelPickerItemCacheTests { /// asserts the invariant that, once populated, `items` never goes /// empty across rebuilds. @Test func notificationBurst_doesNotTransientlyEmptyItems() async throws { - let cache = ModelPickerItemCache.shared + await RemoteProviderTestLock.shared.run { + let cache = ModelPickerItemCache.shared - // Make sure we start populated. If this machine has no foundation - // model, no local MLX models, and no connected remote providers, - // the invariant is trivially satisfied — skip in that case so CI - // doesn't false-positive. - _ = await cache.buildModelPickerItems() - guard !cache.items.isEmpty else { return } - let initialCount = cache.items.count + // Make sure we start populated. If this machine has no foundation + // model, no local MLX models, and no connected remote providers, + // the invariant is trivially satisfied - skip in that case so CI + // doesn't false-positive. + _ = await cache.buildModelPickerItems() + guard !cache.items.isEmpty else { return } + let initialCount = cache.items.count - // Spam many notifications. Each one schedules an observer Task - // that calls `buildModelPickerItems()`. Pre-fix, each Task would - // first set `items = []` and `isLoaded = false`. - for _ in 0 ..< 50 { - NotificationCenter.default.post( - name: .remoteProviderModelsChanged, - object: nil - ) - } + // Spam many notifications. Each one schedules an observer Task + // that calls `buildModelPickerItems()`. Pre-fix, each Task would + // first set `items = []` and `isLoaded = false`. + for _ in 0 ..< 50 { + NotificationCenter.default.post( + name: .remoteProviderModelsChanged, + object: nil + ) + } - // Drain the observer Tasks by repeatedly yielding the MainActor - // and sampling `items`. With the fix, every sample must be - // non-empty — the rebuild only assigns `items` when it has the - // full list. - var samples: [Int] = [] - for _ in 0 ..< 200 { - samples.append(cache.items.count) - try? await Task.sleep(nanoseconds: 200_000) // 0.2ms - } + // Drain the observer Tasks by repeatedly yielding the MainActor + // and sampling `items`. With the fix, every sample must be + // non-empty - the rebuild only assigns `items` when it has the + // full list. + var samples: [Int] = [] + for _ in 0 ..< 200 { + samples.append(cache.items.count) + try? await Task.sleep(nanoseconds: 200_000) // 0.2ms + } - #expect( - !samples.contains(0), - "items must remain populated during rebuilds; observed sample counts: \(samples)" - ) + #expect( + !samples.contains(0), + "items must remain populated during rebuilds; observed sample counts: \(samples)" + ) - // After the burst settles, the cache should still hold a - // populated list (state hasn't actually changed, so it should - // match the initial count). - let final = await cache.buildModelPickerItems() - #expect(!final.isEmpty) - #expect(final.count == initialCount) + // After the burst settles, the cache should still hold a + // populated list (state hasn't actually changed, so it should + // match the initial count). + let final = await cache.buildModelPickerItems() + #expect(!final.isEmpty) + #expect(final.count == initialCount) + } } } diff --git a/Packages/OsaurusCore/Tests/Provider/RemoteProviderManagerRefreshTests.swift b/Packages/OsaurusCore/Tests/Provider/RemoteProviderManagerRefreshTests.swift index 0fb968da3..acaaaa84f 100644 --- a/Packages/OsaurusCore/Tests/Provider/RemoteProviderManagerRefreshTests.swift +++ b/Packages/OsaurusCore/Tests/Provider/RemoteProviderManagerRefreshTests.swift @@ -18,6 +18,7 @@ private final class Counter { func increment() { value += 1 } } +@Suite(.serialized) @MainActor struct RemoteProviderManagerRefreshTests { @@ -56,131 +57,166 @@ struct RemoteProviderManagerRefreshTests { // MARK: - refetchModels @Test func refetchModels_updatesDiscoveredModelsAndPostsNotification() async throws { - let manager = RemoteProviderManager.shared - let provider = install(manager, discovered: ["old-model"]) - defer { manager._testRemoveProviders(ids: [provider.id]) } + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = install(manager, discovered: ["old-model"]) + defer { manager._testRemoveProviders(ids: [provider.id]) } - let counter = Counter() - let observer = observeModelsChanged(counter) - defer { NotificationCenter.default.removeObserver(observer) } + let counter = Counter() + let observer = observeModelsChanged(counter) + defer { NotificationCenter.default.removeObserver(observer) } - manager.testFetchModelsOverride = { _ in ["new-a", "new-b"] } + manager.testFetchModelsOverride = { _ in ["new-a", "new-b"] } - await manager.refetchModels(providerId: provider.id) + await manager.refetchModels(providerId: provider.id) - let updated = manager.providerStates[provider.id]?.discoveredModels ?? [] - #expect(updated == ["new-a", "new-b"]) + let updated = manager.providerStates[provider.id]?.discoveredModels ?? [] + #expect(updated == ["new-a", "new-b"]) - try? await Task.sleep(nanoseconds: 10_000_000) - #expect(counter.value == 1) + try? await Task.sleep(nanoseconds: 10_000_000) + #expect(counter.value == 1) + } + } + + @Test func refetchModels_updatesServiceModelSnapshot() async throws { + try await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = makeProvider() + let service = try #require( + manager._testInstallConnectedProvider( + provider, + discoveredModels: ["old-model"], + installService: true + ) + ) + defer { manager._testRemoveProviders(ids: [provider.id]) } + + manager.testFetchModelsOverride = { _ in ["new-model"] } + + await manager.refetchModels(providerId: provider.id) + + #expect(await service.getRawModels() == ["new-model"]) + } } @Test func refetchModels_skipsNotificationWhenListUnchanged() async throws { - let manager = RemoteProviderManager.shared - let provider = install(manager, discovered: ["same-model"]) - defer { manager._testRemoveProviders(ids: [provider.id]) } + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = install(manager, discovered: ["same-model"]) + defer { manager._testRemoveProviders(ids: [provider.id]) } - let counter = Counter() - let observer = observeModelsChanged(counter) - defer { NotificationCenter.default.removeObserver(observer) } + let counter = Counter() + let observer = observeModelsChanged(counter) + defer { NotificationCenter.default.removeObserver(observer) } - manager.testFetchModelsOverride = { _ in ["same-model"] } - await manager.refetchModels(providerId: provider.id) + manager.testFetchModelsOverride = { _ in ["same-model"] } + await manager.refetchModels(providerId: provider.id) - try? await Task.sleep(nanoseconds: 10_000_000) - #expect(counter.value == 0) + try? await Task.sleep(nanoseconds: 10_000_000) + #expect(counter.value == 0) + } } @Test func refetchModels_preservesStateOnFetchFailure() async throws { - let manager = RemoteProviderManager.shared - let provider = install(manager, discovered: ["keep-me"]) - defer { manager._testRemoveProviders(ids: [provider.id]) } + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = install(manager, discovered: ["keep-me"]) + defer { manager._testRemoveProviders(ids: [provider.id]) } - struct Boom: Error {} - manager.testFetchModelsOverride = { _ in throw Boom() } + struct Boom: Error {} + manager.testFetchModelsOverride = { _ in throw Boom() } - await manager.refetchModels(providerId: provider.id) + await manager.refetchModels(providerId: provider.id) - let state = manager.providerStates[provider.id] - #expect(state?.discoveredModels == ["keep-me"]) - #expect(state?.isConnected == true) + let state = manager.providerStates[provider.id] + #expect(state?.discoveredModels == ["keep-me"]) + #expect(state?.isConnected == true) + } } @Test func refetchModels_noopWhenProviderNotConnected() async throws { - let manager = RemoteProviderManager.shared - let provider = makeProvider(name: "Disconnected") - manager._testInstallConnectedProvider(provider, discoveredModels: ["x"]) - defer { manager._testRemoveProviders(ids: [provider.id]) } - - // Flip to disconnected via the test helper. - var state = manager.providerStates[provider.id]! - state.isConnected = false - manager._testSetState(state, for: provider.id) - - let counter = Counter() - manager.testFetchModelsOverride = { _ in - counter.increment() - return ["should-not-be-fetched"] + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = makeProvider(name: "Disconnected") + manager._testInstallConnectedProvider(provider, discoveredModels: ["x"]) + defer { manager._testRemoveProviders(ids: [provider.id]) } + + // Flip to disconnected via the test helper. + var state = manager.providerStates[provider.id]! + state.isConnected = false + manager._testSetState(state, for: provider.id) + + let counter = Counter() + manager.testFetchModelsOverride = { _ in + counter.increment() + return ["should-not-be-fetched"] + } + + await manager.refetchModels(providerId: provider.id) + #expect(counter.value == 0) } - - await manager.refetchModels(providerId: provider.id) - #expect(counter.value == 0) } // MARK: - refreshConnectedProviders @Test func refreshConnectedProviders_throttlesRepeatedCalls() async throws { - let manager = RemoteProviderManager.shared - let provider = install(manager) - defer { manager._testRemoveProviders(ids: [provider.id]) } - - let counter = Counter() - manager.testFetchModelsOverride = { _ in - counter.increment() - return ["a"] + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = install(manager) + defer { manager._testRemoveProviders(ids: [provider.id]) } + + let counter = Counter() + manager.testFetchModelsOverride = { _ in + counter.increment() + return ["a"] + } + + await manager.refreshConnectedProviders() + await manager.refreshConnectedProviders() + await manager.refreshConnectedProviders() + + #expect(counter.value == 1, "second + third calls should be throttled within the window") } - - await manager.refreshConnectedProviders() - await manager.refreshConnectedProviders() - await manager.refreshConnectedProviders() - - #expect(counter.value == 1, "second + third calls should be throttled within the window") } @Test func refreshConnectedProviders_coalescesConcurrentCalls() async throws { - let manager = RemoteProviderManager.shared - let provider = install(manager) - defer { manager._testRemoveProviders(ids: [provider.id]) } - - let counter = Counter() - manager.testFetchModelsOverride = { _ in - counter.increment() - try? await Task.sleep(nanoseconds: 30_000_000) // 30ms — long enough to overlap - return ["a"] + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = install(manager) + defer { manager._testRemoveProviders(ids: [provider.id]) } + + let counter = Counter() + manager.testFetchModelsOverride = { _ in + counter.increment() + try? await Task.sleep(nanoseconds: 30_000_000) // 30ms - long enough to overlap + return ["a"] + } + + async let r1: Void = manager.refreshConnectedProviders() + async let r2: Void = manager.refreshConnectedProviders() + async let r3: Void = manager.refreshConnectedProviders() + _ = await (r1, r2, r3) + + #expect(counter.value == 1, "concurrent callers should coalesce onto a single fetch") } - - async let r1: Void = manager.refreshConnectedProviders() - async let r2: Void = manager.refreshConnectedProviders() - async let r3: Void = manager.refreshConnectedProviders() - _ = await (r1, r2, r3) - - #expect(counter.value == 1, "concurrent callers should coalesce onto a single fetch") } @Test func refreshConnectedProviders_skipsDisabledProviders() async throws { - let manager = RemoteProviderManager.shared - var provider = makeProvider(name: "Disabled") - provider.enabled = false - manager._testInstallConnectedProvider(provider, discoveredModels: ["x"]) - defer { manager._testRemoveProviders(ids: [provider.id]) } - - let counter = Counter() - manager.testFetchModelsOverride = { _ in - counter.increment() - return ["y"] + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + var provider = makeProvider(name: "Disabled") + provider.enabled = false + manager._testInstallConnectedProvider(provider, discoveredModels: ["x"]) + defer { manager._testRemoveProviders(ids: [provider.id]) } + + let counter = Counter() + manager.testFetchModelsOverride = { _ in + counter.increment() + return ["y"] + } + + await manager.refreshConnectedProviders() + #expect(counter.value == 0) } - - await manager.refreshConnectedProviders() - #expect(counter.value == 0) } }