Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add graceful shutdown of channels #95

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@

import java.time.Duration;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.TimeUnit;

import org.springframework.beans.factory.DisposableBean;
import org.springframework.core.log.LogAccessor;
import org.springframework.util.Assert;

import io.grpc.ChannelCredentials;
Expand All @@ -41,6 +44,8 @@
public class DefaultGrpcChannelFactory<T extends ManagedChannelBuilder<T>>
implements GrpcChannelFactory, DisposableBean {

private final LogAccessor log = new LogAccessor(getClass());

private final List<ManagedChannelWithShutdown> channels = new ArrayList<>();

private final List<GrpcChannelBuilderCustomizer<T>> globalCustomizers = new ArrayList<>();
Expand Down Expand Up @@ -105,14 +110,56 @@ protected T newChannelBuilder(String target, ChannelCredentials credentials) {
return (T) Grpc.newChannelBuilder(target, credentials);
}

/**
* Performs a shutdown on all created channels as follows:
* <ul>
* <li>First an {@link ManagedChannel#shutdown() orderly shutdown} is initiated on
* each channel.
* <li>Next the channels are ordered by smallest to largest grace period, and in
* serial fashion each channel is sent an {@link ManagedChannel#awaitTermination
* awaitTermination} with the channel's remaining grace period.
* <li>Finally, any channel not terminated is sent a
* {@link ManagedChannel#shutdownNow() forceful shutdown}.
* </ul>
*/
@Override
public void destroy() {
this.channels.forEach((c) -> {
var shutdownGracePeriod = c.shutdownGracePeriod();
var channel = c.channel();
// TODO use grace period to do the magical shutdown here
channel.shutdown();
});
this.channels.stream().map(ManagedChannelWithShutdown::channel).forEach(ManagedChannel::shutdown);
this.channels.sort(Comparator.comparingLong((t) -> t.shutdownGracePeriod().toMillis()));
try {
long start = System.currentTimeMillis();
this.channels.forEach((channelWithShutdown) -> {
var channel = channelWithShutdown.channel();
var gracePeriod = channelWithShutdown.shutdownGracePeriod();
if (!channel.isTerminated()) {
this.log.debug(() -> "Awaiting channel termination: " + channel.authority());
long totalTimeWaitedSinceStart = System.currentTimeMillis() - start;
long gracePeriodRemaining = gracePeriod.toMillis() - totalTimeWaitedSinceStart;
this.awaitTermination(channel, gracePeriodRemaining);
}
this.log.debug(() -> "Completed channel termination: " + channel.authority());
});
}
finally {
this.channels.stream().map(ManagedChannelWithShutdown::channel).forEach((channel) -> {
if (!channel.isTerminated()) {
this.log.debug(() -> "Channel not terminated yet - forcing shutdown: " + channel.authority());
channel.shutdownNow();
}
});
}
}

private void awaitTermination(ManagedChannel channel, long awaitMillis) {
try {
if (awaitMillis > 0) {
channel.awaitTermination(awaitMillis, TimeUnit.MILLISECONDS);
}
}
catch (InterruptedException e) {
this.log.debug(() -> "Channel wait exceeded grace period - forcing shutdown: " + channel.authority());
channel.shutdownNow();
}
}

record ManagedChannelWithShutdown(ManagedChannel channel, Duration shutdownGracePeriod) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
/*
* Copyright 2023-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.grpc.client;

import static org.mockito.ArgumentMatchers.anyLong;
import static org.mockito.ArgumentMatchers.eq;
import static org.mockito.Mockito.doAnswer;
import static org.mockito.Mockito.doReturn;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;

import java.time.Duration;
import java.util.ArrayList;
import java.util.List;
import java.util.concurrent.TimeUnit;

import org.awaitility.Awaitility;
import org.junit.jupiter.api.Test;

import org.springframework.grpc.client.DefaultGrpcChannelFactory.ManagedChannelWithShutdown;
import org.springframework.test.util.ReflectionTestUtils;

import io.grpc.ManagedChannel;

/**
* Tests for shutdown process in {@link DefaultGrpcChannelFactory}.
*/
class DefaultGrpcChannelFactoryShutdownTests {

@Test
void channelsAreGracefullyShutdown() throws InterruptedException {
this.channelsShutdownAsExpected(false);
}

@Test
void whenChannelExceedsAwaitTimeOtherChannelsAreStillShutdownGracefully() throws InterruptedException {
this.channelsShutdownAsExpected(true);
}

private void channelsShutdownAsExpected(boolean exceedAwaitTime) throws InterruptedException {
var channelFactory = new DefaultGrpcChannelFactory<>(List.of(), mock());
channelFactory.setVirtualTargets(path -> path);

// create channels using factory and options
var c1 = channelFactory.createChannel("c1",
ChannelBuilderOptions.defaults().withShutdownGracePeriod(Duration.ofSeconds(7)));
var c2 = channelFactory.createChannel("c2",
ChannelBuilderOptions.defaults().withShutdownGracePeriod(Duration.ofSeconds(5)));
var c3 = channelFactory.createChannel("c3",
ChannelBuilderOptions.defaults().withShutdownGracePeriod(Duration.ofSeconds(10)));

// spy each channel to wait accordingly
var c1Spy = setupSpy(c1);
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not excited about having to get in and spy like this. However, it is a necessary evil to get some basic test coverage.

Namely, we ensure each channel gets the following:

  • shutdown invoked
  • awaitTermination invoked
  • shutdownNow never invoked EXCEPT only on channel2 when exceedAwaitTime = true

var c2Spy = setupSpy(c2, exceedAwaitTime);
var c3Spy = setupSpy(c3);

// replace factory channels with spy channels
List<ManagedChannelWithShutdown> spiedChannels = new ArrayList<>();
spiedChannels.add(new ManagedChannelWithShutdown(c1Spy, Duration.ofSeconds(7)));
spiedChannels.add(new ManagedChannelWithShutdown(c2Spy, Duration.ofSeconds(5)));
spiedChannels.add(new ManagedChannelWithShutdown(c3Spy, Duration.ofSeconds(10)));
ReflectionTestUtils.setField(channelFactory, "channels", spiedChannels);

// invoke the shutdown
channelFactory.destroy();

Awaitility.await().atMost(Duration.ofSeconds(15)).untilAsserted(() -> {
// each channel should get ordered shutdown called
verify(c1Spy).shutdown();
verify(c2Spy).shutdown();
verify(c3Spy).shutdown();

// each channel should be awaitTermination (shortest grace periods first)
var inOrder = inOrder(c1Spy, c2Spy, c3Spy);
inOrder.verify(c2Spy).awaitTermination(anyLong(), eq(TimeUnit.MILLISECONDS));
inOrder.verify(c1Spy).awaitTermination(anyLong(), eq(TimeUnit.MILLISECONDS));
inOrder.verify(c3Spy).awaitTermination(anyLong(), eq(TimeUnit.MILLISECONDS));

// c1 and c3 should never get forcibly shutdown
// c2 is forcibly shutdown when exceedAwaitTime is true
verify(c1Spy, never()).shutdownNow();
verify(c3Spy, never()).shutdownNow();
verify(c2Spy, times(exceedAwaitTime ? 1 : 0)).shutdownNow();
});
}

private ManagedChannel setupSpy(ManagedChannel channel) throws InterruptedException {
return this.setupSpy(channel, false);
}

private ManagedChannel setupSpy(ManagedChannel channel, boolean exceedAwaitTime) throws InterruptedException {
var channelSpy = spy(channel);
doAnswer((i) -> {
try {
Thread.sleep(3000);
}
catch (InterruptedException ex) {
Thread.currentThread().interrupt();
}
if (!exceedAwaitTime) {
return i.callRealMethod();
}
throw new InterruptedException("Exceeded await time");
}).when(channelSpy).awaitTermination(anyLong(), eq(TimeUnit.MILLISECONDS));
doReturn(false, true).when(channelSpy).isTerminated();
return channelSpy;
}

}
Loading