Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better align shutdown semantics of testing event loops #2800

Merged
merged 5 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 60 additions & 49 deletions Sources/NIOEmbedded/AsyncTestingEventLoop.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
/// The queue on which we run all our operations.
private let queue = DispatchQueue(label: "io.swiftnio.AsyncEmbeddedEventLoop")

private enum State: Int, AtomicValue { case open, closing, closed }
private let state = ManagedAtomic(State.open)

// This function must only be called on queue.
private func nextTaskNumber() -> UInt64 {
dispatchPrecondition(condition: .onQueue(self.queue))
Expand Down Expand Up @@ -150,6 +153,15 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
let promise: EventLoopPromise<T> = self.makePromise()
let taskID = self.scheduledTaskCounter.loadThenWrappingIncrement(ordering: .relaxed)

switch self.state.load(ordering: .acquiring) {
case .open:
break
case .closing, .closed:
// If the event loop is shut down, or shutting down, immediately cancel the task.
promise.fail(EventLoopError.cancelled)
return Scheduled(promise: promise, cancellationTask: {})
}

let scheduled = Scheduled(
promise: promise,
cancellationTask: {
Expand Down Expand Up @@ -187,27 +199,7 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
} else {
self.queue.async {
self.scheduleTask(deadline: self.now, task)

var tasks = CircularBuffer<EmbeddedScheduledTask>()
while let nextTask = self.scheduledTasks.peek() {
guard nextTask.readyTime <= self.now else {
break
}

// Now we want to grab all tasks that are ready to execute at the same
// time as the first.
while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime
{
tasks.append(candidateTask)
self.scheduledTasks.pop()
}

for task in tasks {
task.task()
}

tasks.removeAll(keepingCapacity: true)
}
self._run()
}
}
}
Expand All @@ -233,41 +225,50 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
///
/// - Note: If `deadline` is before the current time, the current time will not be advanced.
public func advanceTime(to deadline: NIODeadline) async {
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
await withCheckedContinuation { continuation in
self.queue.async {
let newTime = max(deadline, self.now)

var tasks = CircularBuffer<EmbeddedScheduledTask>()
while let nextTask = self.scheduledTasks.peek() {
guard nextTask.readyTime <= newTime else {
break
}
self._advanceTime(to: deadline)
continuation.resume()
}
}
}

// Now we want to grab all tasks that are ready to execute at the same
// time as the first.
while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime
{
tasks.append(candidateTask)
self.scheduledTasks.pop()
}
internal func _advanceTime(to deadline: NIODeadline) {
dispatchPrecondition(condition: .onQueue(self.queue))

// Set the time correctly before we call into user code, then
// call in for all tasks.
self._now.store(nextTask.readyTime.uptimeNanoseconds, ordering: .relaxed)
let newTime = max(deadline, self.now)

for task in tasks {
task.task()
}
var tasks = CircularBuffer<EmbeddedScheduledTask>()
while let nextTask = self.scheduledTasks.peek() {
guard nextTask.readyTime <= newTime else {
break
}

tasks.removeAll(keepingCapacity: true)
}
// Now we want to grab all tasks that are ready to execute at the same
// time as the first.
while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime {
tasks.append(candidateTask)
self.scheduledTasks.pop()
}

// Finally ensure we got the time right.
self._now.store(newTime.uptimeNanoseconds, ordering: .relaxed)
// Set the time correctly before we call into user code, then
// call in for all tasks.
self._now.store(nextTask.readyTime.uptimeNanoseconds, ordering: .relaxed)

continuation.resume()
for task in tasks {
task.task()
}

tasks.removeAll(keepingCapacity: true)
}

// Finally ensure we got the time right.
self._now.store(newTime.uptimeNanoseconds, ordering: .relaxed)
}

internal func _run() {
dispatchPrecondition(condition: .onQueue(self.queue))
self._advanceTime(to: self.now)
}

/// Executes the given function in the context of this event loop. This is useful when it's necessary to be confident that an operation
Expand All @@ -293,6 +294,13 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
}
}

internal func _cancelRemainingScheduledTasks() {
dispatchPrecondition(condition: .onQueue(self.queue))
while let task = self.scheduledTasks.pop() {
task.fail(EventLoopError.cancelled)
}
}

