From 3ec2afbeda51c73fbdd6da052250ce0fab726a10 Mon Sep 17 00:00:00 2001 From: Kevin Hermawan <84965338+kevinhermawan@users.noreply.github.com> Date: Sun, 27 Oct 2024 05:44:50 +0700 Subject: [PATCH] improve: adds better error handling (#12) --- README.md | 29 ++++ .../Documentation.docc/Documentation.md | 29 ++++ Sources/LLMChatOpenAI/LLMChatOpenAI.swift | 124 +++++++++++------- .../LLMChatOpenAI/LLMChatOpenAIError.swift | 36 +++++ .../ChatCompletionTests.swift | 103 ++++++++++++++- 5 files changed, 271 insertions(+), 50 deletions(-) create mode 100644 Sources/LLMChatOpenAI/LLMChatOpenAIError.swift diff --git a/README.md b/README.md index decf570..8c4d1cb 100644 --- a/README.md +++ b/README.md @@ -236,6 +236,35 @@ Task { To learn more about structured outputs, check out the [OpenAI documentation](https://platform.openai.com/docs/guides/structured-outputs/introduction). +### Error Handling + +`LLMChatOpenAI` provides structured error handling through the `LLMChatOpenAIError` enum. This enum contains three cases that represent different types of errors you might encounter: + +```swift +let messages = [ + ChatMessage(role: .system, content: "You are a helpful assistant."), + ChatMessage(role: .user, content: "What is the capital of Indonesia?") +] + +do { + let completion = try await chat.send(model: "gpt-4o", messages: messages) + + print(completion.choices.first?.message.content ?? "No response") +} catch let error as LLMChatOpenAIError { + switch error { + case .serverError(let message): + // Handle server-side errors (e.g., invalid API key, rate limits) + print("Server Error: \(message)") + case .networkError(let error): + // Handle network-related errors (e.g., no internet connection) + print("Network Error: \(error.localizedDescription)") + case .badServerResponse: + // Handle invalid server responses + print("Invalid response received from server") + } +} +``` + ## Related Packages - [swift-ai-model-retriever](https://github.com/kevinhermawan/swift-ai-model-retriever) diff --git a/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md b/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md index 0c7b921..2368189 100644 --- a/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md +++ b/Sources/LLMChatOpenAI/Documentation.docc/Documentation.md @@ -207,6 +207,35 @@ Task { To learn more about structured outputs, check out the [OpenAI documentation](https://platform.openai.com/docs/guides/structured-outputs/introduction). +### Error Handling + +``LLMChatOpenAI`` provides structured error handling through the ``LLMChatOpenAIError`` enum. This enum contains three cases that represent different types of errors you might encounter: + +```swift +let messages = [ + ChatMessage(role: .system, content: "You are a helpful assistant."), + ChatMessage(role: .user, content: "What is the capital of Indonesia?") +] + +do { + let completion = try await chat.send(model: "gpt-4o", messages: messages) + + print(completion.choices.first?.message.content ?? "No response") +} catch let error as LLMChatOpenAIError { + switch error { + case .serverError(let message): + // Handle server-side errors (e.g., invalid API key, rate limits) + print("Server Error: \(message)") + case .networkError(let error): + // Handle network-related errors (e.g., no internet connection) + print("Network Error: \(error.localizedDescription)") + case .badServerResponse: + // Handle invalid server responses + print("Invalid response received from server") + } +} +``` + ## Related Packages - [swift-ai-model-retriever](https://github.com/kevinhermawan/swift-ai-model-retriever) diff --git a/Sources/LLMChatOpenAI/LLMChatOpenAI.swift b/Sources/LLMChatOpenAI/LLMChatOpenAI.swift index 45693ca..815e5cc 100644 --- a/Sources/LLMChatOpenAI/LLMChatOpenAI.swift +++ b/Sources/LLMChatOpenAI/LLMChatOpenAI.swift @@ -30,9 +30,23 @@ public struct LLMChatOpenAI { self.endpoint = endpoint ?? URL(string: "https://api.openai.com/v1/chat/completions")! self.headers = headers } + + var allHeaders: [String: String] { + var defaultHeaders = [ + "Content-Type": "application/json", + "Authorization": "Bearer \(apiKey)" + ] + + if let headers { + defaultHeaders.merge(headers) { _, new in new } + } + + return defaultHeaders + } } -extension LLMChatOpenAI { +// MARK: - Send +public extension LLMChatOpenAI { /// Sends a chat completion request. /// /// - Parameters: @@ -41,7 +55,7 @@ extension LLMChatOpenAI { /// - options: Optional ``ChatOptions`` that customize the completion request. /// /// - Returns: A ``ChatCompletion`` object that contains the API's response. - public func send(model: String, messages: [ChatMessage], options: ChatOptions? = nil) async throws -> ChatCompletion { + func send(model: String, messages: [ChatMessage], options: ChatOptions? = nil) async throws -> ChatCompletion { let body = RequestBody(stream: false, model: model, messages: messages, options: options) return try await performRequest(with: body) @@ -57,7 +71,7 @@ extension LLMChatOpenAI { /// - Returns: A ``ChatCompletion`` object that contains the API's response. /// /// - Note: This method enables fallback functionality when using OpenRouter. For other providers, only the first model in the array will be used. - public func send(models: [String], messages: [ChatMessage], options: ChatOptions? = nil) async throws -> ChatCompletion { + func send(models: [String], messages: [ChatMessage], options: ChatOptions? = nil) async throws -> ChatCompletion { let body: RequestBody if isSupportFallbackModel { @@ -68,7 +82,10 @@ extension LLMChatOpenAI { return try await performRequest(with: body) } - +} + +// MARK: - Stream +public extension LLMChatOpenAI { /// Streams a chat completion request. /// /// - Parameters: @@ -77,7 +94,7 @@ extension LLMChatOpenAI { /// - options: Optional ``ChatOptions`` that customize the completion request. /// /// - Returns: An `AsyncThrowingStream` of ``ChatCompletionChunk`` objects. - public func stream(model: String, messages: [ChatMessage], options: ChatOptions? = nil) -> AsyncThrowingStream { + func stream(model: String, messages: [ChatMessage], options: ChatOptions? = nil) -> AsyncThrowingStream { let body = RequestBody(stream: true, model: model, messages: messages, options: options) return performStreamRequest(with: body) @@ -93,7 +110,7 @@ extension LLMChatOpenAI { /// - Returns: An `AsyncThrowingStream` of ``ChatCompletionChunk`` objects. /// /// - Note: This method enables fallback functionality when using OpenRouter. For other providers, only the first model in the array will be used. - public func stream(models: [String], messages: [ChatMessage], options: ChatOptions? = nil) -> AsyncThrowingStream { + func stream(models: [String], messages: [ChatMessage], options: ChatOptions? = nil) -> AsyncThrowingStream { let body: RequestBody if isSupportFallbackModel { @@ -104,22 +121,58 @@ extension LLMChatOpenAI { return performStreamRequest(with: body) } - - private func performRequest(with body: RequestBody) async throws -> ChatCompletion { - let request = try createRequest(for: endpoint, with: body) - let (data, response) = try await URLSession.shared.data(for: request) - try validateHTTPResponse(response) +} + +// MARK: - Helpers +private extension LLMChatOpenAI { + func createRequest(for url: URL, with body: RequestBody) throws -> URLRequest { + var request = URLRequest(url: url) + request.httpMethod = "POST" + request.httpBody = try JSONEncoder().encode(body) + request.allHTTPHeaderFields = allHeaders - return try JSONDecoder().decode(ChatCompletion.self, from: data) + return request + } + + func performRequest(with body: RequestBody) async throws -> ChatCompletion { + do { + let request = try createRequest(for: endpoint, with: body) + let (data, response) = try await URLSession.shared.data(for: request) + + if let errorResponse = try? JSONDecoder().decode(ChatCompletionError.self, from: data) { + throw LLMChatOpenAIError.serverError(errorResponse.error.message) + } + + guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else { + throw LLMChatOpenAIError.badServerResponse + } + + return try JSONDecoder().decode(ChatCompletion.self, from: data) + } catch let error as LLMChatOpenAIError { + throw error + } catch { + throw LLMChatOpenAIError.networkError(error) + } } - private func performStreamRequest(with body: RequestBody) -> AsyncThrowingStream { + func performStreamRequest(with body: RequestBody) -> AsyncThrowingStream { AsyncThrowingStream { continuation in Task { do { let request = try createRequest(for: endpoint, with: body) let (bytes, response) = try await URLSession.shared.bytes(for: request) - try validateHTTPResponse(response) + + guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else { + for try await line in bytes.lines { + if let data = line.data(using: .utf8), let errorResponse = try? JSONDecoder().decode(ChatCompletionError.self, from: data) { + throw LLMChatOpenAIError.serverError(errorResponse.error.message) + } + + break + } + + throw LLMChatOpenAIError.badServerResponse + } for try await line in bytes.lines { if line.hasPrefix("data: ") { @@ -138,45 +191,16 @@ extension LLMChatOpenAI { } continuation.finish() - } catch { + } catch let error as LLMChatOpenAIError { continuation.finish(throwing: error) + } catch { + continuation.finish(throwing: LLMChatOpenAIError.networkError(error)) } } } } } -// MARK: - Helper Methods -private extension LLMChatOpenAI { - var allHeaders: [String: String] { - var defaultHeaders = [ - "Content-Type": "application/json", - "Authorization": "Bearer \(apiKey)" - ] - - if let headers { - defaultHeaders.merge(headers) { _, new in new } - } - - return defaultHeaders - } - - func createRequest(for url: URL, with body: RequestBody) throws -> URLRequest { - var request = URLRequest(url: url) - request.httpMethod = "POST" - request.httpBody = try JSONEncoder().encode(body) - request.allHTTPHeaderFields = allHeaders - - return request - } - - func validateHTTPResponse(_ response: URLResponse) throws { - guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else { - throw URLError(.badServerResponse) - } - } -} - // MARK: - Supporting Types private extension LLMChatOpenAI { struct RequestBody: Encodable { @@ -228,4 +252,12 @@ private extension LLMChatOpenAI { case streamOptions = "stream_options" } } + + struct ChatCompletionError: Codable { + let error: Error + + struct Error: Codable { + public let message: String + } + } } diff --git a/Sources/LLMChatOpenAI/LLMChatOpenAIError.swift b/Sources/LLMChatOpenAI/LLMChatOpenAIError.swift new file mode 100644 index 0000000..7d0925c --- /dev/null +++ b/Sources/LLMChatOpenAI/LLMChatOpenAIError.swift @@ -0,0 +1,36 @@ +// +// LLMChatOpenAIError.swift +// LLMChatOpenAI +// +// Created by Kevin Hermawan on 10/27/24. +// + +import Foundation + +/// An enum that represents errors from the chat completion request. +public enum LLMChatOpenAIError: LocalizedError { + /// A case that represents a server-side error response. + /// + /// - Parameter message: The error message from the server. + case serverError(String) + + /// A case that represents a network-related error. + /// + /// - Parameter error: The underlying network error. + case networkError(Error) + + /// A case that represents an invalid server response. + case badServerResponse + + /// A localized message that describes the error. + public var errorDescription: String? { + switch self { + case .serverError(let error): + return error + case .networkError(let error): + return error.localizedDescription + case .badServerResponse: + return "Invalid response received from server" + } + } +} diff --git a/Tests/LLMChatOpenAITests/ChatCompletionTests.swift b/Tests/LLMChatOpenAITests/ChatCompletionTests.swift index e4d8102..c5fa47f 100644 --- a/Tests/LLMChatOpenAITests/ChatCompletionTests.swift +++ b/Tests/LLMChatOpenAITests/ChatCompletionTests.swift @@ -37,7 +37,7 @@ final class ChatCompletionTests: XCTestCase { } func testSendChatCompletion() async throws { - let mockResponseString = """ + let mockResponse = """ { "id": "chatcmpl-123", "object": "chat.completion", @@ -61,7 +61,7 @@ final class ChatCompletionTests: XCTestCase { } """ - URLProtocolMock.mockData = mockResponseString.data(using: .utf8) + URLProtocolMock.mockData = mockResponse.data(using: .utf8) let completion = try await chat.send(model: "gpt-4o", messages: messages) let choice = completion.choices.first let message = choice?.message @@ -99,7 +99,7 @@ final class ChatCompletionTests: XCTestCase { } func testSendChatCompletionWithFallbackModels() async throws { - let mockResponseString = """ + let mockResponse = """ { "id": "chatcmpl-123", "object": "chat.completion", @@ -123,7 +123,7 @@ final class ChatCompletionTests: XCTestCase { } """ - URLProtocolMock.mockData = mockResponseString.data(using: .utf8) + URLProtocolMock.mockData = mockResponse.data(using: .utf8) let completion = try await chat.send(models: ["openai/gpt-4o", "mistralai/mixtral-8x7b-instruct"], messages: messages) let choice = completion.choices.first let message = choice?.message @@ -160,3 +160,98 @@ final class ChatCompletionTests: XCTestCase { XCTAssertEqual(receivedContent, "The capital of Indonesia is Jakarta.") } } + +// MARK: - Error Handling +extension ChatCompletionTests { + func testServerError() async throws { + let mockErrorResponse = """ + { + "error": { + "message": "Invalid API key provided" + } + } + """ + + URLProtocolMock.mockData = mockErrorResponse.data(using: .utf8) + + do { + _ = try await chat.send(model: "gpt-4", messages: messages) + + XCTFail("Expected serverError to be thrown") + } catch let error as LLMChatOpenAIError { + switch error { + case .serverError(let message): + XCTAssertEqual(message, "Invalid API key provided") + default: + XCTFail("Expected serverError but got \(error)") + } + } + } + + func testNetworkError() async throws { + URLProtocolMock.mockError = NSError( + domain: NSURLErrorDomain, + code: NSURLErrorNotConnectedToInternet, + userInfo: [NSLocalizedDescriptionKey: "The Internet connection appears to be offline."] + ) + + do { + _ = try await chat.send(model: "gpt-4", messages: messages) + + XCTFail("Expected networkError to be thrown") + } catch let error as LLMChatOpenAIError { + switch error { + case .networkError(let underlyingError): + XCTAssertEqual((underlyingError as NSError).code, NSURLErrorNotConnectedToInternet) + default: + XCTFail("Expected networkError but got \(error)") + } + } + } + + func testStreamServerError() async throws { + let mockErrorResponse = """ + { + "error": { + "message": "Rate limit exceeded" + } + } + """ + + URLProtocolMock.mockStreamData = [mockErrorResponse] + + do { + for try await _ in chat.stream(model: "gpt-4", messages: messages) { + XCTFail("Expected serverError to be thrown") + } + } catch let error as LLMChatOpenAIError { + switch error { + case .serverError(let message): + XCTAssertEqual(message, "Rate limit exceeded") + default: + XCTFail("Expected serverError but got \(error)") + } + } + } + + func testStreamNetworkError() async throws { + URLProtocolMock.mockError = NSError( + domain: NSURLErrorDomain, + code: NSURLErrorNotConnectedToInternet, + userInfo: [NSLocalizedDescriptionKey: "The Internet connection appears to be offline."] + ) + + do { + for try await _ in chat.stream(model: "gpt-4", messages: messages) { + XCTFail("Expected networkError to be thrown") + } + } catch let error as LLMChatOpenAIError { + switch error { + case .networkError(let underlyingError): + XCTAssertEqual((underlyingError as NSError).code, NSURLErrorNotConnectedToInternet) + default: + XCTFail("Expected networkError but got \(error)") + } + } + } +}