diff --git a/Sources/LanguageServerProtocol/Error.swift b/Sources/LanguageServerProtocol/Error.swift index 8a7d026f..b2ca028f 100644 --- a/Sources/LanguageServerProtocol/Error.swift +++ b/Sources/LanguageServerProtocol/Error.swift @@ -106,6 +106,8 @@ public struct ResponseError: Error, Codable, Hashable { extension ResponseError { // MARK: Convenience properties for common errors. + public static let contentModified: ResponseError = ResponseError(code: .contentModified, message: "content modified") + public static let cancelled: ResponseError = ResponseError(code: .cancelled, message: "request cancelled by client") public static let serverCancelled: ResponseError = ResponseError( diff --git a/Sources/LanguageServerProtocolTransport/QueueBasedMessageHandler.swift b/Sources/LanguageServerProtocolTransport/QueueBasedMessageHandler.swift index fadc54aa..878802d5 100644 --- a/Sources/LanguageServerProtocolTransport/QueueBasedMessageHandler.swift +++ b/Sources/LanguageServerProtocolTransport/QueueBasedMessageHandler.swift @@ -38,8 +38,9 @@ public actor QueueBasedMessageHandlerHelper { /// The requests that we are currently handling. /// - /// Used to cancel the tasks if the client requests cancellation. - private var inProgressRequestsByID: [RequestID: Task<(), Never>] = [:] + /// Used to cancel the tasks if the client requests cancellation. `cancellationError` is the error that should be + /// returned to the client if `task` is cancelled. + private var inProgressRequestsByID: [RequestID: (task: Task<(), Never>, cancellationError: ResponseError?)] = [:] /// Up to 10 request IDs that have recently finished. /// @@ -56,22 +57,42 @@ public actor QueueBasedMessageHandlerHelper { /// /// Cancellation is performed automatically when a `$/cancelRequest` notification is received. This can be called to /// implicitly cancel requests based on some criteria. - @_spi(SourceKitLSP) public nonisolated func cancelRequest(id: RequestID) { + /// + /// `cancellationError` is the error that should be returned to the client for the cancelled request. + @_spi(SourceKitLSP) public nonisolated func cancelRequest(id: RequestID, error cancellationError: ResponseError) { // Since the request is very cheap to execute and stops other requests // from performing more work, we execute it with a high priority. cancellationMessageHandlingQueue.async(priority: .high) { - if let task = await self.inProgressRequestsByID[id] { - task.cancel() - return - } - if await !self.recentlyFinishedRequests.contains(id) { - logger.error( - "Cannot cancel request \(id, privacy: .public) because it hasn't been scheduled for execution yet" - ) + await self.cancelRequestImpl(id: id, cancellationError: cancellationError) + } + } + + private func cancelRequestImpl(id: RequestID, cancellationError: ResponseError) { + // Since the request is very cheap to execute and stops other requests + // from performing more work, we execute it with a high priority. + if let task = self.inProgressRequestsByID[id]?.task { + if self.inProgressRequestsByID[id]?.cancellationError == nil { + // If we already have a cancellation error, stick with that one instead of overriding it. + self.inProgressRequestsByID[id]?.cancellationError = cancellationError } + task.cancel() + return + } + if !self.recentlyFinishedRequests.contains(id) { + logger.error( + "Cannot cancel request \(id, privacy: .public) because it hasn't been scheduled for execution yet" + ) } } + /// The error that should be returned to the client when the request with the given ID has ben cancelled by calling + /// `cancelRequest(id:)`. + fileprivate func cancellationError(for id: RequestID) async -> ResponseError? { + // We don't need to hop onto `cancellationMessageHandlingQueue` here because we will have already set the + // `cancellationError` in `inProgressRequestsByID` before cancelling the `Task`. + self.inProgressRequestsByID[id]?.cancellationError + } + fileprivate nonisolated func setInProgressRequest(id: RequestID, request: some RequestType, task: Task<(), Never>?) { self.cancellationMessageHandlingQueue.async(priority: .background) { await self.setInProgressRequestImpl(id: id, request: request, task: task) @@ -79,8 +100,10 @@ public actor QueueBasedMessageHandlerHelper { } private func setInProgressRequestImpl(id: RequestID, request: some RequestType, task: Task<(), Never>?) { - self.inProgressRequestsByID[id] = task - if task == nil { + if let task { + self.inProgressRequestsByID[id] = (task, nil) + } else { + self.inProgressRequestsByID[id] = nil self.recentlyFinishedRequests.append(id) while self.recentlyFinishedRequests.count > 10 { self.recentlyFinishedRequests.removeFirst() @@ -159,7 +182,7 @@ public protocol QueueBasedMessageHandler: MessageHandler { func handle( request: Request, id: RequestID, - reply: @Sendable @escaping (LSPResult) -> Void + reply: @Sendable @escaping (Result) -> Void ) async } @@ -174,7 +197,7 @@ extension QueueBasedMessageHandler { // need to execute it on `messageHandlingQueue`. if let notification = notification as? CancelRequestNotification { logger.log("Received cancel request notification: \(notification.forLogging)") - self.messageHandlingHelper.cancelRequest(id: notification.id) + self.messageHandlingHelper.cancelRequest(id: notification.id, error: .cancelled) return } self.didReceive(notification: notification) @@ -214,7 +237,21 @@ extension QueueBasedMessageHandler { signposter.emitEvent("Start handling", id: signpostID) await self.messageHandlingHelper.withRequestLoggingScopeIfNecessary(id: id) { await withTaskCancellationHandler { - await self.handle(request: request, id: id, reply: reply) + await self.handle(request: request, id: id) { result in + switch result { + case .success(let response): + reply(.success(response)) + case .failure(let error as CancellationError): + Task { + guard let cancellationError = await self.messageHandlingHelper.cancellationError(for: id) else { + return reply(.failure(ResponseError(error))) + } + reply(.failure(cancellationError)) + } + case .failure(let error): + reply(.failure(ResponseError(error))) + } + } signposter.endInterval("Request", state, "Done") } onCancel: { signposter.emitEvent("Cancelled", id: signpostID) diff --git a/Sources/LanguageServerProtocolTransport/RequestAndReply.swift b/Sources/LanguageServerProtocolTransport/RequestAndReply.swift index f41fd296..bfdf1062 100644 --- a/Sources/LanguageServerProtocolTransport/RequestAndReply.swift +++ b/Sources/LanguageServerProtocolTransport/RequestAndReply.swift @@ -15,15 +15,18 @@ public import LanguageServerProtocol /// A request and a callback that returns the request's reply public final class RequestAndReply: Sendable { + /// The request that is handled by this `RequestAndReply` object. public let params: Params - private let replyBlock: @Sendable (LSPResult) -> Void + + /// The closure that is invoked when the `body` closure passed to `reply` terminates. + private let reply: @Sendable (Result) -> Void /// Whether a reply has been made. Every request must reply exactly once. private let replied: AtomicBool = AtomicBool(initialValue: false) - public init(_ request: Params, reply: @escaping @Sendable (LSPResult) -> Void) { + public init(_ request: Params, reply: @escaping @Sendable (Result) -> Void) { self.params = request - self.replyBlock = reply + self.reply = reply } deinit { @@ -31,13 +34,13 @@ public final class RequestAndReply: Sendable { } /// Call the `replyBlock` with the result produced by the given closure. - public func reply(_ body: @Sendable () async throws -> Params.Response) async { - precondition(!replied.value, "replied to request more than once") - replied.value = true + public func reply(_ body: () async throws -> Params.Response) async { + let didReply = replied.setAndGet(newValue: true) + precondition(!didReply, "replied to request more than once") do { - replyBlock(.success(try await body())) + reply(.success(try await body())) } catch { - replyBlock(.failure(ResponseError(error))) + reply(.failure(error)) } } }