internal func drainScheduledTasksByRunningAllCurrentlyScheduledTasks() {
var currentlyScheduledTasks = self.scheduledTasks
while let nextTask = currentlyScheduledTasks.pop() {
Expand All @@ -309,7 +317,8 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {

private func _shutdownGracefully() {
dispatchPrecondition(condition: .onQueue(self.queue))
self.drainScheduledTasksByRunningAllCurrentlyScheduledTasks()
self._run()
self._cancelRemainingScheduledTasks()
}

/// - see: `EventLoop.shutdownGracefully`
Expand All @@ -324,9 +333,11 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {

/// The concurrency-aware equivalent of `shutdownGracefully(queue:_:)`.
public func shutdownGracefully() async {
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
await withCheckedContinuation { continuation in
self.state.store(.closing, ordering: .releasing)
self.queue.async {
self._shutdownGracefully()
self.state.store(.closed, ordering: .releasing)
continuation.resume()
}
}
Expand Down
35 changes: 19 additions & 16 deletions Sources/NIOEmbedded/Embedded.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ public final class EmbeddedEventLoop: EventLoop {
/// The current "time" for this event loop. This is an amount in nanoseconds.
internal var _now: NIODeadline = .uptimeNanoseconds(0)

private enum State { case open, closing, closed }
private var state: State = .open

private var scheduledTaskCounter: UInt64 = 0
private var scheduledTasks = PriorityQueue<EmbeddedScheduledTask>()

Expand Down Expand Up @@ -110,6 +113,16 @@ public final class EmbeddedEventLoop: EventLoop {
@discardableResult
public func scheduleTask<T>(deadline: NIODeadline, _ task: @escaping () throws -> T) -> Scheduled<T> {
let promise: EventLoopPromise<T> = makePromise()

switch self.state {
case .open:
break
case .closing, .closed:
// If the event loop is shut down, or shutting down, immediately cancel the task.
promise.fail(EventLoopError.cancelled)
return Scheduled(promise: promise, cancellationTask: {})
}

self.scheduledTaskCounter += 1
let task = EmbeddedScheduledTask(
id: self.scheduledTaskCounter,
Expand Down Expand Up @@ -197,28 +210,18 @@ public final class EmbeddedEventLoop: EventLoop {
self._now = newTime
}

internal func drainScheduledTasksByRunningAllCurrentlyScheduledTasks() {
var currentlyScheduledTasks = self.scheduledTasks
while let nextTask = currentlyScheduledTasks.pop() {
self._now = nextTask.readyTime
nextTask.task()
}
// Just fail all the remaining scheduled tasks. Despite having run all the tasks that were
// scheduled when we entered the method this may still contain tasks as running the tasks
// may have enqueued more tasks.
internal func cancelRemainingScheduledTasks() {
while let task = self.scheduledTasks.pop() {
task.fail(EventLoopError.shutdown)
task.fail(EventLoopError.cancelled)
}
}

/// - see: `EventLoop.close`
func close() throws {
// Nothing to do here
}

/// - see: `EventLoop.shutdownGracefully`
public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) {
self.state = .closing
run()
cancelRemainingScheduledTasks()
self.state = .closed
queue.sync {
callback(nil)
}
Expand Down Expand Up @@ -640,8 +643,8 @@ public final class EmbeddedChannel: Channel {
throw error
}
}
self.embeddedEventLoop.drainScheduledTasksByRunningAllCurrentlyScheduledTasks()
self.embeddedEventLoop.run()
self.embeddedEventLoop.cancelRemainingScheduledTasks()
try throwIfErrorCaught()
let c = self.channelcore
if c.outboundBuffer.isEmpty && c.inboundBuffer.isEmpty && c.pendingOutboundBuffer.isEmpty {
Expand Down
96 changes: 77 additions & 19 deletions Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -450,39 +450,97 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase {
}
}

func testDrainScheduledTasks() async throws {
func testShutdownCancelsFutureScheduledTasks() async {
let eventLoop = NIOAsyncTestingEventLoop()
let tasksRun = ManagedAtomic(0)
let startTime = eventLoop.now

eventLoop.scheduleTask(in: .nanoseconds(3_141_592)) {
XCTAssertEqual(eventLoop.now, startTime + .nanoseconds(3_141_592))
tasksRun.wrappingIncrement(ordering: .relaxed)
}
let a = eventLoop.scheduleTask(in: .seconds(1)) { tasksRun.wrappingIncrement(ordering: .relaxed) }
let b = eventLoop.scheduleTask(in: .seconds(2)) { tasksRun.wrappingIncrement(ordering: .relaxed) }

eventLoop.scheduleTask(in: .seconds(3_141_592)) {
XCTAssertEqual(eventLoop.now, startTime + .seconds(3_141_592))
tasksRun.wrappingIncrement(ordering: .relaxed)
}
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 0)

await eventLoop.shutdownGracefully()
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 2)
await eventLoop.advanceTime(by: .seconds(1))
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)

XCTAssertNoThrow(try eventLoop.syncShutdownGracefully())
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)

await eventLoop.advanceTime(by: .seconds(1))
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)

await eventLoop.advanceTime(to: .distantFuture)
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)

XCTAssertNoThrow(try a.futureResult.wait())
await XCTAssertThrowsError(try await b.futureResult.get()) { error in
XCTAssertEqual(error as? EventLoopError, .cancelled)
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)
}
}

func testDrainScheduledTasksDoesNotRunNewlyScheduledTasks() async throws {
func testTasksScheduledDuringShutdownAreAutomaticallyCancelled() async throws {
let eventLoop = NIOAsyncTestingEventLoop()
let tasksRun = ManagedAtomic(0)
var childTasks: [Scheduled<Void>] = []

func scheduleNowAndIncrement() {
eventLoop.scheduleTask(in: .nanoseconds(0)) {
func scheduleRecursiveTask(
at taskStartTime: NIODeadline,
andChildTaskAfter childTaskStartDelay: TimeAmount
) -> Scheduled<Void> {
eventLoop.scheduleTask(deadline: taskStartTime) {
tasksRun.wrappingIncrement(ordering: .relaxed)
scheduleNowAndIncrement()
childTasks.append(
scheduleRecursiveTask(
at: eventLoop.now + childTaskStartDelay,
andChildTaskAfter: childTaskStartDelay
)
)
}
}

scheduleNowAndIncrement()
await eventLoop.shutdownGracefully()
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)
_ = scheduleRecursiveTask(at: .uptimeNanoseconds(1), andChildTaskAfter: .zero)

try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await Task.sleep(for: .milliseconds(1))
await eventLoop.shutdownGracefully()
}
group.addTask {
await eventLoop.advanceTime(to: .uptimeNanoseconds(1))
}
try await group.waitForAll()
}

XCTAssertGreaterThan(tasksRun.load(ordering: .relaxed), 1)
XCTAssertEqual(childTasks.count, tasksRun.load(ordering: .relaxed))
}

func testShutdownCancelsRemainingScheduledTasks() async {
let eventLoop = NIOAsyncTestingEventLoop()
var tasksRun = 0

let a = eventLoop.scheduleTask(in: .seconds(1)) { tasksRun += 1 }
let b = eventLoop.scheduleTask(in: .seconds(2)) { tasksRun += 1 }

XCTAssertEqual(tasksRun, 0)

await eventLoop.advanceTime(by: .seconds(1))
XCTAssertEqual(tasksRun, 1)

XCTAssertNoThrow(try eventLoop.syncShutdownGracefully())
XCTAssertEqual(tasksRun, 1)

await eventLoop.advanceTime(by: .seconds(1))
XCTAssertEqual(tasksRun, 1)

await eventLoop.advanceTime(to: .distantFuture)
XCTAssertEqual(tasksRun, 1)

XCTAssertNoThrow(try a.futureResult.wait())
await XCTAssertThrowsError(try await b.futureResult.get()) { error in
XCTAssertEqual(error as? EventLoopError, .cancelled)
XCTAssertEqual(tasksRun, 1)
}
}

func testAdvanceTimeToDeadline() async throws {
Expand Down
22 changes: 20 additions & 2 deletions Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -478,19 +478,37 @@ class EmbeddedChannelTest: XCTestCase {

func testFinishWithRecursivelyScheduledTasks() throws {
let channel = EmbeddedChannel()
var tasks: [Scheduled<Void>] = []
var invocations = 0

func recursivelyScheduleAndIncrement() {
channel.pipeline.eventLoop.scheduleTask(deadline: .distantFuture) {
let task = channel.pipeline.eventLoop.scheduleTask(deadline: .distantFuture) {
invocations += 1
recursivelyScheduleAndIncrement()
}
tasks.append(task)
}

recursivelyScheduleAndIncrement()

try XCTAssertNoThrow(channel.finish())
XCTAssertEqual(invocations, 1)

// None of the tasks should have been executed, they were scheduled for distant future.
XCTAssertEqual(invocations, 0)

// Because the root task didn't run, it should be the onnly one scheduled.
XCTAssertEqual(tasks.count, 1)

// Check the task was failed with cancelled error.
let taskChecked = expectation(description: "task future fulfilled")
tasks.first?.futureResult.whenComplete { result in
switch result {
case .success: XCTFail("Expected task to be cancelled, not run.")
case .failure(let error): XCTAssertEqual(error as? EventLoopError, .cancelled)
}
taskChecked.fulfill()
}
wait(for: [taskChecked], timeout: 0)
}

func testGetChannelOptionAutoReadIsSupported() {
Expand Down
Loading
Loading