Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 45 additions & 20 deletions Packages/OsaurusCore/Managers/RemoteProviderManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,10 @@ public final class RemoteProviderManager: ObservableObject {
var didChange = false
for i in configuration.providers.indices {
let host = configuration.providers[i].host.lowercased()
if configuration.providers[i].providerType == .openaiLegacy
let shouldMigrate =
configuration.providers[i].providerType == .openaiLegacy
&& host.contains("openai.com")
{
if shouldMigrate {
configuration.providers[i].providerType = .openResponses
didChange = true
}
Expand Down Expand Up @@ -229,12 +230,11 @@ public final class RemoteProviderManager: ObservableObject {
providerStates[providerId] = state

do {
if provider.authType == .openAICodexOAuth,
let tokens = provider.getOAuthTokens(),
tokens.isExpired
{
let refreshed = try await OpenAICodexOAuthService.refresh(tokens)
RemoteProviderKeychain.saveOAuthTokens(refreshed, for: provider.id)
if provider.authType == .openAICodexOAuth {
if let tokens = provider.getOAuthTokens(), tokens.isExpired {
let refreshed = try await OpenAICodexOAuthService.refresh(tokens)
RemoteProviderKeychain.saveOAuthTokens(refreshed, for: provider.id)
}
}

// Fetch models from the provider and merge any manually configured deployment IDs.
Expand Down Expand Up @@ -368,6 +368,9 @@ public final class RemoteProviderManager: ObservableObject {

state.discoveredModels = merged
providerStates[providerId] = state
if let service = services[providerId] {
await service.updateModels(merged)
}
notifyModelsChanged()
}

Expand All @@ -383,9 +386,9 @@ public final class RemoteProviderManager: ObservableObject {
let now = Date()
let throttle = Self.modelRefetchThrottle
let dueIds: [UUID] = self.configuration.enabledProviders.compactMap { provider in
if let last = self.lastModelRefetchAt[provider.id],
now.timeIntervalSince(last) < throttle
{
let lastRefetch = self.lastModelRefetchAt[provider.id]
let isThrottled = lastRefetch.map { now.timeIntervalSince($0) < throttle } ?? false
if isThrottled {
return nil
}
return provider.id
Expand Down Expand Up @@ -539,9 +542,12 @@ public final class RemoteProviderManager: ObservableObject {
// Add headers
for (key, value) in testHeaders {
// Don't log the full auth header for security
if key.lowercased() == "authorization" || key.lowercased() == "x-api-key"
|| key.lowercased() == "x-goog-api-key"
{
let lowercasedKey = key.lowercased()
let shouldRedact =
lowercasedKey == "authorization"
|| lowercasedKey == "x-api-key"
|| lowercasedKey == "x-goog-api-key"
if shouldRedact {
print("[Osaurus] Test Connection: Adding header \(key)=***")
} else {
print("[Osaurus] Test Connection: Adding header \(key)=\(value)")
Expand Down Expand Up @@ -624,9 +630,10 @@ public final class RemoteProviderManager: ObservableObject {
}

/// Test Anthropic connection by fetching models from the /models endpoint
private func testAnthropicConnection(tempProvider: RemoteProvider, testHeaders: [String: String]) async throws
-> [String]
{
private func testAnthropicConnection(
tempProvider: RemoteProvider,
testHeaders: [String: String]
) async throws -> [String] {
guard let baseURL = tempProvider.url(for: "/models") else {
print("[Osaurus] Test Connection (Anthropic): Invalid URL")
throw RemoteProviderError.invalidURL
Expand Down Expand Up @@ -659,16 +666,31 @@ public final class RemoteProviderManager: ObservableObject {

// MARK: - Test Helpers

/// Insert a fake connected provider directly into state, bypassing the
/// real `connect()` (no network, no service instance). Test-only.
func _testInstallConnectedProvider(_ provider: RemoteProvider, discoveredModels: [String]) {
/// Insert a fake connected provider directly into state, optionally with a
/// matching service instance for tests that assert routing state. Test-only.
@discardableResult
func _testInstallConnectedProvider(
_ provider: RemoteProvider,
discoveredModels: [String],
installService: Bool = false
) -> RemoteProviderService? {
configuration.add(provider)
ephemeralProviderIds.insert(provider.id)
var state = RemoteProviderState(providerId: provider.id)
state.isConnected = true
state.discoveredModels = discoveredModels
state.lastConnectedAt = Date()
providerStates[provider.id] = state

guard installService else { return nil }

let service = RemoteProviderService(
provider: provider,
models: discoveredModels,
resolvedHeaders: provider.resolvedHeaders()
)
services[provider.id] = service
return service
}

/// Mutate a test-installed provider's state. Test-only.
Expand All @@ -684,6 +706,9 @@ public final class RemoteProviderManager: ObservableObject {
ephemeralProviderIds.remove(id)
providerStates.removeValue(forKey: id)
lastModelRefetchAt.removeValue(forKey: id)
if let service = services.removeValue(forKey: id) {
Task { await service.invalidateSession() }
}
}
refreshConnectedTask = nil
testFetchModelsOverride = nil
Expand Down
50 changes: 50 additions & 0 deletions Packages/OsaurusCore/Tests/Helpers/RemoteProviderTestLock.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
//
// RemoteProviderTestLock.swift
// OsaurusCoreTests
//
// Process-wide serialization for tests that mutate `RemoteProviderManager.shared`
// or assert on `ModelPickerItemCache.shared` while remote-provider notifications
// are in flight. `@Suite(.serialized)` only serializes tests inside one suite.
//

import Foundation

actor RemoteProviderTestLock {
static let shared = RemoteProviderTestLock()

private var holder = false
private var waiters: [CheckedContinuation<Void, Never>] = []

private func acquire() async {
if !holder {
holder = true
return
}
await withCheckedContinuation { (cont: CheckedContinuation<Void, Never>) in
waiters.append(cont)
}
}

private func release() {
if let next = waiters.first {
waiters.removeFirst()
next.resume()
} else {
holder = false
}
}

func run<T: Sendable>(
_ body: @MainActor @Sendable () async throws -> T
) async rethrows -> T {
await acquire()
do {
let value = try await body()
release()
return value
} catch {
release()
throw error
}
}
}
125 changes: 65 additions & 60 deletions Packages/OsaurusCore/Tests/Model/ModelPickerItemCacheTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import Testing

@testable import OsaurusCore

@Suite(.serialized)
@MainActor
struct ModelPickerItemCacheTests {

Expand All @@ -33,33 +34,35 @@ struct ModelPickerItemCacheTests {
/// order, so callers could disagree about whether remote models were
/// present.
@Test func concurrentCallers_returnIdenticalResults() async throws {
// Establish a baseline so we know what to compare against, and so
// any work needed to populate the cache (e.g. local model
// discovery) doesn't perturb the concurrent run below.
let baselineItems = await ModelPickerItemCache.shared.buildModelPickerItems()
let baselineIds = baselineItems.map(\.id)
await RemoteProviderTestLock.shared.run {
// Establish a baseline so we know what to compare against, and so
// any work needed to populate the cache (e.g. local model
// discovery) doesn't perturb the concurrent run below.
let baselineItems = await ModelPickerItemCache.shared.buildModelPickerItems()
let baselineIds = baselineItems.map(\.id)

// Spawn many detached tasks that each call into the @MainActor
// cache. Detached tasks are deliberately used so the calls hop
// back into the MainActor at the await point and exercise the
// serialized rebuild path the way real callers do (notification
// observer Tasks, the AppDelegate prewarm Task, ChatView's
// refresh Task, and so on).
var tasks: [Task<[String], Never>] = []
for _ in 0 ..< 32 {
tasks.append(
Task.detached {
let items = await ModelPickerItemCache.shared.buildModelPickerItems()
return items.map(\.id)
}
)
}
// Spawn many detached tasks that each call into the @MainActor
// cache. Detached tasks are deliberately used so the calls hop
// back into the MainActor at the await point and exercise the
// serialized rebuild path the way real callers do (notification
// observer Tasks, the AppDelegate prewarm Task, ChatView's
// refresh Task, and so on).
var tasks: [Task<[String], Never>] = []
for _ in 0 ..< 32 {
tasks.append(
Task.detached {
let items = await ModelPickerItemCache.shared.buildModelPickerItems()
return items.map(\.id)
}
)
}

for task in tasks {
let ids = await task.value
#expect(ids == baselineIds)
for task in tasks {
let ids = await task.value
#expect(ids == baselineIds)
}
#expect(ModelPickerItemCache.shared.isLoaded)
}
#expect(ModelPickerItemCache.shared.isLoaded)
}

/// Posting a burst of `.remoteProviderModelsChanged` notifications used
Expand All @@ -70,46 +73,48 @@ struct ModelPickerItemCacheTests {
/// asserts the invariant that, once populated, `items` never goes
/// empty across rebuilds.
@Test func notificationBurst_doesNotTransientlyEmptyItems() async throws {
let cache = ModelPickerItemCache.shared
await RemoteProviderTestLock.shared.run {
let cache = ModelPickerItemCache.shared

// Make sure we start populated. If this machine has no foundation
// model, no local MLX models, and no connected remote providers,
// the invariant is trivially satisfied skip in that case so CI
// doesn't false-positive.
_ = await cache.buildModelPickerItems()
guard !cache.items.isEmpty else { return }
let initialCount = cache.items.count
// Make sure we start populated. If this machine has no foundation
// model, no local MLX models, and no connected remote providers,
// the invariant is trivially satisfied - skip in that case so CI
// doesn't false-positive.
_ = await cache.buildModelPickerItems()
guard !cache.items.isEmpty else { return }
let initialCount = cache.items.count

// Spam many notifications. Each one schedules an observer Task
// that calls `buildModelPickerItems()`. Pre-fix, each Task would
// first set `items = []` and `isLoaded = false`.
for _ in 0 ..< 50 {
NotificationCenter.default.post(
name: .remoteProviderModelsChanged,
object: nil
)
}
// Spam many notifications. Each one schedules an observer Task
// that calls `buildModelPickerItems()`. Pre-fix, each Task would
// first set `items = []` and `isLoaded = false`.
for _ in 0 ..< 50 {
NotificationCenter.default.post(
name: .remoteProviderModelsChanged,
object: nil
)
}

// Drain the observer Tasks by repeatedly yielding the MainActor
// and sampling `items`. With the fix, every sample must be
// non-empty the rebuild only assigns `items` when it has the
// full list.
var samples: [Int] = []
for _ in 0 ..< 200 {
samples.append(cache.items.count)
try? await Task.sleep(nanoseconds: 200_000) // 0.2ms
}
// Drain the observer Tasks by repeatedly yielding the MainActor
// and sampling `items`. With the fix, every sample must be
// non-empty - the rebuild only assigns `items` when it has the
// full list.
var samples: [Int] = []
for _ in 0 ..< 200 {
samples.append(cache.items.count)
try? await Task.sleep(nanoseconds: 200_000) // 0.2ms
}

#expect(
!samples.contains(0),
"items must remain populated during rebuilds; observed sample counts: \(samples)"
)
#expect(
!samples.contains(0),
"items must remain populated during rebuilds; observed sample counts: \(samples)"
)

// After the burst settles, the cache should still hold a
// populated list (state hasn't actually changed, so it should
// match the initial count).
let final = await cache.buildModelPickerItems()
#expect(!final.isEmpty)
#expect(final.count == initialCount)
// After the burst settles, the cache should still hold a
// populated list (state hasn't actually changed, so it should
// match the initial count).
let final = await cache.buildModelPickerItems()
#expect(!final.isEmpty)
#expect(final.count == initialCount)
}
}
}
Loading
Loading