Skip to content

Commit 4bcff6e

Browse files
Better align shutdown behavior of testing event loops
1 parent bf69999 commit 4bcff6e

File tree

6 files changed

+224
-108
lines changed

6 files changed

+224
-108
lines changed

Sources/NIOEmbedded/AsyncTestingEventLoop.swift

+61-49
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,10 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
9595
/// The queue on which we run all our operations.
9696
private let queue = DispatchQueue(label: "io.swiftnio.AsyncEmbeddedEventLoop")
9797

98+
private enum State: Int, AtomicValue { case open, closing, closed }
99+
private let _state = ManagedAtomic(State.open)
100+
private var state: State { self._state.load(ordering: .relaxed) }
101+
98102
// This function must only be called on queue.
99103
private func nextTaskNumber() -> UInt64 {
100104
dispatchPrecondition(condition: .onQueue(self.queue))
@@ -150,6 +154,15 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
150154
let promise: EventLoopPromise<T> = self.makePromise()
151155
let taskID = self.scheduledTaskCounter.loadThenWrappingIncrement(ordering: .relaxed)
152156

157+
switch self.state {
158+
case .open:
159+
break
160+
case .closing, .closed:
161+
// If the event loop is shut down, or shutting down, immediately cancel the task.
162+
promise.fail(EventLoopError.cancelled)
163+
return Scheduled(promise: promise, cancellationTask: {})
164+
}
165+
153166
let scheduled = Scheduled(
154167
promise: promise,
155168
cancellationTask: {
@@ -187,27 +200,7 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
187200
} else {
188201
self.queue.async {
189202
self.scheduleTask(deadline: self.now, task)
190-
191-
var tasks = CircularBuffer<EmbeddedScheduledTask>()
192-
while let nextTask = self.scheduledTasks.peek() {
193-
guard nextTask.readyTime <= self.now else {
194-
break
195-
}
196-
197-
// Now we want to grab all tasks that are ready to execute at the same
198-
// time as the first.
199-
while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime
200-
{
201-
tasks.append(candidateTask)
202-
self.scheduledTasks.pop()
203-
}
204-
205-
for task in tasks {
206-
task.task()
207-
}
208-
209-
tasks.removeAll(keepingCapacity: true)
210-
}
203+
self._run()
211204
}
212205
}
213206
}
@@ -233,41 +226,50 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
233226
///
234227
/// - Note: If `deadline` is before the current time, the current time will not be advanced.
235228
public func advanceTime(to deadline: NIODeadline) async {
236-
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
229+
await withCheckedContinuation { continuation in
237230
self.queue.async {
238-
let newTime = max(deadline, self.now)
239-
240-
var tasks = CircularBuffer<EmbeddedScheduledTask>()
241-
while let nextTask = self.scheduledTasks.peek() {
242-
guard nextTask.readyTime <= newTime else {
243-
break
244-
}
231+
self._advanceTime(to: deadline)
232+
continuation.resume()
233+
}
234+
}
235+
}
245236

246-
// Now we want to grab all tasks that are ready to execute at the same
247-
// time as the first.
248-
while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime
249-
{
250-
tasks.append(candidateTask)
251-
self.scheduledTasks.pop()
252-
}
237+
internal func _advanceTime(to deadline: NIODeadline) {
238+
dispatchPrecondition(condition: .onQueue(self.queue))
253239

254-
// Set the time correctly before we call into user code, then
255-
// call in for all tasks.
256-
self._now.store(nextTask.readyTime.uptimeNanoseconds, ordering: .relaxed)
240+
let newTime = max(deadline, self.now)
257241

258-
for task in tasks {
259-
task.task()
260-
}
242+
var tasks = CircularBuffer<EmbeddedScheduledTask>()
243+
while let nextTask = self.scheduledTasks.peek() {
244+
guard nextTask.readyTime <= newTime else {
245+
break
246+
}
261247

262-
tasks.removeAll(keepingCapacity: true)
263-
}
248+
// Now we want to grab all tasks that are ready to execute at the same
249+
// time as the first.
250+
while let candidateTask = self.scheduledTasks.peek(), candidateTask.readyTime == nextTask.readyTime {
251+
tasks.append(candidateTask)
252+
self.scheduledTasks.pop()
253+
}
264254

265-
// Finally ensure we got the time right.
266-
self._now.store(newTime.uptimeNanoseconds, ordering: .relaxed)
255+
// Set the time correctly before we call into user code, then
256+
// call in for all tasks.
257+
self._now.store(nextTask.readyTime.uptimeNanoseconds, ordering: .relaxed)
267258

268-
continuation.resume()
259+
for task in tasks {
260+
task.task()
269261
}
262+
263+
tasks.removeAll(keepingCapacity: true)
270264
}
265+
266+
// Finally ensure we got the time right.
267+
self._now.store(newTime.uptimeNanoseconds, ordering: .relaxed)
268+
}
269+
270+
internal func _run() {
271+
dispatchPrecondition(condition: .onQueue(self.queue))
272+
self._advanceTime(to: self.now)
271273
}
272274

273275
/// Executes the given function in the context of this event loop. This is useful when it's necessary to be confident that an operation
@@ -293,6 +295,13 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
293295
}
294296
}
295297

