Skip to content

Commit

Permalink
Better align shutdown behavior of testing event loops
Browse files Browse the repository at this point in the history
  • Loading branch information
simonjbeaumont committed Jul 22, 2024
1 parent bf69999 commit 4c28e6c
Show file tree
Hide file tree
Showing 6 changed files with 223 additions and 108 deletions.
109 changes: 60 additions & 49 deletions Sources/NIOEmbedded/AsyncTestingEventLoop.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,9 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
/// The queue on which we run all our operations.
private let queue = DispatchQueue(label: "io.swiftnio.AsyncEmbeddedEventLoop")

private enum State: Int, AtomicValue { case open, closing, closed }
private let state = ManagedAtomic(State.open)

// This function must only be called on queue.
private func nextTaskNumber() -> UInt64 {
dispatchPrecondition(condition: .onQueue(self.queue))
Expand Down Expand Up @@ -150,6 +153,15 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
let promise: EventLoopPromise<T> = self.makePromise()
let taskID = self.scheduledTaskCounter.loadThenWrappingIncrement(ordering: .relaxed)

switch self.state.load(ordering: .acquiring) {
case .open:
break
case .closing, .closed:
// If the event loop is shut down, or shutting down, immediately cancel the task.
promise.fail(EventLoopError.cancelled)
return Scheduled(promise: promise, cancellationTask: {})
}

let scheduled = Scheduled(
promise: promise,
cancellationTask: {
Expand Down Expand Up @@ -187,27 +199,7 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
} else {
self.queue.async {
self.scheduleTask(deadline: self.now, task)

var tasks = CircularBuffer<EmbeddedScheduledTask>()
while let nextTask = self.scheduledTasks.peek() {
guard nextTask.readyTime <= self.now else {
break
}

// Now we want to grab all tasks that are ready to execute at the same
// time as the first.
while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime
{
tasks.append(candidateTask)
self.scheduledTasks.pop()
}

for task in tasks {
task.task()
}

tasks.removeAll(keepingCapacity: true)
}
self._run()
}
}
}
Expand All @@ -233,41 +225,50 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
///
/// - Note: If `deadline` is before the current time, the current time will not be advanced.
public func advanceTime(to deadline: NIODeadline) async {
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
await withCheckedContinuation { continuation in
self.queue.async {
let newTime = max(deadline, self.now)

var tasks = CircularBuffer<EmbeddedScheduledTask>()
while let nextTask = self.scheduledTasks.peek() {
guard nextTask.readyTime <= newTime else {
break
}
self._advanceTime(to: deadline)
continuation.resume()
}
}
}

// Now we want to grab all tasks that are ready to execute at the same
// time as the first.
while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime
{
tasks.append(candidateTask)
self.scheduledTasks.pop()
}
internal func _advanceTime(to deadline: NIODeadline) {
dispatchPrecondition(condition: .onQueue(self.queue))

// Set the time correctly before we call into user code, then
// call in for all tasks.
self._now.store(nextTask.readyTime.uptimeNanoseconds, ordering: .relaxed)
let newTime = max(deadline, self.now)

for task in tasks {
task.task()
}
var tasks = CircularBuffer<EmbeddedScheduledTask>()
while let nextTask = self.scheduledTasks.peek() {
guard nextTask.readyTime <= newTime else {
break
}

tasks.removeAll(keepingCapacity: true)
}
// Now we want to grab all tasks that are ready to execute at the same
// time as the first.
while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime {
tasks.append(candidateTask)
self.scheduledTasks.pop()
}

// Finally ensure we got the time right.
self._now.store(newTime.uptimeNanoseconds, ordering: .relaxed)
// Set the time correctly before we call into user code, then
// call in for all tasks.
self._now.store(nextTask.readyTime.uptimeNanoseconds, ordering: .relaxed)

continuation.resume()
for task in tasks {
task.task()
}

tasks.removeAll(keepingCapacity: true)
}

// Finally ensure we got the time right.
self._now.store(newTime.uptimeNanoseconds, ordering: .relaxed)
}

internal func _run() {
dispatchPrecondition(condition: .onQueue(self.queue))
self._advanceTime(to: self.now)
}

/// Executes the given function in the context of this event loop. This is useful when it's necessary to be confident that an operation
Expand All @@ -293,6 +294,13 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
}
}

internal func _cancelRemainingScheduledTasks() {
dispatchPrecondition(condition: .onQueue(self.queue))
while let task = self.scheduledTasks.pop() {
task.fail(EventLoopError.cancelled)
}
}

