From be02b34b53882cb2f71f05bb04003cc2b7672087 Mon Sep 17 00:00:00 2001 From: Fabian Fett Date: Wed, 16 Feb 2022 15:01:35 +0100 Subject: [PATCH] Release stream callback, once the stream has finished (#1363) --- .../GRPC/ClientCalls/ResponseContainers.swift | 14 +++- Tests/GRPCTests/FakeChannelTests.swift | 17 +++- ...treamResponseHandlerRetainCycleTests.swift | 81 +++++++++++++++++++ 3 files changed, 106 insertions(+), 6 deletions(-) create mode 100644 Tests/GRPCTests/StreamResponseHandlerRetainCycleTests.swift diff --git a/Sources/GRPC/ClientCalls/ResponseContainers.swift b/Sources/GRPC/ClientCalls/ResponseContainers.swift index 5d9c4b0bf..ad294fcda 100644 --- a/Sources/GRPC/ClientCalls/ResponseContainers.swift +++ b/Sources/GRPC/ClientCalls/ResponseContainers.swift @@ -98,7 +98,7 @@ internal class StreamingResponseParts { private let eventLoop: EventLoop /// A callback for response messages. - private let responseCallback: (Response) -> Void + private var responseCallback: Optional<(Response) -> Void> /// Lazy promises for the status, initial-, and trailing-metadata. private var initialMetadataPromise: LazyEventLoopPromise @@ -139,9 +139,13 @@ internal class StreamingResponseParts { self.initialMetadataPromise.succeed(metadata) case let .message(response): - self.responseCallback(response) + self.responseCallback?(response) case let .end(status, trailers): + // Once the stream has finished, we must release the callback, to make sure don't + // break potential retain cycles (the callback may reference other object's that in + // turn reference `StreamingResponseParts`). + self.responseCallback = nil self.initialMetadataPromise.fail(status) self.trailingMetadataPromise.succeed(trailers) self.statusPromise.succeed(status) @@ -149,6 +153,12 @@ internal class StreamingResponseParts { } internal func handleError(_ error: Error) { + self.eventLoop.assertInEventLoop() + + // Once the stream has finished, we must release the callback, to make sure don't + // break potential retain cycles (the callback may reference other object's that in + // turn reference `StreamingResponseParts`). + self.responseCallback = nil let withoutContext = error.removingContext() let status = withoutContext.makeGRPCStatus() self.initialMetadataPromise.fail(withoutContext) diff --git a/Tests/GRPCTests/FakeChannelTests.swift b/Tests/GRPCTests/FakeChannelTests.swift index 0803afbc5..8da3b7dd6 100644 --- a/Tests/GRPCTests/FakeChannelTests.swift +++ b/Tests/GRPCTests/FakeChannelTests.swift @@ -81,6 +81,10 @@ class FakeChannelTests: GRPCTestCase { } func testBidirectional() { + final class ResponseCollector { + private(set) var responses = [Response]() + func collect(_ response: Response) { self.responses.append(response) } + } var requests: [Request] = [] let response = self.makeStreamingResponse { part in switch part { @@ -91,10 +95,12 @@ class FakeChannelTests: GRPCTestCase { } } - var responses: [Response] = [] - let call = self.makeBidirectionalStreamingCall { - responses.append($0) + var collector = ResponseCollector() + XCTAssertTrue(isKnownUniquelyReferenced(&collector)) + let call = self.makeBidirectionalStreamingCall { [collector] in + collector.collect($0) } + XCTAssertFalse(isKnownUniquelyReferenced(&collector)) XCTAssertNoThrow(try call.sendMessage(.with { $0.text = "1" }).wait()) XCTAssertNoThrow(try call.sendMessage(.with { $0.text = "2" }).wait()) @@ -106,9 +112,12 @@ class FakeChannelTests: GRPCTestCase { XCTAssertNoThrow(try response.sendMessage(.with { $0.text = "4" })) XCTAssertNoThrow(try response.sendMessage(.with { $0.text = "5" })) XCTAssertNoThrow(try response.sendMessage(.with { $0.text = "6" })) + XCTAssertEqual(collector.responses.count, 3) + XCTAssertFalse(isKnownUniquelyReferenced(&collector)) XCTAssertNoThrow(try response.sendEnd()) + XCTAssertTrue(isKnownUniquelyReferenced(&collector)) - XCTAssertEqual(responses, (4 ... 6).map { number in .with { $0.text = "\(number)" } }) + XCTAssertEqual(collector.responses, (4 ... 6).map { number in .with { $0.text = "\(number)" } }) XCTAssertTrue(try call.status.map { $0.isOk }.wait()) } diff --git a/Tests/GRPCTests/StreamResponseHandlerRetainCycleTests.swift b/Tests/GRPCTests/StreamResponseHandlerRetainCycleTests.swift new file mode 100644 index 000000000..88fb6ad85 --- /dev/null +++ b/Tests/GRPCTests/StreamResponseHandlerRetainCycleTests.swift @@ -0,0 +1,81 @@ +/* + * Copyright 2022, gRPC Authors All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +import EchoImplementation +import EchoModel +import GRPC +import NIOConcurrencyHelpers +import NIOCore +import NIOPosix +import XCTest + +final class StreamResponseHandlerRetainCycleTests: GRPCTestCase { + var group: EventLoopGroup! + var server: Server! + var client: ClientConnection! + + var echo: Echo_EchoClient! + + override func setUp() { + super.setUp() + self.group = MultiThreadedEventLoopGroup(numberOfThreads: 1) + + self.server = try! Server.insecure(group: self.group) + .withServiceProviders([EchoProvider()]) + .withLogger(self.serverLogger) + .bind(host: "localhost", port: 0) + .wait() + + self.client = ClientConnection.insecure(group: self.group) + .withBackgroundActivityLogger(self.clientLogger) + .connect(host: "localhost", port: self.server.channel.localAddress!.port!) + + self.echo = Echo_EchoClient( + channel: self.client, + defaultCallOptions: CallOptions(logger: self.clientLogger) + ) + } + + override func tearDown() { + XCTAssertNoThrow(try self.client.close().wait()) + XCTAssertNoThrow(try self.server.close().wait()) + XCTAssertNoThrow(try self.group.syncShutdownGracefully()) + super.tearDown() + } + + func testHandlerClosureIsReleasedOnceStreamEnds() { + final class Counter { + private let atomic = NIOAtomic.makeAtomic(value: 0) + func increment() { self.atomic.add(1) } + var value: Int { + self.atomic.load() + } + } + + var counter = Counter() + XCTAssertTrue(isKnownUniquelyReferenced(&counter)) + let get = self.echo.update { [capturedCounter = counter] _ in + capturedCounter.increment() + } + XCTAssertFalse(isKnownUniquelyReferenced(&counter)) + + get.sendMessage(.init(text: "hello world"), promise: nil) + XCTAssertFalse(isKnownUniquelyReferenced(&counter)) + XCTAssertNoThrow(try get.sendEnd().wait()) + XCTAssertNoThrow(try get.status.wait()) + XCTAssertEqual(counter.value, 1) + XCTAssertTrue(isKnownUniquelyReferenced(&counter)) + } +}