diff --git a/src/aleph/sdk/client.py b/src/aleph/sdk/client.py index d48fded9..fc5cb5b3 100644 --- a/src/aleph/sdk/client.py +++ b/src/aleph/sdk/client.py @@ -1,3 +1,4 @@ +from __future__ import annotations import asyncio import hashlib import json @@ -45,21 +46,18 @@ ) from aleph_message.models.execution.base import Encoding from aleph_message.status import MessageStatus -from pydantic import ValidationError from aleph.sdk.types import Account, GenericMessage, StorageEnum from aleph.sdk.utils import Writable, copy_async_readable_to_buffer from .conf import settings -from .exceptions import ( - BroadcastError, - FileTooLarge, - InvalidMessageError, - MessageNotFoundError, - MultipleMessagesError, -) +from .exceptions import BroadcastError, FileTooLarge, InvalidMessageError from .models import MessagesResponse -from .utils import check_unix_socket_valid, get_message_type_value +from .query.engines.base import QueryEngine +from .query.engines.http import HttpQueryEngine +from .query.filter import MessageFilter, PostFilter +from .query.manager import QueryManager +from .utils import get_message_type_value logger = logging.getLogger(__name__) @@ -469,11 +467,12 @@ def submit( class AlephClient: - api_server: str - http_session: aiohttp.ClientSession + manager: QueryManager + _close_engine_on_exit: bool # Whether the engine was created by the client or externally def __init__( self, + engine: Optional[QueryEngine] = None, api_server: Optional[str] = None, api_unix_socket: Optional[str] = None, allow_unix_sockets: bool = True, @@ -483,45 +482,42 @@ def __init__( Unix sockets are used when running inside a virtual machine, and can be shared across containers in a more secure way than TCP ports. """ - self.api_server = api_server or settings.API_HOST + if not engine: + engine = HttpQueryEngine.create_with_new_session( + api_server=api_server, + api_unix_socket=api_unix_socket, + allow_unix_sockets=allow_unix_sockets, + timeout=timeout, + ) + self._close_engine_on_exit = True + self.manager = QueryManager(engine=engine) + if not self.api_server: raise ValueError("Missing API host") - unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET - if unix_socket_path and allow_unix_sockets: - check_unix_socket_valid(unix_socket_path) - connector = aiohttp.UnixConnector(path=unix_socket_path) - else: - connector = None - - # ClientSession timeout defaults to a private sentinel object and may not be None. - self.http_session = ( - aiohttp.ClientSession( - base_url=self.api_server, connector=connector, timeout=timeout - ) - if timeout - else aiohttp.ClientSession( - base_url=self.api_server, - connector=connector, - ) - ) + @property + def api_server(self): + return str(self.manager.engine.source) def __enter__(self) -> UserSessionSync: return UserSessionSync(async_session=self) def __exit__(self, exc_type, exc_val, exc_tb): - close_fut = self.http_session.close() + if not self._close_engine_on_exit: + return + close_fut = self.manager.engine.stop() try: loop = asyncio.get_running_loop() loop.run_until_complete(close_fut) except RuntimeError: asyncio.run(close_fut) - async def __aenter__(self) -> "AlephClient": + async def __aenter__(self) -> AlephClient: return self async def __aexit__(self, exc_type, exc_val, exc_tb): - await self.http_session.close() + if self._close_engine_on_exit: + await self.manager.engine.stop() async def fetch_aggregate( self, @@ -537,16 +533,7 @@ async def fetch_aggregate( :param limit: Maximum number of items to fetch (Default: 100) """ - params: Dict[str, Any] = {"keys": key} - if limit: - params["limit"] = limit - - async with self.http_session.get( - f"/api/v0/aggregates/{address}.json", params=params - ) as resp: - result = await resp.json() - data = result.get("data", dict()) - return data.get(key) + return await self.manager.engine.fetch_aggregate(address, key, limit) async def fetch_aggregates( self, @@ -562,20 +549,7 @@ async def fetch_aggregates( :param limit: Maximum number of items to fetch (Default: 100) """ - keys_str = ",".join(keys) if keys else "" - params: Dict[str, Any] = {} - if keys_str: - params["keys"] = keys_str - if limit: - params["limit"] = limit - - async with self.http_session.get( - f"/api/v0/aggregates/{address}.json", - params=params, - ) as resp: - result = await resp.json() - data = result.get("data", dict()) - return data + return await self.manager.engine.fetch_aggregates(address, keys, limit) async def get_posts( self, @@ -607,6 +581,21 @@ async def get_posts( :param end_date: Latest date to fetch messages from """ + query_filter = PostFilter( + types=types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) + + queryset = self.manager.engine.apply_filter(query_filter) + return await queryset.fetch_posts(pagination=pagination, page=page) + params: Dict[str, Any] = dict(pagination=pagination, page=page) if types is not None: @@ -648,7 +637,7 @@ async def download_file_to_buffer( :param output_buffer: Writable binary buffer. The file will be written to this buffer. """ - async with self.http_session.get( + async with self.manager.engine._http_session.get( f"/api/v0/storage/raw/{file_hash}" ) as response: if response.status == 200: @@ -736,96 +725,24 @@ async def get_messages( ) -> MessagesResponse: """ Fetch a list of messages from the network. - - :param pagination: Number of items to fetch (Default: 200) - :param page: Page to fetch, begins at 1 (Default: 1) - :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" - :param content_types: Filter by content type - :param content_keys: Filter by content key - :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) - :param addresses: Addresses of the posts to fetch (Default: all addresses) - :param tags: Tags of the posts to fetch (Default: all tags) - :param hashes: Specific item_hashes to fetch - :param channels: Channels of the posts to fetch (Default: all channels) - :param chains: Filter by sender address chain - :param start_date: Earliest date to fetch messages from - :param end_date: Latest date to fetch messages from - :param ignore_invalid_messages: Ignore invalid messages (Default: False) - :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) """ - ignore_invalid_messages = ( - True if ignore_invalid_messages is None else ignore_invalid_messages - ) - invalid_messages_log_level = ( - logging.NOTSET - if invalid_messages_log_level is None - else invalid_messages_log_level - ) - - params: Dict[str, Any] = dict(pagination=pagination, page=page) - if message_type is not None: - params["msgType"] = message_type.value - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if content_keys is not None: - params["contentKeys"] = ",".join(content_keys) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date + query_filter = MessageFilter( + message_type=message_type, + content_types=content_types, + content_keys=content_keys, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) - async with self.http_session.get( - "/api/v0/messages.json", params=params - ) as resp: - resp.raise_for_status() - response_json = await resp.json() - messages_raw = response_json["messages"] - - # All messages may not be valid according to the latest specification in - # aleph-message. This allows the user to specify how errors should be handled. - messages: List[AlephMessage] = [] - for message_raw in messages_raw: - try: - message = parse_message(message_raw) - messages.append(message) - except KeyError as e: - if not ignore_invalid_messages: - raise e - logger.log( - level=invalid_messages_log_level, - msg=f"KeyError: Field '{e.args[0]}' not found", - ) - except ValidationError as e: - if not ignore_invalid_messages: - raise e - if invalid_messages_log_level: - logger.log(level=invalid_messages_log_level, msg=e) - - return MessagesResponse( - messages=messages, - pagination_page=response_json["pagination_page"], - pagination_total=response_json["pagination_total"], - pagination_per_page=response_json["pagination_per_page"], - pagination_item=response_json["pagination_item"], - ) + queryset = self.manager.apply_filter(query_filter=query_filter) + return await queryset.fetch_messages(page=page, page_size=pagination) async def get_message( self, @@ -840,17 +757,16 @@ async def get_message( :param message_type: Type of message to fetch :param channel: Channel of the message to fetch """ - messages_response = await self.get_messages( + + query_filter = MessageFilter( hashes=[item_hash], + message_type=message_type, channels=[channel] if channel else None, ) - if len(messages_response.messages) < 1: - raise MessageNotFoundError(f"No such hash {item_hash}") - if len(messages_response.messages) != 1: - raise MultipleMessagesError( - f"Multiple messages found for the same item_hash `{item_hash}`" - ) - message: GenericMessage = messages_response.messages[0] + queryset = self.manager.apply_filter(query_filter=query_filter) + message: GenericMessage = await queryset.first() + + # Optional, additional message type check. if message_type: expected_type = get_message_type_value(message_type) if message.type != expected_type: @@ -860,7 +776,7 @@ async def get_message( ) return message - async def watch_messages( + def watch_messages( self, message_type: Optional[MessageType] = None, content_types: Optional[Iterable[str]] = None, @@ -887,48 +803,21 @@ async def watch_messages( :param start_date: Start date from when to watch :param end_date: End date until when to watch """ - params: Dict[str, Any] = dict() - - if message_type is not None: - params["msgType"] = message_type.value - if content_types is not None: - params["contentTypes"] = ",".join(content_types) - if refs is not None: - params["refs"] = ",".join(refs) - if addresses is not None: - params["addresses"] = ",".join(addresses) - if tags is not None: - params["tags"] = ",".join(tags) - if hashes is not None: - params["hashes"] = ",".join(hashes) - if channels is not None: - params["channels"] = ",".join(channels) - if chains is not None: - params["chains"] = ",".join(chains) - if start_date is not None: - if not isinstance(start_date, float) and hasattr(start_date, "timestamp"): - start_date = start_date.timestamp() - params["startDate"] = start_date - if end_date is not None: - if not isinstance(end_date, float) and hasattr(start_date, "timestamp"): - end_date = end_date.timestamp() - params["endDate"] = end_date + query_filter = MessageFilter( + message_type=message_type, + content_types=content_types, + refs=refs, + addresses=addresses, + tags=tags, + hashes=hashes, + channels=channels, + chains=chains, + start_date=start_date, + end_date=end_date, + ) - async with self.http_session.ws_connect( - "/api/ws0/messages", params=params - ) as ws: - logger.debug("Websocket connected") - async for msg in ws: - if msg.type == aiohttp.WSMsgType.TEXT: - if msg.data == "close cmd": - await ws.close() - break - else: - data = json.loads(msg.data) - yield parse_message(data) - elif msg.type == aiohttp.WSMsgType.ERROR: - break + return self.manager.apply_filter(query_filter=query_filter).watch() class AuthenticatedAlephClient(AlephClient): @@ -962,10 +851,10 @@ def __init__( ) self.account = account - def __enter__(self) -> "AuthenticatedUserSessionSync": + def __enter__(self) -> AuthenticatedUserSessionSync: return AuthenticatedUserSessionSync(async_session=self) - async def __aenter__(self) -> "AuthenticatedAlephClient": + async def __aenter__(self) -> AuthenticatedAlephClient: return self async def ipfs_push(self, content: Mapping) -> str: diff --git a/src/aleph/sdk/objects.py b/src/aleph/sdk/objects.py new file mode 100644 index 00000000..120c6608 --- /dev/null +++ b/src/aleph/sdk/objects.py @@ -0,0 +1,49 @@ +def test_object(): + account = Account() + client = AlephClient(account) + client.get_messages(...) + + my_posts = Post.objects.filter(sender="0x...") + + new_post = Post.objects.create(body="Hello, world!") + new_post.save() + + +class Manager: + pass + + +class PostManager(Manager): + pass + + +class HttpPostManager(PostManager): + def filter(self, **kwargs): + query_filter = QueryFilter(**kwargs) + pass + + def create(self, body: str): + pass + + def save(self, post): + pass + + +class SqliteCachedHttpManager(HttpPostManager): + pass + + +class Post: + objects: PostManager + message: AlephMessage + + def __init__(self, manager: Optional[PostManager] = None): + self.objects = PostManager() + + +class MyPost(Post): + title: str + number: int + + +Aleph diff --git a/src/aleph/sdk/query.py b/src/aleph/sdk/query.py new file mode 100644 index 00000000..4ef11e51 --- /dev/null +++ b/src/aleph/sdk/query.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import logging +from datetime import datetime +from typing import Any, AsyncIterator, Dict, Iterable, List, Optional, Union + +import aiohttp +from aleph_message.models import ( + AlephMessage, + MessagesResponse, + MessageType, + parse_message, +) +from pydantic import ValidationError + +logger = logging.getLogger(__name__) + + +def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]: + if values: + return ",".join(values) + else: + return None + + +def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]: + if date is None: + return None + elif isinstance(date, float): + return date + elif hasattr(date, "timestamp"): + return date.timestamp() + else: + raise TypeError(f"Invalid type: `{type(date)}`") + + +class MessageQueryFilter: + """ + A collection of filters that can be applied on message queries. + + :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" + :param content_types: Filter by content type + :param content_keys: Filter by content key + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Filter by sender address chain + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + """ + + message_type: Optional[MessageType] + content_types: Optional[Iterable[str]] + content_keys: Optional[Iterable[str]] + refs: Optional[Iterable[str]] + addresses: Optional[Iterable[str]] + tags: Optional[Iterable[str]] + hashes: Optional[Iterable[str]] + channels: Optional[Iterable[str]] + chains: Optional[Iterable[str]] + start_date: Optional[Union[datetime, float]] + end_date: Optional[Union[datetime, float]] + + def __init__( + self, + message_type: Optional[MessageType] = None, + content_types: Optional[Iterable[str]] = None, + content_keys: Optional[Iterable[str]] = None, + refs: Optional[Iterable[str]] = None, + addresses: Optional[Iterable[str]] = None, + tags: Optional[Iterable[str]] = None, + hashes: Optional[Iterable[str]] = None, + channels: Optional[Iterable[str]] = None, + chains: Optional[Iterable[str]] = None, + start_date: Optional[Union[datetime, float]] = None, + end_date: Optional[Union[datetime, float]] = None, + ): + self.message_type = message_type + self.content_types = content_types + self.content_keys = content_keys + self.refs = refs + self.addresses = addresses + self.tags = tags + self.hashes = hashes + self.channels = channels + self.chains = chains + self.start_date = start_date + self.end_date = end_date + + def as_http_params(self) -> Dict[str, str]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + partial_result = { + "msgType": self.message_type.value if self.message_type else None, + "contentTypes": serialize_list(self.content_types), + "contentKeys": serialize_list(self.content_keys), + "refs": serialize_list(self.refs), + "addresses": serialize_list(self.addresses), + "tags": serialize_list(self.tags), + "hashes": serialize_list(self.hashes), + "channels": serialize_list(self.channels), + "chains": serialize_list(self.chains), + "startDate": _date_field_to_float(self.start_date), + "endDate": _date_field_to_float(self.end_date), + } + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result + + +class MessageQuery: + """ + Interface to query messages from an API server. + + :param query_filter: The filter to apply when fetching messages + :param http_client_session: The Aiohttp client session to the API server + :param ignore_invalid_messages: Ignore invalid messages (Default: False) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + """ + + query_filter: MessageQueryFilter + http_client_session: aiohttp.ClientSession + ignore_invalid_messages: bool + invalid_messages_log_level: int + + def __init__( + self, + query_filter: MessageQueryFilter, + http_client_session: aiohttp.ClientSession, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ): + self.query_filter = query_filter + self.http_client_session = http_client_session + self.ignore_invalid_messages = ignore_invalid_messages + self.invalid_messages_log_level = invalid_messages_log_level + + async def fetch_json(self, page: int = 0, pagination: int = 200): + """Return the raw JSON response from the API server.""" + params: Dict[str, Any] = self.query_filter.as_http_params() + params["page"] = str(page) + params["pagination"] = str(pagination) + async with self.http_client_session.get( + "/api/v0/messages.json", params=params + ) as resp: + resp.raise_for_status() + return await resp.json() + + async def fetch(self, page: int = 0, pagination: int = 200): + """Return the parsed messages from the API server.""" + response_json = await self.fetch_json(page=page, pagination=pagination) + + messages_raw = response_json["messages"] + + # All messages may not be valid according to the latest specification in + # aleph-message. This allows the user to specify how errors should be handled. + messages: List[AlephMessage] = [] + for message_raw in messages_raw: + try: + message = parse_message(message_raw) + messages.append(message) + except KeyError as e: + if not self.ignore_invalid_messages: + raise e + logger.log( + level=self.invalid_messages_log_level, + msg=f"KeyError: Field '{e.args[0]}' not found", + ) + except ValidationError as e: + if not self.ignore_invalid_messages: + raise e + if self.invalid_messages_log_level: + logger.log(level=self.invalid_messages_log_level, msg=e) + + return MessagesResponse( + messages=messages, + pagination_page=response_json["pagination_page"], + pagination_total=response_json["pagination_total"], + pagination_per_page=response_json["pagination_per_page"], + pagination_item=response_json["pagination_item"], + ) + + async def __aiter__(self) -> AsyncIterator[AlephMessage]: + """Iterate asynchronously over matching messages. + Handles pagination internally. + + ``` + async for message in MessageQuery(query_filter=filter): + print(message) + ``` + """ + page: int = 0 + partial_result = await self.fetch(page=0) + while partial_result: + for message in partial_result.messages: + yield message + + page += 1 + partial_result = await self.fetch(page=0) diff --git a/src/aleph/sdk/query/__init__.py b/src/aleph/sdk/query/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/query/engines/__init__.py b/src/aleph/sdk/query/engines/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/query/engines/base.py b/src/aleph/sdk/query/engines/base.py new file mode 100644 index 00000000..df01abd2 --- /dev/null +++ b/src/aleph/sdk/query/engines/base.py @@ -0,0 +1,52 @@ +from __future__ import annotations + +from typing import Any, AsyncIterator, List, Optional, Protocol, Dict + +from aleph_message.models import AlephMessage + +from aleph.sdk.query.filter import MessageFilter, WatchFilter + + +class QueryEngine(Protocol): + """ + Interface to query messages from an API server. + + :param query_filter: The filter to apply when fetching messages + :param http_client_session: The Aiohttp client session to the API server + :param ignore_invalid_messages: Ignore invalid messages (Default: False) + :param invalid_messages_log_level: Log level to use for invalid messages (Default: logging.NOTSET) + """ + + query_filter: MessageFilter + source: Any + + def stop(self): + pass + + async def __aiter__(self) -> AsyncIterator[AlephMessage]: + pass + + async def first(self) -> Optional[AlephMessage]: + pass + + async def all(self) -> List[AlephMessage]: + pass + + async def fetch_messages( + self, query_filter: MessageFilter, page: int = 0, page_size: int = 200 + ): + pass + + async def fetch_aggregate( + self, + address: str, + key: str, + limit: int = 100, + ) -> Dict[str, Dict]: + pass + + async def watch_messages( + self, query_filter: WatchFilter + ) -> AsyncIterator[AlephMessage]: + yield + raise NotImplementedError() diff --git a/src/aleph/sdk/query/engines/http.py b/src/aleph/sdk/query/engines/http.py new file mode 100644 index 00000000..ab9c62cd --- /dev/null +++ b/src/aleph/sdk/query/engines/http.py @@ -0,0 +1,246 @@ +from __future__ import annotations + +import json +import logging +from datetime import datetime +from typing import Any, AsyncIterator, Dict, Iterable, Optional, Union + +import aiohttp +from aleph_message import parse_message +from aleph_message.models import AlephMessage +from yarl import URL + +from aleph.sdk.conf import settings +from aleph.sdk.query.engines.base import QueryEngine +from aleph.sdk.query.filter import MessageFilter, WatchFilter, BaseFilter, PostFilter +from aleph.sdk.utils import check_unix_socket_valid + +logger = logging.getLogger(__name__) + + +def create_http_session( + api_server: Optional[str] = None, + api_unix_socket: Optional[str] = None, + allow_unix_sockets: bool = True, + timeout: Optional[aiohttp.ClientTimeout] = None, +): + """Create an HTTP session, using an UNIX socket or TCP and an optional timeout.""" + + host = api_server or settings.API_HOST + if not host: + raise ValueError("Missing API host") + + unix_socket_path = api_unix_socket or settings.API_UNIX_SOCKET + if unix_socket_path and allow_unix_sockets: + check_unix_socket_valid(unix_socket_path) + connector = aiohttp.UnixConnector(path=unix_socket_path) + else: + connector = None + + # ClientSession timeout defaults to a private sentinel object and may not be None. + return ( + aiohttp.ClientSession( + base_url=host, + connector=connector, + timeout=timeout, + ) + if timeout + else aiohttp.ClientSession( + base_url=host, + connector=connector, + ) + ) + + +class HttpQueryEngine(QueryEngine): + _http_session: aiohttp.ClientSession + _api_server: URL + ignore_invalid_messages: bool + invalid_messages_log_level: int + + def __init__( + self, + http_session: aiohttp.ClientSession, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ): + base_url = http_session._base_url + if not base_url: + raise ValueError("No API server defined on the HTTP session.") + + self._http_session = http_session + self._api_server = base_url + self.ignore_invalid_messages = ignore_invalid_messages + self.invalid_messages_log_level = invalid_messages_log_level + + async def stop(self): + await self._http_session.close() + + @property + def source(self) -> URL: + return self._api_server + + @classmethod + def create_with_new_session( + cls, + api_server: Optional[str] = None, + api_unix_socket: Optional[str] = None, + allow_unix_sockets: bool = True, + timeout: Optional[aiohttp.ClientTimeout] = None, + ) -> HttpQueryEngine: + http_session = create_http_session( + api_server=api_server, + api_unix_socket=api_unix_socket, + allow_unix_sockets=allow_unix_sockets, + timeout=timeout, + ) + return cls( + http_session=http_session, + ) + + async def fetch_messages( + self, query_filter, page: int = 0, page_size: int = 200 + ) -> Dict[str, Any]: + """Return the raw JSON response from the API server.""" + params: Dict[str, Any] = self._convert_query_filter(query_filter) + params["page"] = str(page) + params["pagination"] = str(page_size) + async with self._http_session.get( + "/api/v0/messages.json", params=params + ) as resp: + resp.raise_for_status() + result = await resp.json() + return result + + async def fetch_aggregate( + self, + address: str, + key: str, + limit: int = 100, + ) -> Dict[str, Dict]: + """ + Fetch a value from the aggregate store by owner address and item key. + + :param address: Address of the owner of the aggregate + :param key: Key of the aggregate + :param limit: Maximum number of items to fetch (Default: 100) + """ + + params: Dict[str, Any] = {"keys": key} + if limit: + params["limit"] = limit + + async with self._http_session.get( + f"/api/v0/aggregates/{address}.json", params=params + ) as resp: + result = await resp.json() + data = result.get("data", dict()) + return data.get(key) + + async def fetch_aggregates( + self, + address: str, + keys: Optional[Iterable[str]] = None, + limit: int = 100, + ) -> Dict[str, Dict]: + """ + Fetch key-value pairs from the aggregate store by owner address. + + :param address: Address of the owner of the aggregate + :param keys: Keys of the aggregates to fetch (Default: all items) + :param limit: Maximum number of items to fetch (Default: 100) + """ + + keys_str = ",".join(keys) if keys else "" + params: Dict[str, Any] = {} + if keys_str: + params["keys"] = keys_str + if limit: + params["limit"] = limit + + async with self._http_session.get( + f"/api/v0/aggregates/{address}.json", + params=params, + ) as resp: + result = await resp.json() + data = result.get("data", dict()) + return data + + async def watch_messages( + self, query_filter: WatchFilter + ) -> AsyncIterator[AlephMessage]: + """Return an async iterator that will yield messages as they are received.""" + params: Dict[str, Any] = self._convert_query_filter(query_filter) + async with self._http_session.ws_connect( + "/api/ws0/messages", params=params + ) as ws: + logger.debug("Websocket connected") + async for msg in ws: + msg: aiohttp.WSMessage + if msg.type == aiohttp.WSMsgType.TEXT: + if msg.data == "close cmd": + await ws.close() + break + else: + data: Dict = json.loads(msg.data) + yield parse_message(data) + elif msg.type == aiohttp.WSMsgType.ERROR: + break + + @staticmethod + def _convert_query_filter(query_filter: BaseFilter) -> Dict[str, Any]: + """Convert the filters into a dict that can be used by an `aiohttp` client + as `params` to build the HTTP query string. + """ + + message_type = ( + query_filter.message_type.value if query_filter.message_type else None + ) + + partial_result = { + "msgType": message_type, + "contentTypes": serialize_list(query_filter.content_types), + "refs": serialize_list(query_filter.refs), + "addresses": serialize_list(query_filter.addresses), + "tags": serialize_list(query_filter.tags), + "hashes": serialize_list(query_filter.hashes), + "channels": serialize_list(query_filter.channels), + "chains": serialize_list(query_filter.chains), + "startDate": _date_field_to_float(query_filter.start_date), + "endDate": _date_field_to_float(query_filter.end_date), + } + + if isinstance(query_filter, MessageFilter): + partial_result["contentKeys"] = serialize_list(query_filter.content_keys) + + if isinstance(query_filter, PostFilter): + partial_result["types"] = serialize_list(query_filter.types) + + # Ensure all values are strings. + result: Dict[str, str] = {} + + # Drop empty values + for key, value in partial_result.items(): + if value: + assert isinstance(value, str), f"Value must be a string: `{value}`" + result[key] = value + + return result + + +def serialize_list(values: Optional[Iterable[str]]) -> Optional[str]: + if values: + return ",".join(values) + else: + return None + + +def _date_field_to_float(date: Optional[Union[datetime, float]]) -> Optional[float]: + if date is None: + return None + elif isinstance(date, float): + return date + elif hasattr(date, "timestamp"): + return date.timestamp() + else: + raise TypeError(f"Invalid type: `{type(date)}`") diff --git a/src/aleph/sdk/query/engines/sqlite/__init__.py b/src/aleph/sdk/query/engines/sqlite/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/aleph/sdk/query/engines/sqlite/engine.py b/src/aleph/sdk/query/engines/sqlite/engine.py new file mode 100644 index 00000000..c90df9a3 --- /dev/null +++ b/src/aleph/sdk/query/engines/sqlite/engine.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import AsyncIterator + +from aleph_message.models import AlephMessage + +from aleph.sdk.query.engines.base import QueryEngine + + +class SqliteDatabase: + # Should use Peewee or something similar + pass + + +class SqliteQueryEngine(QueryEngine): + async def page(self, page: int = 0, page_size: int = 200): + raise NotImplementedError() + + async def __aiter__(self) -> AsyncIterator[AlephMessage]: + raise NotImplementedError() diff --git a/src/aleph/sdk/query/engines/sqlite/models.py b/src/aleph/sdk/query/engines/sqlite/models.py new file mode 100644 index 00000000..8f0d515a --- /dev/null +++ b/src/aleph/sdk/query/engines/sqlite/models.py @@ -0,0 +1,49 @@ +from datetime import datetime +from typing import Any, Dict, List, Optional + +from aleph_message.models import AlephMessage, ItemHash +from pydantic import BaseModel, Field + + +class MessagesResponse(PaginationResponse): + """Response from an Aleph node API on the path /api/v0/messages.json""" + + messages: List[AlephMessage] + pagination_item = "messages" + + +class Post(BaseModel): + """ + A post is a type of message that can be updated. Over the get_posts API + we get the latest version of a post. + """ + + item_hash: ItemHash = Field(description="Hash of the content (sha256 by default)") + content: Dict[str, Any] = Field( + description="The content.content of the POST message" + ) + original_item_hash: ItemHash = Field( + description="Hash of the original content (sha256 by default)" + ) + original_type: str = Field( + description="The original, user-generated 'content-type' of the POST message" + ) + address: str = Field(description="The address of the sender of the POST message") + ref: Optional[str] = Field(description="Other message referenced by this one") + channel: Optional[str] = Field( + description="The channel where the POST message was published" + ) + created: datetime = Field(description="The time when the POST message was created") + last_updated: datetime = Field( + description="The time when the POST message was last updated" + ) + + class Config: + allow_extra = False + + +class PostsResponse(PaginationResponse): + """Response from an Aleph node API on the path /api/v0/posts.json""" + + posts: List[Post] + pagination_item = "posts" diff --git a/src/aleph/sdk/query/filter.py b/src/aleph/sdk/query/filter.py new file mode 100644 index 00000000..bb4ad812 --- /dev/null +++ b/src/aleph/sdk/query/filter.py @@ -0,0 +1,49 @@ +from __future__ import annotations + +from dataclasses import dataclass +from datetime import datetime +from typing import Iterable, Optional, Union + +from aleph_message.models import MessageType + + +@dataclass +class BaseFilter: + refs: Optional[Iterable[str]] = None + addresses: Optional[Iterable[str]] = None + tags: Optional[Iterable[str]] = None + hashes: Optional[Iterable[str]] = None + channels: Optional[Iterable[str]] = None + chains: Optional[Iterable[str]] = None + start_date: Optional[Union[datetime, float]] = None + end_date: Optional[Union[datetime, float]] = None + + +class WatchFilter(BaseFilter): + message_type: Optional[MessageType] = None + + +class MessageFilter(BaseFilter): + """ + A collection of filters that can be applied on message queries. + + :param message_type: Filter by message type, can be "AGGREGATE", "POST", "PROGRAM", "VM", "STORE" or "FORGET" + :param content_types: Filter by content type + :param content_keys: Filter by content key + :param refs: If set, only fetch posts that reference these hashes (in the "refs" field) + :param addresses: Addresses of the posts to fetch (Default: all addresses) + :param tags: Tags of the posts to fetch (Default: all tags) + :param hashes: Specific item_hashes to fetch + :param channels: Channels of the posts to fetch (Default: all channels) + :param chains: Filter by sender address chain + :param start_date: Earliest date to fetch messages from + :param end_date: Latest date to fetch messages from + """ + + message_type: Optional[MessageType] = None + content_types: Optional[Iterable[str]] = None + content_keys: Optional[Iterable[str]] = None + + +class PostFilter(BaseFilter): + types: Optional[Iterable[str]] = None diff --git a/src/aleph/sdk/query/manager.py b/src/aleph/sdk/query/manager.py new file mode 100644 index 00000000..eebd6428 --- /dev/null +++ b/src/aleph/sdk/query/manager.py @@ -0,0 +1,59 @@ +from __future__ import annotations + +import logging + +from aleph.sdk.query.engines.base import QueryEngine +from aleph.sdk.query.filter import MessageFilter +from aleph.sdk.query.queryset import QuerySet + +logger = logging.getLogger(__name__) + + +class QueryManager: + """Query manager for Aleph messages. + + This is the main entry point for querying messages from an engine with a filter. + """ + + query_filter: MessageFilter + engine: QueryEngine + ignore_invalid_messages: bool + invalid_messages_log_level: int + + def __init__( + self, + engine: QueryEngine, + ignore_invalid_messages: bool = True, + invalid_messages_log_level: int = logging.NOTSET, + ): + self.engine = engine + self.ignore_invalid_messages = ignore_invalid_messages + self.invalid_messages_log_level = invalid_messages_log_level + + def filter(self, **kwargs) -> QuerySet: + query_filter = MessageFilter(**kwargs) + return QuerySet( + query_filter=query_filter, + engine=self.engine, + ignore_invalid_messages=self.ignore_invalid_messages, + invalid_messages_log_level=self.invalid_messages_log_level, + ) + + def apply_filter(self, query_filter: MessageFilter) -> QuerySet: + return QuerySet( + query_filter=query_filter, + engine=self.engine, + ignore_invalid_messages=self.ignore_invalid_messages, + invalid_messages_log_level=self.invalid_messages_log_level, + ) + + def all(self) -> QuerySet: + """ + Return all messages. Use with caution, as this may return a lot of messages. + """ + return QuerySet( + query_filter=MessageFilter(), # An empty filter should return all messages. + engine=self.engine, + ignore_invalid_messages=self.ignore_invalid_messages, + invalid_messages_log_level=self.invalid_messages_log_level, + ) diff --git a/src/aleph/sdk/query/queryset.py b/src/aleph/sdk/query/queryset.py new file mode 100644 index 00000000..a1449ec6 --- /dev/null +++ b/src/aleph/sdk/query/queryset.py @@ -0,0 +1,119 @@ +from __future__ import annotations + +import logging +from typing import AsyncIterator, List + +from aleph_message.models import ( + AlephMessage, + MessagesResponse, + item_hash, + parse_message, +) +from pydantic import ValidationError + +from aleph.sdk.exceptions import MessageNotFoundError, MultipleMessagesError +from aleph.sdk.query.engines.base import QueryEngine +from aleph.sdk.query.filter import MessageFilter, BaseFilter + +logger = logging.getLogger(__name__) + + +class QuerySet: + """QuerySet for Aleph messages. + + Helps to iterate over messages from an engine with a filter. + """ + + query_filter: BaseFilter + engine: QueryEngine + ignore_invalid_messages: bool + invalid_messages_log_level: int + + def __init__( + self, + query_filter: BaseFilter, + engine: QueryEngine, + ignore_invalid_messages: bool, + invalid_messages_log_level: int, + ): + self.query_filter = query_filter + self.engine = engine + self.ignore_invalid_messages = ignore_invalid_messages + self.invalid_messages_log_level = invalid_messages_log_level + + async def fetch_messages( + self, page: int = 0, page_size: int = 200 + ) -> MessagesResponse: + """Return the parsed messages from the API server.""" + response_json = await self.engine.fetch_messages( + query_filter=self.query_filter, page=page, page_size=page_size + ) + + messages_raw = response_json["messages"] # TODO: Depends on API response format + + # All messages may not be valid according to the latest specification in + # aleph-message. This allows the user to specify how errors should be handled. + messages: List[AlephMessage] = [] + for message_raw in messages_raw: + try: + message = parse_message(message_raw) + messages.append(message) + except KeyError as e: + if not self.ignore_invalid_messages: + raise e + logger.log( + level=self.invalid_messages_log_level, + msg=f"KeyError: Field '{e.args[0]}' not found", + ) + except ValidationError as e: + if not self.ignore_invalid_messages: + raise e + if self.invalid_messages_log_level: + logger.log(level=self.invalid_messages_log_level, msg=e) + + return MessagesResponse( + messages=messages, + pagination_page=response_json["pagination_page"], + pagination_total=response_json["pagination_total"], + pagination_per_page=response_json["pagination_per_page"], + pagination_item=response_json["pagination_item"], + ) + + async def __aiter__(self) -> AsyncIterator[AlephMessage]: + """Iterate asynchronously over matching messages. + Handles pagination internally. + + ``` + async for message in MessageQuery(query_filter=filter): + print(message) + ``` + """ + page: int = 0 + partial_result = await self.fetch_messages(page=0) + while partial_result: + for message in partial_result.messages: + yield message + + page += 1 + partial_result = await self.fetch_messages(page=0) + + async def first(self) -> AlephMessage: + """Return the first matching message.""" + response = await self.fetch_messages(page=0, page_size=1) + + # Raise specific exceptions. + if len(response.messages) < 1: + raise MessageNotFoundError(f"No such hash {item_hash}") + if len(response.messages) != 1: + raise MultipleMessagesError( + f"Multiple messages found for the same item_hash `{item_hash}`" + ) + + message = response.messages[0] + return message + + def watch(self) -> AsyncIterator[AlephMessage]: + """Watch for new messages. + This is an infinite iterator that will yield messages as they are received. + """ + return self.engine.watch_messages(query_filter=self.query_filter) diff --git a/src/aleph/sdk/sync.py b/src/aleph/sdk/sync.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit/test_asynchronous.py b/tests/unit/test_asynchronous.py index 8973263b..f247d5c4 100644 --- a/tests/unit/test_asynchronous.py +++ b/tests/unit/test_asynchronous.py @@ -52,7 +52,7 @@ async def text(self): client = AuthenticatedAlephClient( account=ethereum_account, api_server="http://localhost" ) - client.http_session = http_session + # client.http_session = http_session return client diff --git a/tests/unit/test_asynchronous_get.py b/tests/unit/test_asynchronous_get.py index 5a139328..564c22c1 100644 --- a/tests/unit/test_asynchronous_get.py +++ b/tests/unit/test_asynchronous_get.py @@ -31,7 +31,7 @@ def get(self, *_args, **_kwargs): http_session = MockHttpSession() client = AlephClient(api_server="http://localhost") - client.http_session = http_session + # client.http_session = http_session return client