From 9c10d0aab3f721e2053af9e48e51841edb18f20a Mon Sep 17 00:00:00 2001 From: Ruslan Kuprieiev Date: Thu, 27 Jun 2024 23:30:28 +0300 Subject: [PATCH] feat: introduce List and Cancel (#146) * feat: introduce List and Cancel * Task -> RunTask Per https://github.com/fal-ai/isolate/pull/146#discussion_r1643147280 --- src/isolate/server/definitions/server.proto | 26 +++- src/isolate/server/definitions/server_pb2.py | 18 ++- src/isolate/server/definitions/server_pb2.pyi | 73 ++++++++++- .../server/definitions/server_pb2_grpc.py | 88 +++++++++++++ src/isolate/server/server.py | 117 +++++++++++++----- tests/test_server.py | 32 ++++- 6 files changed, 309 insertions(+), 45 deletions(-) diff --git a/src/isolate/server/definitions/server.proto b/src/isolate/server/definitions/server.proto index fc953e2..a42d54e 100644 --- a/src/isolate/server/definitions/server.proto +++ b/src/isolate/server/definitions/server.proto @@ -10,6 +10,12 @@ service Isolate { // Submit a function to be run without waiting for results. rpc Submit (SubmitRequest) returns (SubmitResponse) {} + + // List running tasks + rpc List (ListRequest) returns (ListResponse) {} + + // Cancel a running task + rpc Cancel (CancelRequest) returns (CancelResponse) {} } message BoundFunction { @@ -33,5 +39,23 @@ message SubmitRequest { } message SubmitResponse { - // Reserved for future use. + string task_id = 1; +} + +message ListRequest { +} + +message TaskInfo { + string task_id = 1; +} + +message ListResponse { + repeated TaskInfo tasks = 1; +} + +message CancelRequest { + string task_id = 1; +} + +message CancelResponse { } diff --git a/src/isolate/server/definitions/server_pb2.py b/src/isolate/server/definitions/server_pb2.py index 5bb5084..aa9d712 100644 --- a/src/isolate/server/definitions/server_pb2.py +++ b/src/isolate/server/definitions/server_pb2.py @@ -16,7 +16,7 @@ from google.protobuf import struct_pb2 as google_dot_protobuf_dot_struct__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cserver.proto\x1a\x0c\x63ommon.proto\x1a\x1cgoogle/protobuf/struct.proto\"\x9d\x01\n\rBoundFunction\x12,\n\x0c\x65nvironments\x18\x01 \x03(\x0b\x32\x16.EnvironmentDefinition\x12#\n\x08\x66unction\x18\x02 \x01(\x0b\x32\x11.SerializedObject\x12*\n\nsetup_func\x18\x03 \x01(\x0b\x32\x11.SerializedObjectH\x00\x88\x01\x01\x42\r\n\x0b_setup_func\"d\n\x15\x45nvironmentDefinition\x12\x0c\n\x04kind\x18\x01 \x01(\t\x12.\n\rconfiguration\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\r\n\x05\x66orce\x18\x03 \x01(\x08\"1\n\rSubmitRequest\x12 \n\x08\x66unction\x18\x01 \x01(\x0b\x32\x0e.BoundFunction\"\x10\n\x0eSubmitResponse2d\n\x07Isolate\x12,\n\x03Run\x12\x0e.BoundFunction\x1a\x11.PartialRunResult\"\x00\x30\x01\x12+\n\x06Submit\x12\x0e.SubmitRequest\x1a\x0f.SubmitResponse\"\x00\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0cserver.proto\x1a\x0c\x63ommon.proto\x1a\x1cgoogle/protobuf/struct.proto\"\x9d\x01\n\rBoundFunction\x12,\n\x0c\x65nvironments\x18\x01 \x03(\x0b\x32\x16.EnvironmentDefinition\x12#\n\x08\x66unction\x18\x02 \x01(\x0b\x32\x11.SerializedObject\x12*\n\nsetup_func\x18\x03 \x01(\x0b\x32\x11.SerializedObjectH\x00\x88\x01\x01\x42\r\n\x0b_setup_func\"d\n\x15\x45nvironmentDefinition\x12\x0c\n\x04kind\x18\x01 \x01(\t\x12.\n\rconfiguration\x18\x02 \x01(\x0b\x32\x17.google.protobuf.Struct\x12\r\n\x05\x66orce\x18\x03 \x01(\x08\"1\n\rSubmitRequest\x12 \n\x08\x66unction\x18\x01 \x01(\x0b\x32\x0e.BoundFunction\"!\n\x0eSubmitResponse\x12\x0f\n\x07task_id\x18\x01 \x01(\t\"\r\n\x0bListRequest\"\x1b\n\x08TaskInfo\x12\x0f\n\x07task_id\x18\x01 \x01(\t\"(\n\x0cListResponse\x12\x18\n\x05tasks\x18\x01 \x03(\x0b\x32\t.TaskInfo\" \n\rCancelRequest\x12\x0f\n\x07task_id\x18\x01 \x01(\t\"\x10\n\x0e\x43\x61ncelResponse2\xb8\x01\n\x07Isolate\x12,\n\x03Run\x12\x0e.BoundFunction\x1a\x11.PartialRunResult\"\x00\x30\x01\x12+\n\x06Submit\x12\x0e.SubmitRequest\x1a\x0f.SubmitResponse\"\x00\x12%\n\x04List\x12\x0c.ListRequest\x1a\r.ListResponse\"\x00\x12+\n\x06\x43\x61ncel\x12\x0e.CancelRequest\x1a\x0f.CancelResponse\"\x00\x62\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) @@ -30,7 +30,17 @@ _globals['_SUBMITREQUEST']._serialized_start=322 _globals['_SUBMITREQUEST']._serialized_end=371 _globals['_SUBMITRESPONSE']._serialized_start=373 - _globals['_SUBMITRESPONSE']._serialized_end=389 - _globals['_ISOLATE']._serialized_start=391 - _globals['_ISOLATE']._serialized_end=491 + _globals['_SUBMITRESPONSE']._serialized_end=406 + _globals['_LISTREQUEST']._serialized_start=408 + _globals['_LISTREQUEST']._serialized_end=421 + _globals['_TASKINFO']._serialized_start=423 + _globals['_TASKINFO']._serialized_end=450 + _globals['_LISTRESPONSE']._serialized_start=452 + _globals['_LISTRESPONSE']._serialized_end=492 + _globals['_CANCELREQUEST']._serialized_start=494 + _globals['_CANCELREQUEST']._serialized_end=526 + _globals['_CANCELRESPONSE']._serialized_start=528 + _globals['_CANCELRESPONSE']._serialized_end=544 + _globals['_ISOLATE']._serialized_start=547 + _globals['_ISOLATE']._serialized_end=731 # @@protoc_insertion_point(module_scope) diff --git a/src/isolate/server/definitions/server_pb2.pyi b/src/isolate/server/definitions/server_pb2.pyi index 4c45754..1b185f1 100644 --- a/src/isolate/server/definitions/server_pb2.pyi +++ b/src/isolate/server/definitions/server_pb2.pyi @@ -88,12 +88,81 @@ global___SubmitRequest = SubmitRequest @typing.final class SubmitResponse(google.protobuf.message.Message): - """Reserved for future use.""" - DESCRIPTOR: google.protobuf.descriptor.Descriptor + TASK_ID_FIELD_NUMBER: builtins.int + task_id: builtins.str def __init__( self, + *, + task_id: builtins.str = ..., ) -> None: ... + def ClearField(self, field_name: typing.Literal["task_id", b"task_id"]) -> None: ... global___SubmitResponse = SubmitResponse + +@typing.final +class ListRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___ListRequest = ListRequest + +@typing.final +class TaskInfo(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TASK_ID_FIELD_NUMBER: builtins.int + task_id: builtins.str + def __init__( + self, + *, + task_id: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["task_id", b"task_id"]) -> None: ... + +global___TaskInfo = TaskInfo + +@typing.final +class ListResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TASKS_FIELD_NUMBER: builtins.int + @property + def tasks(self) -> google.protobuf.internal.containers.RepeatedCompositeFieldContainer[global___TaskInfo]: ... + def __init__( + self, + *, + tasks: collections.abc.Iterable[global___TaskInfo] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["tasks", b"tasks"]) -> None: ... + +global___ListResponse = ListResponse + +@typing.final +class CancelRequest(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + TASK_ID_FIELD_NUMBER: builtins.int + task_id: builtins.str + def __init__( + self, + *, + task_id: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing.Literal["task_id", b"task_id"]) -> None: ... + +global___CancelRequest = CancelRequest + +@typing.final +class CancelResponse(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + def __init__( + self, + ) -> None: ... + +global___CancelResponse = CancelResponse diff --git a/src/isolate/server/definitions/server_pb2_grpc.py b/src/isolate/server/definitions/server_pb2_grpc.py index 8aa3043..74d22d4 100644 --- a/src/isolate/server/definitions/server_pb2_grpc.py +++ b/src/isolate/server/definitions/server_pb2_grpc.py @@ -50,6 +50,16 @@ def __init__(self, channel): request_serializer=server__pb2.SubmitRequest.SerializeToString, response_deserializer=server__pb2.SubmitResponse.FromString, _registered_method=True) + self.List = channel.unary_unary( + '/Isolate/List', + request_serializer=server__pb2.ListRequest.SerializeToString, + response_deserializer=server__pb2.ListResponse.FromString, + _registered_method=True) + self.Cancel = channel.unary_unary( + '/Isolate/Cancel', + request_serializer=server__pb2.CancelRequest.SerializeToString, + response_deserializer=server__pb2.CancelResponse.FromString, + _registered_method=True) class IsolateServicer(object): @@ -70,6 +80,20 @@ def Submit(self, request, context): context.set_details('Method not implemented!') raise NotImplementedError('Method not implemented!') + def List(self, request, context): + """List running tasks + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + + def Cancel(self, request, context): + """Cancel a running task + """ + context.set_code(grpc.StatusCode.UNIMPLEMENTED) + context.set_details('Method not implemented!') + raise NotImplementedError('Method not implemented!') + def add_IsolateServicer_to_server(servicer, server): rpc_method_handlers = { @@ -83,6 +107,16 @@ def add_IsolateServicer_to_server(servicer, server): request_deserializer=server__pb2.SubmitRequest.FromString, response_serializer=server__pb2.SubmitResponse.SerializeToString, ), + 'List': grpc.unary_unary_rpc_method_handler( + servicer.List, + request_deserializer=server__pb2.ListRequest.FromString, + response_serializer=server__pb2.ListResponse.SerializeToString, + ), + 'Cancel': grpc.unary_unary_rpc_method_handler( + servicer.Cancel, + request_deserializer=server__pb2.CancelRequest.FromString, + response_serializer=server__pb2.CancelResponse.SerializeToString, + ), } generic_handler = grpc.method_handlers_generic_handler( 'Isolate', rpc_method_handlers) @@ -147,3 +181,57 @@ def Submit(request, timeout, metadata, _registered_method=True) + + @staticmethod + def List(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/Isolate/List', + server__pb2.ListRequest.SerializeToString, + server__pb2.ListResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) + + @staticmethod + def Cancel(request, + target, + options=(), + channel_credentials=None, + call_credentials=None, + insecure=False, + compression=None, + wait_for_ready=None, + timeout=None, + metadata=None): + return grpc.experimental.unary_unary( + request, + target, + '/Isolate/Cancel', + server__pb2.CancelRequest.SerializeToString, + server__pb2.CancelResponse.FromString, + options, + channel_credentials, + insecure, + call_credentials, + compression, + wait_for_ready, + timeout, + metadata, + _registered_method=True) diff --git a/src/isolate/server/server.py b/src/isolate/server/server.py index d6685ed..1a89470 100644 --- a/src/isolate/server/server.py +++ b/src/isolate/server/server.py @@ -4,6 +4,7 @@ import threading import time import traceback +import uuid from collections import defaultdict from concurrent import futures from concurrent.futures import ThreadPoolExecutor @@ -114,11 +115,11 @@ def establish( self, connection: LocalPythonGRPC, queue: Queue, - ) -> Iterator[tuple[definitions.AgentStub, Queue]]: + ) -> Iterator[RunnerAgent]: agent = self._allocate_new_agent(connection, queue) try: - yield agent.stub, agent.message_queue + yield agent finally: self._cache_agent(connection, agent) @@ -165,19 +166,34 @@ def __exit__(self, *exc_info: Any) -> None: agent.terminate() +@dataclass +class RunTask: + request: definitions.BoundFunction + future: futures.Future | None = None + agent: RunnerAgent | None = None + + def cancel(self): + while True: + self.future.cancel() + if self.agent: + self.agent.terminate() + try: + self.future.exception(timeout=0.1) + return + except futures.TimeoutError: + pass + + @dataclass class IsolateServicer(definitions.IsolateServicer): bridge_manager: BridgeManager default_settings: IsolateSettings = field(default_factory=IsolateSettings) - background_tasks: set[futures.Future] = field(default_factory=set) + background_tasks: dict[str, RunTask] = field(default_factory=dict) - def _run_function( - self, - request: definitions.BoundFunction, - ) -> Iterator[definitions.PartialRunResult]: + def _run_task(self, task: RunTask) -> Iterator[definitions.PartialRunResult]: messages: Queue[definitions.PartialRunResult] = Queue() environments = [] - for env in request.environments: + for env in task.request.environments: try: environments.append((env.force, from_grpc(env))) except ValueError: @@ -195,7 +211,7 @@ def _run_function( run_settings = replace( self.default_settings, log_hook=log_handler.handle, - serialization_method=request.function.method, + serialization_method=task.request.function.method, ) for _, environment in environments: @@ -244,28 +260,28 @@ def _run_function( extra_inheritance_paths=inheritance_paths, ) - with self.bridge_manager.establish(connection, queue=messages) as ( - bridge, - queue, - ): + with self.bridge_manager.establish(connection, queue=messages) as agent: + task.agent = agent function_call = definitions.FunctionCall( - function=request.function, - setup_func=request.setup_func, + function=task.request.function, + setup_func=task.request.setup_func, ) - if not request.HasField("setup_func"): + if not task.request.HasField("setup_func"): function_call.ClearField("setup_func") future = local_pool.submit( _proxy_to_queue, - queue=queue, - bridge=bridge, + queue=agent.message_queue, + bridge=agent.stub, input=function_call, ) # Unlike above; we are not interested in the result value of future # here, since it will be already transferred to other side without # us even seeing (through the queue). - yield from self.watch_queue_until_completed(queue, future.done) + yield from self.watch_queue_until_completed( + agent.message_queue, future.done + ) # But we still have to check whether there were any errors raised # during the execution, and handle them accordingly. @@ -293,14 +309,8 @@ def _run_function( StatusCode.UNKNOWN, ) - def _run_function_in_background( - self, - bound_function: definitions.BoundFunction, - ) -> None: - try: - for _ in self._run_function(bound_function): - pass - except GRPCException: + def _run_task_in_background(self, task: RunTask) -> None: + for _ in self._run_task(task): pass def Submit( @@ -308,13 +318,24 @@ def Submit( request: definitions.SubmitRequest, context: ServicerContext, ) -> definitions.SubmitResponse: - run_future = RUNNER_THREAD_POOL.submit( - self._run_function_in_background, - request.function, + task = RunTask(request=request.function) + task.future = RUNNER_THREAD_POOL.submit( + self._run_task_in_background, + task, ) - self.background_tasks.add(run_future) + task_id = str(uuid.uuid4()) + + print(f"Submitted a task {task_id}") - return definitions.SubmitResponse() + self.background_tasks[task_id] = task + + def _callback(_): + print(f"Task {task_id} finished") + self.background_tasks.pop(task_id, None) + + task.future.add_done_callback(_callback) + + return definitions.SubmitResponse(task_id=task_id) def Run( self, @@ -322,7 +343,7 @@ def Run( context: ServicerContext, ) -> Iterator[definitions.PartialRunResult]: try: - yield from self._run_function(request) + yield from self._run_task(RunTask(request=request)) except GRPCException as exc: return self.abort_with_msg( exc.message, @@ -330,6 +351,32 @@ def Run( code=exc.code, ) + def List( + self, + request: definitions.ListRequest, + context: ServicerContext, + ) -> definitions.ListResponse: + return definitions.ListResponse( + tasks=[ + definitions.TaskInfo(task_id=task_id) + for task_id in self.background_tasks.keys() + ] + ) + + def Cancel( + self, + request: definitions.CancelRequest, + context: ServicerContext, + ) -> definitions.CancelResponse: + task_id = request.task_id + + print(f"Canceling task {task_id}") + task = self.background_tasks.get(task_id) + if task is not None: + task.cancel() + + return definitions.CancelResponse() + def watch_queue_until_completed( self, queue: Queue, is_completed: Callable[[], bool] ) -> Iterator[definitions.PartialRunResult]: @@ -381,6 +428,10 @@ def abort_with_msg( context.set_details(message) return None + def cancel_tasks(self): + for task in self.background_tasks.values(): + task.cancel() + def _proxy_to_queue( queue: Queue, diff --git a/tests/test_server.py b/tests/test_server.py index b403182..778fb75 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -65,11 +65,7 @@ def make_server(tmp_path): yield Stubs(isolate_stub=isolate_stub, health_stub=health_stub) finally: server.stop(None) - for task in servicer.background_tasks: - if task.done(): - task.result() - else: - print("leaking background task", task) + servicer.cancel_tasks() @pytest.fixture @@ -697,3 +693,29 @@ def test_server_submit( stub.Submit(request) time.sleep(5) assert "completed" in file.read_text() + assert not list(stub.List(definitions.ListRequest()).tasks) + + +def myserver(): + import time + + while True: + print("running") + time.sleep(1) + + +def test_server_submit_server( + stub: definitions.IsolateStub, + monkeypatch: Any, +) -> None: + inherit_from_local(monkeypatch) + + request = definitions.SubmitRequest(function=prepare_request(myserver)) + task_id = stub.Submit(request).task_id + + tasks = [task.task_id for task in stub.List(definitions.ListRequest()).tasks] + assert task_id in tasks + + stub.Cancel(definitions.CancelRequest(task_id=task_id)) + + assert not list(stub.List(definitions.ListRequest()).tasks)