Skip to content

Commit

Permalink
Refactored server close logic to gracefully exit without using GOAWAY…
Browse files Browse the repository at this point in the history
… frames
  • Loading branch information
vmagamedov committed May 19, 2024
1 parent 5916cba commit 777d6c3
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 30 deletions.
9 changes: 6 additions & 3 deletions grpclib/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,10 @@


class Handler(AbstractHandler):
connection_lost = False
closing = False

def connection_made(self, connection: Any) -> None:
pass

def accept(self, stream: Any, headers: Any, release_stream: Any) -> None:
raise NotImplementedError('Client connection can not accept requests')
Expand All @@ -71,7 +74,7 @@ def cancel(self, stream: Any) -> None:
pass

def close(self) -> None:
self.connection_lost = True
self.closing = True


class Stream(StreamIterator[_RecvType], Generic[_SendType, _RecvType]):
Expand Down Expand Up @@ -737,7 +740,7 @@ async def _create_connection(self) -> H2Protocol:
@property
def _connected(self) -> bool:
return (self._protocol is not None
and not self._protocol.handler.connection_lost)
and not cast(Handler, self._protocol.handler).closing)

async def __connect__(self) -> H2Protocol:
if not self._connected:
Expand Down
5 changes: 5 additions & 0 deletions grpclib/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,10 @@ def closable(self) -> bool:

class AbstractHandler(ABC):

@abstractmethod
def connection_made(self, connection: Connection) -> None:
pass

@abstractmethod
def accept(
self,
Expand Down Expand Up @@ -709,6 +713,7 @@ def connection_made(self, transport: BaseTransport) -> None:
self.connection.flush()
self.connection.initialize()

self.handler.connection_made(self.connection)
self.processor = EventsProcessor(self.handler, self.connection)

def data_received(self, data: bytes) -> None:
Expand Down
56 changes: 32 additions & 24 deletions grpclib/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import logging
import asyncio
import warnings
from functools import partial

from types import TracebackType
from typing import TYPE_CHECKING, Optional, Collection, Generic, Type, cast
Expand All @@ -12,6 +13,7 @@

import h2.config
import h2.exceptions
from h2.errors import ErrorCodes

from multidict import MultiDict

Expand All @@ -24,7 +26,7 @@
from .metadata import Deadline, encode_grpc_message, _Metadata
from .metadata import encode_metadata, decode_metadata, _MetadataLike
from .metadata import _STATUS_DETAILS_KEY, encode_bin_value
from .protocol import H2Protocol, AbstractHandler
from .protocol import H2Protocol, AbstractHandler, Connection
from .exceptions import GRPCError, ProtocolError, StreamTerminatedError
from .encoding.base import GRPC_CONTENT_TYPE, CodecBase, StatusDetailsCodecBase
from .encoding.proto import ProtoCodec, ProtoStatusDetailsCodec
Expand Down Expand Up @@ -493,9 +495,8 @@ def __gc_step__(self) -> None:
self.__gc_collect__()


class Handler(_GC, AbstractHandler):
__gc_interval__ = 10

class Handler(AbstractHandler):
connection: Connection
closing = False

def __init__(
Expand All @@ -511,44 +512,51 @@ def __init__(
self.dispatch = dispatch
self.loop = asyncio.get_event_loop()
self._tasks: Dict['protocol.Stream', 'asyncio.Task[None]'] = {}
self._cancelled: Set['asyncio.Task[None]'] = set()

def __gc_collect__(self) -> None:
self._tasks = {s: t for s, t in self._tasks.items()
if not t.done()}
self._cancelled = {t for t in self._cancelled
if not t.done()}
def connection_made(self, connection: Connection) -> None:
self.connection = connection

def handler_done(
self,
stream: 'protocol.Stream',
_: 'asyncio.Future[None]',
) -> None:
self._tasks.pop(stream)
if self.closing and not self._tasks:
self.connection.close()

def accept(
self,
stream: 'protocol.Stream',
headers: _Headers,
release_stream: Callable[[], Any],
) -> None:
self.__gc_step__()
self._tasks[stream] = self.loop.create_task(request_handler(
self.mapping, stream, headers, self.codec,
self.status_details_codec, self.dispatch, release_stream,
))
if self.closing:
stream.reset_nowait(ErrorCodes.REFUSED_STREAM)
release_stream()
else:
task = self._tasks[stream] = self.loop.create_task(request_handler(
self.mapping, stream, headers, self.codec,
self.status_details_codec, self.dispatch, release_stream,
))
task.add_done_callback(partial(self.handler_done, stream))

def cancel(self, stream: 'protocol.Stream') -> None:
task = self._tasks.pop(stream)
task.cancel()
self._cancelled.add(task)
self._tasks[stream].cancel()

def close(self) -> None:
for task in self._tasks.values():
task.cancel()
self._cancelled.update(self._tasks.values())
self.closing = True

async def wait_closed(self) -> None:
if self._cancelled:
await asyncio.wait(self._cancelled)
if self._tasks:
await asyncio.wait(self._tasks.values())
else:
self.connection.close()

def check_closed(self) -> bool:
self.__gc_collect__()
return not self._tasks and not self._cancelled
return not self._tasks


class Server(_GC):
Expand Down Expand Up @@ -737,11 +745,11 @@ async def wait_closed(self) -> None:
if self._server is None or self._server_closed_fut is None:
raise RuntimeError('Server is not started')
await self._server_closed_fut
await self._server.wait_closed()
if self._handlers:
await asyncio.wait({
self._loop.create_task(h.wait_closed()) for h in self._handlers
})
await self._server.wait_closed()

async def __aenter__(self) -> 'Server':
return self
Expand Down
3 changes: 3 additions & 0 deletions tests/stubs.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ class DummyHandler(AbstractHandler):
headers = None
release_stream = None

def connection_made(self, connection):
pass

def accept(self, stream, headers, release_stream):
self.stream = stream
self.headers = headers
Expand Down
3 changes: 0 additions & 3 deletions tests/test_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,14 +83,11 @@ async def test_stream():
cs = ClientServer(DummyService, DummyServiceStub)
async with cs as (_, stub):
await stub.UnaryUnary(DummyRequest(value='ping'))
handler = next(iter(cs.server._handlers))
handler.__gc_collect__()
gc.collect()
gc.disable()
try:
pre = set(collect())
await stub.UnaryUnary(DummyRequest(value='ping'))
handler.__gc_collect__()
post = collect()

diff = set(post).difference(pre)
Expand Down

0 comments on commit 777d6c3

Please sign in to comment.