Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Sources/LanguageServerProtocol/Error.swift
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,8 @@ public struct ResponseError: Error, Codable, Hashable {
extension ResponseError {
// MARK: Convenience properties for common errors.

public static let contentModified: ResponseError = ResponseError(code: .contentModified, message: "content modified")

public static let cancelled: ResponseError = ResponseError(code: .cancelled, message: "request cancelled by client")

public static let serverCancelled: ResponseError = ResponseError(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,9 @@ public actor QueueBasedMessageHandlerHelper {

/// The requests that we are currently handling.
///
/// Used to cancel the tasks if the client requests cancellation.
private var inProgressRequestsByID: [RequestID: Task<(), Never>] = [:]
/// Used to cancel the tasks if the client requests cancellation. `cancellationError` is the error that should be
/// returned to the client if `task` is cancelled.
private var inProgressRequestsByID: [RequestID: (task: Task<(), Never>, cancellationError: ResponseError?)] = [:]

/// Up to 10 request IDs that have recently finished.
///
Expand All @@ -56,31 +57,53 @@ public actor QueueBasedMessageHandlerHelper {
///
/// Cancellation is performed automatically when a `$/cancelRequest` notification is received. This can be called to
/// implicitly cancel requests based on some criteria.
@_spi(SourceKitLSP) public nonisolated func cancelRequest(id: RequestID) {
///
/// `cancellationError` is the error that should be returned to the client for the cancelled request.
@_spi(SourceKitLSP) public nonisolated func cancelRequest(id: RequestID, error cancellationError: ResponseError) {
// Since the request is very cheap to execute and stops other requests
// from performing more work, we execute it with a high priority.
cancellationMessageHandlingQueue.async(priority: .high) {
if let task = await self.inProgressRequestsByID[id] {
task.cancel()
return
}
if await !self.recentlyFinishedRequests.contains(id) {
logger.error(
"Cannot cancel request \(id, privacy: .public) because it hasn't been scheduled for execution yet"
)
await self.cancelRequestImpl(id: id, cancellationError: cancellationError)
}
}

private func cancelRequestImpl(id: RequestID, cancellationError: ResponseError) {
// Since the request is very cheap to execute and stops other requests
// from performing more work, we execute it with a high priority.
if let task = self.inProgressRequestsByID[id]?.task {
if self.inProgressRequestsByID[id]?.cancellationError == nil {
// If we already have a cancellation error, stick with that one instead of overriding it.
self.inProgressRequestsByID[id]?.cancellationError = cancellationError
}
task.cancel()
return
}
if !self.recentlyFinishedRequests.contains(id) {
logger.error(
"Cannot cancel request \(id, privacy: .public) because it hasn't been scheduled for execution yet"
)
}
}

/// The error that should be returned to the client when the request with the given ID has ben cancelled by calling
/// `cancelRequest(id:)`.
fileprivate func cancellationError(for id: RequestID) async -> ResponseError? {
// We don't need to hop onto `cancellationMessageHandlingQueue` here because we will have already set the
// `cancellationError` in `inProgressRequestsByID` before cancelling the `Task`.
self.inProgressRequestsByID[id]?.cancellationError
}

fileprivate nonisolated func setInProgressRequest(id: RequestID, request: some RequestType, task: Task<(), Never>?) {
self.cancellationMessageHandlingQueue.async(priority: .background) {
await self.setInProgressRequestImpl(id: id, request: request, task: task)
}
}

private func setInProgressRequestImpl(id: RequestID, request: some RequestType, task: Task<(), Never>?) {
self.inProgressRequestsByID[id] = task
if task == nil {
if let task {
self.inProgressRequestsByID[id] = (task, nil)
} else {
self.inProgressRequestsByID[id] = nil
self.recentlyFinishedRequests.append(id)
while self.recentlyFinishedRequests.count > 10 {
self.recentlyFinishedRequests.removeFirst()
Expand Down Expand Up @@ -159,7 +182,7 @@ public protocol QueueBasedMessageHandler: MessageHandler {
func handle<Request: RequestType>(
request: Request,
id: RequestID,
reply: @Sendable @escaping (LSPResult<Request.Response>) -> Void
reply: @Sendable @escaping (Result<Request.Response, any Error>) -> Void
) async
}

Expand All @@ -174,7 +197,7 @@ extension QueueBasedMessageHandler {
// need to execute it on `messageHandlingQueue`.
if let notification = notification as? CancelRequestNotification {
logger.log("Received cancel request notification: \(notification.forLogging)")
self.messageHandlingHelper.cancelRequest(id: notification.id)
self.messageHandlingHelper.cancelRequest(id: notification.id, error: .cancelled)
return
}
self.didReceive(notification: notification)
Expand Down Expand Up @@ -214,7 +237,21 @@ extension QueueBasedMessageHandler {
signposter.emitEvent("Start handling", id: signpostID)
await self.messageHandlingHelper.withRequestLoggingScopeIfNecessary(id: id) {
await withTaskCancellationHandler {
await self.handle(request: request, id: id, reply: reply)
await self.handle(request: request, id: id) { result in
switch result {
case .success(let response):
reply(.success(response))
case .failure(let error as CancellationError):
Task {
guard let cancellationError = await self.messageHandlingHelper.cancellationError(for: id) else {
return reply(.failure(ResponseError(error)))
}
reply(.failure(cancellationError))
}
case .failure(let error):
reply(.failure(ResponseError(error)))
}
}
signposter.endInterval("Request", state, "Done")
} onCancel: {
signposter.emitEvent("Cancelled", id: signpostID)
Expand Down
19 changes: 11 additions & 8 deletions Sources/LanguageServerProtocolTransport/RequestAndReply.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,32 @@ public import LanguageServerProtocol

/// A request and a callback that returns the request's reply
public final class RequestAndReply<Params: RequestType>: Sendable {
/// The request that is handled by this `RequestAndReply` object.
public let params: Params
private let replyBlock: @Sendable (LSPResult<Params.Response>) -> Void

/// The closure that is invoked when the `body` closure passed to `reply` terminates.
private let reply: @Sendable (Result<Params.Response, any Error>) -> Void

/// Whether a reply has been made. Every request must reply exactly once.
private let replied: AtomicBool = AtomicBool(initialValue: false)

public init(_ request: Params, reply: @escaping @Sendable (LSPResult<Params.Response>) -> Void) {
public init(_ request: Params, reply: @escaping @Sendable (Result<Params.Response, any Error>) -> Void) {
self.params = request
self.replyBlock = reply
self.reply = reply
}

deinit {
precondition(replied.value, "request never received a reply")
}

/// Call the `replyBlock` with the result produced by the given closure.
public func reply(_ body: @Sendable () async throws -> Params.Response) async {
precondition(!replied.value, "replied to request more than once")
replied.value = true
public func reply(_ body: () async throws -> Params.Response) async {
let didReply = replied.setAndGet(newValue: true)
precondition(!didReply, "replied to request more than once")
do {
replyBlock(.success(try await body()))
reply(.success(try await body()))
} catch {
replyBlock(.failure(ResponseError(error)))
reply(.failure(error))
}
}
}
Loading