From a0507e4792966a3d4e736dac96c9664c5f868a74 Mon Sep 17 00:00:00 2001 From: Zach Sailer Date: Thu, 13 Feb 2025 11:20:06 -0800 Subject: [PATCH] parallelize async extension start --- jupyter_server/extension/manager.py | 11 +++++------ tests/extension/mockextensions/app.py | 3 +++ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/jupyter_server/extension/manager.py b/jupyter_server/extension/manager.py index 80510e21a6..2b18573c95 100644 --- a/jupyter_server/extension/manager.py +++ b/jupyter_server/extension/manager.py @@ -168,12 +168,12 @@ def load(self, serverapp): loader = self._get_loader() return loader(serverapp) - def start(self, serverapp): + async def start(self, serverapp): """Call's the extensions 'start' hook where it can start (possibly async) tasks _after_ the event loop is running. """ starter = self._get_starter() - return starter(serverapp) + return await starter(serverapp) class ExtensionPackage(LoggingConfigurable): @@ -247,10 +247,10 @@ def load_point(self, point_name, serverapp): point = self.extension_points[point_name] return point.load(serverapp) - def start_point(self, point_name, serverapp): + async def start_point(self, point_name, serverapp): """Load an extension point.""" point = self.extension_points[point_name] - return point.start(serverapp) + return await point.start(serverapp) def link_all_points(self, serverapp): """Link all extension points.""" @@ -447,8 +447,7 @@ async def start_all_extensions(self): """Start all enabled extensions.""" # Sort the extension names to enforce deterministic loading # order. - for name in self.sorted_extensions: - await self.start_extension(name) + await multi([self.start_extension(name) for name in self.sorted_extensions]) async def stop_all_extensions(self): """Call the shutdown hooks in all extensions.""" diff --git a/tests/extension/mockextensions/app.py b/tests/extension/mockextensions/app.py index 361988929f..d54cd102e2 100644 --- a/tests/extension/mockextensions/app.py +++ b/tests/extension/mockextensions/app.py @@ -65,6 +65,9 @@ class MockExtensionApp(ExtensionAppJinjaMixin, ExtensionApp): } } + async def _start_jupyter_server_extension(self, serverapp): + self.started = True + @staticmethod def get_extension_package(): return "tests.extension.mockextensions"