diff --git a/.gitignore b/.gitignore index da858d9..ef7bb86 100644 --- a/.gitignore +++ b/.gitignore @@ -15,3 +15,7 @@ __pycache__ .pytest_cache .ruff_cache coverage.xml +.coverage +.venv +venv +.DS_Store diff --git a/pyproject.toml b/pyproject.toml index 0f3a947..f4ee601 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,9 @@ classifiers = [ "Development Status :: 5 - Production/Stable", "Topic :: Internet :: File Transfer Protocol (FTP)", ] +dependencies = [ + "typing_extensions >= 4.10,<5.0", +] [project.urls] Github = "https://github.com/aio-libs/aioftp" @@ -60,6 +63,10 @@ dev = [ "black", "ruff", + # typecheckers + "pyright", + "mypy", + # docs "sphinx", "alabaster", @@ -81,9 +88,11 @@ target-version = ["py38"] [tool.ruff] line-length = 120 target-version = "py38" -select = ["E", "W", "F", "Q", "UP", "I", "ASYNC"] src = ["src"] +[tool.ruff.lint] +select = ["E", "W", "F", "Q", "UP", "I", "ASYNC"] + [tool.coverage] run.source = ["./src/aioftp"] run.omit = ["./src/aioftp/__main__.py"] @@ -96,3 +105,28 @@ log_format = "%(asctime)s.%(msecs)03d %(name)-20s %(levelname)-8s %(filename)-15 log_date_format = "%H:%M:%S" log_level = "DEBUG" asyncio_mode = "strict" + +[tool.pyright] +pythonVersion = "3.8" +strict = ["src"] +reportImplicitOverride = true +exclude = ["venv", ".venv", "build", "tests", "docs", "ftpbench.py"] + +[tool.mypy] +files = "src/aioftp" +show_absolute_path = true +pretty = true +strict = true +enable_error_code = [ + "explicit-override", + "redundant-self", "redundant-expr", "possibly-undefined", +] +warn_unreachable = true +disallow_any_generics = true +disallow_untyped_defs = true + +[[tool.mypy.overrides]] +module = "tests" +strict = false +disallow_any_generics = false +disallow_untyped_defs = false diff --git a/src/aioftp/__init__.py b/src/aioftp/__init__.py index 15235af..226e69d 100644 --- a/src/aioftp/__init__.py +++ b/src/aioftp/__init__.py @@ -4,15 +4,106 @@ import importlib.metadata -from .client import * -from .common import * -from .errors import * -from .pathio import * -from .server import * +from .client import BaseClient, Client, Code, DataConnectionThrottleStreamIO +from .common import ( + DEFAULT_ACCOUNT, + DEFAULT_BLOCK_SIZE, + DEFAULT_PASSWORD, + DEFAULT_PORT, + DEFAULT_USER, + END_OF_LINE, + AbstractAsyncLister, + AsyncListerMixin, + AsyncStreamIterator, + StreamIO, + StreamThrottle, + Throttle, + ThrottleStreamIO, + async_enterable, + setlocale, + with_timeout, + wrap_with_container, +) +from .errors import AIOFTPException, NoAvailablePort, PathIOError, PathIsNotAbsolute, StatusCodeError +from .pathio import ( + AbstractPathIO, + AsyncPathIO, + MemoryPathIO, + PathIO, + PathIONursery, +) +from .server import ( + AbstractUserManager, + AvailableConnections, + Connection, + ConnectionConditions, + MemoryUserManager, + PathConditions, + PathPermissions, + Permission, + Server, + User, + worker, +) +from .types import NotEmptyCodes, check_not_empty_codes -__version__ = importlib.metadata.version(__package__) +__version__ = importlib.metadata.version(__package__) # pyright: ignore[reportArgumentType] version = tuple(map(int, __version__.split("."))) + __all__ = ( - client.__all__ + server.__all__ + errors.__all__ + common.__all__ + pathio.__all__ + ("version", "__version__") + # client + "BaseClient", + "Client", + "DataConnectionThrottleStreamIO", + "Code", + # server + "Server", + "Permission", + "User", + "AbstractUserManager", + "MemoryUserManager", + "Connection", + "AvailableConnections", + "ConnectionConditions", + "PathConditions", + "PathPermissions", + "worker", + # errors + "AIOFTPException", + "StatusCodeError", + "PathIsNotAbsolute", + "PathIOError", + "NoAvailablePort", + # common + "with_timeout", + "StreamIO", + "Throttle", + "StreamThrottle", + "ThrottleStreamIO", + "END_OF_LINE", + "DEFAULT_BLOCK_SIZE", + "wrap_with_container", + "AsyncStreamIterator", + "AbstractAsyncLister", + "AsyncListerMixin", + "async_enterable", + "DEFAULT_PORT", + "DEFAULT_USER", + "DEFAULT_PASSWORD", + "DEFAULT_ACCOUNT", + "setlocale", + # pathio + "AbstractPathIO", + "PathIO", + "AsyncPathIO", + "MemoryPathIO", + "PathIONursery", + # + # types + "NotEmptyCodes", + "check_not_empty_codes", + # + "version", + "__version__", ) diff --git a/src/aioftp/__main__.py b/src/aioftp/__main__.py index 3b7db1d..c12ca00 100644 --- a/src/aioftp/__main__.py +++ b/src/aioftp/__main__.py @@ -5,6 +5,7 @@ import contextlib import logging import socket +from typing import Type import aioftp @@ -65,6 +66,7 @@ format="%(asctime)s [%(name)s] %(message)s", datefmt="[%H:%M:%S]:", ) +path_io_factory: Type[aioftp.AbstractPathIO] if args.memory: user = aioftp.User(args.login, args.password, base_path="/") path_io_factory = aioftp.MemoryPathIO @@ -81,7 +83,7 @@ }[args.family] -async def main(): +async def main() -> None: server = aioftp.Server([user], path_io_factory=path_io_factory) await server.run(args.host, args.port, family=family) diff --git a/src/aioftp/client.py b/src/aioftp/client.py index 338f73c..3fceb65 100644 --- a/src/aioftp/client.py +++ b/src/aioftp/client.py @@ -7,6 +7,26 @@ import pathlib import re from functools import partial +from ssl import SSLContext +from types import TracebackType +from typing import ( + Any, + AsyncIterator, + Callable, + Deque, + Dict, + List, + Literal, + Optional, + Sequence, + Tuple, + Type, + Union, + cast, + overload, +) + +from typing_extensions import NotRequired, Self, TypeAlias, TypedDict, override from . import errors, pathio from .common import ( @@ -25,9 +45,10 @@ setlocale, wrap_with_container, ) +from .types import AsyncEnterableProtocol, NotEmptyCodes, check_not_empty_codes try: - from siosocks.io.asyncio import open_connection + from siosocks.io.asyncio import open_connection # type: ignore except ImportError: from asyncio import open_connection @@ -40,13 +61,42 @@ ) logger = logging.getLogger(__name__) +_PathType: TypeAlias = Union[str, pathlib.PurePath] +_ConnectionType: TypeAlias = Literal["I", "A", "E", "L"] + +_InfoFileType: TypeAlias = Literal["file", "dir", "link", "unknown"] + + +class WindowsInfoDict(TypedDict): + type: _InfoFileType + size: NotRequired[str] + modify: str + + +UnixInfoDict = TypedDict( + "UnixInfoDict", + { + "type": _InfoFileType, + "size": str, + "modify": str, + "unix.mode": int, + "unix.links": str, + "unix.owner": str, + "unix.group": str, + }, +) + +InfoDict: TypeAlias = Union[WindowsInfoDict, UnixInfoDict] + +_ParserType: TypeAlias = Callable[[bytes], Tuple[pathlib.PurePosixPath, InfoDict]] + class Code(str): """ Representation of server status code. """ - def matches(self, mask): + def matches(self, mask: str) -> bool: """ :param mask: Template for comparision. If mask symbol is not digit then it passes. @@ -77,11 +127,15 @@ class DataConnectionThrottleStreamIO(ThrottleStreamIO): :py:class:`aioftp.ThrottleStreamIO` """ - def __init__(self, client, *args, **kwargs): + def __init__(self, client: "BaseClient", *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self.client = client - async def finish(self, expected_codes="2xx", wait_codes="1xx"): + async def finish( + self, + expected_codes: Code = Code("2xx"), + wait_codes: Union[Tuple[Code], Code] = Code("1xx"), + ) -> None: """ :py:func:`asyncio.coroutine` @@ -99,29 +153,45 @@ async def finish(self, expected_codes="2xx", wait_codes="1xx"): self.close() await self.client.command(None, expected_codes, wait_codes) - async def __aexit__(self, exc_type, exc, tb): + @override + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: if exc is None: await self.finish() else: self.close() +_CodesORCode: TypeAlias = Union[NotEmptyCodes, Code] +_MaybeCodesORCode: TypeAlias = Union[Tuple[Code, ...], Code] +_EmptyCodes: TypeAlias = Tuple[()] + + class BaseClient: + + stream: Optional[ThrottleStreamIO] + server_port: Optional[int] + server_host: Optional[str] + def __init__( self, *, - socket_timeout=None, - connection_timeout=None, - read_speed_limit=None, - write_speed_limit=None, - path_timeout=None, - path_io_factory=pathio.PathIO, - encoding="utf-8", - ssl=None, - parse_list_line_custom=None, - parse_list_line_custom_first=True, - passive_commands=("epsv", "pasv"), - **siosocks_asyncio_kwargs, + socket_timeout: Optional[int] = None, + connection_timeout: Optional[int] = None, + read_speed_limit: Optional[int] = None, + write_speed_limit: Optional[int] = None, + path_timeout: Optional[int] = None, + path_io_factory: Type[pathio.AbstractPathIO] = pathio.PathIO, + encoding: str = "utf-8", + ssl: Optional[SSLContext] = None, + parse_list_line_custom: Optional[_ParserType] = None, + parse_list_line_custom_first: bool = True, + passive_commands: Sequence[str] = ("epsv", "pasv"), + **siosocks_asyncio_kwargs: Any, ): self.socket_timeout = socket_timeout self.connection_timeout = connection_timeout @@ -138,8 +208,10 @@ def __init__( self.parse_list_line_custom_first = parse_list_line_custom_first self._passive_commands = passive_commands self._open_connection = partial(open_connection, ssl=self.ssl, **siosocks_asyncio_kwargs) + self.server_host = None + self.server_port = None - async def connect(self, host, port=DEFAULT_PORT): + async def connect(self, host: str, port: int = DEFAULT_PORT) -> Optional[List[str]]: self.server_host = host self.server_port = port reader, writer = await asyncio.wait_for( @@ -152,15 +224,16 @@ async def connect(self, host, port=DEFAULT_PORT): throttles={"_": self.throttle}, timeout=self.socket_timeout, ) + return None - def close(self): + def close(self) -> None: """ Close connection. """ if self.stream is not None: self.stream.close() - async def parse_line(self): + async def parse_line(self) -> Tuple[Code, str]: """ :py:func:`asyncio.coroutine` @@ -174,6 +247,9 @@ async def parse_line(self): :raises asyncio.TimeoutError: if there where no data for `timeout` period """ + if self.stream is None: + raise ValueError("Connection is not established") + line = await self.stream.readline() if not line: self.stream.close() @@ -182,7 +258,7 @@ async def parse_line(self): logger.debug(s) return Code(s[:3]), s[3:] - async def parse_response(self): + async def parse_response(self) -> Tuple[Code, List[str]]: """ :py:func:`asyncio.coroutine` @@ -207,7 +283,12 @@ async def parse_response(self): info.append(curr_code + rest) return code, info - def check_codes(self, expected_codes, received_code, info): + def check_codes( + self, + expected_codes: Union[Tuple[Code, ...], Code], + received_code: Code, + info: List[str], + ) -> None: """ Checks if any of expected matches received. @@ -226,13 +307,49 @@ def check_codes(self, expected_codes, received_code, info): if not any(map(received_code.matches, expected_codes)): raise errors.StatusCodeError(expected_codes, received_code, info) + @overload async def command( self, - command=None, - expected_codes=(), - wait_codes=(), - censor_after=None, - ): + command: Optional[str], + expected_codes: _EmptyCodes, + wait_codes: _EmptyCodes, + censor_after: Optional[int] = None, + ) -> None: ... + + @overload + async def command( + self, + command: Optional[str], + expected_codes: _MaybeCodesORCode, + wait_codes: _CodesORCode, + censor_after: Optional[int] = None, + ) -> Tuple[Code, List[str]]: ... + + @overload + async def command( + self, + command: Optional[str], + expected_codes: _CodesORCode, + wait_codes: _MaybeCodesORCode = (), + censor_after: Optional[int] = None, + ) -> Tuple[Code, List[str]]: ... + + @overload + async def command( + self, + command: Optional[str], + expected_codes: _MaybeCodesORCode = (), + wait_codes: _MaybeCodesORCode = (), + censor_after: Optional[int] = None, + ) -> Optional[Tuple[Code, List[str]]]: ... + + async def command( + self, + command: Optional[str] = None, + expected_codes: Union[_EmptyCodes, _MaybeCodesORCode, _CodesORCode] = (), + wait_codes: Union[_EmptyCodes, _MaybeCodesORCode, _CodesORCode] = (), + censor_after: Optional[int] = None, + ) -> Optional[Tuple[Code, List[str]]]: """ :py:func:`asyncio.coroutine` @@ -268,6 +385,10 @@ async def command( else: logger.debug(command) message = command + END_OF_LINE + + if self.stream is None: + raise ValueError("Connection is not established") + await self.stream.write(message.encode(encoding=self.encoding)) if expected_codes or wait_codes: code, info = await self.parse_response() @@ -276,9 +397,10 @@ async def command( if expected_codes: self.check_codes(expected_codes, code, info) return code, info + return None @staticmethod - def parse_epsv_response(s): + def parse_epsv_response(s: str) -> Tuple[None, int]: """ Parsing `EPSV` (`message (|||port|)`) response. @@ -294,7 +416,7 @@ def parse_epsv_response(s): return None, port @staticmethod - def parse_pasv_response(s): + def parse_pasv_response(s: str) -> Tuple[str, int]: """ Parsing `PASV` server response. @@ -311,7 +433,7 @@ def parse_pasv_response(s): return ip, port @staticmethod - def parse_directory_response(s): + def parse_directory_response(s: str) -> pathlib.PurePosixPath: """ Parsing directory server response. @@ -340,7 +462,7 @@ def parse_directory_response(s): return pathlib.PurePosixPath(directory) @staticmethod - def parse_unix_mode(s): + def parse_unix_mode(s: str) -> int: """ Parsing unix mode strings ("rwxr-x--t") into hexacimal notation. @@ -379,7 +501,7 @@ def parse_unix_mode(s): return mode @staticmethod - def format_date_time(d): + def format_date_time(d: datetime.datetime) -> str: """ Formats dates from strptime in a consistent format @@ -391,7 +513,7 @@ def format_date_time(d): return d.strftime("%Y%m%d%H%M00") @classmethod - def parse_ls_date(cls, s, *, now=None): + def parse_ls_date(cls, s: str, *, now: Optional[datetime.datetime] = None) -> str: """ Parsing dates from the ls unix utility. For example, "Nov 18 1958", "Jan 03 2018", and "Nov 18 12:29". @@ -430,7 +552,7 @@ def parse_ls_date(cls, s, *, now=None): d = datetime.datetime.strptime(s, "%b %d %Y") return cls.format_date_time(d) - def parse_list_line_unix(self, b): + def parse_list_line_unix(self, b: bytes) -> Tuple[pathlib.PurePosixPath, UnixInfoDict]: """ Attempt to parse a LIST line (similar to unix ls utility). @@ -441,7 +563,7 @@ def parse_list_line_unix(self, b): :rtype: (:py:class:`pathlib.PurePosixPath`, :py:class:`dict`) """ s = b.decode(encoding=self.encoding).rstrip() - info = {} + info: dict[str, Any] = {} if s[0] == "-": info["type"] = "file" elif s[0] == "d": @@ -483,9 +605,22 @@ def parse_list_line_unix(self, b): i = -2 if link_dst[-1] == "'" or link_dst[-1] == '"' else -1 info["type"] = "dir" if link_dst[i] == "/" else "file" s = link_src - return pathlib.PurePosixPath(s), info - def parse_list_line_windows(self, b): + typed_info = UnixInfoDict( + { + "type": info["type"], + "size": info["size"], + "modify": info["modify"], + "unix.mode": info["unix.mode"], + "unix.links": info["unix.links"], + "unix.owner": info["unix.owner"], + "unix.group": info["unix.group"], + }, + ) + + return pathlib.PurePosixPath(s), typed_info + + def parse_list_line_windows(self, b: bytes) -> Tuple[pathlib.PurePosixPath, WindowsInfoDict]: """ Parsing Microsoft Windows `dir` output @@ -497,13 +632,13 @@ def parse_list_line_windows(self, b): """ line = b.decode(encoding=self.encoding).rstrip("\r\n") date_time_end = line.index("M") - date_time_str = line[: date_time_end + 1].strip().split(" ") - date_time_str = " ".join([x for x in date_time_str if len(x) > 0]) + date_time_list_str = line[: date_time_end + 1].strip().split(" ") + date_time_str = " ".join([x for x in date_time_list_str if len(x) > 0]) line = line[date_time_end + 1 :].lstrip() with setlocale("C"): strptime = datetime.datetime.strptime date_time = strptime(date_time_str, "%m/%d/%Y %I:%M %p") - info = {} + info: dict[str, Any] = {} info["modify"] = self.format_date_time(date_time) next_space = line.index(" ") if line.startswith(""): @@ -519,9 +654,17 @@ def parse_list_line_windows(self, b): filename = line[next_space:].lstrip() if filename == "." or filename == "..": raise ValueError - return pathlib.PurePosixPath(filename), info - def parse_list_line(self, b): + windows_info: WindowsInfoDict = { + "type": info["type"], + "modify": info["modify"], + } + if "size" in info: + windows_info["size"] = info["size"] + + return pathlib.PurePosixPath(filename), windows_info + + def parse_list_line(self, b: bytes) -> Tuple[pathlib.PurePosixPath, InfoDict]: """ Parse LIST response with both Microsoft Windows® parser and UNIX parser @@ -532,8 +675,8 @@ def parse_list_line(self, b): :return: (path, info) :rtype: (:py:class:`pathlib.PurePosixPath`, :py:class:`dict`) """ - ex = [] - parsers = [ + ex: List[Exception] = [] + parsers: List[Optional[_ParserType]] = [ self.parse_list_line_unix, self.parse_list_line_windows, ] @@ -550,7 +693,7 @@ def parse_list_line(self, b): ex.append(e) raise ValueError("All parsers failed to parse", b, ex) - def parse_mlsx_line(self, b): + def parse_mlsx_line(self, b: Union[bytes, str]) -> Tuple[pathlib.PurePosixPath, InfoDict]: """ Parsing MLS(T|D) response. @@ -566,11 +709,12 @@ def parse_mlsx_line(self, b): s = b line = s.rstrip() facts_found, _, name = line.partition(" ") - entry = {} + entry: Dict[str, str] = {} for fact in facts_found[:-1].split(";"): key, _, value = fact.partition("=") entry[key.lower()] = value - return pathlib.PurePosixPath(name), entry + + return pathlib.PurePosixPath(name), cast(InfoDict, entry) class Client(BaseClient): @@ -619,7 +763,8 @@ class Client(BaseClient): :param **siosocks_asyncio_kwargs: siosocks key-word only arguments """ - async def connect(self, host, port=DEFAULT_PORT): + @override + async def connect(self, host: str, port: int = DEFAULT_PORT) -> List[str]: """ :py:func:`asyncio.coroutine` @@ -632,15 +777,15 @@ async def connect(self, host, port=DEFAULT_PORT): :type port: :py:class:`int` """ await super().connect(host, port) - code, info = await self.command(None, "220", "120") + _, info = await self.command(None, Code("220"), Code("120")) return info async def login( self, - user=DEFAULT_USER, - password=DEFAULT_PASSWORD, - account=DEFAULT_ACCOUNT, - ): + user: str = DEFAULT_USER, + password: str = DEFAULT_PASSWORD, + account: str = DEFAULT_ACCOUNT, + ) -> None: """ :py:func:`asyncio.coroutine` @@ -657,7 +802,7 @@ async def login( :raises aioftp.StatusCodeError: if unknown code received """ - code, info = await self.command("USER " + user, ("230", "33x")) + code, info = await self.command("USER " + user, check_not_empty_codes((Code("230"), Code("33x")))) while code.matches("33x"): censor_after = None if code == "331": @@ -666,14 +811,19 @@ async def login( elif code == "332": cmd = "ACCT " + account else: - raise errors.StatusCodeError("33x", code, info) + raise errors.StatusCodeError(Code("33x"), code, info) code, info = await self.command( cmd, - ("230", "33x"), + check_not_empty_codes( + ( + Code("230"), + Code("33x"), + ), + ), censor_after=censor_after, ) - async def get_current_directory(self): + async def get_current_directory(self) -> pathlib.PurePosixPath: """ :py:func:`asyncio.coroutine` @@ -681,11 +831,11 @@ async def get_current_directory(self): :rtype: :py:class:`pathlib.PurePosixPath` """ - code, info = await self.command("PWD", "257") + _, info = await self.command("PWD", Code("257")) directory = self.parse_directory_response(info[-1]) return directory - async def change_directory(self, path=".."): + async def change_directory(self, path: _PathType = "..") -> None: """ :py:func:`asyncio.coroutine` @@ -699,9 +849,9 @@ async def change_directory(self, path=".."): cmd = "CDUP" else: cmd = "CWD " + str(path) - await self.command(cmd, "2xx") + await self.command(cmd, Code("2xx")) - async def make_directory(self, path, *, parents=True): + async def make_directory(self, path: _PathType, *, parents: bool = True) -> None: """ :py:func:`asyncio.coroutine` @@ -714,7 +864,7 @@ async def make_directory(self, path, *, parents=True): :type parents: :py:class:`bool` """ path = pathlib.PurePosixPath(path) - need_create = [] + need_create: List[_PathType] = [] while path.name and not await self.exists(path): need_create.append(path) path = path.parent @@ -722,9 +872,9 @@ async def make_directory(self, path, *, parents=True): break need_create.reverse() for path in need_create: - await self.command("MKD " + str(path), "257") + await self.command("MKD " + str(path), Code("257")) - async def remove_directory(self, path): + async def remove_directory(self, path: _PathType) -> None: """ :py:func:`asyncio.coroutine` @@ -733,9 +883,15 @@ async def remove_directory(self, path): :param path: empty directory to remove :type path: :py:class:`str` or :py:class:`pathlib.PurePosixPath` """ - await self.command("RMD " + str(path), "250") + await self.command("RMD " + str(path), Code("250")) - def list(self, path="", *, recursive=False, raw_command=None): + def list( + self, + path: _PathType = "", + *, + recursive: bool = False, + raw_command: Optional[str] = None, + ) -> AsyncListerMixin[Tuple[pathlib.PurePath, InfoDict]]: """ :py:func:`asyncio.coroutine` @@ -770,59 +926,78 @@ def list(self, path="", *, recursive=False, raw_command=None): >>> stats = await client.list() """ - class AsyncLister(AsyncListerMixin): - stream = None - - async def _new_stream(cls, local_path): - cls.path = local_path - cls.parse_line = self.parse_mlsx_line + class AsyncLister( + AsyncIterator[Tuple[pathlib.PurePath, InfoDict]], + AsyncListerMixin[Tuple[pathlib.PurePath, InfoDict]], + ): + stream: Optional[DataConnectionThrottleStreamIO] = None + directories: Deque[Tuple[_PathType, InfoDict]] + parse_line: Callable[[bytes], Tuple[pathlib.PurePath, InfoDict]] + client: Client = self + + async def _new_stream( + self, + local_path: _PathType, + ) -> Optional[DataConnectionThrottleStreamIO]: + self.path = local_path + self.parse_line = self.client.parse_mlsx_line if raw_command not in [None, "MLSD", "LIST"]: raise ValueError( - "raw_command must be one of MLSD or " f"LIST, but got {raw_command}", + f"raw_command must be one of MLSD or LIST, but got {raw_command}", ) if raw_command in [None, "MLSD"]: try: - command = ("MLSD " + str(cls.path)).strip() - return await self.get_stream(command, "1xx") + command = ("MLSD " + str(self.path)).strip() + return await self.client.get_stream(command, Code("1xx")) except errors.StatusCodeError as e: code = e.received_codes[-1] if not code.matches("50x") or raw_command is not None: raise if raw_command in [None, "LIST"]: - cls.parse_line = self.parse_list_line - command = ("LIST " + str(cls.path)).strip() - return await self.get_stream(command, "1xx") + self.parse_line = self.client.parse_list_line + command = ("LIST " + str(self.path)).strip() + return await self.client.get_stream(command, Code("1xx")) + return None + + @override + def __aiter__(self) -> Self: + self.directories = collections.deque() + return self - def __aiter__(cls): - cls.directories = collections.deque() - return cls + @override + async def __anext__(self) -> Tuple[pathlib.PurePath, InfoDict]: + if self.stream is None: + self.stream = await self._new_stream(path) - async def __anext__(cls): - if cls.stream is None: - cls.stream = await cls._new_stream(path) while True: - line = await cls.stream.readline() + if self.stream is None: + raise StopAsyncIteration + line = await self.stream.readline() while not line: - await cls.stream.finish() - if cls.directories: - current_path, info = cls.directories.popleft() - cls.stream = await cls._new_stream(current_path) - line = await cls.stream.readline() + await self.stream.finish() + if self.directories: + current_path, info = self.directories.popleft() + self.stream = await self._new_stream(current_path) + + if self.stream is None: + raise ValueError("Connection is not established") + + line = await self.stream.readline() else: raise StopAsyncIteration - name, info = cls.parse_line(line) + name, info = self.parse_line(line) # skipping . and .. as these are symlinks in Unix if str(name) in (".", ".."): continue - stat = cls.path / name, info + stat = self.path / name, info if info["type"] == "dir" and recursive: - cls.directories.append(stat) + self.directories.append(stat) return stat return AsyncLister() - async def stat(self, path): + async def stat(self, path: _PathType) -> InfoDict: """ :py:func:`asyncio.coroutine` @@ -836,8 +1011,8 @@ async def stat(self, path): """ path = pathlib.PurePosixPath(path) try: - code, info = await self.command("MLST " + str(path), "2xx") - name, info = self.parse_mlsx_line(info[1].lstrip()) + _, info_ = await self.command("MLST " + str(path), Code("2xx")) + _, info = self.parse_mlsx_line(info_[1].lstrip()) return info except errors.StatusCodeError as e: if not e.received_codes[-1].matches("50x"): @@ -853,7 +1028,7 @@ async def stat(self, path): "path does not exists", ) - async def is_file(self, path): + async def is_file(self, path: _PathType) -> bool: """ :py:func:`asyncio.coroutine` @@ -867,7 +1042,7 @@ async def is_file(self, path): info = await self.stat(path) return info["type"] == "file" - async def is_dir(self, path): + async def is_dir(self, path: _PathType) -> bool: """ :py:func:`asyncio.coroutine` @@ -881,7 +1056,7 @@ async def is_dir(self, path): info = await self.stat(path) return info["type"] == "dir" - async def exists(self, path): + async def exists(self, path: _PathType) -> bool: """ :py:func:`asyncio.coroutine` @@ -900,7 +1075,7 @@ async def exists(self, path): return False raise - async def rename(self, source, destination): + async def rename(self, source: _PathType, destination: _PathType) -> None: """ :py:func:`asyncio.coroutine` @@ -912,10 +1087,10 @@ async def rename(self, source, destination): :param destination: path new name :type destination: :py:class:`str` or :py:class:`pathlib.PurePosixPath` """ - await self.command("RNFR " + str(source), "350") - await self.command("RNTO " + str(destination), "2xx") + await self.command("RNFR " + str(source), Code("350")) + await self.command("RNTO " + str(destination), Code("2xx")) - async def remove_file(self, path): + async def remove_file(self, path: _PathType) -> None: """ :py:func:`asyncio.coroutine` @@ -924,9 +1099,9 @@ async def remove_file(self, path): :param path: file to remove :type path: :py:class:`str` or :py:class:`pathlib.PurePosixPath` """ - await self.command("DELE " + str(path), "2xx") + await self.command("DELE " + str(path), Code("2xx")) - async def remove(self, path): + async def remove(self, path: _PathType) -> None: """ :py:func:`asyncio.coroutine` @@ -946,7 +1121,12 @@ async def remove(self, path): await self.remove(name) await self.remove_directory(path) - def upload_stream(self, destination, *, offset=0): + def upload_stream( + self, + destination: _PathType, + *, + offset: int = 0, + ) -> AsyncEnterableProtocol[DataConnectionThrottleStreamIO]: """ Create stream for write data to `destination` file. @@ -964,7 +1144,12 @@ def upload_stream(self, destination, *, offset=0): offset=offset, ) - def append_stream(self, destination, *, offset=0): + def append_stream( + self, + destination: _PathType, + *, + offset: int = 0, + ) -> AsyncEnterableProtocol[DataConnectionThrottleStreamIO]: """ Create stream for append (write) data to `destination` file. @@ -982,7 +1167,14 @@ def append_stream(self, destination, *, offset=0): offset=offset, ) - async def upload(self, source, destination="", *, write_into=False, block_size=DEFAULT_BLOCK_SIZE): + async def upload( + self, + source: _PathType, + destination: _PathType = "", + *, + write_into: bool = False, + block_size: int = DEFAULT_BLOCK_SIZE, + ) -> None: """ :py:func:`asyncio.coroutine` @@ -1003,25 +1195,27 @@ async def upload(self, source, destination="", *, write_into=False, block_size=D :param block_size: block size for transaction :type block_size: :py:class:`int` """ - source = pathlib.Path(source) - destination = pathlib.PurePosixPath(destination) + source_path = pathlib.Path(source) + destination_path = pathlib.PurePosixPath(destination) if not write_into: - destination = destination / source.name - if await self.path_io.is_file(source): - await self.make_directory(destination.parent) - async with self.path_io.open(source, mode="rb") as file_in, self.upload_stream(destination) as stream: + destination_path = destination_path / source_path.name + if await self.path_io.is_file(source_path): + await self.make_directory(destination_path.parent) + async with self.path_io.open(source_path, mode="rb") as file_in, self.upload_stream( + destination_path, + ) as stream: async for block in file_in.iter_by_block(block_size): await stream.write(block) - elif await self.path_io.is_dir(source): - await self.make_directory(destination) - sources = collections.deque([source]) + elif await self.path_io.is_dir(source_path): + await self.make_directory(destination_path) + sources = collections.deque([source_path]) while sources: src = sources.popleft() async for path in self.path_io.list(src): if write_into: - relative = destination.name / path.relative_to(source) + relative = destination_path.name / path.relative_to(source_path) else: - relative = path.relative_to(source.parent) + relative = path.relative_to(source_path.parent) if await self.path_io.is_dir(path): await self.make_directory(relative) sources.append(path) @@ -1033,7 +1227,12 @@ async def upload(self, source, destination="", *, write_into=False, block_size=D block_size=block_size, ) - def download_stream(self, source, *, offset=0): + def download_stream( + self, + source: _PathType, + *, + offset: int = 0, + ) -> AsyncEnterableProtocol[DataConnectionThrottleStreamIO]: """ :py:func:`asyncio.coroutine` @@ -1049,7 +1248,14 @@ def download_stream(self, source, *, offset=0): """ return self.get_stream("RETR " + str(source), "1xx", offset=offset) - async def download(self, source, destination="", *, write_into=False, block_size=DEFAULT_BLOCK_SIZE): + async def download( + self, + source: _PathType, + destination: _PathType = "", + *, + write_into: bool = False, + block_size: int = DEFAULT_BLOCK_SIZE, + ) -> None: """ :py:func:`asyncio.coroutine` @@ -1070,23 +1276,25 @@ async def download(self, source, destination="", *, write_into=False, block_size :param block_size: block size for transaction :type block_size: :py:class:`int` """ - source = pathlib.PurePosixPath(source) - destination = pathlib.Path(destination) + source_path = pathlib.PurePosixPath(source) + destination_path = pathlib.Path(destination) if not write_into: - destination = destination / source.name - if await self.is_file(source): + destination_path = destination_path / source_path.name + if await self.is_file(source_path): await self.path_io.mkdir( - destination.parent, + destination_path.parent, parents=True, exist_ok=True, ) - async with self.path_io.open(destination, mode="wb") as file_out, self.download_stream(source) as stream: + async with self.path_io.open(destination_path, mode="wb") as file_out, self.download_stream( + source_path, + ) as stream: async for block in stream.iter_by_block(block_size): await file_out.write(block) - elif await self.is_dir(source): - await self.path_io.mkdir(destination, parents=True, exist_ok=True) - for name, info in await self.list(source): - full = destination / name.relative_to(source) + elif await self.is_dir(source_path): + await self.path_io.mkdir(destination_path, parents=True, exist_ok=True) + for name, info in await self.list(source_path): + full = destination_path / name.relative_to(source_path) if info["type"] in ("file", "dir"): await self.download( name, @@ -1095,30 +1303,30 @@ async def download(self, source, destination="", *, write_into=False, block_size block_size=block_size, ) - async def quit(self): + async def quit(self) -> None: """ :py:func:`asyncio.coroutine` Send "QUIT" and close connection. """ - await self.command("QUIT", "2xx") + await self.command("QUIT", Code("2xx")) self.close() - async def _do_epsv(self): - code, info = await self.command("EPSV", "229") + async def _do_epsv(self) -> Tuple[None, int]: + _, info = await self.command("EPSV", Code("229")) ip, port = self.parse_epsv_response(info[-1]) return ip, port - async def _do_pasv(self): - code, info = await self.command("PASV", "227") + async def _do_pasv(self) -> Tuple[str, int]: + _, info = await self.command("PASV", Code("227")) ip, port = self.parse_pasv_response(info[-1]) return ip, port async def get_passive_connection( self, - conn_type="I", - commands=None, - ): + conn_type: _ConnectionType = "I", + commands: Optional[Sequence[str]] = None, + ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]: """ :py:func:`asyncio.coroutine` @@ -1143,7 +1351,7 @@ async def get_passive_connection( commands = self._passive_commands if not commands: raise ValueError("No passive commands provided") - await self.command("TYPE " + conn_type, "200") + await self.command("TYPE " + conn_type, Code("200")) for i, name in enumerate(commands, start=1): name = name.lower() if name not in functions: @@ -1155,13 +1363,18 @@ async def get_passive_connection( is_last = i == len(commands) if is_last or not e.received_codes[-1].matches("50x"): raise - if ip in ("0.0.0.0", None): + if ip in ("0.0.0.0", None): # type: ignore ip = self.server_host - reader, writer = await self._open_connection(ip, port) + reader, writer = await self._open_connection(ip, port) # type: ignore return reader, writer @async_enterable - async def get_stream(self, *command_args, conn_type="I", offset=0): + async def get_stream( + self, + *command_args: Any, + conn_type: _ConnectionType = "I", + offset: int = 0, + ) -> DataConnectionThrottleStreamIO: """ :py:func:`asyncio.coroutine` @@ -1180,7 +1393,7 @@ async def get_stream(self, *command_args, conn_type="I", offset=0): """ reader, writer = await self.get_passive_connection(conn_type) if offset: - await self.command("REST " + str(offset), "350") + await self.command("REST " + str(offset), Code("350")) await self.command(*command_args) stream = DataConnectionThrottleStreamIO( self, @@ -1191,7 +1404,7 @@ async def get_stream(self, *command_args, conn_type="I", offset=0): ) return stream - async def abort(self, *, wait=True): + async def abort(self, *, wait: bool = True) -> None: """ :py:func:`asyncio.coroutine` @@ -1201,15 +1414,21 @@ async def abort(self, *, wait=True): :type wait: :py:class:`bool` """ if wait: - await self.command("ABOR", "226", "426") + await self.command("ABOR", Code("226"), Code("426")) else: await self.command("ABOR") @classmethod @contextlib.asynccontextmanager async def context( - cls, host, port=DEFAULT_PORT, user=DEFAULT_USER, password=DEFAULT_PASSWORD, account=DEFAULT_ACCOUNT, **kwargs - ): + cls, + host: str, + port: int = DEFAULT_PORT, + user: str = DEFAULT_USER, + password: str = DEFAULT_PASSWORD, + account: str = DEFAULT_ACCOUNT, + **kwargs: Any, + ) -> AsyncIterator["Client"]: """ Classmethod async context manager. This create :py:class:`aioftp.Client`, make async call to diff --git a/src/aioftp/common.py b/src/aioftp/common.py index 41eec46..fca6d30 100644 --- a/src/aioftp/common.py +++ b/src/aioftp/common.py @@ -1,10 +1,40 @@ import abc import asyncio -import collections import functools import locale import threading from contextlib import contextmanager +from types import TracebackType +from typing import ( + TYPE_CHECKING, + Any, + AsyncContextManager, + AsyncIterable, + AsyncIterator, + Awaitable, + Callable, + Coroutine, + Dict, + Final, + Generator, + Generic, + List, + NamedTuple, + Optional, + Tuple, + Type, + TypeVar, + Union, + overload, +) + +from typing_extensions import ParamSpec, Self, override + +from .types import AsyncEnterableProtocol +from .utils import get_param + +if TYPE_CHECKING: + from .pathio import AsyncPathIO __all__ = ( "with_timeout", @@ -27,35 +57,67 @@ ) -END_OF_LINE = "\r\n" -DEFAULT_BLOCK_SIZE = 8192 +END_OF_LINE: Final[str] = "\r\n" +DEFAULT_BLOCK_SIZE: Final[int] = 8192 + +DEFAULT_PORT: Final[int] = 21 +DEFAULT_USER: Final[str] = "anonymous" +DEFAULT_PASSWORD: Final[str] = "anon@" +DEFAULT_ACCOUNT: Final[str] = "" +HALF_OF_YEAR_IN_SECONDS: Final[int] = 15778476 +TWO_YEARS_IN_SECONDS: Final[float] = ((365 * 3 + 366) * 24 * 60 * 60) / 2 -DEFAULT_PORT = 21 -DEFAULT_USER = "anonymous" -DEFAULT_PASSWORD = "anon@" -DEFAULT_ACCOUNT = "" -HALF_OF_YEAR_IN_SECONDS = 15778476 -TWO_YEARS_IN_SECONDS = ((365 * 3 + 366) * 24 * 60 * 60) / 2 +_T = TypeVar("_T") +_PS = ParamSpec("_PS") -def _now(): +def _now() -> float: return asyncio.get_running_loop().time() -def _with_timeout(name): - def decorator(f): +def _with_timeout( + name: str, +) -> Callable[ + [Callable[_PS, Coroutine[None, None, _T]]], + Callable[_PS, Coroutine[None, None, _T]], +]: + def decorator(f: Callable[_PS, Coroutine[None, None, _T]]) -> Callable[_PS, Coroutine[None, None, _T]]: @functools.wraps(f) - def wrapper(cls, *args, **kwargs): - coro = f(cls, *args, **kwargs) - timeout = getattr(cls, name) - return asyncio.wait_for(coro, timeout) + async def wrapper(*args: _PS.args, **kwargs: _PS.kwargs) -> _T: + self: "AsyncPathIO" = get_param((0, "self"), args, kwargs) + coro = f(*args, **kwargs) + timeout = getattr(self, name) + return await asyncio.wait_for(coro, timeout) return wrapper return decorator -def with_timeout(name): +@overload +def with_timeout( + name: str, +) -> Callable[ + [Callable[_PS, Coroutine[None, None, _T]]], + Callable[_PS, Coroutine[None, None, _T]], +]: ... + + +@overload +def with_timeout( + name: Callable[_PS, Coroutine[None, None, _T]], +) -> Callable[_PS, Coroutine[None, None, _T]]: ... + + +def with_timeout( + name: Union[str, Callable[_PS, Coroutine[None, None, _T]]], +) -> Union[ + Callable[ + [Callable[_PS, Coroutine[None, None, _T]]], + Callable[_PS, Coroutine[None, None, _T]], + ], + Callable[_PS, Coroutine[None, None, _T]], +]: """ Method decorator, wraps method with :py:func:`asyncio.wait_for`. `timeout` argument takes from `name` decorator argument or "timeout". @@ -97,14 +159,16 @@ def with_timeout(name): return _with_timeout("timeout")(name) -class AsyncStreamIterator: - def __init__(self, read_coro): +class AsyncStreamIterator(AsyncIterator[_T], Generic[_T]): + def __init__(self, read_coro: Callable[[], Awaitable[_T]]): self.read_coro = read_coro - def __aiter__(self): + @override + def __aiter__(self) -> "AsyncStreamIterator[_T]": return self - async def __anext__(self): + @override + async def __anext__(self) -> _T: data = await self.read_coro() if data: return data @@ -112,7 +176,11 @@ async def __anext__(self): raise StopAsyncIteration -class AsyncListerMixin: +class AsyncListerMixin( + AsyncIterable[_T], + Awaitable[List[_T]], + Generic[_T], +): """ Add ability to `async for` context to collect data to list via await. @@ -123,17 +191,18 @@ class AsyncListerMixin: >>> results = await Context(...) """ - async def _to_list(self): - items = [] + async def _to_list(self) -> List[_T]: + items: List[_T] = [] async for item in self: items.append(item) return items - def __await__(self): + @override + def __await__(self) -> Generator[None, None, List[_T]]: return self._to_list().__await__() -class AbstractAsyncLister(AsyncListerMixin, abc.ABC): +class AbstractAsyncLister(AsyncListerMixin[_T], abc.ABC): """ Abstract context with ability to collect all iterables into :py:class:`list` via `await` with optional timeout (via @@ -162,16 +231,17 @@ class AbstractAsyncLister(AsyncListerMixin, abc.ABC): [block, block, block, ...] """ - def __init__(self, *, timeout=None): + def __init__(self, *, timeout: Optional[float] = None) -> None: super().__init__() self.timeout = timeout - def __aiter__(self): + @override + def __aiter__(self) -> "AbstractAsyncLister[_T]": return self @with_timeout @abc.abstractmethod - async def __anext__(self): + async def __anext__(self) -> _T: """ :py:func:`asyncio.coroutine` @@ -179,7 +249,12 @@ async def __anext__(self): """ -def async_enterable(f): +_T_acm = TypeVar("_T_acm", bound=AsyncContextManager[Any]) + + +def async_enterable( + f: Callable[_PS, Coroutine[None, None, _T_acm]], +) -> Callable[_PS, AsyncEnterableProtocol[_T_acm]]: """ Decorator. Bring coroutine result up, so it can be used as async context @@ -215,16 +290,19 @@ def async_enterable(f): """ @functools.wraps(f) - def wrapper(*args, **kwargs): - class AsyncEnterableInstance: - async def __aenter__(self): + def wrapper(*args: _PS.args, **kwargs: _PS.kwargs) -> AsyncEnterableProtocol[_T_acm]: + class AsyncEnterableInstance(AsyncEnterableProtocol[_T_acm]): # pyright: ignore[reportGeneralTypeIssues] + @override + async def __aenter__(self) -> _T_acm: self.context = await f(*args, **kwargs) - return await self.context.__aenter__() + return await self.context.__aenter__() # type: ignore - async def __aexit__(self, *args, **kwargs): + @override + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: await self.context.__aexit__(*args, **kwargs) - def __await__(self): + @override + def __await__(self) -> Generator[None, None, _T_acm]: return f(*args, **kwargs).__await__() return AsyncEnterableInstance() @@ -232,9 +310,18 @@ def __await__(self): return wrapper -def wrap_with_container(o): +_T_str_bounded = TypeVar("_T_str_bounded", bound=str) + + +@overload +def wrap_with_container(o: _T_str_bounded) -> Tuple[_T_str_bounded]: ... +@overload +def wrap_with_container(o: Any) -> Any: ... + + +def wrap_with_container(o: _T) -> Union[_T, Tuple[str]]: if isinstance(o, str): - o = (o,) + return (o,) return o @@ -260,14 +347,22 @@ class StreamIO: :type write_timeout: :py:class:`int`, :py:class:`float` or :py:class:`None` """ - def __init__(self, reader, writer, *, timeout=None, read_timeout=None, write_timeout=None): + def __init__( + self, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter, + *, + timeout: Optional[Union[int, float]] = None, + read_timeout: Optional[Union[int, float]] = None, + write_timeout: Optional[Union[int, float]] = None, + ) -> None: self.reader = reader self.writer = writer self.read_timeout = read_timeout or timeout self.write_timeout = write_timeout or timeout @with_timeout("read_timeout") - async def readline(self): + async def readline(self) -> bytes: """ :py:func:`asyncio.coroutine` @@ -276,7 +371,7 @@ async def readline(self): return await self.reader.readline() @with_timeout("read_timeout") - async def read(self, count=-1): + async def read(self, count: int = -1) -> bytes: """ :py:func:`asyncio.coroutine` @@ -288,7 +383,7 @@ async def read(self, count=-1): return await self.reader.read(count) @with_timeout("read_timeout") - async def readexactly(self, count): + async def readexactly(self, count: int) -> bytes: """ :py:func:`asyncio.coroutine` @@ -300,7 +395,7 @@ async def readexactly(self, count): return await self.reader.readexactly(count) @with_timeout("write_timeout") - async def write(self, data): + async def write(self, data: bytes) -> None: """ :py:func:`asyncio.coroutine` @@ -313,7 +408,7 @@ async def write(self, data): self.writer.write(data) await self.writer.drain() - def close(self): + def close(self) -> None: """ Close connection. """ @@ -332,13 +427,13 @@ class Throttle: :type reset_rate: :py:class:`int` or :py:class:`float` """ - def __init__(self, *, limit=None, reset_rate=10): + def __init__(self, *, limit: Optional[int] = None, reset_rate: Union[int, float] = 10) -> None: self._limit = limit self.reset_rate = reset_rate - self._start = None + self._start: Optional[float] = None self._sum = 0 - async def wait(self): + async def wait(self) -> None: """ :py:func:`asyncio.coroutine` @@ -349,7 +444,7 @@ async def wait(self): end = self._start + self._sum / self._limit await asyncio.sleep(max(0, end - now)) - def append(self, data, start): + def append(self, data: bytes, start: float) -> None: """ Count `data` for throttle @@ -369,14 +464,14 @@ def append(self, data, start): self._sum += len(data) @property - def limit(self): + def limit(self) -> Optional[int]: """ Throttle limit """ return self._limit @limit.setter - def limit(self, value): + def limit(self, value: Optional[int]) -> None: """ Set throttle limit @@ -387,17 +482,18 @@ def limit(self, value): self._start = None self._sum = 0 - def clone(self): + def clone(self) -> "Throttle": """ Clone throttle without memory """ return Throttle(limit=self._limit, reset_rate=self.reset_rate) - def __repr__(self): + @override + def __repr__(self) -> str: return f"{self.__class__.__name__}(limit={self._limit!r}, " f"reset_rate={self.reset_rate!r})" -class StreamThrottle(collections.namedtuple("StreamThrottle", "read write")): +class StreamThrottle(NamedTuple): """ Stream throttle with `read` and `write` :py:class:`aioftp.Throttle` @@ -408,7 +504,10 @@ class StreamThrottle(collections.namedtuple("StreamThrottle", "read write")): :type write: :py:class:`aioftp.Throttle` """ - def clone(self): + read: Throttle + write: Throttle + + def clone(self) -> "StreamThrottle": """ Clone throttles without memory """ @@ -418,7 +517,11 @@ def clone(self): ) @classmethod - def from_limits(cls, read_speed_limit=None, write_speed_limit=None): + def from_limits( + cls, + read_speed_limit: Optional[int] = None, + write_speed_limit: Optional[int] = None, + ) -> "StreamThrottle": """ Simple wrapper for creation :py:class:`aioftp.StreamThrottle` @@ -463,11 +566,11 @@ class ThrottleStreamIO(StreamIO): ... ) """ - def __init__(self, *args, throttles={}, **kwargs): + def __init__(self, *args: Any, throttles: Dict[str, StreamThrottle] = {}, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.throttles = throttles - async def wait(self, name): + async def wait(self, name: str) -> None: """ :py:func:`asyncio.coroutine` @@ -476,7 +579,7 @@ async def wait(self, name): :param name: name of throttle to acquire ("read" or "write") :type name: :py:class:`str` """ - tasks = [] + tasks: List[asyncio.Task[None]] = [] for throttle in self.throttles.values(): curr_throttle = getattr(throttle, name) if curr_throttle.limit: @@ -484,7 +587,7 @@ async def wait(self, name): if tasks: await asyncio.wait(tasks) - def append(self, name, data, start): + def append(self, name: str, data: bytes, start: float) -> None: """ Update timeout for all throttles @@ -501,7 +604,8 @@ def append(self, name, data, start): for throttle in self.throttles.values(): getattr(throttle, name).append(data, start) - async def read(self, count=-1): + @override + async def read(self, count: int = -1) -> bytes: """ :py:func:`asyncio.coroutine` @@ -513,7 +617,8 @@ async def read(self, count=-1): self.append("read", data, start) return data - async def readline(self): + @override + async def readline(self) -> bytes: """ :py:func:`asyncio.coroutine` @@ -525,7 +630,8 @@ async def readline(self): self.append("read", data, start) return data - async def write(self, data): + @override + async def write(self, data: bytes) -> None: """ :py:func:`asyncio.coroutine` @@ -536,13 +642,18 @@ async def write(self, data): await super().write(data) self.append("write", data, start) - async def __aenter__(self): + async def __aenter__(self) -> Self: return self - async def __aexit__(self, *args): + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc: Optional[BaseException], + tb: Optional[TracebackType], + ) -> None: self.close() - def iter_by_line(self): + def iter_by_line(self) -> AsyncStreamIterator[bytes]: """ Read/iterate stream by line. @@ -553,9 +664,9 @@ def iter_by_line(self): >>> async for line in stream.iter_by_line(): ... ... """ - return AsyncStreamIterator(self.readline) + return AsyncStreamIterator[bytes](self.readline) - def iter_by_block(self, count=DEFAULT_BLOCK_SIZE): + def iter_by_block(self, count: int = DEFAULT_BLOCK_SIZE) -> AsyncStreamIterator[bytes]: """ Read/iterate stream by block. @@ -573,7 +684,7 @@ def iter_by_block(self, count=DEFAULT_BLOCK_SIZE): @contextmanager -def setlocale(name): +def setlocale(name: str) -> Generator[str, None, None]: """ Context manager with threading lock for set locale on enter, and set it back to original state on exit. diff --git a/src/aioftp/errors.py b/src/aioftp/errors.py index 25af25c..a80e73a 100644 --- a/src/aioftp/errors.py +++ b/src/aioftp/errors.py @@ -1,5 +1,11 @@ +import typing +from typing import Any, Iterable, Optional, Tuple, Union + from . import common +if typing.TYPE_CHECKING: + from . import Code + __all__ = ( "AIOFTPException", "StatusCodeError", @@ -41,9 +47,14 @@ class StatusCodeError(AIOFTPException): Exception members are tuples, even for one code. """ - def __init__(self, expected_codes, received_codes, info): + def __init__( + self, + expected_codes: Union[Tuple["Code", ...], "Code"], + received_codes: Union[Tuple["Code", ...], "Code"], + info: Iterable[str], + ) -> None: super().__init__( - f"Waiting for {expected_codes} but got " f"{received_codes} {info!r}", + f"Waiting for {expected_codes} but got {received_codes} {info!r}", ) self.expected_codes = common.wrap_with_container(expected_codes) self.received_codes = common.wrap_with_container(received_codes) @@ -72,7 +83,7 @@ class PathIOError(AIOFTPException): ... # handle """ - def __init__(self, *args, reason=None, **kwargs): + def __init__(self, *args: Any, reason: Optional[Any] = None, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.reason = reason diff --git a/src/aioftp/pathio.py b/src/aioftp/pathio.py index 8335ea6..9a6248d 100644 --- a/src/aioftp/pathio.py +++ b/src/aioftp/pathio.py @@ -1,13 +1,36 @@ import abc import asyncio -import collections import functools import io import operator +import os import pathlib import stat import sys import time +import typing +from concurrent import futures +from typing import ( + Any, + Awaitable, + BinaryIO, + Callable, + Coroutine, + Dict, + Generator, + List, + Literal, + NamedTuple, + Optional, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) + +from typing_extensions import ParamSpec, Protocol, Self, TypeAlias, override from . import errors from .common import ( @@ -16,6 +39,11 @@ AsyncStreamIterator, with_timeout, ) +from .types import OpenMode, StatsProtocol +from .utils import get_param + +if typing.TYPE_CHECKING: + from .server import Connection __all__ = ( "AbstractPathIO", @@ -25,6 +53,36 @@ "PathIONursery", ) +_Ps = ParamSpec("_Ps") +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) + + +class _DirNodeProtocol(Protocol): + type: Literal["dir"] + name: str + ctime: int + mtime: int + content: List[Union["_FileNodeProtocol", "_DirNodeProtocol"]] + + +class _FileNodeProtocol(Protocol): + type: Literal["file"] + name: str + ctime: int + mtime: int + content: io.BytesIO + + +_NodeProtocol: TypeAlias = Union[_DirNodeProtocol, _FileNodeProtocol] +_NodeType: TypeAlias = Literal["file", "dir"] + + +class SupportsNext(Protocol[_T_co]): + @abc.abstractmethod + def __next__(self) -> _T_co: + raise NotImplementedError + class AsyncPathIOContext: """ @@ -47,32 +105,34 @@ class AsyncPathIOContext: """ - def __init__(self, pathio, args, kwargs): + close: Optional["functools.partial[Awaitable[None]]"] + + def __init__(self, pathio: "AbstractPathIO", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None: self.close = None self.pathio = pathio self.args = args self.kwargs = kwargs - async def __aenter__(self): - self.file = await self.pathio._open(*self.args, **self.kwargs) + async def __aenter__(self) -> Self: + self.file = await self.pathio._open(*self.args, **self.kwargs) # pyright: ignore[reportPrivateUsage] self.seek = functools.partial(self.pathio.seek, self.file) self.write = functools.partial(self.pathio.write, self.file) self.read = functools.partial(self.pathio.read, self.file) self.close = functools.partial(self.pathio.close, self.file) return self - async def __aexit__(self, *args): + async def __aexit__(self, *args: object) -> None: if self.close is not None: await self.close() - def __await__(self): + def __await__(self) -> Generator[None, None, "AsyncPathIOContext"]: return self.__aenter__().__await__() - def iter_by_block(self, count=DEFAULT_BLOCK_SIZE): + def iter_by_block(self, count: int = DEFAULT_BLOCK_SIZE) -> AsyncStreamIterator[bytes]: return AsyncStreamIterator(lambda: self.read(count)) -def universal_exception(coro): +def universal_exception(coro: Callable[_Ps, Coroutine[None, None, _T]]) -> Callable[_Ps, Coroutine[None, None, _T]]: """ Decorator. Reraising any exception (except `CancelledError` and `NotImplementedError`) with universal exception @@ -80,7 +140,7 @@ def universal_exception(coro): """ @functools.wraps(coro) - async def wrapper(*args, **kwargs): + async def wrapper(*args: _Ps.args, **kwargs: _Ps.kwargs) -> _T: try: return await coro(*args, **kwargs) except ( @@ -96,30 +156,38 @@ async def wrapper(*args, **kwargs): class PathIONursery: - def __init__(self, factory): + state: Optional[Union[List[_NodeProtocol], io.BytesIO]] + + def __init__(self, factory: Type["AbstractPathIO"]): self.factory = factory self.state = None - def __call__(self, *args, **kwargs): - instance = self.factory(*args, state=self.state, **kwargs) + def __call__(self, *args: Any, **kwargs: Any) -> "AbstractPathIO": + instance = self.factory(*args, state=self.state, **kwargs) # type: ignore if self.state is None: self.state = instance.state return instance -def defend_file_methods(coro): +def defend_file_methods( + coro: Callable[_Ps, Coroutine[None, None, _T]], +) -> Callable[_Ps, Coroutine[None, None, _T]]: """ Decorator. Raises exception when file methods called with wrapped by :py:class:`aioftp.AsyncPathIOContext` file object. """ @functools.wraps(coro) - async def wrapper(self, file, *args, **kwargs): + async def wrapper( + *args: _Ps.args, + **kwargs: _Ps.kwargs, + ) -> _T: + file = get_param((1, "file"), args, kwargs) if isinstance(file, AsyncPathIOContext): raise ValueError( - "Native path io file methods can not be used " "with wrapped file object", + "Native path io file methods can not be used with wrapped file object", ) - return await coro(self, file, *args, **kwargs) + return await coro(*args, **kwargs) return wrapper @@ -137,19 +205,25 @@ class AbstractPathIO(abc.ABC): :param state: shared pathio state per server """ - def __init__(self, timeout=None, connection=None, state=None): + def __init__( + self, + timeout: Optional[Union[float, int]] = None, + connection: Optional["Connection"] = None, + state: Optional[List[_NodeProtocol]] = None, + ): self.timeout = timeout self.connection = connection @property - def state(self): + def state(self) -> Optional[Union[List[_NodeProtocol], io.BytesIO]]: """ Shared pathio state per server """ + return None @universal_exception @abc.abstractmethod - async def exists(self, path): + async def exists(self, path: pathlib.Path) -> bool: """ :py:func:`asyncio.coroutine` @@ -163,7 +237,7 @@ async def exists(self, path): @universal_exception @abc.abstractmethod - async def is_dir(self, path): + async def is_dir(self, path: pathlib.Path) -> bool: """ :py:func:`asyncio.coroutine` @@ -177,7 +251,7 @@ async def is_dir(self, path): @universal_exception @abc.abstractmethod - async def is_file(self, path): + async def is_file(self, path: pathlib.Path) -> bool: """ :py:func:`asyncio.coroutine` @@ -191,7 +265,7 @@ async def is_file(self, path): @universal_exception @abc.abstractmethod - async def mkdir(self, path, *, parents=False, exist_ok=False): + async def mkdir(self, path: pathlib.Path, *, parents: bool = False, exist_ok: bool = False) -> None: """ :py:func:`asyncio.coroutine` @@ -209,7 +283,7 @@ async def mkdir(self, path, *, parents=False, exist_ok=False): @universal_exception @abc.abstractmethod - async def rmdir(self, path): + async def rmdir(self, path: pathlib.Path) -> None: """ :py:func:`asyncio.coroutine` @@ -221,7 +295,7 @@ async def rmdir(self, path): @universal_exception @abc.abstractmethod - async def unlink(self, path): + async def unlink(self, path: pathlib.Path) -> None: """ :py:func:`asyncio.coroutine` @@ -232,7 +306,7 @@ async def unlink(self, path): """ @abc.abstractmethod - def list(self, path): + def list(self, path: pathlib.Path) -> AbstractAsyncLister[Any]: """ Create instance of subclass of :py:class:`aioftp.AbstractAsyncLister`. You should subclass and implement `__anext__` method @@ -260,7 +334,7 @@ def list(self, path): @universal_exception @abc.abstractmethod - async def stat(self, path): + async def stat(self, path: pathlib.Path) -> StatsProtocol: """ :py:func:`asyncio.coroutine` @@ -276,7 +350,7 @@ async def stat(self, path): @universal_exception @abc.abstractmethod - async def _open(self, path, mode): + async def _open(self, path: pathlib.Path, mode: OpenMode) -> BinaryIO: """ :py:func:`asyncio.coroutine` @@ -295,7 +369,7 @@ async def _open(self, path, mode): :return: file-object """ - def open(self, *args, **kwargs): + def open(self, *args: Any, **kwargs: Any) -> AsyncPathIOContext: """ Create instance of :py:class:`aioftp.pathio.AsyncPathIOContext`, parameters passed to :py:meth:`aioftp.AbstractPathIO._open` @@ -307,7 +381,7 @@ def open(self, *args, **kwargs): @universal_exception @defend_file_methods @abc.abstractmethod - async def seek(self, file, offset, whence=io.SEEK_SET): + async def seek(self, file: BinaryIO, offset: int, whence: int = io.SEEK_SET) -> int: """ :py:func:`asyncio.coroutine` @@ -326,7 +400,7 @@ async def seek(self, file, offset, whence=io.SEEK_SET): @universal_exception @defend_file_methods @abc.abstractmethod - async def write(self, file, data): + async def write(self, file: BinaryIO, data: bytes) -> int: """ :py:func:`asyncio.coroutine` @@ -341,7 +415,7 @@ async def write(self, file, data): @universal_exception @defend_file_methods @abc.abstractmethod - async def read(self, file, block_size): + async def read(self, file: BinaryIO, block_size: int) -> bytes: """ :py:func:`asyncio.coroutine` @@ -358,7 +432,7 @@ async def read(self, file, block_size): @universal_exception @defend_file_methods @abc.abstractmethod - async def close(self, file): + async def close(self, file: BinaryIO) -> None: """ :py:func:`asyncio.coroutine` @@ -369,7 +443,7 @@ async def close(self, file): @universal_exception @abc.abstractmethod - async def rename(self, source, destination): + async def rename(self, source: pathlib.Path, destination: pathlib.Path) -> Optional[pathlib.Path]: """ :py:func:`asyncio.coroutine` @@ -388,36 +462,44 @@ class PathIO(AbstractPathIO): Blocking path io. Directly based on :py:class:`pathlib.Path` methods. """ + @override @universal_exception - async def exists(self, path): + async def exists(self, path: pathlib.Path) -> bool: return path.exists() + @override @universal_exception - async def is_dir(self, path): + async def is_dir(self, path: pathlib.Path) -> bool: return path.is_dir() + @override @universal_exception - async def is_file(self, path): + async def is_file(self, path: pathlib.Path) -> bool: return path.is_file() + @override @universal_exception - async def mkdir(self, path, *, parents=False, exist_ok=False): + async def mkdir(self, path: pathlib.Path, *, parents: bool = False, exist_ok: bool = False) -> None: return path.mkdir(parents=parents, exist_ok=exist_ok) + @override @universal_exception - async def rmdir(self, path): + async def rmdir(self, path: pathlib.Path) -> None: return path.rmdir() + @override @universal_exception - async def unlink(self, path): + async def unlink(self, path: pathlib.Path) -> None: return path.unlink() - def list(self, path): - class Lister(AbstractAsyncLister): - iter = None + @override + def list(self, path: pathlib.Path) -> AbstractAsyncLister[pathlib.Path]: + class Lister(AbstractAsyncLister[pathlib.Path]): + iter: Optional[Generator[pathlib.Path, None, None]] = None + @override @universal_exception - async def __anext__(self): + async def __anext__(self) -> pathlib.Path: if self.iter is None: self.iter = path.glob("*") try: @@ -427,45 +509,58 @@ async def __anext__(self): return Lister(timeout=self.timeout) + @override @universal_exception - async def stat(self, path): + async def stat(self, path: pathlib.Path) -> os.stat_result: return path.stat() + @override @universal_exception - async def _open(self, path, *args, **kwargs): - return path.open(*args, **kwargs) + async def _open(self, path: pathlib.Path, mode: OpenMode) -> BinaryIO: + return path.open(mode) + @override @universal_exception @defend_file_methods - async def seek(self, file, *args, **kwargs): + async def seek(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int: return file.seek(*args, **kwargs) + @override @universal_exception @defend_file_methods - async def write(self, file, *args, **kwargs): + async def write(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int: return file.write(*args, **kwargs) + @override @universal_exception @defend_file_methods - async def read(self, file, *args, **kwargs): + async def read(self, file: BinaryIO, *args: Any, **kwargs: Any) -> bytes: return file.read(*args, **kwargs) + @override @universal_exception @defend_file_methods - async def close(self, file): + async def close(self, file: BinaryIO) -> None: return file.close() + @override @universal_exception - async def rename(self, source, destination): + async def rename(self, source: pathlib.Path, destination: pathlib.Path) -> pathlib.Path: return source.rename(destination) -def _blocking_io(f): +def _blocking_io( + f: Callable[_Ps, _T], +) -> Callable[_Ps, Coroutine[None, None, _T]]: @functools.wraps(f) - async def wrapper(self, *args, **kwargs): + async def wrapper( + *args: _Ps.args, + **kwargs: _Ps.kwargs, + ) -> _T: + self: "AsyncPathIO" = get_param((0, "self"), args, kwargs) return await asyncio.get_running_loop().run_in_executor( self.executor, - functools.partial(f, self, *args, **kwargs), + functools.partial(f, *args, **kwargs), ) return wrapper @@ -482,130 +577,195 @@ class AsyncPathIO(AbstractPathIO): :type executor: :py:class:`concurrent.futures.Executor` """ - def __init__(self, *args, executor=None, **kwargs): + executor: Optional[futures.Executor] + + def __init__( + self, + *args: Any, + executor: Optional[futures.Executor] = None, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.executor = executor + @override @universal_exception @with_timeout @_blocking_io - def exists(self, path): + def exists(self, path: pathlib.Path) -> bool: return path.exists() + @override @universal_exception @with_timeout @_blocking_io - def is_dir(self, path): + def is_dir(self, path: pathlib.Path) -> bool: return path.is_dir() + @override @universal_exception @with_timeout @_blocking_io - def is_file(self, path): + def is_file(self, path: pathlib.Path) -> bool: return path.is_file() + @override @universal_exception @with_timeout @_blocking_io - def mkdir(self, path, *, parents=False, exist_ok=False): + def mkdir( + self, + path: pathlib.Path, + *, + parents: bool = False, + exist_ok: bool = False, + ) -> None: return path.mkdir(parents=parents, exist_ok=exist_ok) + @override @universal_exception @with_timeout @_blocking_io - def rmdir(self, path): + def rmdir(self, path: pathlib.Path) -> None: return path.rmdir() + @override @universal_exception @with_timeout @_blocking_io - def unlink(self, path): + def unlink(self, path: pathlib.Path) -> None: return path.unlink() - def list(self, path): - class Lister(AbstractAsyncLister): - iter = None + @override + def list(self, path: pathlib.Path) -> AbstractAsyncLister[pathlib.Path]: + class Lister(AbstractAsyncLister[pathlib.Path]): + iter: Optional[SupportsNext[pathlib.Path]] = None - def __init__(self, *args, executor=None, **kwargs): + def __init__(self, *args: Any, executor: Optional[futures.Executor] = None, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.executor = executor - def worker(self): + def worker(self) -> pathlib.Path: try: - return next(self.iter) + return next(self.iter) # type: ignore except StopIteration: raise StopAsyncIteration + @override @universal_exception @with_timeout @_blocking_io - def __anext__(self): + def __anext__(self) -> pathlib.Path: if self.iter is None: self.iter = path.glob("*") return self.worker() return Lister(timeout=self.timeout, executor=self.executor) + @override @universal_exception @with_timeout @_blocking_io - def stat(self, path): + def stat(self, path: pathlib.Path) -> os.stat_result: return path.stat() + @override @universal_exception @with_timeout @_blocking_io - def _open(self, path, *args, **kwargs): - return path.open(*args, **kwargs) + def _open(self, path: pathlib.Path, *args: Any, **kwargs: Any) -> BinaryIO: + return cast(BinaryIO, path.open(*args, **kwargs)) + @override @universal_exception @defend_file_methods @with_timeout @_blocking_io - def seek(self, file, *args, **kwargs): + def seek(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int: return file.seek(*args, **kwargs) + @override @universal_exception @defend_file_methods @with_timeout @_blocking_io - def write(self, file, *args, **kwargs): + def write(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int: return file.write(*args, **kwargs) + @override @universal_exception @defend_file_methods @with_timeout @_blocking_io - def read(self, file, *args, **kwargs): + def read(self, file: BinaryIO, *args: Any, **kwargs: Any) -> bytes: return file.read(*args, **kwargs) + @override @universal_exception @defend_file_methods @with_timeout @_blocking_io - def close(self, file): + def close(self, file: BinaryIO) -> None: return file.close() + @override @universal_exception @with_timeout @_blocking_io - def rename(self, source, destination): + def rename(self, source: pathlib.Path, destination: pathlib.Path) -> pathlib.Path: return source.rename(destination) class Node: - def __init__(self, type, name, ctime=None, mtime=None, *, content): + type: _NodeType + name: str + ctime: int + mtime: int + content: Union[List["Node"], io.BytesIO] + + @overload + def __init__( + self, + type: Literal["file"], + name: str, + ctime: Optional[int] = None, + mtime: Optional[int] = None, + *, + content: io.BytesIO, + ) -> None: ... + + @overload + def __init__( + self, + type: Literal["dir"], + name: str, + ctime: Optional[int] = None, + mtime: Optional[int] = None, + *, + content: List["Node"], + ) -> None: ... + + def __init__( + self, + type: _NodeType, + name: str, + ctime: Optional[int] = None, + mtime: Optional[int] = None, + *, + content: Union[List["Node"], io.BytesIO], + ) -> None: self.type = type self.name = name self.ctime = ctime or int(time.time()) self.mtime = mtime or int(time.time()) self.content = content - def __repr__(self): + @override + def __repr__(self) -> str: return ( f"{self.__class__.__name__}(type={self.type!r}, " f"name={self.name!r}, ctime={self.ctime!r}, " - f"mtime={self.mtime!r}, content={self.content!r})" + f"mtime={self.mtime!r}, content={self.content!r}" ) @@ -615,132 +775,159 @@ class MemoryPathIO(AbstractPathIO): and probably not so fast as it can be. """ - Stats = collections.namedtuple( - "Stats", - ( - "st_size", - "st_ctime", - "st_mtime", - "st_nlink", - "st_mode", - ), - ) - - def __init__(self, *args, state=None, cwd=None, **kwargs): + class Stats(NamedTuple): + st_size: int + st_ctime: int + st_mtime: int + st_nlink: int + st_mode: int + + fs: Union[List[_NodeProtocol], io.BytesIO] + + def __init__( + self, + *args: Any, + state: Optional[Union[List[_NodeProtocol], io.BytesIO]] = None, + cwd: Optional[str] = None, + **kwargs: Any, + ) -> None: super().__init__(*args, **kwargs) self.cwd = pathlib.PurePosixPath(cwd or "/") if state is None: - self.fs = [Node("dir", "/", content=[])] + self.fs = [cast(_DirNodeProtocol, Node("dir", "/", content=[]))] else: self.fs = state @property - def state(self): + @override + def state(self) -> Union[List[_NodeProtocol], io.BytesIO]: return self.fs - def __repr__(self): + @override + def __repr__(self) -> str: return repr(self.fs) - def _absolute(self, path): + def _absolute(self, path: pathlib.PurePath) -> pathlib.PurePath: if not path.is_absolute(): path = self.cwd / path return path - def get_node(self, path): - nodes = self.fs + def get_node(self, path: pathlib.PurePath) -> Optional[_NodeProtocol]: + nodes: Union[List[_NodeProtocol], io.BytesIO] = self.fs node = None - path = self._absolute(path) - for part in path.parts: + path_ = self._absolute(path) + for part in path_.parts: if not isinstance(nodes, list): - return + return None for node in nodes: if node.name == part: nodes = node.content break else: - return + return None return node + @override @universal_exception - async def exists(self, path): + async def exists(self, path: pathlib.Path) -> bool: return self.get_node(path) is not None + @override @universal_exception - async def is_dir(self, path): + async def is_dir(self, path: pathlib.Path) -> bool: node = self.get_node(path) return not (node is None or node.type != "dir") + @override @universal_exception - async def is_file(self, path): + async def is_file(self, path: pathlib.Path) -> bool: node = self.get_node(path) return not (node is None or node.type != "file") + @override @universal_exception - async def mkdir(self, path, *, parents=False, exist_ok=False): - path = self._absolute(path) - node = self.get_node(path) + async def mkdir( + self, + path: pathlib.Path, + *, + parents: bool = False, + exist_ok: bool = False, + ) -> None: + path_ = self._absolute(path) + node = self.get_node(path_) if node: if node.type != "dir" or not exist_ok: raise FileExistsError elif not parents: - parent = self.get_node(path.parent) + parent = self.get_node(path_.parent) if parent is None: raise FileNotFoundError if parent.type != "dir": raise NotADirectoryError - node = Node("dir", path.name, content=[]) + node = cast(_DirNodeProtocol, Node("dir", path.name, content=[])) parent.content.append(node) else: - nodes = self.fs - for part in path.parts: + nodes: Union[List[_NodeProtocol], io.BytesIO] = self.fs + for part in path_.parts: if isinstance(nodes, list): for node in nodes: if node.name == part: nodes = node.content break else: - node = Node("dir", part, content=[]) + node = cast(_DirNodeProtocol, Node("dir", part, content=[])) nodes.append(node) nodes = node.content else: raise NotADirectoryError + @override @universal_exception - async def rmdir(self, path): + async def rmdir(self, path: pathlib.Path) -> None: node = self.get_node(path) if node is None: raise FileNotFoundError if node.type != "dir": raise NotADirectoryError + if node.content: raise OSError("Directory not empty") - parent = self.get_node(path.parent) + parent = cast(_DirNodeProtocol, self.get_node(path.parent)) for i, node in enumerate(parent.content): if node.name == path.name: break + else: + return None + parent.content.pop(i) + @override @universal_exception - async def unlink(self, path): + async def unlink(self, path: pathlib.Path) -> None: node = self.get_node(path) if node is None: raise FileNotFoundError if node.type != "file": raise IsADirectoryError - parent = self.get_node(path.parent) + parent = cast(_DirNodeProtocol, self.get_node(path.parent)) for i, node in enumerate(parent.content): if node.name == path.name: break + else: + return None + parent.content.pop(i) - def list(self, path): - class Lister(AbstractAsyncLister): - iter = None + @override + def list(self, path: pathlib.Path) -> "AbstractAsyncLister[pathlib.Path]": + class Lister(AbstractAsyncLister[pathlib.Path]): + iter: Optional[SupportsNext[pathlib.Path]] = None + @override @universal_exception - async def __anext__(cls): + async def __anext__(cls) -> pathlib.Path: if cls.iter is None: node = self.get_node(path) if node is None or node.type != "dir": @@ -756,8 +943,9 @@ async def __anext__(cls): return Lister(timeout=self.timeout) + @override @universal_exception - async def stat(self, path): + async def stat(self, path: pathlib.Path) -> "MemoryPathIO.Stats": node = self.get_node(path) if node is None: raise FileNotFoundError @@ -776,21 +964,22 @@ async def stat(self, path): mode, ) + @override @universal_exception - async def _open(self, path, mode="rb", *args, **kwargs): + async def _open(self, path: pathlib.Path, mode: OpenMode = "rb", *args: Any, **kwargs: Any) -> io.BytesIO: if mode == "rb": - node = self.get_node(path) + node = cast(Optional[_FileNodeProtocol], self.get_node(path)) if node is None: raise FileNotFoundError file_like = node.content file_like.seek(0, io.SEEK_SET) elif mode in ("wb", "ab", "r+b"): - node = self.get_node(path) + node = cast(Optional[_FileNodeProtocol], self.get_node(path)) if node is None: parent = self.get_node(path.parent) if parent is None or parent.type != "dir": raise FileNotFoundError - new_node = Node("file", path.name, content=io.BytesIO()) + new_node = cast(_FileNodeProtocol, Node("file", path.name, content=io.BytesIO())) parent.content.append(new_node) file_like = new_node.content elif node.type != "file": @@ -806,36 +995,42 @@ async def _open(self, path, mode="rb", *args, **kwargs): file_like.seek(0, io.SEEK_SET) else: raise ValueError(f"invalid mode: {mode}") - return file_like + return file_like # type: ignore + @override @universal_exception @defend_file_methods - async def seek(self, file, *args, **kwargs): + async def seek(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int: return file.seek(*args, **kwargs) + @override @universal_exception @defend_file_methods - async def write(self, file, *args, **kwargs): - file.write(*args, **kwargs) - file.mtime = int(time.time()) + async def write(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int: + result = file.write(*args, **kwargs) + file.mtime = int(time.time()) # type: ignore + return result + @override @universal_exception @defend_file_methods - async def read(self, file, *args, **kwargs): + async def read(self, file: BinaryIO, *args: Any, **kwargs: Any) -> bytes: return file.read(*args, **kwargs) + @override @universal_exception @defend_file_methods - async def close(self, file): + async def close(self, file: BinaryIO) -> None: pass + @override @universal_exception - async def rename(self, source, destination): + async def rename(self, source: pathlib.Path, destination: pathlib.Path) -> Optional[pathlib.Path]: if source != destination: - sparent = self.get_node(source.parent) - dparent = self.get_node(destination.parent) + sparent = cast(_DirNodeProtocol, self.get_node(source.parent)) + dparent = cast(Optional[_DirNodeProtocol], self.get_node(destination.parent)) snode = self.get_node(source) - if None in (snode, dparent): + if snode is None or dparent is None: raise FileNotFoundError for i, node in enumerate(sparent.content): if node.name == source.name: @@ -847,3 +1042,4 @@ async def rename(self, source, destination): break else: dparent.content.append(snode) + return None diff --git a/src/aioftp/server.py b/src/aioftp/server.py index 1830d78..a16f421 100644 --- a/src/aioftp/server.py +++ b/src/aioftp/server.py @@ -1,6 +1,5 @@ import abc import asyncio -import collections import enum import errno import functools @@ -8,19 +7,46 @@ import pathlib import socket import stat -import sys import time +from ssl import SSLContext +from typing import ( + Any, + Callable, + ClassVar, + Coroutine, + DefaultDict, + Dict, + Generic, + Iterable, + List, + NamedTuple, + Optional, + Protocol, + Sequence, + Set, + Tuple, + Type, + TypeVar, + Union, + cast, + final, +) + +from typing_extensions import Literal, NotRequired, ParamSpec, Self, TypeAlias, TypedDict, override -from . import errors, pathio +from . import client, errors, pathio from .common import ( DEFAULT_BLOCK_SIZE, END_OF_LINE, HALF_OF_YEAR_IN_SECONDS, + StreamIO, StreamThrottle, ThrottleStreamIO, setlocale, wrap_with_container, ) +from .types import OpenMode, StatsProtocol +from .utils import get_param __all__ = ( "Permission", @@ -36,14 +62,69 @@ "Server", ) -IS_PY37_PLUS = sys.version_info[:2] >= (3, 7) -if IS_PY37_PLUS: - get_current_task = asyncio.current_task -else: - get_current_task = asyncio.Task.current_task logger = logging.getLogger(__name__) +_PS = ParamSpec("_PS") +_T = TypeVar("_T") + +_PathType: TypeAlias = Union[str, pathlib.PurePosixPath] + + +_FAIL_CODE: TypeAlias = Literal["503", "425"] + + +class _ConnectionCondition(NamedTuple): + name: str + message: str + + +_Future: TypeAlias = "asyncio.Future[Any]" + + +class ConnectionProtocol(Protocol): + extra_workers: Set["asyncio.Task[Any]"] + client_host: str + server_host: str + passive_server_port: int + server_port: int + client_port: int + command_connection: ThrottleStreamIO + data_connection: ThrottleStreamIO + socket_timeout: int + wait_future_timeout: int + block_size: int + path_io_factory: pathio.AbstractPathIO + path_timeout: Optional[Union[int, float]] + user: "User" + response: Callable[..., None] + acquired: bool + restart_offset: int + current_directory: pathlib.PurePosixPath + _dispatcher: "asyncio.Task[Any]" + path_io: pathio.AbstractPathIO + passive_server: asyncio.base_events.Server + rename_from: pathlib.Path + logged: object + transfer_type: str + + future: "Connection.Container[Any]" + + def __getitem__(self, name: str) -> _Future: ... + + +class _PathCondition(NamedTuple): + name: str + fail: bool + message: str + + +class _MLSXFacts(TypedDict): + Size: int + Create: str + Modify: str + Type: NotRequired[Literal["dir", "file", "unknown"]] + class Permission: """ @@ -59,19 +140,26 @@ class Permission: :type writable: :py:class:`bool` """ - def __init__(self, path="/", *, readable=True, writable=True): + def __init__( + self, + path: Union[str, pathlib.PurePosixPath] = "/", + *, + readable: bool = True, + writable: bool = True, + ) -> None: self.path = pathlib.PurePosixPath(path) self.readable = readable self.writable = writable - def is_parent(self, other): + def is_parent(self, other: pathlib.PurePosixPath) -> bool: try: other.relative_to(self.path) return True except ValueError: return False - def __repr__(self): + @override + def __repr__(self) -> str: return f"{self.__class__.__name__}({self.path!r}, " f"readable={self.readable!r}, writable={self.writable!r})" @@ -116,18 +204,18 @@ class User: def __init__( self, - login=None, - password=None, + login: Optional[str] = None, + password: Optional[str] = None, *, - base_path=pathlib.Path("."), - home_path=pathlib.PurePosixPath("/"), - permissions=None, - maximum_connections=None, - read_speed_limit=None, - write_speed_limit=None, - read_speed_limit_per_connection=None, - write_speed_limit_per_connection=None, - ): + base_path: Union[str, pathlib.Path] = pathlib.Path("."), + home_path: Union[str, pathlib.PurePosixPath] = pathlib.PurePosixPath("/"), + permissions: Optional[Iterable[Permission]] = None, + maximum_connections: Optional[int] = None, + read_speed_limit: Optional[int] = None, + write_speed_limit: Optional[int] = None, + read_speed_limit_per_connection: Optional[int] = None, + write_speed_limit_per_connection: Optional[int] = None, + ) -> None: self.login = login self.password = password self.base_path = pathlib.Path(base_path) @@ -142,7 +230,7 @@ def __init__( # damn 80 symbols self.write_speed_limit_per_connection = write_speed_limit_per_connection - async def get_permissions(self, path): + async def get_permissions(self, path: Union[str, pathlib.PurePosixPath]) -> Permission: """ Return nearest parent permission for `path`. @@ -151,16 +239,17 @@ async def get_permissions(self, path): :rtype: :py:class:`aioftp.Permission` """ - path = pathlib.PurePosixPath(path) - parents = filter(lambda p: p.is_parent(path), self.permissions) + path_ = pathlib.PurePosixPath(path) + parents = filter(lambda p: p.is_parent(path_), self.permissions) perm = min( parents, - key=lambda p: len(path.relative_to(p.path).parts), + key=lambda p: len(path_.relative_to(p.path).parts), default=Permission(), ) return perm - def __repr__(self): + @override + def __repr__(self) -> str: return ( f"{self.__class__.__name__}({self.login!r}, " f"{self.password!r}, base_path={self.base_path!r}, " @@ -184,16 +273,16 @@ class AbstractUserManager(abc.ABC): :type timeout: :py:class:`float`, :py:class:`int` or :py:class:`None` """ - GetUserResponse = enum.Enum( - "UserManagerResponse", - "OK PASSWORD_REQUIRED ERROR", - ) + class GetUserResponse(enum.Enum): + OK = enum.auto() + PASSWORD_REQUIRED = enum.auto() + ERROR = enum.auto() - def __init__(self, *, timeout=None): + def __init__(self, *, timeout: Optional[Union[float, int]] = None) -> None: self.timeout = timeout @abc.abstractmethod - async def get_user(self, login): + async def get_user(self, login: str) -> Tuple[GetUserResponse, Optional[User], str]: """ :py:func:`asyncio.coroutine` @@ -204,7 +293,7 @@ async def get_user(self, login): """ @abc.abstractmethod - async def authenticate(self, user, password): + async def authenticate(self, user: User, password: str) -> bool: """ :py:func:`asyncio.coroutine` @@ -219,7 +308,7 @@ async def authenticate(self, user, password): :rtype: :py:class:`bool` """ - async def notify_logout(self, user): + async def notify_logout(self, user: User) -> None: """ :py:func:`asyncio.coroutine` @@ -239,12 +328,13 @@ class MemoryUserManager(AbstractUserManager): :py:class:`aioftp.User` """ - def __init__(self, users, *args, **kwargs): + def __init__(self, users: Optional[Sequence[User]], *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.users = users or [User()] self.available_connections = dict((user, AvailableConnections(user.maximum_connections)) for user in self.users) - async def get_user(self, login): + @override + async def get_user(self, login: str) -> Tuple[AbstractUserManager.GetUserResponse, Optional[User], str]: user = None for u in self.users: if u.login is None and user is None: @@ -269,17 +359,19 @@ async def get_user(self, login): info = "password required" if state != AbstractUserManager.GetUserResponse.ERROR: - self.available_connections[user].acquire() + self.available_connections[user].acquire() # type: ignore return state, user, info - async def authenticate(self, user, password): + @override + async def authenticate(self, user: User, password: str) -> bool: return user.password == password - async def notify_logout(self, user): + @override + async def notify_logout(self, user: User) -> None: self.available_connections[user].release() -class Connection(collections.defaultdict): +class Connection(DefaultDict[str, _Future]): """ Connection state container for transparent work with futures for async wait @@ -311,37 +403,40 @@ class Connection(collections.defaultdict): __slots__ = ("future",) - class Container: - def __init__(self, storage): + class Container(Generic[_T]): + def __init__(self, storage: Dict[str, "asyncio.Future[_T]"]): self.storage = storage - def __getattr__(self, name): + def __getattr__(self, name: str) -> "asyncio.Future[_T]": return self.storage[name] - def __delattr__(self, name): + @override + def __delattr__(self, name: str) -> None: self.storage.pop(name) - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: super().__init__(asyncio.Future) - self.future = Connection.Container(self) + self.future = Connection.Container[Any](self) for k, v in kwargs.items(): self[k].set_result(v) - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: if name in self: return self[name].result() else: raise AttributeError(f"{name!r} not in storage") - def __setattr__(self, name, value): + @override + def __setattr__(self, name: str, value: Any) -> None: if name in Connection.__slots__: super().__setattr__(name, value) else: if self[name].done(): - self[name] = super().default_factory() + self[name] = cast(Callable[[], _Future], super().default_factory)() self[name].set_result(value) - def __delattr__(self, name): + @override + def __delattr__(self, name: str) -> None: if name in self: self.pop(name) @@ -355,10 +450,10 @@ class AvailableConnections: :type value: :py:class:`int` or :py:class:`None` """ - def __init__(self, value=None): + def __init__(self, value: Optional[int] = None): self.value = self.maximum_value = value - def locked(self): + def locked(self) -> bool: """ Returns True if semaphore-like can not be acquired. @@ -366,7 +461,7 @@ def locked(self): """ return self.value == 0 - def acquire(self): + def acquire(self) -> None: """ Acquire, decrementing the internal counter by one. """ @@ -375,13 +470,13 @@ def acquire(self): if self.value < 0: raise ValueError("Too many acquires") - def release(self): + def release(self) -> None: """ Release, incrementing the internal counter by one. """ if self.value is not None: self.value += 1 - if self.value > self.maximum_value: + if self.value > cast(int, self.maximum_value): # self.maximum_value exist if exist self.value raise ValueError("Too many releases") @@ -424,27 +519,59 @@ class ConnectionConditions: ... ... """ - user_required = ("user", "no user (use USER firstly)") - login_required = ("logged", "not logged in") - passive_server_started = ( + user_required: ClassVar[_ConnectionCondition] = _ConnectionCondition("user", "no user (use USER firstly)") + login_required: ClassVar[_ConnectionCondition] = _ConnectionCondition("logged", "not logged in") + passive_server_started: ClassVar[_ConnectionCondition] = _ConnectionCondition( "passive_server", "no listen socket created (use PASV firstly)", ) - data_connection_made = ("data_connection", "no data connection made") - rename_from_required = ("rename_from", "no filename (use RNFR firstly)") + data_connection_made: ClassVar[_ConnectionCondition] = _ConnectionCondition( + "data_connection", + "no data connection made", + ) + rename_from_required: ClassVar[_ConnectionCondition] = _ConnectionCondition( + "rename_from", + "no filename (use RNFR firstly)", + ) - def __init__(self, *fields, wait=False, fail_code="503", fail_info=None): + def __init__( + self, + *fields: _ConnectionCondition, + wait: bool = False, + fail_code: _FAIL_CODE = "503", + fail_info: Optional[str] = None, + ) -> None: self.fields = fields self.wait = wait self.fail_code = fail_code self.fail_info = fail_info - def __call__(self, f): + def __call__( + self, + f: Callable[_PS, Coroutine[None, None, _T]], + ) -> Callable[ + _PS, + Coroutine[ + None, + None, + Union[ + _T, + Literal[True], + ], + ], + ]: + ctx = self + @functools.wraps(f) - async def wrapper(cls, connection, rest, *args): - futures = {connection[name]: msg for name, msg in self.fields} + async def wrapper( + *args: _PS.args, + **kwargs: _PS.kwargs, + ) -> Union[_T, Literal[True]]: + connection: ConnectionProtocol = get_param((1, "connection"), args, kwargs) + + futures: Dict[_Future, str] = {connection[name]: msg for name, msg in ctx.fields} aggregate = asyncio.gather(*futures) - if self.wait: + if ctx.wait: timeout = connection.wait_future_timeout else: timeout = 0 @@ -457,17 +584,18 @@ async def wrapper(cls, connection, rest, *args): except asyncio.TimeoutError: for future, message in futures.items(): if not future.done(): - if self.fail_info is None: + if ctx.fail_info is None: info = f"bad sequence of commands ({message})" else: - info = self.fail_info - connection.response(self.fail_code, info) + info = ctx.fail_info + connection.response(ctx.fail_code, info) return True - return await f(cls, connection, rest, *args) + return await f(*args, **kwargs) return wrapper +@final class PathConditions: """ Decorator for checking paths. Available options: @@ -486,28 +614,51 @@ class PathConditions: ... ... """ - path_must_exists = ("exists", False, "path does not exists") - path_must_not_exists = ("exists", True, "path already exists") - path_must_be_dir = ("is_dir", False, "path is not a directory") - path_must_be_file = ("is_file", False, "path is not a file") + path_must_exists: ClassVar[_PathCondition] = _PathCondition("exists", False, "path does not exists") + path_must_not_exists: ClassVar[_PathCondition] = _PathCondition("exists", True, "path already exists") + path_must_be_dir: ClassVar[_PathCondition] = _PathCondition("is_dir", False, "path is not a directory") + path_must_be_file: ClassVar[_PathCondition] = _PathCondition("is_file", False, "path is not a file") - def __init__(self, *conditions): + def __init__(self, *conditions: _PathCondition) -> None: self.conditions = conditions - def __call__(self, f): + def __call__( + self, + f: Callable[_PS, Coroutine[None, None, _T]], + ) -> Callable[ + _PS, + Coroutine[ + None, + None, + Union[ + _T, + Literal[True], + ], + ], + ]: + ctx = self + @functools.wraps(f) - async def wrapper(cls, connection, rest, *args): - real_path, virtual_path = cls.get_paths(connection, rest) - for name, fail, message in self.conditions: + async def wrapper( + *args: _PS.args, + **kwargs: _PS.kwargs, + ) -> Union[_T, Literal[True]]: + self: Server = get_param((0, "self"), args, kwargs) + connection: ConnectionProtocol = get_param((1, "connection"), args, kwargs) + rest: pathlib.PurePosixPath = get_param((2, "rest"), args, kwargs) + + real_path, _ = self.get_paths(connection, rest) + for name, fail, message in ctx.conditions: coro = getattr(connection.path_io, name) if await coro(real_path) == fail: connection.response("550", message) return True - return await f(cls, connection, rest, *args) + return await f(*args, **kwargs) return wrapper +@final class PathPermissions: """ Decorator for checking path permissions. There is two permissions right @@ -528,29 +679,52 @@ class PathPermissions: ... ... """ - readable = "readable" - writable = "writable" + readable: ClassVar[str] = "readable" + writable: ClassVar[str] = "writable" - def __init__(self, *permissions): + def __init__(self, *permissions: str) -> None: self.permissions = permissions - def __call__(self, f): + def __call__( + self, + f: Callable[_PS, Coroutine[None, None, _T]], + ) -> Callable[ + _PS, + Coroutine[ + None, + None, + Optional[ + Union[ + _T, + Literal[True], + ] + ], + ], + ]: + ctx = self + @functools.wraps(f) - async def wrapper(cls, connection, rest, *args): - real_path, virtual_path = cls.get_paths(connection, rest) - current_permission = await connection.user.get_permissions( - virtual_path, - ) - for permission in self.permissions: + async def wrapper( + *args: _PS.args, + **kwargs: _PS.kwargs, + ) -> Optional[Union[_T, Literal[True]]]: + self: Server = get_param((0, "self"), args, kwargs) + connection: ConnectionProtocol = get_param((1, "connection"), args, kwargs) + rest: pathlib.PurePosixPath = get_param((2, "rest"), args, kwargs) + + _, virtual_path = self.get_paths(connection, rest) + current_permission = await connection.user.get_permissions(virtual_path) + for permission in ctx.permissions: if not getattr(current_permission, permission): connection.response("550", "permission denied") return True - return await f(cls, connection, rest, *args) + return await f(*args, **kwargs) + return None return wrapper -def worker(f): +def worker(f: Callable[_PS, Coroutine[None, None, None]]) -> Callable[_PS, Coroutine[None, None, None]]: """ Decorator. Abortable worker. If wrapped task will be cancelled by dispatcher, decorator will send ftp codes of successful interrupt. @@ -564,12 +738,15 @@ def worker(f): """ @functools.wraps(f) - async def wrapper(cls, connection, rest): + async def wrapper(*args: _PS.args, **kwargs: _PS.kwargs) -> None: + connection: ConnectionProtocol = get_param((1, "connection"), args, kwargs) + try: - await f(cls, connection, rest) + await f(*args, **kwargs) except asyncio.CancelledError: connection.response("426", "transfer aborted") connection.response("226", "abort successful") + return None return wrapper @@ -641,25 +818,36 @@ class Server: :type ssl: :py:class:`ssl.SSLContext` """ + available_data_ports: Optional["asyncio.PriorityQueue[Tuple[int, int]]"] + throttle_per_user: Dict[User, StreamThrottle] + commands_mapping: Dict[str, Callable[..., Any]] + connections: Dict[StreamIO, ConnectionProtocol] + def __init__( self, - users=None, + users: Optional[ + Union[ + Tuple[User], + List[User], + AbstractUserManager, + ] + ] = None, *, - block_size=DEFAULT_BLOCK_SIZE, - socket_timeout=None, - idle_timeout=None, - wait_future_timeout=1, - path_timeout=None, - path_io_factory=pathio.PathIO, - maximum_connections=None, - read_speed_limit=None, - write_speed_limit=None, - read_speed_limit_per_connection=None, - write_speed_limit_per_connection=None, - ipv4_pasv_forced_response_address=None, - data_ports=None, - encoding="utf-8", - ssl=None, + block_size: int = DEFAULT_BLOCK_SIZE, + socket_timeout: Optional[Union[int, float]] = None, + idle_timeout: Optional[Union[int, float]] = None, + wait_future_timeout: Optional[Union[int, float]] = 1, + path_timeout: Optional[Union[int, float]] = None, + path_io_factory: Type[pathio.AbstractPathIO] = pathio.PathIO, + maximum_connections: Optional[int] = None, + read_speed_limit: Optional[int] = None, + write_speed_limit: Optional[int] = None, + read_speed_limit_per_connection: Optional[int] = None, + write_speed_limit_per_connection: Optional[int] = None, + ipv4_pasv_forced_response_address: Optional[str] = None, + data_ports: Optional[Iterable[int]] = None, + encoding: str = "utf-8", + ssl: Optional[SSLContext] = None, ): self.block_size = block_size self.socket_timeout = socket_timeout @@ -720,7 +908,7 @@ def __init__( "user": self.user, } - async def start(self, host=None, port=0, **kwargs): + async def start(self, host: Optional[str] = None, port: int = 0, **kwargs: Any) -> None: """ :py:func:`asyncio.coroutine` @@ -755,7 +943,7 @@ async def start(self, host=None, port=0, **kwargs): self.server_host = host logger.info("serving on %s:%s", host, port) - async def serve_forever(self): + async def serve_forever(self) -> None: """ :py:func:`asyncio.coroutine` @@ -763,7 +951,7 @@ async def serve_forever(self): """ return await self.server.serve_forever() - async def run(self, host=None, port=0, **kwargs): + async def run(self, host: Optional[str] = None, port: int = 0, **kwargs: Any) -> None: """ :py:func:`asyncio.coroutine` @@ -785,13 +973,13 @@ async def run(self, host=None, port=0, **kwargs): await self.close() @property - def address(self): + def address(self) -> Tuple[Optional[str], int]: """ Server listen socket host and port as :py:class:`tuple` """ return self.server_host, self.server_port - async def close(self): + async def close(self) -> None: """ :py:func:`asyncio.coroutine` @@ -800,16 +988,22 @@ async def close(self): self.server.close() tasks = [asyncio.create_task(self.server.wait_closed())] for connection in self.connections.values(): - connection._dispatcher.cancel() - tasks.append(connection._dispatcher) + connection._dispatcher.cancel() # pyright: ignore[reportPrivateUsage] + tasks.append(connection._dispatcher) # pyright: ignore[reportPrivateUsage] logger.debug("waiting for %d tasks", len(tasks)) await asyncio.wait(tasks) - async def write_line(self, stream, line): + async def write_line(self, stream: StreamIO, line: str) -> None: logger.debug(line) await stream.write((line + END_OF_LINE).encode(encoding=self.encoding)) - async def write_response(self, stream, code, lines="", list=False): + async def write_response( + self, + stream: StreamIO, + code: client.Code, + lines: Iterable[str] = "", + list: bool = False, + ) -> None: """ :py:func:`asyncio.coroutine` @@ -842,7 +1036,11 @@ async def write_response(self, stream, code, lines="", list=False): await write(code + "-" + line) await write(code + " " + tail) - async def parse_command(self, stream, censor_commands=("pass",)): + async def parse_command( + self, + stream: StreamIO, + censor_commands: Tuple[str] = ("pass",), + ) -> Tuple[str, str]: """ :py:func:`asyncio.coroutine` @@ -871,7 +1069,7 @@ async def parse_command(self, stream, censor_commands=("pass",)): return cmd.lower(), rest - async def response_writer(self, stream, response_queue): + async def response_writer(self, stream: StreamIO, response_queue: "asyncio.Queue[Any]") -> None: """ :py:func:`asyncio.coroutine` @@ -892,7 +1090,7 @@ async def response_writer(self, stream, response_queue): finally: response_queue.task_done() - async def dispatcher(self, reader, writer): + async def dispatcher(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: """ :py:func:`asyncio.coroutine` @@ -911,7 +1109,7 @@ async def dispatcher(self, reader, writer): read_timeout=self.idle_timeout, write_timeout=self.socket_timeout, ) - response_queue = asyncio.Queue() + response_queue: "asyncio.Queue[Any]" = asyncio.Queue() connection = Connection( client_host=host, client_port=port, @@ -926,21 +1124,22 @@ async def dispatcher(self, reader, writer): path_io_factory=self.path_io_factory, path_timeout=self.path_timeout, extra_workers=set(), - response=lambda *args: response_queue.put_nowait(args), + response=lambda *args: response_queue.put_nowait(args), # pyright: ignore[reportUnknownLambdaType] acquired=False, restart_offset=0, - _dispatcher=get_current_task(), + _dispatcher=asyncio.current_task(), ) connection.path_io = self.path_io_factory( timeout=self.path_timeout, connection=connection, ) + connection_: ConnectionProtocol = cast(Any, connection) pending = { - asyncio.create_task(self.greeting(connection, "")), + asyncio.create_task(self.greeting(connection_, "")), asyncio.create_task(self.response_writer(stream, response_queue)), asyncio.create_task(self.parse_command(stream)), } - self.connections[key] = connection + self.connections[key] = connection_ try: while True: done, pending = await asyncio.wait( @@ -949,6 +1148,7 @@ async def dispatcher(self, reader, writer): ) connection.extra_workers -= done for task in done: + result: Optional[Union[bool, Tuple[str, str]]] = None try: result = task.result() except errors.PathIOError: @@ -981,7 +1181,7 @@ async def dispatcher(self, reader, writer): logger.exception("dispatcher caught exception") finally: logger.info("closing connection from %s:%s", host, port) - tasks_to_wait = [] + tasks_to_wait: List[asyncio.Task[Any]] = [] if not asyncio.get_running_loop().is_closed(): for task in pending | connection.extra_workers: task.cancel() @@ -1006,7 +1206,7 @@ async def dispatcher(self, reader, writer): await asyncio.wait(tasks_to_wait) @staticmethod - def get_paths(connection, path): + def get_paths(connection: ConnectionProtocol, path: _PathType) -> Tuple[pathlib.Path, pathlib.PurePosixPath]: """ Return *real* and *virtual* paths, resolves ".." with "up" action. *Real* path is path for path_io, when *virtual* deals with @@ -1040,7 +1240,7 @@ def get_paths(connection, path): resolved_virtual_path = pathlib.PurePosixPath("/") return real_path, resolved_virtual_path - async def greeting(self, connection, rest): + async def greeting(self, connection: ConnectionProtocol, rest: object) -> bool: if self.available_connections.locked(): ok, code, info = False, "421", "Too many connections" else: @@ -1050,7 +1250,7 @@ async def greeting(self, connection, rest): connection.response(code, info) return ok - async def user(self, connection, rest): + async def user(self, connection: ConnectionProtocol, rest: str) -> Literal[True]: if connection.future.user.done(): await self.user_manager.notify_logout(connection.user) del connection.user @@ -1059,15 +1259,14 @@ async def user(self, connection, rest): if state == AbstractUserManager.GetUserResponse.OK: code = "230" connection.logged = True - connection.user = user + connection.user = cast(User, user) elif state == AbstractUserManager.GetUserResponse.PASSWORD_REQUIRED: code = "331" - connection.user = user + connection.user = cast(User, user) elif state == AbstractUserManager.GetUserResponse.ERROR: code = "530" else: - message = f"Unknown response {state}" - raise NotImplementedError(message) + raise NotImplementedError(f"Unknown response {state}") if connection.future.user.done(): connection.current_directory = connection.user.home_path @@ -1089,7 +1288,7 @@ async def user(self, connection, rest): return True @ConnectionConditions(ConnectionConditions.user_required) - async def pass_(self, connection, rest): + async def pass_(self, connection: ConnectionProtocol, rest: str) -> Literal[True]: if connection.future.logged.done(): code, info = "503", "already logged in" elif await self.user_manager.authenticate(connection.user, rest): @@ -1100,12 +1299,12 @@ async def pass_(self, connection, rest): connection.response(code, info) return True - async def quit(self, connection, rest): + async def quit(self, connection: ConnectionProtocol, rest: Any) -> Literal[False]: connection.response("221", "bye") return False @ConnectionConditions(ConnectionConditions.login_required) - async def pwd(self, connection, rest): + async def pwd(self, connection: ConnectionProtocol, rest: Any) -> Literal[True]: code, info = "257", f'"{connection.current_directory}"' connection.response(code, info) return True @@ -1116,21 +1315,21 @@ async def pwd(self, connection, rest): PathConditions.path_must_be_dir, ) @PathPermissions(PathPermissions.readable) - async def cwd(self, connection, rest): - real_path, virtual_path = self.get_paths(connection, rest) + async def cwd(self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> Literal[True]: + _, virtual_path = self.get_paths(connection, rest) connection.current_directory = virtual_path connection.response("250", "") return True @ConnectionConditions(ConnectionConditions.login_required) - async def cdup(self, connection, rest): + async def cdup(self, connection: ConnectionProtocol, rest: object) -> Optional[Literal[True]]: return await self.cwd(connection, connection.current_directory.parent) @ConnectionConditions(ConnectionConditions.login_required) @PathConditions(PathConditions.path_must_not_exists) @PathPermissions(PathPermissions.writable) - async def mkd(self, connection, rest): - real_path, virtual_path = self.get_paths(connection, rest) + async def mkd(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]: + real_path, _ = self.get_paths(connection, rest) await connection.path_io.mkdir(real_path, parents=True) connection.response("257", "") return True @@ -1141,24 +1340,25 @@ async def mkd(self, connection, rest): PathConditions.path_must_be_dir, ) @PathPermissions(PathPermissions.writable) - async def rmd(self, connection, rest): - real_path, virtual_path = self.get_paths(connection, rest) + async def rmd(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]: + real_path, _ = self.get_paths(connection, rest) await connection.path_io.rmdir(real_path) connection.response("250", "") return True @staticmethod - def _format_mlsx_time(local_seconds): + def _format_mlsx_time(local_seconds: float) -> str: return time.strftime("%Y%m%d%H%M%S", time.gmtime(local_seconds)) - def _build_mlsx_facts_from_stats(self, stats): + def _build_mlsx_facts_from_stats(self, stats: StatsProtocol) -> _MLSXFacts: return { "Size": stats.st_size, "Create": self._format_mlsx_time(stats.st_ctime), "Modify": self._format_mlsx_time(stats.st_mtime), } - async def build_mlsx_string(self, connection, path): + async def build_mlsx_string(self, connection: ConnectionProtocol, path: pathlib.Path) -> str: + facts: Union[Dict[Any, Any], _MLSXFacts] if not await connection.path_io.exists(path): facts = {} else: @@ -1172,7 +1372,7 @@ async def build_mlsx_string(self, connection, path): facts["Type"] = "unknown" s = "" - for name, value in facts.items(): + for name, value in cast(_MLSXFacts, facts).items(): s += f"{name}={value};" s += " " + path.name return s @@ -1183,7 +1383,7 @@ async def build_mlsx_string(self, connection, path): ) @PathConditions(PathConditions.path_must_exists) @PathPermissions(PathPermissions.readable) - async def mlsd(self, connection, rest): + async def mlsd(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]: @ConnectionConditions( ConnectionConditions.data_connection_made, wait=True, @@ -1191,7 +1391,7 @@ async def mlsd(self, connection, rest): fail_info="Can't open data connection", ) @worker - async def mlsd_worker(self, connection, rest): + async def mlsd_worker(self: Self, connection: ConnectionProtocol, rest: _PathType) -> None: stream = connection.data_connection del connection.data_connection async with stream: @@ -1200,17 +1400,16 @@ async def mlsd_worker(self, connection, rest): b = (s + END_OF_LINE).encode(encoding=self.encoding) await stream.write(b) connection.response("200", "mlsd transfer done") - return True - real_path, virtual_path = self.get_paths(connection, rest) + real_path, _ = self.get_paths(connection, rest) coro = mlsd_worker(self, connection, rest) - task = asyncio.create_task(coro) + task: asyncio.Task[Optional[Literal[True]]] = asyncio.create_task(coro) connection.extra_workers.add(task) connection.response("150", "mlsd transfer started") return True @staticmethod - def build_list_mtime(st_mtime, now=None): + def build_list_mtime(st_mtime: float, now: Optional[float] = None) -> str: if now is None: now = time.time() mtime = time.localtime(st_mtime) @@ -1221,7 +1420,7 @@ def build_list_mtime(st_mtime, now=None): s = time.strftime("%b %e %Y", mtime) return s - async def build_list_string(self, connection, path): + async def build_list_string(self, connection: ConnectionProtocol, path: pathlib.Path) -> str: stats = await connection.path_io.stat(path) mtime = self.build_list_mtime(stats.st_mtime) fields = ( @@ -1242,7 +1441,7 @@ async def build_list_string(self, connection, path): ) @PathConditions(PathConditions.path_must_exists) @PathPermissions(PathPermissions.readable) - async def list(self, connection, rest): + async def list(self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> Literal[True]: @ConnectionConditions( ConnectionConditions.data_connection_made, wait=True, @@ -1250,7 +1449,7 @@ async def list(self, connection, rest): fail_info="Can't open data connection", ) @worker - async def list_worker(self, connection, rest): + async def list_worker(self: Self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> None: stream = connection.data_connection del connection.data_connection async with stream: @@ -1262,11 +1461,10 @@ async def list_worker(self, connection, rest): b = (s + END_OF_LINE).encode(encoding=self.encoding) await stream.write(b) connection.response("226", "list transfer done") - return True - real_path, virtual_path = self.get_paths(connection, rest) + real_path, _ = self.get_paths(connection, rest) coro = list_worker(self, connection, rest) - task = asyncio.create_task(coro) + task: asyncio.Task[Optional[Literal[True]]] = asyncio.create_task(coro) connection.extra_workers.add(task) connection.response("150", "list transfer started") return True @@ -1274,8 +1472,8 @@ async def list_worker(self, connection, rest): @ConnectionConditions(ConnectionConditions.login_required) @PathConditions(PathConditions.path_must_exists) @PathPermissions(PathPermissions.readable) - async def mlst(self, connection, rest): - real_path, virtual_path = self.get_paths(connection, rest) + async def mlst(self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> Literal[True]: + real_path, _ = self.get_paths(connection, rest) s = await self.build_mlsx_string(connection, real_path) connection.response("250", ["start", s, "end"], True) return True @@ -1283,8 +1481,8 @@ async def mlst(self, connection, rest): @ConnectionConditions(ConnectionConditions.login_required) @PathConditions(PathConditions.path_must_exists) @PathPermissions(PathPermissions.writable) - async def rnfr(self, connection, rest): - real_path, virtual_path = self.get_paths(connection, rest) + async def rnfr(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]: + real_path, _ = self.get_paths(connection, rest) connection.rename_from = real_path connection.response("350", "rename from accepted") return True @@ -1295,8 +1493,8 @@ async def rnfr(self, connection, rest): ) @PathConditions(PathConditions.path_must_not_exists) @PathPermissions(PathPermissions.writable) - async def rnto(self, connection, rest): - real_path, virtual_path = self.get_paths(connection, rest) + async def rnto(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]: + real_path, _ = self.get_paths(connection, rest) rename_from = connection.rename_from del connection.rename_from await connection.path_io.rename(rename_from, real_path) @@ -1309,8 +1507,8 @@ async def rnto(self, connection, rest): PathConditions.path_must_be_file, ) @PathPermissions(PathPermissions.writable) - async def dele(self, connection, rest): - real_path, virtual_path = self.get_paths(connection, rest) + async def dele(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]: + real_path, _ = self.get_paths(connection, rest) await connection.path_io.unlink(real_path) connection.response("250", "") return True @@ -1320,7 +1518,12 @@ async def dele(self, connection, rest): ConnectionConditions.passive_server_started, ) @PathPermissions(PathPermissions.writable) - async def stor(self, connection, rest, mode="wb"): + async def stor( + self, + connection: ConnectionProtocol, + rest: _PathType, + mode: OpenMode = "wb", + ) -> Literal[True]: @ConnectionConditions( ConnectionConditions.data_connection_made, wait=True, @@ -1328,7 +1531,7 @@ async def stor(self, connection, rest, mode="wb"): fail_info="Can't open data connection", ) @worker - async def stor_worker(self, connection, rest): + async def stor_worker(self: Self, connection: ConnectionProtocol, rest: object) -> None: stream = connection.data_connection del connection.data_connection if connection.restart_offset: @@ -1341,17 +1544,16 @@ async def stor_worker(self, connection, rest): await file_out.seek(connection.restart_offset) async for data in stream.iter_by_block(connection.block_size): await file_out.write(data) - connection.response("226", "data transfer done") - return True + connection.response("226", ["data transfer done"]) - real_path, virtual_path = self.get_paths(connection, rest) + real_path, _ = self.get_paths(connection, rest) if await connection.path_io.is_dir(real_path.parent): coro = stor_worker(self, connection, rest) task = asyncio.create_task(coro) connection.extra_workers.add(task) - code, info = "150", "data transfer started" + code, info = "150", ["data transfer started"] else: - code, info = "550", "path unreachable" + code, info = "550", ["path unreachable"] connection.response(code, info) return True @@ -1364,7 +1566,7 @@ async def stor_worker(self, connection, rest): PathConditions.path_must_be_file, ) @PathPermissions(PathPermissions.readable) - async def retr(self, connection, rest): + async def retr(self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> Literal[True]: @ConnectionConditions( ConnectionConditions.data_connection_made, wait=True, @@ -1372,7 +1574,7 @@ async def retr(self, connection, rest): fail_info="Can't open data connection", ) @worker - async def retr_worker(self, connection, rest): + async def retr_worker(self: Self, connection: ConnectionProtocol, rest: object) -> None: stream = connection.data_connection del connection.data_connection file_in = connection.path_io.open(real_path, mode="rb") @@ -1382,9 +1584,8 @@ async def retr_worker(self, connection, rest): async for data in file_in.iter_by_block(connection.block_size): await stream.write(data) connection.response("226", "data transfer done") - return True - real_path, virtual_path = self.get_paths(connection, rest) + real_path, _ = self.get_paths(connection, rest) coro = retr_worker(self, connection, rest) task = asyncio.create_task(coro) connection.extra_workers.add(task) @@ -1392,7 +1593,7 @@ async def retr_worker(self, connection, rest): return True @ConnectionConditions(ConnectionConditions.login_required) - async def type(self, connection, rest): + async def type(self, connection: ConnectionProtocol, rest: str) -> Literal[True]: if rest in ("I", "A"): connection.transfer_type = rest code, info = "200", "" @@ -1402,12 +1603,12 @@ async def type(self, connection, rest): return True @ConnectionConditions(ConnectionConditions.login_required) - async def pbsz(self, connection, rest): + async def pbsz(self, connection: ConnectionProtocol, rest: object) -> Literal[True]: connection.response("200", "") return True @ConnectionConditions(ConnectionConditions.login_required) - async def prot(self, connection, rest): + async def prot(self, connection: ConnectionProtocol, rest: str) -> Literal[True]: if rest == "P": code, info = "200", "" else: @@ -1415,9 +1616,16 @@ async def prot(self, connection, rest): connection.response(code, info) return True - async def _start_passive_server(self, connection, handler_callback): + async def _start_passive_server( + self, + connection: ConnectionProtocol, + handler_callback: Callable[ + [asyncio.StreamReader, asyncio.StreamWriter], + Coroutine[None, None, None], + ], + ) -> asyncio.base_events.Server: if self.available_data_ports is not None: - viewed_ports = set() + viewed_ports: Set[int] = set() while True: try: priority, port = self.available_data_ports.get_nowait() @@ -1436,7 +1644,7 @@ async def _start_passive_server(self, connection, handler_callback): except asyncio.QueueEmpty: raise errors.NoAvailablePort except OSError as err: - self.available_data_ports.put_nowait((priority + 1, port)) + self.available_data_ports.put_nowait((priority + 1, port)) # type: ignore if err.errno != errno.EADDRINUSE: raise else: @@ -1447,11 +1655,11 @@ async def _start_passive_server(self, connection, handler_callback): ssl=self.ssl, **self._start_server_extra_arguments, ) - return passive_server + return passive_server # type: ignore @ConnectionConditions(ConnectionConditions.login_required) - async def pasv(self, connection, rest): - async def handler(reader, writer): + async def pasv(self, connection: ConnectionProtocol, rest: object) -> bool: + async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: if connection.future.data_connection.done(): writer.close() else: @@ -1486,7 +1694,7 @@ async def handler(reader, writer): connection.response("503", ["this server started in ipv6 mode"]) return False - nums = tuple(map(int, host.split("."))) + (port >> 8, port & 0xFF) + nums = tuple(map(int, host.split("."))) + (port >> 8, port & 0xFF) # type: ignore info.append(f"({','.join(map(str, nums))})") if connection.future.data_connection.done(): connection.data_connection.close() @@ -1495,8 +1703,8 @@ async def handler(reader, writer): return True @ConnectionConditions(ConnectionConditions.login_required) - async def epsv(self, connection, rest): - async def handler(reader, writer): + async def epsv(self, connection: ConnectionProtocol, rest: object) -> bool: + async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None: if connection.future.data_connection.done(): writer.close() else: @@ -1527,7 +1735,7 @@ async def handler(reader, writer): _, port, *_ = sock.getsockname() break - info[0] += f" (|||{port}|)" + info[0] += f" (|||{port}|)" # type: ignore if connection.future.data_connection.done(): connection.data_connection.close() del connection.data_connection @@ -1535,7 +1743,7 @@ async def handler(reader, writer): return True @ConnectionConditions(ConnectionConditions.login_required) - async def abor(self, connection, rest): + async def abor(self, connection: ConnectionProtocol, rest: object) -> Literal[True]: if connection.extra_workers: for worker in connection.extra_workers: worker.cancel() @@ -1543,10 +1751,10 @@ async def abor(self, connection, rest): connection.response("226", "nothing to abort") return True - async def appe(self, connection, rest): + async def appe(self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> Optional[Literal[True]]: return await self.stor(connection, rest, "ab") - async def rest(self, connection, rest): + async def rest(self, connection: ConnectionProtocol, rest: str) -> Literal[True]: if rest.isdigit(): connection.restart_offset = int(rest) connection.response("350", f"restarting at {rest}") @@ -1556,7 +1764,7 @@ async def rest(self, connection, rest): connection.response("501", message) return True - async def syst(self, connection, rest): + async def syst(self, connection: ConnectionProtocol, rest: object) -> Literal[True]: """Return system type (always returns UNIX type: L8).""" connection.response("215", "UNIX Type: L8") return True diff --git a/src/aioftp/types.py b/src/aioftp/types.py new file mode 100644 index 0000000..5ac8dd3 --- /dev/null +++ b/src/aioftp/types.py @@ -0,0 +1,55 @@ +from typing import TYPE_CHECKING, Any, Generator, Literal, NewType, Protocol, Tuple, TypeVar + +from typing_extensions import TypeAlias + +if TYPE_CHECKING: + from .client import Code + +_T_co = TypeVar("_T_co", covariant=True) + +__all__ = ("NotEmptyCodes", "check_not_empty_codes") + + +class StatsProtocol(Protocol): + @property + def st_size(self) -> int: + raise NotImplementedError + + @property + def st_ctime(self) -> float: + raise NotImplementedError + + @property + def st_mtime(self) -> float: + raise NotImplementedError + + @property + def st_nlink(self) -> int: + raise NotImplementedError + + @property + def st_mode(self) -> int: + raise NotImplementedError + + +class AsyncEnterableProtocol(Protocol[_T_co]): + async def __aenter__(self) -> _T_co: + raise NotImplementedError + + async def __aexit__(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError + + def __await__(self) -> Generator[None, None, _T_co]: + raise NotImplementedError + + +OpenMode: TypeAlias = Literal["rb", "wb", "ab", "r+b"] + +NotEmptyCodes = NewType("NotEmptyCodes", Tuple["Code", ...]) + + +def check_not_empty_codes(codes: Tuple["Code", ...]) -> NotEmptyCodes: + if not codes: + raise ValueError("Codes should not be empty") + + return NotEmptyCodes(codes) diff --git a/src/aioftp/utils.py b/src/aioftp/utils.py new file mode 100644 index 0000000..cf74c2f --- /dev/null +++ b/src/aioftp/utils.py @@ -0,0 +1,5 @@ +from typing import Any, Dict, Tuple + + +def get_param(where: Tuple[int, str], args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any: + return kwargs.get(where[1]) or args[where[0]]