internal func drainScheduledTasksByRunningAllCurrentlyScheduledTasks() {
var currentlyScheduledTasks = self.scheduledTasks
while let nextTask = currentlyScheduledTasks.pop() {
Expand All @@ -309,7 +317,8 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {

private func _shutdownGracefully() {
dispatchPrecondition(condition: .onQueue(self.queue))
self.drainScheduledTasksByRunningAllCurrentlyScheduledTasks()
self._run()
self._cancelRemainingScheduledTasks()
}

/// - see: `EventLoop.shutdownGracefully`
Expand All @@ -324,9 +333,11 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {

/// The concurrency-aware equivalent of `shutdownGracefully(queue:_:)`.
public func shutdownGracefully() async {
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
await withCheckedContinuation { continuation in
self.state.store(.closing, ordering: .releasing)
self.queue.async {
self._shutdownGracefully()
self.state.store(.closed, ordering: .releasing)
continuation.resume()
}
}
Expand Down
35 changes: 19 additions & 16 deletions Sources/NIOEmbedded/Embedded.swift
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,9 @@ public final class EmbeddedEventLoop: EventLoop {
/// The current "time" for this event loop. This is an amount in nanoseconds.
internal var _now: NIODeadline = .uptimeNanoseconds(0)

private enum State { case open, closing, closed }
private var state: State = .open

private var scheduledTaskCounter: UInt64 = 0
private var scheduledTasks = PriorityQueue<EmbeddedScheduledTask>()

Expand Down Expand Up @@ -110,6 +113,16 @@ public final class EmbeddedEventLoop: EventLoop {
@discardableResult
public func scheduleTask<T>(deadline: NIODeadline, _ task: @escaping () throws -> T) -> Scheduled<T> {
let promise: EventLoopPromise<T> = makePromise()

switch self.state {
case .open:
break
case .closing, .closed:
// If the event loop is shut down, or shutting down, immediately cancel the task.
promise.fail(EventLoopError.cancelled)
return Scheduled(promise: promise, cancellationTask: {})
}

self.scheduledTaskCounter += 1
let task = EmbeddedScheduledTask(
id: self.scheduledTaskCounter,
Expand Down Expand Up @@ -197,28 +210,18 @@ public final class EmbeddedEventLoop: EventLoop {
self._now = newTime
}

internal func drainScheduledTasksByRunningAllCurrentlyScheduledTasks() {
var currentlyScheduledTasks = self.scheduledTasks
while let nextTask = currentlyScheduledTasks.pop() {
self._now = nextTask.readyTime
nextTask.task()
}
// Just fail all the remaining scheduled tasks. Despite having run all the tasks that were
// scheduled when we entered the method this may still contain tasks as running the tasks
// may have enqueued more tasks.
internal func cancelRemainingScheduledTasks() {
while let task = self.scheduledTasks.pop() {
task.fail(EventLoopError.shutdown)
task.fail(EventLoopError.cancelled)
}
}

/// - see: `EventLoop.close`
func close() throws {
// Nothing to do here
}

/// - see: `EventLoop.shutdownGracefully`
public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) {
self.state = .closing
run()
cancelRemainingScheduledTasks()
self.state = .closed
queue.sync {
callback(nil)
}
Expand Down Expand Up @@ -640,8 +643,8 @@ public final class EmbeddedChannel: Channel {
throw error
}
}
self.embeddedEventLoop.drainScheduledTasksByRunningAllCurrentlyScheduledTasks()
self.embeddedEventLoop.run()
self.embeddedEventLoop.cancelRemainingScheduledTasks()
try throwIfErrorCaught()
let c = self.channelcore
if c.outboundBuffer.isEmpty && c.inboundBuffer.isEmpty && c.pendingOutboundBuffer.isEmpty {
Expand Down
96 changes: 77 additions & 19 deletions Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -450,39 +450,97 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase {
}
}

func testDrainScheduledTasks() async throws {
func testShutdownCancelsFutureScheduledTasks() async {
let eventLoop = NIOAsyncTestingEventLoop()
let tasksRun = ManagedAtomic(0)
let startTime = eventLoop.now

eventLoop.scheduleTask(in: .nanoseconds(3_141_592)) {
XCTAssertEqual(eventLoop.now, startTime + .nanoseconds(3_141_592))
tasksRun.wrappingIncrement(ordering: .relaxed)
}
let a = eventLoop.scheduleTask(in: .seconds(1)) { tasksRun.wrappingIncrement(ordering: .relaxed) }
let b = eventLoop.scheduleTask(in: .seconds(2)) { tasksRun.wrappingIncrement(ordering: .relaxed) }

eventLoop.scheduleTask(in: .seconds(3_141_592)) {
XCTAssertEqual(eventLoop.now, startTime + .seconds(3_141_592))
tasksRun.wrappingIncrement(ordering: .relaxed)
}
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 0)

await eventLoop.shutdownGracefully()
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 2)
await eventLoop.advanceTime(by: .seconds(1))
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)

XCTAssertNoThrow(try eventLoop.syncShutdownGracefully())
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)

await eventLoop.advanceTime(by: .seconds(1))
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)

await eventLoop.advanceTime(to: .distantFuture)
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)

XCTAssertNoThrow(try a.futureResult.wait())
await XCTAssertThrowsError(try await b.futureResult.get()) { error in
XCTAssertEqual(error as? EventLoopError, .cancelled)
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)
}
}

func testDrainScheduledTasksDoesNotRunNewlyScheduledTasks() async throws {
func testTasksScheduledDuringShutdownAreAutomaticallyCancelled() async throws {
let eventLoop = NIOAsyncTestingEventLoop()
let tasksRun = ManagedAtomic(0)
var childTasks: [Scheduled<Void>] = []

func scheduleNowAndIncrement() {
eventLoop.scheduleTask(in: .nanoseconds(0)) {
func scheduleRecursiveTask(
at taskStartTime: NIODeadline,
andChildTaskAfter childTaskStartDelay: TimeAmount
) -> Scheduled<Void> {
eventLoop.scheduleTask(deadline: taskStartTime) {
tasksRun.wrappingIncrement(ordering: .relaxed)
scheduleNowAndIncrement()
childTasks.append(
scheduleRecursiveTask(
at: eventLoop.now + childTaskStartDelay,
andChildTaskAfter: childTaskStartDelay
)
)
}
}

scheduleNowAndIncrement()
await eventLoop.shutdownGracefully()
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)
_ = scheduleRecursiveTask(at: .uptimeNanoseconds(1), andChildTaskAfter: .zero)

try await withThrowingTaskGroup(of: Void.self) { group in
group.addTask {
try await Task.sleep(for: .milliseconds(1))
await eventLoop.shutdownGracefully()
}
group.addTask {
await eventLoop.advanceTime(to: .uptimeNanoseconds(1))
}
try await group.waitForAll()
}

XCTAssertGreaterThan(tasksRun.load(ordering: .relaxed), 1)
XCTAssertEqual(childTasks.count, tasksRun.load(ordering: .relaxed))
}

func testShutdownCancelsRemainingScheduledTasks() async {
let eventLoop = NIOAsyncTestingEventLoop()
var tasksRun = 0

let a = eventLoop.scheduleTask(in: .seconds(1)) { tasksRun += 1 }
let b = eventLoop.scheduleTask(in: .seconds(2)) { tasksRun += 1 }

XCTAssertEqual(tasksRun, 0)

await eventLoop.advanceTime(by: .seconds(1))
XCTAssertEqual(tasksRun, 1)

XCTAssertNoThrow(try eventLoop.syncShutdownGracefully())
XCTAssertEqual(tasksRun, 1)

await eventLoop.advanceTime(by: .seconds(1))
XCTAssertEqual(tasksRun, 1)

await eventLoop.advanceTime(to: .distantFuture)
XCTAssertEqual(tasksRun, 1)

XCTAssertNoThrow(try a.futureResult.wait())
await XCTAssertThrowsError(try await b.futureResult.get()) { error in
XCTAssertEqual(error as? EventLoopError, .cancelled)
XCTAssertEqual(tasksRun, 1)
}
}

func testAdvanceTimeToDeadline() async throws {
Expand Down
22 changes: 20 additions & 2 deletions Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift
Original file line number Diff line number Diff line change
Expand Up @@ -478,19 +478,37 @@ class EmbeddedChannelTest: XCTestCase {

func testFinishWithRecursivelyScheduledTasks() throws {
let channel = EmbeddedChannel()
var tasks: [Scheduled<Void>] = []
var invocations = 0

func recursivelyScheduleAndIncrement() {
channel.pipeline.eventLoop.scheduleTask(deadline: .distantFuture) {
let task = channel.pipeline.eventLoop.scheduleTask(deadline: .distantFuture) {
invocations += 1
recursivelyScheduleAndIncrement()
}
tasks.append(task)
}

recursivelyScheduleAndIncrement()

try XCTAssertNoThrow(channel.finish())
XCTAssertEqual(invocations, 1)

// None of the tasks should have been executed, they were scheduled for distant future.
XCTAssertEqual(invocations, 0)

// Because the root task didn't run, it should be the onnly one scheduled.
XCTAssertEqual(tasks.count, 1)

// Check the task was failed with cancelled error.
let taskChecked = expectation(description: "task future fulfilled")
tasks.first?.futureResult.whenComplete { result in
switch result {
case .success: XCTFail("Expected task to be cancelled, not run.")
case .failure(let error): XCTAssertEqual(error as? EventLoopError, .cancelled)
}
taskChecked.fulfill()
}
wait(for: [taskChecked], timeout: 0)
}

func testGetChannelOptionAutoReadIsSupported() {
Expand Down
Loading

0 comments on commit 4c28e6c

Please sign in to comment.