From 9c4b51de9e7e03b9bff565bd194c5a5637196371 Mon Sep 17 00:00:00 2001 From: Michael Meding Date: Mon, 4 May 2026 01:03:31 -0300 Subject: [PATCH] fix(providers): sync service models after refresh Business rationale: Custom providers can add models while Osaurus is running, and the picker now refreshes those lists on open. Keeping the provider service snapshot in sync with that refreshed state preserves trust in the local harness: a model that appears available should also be reflected by the service that will handle the request. Coding rationale: Build on the merged refreshConnectedProviders path instead of keeping a second refresh API. Update the existing service actor only after the merged model list changes, and add a process-wide test lock for provider/cache singletons because suite-level serialization does not prevent cross-suite notification races. --- .../Managers/RemoteProviderManager.swift | 65 ++++-- .../Helpers/RemoteProviderTestLock.swift | 50 ++++ .../Model/ModelPickerItemCacheTests.swift | 125 +++++----- .../RemoteProviderManagerRefreshTests.swift | 218 ++++++++++-------- 4 files changed, 287 insertions(+), 171 deletions(-) create mode 100644 Packages/OsaurusCore/Tests/Helpers/RemoteProviderTestLock.swift diff --git a/Packages/OsaurusCore/Managers/RemoteProviderManager.swift b/Packages/OsaurusCore/Managers/RemoteProviderManager.swift index 0e0672c77..63aba2b5b 100644 --- a/Packages/OsaurusCore/Managers/RemoteProviderManager.swift +++ b/Packages/OsaurusCore/Managers/RemoteProviderManager.swift @@ -77,9 +77,10 @@ public final class RemoteProviderManager: ObservableObject { var didChange = false for i in configuration.providers.indices { let host = configuration.providers[i].host.lowercased() - if configuration.providers[i].providerType == .openaiLegacy + let shouldMigrate = + configuration.providers[i].providerType == .openaiLegacy && host.contains("openai.com") - { + if shouldMigrate { configuration.providers[i].providerType = .openResponses didChange = true } @@ -229,12 +230,11 @@ public final class RemoteProviderManager: ObservableObject { providerStates[providerId] = state do { - if provider.authType == .openAICodexOAuth, - let tokens = provider.getOAuthTokens(), - tokens.isExpired - { - let refreshed = try await OpenAICodexOAuthService.refresh(tokens) - RemoteProviderKeychain.saveOAuthTokens(refreshed, for: provider.id) + if provider.authType == .openAICodexOAuth { + if let tokens = provider.getOAuthTokens(), tokens.isExpired { + let refreshed = try await OpenAICodexOAuthService.refresh(tokens) + RemoteProviderKeychain.saveOAuthTokens(refreshed, for: provider.id) + } } // Fetch models from the provider and merge any manually configured deployment IDs. @@ -368,6 +368,9 @@ public final class RemoteProviderManager: ObservableObject { state.discoveredModels = merged providerStates[providerId] = state + if let service = services[providerId] { + await service.updateModels(merged) + } notifyModelsChanged() } @@ -383,9 +386,9 @@ public final class RemoteProviderManager: ObservableObject { let now = Date() let throttle = Self.modelRefetchThrottle let dueIds: [UUID] = self.configuration.enabledProviders.compactMap { provider in - if let last = self.lastModelRefetchAt[provider.id], - now.timeIntervalSince(last) < throttle - { + let lastRefetch = self.lastModelRefetchAt[provider.id] + let isThrottled = lastRefetch.map { now.timeIntervalSince($0) < throttle } ?? false + if isThrottled { return nil } return provider.id @@ -539,9 +542,12 @@ public final class RemoteProviderManager: ObservableObject { // Add headers for (key, value) in testHeaders { // Don't log the full auth header for security - if key.lowercased() == "authorization" || key.lowercased() == "x-api-key" - || key.lowercased() == "x-goog-api-key" - { + let lowercasedKey = key.lowercased() + let shouldRedact = + lowercasedKey == "authorization" + || lowercasedKey == "x-api-key" + || lowercasedKey == "x-goog-api-key" + if shouldRedact { print("[Osaurus] Test Connection: Adding header \(key)=***") } else { print("[Osaurus] Test Connection: Adding header \(key)=\(value)") @@ -624,9 +630,10 @@ public final class RemoteProviderManager: ObservableObject { } /// Test Anthropic connection by fetching models from the /models endpoint - private func testAnthropicConnection(tempProvider: RemoteProvider, testHeaders: [String: String]) async throws - -> [String] - { + private func testAnthropicConnection( + tempProvider: RemoteProvider, + testHeaders: [String: String] + ) async throws -> [String] { guard let baseURL = tempProvider.url(for: "/models") else { print("[Osaurus] Test Connection (Anthropic): Invalid URL") throw RemoteProviderError.invalidURL @@ -659,9 +666,14 @@ public final class RemoteProviderManager: ObservableObject { // MARK: - Test Helpers - /// Insert a fake connected provider directly into state, bypassing the - /// real `connect()` (no network, no service instance). Test-only. - func _testInstallConnectedProvider(_ provider: RemoteProvider, discoveredModels: [String]) { + /// Insert a fake connected provider directly into state, optionally with a + /// matching service instance for tests that assert routing state. Test-only. + @discardableResult + func _testInstallConnectedProvider( + _ provider: RemoteProvider, + discoveredModels: [String], + installService: Bool = false + ) -> RemoteProviderService? { configuration.add(provider) ephemeralProviderIds.insert(provider.id) var state = RemoteProviderState(providerId: provider.id) @@ -669,6 +681,16 @@ public final class RemoteProviderManager: ObservableObject { state.discoveredModels = discoveredModels state.lastConnectedAt = Date() providerStates[provider.id] = state + + guard installService else { return nil } + + let service = RemoteProviderService( + provider: provider, + models: discoveredModels, + resolvedHeaders: provider.resolvedHeaders() + ) + services[provider.id] = service + return service } /// Mutate a test-installed provider's state. Test-only. @@ -684,6 +706,9 @@ public final class RemoteProviderManager: ObservableObject { ephemeralProviderIds.remove(id) providerStates.removeValue(forKey: id) lastModelRefetchAt.removeValue(forKey: id) + if let service = services.removeValue(forKey: id) { + Task { await service.invalidateSession() } + } } refreshConnectedTask = nil testFetchModelsOverride = nil diff --git a/Packages/OsaurusCore/Tests/Helpers/RemoteProviderTestLock.swift b/Packages/OsaurusCore/Tests/Helpers/RemoteProviderTestLock.swift new file mode 100644 index 000000000..78980b7ec --- /dev/null +++ b/Packages/OsaurusCore/Tests/Helpers/RemoteProviderTestLock.swift @@ -0,0 +1,50 @@ +// +// RemoteProviderTestLock.swift +// OsaurusCoreTests +// +// Process-wide serialization for tests that mutate `RemoteProviderManager.shared` +// or assert on `ModelPickerItemCache.shared` while remote-provider notifications +// are in flight. `@Suite(.serialized)` only serializes tests inside one suite. +// + +import Foundation + +actor RemoteProviderTestLock { + static let shared = RemoteProviderTestLock() + + private var holder = false + private var waiters: [CheckedContinuation] = [] + + private func acquire() async { + if !holder { + holder = true + return + } + await withCheckedContinuation { (cont: CheckedContinuation) in + waiters.append(cont) + } + } + + private func release() { + if let next = waiters.first { + waiters.removeFirst() + next.resume() + } else { + holder = false + } + } + + func run( + _ body: @MainActor @Sendable () async throws -> T + ) async rethrows -> T { + await acquire() + do { + let value = try await body() + release() + return value + } catch { + release() + throw error + } + } +} diff --git a/Packages/OsaurusCore/Tests/Model/ModelPickerItemCacheTests.swift b/Packages/OsaurusCore/Tests/Model/ModelPickerItemCacheTests.swift index 958080ef2..7730d567a 100644 --- a/Packages/OsaurusCore/Tests/Model/ModelPickerItemCacheTests.swift +++ b/Packages/OsaurusCore/Tests/Model/ModelPickerItemCacheTests.swift @@ -21,6 +21,7 @@ import Testing @testable import OsaurusCore +@Suite(.serialized) @MainActor struct ModelPickerItemCacheTests { @@ -33,33 +34,35 @@ struct ModelPickerItemCacheTests { /// order, so callers could disagree about whether remote models were /// present. @Test func concurrentCallers_returnIdenticalResults() async throws { - // Establish a baseline so we know what to compare against, and so - // any work needed to populate the cache (e.g. local model - // discovery) doesn't perturb the concurrent run below. - let baselineItems = await ModelPickerItemCache.shared.buildModelPickerItems() - let baselineIds = baselineItems.map(\.id) + await RemoteProviderTestLock.shared.run { + // Establish a baseline so we know what to compare against, and so + // any work needed to populate the cache (e.g. local model + // discovery) doesn't perturb the concurrent run below. + let baselineItems = await ModelPickerItemCache.shared.buildModelPickerItems() + let baselineIds = baselineItems.map(\.id) - // Spawn many detached tasks that each call into the @MainActor - // cache. Detached tasks are deliberately used so the calls hop - // back into the MainActor at the await point and exercise the - // serialized rebuild path the way real callers do (notification - // observer Tasks, the AppDelegate prewarm Task, ChatView's - // refresh Task, and so on). - var tasks: [Task<[String], Never>] = [] - for _ in 0 ..< 32 { - tasks.append( - Task.detached { - let items = await ModelPickerItemCache.shared.buildModelPickerItems() - return items.map(\.id) - } - ) - } + // Spawn many detached tasks that each call into the @MainActor + // cache. Detached tasks are deliberately used so the calls hop + // back into the MainActor at the await point and exercise the + // serialized rebuild path the way real callers do (notification + // observer Tasks, the AppDelegate prewarm Task, ChatView's + // refresh Task, and so on). + var tasks: [Task<[String], Never>] = [] + for _ in 0 ..< 32 { + tasks.append( + Task.detached { + let items = await ModelPickerItemCache.shared.buildModelPickerItems() + return items.map(\.id) + } + ) + } - for task in tasks { - let ids = await task.value - #expect(ids == baselineIds) + for task in tasks { + let ids = await task.value + #expect(ids == baselineIds) + } + #expect(ModelPickerItemCache.shared.isLoaded) } - #expect(ModelPickerItemCache.shared.isLoaded) } /// Posting a burst of `.remoteProviderModelsChanged` notifications used @@ -70,46 +73,48 @@ struct ModelPickerItemCacheTests { /// asserts the invariant that, once populated, `items` never goes /// empty across rebuilds. @Test func notificationBurst_doesNotTransientlyEmptyItems() async throws { - let cache = ModelPickerItemCache.shared + await RemoteProviderTestLock.shared.run { + let cache = ModelPickerItemCache.shared - // Make sure we start populated. If this machine has no foundation - // model, no local MLX models, and no connected remote providers, - // the invariant is trivially satisfied — skip in that case so CI - // doesn't false-positive. - _ = await cache.buildModelPickerItems() - guard !cache.items.isEmpty else { return } - let initialCount = cache.items.count + // Make sure we start populated. If this machine has no foundation + // model, no local MLX models, and no connected remote providers, + // the invariant is trivially satisfied - skip in that case so CI + // doesn't false-positive. + _ = await cache.buildModelPickerItems() + guard !cache.items.isEmpty else { return } + let initialCount = cache.items.count - // Spam many notifications. Each one schedules an observer Task - // that calls `buildModelPickerItems()`. Pre-fix, each Task would - // first set `items = []` and `isLoaded = false`. - for _ in 0 ..< 50 { - NotificationCenter.default.post( - name: .remoteProviderModelsChanged, - object: nil - ) - } + // Spam many notifications. Each one schedules an observer Task + // that calls `buildModelPickerItems()`. Pre-fix, each Task would + // first set `items = []` and `isLoaded = false`. + for _ in 0 ..< 50 { + NotificationCenter.default.post( + name: .remoteProviderModelsChanged, + object: nil + ) + } - // Drain the observer Tasks by repeatedly yielding the MainActor - // and sampling `items`. With the fix, every sample must be - // non-empty — the rebuild only assigns `items` when it has the - // full list. - var samples: [Int] = [] - for _ in 0 ..< 200 { - samples.append(cache.items.count) - try? await Task.sleep(nanoseconds: 200_000) // 0.2ms - } + // Drain the observer Tasks by repeatedly yielding the MainActor + // and sampling `items`. With the fix, every sample must be + // non-empty - the rebuild only assigns `items` when it has the + // full list. + var samples: [Int] = [] + for _ in 0 ..< 200 { + samples.append(cache.items.count) + try? await Task.sleep(nanoseconds: 200_000) // 0.2ms + } - #expect( - !samples.contains(0), - "items must remain populated during rebuilds; observed sample counts: \(samples)" - ) + #expect( + !samples.contains(0), + "items must remain populated during rebuilds; observed sample counts: \(samples)" + ) - // After the burst settles, the cache should still hold a - // populated list (state hasn't actually changed, so it should - // match the initial count). - let final = await cache.buildModelPickerItems() - #expect(!final.isEmpty) - #expect(final.count == initialCount) + // After the burst settles, the cache should still hold a + // populated list (state hasn't actually changed, so it should + // match the initial count). + let final = await cache.buildModelPickerItems() + #expect(!final.isEmpty) + #expect(final.count == initialCount) + } } } diff --git a/Packages/OsaurusCore/Tests/Provider/RemoteProviderManagerRefreshTests.swift b/Packages/OsaurusCore/Tests/Provider/RemoteProviderManagerRefreshTests.swift index 0fb968da3..acaaaa84f 100644 --- a/Packages/OsaurusCore/Tests/Provider/RemoteProviderManagerRefreshTests.swift +++ b/Packages/OsaurusCore/Tests/Provider/RemoteProviderManagerRefreshTests.swift @@ -18,6 +18,7 @@ private final class Counter { func increment() { value += 1 } } +@Suite(.serialized) @MainActor struct RemoteProviderManagerRefreshTests { @@ -56,131 +57,166 @@ struct RemoteProviderManagerRefreshTests { // MARK: - refetchModels @Test func refetchModels_updatesDiscoveredModelsAndPostsNotification() async throws { - let manager = RemoteProviderManager.shared - let provider = install(manager, discovered: ["old-model"]) - defer { manager._testRemoveProviders(ids: [provider.id]) } + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = install(manager, discovered: ["old-model"]) + defer { manager._testRemoveProviders(ids: [provider.id]) } - let counter = Counter() - let observer = observeModelsChanged(counter) - defer { NotificationCenter.default.removeObserver(observer) } + let counter = Counter() + let observer = observeModelsChanged(counter) + defer { NotificationCenter.default.removeObserver(observer) } - manager.testFetchModelsOverride = { _ in ["new-a", "new-b"] } + manager.testFetchModelsOverride = { _ in ["new-a", "new-b"] } - await manager.refetchModels(providerId: provider.id) + await manager.refetchModels(providerId: provider.id) - let updated = manager.providerStates[provider.id]?.discoveredModels ?? [] - #expect(updated == ["new-a", "new-b"]) + let updated = manager.providerStates[provider.id]?.discoveredModels ?? [] + #expect(updated == ["new-a", "new-b"]) - try? await Task.sleep(nanoseconds: 10_000_000) - #expect(counter.value == 1) + try? await Task.sleep(nanoseconds: 10_000_000) + #expect(counter.value == 1) + } + } + + @Test func refetchModels_updatesServiceModelSnapshot() async throws { + try await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = makeProvider() + let service = try #require( + manager._testInstallConnectedProvider( + provider, + discoveredModels: ["old-model"], + installService: true + ) + ) + defer { manager._testRemoveProviders(ids: [provider.id]) } + + manager.testFetchModelsOverride = { _ in ["new-model"] } + + await manager.refetchModels(providerId: provider.id) + + #expect(await service.getRawModels() == ["new-model"]) + } } @Test func refetchModels_skipsNotificationWhenListUnchanged() async throws { - let manager = RemoteProviderManager.shared - let provider = install(manager, discovered: ["same-model"]) - defer { manager._testRemoveProviders(ids: [provider.id]) } + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = install(manager, discovered: ["same-model"]) + defer { manager._testRemoveProviders(ids: [provider.id]) } - let counter = Counter() - let observer = observeModelsChanged(counter) - defer { NotificationCenter.default.removeObserver(observer) } + let counter = Counter() + let observer = observeModelsChanged(counter) + defer { NotificationCenter.default.removeObserver(observer) } - manager.testFetchModelsOverride = { _ in ["same-model"] } - await manager.refetchModels(providerId: provider.id) + manager.testFetchModelsOverride = { _ in ["same-model"] } + await manager.refetchModels(providerId: provider.id) - try? await Task.sleep(nanoseconds: 10_000_000) - #expect(counter.value == 0) + try? await Task.sleep(nanoseconds: 10_000_000) + #expect(counter.value == 0) + } } @Test func refetchModels_preservesStateOnFetchFailure() async throws { - let manager = RemoteProviderManager.shared - let provider = install(manager, discovered: ["keep-me"]) - defer { manager._testRemoveProviders(ids: [provider.id]) } + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = install(manager, discovered: ["keep-me"]) + defer { manager._testRemoveProviders(ids: [provider.id]) } - struct Boom: Error {} - manager.testFetchModelsOverride = { _ in throw Boom() } + struct Boom: Error {} + manager.testFetchModelsOverride = { _ in throw Boom() } - await manager.refetchModels(providerId: provider.id) + await manager.refetchModels(providerId: provider.id) - let state = manager.providerStates[provider.id] - #expect(state?.discoveredModels == ["keep-me"]) - #expect(state?.isConnected == true) + let state = manager.providerStates[provider.id] + #expect(state?.discoveredModels == ["keep-me"]) + #expect(state?.isConnected == true) + } } @Test func refetchModels_noopWhenProviderNotConnected() async throws { - let manager = RemoteProviderManager.shared - let provider = makeProvider(name: "Disconnected") - manager._testInstallConnectedProvider(provider, discoveredModels: ["x"]) - defer { manager._testRemoveProviders(ids: [provider.id]) } - - // Flip to disconnected via the test helper. - var state = manager.providerStates[provider.id]! - state.isConnected = false - manager._testSetState(state, for: provider.id) - - let counter = Counter() - manager.testFetchModelsOverride = { _ in - counter.increment() - return ["should-not-be-fetched"] + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = makeProvider(name: "Disconnected") + manager._testInstallConnectedProvider(provider, discoveredModels: ["x"]) + defer { manager._testRemoveProviders(ids: [provider.id]) } + + // Flip to disconnected via the test helper. + var state = manager.providerStates[provider.id]! + state.isConnected = false + manager._testSetState(state, for: provider.id) + + let counter = Counter() + manager.testFetchModelsOverride = { _ in + counter.increment() + return ["should-not-be-fetched"] + } + + await manager.refetchModels(providerId: provider.id) + #expect(counter.value == 0) } - - await manager.refetchModels(providerId: provider.id) - #expect(counter.value == 0) } // MARK: - refreshConnectedProviders @Test func refreshConnectedProviders_throttlesRepeatedCalls() async throws { - let manager = RemoteProviderManager.shared - let provider = install(manager) - defer { manager._testRemoveProviders(ids: [provider.id]) } - - let counter = Counter() - manager.testFetchModelsOverride = { _ in - counter.increment() - return ["a"] + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = install(manager) + defer { manager._testRemoveProviders(ids: [provider.id]) } + + let counter = Counter() + manager.testFetchModelsOverride = { _ in + counter.increment() + return ["a"] + } + + await manager.refreshConnectedProviders() + await manager.refreshConnectedProviders() + await manager.refreshConnectedProviders() + + #expect(counter.value == 1, "second + third calls should be throttled within the window") } - - await manager.refreshConnectedProviders() - await manager.refreshConnectedProviders() - await manager.refreshConnectedProviders() - - #expect(counter.value == 1, "second + third calls should be throttled within the window") } @Test func refreshConnectedProviders_coalescesConcurrentCalls() async throws { - let manager = RemoteProviderManager.shared - let provider = install(manager) - defer { manager._testRemoveProviders(ids: [provider.id]) } - - let counter = Counter() - manager.testFetchModelsOverride = { _ in - counter.increment() - try? await Task.sleep(nanoseconds: 30_000_000) // 30ms — long enough to overlap - return ["a"] + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + let provider = install(manager) + defer { manager._testRemoveProviders(ids: [provider.id]) } + + let counter = Counter() + manager.testFetchModelsOverride = { _ in + counter.increment() + try? await Task.sleep(nanoseconds: 30_000_000) // 30ms - long enough to overlap + return ["a"] + } + + async let r1: Void = manager.refreshConnectedProviders() + async let r2: Void = manager.refreshConnectedProviders() + async let r3: Void = manager.refreshConnectedProviders() + _ = await (r1, r2, r3) + + #expect(counter.value == 1, "concurrent callers should coalesce onto a single fetch") } - - async let r1: Void = manager.refreshConnectedProviders() - async let r2: Void = manager.refreshConnectedProviders() - async let r3: Void = manager.refreshConnectedProviders() - _ = await (r1, r2, r3) - - #expect(counter.value == 1, "concurrent callers should coalesce onto a single fetch") } @Test func refreshConnectedProviders_skipsDisabledProviders() async throws { - let manager = RemoteProviderManager.shared - var provider = makeProvider(name: "Disabled") - provider.enabled = false - manager._testInstallConnectedProvider(provider, discoveredModels: ["x"]) - defer { manager._testRemoveProviders(ids: [provider.id]) } - - let counter = Counter() - manager.testFetchModelsOverride = { _ in - counter.increment() - return ["y"] + await RemoteProviderTestLock.shared.run { + let manager = RemoteProviderManager.shared + var provider = makeProvider(name: "Disabled") + provider.enabled = false + manager._testInstallConnectedProvider(provider, discoveredModels: ["x"]) + defer { manager._testRemoveProviders(ids: [provider.id]) } + + let counter = Counter() + manager.testFetchModelsOverride = { _ in + counter.increment() + return ["y"] + } + + await manager.refreshConnectedProviders() + #expect(counter.value == 0) } - - await manager.refreshConnectedProviders() - #expect(counter.value == 0) } }