Skip to content

Commit

Permalink
refactor: improves error handling (#8)
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinhermawan authored Oct 30, 2024
1 parent 1f78b37 commit 1bfdd92
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 120 deletions.
31 changes: 26 additions & 5 deletions Playground/Playground/Views/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,9 @@ struct ChatView: View {
@State private var outputTokens: Int = 0
@State private var totalTokens: Int = 0

@State private var isGenerating: Bool = false
@State private var generationTask: Task<Void, Never>?

var body: some View {
VStack {
Form {
Expand All @@ -33,7 +36,11 @@ struct ChatView: View {
}

VStack {
SendButton(stream: viewModel.stream, onSend: onSend, onStream: onStream)
if isGenerating {
CancelButton(onCancel: { generationTask?.cancel() })
} else {
SendButton(stream: viewModel.stream, onSend: onSend, onStream: onStream)
}
}
}
.toolbar {
Expand All @@ -53,15 +60,22 @@ struct ChatView: View {
private func onSend() {
clear()

isGenerating = true

let messages = [
ChatMessage(role: .system, content: viewModel.systemPrompt),
ChatMessage(role: .user, content: prompt)
]

let options = ChatOptions(temperature: viewModel.temperature)

Task {
generationTask = Task {
do {
defer {
self.isGenerating = false
self.generationTask = nil
}

let completion = try await viewModel.chat.send(model: viewModel.selectedModel, messages: messages, options: options)

if let text = completion.content.first?.text {
Expand All @@ -74,23 +88,30 @@ struct ChatView: View {
self.totalTokens = usage.totalTokens
}
} catch {
print(String(describing: error))
print(error)
}
}
}

private func onStream() {
clear()

isGenerating = true

let messages = [
ChatMessage(role: .system, content: viewModel.systemPrompt),
ChatMessage(role: .user, content: prompt)
]

let options = ChatOptions(temperature: viewModel.temperature)

Task {
generationTask = Task {
do {
defer {
self.isGenerating = false
self.generationTask = nil
}

for try await chunk in viewModel.chat.stream(model: viewModel.selectedModel, messages: messages, options: options) {
if let text = chunk.delta?.text {
self.response += text
Expand All @@ -103,7 +124,7 @@ struct ChatView: View {
}
}
} catch {
print(String(describing: error))
print(error)
}
}
}
Expand Down
27 changes: 27 additions & 0 deletions Playground/Playground/Views/Subviews/CancelButton.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
//
// CancelButton.swift
// Playground
//
// Created by Kevin Hermawan on 10/30/24.
//

import SwiftUI

struct CancelButton: View {
private let onCancel: () -> Void

init(onCancel: @escaping () -> Void) {
self.onCancel = onCancel
}

var body: some View {
Button(action: onCancel) {
Text("Cancel")
.padding(.vertical, 8)
.frame(maxWidth: .infinity)
}
.buttonStyle(.bordered)
.padding([.horizontal, .bottom])
.padding(.top, 8)
}
}
12 changes: 9 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -215,10 +215,16 @@ do {
case .networkError(let error):
// Handle network-related errors (e.g., no internet connection)
print("Network Error: \(error.localizedDescription)")
case .badServerResponse:
// Handle invalid server responses
print("Invalid response received from server")
case .decodingError(let error):
// Handle errors that occur when the response cannot be decoded
print("Decoding Error: \(error)")
case .cancelled:
// Handle requests that are cancelled
print("Request was cancelled")
}
} catch {
// Handle any other errors
print("An unexpected error occurred: \(error)")
}
```

Expand Down
12 changes: 9 additions & 3 deletions Sources/LLMChatAnthropic/Documentation.docc/Documentation.md
Original file line number Diff line number Diff line change
Expand Up @@ -186,10 +186,16 @@ do {
case .networkError(let error):
// Handle network-related errors (e.g., no internet connection)
print("Network Error: \(error.localizedDescription)")
case .badServerResponse:
// Handle invalid server responses
print("Invalid response received from server")
case .decodingError(let error):
// Handle errors that occur when the response cannot be decoded
print("Decoding Error: \(error)")
case .cancelled:
// Handle requests that are cancelled
print("Request was cancelled")
}
} catch {
// Handle any other errors
print("An unexpected error occurred: \(error)")
}
```

Expand Down
165 changes: 94 additions & 71 deletions Sources/LLMChatAnthropic/LLMChatAnthropic.swift
Original file line number Diff line number Diff line change
Expand Up @@ -89,15 +89,26 @@ private extension LLMChatAnthropic {
let request = try createRequest(for: endpoint, with: body)
let (data, response) = try await URLSession.shared.data(for: request)

guard let httpResponse = response as? HTTPURLResponse else {
throw LLMChatAnthropicError.serverError(response.description)
}

// Check for API errors first, as they might come with 200 status
if let errorResponse = try? JSONDecoder().decode(ChatCompletionError.self, from: data) {
throw LLMChatAnthropicError.serverError(errorResponse.error.message)
}

guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else {
throw LLMChatAnthropicError.badServerResponse
guard 200...299 ~= httpResponse.statusCode else {
throw LLMChatAnthropicError.serverError(response.description)
}

return try JSONDecoder().decode(ChatCompletion.self, from: data)
} catch is CancellationError {
throw LLMChatAnthropicError.cancelled
} catch let error as URLError where error.code == .cancelled {
throw LLMChatAnthropicError.cancelled
} catch let error as DecodingError {
throw LLMChatAnthropicError.decodingError(error)
} catch let error as LLMChatAnthropicError {
throw error
} catch {
Expand All @@ -107,87 +118,99 @@ private extension LLMChatAnthropic {

func performStreamRequest(with body: RequestBody) -> AsyncThrowingStream<ChatCompletionChunk, Error> {
AsyncThrowingStream { continuation in
Task {
do {
let request = try createRequest(for: endpoint, with: body)
let (bytes, response) = try await URLSession.shared.bytes(for: request)

guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else {
let task = Task {
await withTaskCancellationHandler {
do {
let request = try createRequest(for: endpoint, with: body)
let (bytes, response) = try await URLSession.shared.bytes(for: request)

guard let httpResponse = response as? HTTPURLResponse, 200...299 ~= httpResponse.statusCode else {
throw LLMChatAnthropicError.serverError(response.description)
}

var currentChunk = ChatCompletionChunk(id: "", model: "", role: "")

for try await line in bytes.lines {
if let data = line.data(using: .utf8), let errorResponse = try? JSONDecoder().decode(ChatCompletionError.self, from: data) {
throw LLMChatAnthropicError.serverError(errorResponse.error.message)
try Task.checkCancellation()

if line.hasPrefix("event: error") {
throw LLMChatAnthropicError.streamError
}

break
}

throw LLMChatAnthropicError.badServerResponse
}

var currentChunk = ChatCompletionChunk(id: "", model: "", role: "")

for try await line in bytes.lines {
if line.hasPrefix("data: ") {
let jsonData = line.dropFirst(6)
guard line.hasPrefix("data: "), let data = line.dropFirst(6).data(using: .utf8) else {
continue
}

if let data = jsonData.data(using: .utf8) {
let rawChunk = try JSONDecoder().decode(RawChatCompletionChunk.self, from: data)

switch rawChunk.type {
case "message_start":
if let message = rawChunk.message {
currentChunk.id = message.id
currentChunk.role = message.role
currentChunk.model = message.model

if let usage = message.usage, let inputTokens = usage.inputTokens, let outputTokens = usage.outputTokens {
currentChunk.usage = ChatCompletionChunk.Usage(inputTokens: inputTokens, outputTokens: outputTokens)
}

continuation.yield(currentChunk)
}
case "content_block_start":
if let contentBlock = rawChunk.contentBlock {
currentChunk.delta = ChatCompletionChunk.Delta(type: contentBlock.type, toolName: contentBlock.name)

continuation.yield(currentChunk)
}
case "content_block_delta":
if let delta = rawChunk.delta {
currentChunk.delta?.text = delta.text
currentChunk.delta?.toolInput = delta.partialJson

continuation.yield(currentChunk)
let rawChunk = try JSONDecoder().decode(RawChatCompletionChunk.self, from: data)

switch rawChunk.type {
case "message_start":
if let message = rawChunk.message {
currentChunk.id = message.id
currentChunk.role = message.role
currentChunk.model = message.model

if let usage = message.usage, let inputTokens = usage.inputTokens, let outputTokens = usage.outputTokens {
currentChunk.usage = .init(inputTokens: inputTokens, outputTokens: outputTokens)
}
case "message_delta":
if let delta = rawChunk.delta {
currentChunk.delta?.text = nil
currentChunk.delta?.toolInput = nil
currentChunk.stopReason = delta.stopReason
currentChunk.stopSequence = delta.stopSequence

if let outputTokens = rawChunk.usage?.outputTokens {
currentChunk.usage?.outputTokens = outputTokens
}

continuation.yield(currentChunk)

continuation.yield(currentChunk)
}

case "content_block_start":
if let contentBlock = rawChunk.contentBlock {
currentChunk.delta = .init(type: contentBlock.type, toolName: contentBlock.name)

continuation.yield(currentChunk)
}
case "content_block_delta":
if let delta = rawChunk.delta {
currentChunk.delta?.text = delta.text
currentChunk.delta?.toolInput = delta.partialJson

continuation.yield(currentChunk)
}
case "message_delta":
if let delta = rawChunk.delta {
currentChunk.delta?.text = nil
currentChunk.delta?.toolInput = nil
currentChunk.stopReason = delta.stopReason
currentChunk.stopSequence = delta.stopSequence

if let outputTokens = rawChunk.usage?.outputTokens {
currentChunk.usage?.outputTokens = outputTokens
}
case "message_stop":
continuation.finish()
default:
break

continuation.yield(currentChunk)
}
case "message_stop":
continuation.finish()
return
default:
break
}
}

continuation.finish()
} catch is CancellationError {
continuation.finish(throwing: LLMChatAnthropicError.cancelled)
} catch let error as URLError where error.code == .cancelled {
continuation.finish(throwing: LLMChatAnthropicError.cancelled)
} catch let error as DecodingError {
continuation.finish(throwing: LLMChatAnthropicError.decodingError(error))
} catch let error as LLMChatAnthropicError {
continuation.finish(throwing: error)
} catch {
continuation.finish(throwing: LLMChatAnthropicError.networkError(error))
}

continuation.finish()
} catch let error as LLMChatAnthropicError {
continuation.finish(throwing: error)
} catch {
continuation.finish(throwing: LLMChatAnthropicError.networkError(error))
} onCancel: {
continuation.finish(throwing: LLMChatAnthropicError.cancelled)
}
}

continuation.onTermination = { @Sendable _ in
task.cancel()
}
}
}
}
Expand Down
32 changes: 14 additions & 18 deletions Sources/LLMChatAnthropic/LLMChatAnthropicError.swift
Original file line number Diff line number Diff line change
Expand Up @@ -8,29 +8,25 @@
import Foundation

/// An enum that represents errors from the chat completion request.
public enum LLMChatAnthropicError: LocalizedError, Sendable {
/// A case that represents a server-side error response.
public enum LLMChatAnthropicError: Error, Sendable {
/// An error that occurs during JSON decoding.
///
/// - Parameter message: The error message from the server.
case serverError(String)
/// - Parameter error: The underlying decoding error.
case decodingError(Error)

/// A case that represents a network-related error.
/// An error that occurs during network operations.
///
/// - Parameter error: The underlying network error.
case networkError(Error)

/// A case that represents an invalid server response.
case badServerResponse
/// An error returned by the server.
///
/// - Parameter message: The error message received from the server.
case serverError(String)

/// An error that occurs during stream processing.
case streamError

/// A localized message that describes the error.
public var errorDescription: String? {
switch self {
case .serverError(let error):
return error
case .networkError(let error):
return error.localizedDescription
case .badServerResponse:
return "Invalid response received from server"
}
}
/// An error that occurs when the request is cancelled.
case cancelled
}
Loading

0 comments on commit 1bfdd92

Please sign in to comment.