298+
internal func _cancelRemainingScheduledTasks() {
299+
dispatchPrecondition(condition: .onQueue(self.queue))
300+
while let task = self.scheduledTasks.pop() {
301+
task.fail(EventLoopError.cancelled)
302+
}
303+
}
304+
296305
internal func drainScheduledTasksByRunningAllCurrentlyScheduledTasks() {
297306
var currentlyScheduledTasks = self.scheduledTasks
298307
while let nextTask = currentlyScheduledTasks.pop() {
@@ -309,7 +318,8 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
309318

310319
private func _shutdownGracefully() {
311320
dispatchPrecondition(condition: .onQueue(self.queue))
312-
self.drainScheduledTasksByRunningAllCurrentlyScheduledTasks()
321+
self._run()
322+
self._cancelRemainingScheduledTasks()
313323
}
314324

315325
/// - see: `EventLoop.shutdownGracefully`
@@ -324,9 +334,11 @@ public final class NIOAsyncTestingEventLoop: EventLoop, @unchecked Sendable {
324334

325335
/// The concurrency-aware equivalent of `shutdownGracefully(queue:_:)`.
326336
public func shutdownGracefully() async {
327-
await withCheckedContinuation { (continuation: CheckedContinuation<Void, Never>) in
337+
await withCheckedContinuation { continuation in
338+
self._state.store(.closing, ordering: .relaxed)
328339
self.queue.async {
329340
self._shutdownGracefully()
341+
self._state.store(.closed, ordering: .relaxed)
330342
continuation.resume()
331343
}
332344
}

Sources/NIOEmbedded/Embedded.swift

+19-16
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,9 @@ public final class EmbeddedEventLoop: EventLoop {
8080
/// The current "time" for this event loop. This is an amount in nanoseconds.
8181
internal var _now: NIODeadline = .uptimeNanoseconds(0)
8282

83+
private enum State { case open, closing, closed }
84+
private var state: State = .open
85+
8386
private var scheduledTaskCounter: UInt64 = 0
8487
private var scheduledTasks = PriorityQueue<EmbeddedScheduledTask>()
8588

@@ -110,6 +113,16 @@ public final class EmbeddedEventLoop: EventLoop {
110113
@discardableResult
111114
public func scheduleTask<T>(deadline: NIODeadline, _ task: @escaping () throws -> T) -> Scheduled<T> {
112115
let promise: EventLoopPromise<T> = makePromise()
116+
117+
switch self.state {
118+
case .open:
119+
break
120+
case .closing, .closed:
121+
// If the event loop is shut down, or shutting down, immediately cancel the task.
122+
promise.fail(EventLoopError.cancelled)
123+
return Scheduled(promise: promise, cancellationTask: {})
124+
}
125+
113126
self.scheduledTaskCounter += 1
114127
let task = EmbeddedScheduledTask(
115128
id: self.scheduledTaskCounter,
@@ -197,28 +210,18 @@ public final class EmbeddedEventLoop: EventLoop {
197210
self._now = newTime
198211
}
199212

200-
internal func drainScheduledTasksByRunningAllCurrentlyScheduledTasks() {
201-
var currentlyScheduledTasks = self.scheduledTasks
202-
while let nextTask = currentlyScheduledTasks.pop() {
203-
self._now = nextTask.readyTime
204-
nextTask.task()
205-
}
206-
// Just fail all the remaining scheduled tasks. Despite having run all the tasks that were
207-
// scheduled when we entered the method this may still contain tasks as running the tasks
208-
// may have enqueued more tasks.
213+
internal func cancelRemainingScheduledTasks() {
209214
while let task = self.scheduledTasks.pop() {
210-
task.fail(EventLoopError.shutdown)
215+
task.fail(EventLoopError.cancelled)
211216
}
212217
}
213218

214-
/// - see: `EventLoop.close`
215-
func close() throws {
216-
// Nothing to do here
217-
}
218-
219219
/// - see: `EventLoop.shutdownGracefully`
220220
public func shutdownGracefully(queue: DispatchQueue, _ callback: @escaping (Error?) -> Void) {
221+
self.state = .closing
221222
run()
223+
cancelRemainingScheduledTasks()
224+
self.state = .closed
222225
queue.sync {
223226
callback(nil)
224227
}
@@ -640,8 +643,8 @@ public final class EmbeddedChannel: Channel {
640643
throw error
641644
}
642645
}
643-
self.embeddedEventLoop.drainScheduledTasksByRunningAllCurrentlyScheduledTasks()
644646
self.embeddedEventLoop.run()
647+
self.embeddedEventLoop.cancelRemainingScheduledTasks()
645648
try throwIfErrorCaught()
646649
let c = self.channelcore
647650
if c.outboundBuffer.isEmpty && c.inboundBuffer.isEmpty && c.pendingOutboundBuffer.isEmpty {

Tests/NIOEmbeddedTests/AsyncTestingEventLoopTests.swift

+77-19
Original file line numberDiff line numberDiff line change
@@ -450,39 +450,97 @@ final class NIOAsyncTestingEventLoopTests: XCTestCase {
450450
}
451451
}
452452

453-
func testDrainScheduledTasks() async throws {
453+
func testShutdownCancelsFutureScheduledTasks() async {
454454
let eventLoop = NIOAsyncTestingEventLoop()
455455
let tasksRun = ManagedAtomic(0)
456-
let startTime = eventLoop.now
457456

458-
eventLoop.scheduleTask(in: .nanoseconds(3_141_592)) {
459-
XCTAssertEqual(eventLoop.now, startTime + .nanoseconds(3_141_592))
460-
tasksRun.wrappingIncrement(ordering: .relaxed)
461-
}
457+
let a = eventLoop.scheduleTask(in: .seconds(1)) { tasksRun.wrappingIncrement(ordering: .relaxed) }
458+
let b = eventLoop.scheduleTask(in: .seconds(2)) { tasksRun.wrappingIncrement(ordering: .relaxed) }
462459

463-
eventLoop.scheduleTask(in: .seconds(3_141_592)) {
464-
XCTAssertEqual(eventLoop.now, startTime + .seconds(3_141_592))
465-
tasksRun.wrappingIncrement(ordering: .relaxed)
466-
}
460+
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 0)
467461

468-
await eventLoop.shutdownGracefully()
469-
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 2)
462+
await eventLoop.advanceTime(by: .seconds(1))
463+
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)
464+
465+
XCTAssertNoThrow(try eventLoop.syncShutdownGracefully())
466+
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)
467+
468+
await eventLoop.advanceTime(by: .seconds(1))
469+
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)
470+
471+
await eventLoop.advanceTime(to: .distantFuture)
472+
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)
473+
474+
XCTAssertNoThrow(try a.futureResult.wait())
475+
await XCTAssertThrowsError(try await b.futureResult.get()) { error in
476+
XCTAssertEqual(error as? EventLoopError, .cancelled)
477+
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)
478+
}
470479
}
471480

472-
func testDrainScheduledTasksDoesNotRunNewlyScheduledTasks() async throws {
481+
func testTasksScheduledDuringShutdownAreAutomaticallyCancelled() async throws {
473482
let eventLoop = NIOAsyncTestingEventLoop()
474483
let tasksRun = ManagedAtomic(0)
484+
var childTasks: [Scheduled<Void>] = []
475485

476-
func scheduleNowAndIncrement() {
477-
eventLoop.scheduleTask(in: .nanoseconds(0)) {
486+
func scheduleRecursiveTask(
487+
at taskStartTime: NIODeadline,
488+
andChildTaskAfter childTaskStartDelay: TimeAmount
489+
) -> Scheduled<Void> {
490+
eventLoop.scheduleTask(deadline: taskStartTime) {
478491
tasksRun.wrappingIncrement(ordering: .relaxed)
479-
scheduleNowAndIncrement()
492+
childTasks.append(
493+
scheduleRecursiveTask(
494+
at: eventLoop.now + childTaskStartDelay,
495+
andChildTaskAfter: childTaskStartDelay
496+
)
497+
)
480498
}
481499
}
482500

483-
scheduleNowAndIncrement()
484-
await eventLoop.shutdownGracefully()
485-
XCTAssertEqual(tasksRun.load(ordering: .relaxed), 1)
501+
_ = scheduleRecursiveTask(at: .uptimeNanoseconds(1), andChildTaskAfter: .zero)
502+
503+
try await withThrowingTaskGroup(of: Void.self) { group in
504+
group.addTask {
505+
try await Task.sleep(for: .milliseconds(1))
506+
await eventLoop.shutdownGracefully()
507+
}
508+
group.addTask {
509+
await eventLoop.advanceTime(to: .uptimeNanoseconds(1))
510+
}
511+
try await group.waitForAll()
512+
}
513+
514+
XCTAssertGreaterThan(tasksRun.load(ordering: .relaxed), 1)
515+
XCTAssertEqual(childTasks.count, tasksRun.load(ordering: .relaxed))
516+
}
517+
518+
func testShutdownCancelsRemainingScheduledTasks() async {
519+
let eventLoop = NIOAsyncTestingEventLoop()
520+
var tasksRun = 0
521+
522+
let a = eventLoop.scheduleTask(in: .seconds(1)) { tasksRun += 1 }
523+
let b = eventLoop.scheduleTask(in: .seconds(2)) { tasksRun += 1 }
524+
525+
XCTAssertEqual(tasksRun, 0)
526+
527+
await eventLoop.advanceTime(by: .seconds(1))
528+
XCTAssertEqual(tasksRun, 1)
529+
530+
XCTAssertNoThrow(try eventLoop.syncShutdownGracefully())
531+
XCTAssertEqual(tasksRun, 1)
532+
533+
await eventLoop.advanceTime(by: .seconds(1))
534+
XCTAssertEqual(tasksRun, 1)
535+
536+
await eventLoop.advanceTime(to: .distantFuture)
537+
XCTAssertEqual(tasksRun, 1)
538+
539+
XCTAssertNoThrow(try a.futureResult.wait())
540+
await XCTAssertThrowsError(try await b.futureResult.get()) { error in
541+
XCTAssertEqual(error as? EventLoopError, .cancelled)
542+
XCTAssertEqual(tasksRun, 1)
543+
}
486544
}
487545

488546
func testAdvanceTimeToDeadline() async throws {

Tests/NIOEmbeddedTests/EmbeddedChannelTest.swift

+20-2
Original file line numberDiff line numberDiff line change
@@ -478,19 +478,37 @@ class EmbeddedChannelTest: XCTestCase {
478478

479479
func testFinishWithRecursivelyScheduledTasks() throws {
480480
let channel = EmbeddedChannel()
481+
var tasks: [Scheduled<Void>] = []
481482
var invocations = 0
482483

483484
func recursivelyScheduleAndIncrement() {
484-
channel.pipeline.eventLoop.scheduleTask(deadline: .distantFuture) {
485+
let task = channel.pipeline.eventLoop.scheduleTask(deadline: .distantFuture) {
485486
invocations += 1
486487
recursivelyScheduleAndIncrement()
487488
}
489+
tasks.append(task)
488490
}
489491

490492
recursivelyScheduleAndIncrement()
491493

492494
try XCTAssertNoThrow(channel.finish())
493-
XCTAssertEqual(invocations, 1)
495+
496+
// None of the tasks should have been executed, they were scheduled for distant future.
497+
XCTAssertEqual(invocations, 0)
498+
499+
// Because the root task didn't run, it should be the onnly one scheduled.
500+
XCTAssertEqual(tasks.count, 1)
501+
502+
// Check the task was failed with cancelled error.
503+
let taskChecked = expectation(description: "task future fulfilled")
504+
tasks.first?.futureResult.whenComplete { result in
505+
switch result {
506+
case .success: XCTFail("Expected task to be cancelled, not run.")
507+
case .failure(let error): XCTAssertEqual(error as? EventLoopError, .cancelled)
508+
}
509+
taskChecked.fulfill()
510+
}
511+
wait(for: [taskChecked], timeout: 0)
494512
}
495513

496514
func testGetChannelOptionAutoReadIsSupported() {

0 commit comments

Comments
 (0)