diff --git a/micropip/_compat_in_pyodide.py b/micropip/_compat_in_pyodide.py index e47aadf..0e1d2b9 100644 --- a/micropip/_compat_in_pyodide.py +++ b/micropip/_compat_in_pyodide.py @@ -1,4 +1,7 @@ +from asyncio import CancelledError +from collections.abc import Awaitable, Callable from pathlib import Path +from typing import TYPE_CHECKING, Concatenate, ParamSpec, TypeVar from urllib.parse import urlparse from pyodide._package_loader import get_dynlibs @@ -7,7 +10,7 @@ try: import pyodide_js - from js import Object + from js import AbortController, AbortSignal, Object from pyodide_js import loadedPackages, loadPackage from pyodide_js._api import ( # type: ignore[import] loadBinaryFile, @@ -21,21 +24,45 @@ raise # Otherwise, this is pytest test collection so let it go. +if IN_BROWSER or TYPE_CHECKING: + P = ParamSpec("P") + T = TypeVar("T") -async def fetch_bytes(url: str, kwargs: dict[str, str]) -> bytes: + def _abort_on_cancel( + func: Callable[Concatenate[AbortSignal, P], Awaitable[T]], + ) -> Callable[P, Awaitable[T]]: + """inject an AbortSignal as the first argument""" + + async def wrapper(*args: P.args, **kwargs: P.kwargs) -> T: + controller = AbortController.new() + try: + return await func(controller.signal, *args, **kwargs) + except CancelledError: + controller.abort() + raise + + return wrapper + +else: + _abort_on_cancel = lambda func: lambda *args, **kwargs: func(None, *args, **kwargs) + + +@_abort_on_cancel +async def fetch_bytes(signal: AbortSignal, url: str, kwargs: dict[str, str]) -> bytes: parsed_url = urlparse(url) if parsed_url.scheme == "emfs": return Path(parsed_url.path).read_bytes() if parsed_url.scheme == "file": return (await loadBinaryFile(parsed_url.path)).to_bytes() - return await (await pyfetch(url, **kwargs)).bytes() + return await (await pyfetch(url, **kwargs, signal=signal)).bytes() +@_abort_on_cancel async def fetch_string_and_headers( - url: str, kwargs: dict[str, str] + signal: AbortSignal, url: str, kwargs: dict[str, str] ) -> tuple[str, dict[str, str]]: - response = await pyfetch(url, **kwargs) + response = await pyfetch(url, **kwargs, signal=signal) content = await response.string() # TODO: replace with response.headers when pyodide>= 0.24 is released diff --git a/micropip/_compat_not_in_pyodide.py b/micropip/_compat_not_in_pyodide.py index 137ec1d..dac34c0 100644 --- a/micropip/_compat_not_in_pyodide.py +++ b/micropip/_compat_not_in_pyodide.py @@ -74,7 +74,7 @@ def __get__(self, attr): REPODATA_INFO: dict[str, str] = {} -def loadPackage(packages: str | list[str]) -> None: +async def loadPackage(packages: str | list[str]) -> None: pass diff --git a/micropip/transaction.py b/micropip/transaction.py index bbc91e2..73adaee 100644 --- a/micropip/transaction.py +++ b/micropip/transaction.py @@ -48,11 +48,17 @@ async def gather_requirements( self, requirements: list[str] | list[Requirement], ) -> None: - requirement_promises = [] - for requirement in requirements: - requirement_promises.append(self.add_requirement(requirement)) - - await asyncio.gather(*requirement_promises) + futures: list[asyncio.Future] = [] + try: + for requirement in requirements: + futures.append(asyncio.ensure_future(self.add_requirement(requirement))) + await asyncio.gather(*futures) + except ValueError: + if not self.keep_going: + for future in futures: + if not future.done(): + future.cancel() + raise async def add_requirement(self, req: str | Requirement) -> None: if isinstance(req, Requirement):