Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite InMemoryBroker over aiochannel #92

Open
wants to merge 4 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ jobs:
python-version: "3.9"
cache: "poetry"
- name: Install deps
run: poetry install
run: poetry install --extras "memory zmq"
- name: Run lint check
run: poetry run pre-commit run -a ${{ matrix.cmd }}
pytest:
Expand All @@ -46,7 +46,7 @@ jobs:
python-version: "${{ matrix.py_version }}"
cache: "poetry"
- name: Install deps
run: poetry install
run: poetry install --extras "memory zmq"
- name: Run pytest check
run: poetry run pytest -vv -n auto --cov="taskiq" .
- name: Generate report
Expand Down
25 changes: 19 additions & 6 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@ classifiers = [
keywords = ["taskiq", "tasks", "distributed", "async"]

[tool.poetry.dependencies]
python = "^3.7"
python = "^3.7.2"
typing-extensions = ">=3.10.0.0"
pydantic = "^1.6.2"
aiochannel = { version = ">=1.2.0", optional = true }
pyzmq = { version = "^23.2.0", optional = true }
uvloop = { version = ">=0.16.0,<1", optional = true }
watchdog = "^2.1.9"
Expand All @@ -55,6 +56,7 @@ pytest-xdist = { version = "^2.5.0", extras = ["psutil"] }
types-mock = "^4.0.15"

[tool.poetry.extras]
memory = ["aiochannel"]
zmq = ["pyzmq"]
uv = ["uvloop"]

Expand Down
50 changes: 11 additions & 39 deletions taskiq/brokers/inmemory_broker.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import inspect
from collections import OrderedDict
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar, get_type_hints
from typing import Any, AsyncGenerator, Callable, Optional, TypeVar

from taskiq_dependencies import DependencyGraph
from aiochannel import Channel

from taskiq.abc.broker import AsyncBroker
from taskiq.abc.result_backend import AsyncResultBackend, TaskiqResult
from taskiq.cli.worker.args import WorkerArgs
from taskiq.cli.worker.receiver import Receiver
from taskiq.events import TaskiqEvents
from taskiq.exceptions import TaskiqError
from taskiq.message import BrokerMessage
Expand Down Expand Up @@ -87,12 +84,9 @@ class InMemoryBroker(AsyncBroker):
It's useful for local development, if you don't want to setup real broker.
"""

def __init__( # noqa: WPS211
def __init__(
self,
sync_tasks_pool_size: int = 4,
logs_format: Optional[str] = None,
max_stored_results: int = 100,
cast_types: bool = True,
result_backend: Optional[AsyncResultBackend[Any]] = None,
task_id_generator: Optional[Callable[[], str]] = None,
) -> None:
Expand All @@ -104,16 +98,7 @@ def __init__( # noqa: WPS211
result_backend=result_backend,
task_id_generator=task_id_generator,
)
self.receiver = Receiver(
self,
WorkerArgs(
broker="",
modules=[],
max_threadpool_threads=sync_tasks_pool_size,
no_parse=not cast_types,
log_collector_format=logs_format or WorkerArgs.log_collector_format,
),
)
self.channel: Channel[BrokerMessage] = Channel()

async def kick(self, message: BrokerMessage) -> None:
"""
Expand All @@ -128,31 +113,18 @@ async def kick(self, message: BrokerMessage) -> None:
target_task = self.available_tasks.get(message.task_name)
if target_task is None:
raise TaskiqError("Unknown task.")
if not self.receiver.dependency_graphs.get(target_task.task_name):
self.receiver.dependency_graphs[target_task.task_name] = DependencyGraph(
target_task.original_func,
)
if not self.receiver.task_signatures.get(target_task.task_name):
self.receiver.task_signatures[target_task.task_name] = inspect.signature(
target_task.original_func,
)
if not self.receiver.task_hints.get(target_task.task_name):
self.receiver.task_hints[target_task.task_name] = get_type_hints(
target_task.original_func,
)

await self.receiver.callback(message=message)
await self.channel.put(message)

def listen(self) -> AsyncGenerator[BrokerMessage, None]:
async def listen(self) -> AsyncGenerator[BrokerMessage, None]:
"""
Inmemory broker cannot listen.
Listen to channel.

This method throws RuntimeError if you call it.
Because inmemory broker cannot really listen to any of tasks.
This function listens to channel and yields every new message.

:raises RuntimeError: if this method is called.
:yields: broker message.
"""
raise RuntimeError("Inmemory brokers cannot listen.")
async for message in self.channel:
yield message

async def startup(self) -> None:
"""Runs startup events for client and worker side."""
Expand Down
33 changes: 33 additions & 0 deletions tests/test_brokers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import asyncio
from contextlib import contextmanager
from typing import Generator

import pytest

from taskiq.abc.broker import AsyncBroker
from taskiq.brokers.inmemory_broker import InMemoryBroker
from taskiq.cli.worker.args import WorkerArgs
from taskiq.cli.worker.async_task_runner import async_listen_messages


@contextmanager
def receive_messages(broker: AsyncBroker) -> Generator[None, None, None]:
cli_args = WorkerArgs(broker="", modules=[])
listen_task = asyncio.create_task(async_listen_messages(broker, cli_args))
yield
listen_task.cancel()


@pytest.mark.anyio
async def test_inmemory_broker_handle_message() -> None:
"""Test that inmemory broker receive and handle task works."""
broker = InMemoryBroker()

@broker.task()
async def test_echo(msg: str) -> str:
return msg

with receive_messages(broker):
task = await test_echo.kiq("foo")
result = await task.wait_result(timeout=1)
assert result.return_value == "foo"