diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index e99555f1..118f6132 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -23,7 +23,7 @@ jobs: python-version: "3.9" cache: "poetry" - name: Install deps - run: poetry install + run: poetry install --extras "memory zmq" - name: Run lint check run: poetry run pre-commit run -a ${{ matrix.cmd }} pytest: @@ -46,7 +46,7 @@ jobs: python-version: "${{ matrix.py_version }}" cache: "poetry" - name: Install deps - run: poetry install + run: poetry install --extras "memory zmq" - name: Run pytest check run: poetry run pytest -vv -n auto --cov="taskiq" . - name: Generate report diff --git a/poetry.lock b/poetry.lock index b61b4c06..b527cbb5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,16 @@ -# This file is automatically @generated by Poetry 1.4.0 and should not be changed by hand. +# This file is automatically @generated by Poetry and should not be changed by hand. + +[[package]] +name = "aiochannel" +version = "1.2.1" +description = "asyncio Channels (closable queues) inspired by golang" +category = "main" +optional = true +python-versions = ">=3.7.2,<4.0.0" +files = [ + {file = "aiochannel-1.2.1-py3-none-any.whl", hash = "sha256:9187e3832a556fb308ac3fe070b042155a3213737204d5cb57da9a549c2b5f1e"}, + {file = "aiochannel-1.2.1.tar.gz", hash = "sha256:25d960c6e438861556a2623516161a724c1caa786264cbbd0bb62a8f0423f467"}, +] [[package]] name = "anyio" @@ -1447,14 +1459,14 @@ jupyter = ["ipywidgets (>=7.5.1,<9)"] [[package]] name = "setuptools" -version = "67.6.0" +version = "67.6.1" description = "Easily download, build, install, upgrade, and uninstall Python packages" category = "dev" optional = false python-versions = ">=3.7" files = [ - {file = "setuptools-67.6.0-py3-none-any.whl", hash = "sha256:b78aaa36f6b90a074c1fa651168723acbf45d14cb1196b6f02c0fd07f17623b2"}, - {file = "setuptools-67.6.0.tar.gz", hash = "sha256:2ee892cd5f29f3373097f5a814697e397cf3ce313616df0af11231e2ad118077"}, + {file = "setuptools-67.6.1-py3-none-any.whl", hash = "sha256:e728ca814a823bf7bf60162daf9db95b93d532948c4c0bea762ce62f60189078"}, + {file = "setuptools-67.6.1.tar.gz", hash = "sha256:257de92a9d50a60b8e22abfcbb771571fde0dbf3ec234463212027a4eeecbe9a"}, ] [package.extras] @@ -1786,10 +1798,11 @@ docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [extras] +memory = ["aiochannel"] uv = ["uvloop"] zmq = ["pyzmq"] [metadata] lock-version = "2.0" -python-versions = "^3.7" -content-hash = "49a0d7f64cb71f9d637e0418c0ad61c9f4a709ca7c18dc8b98ad3b6da786cac9" +python-versions = "^3.7.2" +content-hash = "0a1c21ffed540daa017b453016603229ebdb4e8ca11b201f01c5697fec9397e8" diff --git a/pyproject.toml b/pyproject.toml index a53440ce..c99d9cb8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,9 +26,10 @@ classifiers = [ keywords = ["taskiq", "tasks", "distributed", "async"] [tool.poetry.dependencies] -python = "^3.7" +python = "^3.7.2" typing-extensions = ">=3.10.0.0" pydantic = "^1.6.2" +aiochannel = { version = ">=1.2.0", optional = true } pyzmq = { version = "^23.2.0", optional = true } uvloop = { version = ">=0.16.0,<1", optional = true } watchdog = "^2.1.9" @@ -55,6 +56,7 @@ pytest-xdist = { version = "^2.5.0", extras = ["psutil"] } types-mock = "^4.0.15" [tool.poetry.extras] +memory = ["aiochannel"] zmq = ["pyzmq"] uv = ["uvloop"] diff --git a/taskiq/brokers/inmemory_broker.py b/taskiq/brokers/inmemory_broker.py index 823169d2..492c8425 100644 --- a/taskiq/brokers/inmemory_broker.py +++ b/taskiq/brokers/inmemory_broker.py @@ -1,13 +1,10 @@ -import inspect from collections import OrderedDict -from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, get_type_hints +from typing import Any, AsyncGenerator, Callable, Optional, TypeVar -from taskiq_dependencies import DependencyGraph +from aiochannel import Channel from taskiq.abc.broker import AsyncBroker from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult -from taskiq.cli.worker.args import WorkerArgs -from taskiq.cli.worker.receiver import Receiver from taskiq.events import TaskiqEvents from taskiq.exceptions import TaskiqError from taskiq.message import BrokerMessage @@ -87,12 +84,9 @@ class InMemoryBroker(AsyncBroker): It's useful for local development, if you don't want to setup real broker. """ - def __init__( # noqa: WPS211 + def __init__( self, - sync_tasks_pool_size: int = 4, - logs_format: Optional[str] = None, max_stored_results: int = 100, - cast_types: bool = True, result_backend: Optional[AsyncResultBackend[Any]] = None, task_id_generator: Optional[Callable[[], str]] = None, ) -> None: @@ -104,16 +98,7 @@ def __init__( # noqa: WPS211 result_backend=result_backend, task_id_generator=task_id_generator, ) - self.receiver = Receiver( - self, - WorkerArgs( - broker="", - modules=[], - max_threadpool_threads=sync_tasks_pool_size, - no_parse=not cast_types, - log_collector_format=logs_format or WorkerArgs.log_collector_format, - ), - ) + self.channel: Channel[BrokerMessage] = Channel() async def kick(self, message: BrokerMessage) -> None: """ @@ -128,31 +113,18 @@ async def kick(self, message: BrokerMessage) -> None: target_task = self.available_tasks.get(message.task_name) if target_task is None: raise TaskiqError("Unknown task.") - if not self.receiver.dependency_graphs.get(target_task.task_name): - self.receiver.dependency_graphs[target_task.task_name] = DependencyGraph( - target_task.original_func, - ) - if not self.receiver.task_signatures.get(target_task.task_name): - self.receiver.task_signatures[target_task.task_name] = inspect.signature( - target_task.original_func, - ) - if not self.receiver.task_hints.get(target_task.task_name): - self.receiver.task_hints[target_task.task_name] = get_type_hints( - target_task.original_func, - ) - - await self.receiver.callback(message=message) + await self.channel.put(message) - def listen(self) -> AsyncGenerator[BrokerMessage, None]: + async def listen(self) -> AsyncGenerator[BrokerMessage, None]: """ - Inmemory broker cannot listen. + Listen to channel. - This method throws RuntimeError if you call it. - Because inmemory broker cannot really listen to any of tasks. + This function listens to channel and yields every new message. - :raises RuntimeError: if this method is called. + :yields: broker message. """ - raise RuntimeError("Inmemory brokers cannot listen.") + async for message in self.channel: + yield message async def startup(self) -> None: """Runs startup events for client and worker side.""" diff --git a/tests/test_brokers.py b/tests/test_brokers.py new file mode 100644 index 00000000..6675f186 --- /dev/null +++ b/tests/test_brokers.py @@ -0,0 +1,33 @@ +import asyncio +from contextlib import contextmanager +from typing import Generator + +import pytest + +from taskiq.abc.broker import AsyncBroker +from taskiq.brokers.inmemory_broker import InMemoryBroker +from taskiq.cli.worker.args import WorkerArgs +from taskiq.cli.worker.async_task_runner import async_listen_messages + + +@contextmanager +def receive_messages(broker: AsyncBroker) -> Generator[None, None, None]: + cli_args = WorkerArgs(broker="", modules=[]) + listen_task = asyncio.create_task(async_listen_messages(broker, cli_args)) + yield + listen_task.cancel() + + +@pytest.mark.anyio +async def test_inmemory_broker_handle_message() -> None: + """Test that inmemory broker receive and handle task works.""" + broker = InMemoryBroker() + + @broker.task() + async def test_echo(msg: str) -> str: + return msg + + with receive_messages(broker): + task = await test_echo.kiq("foo") + result = await task.wait_result(timeout=1) + assert result.return_value == "foo"