diff --git a/Playground/Playground/Views/ChatView.swift b/Playground/Playground/Views/ChatView.swift index c8dada7..d355d16 100644 --- a/Playground/Playground/Views/ChatView.swift +++ b/Playground/Playground/Views/ChatView.swift @@ -18,6 +18,9 @@ struct ChatView: View { @State private var outputTokens: Int = 0 @State private var totalTokens: Int = 0 + @State private var isGenerating: Bool = false + @State private var generationTask: Task? + var body: some View { VStack { Form { @@ -33,7 +36,11 @@ struct ChatView: View { } VStack { - SendButton(stream: viewModel.stream, onSend: onSend, onStream: onStream) + if isGenerating { + CancelButton(onCancel: { generationTask?.cancel() }) + } else { + SendButton(stream: viewModel.stream, onSend: onSend, onStream: onStream) + } } } .toolbar { @@ -53,6 +60,8 @@ struct ChatView: View { private func onSend() { clear() + isGenerating = true + let messages = [ ChatMessage(role: .system, content: viewModel.systemPrompt), ChatMessage(role: .user, content: prompt) @@ -60,8 +69,13 @@ struct ChatView: View { let options = ChatOptions(temperature: viewModel.temperature) - Task { + generationTask = Task { do { + defer { + self.isGenerating = false + self.generationTask = nil + } + let completion = try await viewModel.chat.send(model: viewModel.selectedModel, messages: messages, options: options) if let text = completion.content.first?.text { @@ -74,7 +88,7 @@ struct ChatView: View { self.totalTokens = usage.totalTokens } } catch { - print(String(describing: error)) + print(error) } } } @@ -82,6 +96,8 @@ struct ChatView: View { private func onStream() { clear() + isGenerating = true + let messages = [ ChatMessage(role: .system, content: viewModel.systemPrompt), ChatMessage(role: .user, content: prompt) @@ -89,8 +105,13 @@ struct ChatView: View { let options = ChatOptions(temperature: viewModel.temperature) - Task { + generationTask = Task { do { + defer { + self.isGenerating = false + self.generationTask = nil + } + for try await chunk in viewModel.chat.stream(model: viewModel.selectedModel, messages: messages, options: options) { if let text = chunk.delta?.text { self.response += text @@ -103,7 +124,7 @@ struct ChatView: View { } } } catch { - print(String(describing: error)) + print(error) } } } diff --git a/Playground/Playground/Views/Subviews/CancelButton.swift b/Playground/Playground/Views/Subviews/CancelButton.swift new file mode 100644 index 0000000..a2b01e9 --- /dev/null +++ b/Playground/Playground/Views/Subviews/CancelButton.swift @@ -0,0 +1,27 @@ +// +// CancelButton.swift +// Playground +// +// Created by Kevin Hermawan on 10/30/24. +// + +import SwiftUI + +struct CancelButton: View { + private let onCancel: () -> Void + + init(onCancel: @escaping () -> Void) { + self.onCancel = onCancel + } + + var body: some View { + Button(action: onCancel) { + Text("Cancel") + .padding(.vertical, 8) + .frame(maxWidth: .infinity) + } + .buttonStyle(.bordered) + .padding([.horizontal, .bottom]) + .padding(.top, 8) + } +} diff --git a/README.md b/README.md index ead6327..0c9a82b 100644 --- a/README.md +++ b/README.md @@ -215,10 +215,16 @@ do { case .networkError(let error): // Handle network-related errors (e.g., no internet connection) print("Network Error: \(error.localizedDescription)") - case .badServerResponse: - // Handle invalid server responses - print("Invalid response received from server") + case .decodingError(let error): + // Handle errors that occur when the response cannot be decoded + print("Decoding Error: \(error)") + case .cancelled: + // Handle requests that are cancelled + print("Request was cancelled") } +} catch { + // Handle any other errors + print("An unexpected error occurred: \(error)") } ``` diff --git a/Sources/LLMChatAnthropic/Documentation.docc/Documentation.md b/Sources/LLMChatAnthropic/Documentation.docc/Documentation.md index 972ccce..be3662e 100644 --- a/Sources/LLMChatAnthropic/Documentation.docc/Documentation.md +++ b/Sources/LLMChatAnthropic/Documentation.docc/Documentation.md @@ -186,10 +186,16 @@ do { case .networkError(let error): // Handle network-related errors (e.g., no internet connection) print("Network Error: \(error.localizedDescription)") - case .badServerResponse: - // Handle invalid server responses - print("Invalid response received from server") + case .decodingError(let error): + // Handle errors that occur when the response cannot be decoded + print("Decoding Error: \(error)") + case .cancelled: + // Handle requests that are cancelled + print("Request was cancelled") } +} catch { + // Handle any other errors + print("An unexpected error occurred: \(error)") } ``` diff --git a/Sources/LLMChatAnthropic/LLMChatAnthropic.swift b/Sources/LLMChatAnthropic/LLMChatAnthropic.swift index 626c31f..f7b0b8b 100644 --- a/Sources/LLMChatAnthropic/LLMChatAnthropic.swift +++ b/Sources/LLMChatAnthropic/LLMChatAnthropic.swift @@ -89,15 +89,26 @@ private extension LLMChatAnthropic { let request = try createRequest(for: endpoint, with: body) let (data, response) = try await URLSession.shared.data(for: request) + guard let httpResponse = response as? HTTPURLResponse else { + throw LLMChatAnthropicError.serverError(response.description) + } + + // Check for API errors first, as they might come with 200 status if let errorResponse = try? JSONDecoder().decode(ChatCompletionError.self, from: data) { throw LLMChatAnthropicError.serverError(errorResponse.error.message) } - guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else { - throw LLMChatAnthropicError.badServerResponse + guard 200...299 ~= httpResponse.statusCode else { + throw LLMChatAnthropicError.serverError(response.description) } return try JSONDecoder().decode(ChatCompletion.self, from: data) + } catch is CancellationError { + throw LLMChatAnthropicError.cancelled + } catch let error as URLError where error.code == .cancelled { + throw LLMChatAnthropicError.cancelled + } catch let error as DecodingError { + throw LLMChatAnthropicError.decodingError(error) } catch let error as LLMChatAnthropicError { throw error } catch { @@ -107,87 +118,99 @@ private extension LLMChatAnthropic { func performStreamRequest(with body: RequestBody) -> AsyncThrowingStream { AsyncThrowingStream { continuation in - Task { - do { - let request = try createRequest(for: endpoint, with: body) - let (bytes, response) = try await URLSession.shared.bytes(for: request) - - guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else { + let task = Task { + await withTaskCancellationHandler { + do { + let request = try createRequest(for: endpoint, with: body) + let (bytes, response) = try await URLSession.shared.bytes(for: request) + + guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else { + throw LLMChatAnthropicError.serverError(response.description) + } + + var currentChunk = ChatCompletionChunk(id: "", model: "", role: "") + for try await line in bytes.lines { - if let data = line.data(using: .utf8), let errorResponse = try? JSONDecoder().decode(ChatCompletionError.self, from: data) { - throw LLMChatAnthropicError.serverError(errorResponse.error.message) + try Task.checkCancellation() + + if line.hasPrefix("event: error") { + throw LLMChatAnthropicError.streamError } - break - } - - throw LLMChatAnthropicError.badServerResponse - } - - var currentChunk = ChatCompletionChunk(id: "", model: "", role: "") - - for try await line in bytes.lines { - if line.hasPrefix("data: ") { - let jsonData = line.dropFirst(6) + guard line.hasPrefix("data: "), let data = line.dropFirst(6).data(using: .utf8) else { + continue + } - if let data = jsonData.data(using: .utf8) { - let rawChunk = try JSONDecoder().decode(RawChatCompletionChunk.self, from: data) - - switch rawChunk.type { - case "message_start": - if let message = rawChunk.message { - currentChunk.id = message.id - currentChunk.role = message.role - currentChunk.model = message.model - - if let usage = message.usage, let inputTokens = usage.inputTokens, let outputTokens = usage.outputTokens { - currentChunk.usage = ChatCompletionChunk.Usage(inputTokens: inputTokens, outputTokens: outputTokens) - } - - continuation.yield(currentChunk) - } - case "content_block_start": - if let contentBlock = rawChunk.contentBlock { - currentChunk.delta = ChatCompletionChunk.Delta(type: contentBlock.type, toolName: contentBlock.name) - - continuation.yield(currentChunk) - } - case "content_block_delta": - if let delta = rawChunk.delta { - currentChunk.delta?.text = delta.text - currentChunk.delta?.toolInput = delta.partialJson - - continuation.yield(currentChunk) + let rawChunk = try JSONDecoder().decode(RawChatCompletionChunk.self, from: data) + + switch rawChunk.type { + case "message_start": + if let message = rawChunk.message { + currentChunk.id = message.id + currentChunk.role = message.role + currentChunk.model = message.model + + if let usage = message.usage, let inputTokens = usage.inputTokens, let outputTokens = usage.outputTokens { + currentChunk.usage = .init(inputTokens: inputTokens, outputTokens: outputTokens) } - case "message_delta": - if let delta = rawChunk.delta { - currentChunk.delta?.text = nil - currentChunk.delta?.toolInput = nil - currentChunk.stopReason = delta.stopReason - currentChunk.stopSequence = delta.stopSequence - - if let outputTokens = rawChunk.usage?.outputTokens { - currentChunk.usage?.outputTokens = outputTokens - } - - continuation.yield(currentChunk) + + continuation.yield(currentChunk) + } + + case "content_block_start": + if let contentBlock = rawChunk.contentBlock { + currentChunk.delta = .init(type: contentBlock.type, toolName: contentBlock.name) + + continuation.yield(currentChunk) + } + case "content_block_delta": + if let delta = rawChunk.delta { + currentChunk.delta?.text = delta.text + currentChunk.delta?.toolInput = delta.partialJson + + continuation.yield(currentChunk) + } + case "message_delta": + if let delta = rawChunk.delta { + currentChunk.delta?.text = nil + currentChunk.delta?.toolInput = nil + currentChunk.stopReason = delta.stopReason + currentChunk.stopSequence = delta.stopSequence + + if let outputTokens = rawChunk.usage?.outputTokens { + currentChunk.usage?.outputTokens = outputTokens } - case "message_stop": - continuation.finish() - default: - break + + continuation.yield(currentChunk) } + case "message_stop": + continuation.finish() + return + default: + break } } + + continuation.finish() + } catch is CancellationError { + continuation.finish(throwing: LLMChatAnthropicError.cancelled) + } catch let error as URLError where error.code == .cancelled { + continuation.finish(throwing: LLMChatAnthropicError.cancelled) + } catch let error as DecodingError { + continuation.finish(throwing: LLMChatAnthropicError.decodingError(error)) + } catch let error as LLMChatAnthropicError { + continuation.finish(throwing: error) + } catch { + continuation.finish(throwing: LLMChatAnthropicError.networkError(error)) } - - continuation.finish() - } catch let error as LLMChatAnthropicError { - continuation.finish(throwing: error) - } catch { - continuation.finish(throwing: LLMChatAnthropicError.networkError(error)) + } onCancel: { + continuation.finish(throwing: LLMChatAnthropicError.cancelled) } } + + continuation.onTermination = { @Sendable _ in + task.cancel() + } } } } diff --git a/Sources/LLMChatAnthropic/LLMChatAnthropicError.swift b/Sources/LLMChatAnthropic/LLMChatAnthropicError.swift index 2278700..b8d19b0 100644 --- a/Sources/LLMChatAnthropic/LLMChatAnthropicError.swift +++ b/Sources/LLMChatAnthropic/LLMChatAnthropicError.swift @@ -8,29 +8,25 @@ import Foundation /// An enum that represents errors from the chat completion request. -public enum LLMChatAnthropicError: LocalizedError, Sendable { - /// A case that represents a server-side error response. +public enum LLMChatAnthropicError: Error, Sendable { + /// An error that occurs during JSON decoding. /// - /// - Parameter message: The error message from the server. - case serverError(String) + /// - Parameter error: The underlying decoding error. + case decodingError(Error) - /// A case that represents a network-related error. + /// An error that occurs during network operations. /// /// - Parameter error: The underlying network error. case networkError(Error) - /// A case that represents an invalid server response. - case badServerResponse + /// An error returned by the server. + /// + /// - Parameter message: The error message received from the server. + case serverError(String) + + /// An error that occurs during stream processing. + case streamError - /// A localized message that describes the error. - public var errorDescription: String? { - switch self { - case .serverError(let error): - return error - case .networkError(let error): - return error.localizedDescription - case .badServerResponse: - return "Invalid response received from server" - } - } + /// An error that occurs when the request is cancelled. + case cancelled } diff --git a/Tests/LLMChatAnthropicTests/ChatCompletionTests.swift b/Tests/LLMChatAnthropicTests/ChatCompletionTests.swift index df7b78e..d4900f8 100644 --- a/Tests/LLMChatAnthropicTests/ChatCompletionTests.swift +++ b/Tests/LLMChatAnthropicTests/ChatCompletionTests.swift @@ -24,6 +24,7 @@ final class ChatCompletionTests: XCTestCase { ] URLProtocol.registerClass(URLProtocolMock.self) + URLProtocolMock.reset() } override func tearDown() { @@ -181,38 +182,92 @@ extension ChatCompletionTests { } } - func testStreamServerError() async throws { - let mockErrorResponse = """ - { - "error": { - "message": "Rate limit exceeded" + func testHTTPError() async throws { + URLProtocolMock.mockStatusCode = 429 + URLProtocolMock.mockData = "Rate limit exceeded".data(using: .utf8) + + do { + _ = try await chat.send(model: "claude-3-5-sonnet", messages: messages) + + XCTFail("Expected serverError to be thrown") + } catch let error as LLMChatAnthropicError { + switch error { + case .serverError(let message): + XCTAssertTrue(message.contains("429")) + default: + XCTFail("Expected serverError but got \(error)") } } - """ + } + + func testDecodingError() async throws { + let invalidJSON = "{ invalid json }" + URLProtocolMock.mockData = invalidJSON.data(using: .utf8) + + do { + _ = try await chat.send(model: "claude-3-5-sonnet", messages: messages) + + XCTFail("Expected decodingError to be thrown") + } catch let error as LLMChatAnthropicError { + switch error { + case .decodingError: + break + default: + XCTFail("Expected decodingError but got \(error)") + } + } + } + + func testCancellation() async throws { + let task = Task { + _ = try await chat.send(model: "claude-3-5-sonnet", messages: messages) + } - URLProtocolMock.mockStreamData = [mockErrorResponse] + task.cancel() + + do { + _ = try await task.value + + XCTFail("Expected cancelled error to be thrown") + } catch let error as LLMChatAnthropicError { + switch error { + case .cancelled: + break + default: + XCTFail("Expected cancelled but got \(error)") + } + } + } +} + +// MARK: - Error Handling (Stream) +extension ChatCompletionTests { + func testStreamServerError() async throws { + URLProtocolMock.mockStreamData = ["event: error\ndata: Server error occurred\n\n"] do { for try await _ in chat.stream(model: "claude-3-5-sonnet", messages: messages) { - XCTFail("Expected serverError to be thrown") + XCTFail("Expected streamError to be thrown") } } catch let error as LLMChatAnthropicError { switch error { - case .serverError(let message): - XCTAssertEqual(message, "Rate limit exceeded") + case .streamError: + break default: - XCTFail("Expected serverError but got \(error)") + XCTFail("Expected streamError but got \(error)") } } } func testStreamNetworkError() async throws { - URLProtocolMock.mockError = NSError( + let networkError = NSError( domain: NSURLErrorDomain, - code: NSURLErrorNotConnectedToInternet, - userInfo: [NSLocalizedDescriptionKey: "The Internet connection appears to be offline."] + code: NSURLErrorNetworkConnectionLost, + userInfo: [NSLocalizedDescriptionKey: "The network connection was lost."] ) + URLProtocolMock.mockError = networkError + do { for try await _ in chat.stream(model: "claude-3-5-sonnet", messages: messages) { XCTFail("Expected networkError to be thrown") @@ -220,10 +275,70 @@ extension ChatCompletionTests { } catch let error as LLMChatAnthropicError { switch error { case .networkError(let underlyingError): - XCTAssertEqual((underlyingError as NSError).code, NSURLErrorNotConnectedToInternet) + XCTAssertEqual((underlyingError as NSError).code, NSURLErrorNetworkConnectionLost) default: XCTFail("Expected networkError but got \(error)") } } } + + func testStreamHTTPError() async throws { + URLProtocolMock.mockStatusCode = 503 + URLProtocolMock.mockStreamData = [""] + + do { + for try await _ in chat.stream(model: "claude-3-5-sonnet", messages: messages) { + XCTFail("Expected serverError to be thrown") + } + } catch let error as LLMChatAnthropicError { + switch error { + case .serverError(let message): + XCTAssertTrue(message.contains("503")) + default: + XCTFail("Expected serverError but got \(error)") + } + } + } + + func testStreamDecodingError() async throws { + URLProtocolMock.mockStreamData = ["event: message_start\ndata: { invalid json }\n\n"] + + do { + for try await _ in chat.stream(model: "claude-3-5-sonnet", messages: messages) { + XCTFail("Expected decodingError to be thrown") + } + } catch let error as LLMChatAnthropicError { + switch error { + case .decodingError: + break + default: + XCTFail("Expected decodingError but got \(error)") + } + } + } + + func testStreamCancellation() async throws { + URLProtocolMock.mockStreamData = Array(repeating: "event: content_block_delta\ndata: {\"type\":\"content_block_delta\",\"index\":0,\"delta\":{\"type\":\"text_delta\",\"text\":\"test\"}}\n\n", count: 1000) + + let expectation = XCTestExpectation(description: "Stream cancelled") + + let task = Task { + do { + for try await _ in chat.stream(model: "claude-3-5-sonnet", messages: messages) { + try await Task.sleep(nanoseconds: 100_000_000) // 1 second + } + + XCTFail("Expected stream to be cancelled") + } catch is CancellationError { + expectation.fulfill() + } catch { + XCTFail("Expected CancellationError but got \(error)") + } + } + + try await Task.sleep(nanoseconds: 1_000_000_000) // 1 second + task.cancel() + + await fulfillment(of: [expectation], timeout: 5.0) + } } diff --git a/Tests/LLMChatAnthropicTests/Utils/URLProtocolMock.swift b/Tests/LLMChatAnthropicTests/Utils/URLProtocolMock.swift index 22dd1d2..d98f447 100644 --- a/Tests/LLMChatAnthropicTests/Utils/URLProtocolMock.swift +++ b/Tests/LLMChatAnthropicTests/Utils/URLProtocolMock.swift @@ -11,6 +11,7 @@ final class URLProtocolMock: URLProtocol { static var mockData: Data? static var mockStreamData: [String]? static var mockError: Error? + static var mockStatusCode: Int? override class func canInit(with request: URLRequest) -> Bool { return true @@ -28,19 +29,18 @@ final class URLProtocolMock: URLProtocol { return } - if let streamData = URLProtocolMock.mockStreamData { - let response = HTTPURLResponse(url: request.url!, statusCode: 200, httpVersion: nil, headerFields: ["Content-Type": "text/event-stream"])! + if let streamData = URLProtocolMock.mockStreamData, let url = request.url { + let response = HTTPURLResponse(url: url, statusCode: URLProtocolMock.mockStatusCode ?? 200, httpVersion: nil, headerFields: ["Content-Type": "text/event-stream"])! client.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) for line in streamData { client.urlProtocol(self, didLoad: Data(line.utf8)) } - } else if let data = URLProtocolMock.mockData { - let response = HTTPURLResponse(url: request.url!, statusCode: 200, httpVersion: nil, headerFields: nil)! + } else if let data = URLProtocolMock.mockData, let url = request.url { + let response = HTTPURLResponse(url: url, statusCode: URLProtocolMock.mockStatusCode ?? 200, httpVersion: nil, headerFields: nil)! client.urlProtocol(self, didReceive: response, cacheStoragePolicy: .notAllowed) client.urlProtocol(self, didLoad: data) } else { - client.urlProtocol(self, didFailWithError: NSError(domain: "MockURLProtocol", code: -1, userInfo: [NSLocalizedDescriptionKey: "No mock data available"])) return } @@ -48,4 +48,11 @@ final class URLProtocolMock: URLProtocol { } override func stopLoading() {} + + static func reset() { + mockData = nil + mockStreamData = nil + mockError = nil + mockStatusCode = nil + } }