diff --git a/.gitignore b/.gitignore
index da858d9..ef7bb86 100644
--- a/.gitignore
+++ b/.gitignore
@@ -15,3 +15,7 @@ __pycache__
.pytest_cache
.ruff_cache
coverage.xml
+.coverage
+.venv
+venv
+.DS_Store
diff --git a/pyproject.toml b/pyproject.toml
index 0f3a947..f4ee601 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -37,6 +37,9 @@ classifiers = [
"Development Status :: 5 - Production/Stable",
"Topic :: Internet :: File Transfer Protocol (FTP)",
]
+dependencies = [
+ "typing_extensions >= 4.10,<5.0",
+]
[project.urls]
Github = "https://github.com/aio-libs/aioftp"
@@ -60,6 +63,10 @@ dev = [
"black",
"ruff",
+ # typecheckers
+ "pyright",
+ "mypy",
+
# docs
"sphinx",
"alabaster",
@@ -81,9 +88,11 @@ target-version = ["py38"]
[tool.ruff]
line-length = 120
target-version = "py38"
-select = ["E", "W", "F", "Q", "UP", "I", "ASYNC"]
src = ["src"]
+[tool.ruff.lint]
+select = ["E", "W", "F", "Q", "UP", "I", "ASYNC"]
+
[tool.coverage]
run.source = ["./src/aioftp"]
run.omit = ["./src/aioftp/__main__.py"]
@@ -96,3 +105,28 @@ log_format = "%(asctime)s.%(msecs)03d %(name)-20s %(levelname)-8s %(filename)-15
log_date_format = "%H:%M:%S"
log_level = "DEBUG"
asyncio_mode = "strict"
+
+[tool.pyright]
+pythonVersion = "3.8"
+strict = ["src"]
+reportImplicitOverride = true
+exclude = ["venv", ".venv", "build", "tests", "docs", "ftpbench.py"]
+
+[tool.mypy]
+files = "src/aioftp"
+show_absolute_path = true
+pretty = true
+strict = true
+enable_error_code = [
+ "explicit-override",
+ "redundant-self", "redundant-expr", "possibly-undefined",
+]
+warn_unreachable = true
+disallow_any_generics = true
+disallow_untyped_defs = true
+
+[[tool.mypy.overrides]]
+module = "tests"
+strict = false
+disallow_any_generics = false
+disallow_untyped_defs = false
diff --git a/src/aioftp/__init__.py b/src/aioftp/__init__.py
index 15235af..226e69d 100644
--- a/src/aioftp/__init__.py
+++ b/src/aioftp/__init__.py
@@ -4,15 +4,106 @@
import importlib.metadata
-from .client import *
-from .common import *
-from .errors import *
-from .pathio import *
-from .server import *
+from .client import BaseClient, Client, Code, DataConnectionThrottleStreamIO
+from .common import (
+ DEFAULT_ACCOUNT,
+ DEFAULT_BLOCK_SIZE,
+ DEFAULT_PASSWORD,
+ DEFAULT_PORT,
+ DEFAULT_USER,
+ END_OF_LINE,
+ AbstractAsyncLister,
+ AsyncListerMixin,
+ AsyncStreamIterator,
+ StreamIO,
+ StreamThrottle,
+ Throttle,
+ ThrottleStreamIO,
+ async_enterable,
+ setlocale,
+ with_timeout,
+ wrap_with_container,
+)
+from .errors import AIOFTPException, NoAvailablePort, PathIOError, PathIsNotAbsolute, StatusCodeError
+from .pathio import (
+ AbstractPathIO,
+ AsyncPathIO,
+ MemoryPathIO,
+ PathIO,
+ PathIONursery,
+)
+from .server import (
+ AbstractUserManager,
+ AvailableConnections,
+ Connection,
+ ConnectionConditions,
+ MemoryUserManager,
+ PathConditions,
+ PathPermissions,
+ Permission,
+ Server,
+ User,
+ worker,
+)
+from .types import NotEmptyCodes, check_not_empty_codes
-__version__ = importlib.metadata.version(__package__)
+__version__ = importlib.metadata.version(__package__) # pyright: ignore[reportArgumentType]
version = tuple(map(int, __version__.split(".")))
+
__all__ = (
- client.__all__ + server.__all__ + errors.__all__ + common.__all__ + pathio.__all__ + ("version", "__version__")
+ # client
+ "BaseClient",
+ "Client",
+ "DataConnectionThrottleStreamIO",
+ "Code",
+ # server
+ "Server",
+ "Permission",
+ "User",
+ "AbstractUserManager",
+ "MemoryUserManager",
+ "Connection",
+ "AvailableConnections",
+ "ConnectionConditions",
+ "PathConditions",
+ "PathPermissions",
+ "worker",
+ # errors
+ "AIOFTPException",
+ "StatusCodeError",
+ "PathIsNotAbsolute",
+ "PathIOError",
+ "NoAvailablePort",
+ # common
+ "with_timeout",
+ "StreamIO",
+ "Throttle",
+ "StreamThrottle",
+ "ThrottleStreamIO",
+ "END_OF_LINE",
+ "DEFAULT_BLOCK_SIZE",
+ "wrap_with_container",
+ "AsyncStreamIterator",
+ "AbstractAsyncLister",
+ "AsyncListerMixin",
+ "async_enterable",
+ "DEFAULT_PORT",
+ "DEFAULT_USER",
+ "DEFAULT_PASSWORD",
+ "DEFAULT_ACCOUNT",
+ "setlocale",
+ # pathio
+ "AbstractPathIO",
+ "PathIO",
+ "AsyncPathIO",
+ "MemoryPathIO",
+ "PathIONursery",
+ #
+ # types
+ "NotEmptyCodes",
+ "check_not_empty_codes",
+ #
+ "version",
+ "__version__",
)
diff --git a/src/aioftp/__main__.py b/src/aioftp/__main__.py
index 3b7db1d..c12ca00 100644
--- a/src/aioftp/__main__.py
+++ b/src/aioftp/__main__.py
@@ -5,6 +5,7 @@
import contextlib
import logging
import socket
+from typing import Type
import aioftp
@@ -65,6 +66,7 @@
format="%(asctime)s [%(name)s] %(message)s",
datefmt="[%H:%M:%S]:",
)
+path_io_factory: Type[aioftp.AbstractPathIO]
if args.memory:
user = aioftp.User(args.login, args.password, base_path="/")
path_io_factory = aioftp.MemoryPathIO
@@ -81,7 +83,7 @@
}[args.family]
-async def main():
+async def main() -> None:
server = aioftp.Server([user], path_io_factory=path_io_factory)
await server.run(args.host, args.port, family=family)
diff --git a/src/aioftp/client.py b/src/aioftp/client.py
index 338f73c..3fceb65 100644
--- a/src/aioftp/client.py
+++ b/src/aioftp/client.py
@@ -7,6 +7,26 @@
import pathlib
import re
from functools import partial
+from ssl import SSLContext
+from types import TracebackType
+from typing import (
+ Any,
+ AsyncIterator,
+ Callable,
+ Deque,
+ Dict,
+ List,
+ Literal,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ Union,
+ cast,
+ overload,
+)
+
+from typing_extensions import NotRequired, Self, TypeAlias, TypedDict, override
from . import errors, pathio
from .common import (
@@ -25,9 +45,10 @@
setlocale,
wrap_with_container,
)
+from .types import AsyncEnterableProtocol, NotEmptyCodes, check_not_empty_codes
try:
- from siosocks.io.asyncio import open_connection
+ from siosocks.io.asyncio import open_connection # type: ignore
except ImportError:
from asyncio import open_connection
@@ -40,13 +61,42 @@
)
logger = logging.getLogger(__name__)
+_PathType: TypeAlias = Union[str, pathlib.PurePath]
+_ConnectionType: TypeAlias = Literal["I", "A", "E", "L"]
+
+_InfoFileType: TypeAlias = Literal["file", "dir", "link", "unknown"]
+
+
+class WindowsInfoDict(TypedDict):
+ type: _InfoFileType
+ size: NotRequired[str]
+ modify: str
+
+
+UnixInfoDict = TypedDict(
+ "UnixInfoDict",
+ {
+ "type": _InfoFileType,
+ "size": str,
+ "modify": str,
+ "unix.mode": int,
+ "unix.links": str,
+ "unix.owner": str,
+ "unix.group": str,
+ },
+)
+
+InfoDict: TypeAlias = Union[WindowsInfoDict, UnixInfoDict]
+
+_ParserType: TypeAlias = Callable[[bytes], Tuple[pathlib.PurePosixPath, InfoDict]]
+
class Code(str):
"""
Representation of server status code.
"""
- def matches(self, mask):
+ def matches(self, mask: str) -> bool:
"""
:param mask: Template for comparision. If mask symbol is not digit
then it passes.
@@ -77,11 +127,15 @@ class DataConnectionThrottleStreamIO(ThrottleStreamIO):
:py:class:`aioftp.ThrottleStreamIO`
"""
- def __init__(self, client, *args, **kwargs):
+ def __init__(self, client: "BaseClient", *args: Any, **kwargs: Any):
super().__init__(*args, **kwargs)
self.client = client
- async def finish(self, expected_codes="2xx", wait_codes="1xx"):
+ async def finish(
+ self,
+ expected_codes: Code = Code("2xx"),
+ wait_codes: Union[Tuple[Code], Code] = Code("1xx"),
+ ) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -99,29 +153,45 @@ async def finish(self, expected_codes="2xx", wait_codes="1xx"):
self.close()
await self.client.command(None, expected_codes, wait_codes)
- async def __aexit__(self, exc_type, exc, tb):
+ @override
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> None:
if exc is None:
await self.finish()
else:
self.close()
+_CodesORCode: TypeAlias = Union[NotEmptyCodes, Code]
+_MaybeCodesORCode: TypeAlias = Union[Tuple[Code, ...], Code]
+_EmptyCodes: TypeAlias = Tuple[()]
+
+
class BaseClient:
+
+ stream: Optional[ThrottleStreamIO]
+ server_port: Optional[int]
+ server_host: Optional[str]
+
def __init__(
self,
*,
- socket_timeout=None,
- connection_timeout=None,
- read_speed_limit=None,
- write_speed_limit=None,
- path_timeout=None,
- path_io_factory=pathio.PathIO,
- encoding="utf-8",
- ssl=None,
- parse_list_line_custom=None,
- parse_list_line_custom_first=True,
- passive_commands=("epsv", "pasv"),
- **siosocks_asyncio_kwargs,
+ socket_timeout: Optional[int] = None,
+ connection_timeout: Optional[int] = None,
+ read_speed_limit: Optional[int] = None,
+ write_speed_limit: Optional[int] = None,
+ path_timeout: Optional[int] = None,
+ path_io_factory: Type[pathio.AbstractPathIO] = pathio.PathIO,
+ encoding: str = "utf-8",
+ ssl: Optional[SSLContext] = None,
+ parse_list_line_custom: Optional[_ParserType] = None,
+ parse_list_line_custom_first: bool = True,
+ passive_commands: Sequence[str] = ("epsv", "pasv"),
+ **siosocks_asyncio_kwargs: Any,
):
self.socket_timeout = socket_timeout
self.connection_timeout = connection_timeout
@@ -138,8 +208,10 @@ def __init__(
self.parse_list_line_custom_first = parse_list_line_custom_first
self._passive_commands = passive_commands
self._open_connection = partial(open_connection, ssl=self.ssl, **siosocks_asyncio_kwargs)
+ self.server_host = None
+ self.server_port = None
- async def connect(self, host, port=DEFAULT_PORT):
+ async def connect(self, host: str, port: int = DEFAULT_PORT) -> Optional[List[str]]:
self.server_host = host
self.server_port = port
reader, writer = await asyncio.wait_for(
@@ -152,15 +224,16 @@ async def connect(self, host, port=DEFAULT_PORT):
throttles={"_": self.throttle},
timeout=self.socket_timeout,
)
+ return None
- def close(self):
+ def close(self) -> None:
"""
Close connection.
"""
if self.stream is not None:
self.stream.close()
- async def parse_line(self):
+ async def parse_line(self) -> Tuple[Code, str]:
"""
:py:func:`asyncio.coroutine`
@@ -174,6 +247,9 @@ async def parse_line(self):
:raises asyncio.TimeoutError: if there where no data for `timeout`
period
"""
+ if self.stream is None:
+ raise ValueError("Connection is not established")
+
line = await self.stream.readline()
if not line:
self.stream.close()
@@ -182,7 +258,7 @@ async def parse_line(self):
logger.debug(s)
return Code(s[:3]), s[3:]
- async def parse_response(self):
+ async def parse_response(self) -> Tuple[Code, List[str]]:
"""
:py:func:`asyncio.coroutine`
@@ -207,7 +283,12 @@ async def parse_response(self):
info.append(curr_code + rest)
return code, info
- def check_codes(self, expected_codes, received_code, info):
+ def check_codes(
+ self,
+ expected_codes: Union[Tuple[Code, ...], Code],
+ received_code: Code,
+ info: List[str],
+ ) -> None:
"""
Checks if any of expected matches received.
@@ -226,13 +307,49 @@ def check_codes(self, expected_codes, received_code, info):
if not any(map(received_code.matches, expected_codes)):
raise errors.StatusCodeError(expected_codes, received_code, info)
+ @overload
async def command(
self,
- command=None,
- expected_codes=(),
- wait_codes=(),
- censor_after=None,
- ):
+ command: Optional[str],
+ expected_codes: _EmptyCodes,
+ wait_codes: _EmptyCodes,
+ censor_after: Optional[int] = None,
+ ) -> None: ...
+
+ @overload
+ async def command(
+ self,
+ command: Optional[str],
+ expected_codes: _MaybeCodesORCode,
+ wait_codes: _CodesORCode,
+ censor_after: Optional[int] = None,
+ ) -> Tuple[Code, List[str]]: ...
+
+ @overload
+ async def command(
+ self,
+ command: Optional[str],
+ expected_codes: _CodesORCode,
+ wait_codes: _MaybeCodesORCode = (),
+ censor_after: Optional[int] = None,
+ ) -> Tuple[Code, List[str]]: ...
+
+ @overload
+ async def command(
+ self,
+ command: Optional[str],
+ expected_codes: _MaybeCodesORCode = (),
+ wait_codes: _MaybeCodesORCode = (),
+ censor_after: Optional[int] = None,
+ ) -> Optional[Tuple[Code, List[str]]]: ...
+
+ async def command(
+ self,
+ command: Optional[str] = None,
+ expected_codes: Union[_EmptyCodes, _MaybeCodesORCode, _CodesORCode] = (),
+ wait_codes: Union[_EmptyCodes, _MaybeCodesORCode, _CodesORCode] = (),
+ censor_after: Optional[int] = None,
+ ) -> Optional[Tuple[Code, List[str]]]:
"""
:py:func:`asyncio.coroutine`
@@ -268,6 +385,10 @@ async def command(
else:
logger.debug(command)
message = command + END_OF_LINE
+
+ if self.stream is None:
+ raise ValueError("Connection is not established")
+
await self.stream.write(message.encode(encoding=self.encoding))
if expected_codes or wait_codes:
code, info = await self.parse_response()
@@ -276,9 +397,10 @@ async def command(
if expected_codes:
self.check_codes(expected_codes, code, info)
return code, info
+ return None
@staticmethod
- def parse_epsv_response(s):
+ def parse_epsv_response(s: str) -> Tuple[None, int]:
"""
Parsing `EPSV` (`message (|||port|)`) response.
@@ -294,7 +416,7 @@ def parse_epsv_response(s):
return None, port
@staticmethod
- def parse_pasv_response(s):
+ def parse_pasv_response(s: str) -> Tuple[str, int]:
"""
Parsing `PASV` server response.
@@ -311,7 +433,7 @@ def parse_pasv_response(s):
return ip, port
@staticmethod
- def parse_directory_response(s):
+ def parse_directory_response(s: str) -> pathlib.PurePosixPath:
"""
Parsing directory server response.
@@ -340,7 +462,7 @@ def parse_directory_response(s):
return pathlib.PurePosixPath(directory)
@staticmethod
- def parse_unix_mode(s):
+ def parse_unix_mode(s: str) -> int:
"""
Parsing unix mode strings ("rwxr-x--t") into hexacimal notation.
@@ -379,7 +501,7 @@ def parse_unix_mode(s):
return mode
@staticmethod
- def format_date_time(d):
+ def format_date_time(d: datetime.datetime) -> str:
"""
Formats dates from strptime in a consistent format
@@ -391,7 +513,7 @@ def format_date_time(d):
return d.strftime("%Y%m%d%H%M00")
@classmethod
- def parse_ls_date(cls, s, *, now=None):
+ def parse_ls_date(cls, s: str, *, now: Optional[datetime.datetime] = None) -> str:
"""
Parsing dates from the ls unix utility. For example,
"Nov 18 1958", "Jan 03 2018", and "Nov 18 12:29".
@@ -430,7 +552,7 @@ def parse_ls_date(cls, s, *, now=None):
d = datetime.datetime.strptime(s, "%b %d %Y")
return cls.format_date_time(d)
- def parse_list_line_unix(self, b):
+ def parse_list_line_unix(self, b: bytes) -> Tuple[pathlib.PurePosixPath, UnixInfoDict]:
"""
Attempt to parse a LIST line (similar to unix ls utility).
@@ -441,7 +563,7 @@ def parse_list_line_unix(self, b):
:rtype: (:py:class:`pathlib.PurePosixPath`, :py:class:`dict`)
"""
s = b.decode(encoding=self.encoding).rstrip()
- info = {}
+ info: dict[str, Any] = {}
if s[0] == "-":
info["type"] = "file"
elif s[0] == "d":
@@ -483,9 +605,22 @@ def parse_list_line_unix(self, b):
i = -2 if link_dst[-1] == "'" or link_dst[-1] == '"' else -1
info["type"] = "dir" if link_dst[i] == "/" else "file"
s = link_src
- return pathlib.PurePosixPath(s), info
- def parse_list_line_windows(self, b):
+ typed_info = UnixInfoDict(
+ {
+ "type": info["type"],
+ "size": info["size"],
+ "modify": info["modify"],
+ "unix.mode": info["unix.mode"],
+ "unix.links": info["unix.links"],
+ "unix.owner": info["unix.owner"],
+ "unix.group": info["unix.group"],
+ },
+ )
+
+ return pathlib.PurePosixPath(s), typed_info
+
+ def parse_list_line_windows(self, b: bytes) -> Tuple[pathlib.PurePosixPath, WindowsInfoDict]:
"""
Parsing Microsoft Windows `dir` output
@@ -497,13 +632,13 @@ def parse_list_line_windows(self, b):
"""
line = b.decode(encoding=self.encoding).rstrip("\r\n")
date_time_end = line.index("M")
- date_time_str = line[: date_time_end + 1].strip().split(" ")
- date_time_str = " ".join([x for x in date_time_str if len(x) > 0])
+ date_time_list_str = line[: date_time_end + 1].strip().split(" ")
+ date_time_str = " ".join([x for x in date_time_list_str if len(x) > 0])
line = line[date_time_end + 1 :].lstrip()
with setlocale("C"):
strptime = datetime.datetime.strptime
date_time = strptime(date_time_str, "%m/%d/%Y %I:%M %p")
- info = {}
+ info: dict[str, Any] = {}
info["modify"] = self.format_date_time(date_time)
next_space = line.index(" ")
if line.startswith("
"):
@@ -519,9 +654,17 @@ def parse_list_line_windows(self, b):
filename = line[next_space:].lstrip()
if filename == "." or filename == "..":
raise ValueError
- return pathlib.PurePosixPath(filename), info
- def parse_list_line(self, b):
+ windows_info: WindowsInfoDict = {
+ "type": info["type"],
+ "modify": info["modify"],
+ }
+ if "size" in info:
+ windows_info["size"] = info["size"]
+
+ return pathlib.PurePosixPath(filename), windows_info
+
+ def parse_list_line(self, b: bytes) -> Tuple[pathlib.PurePosixPath, InfoDict]:
"""
Parse LIST response with both Microsoft Windows® parser and
UNIX parser
@@ -532,8 +675,8 @@ def parse_list_line(self, b):
:return: (path, info)
:rtype: (:py:class:`pathlib.PurePosixPath`, :py:class:`dict`)
"""
- ex = []
- parsers = [
+ ex: List[Exception] = []
+ parsers: List[Optional[_ParserType]] = [
self.parse_list_line_unix,
self.parse_list_line_windows,
]
@@ -550,7 +693,7 @@ def parse_list_line(self, b):
ex.append(e)
raise ValueError("All parsers failed to parse", b, ex)
- def parse_mlsx_line(self, b):
+ def parse_mlsx_line(self, b: Union[bytes, str]) -> Tuple[pathlib.PurePosixPath, InfoDict]:
"""
Parsing MLS(T|D) response.
@@ -566,11 +709,12 @@ def parse_mlsx_line(self, b):
s = b
line = s.rstrip()
facts_found, _, name = line.partition(" ")
- entry = {}
+ entry: Dict[str, str] = {}
for fact in facts_found[:-1].split(";"):
key, _, value = fact.partition("=")
entry[key.lower()] = value
- return pathlib.PurePosixPath(name), entry
+
+ return pathlib.PurePosixPath(name), cast(InfoDict, entry)
class Client(BaseClient):
@@ -619,7 +763,8 @@ class Client(BaseClient):
:param **siosocks_asyncio_kwargs: siosocks key-word only arguments
"""
- async def connect(self, host, port=DEFAULT_PORT):
+ @override
+ async def connect(self, host: str, port: int = DEFAULT_PORT) -> List[str]:
"""
:py:func:`asyncio.coroutine`
@@ -632,15 +777,15 @@ async def connect(self, host, port=DEFAULT_PORT):
:type port: :py:class:`int`
"""
await super().connect(host, port)
- code, info = await self.command(None, "220", "120")
+ _, info = await self.command(None, Code("220"), Code("120"))
return info
async def login(
self,
- user=DEFAULT_USER,
- password=DEFAULT_PASSWORD,
- account=DEFAULT_ACCOUNT,
- ):
+ user: str = DEFAULT_USER,
+ password: str = DEFAULT_PASSWORD,
+ account: str = DEFAULT_ACCOUNT,
+ ) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -657,7 +802,7 @@ async def login(
:raises aioftp.StatusCodeError: if unknown code received
"""
- code, info = await self.command("USER " + user, ("230", "33x"))
+ code, info = await self.command("USER " + user, check_not_empty_codes((Code("230"), Code("33x"))))
while code.matches("33x"):
censor_after = None
if code == "331":
@@ -666,14 +811,19 @@ async def login(
elif code == "332":
cmd = "ACCT " + account
else:
- raise errors.StatusCodeError("33x", code, info)
+ raise errors.StatusCodeError(Code("33x"), code, info)
code, info = await self.command(
cmd,
- ("230", "33x"),
+ check_not_empty_codes(
+ (
+ Code("230"),
+ Code("33x"),
+ ),
+ ),
censor_after=censor_after,
)
- async def get_current_directory(self):
+ async def get_current_directory(self) -> pathlib.PurePosixPath:
"""
:py:func:`asyncio.coroutine`
@@ -681,11 +831,11 @@ async def get_current_directory(self):
:rtype: :py:class:`pathlib.PurePosixPath`
"""
- code, info = await self.command("PWD", "257")
+ _, info = await self.command("PWD", Code("257"))
directory = self.parse_directory_response(info[-1])
return directory
- async def change_directory(self, path=".."):
+ async def change_directory(self, path: _PathType = "..") -> None:
"""
:py:func:`asyncio.coroutine`
@@ -699,9 +849,9 @@ async def change_directory(self, path=".."):
cmd = "CDUP"
else:
cmd = "CWD " + str(path)
- await self.command(cmd, "2xx")
+ await self.command(cmd, Code("2xx"))
- async def make_directory(self, path, *, parents=True):
+ async def make_directory(self, path: _PathType, *, parents: bool = True) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -714,7 +864,7 @@ async def make_directory(self, path, *, parents=True):
:type parents: :py:class:`bool`
"""
path = pathlib.PurePosixPath(path)
- need_create = []
+ need_create: List[_PathType] = []
while path.name and not await self.exists(path):
need_create.append(path)
path = path.parent
@@ -722,9 +872,9 @@ async def make_directory(self, path, *, parents=True):
break
need_create.reverse()
for path in need_create:
- await self.command("MKD " + str(path), "257")
+ await self.command("MKD " + str(path), Code("257"))
- async def remove_directory(self, path):
+ async def remove_directory(self, path: _PathType) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -733,9 +883,15 @@ async def remove_directory(self, path):
:param path: empty directory to remove
:type path: :py:class:`str` or :py:class:`pathlib.PurePosixPath`
"""
- await self.command("RMD " + str(path), "250")
+ await self.command("RMD " + str(path), Code("250"))
- def list(self, path="", *, recursive=False, raw_command=None):
+ def list(
+ self,
+ path: _PathType = "",
+ *,
+ recursive: bool = False,
+ raw_command: Optional[str] = None,
+ ) -> AsyncListerMixin[Tuple[pathlib.PurePath, InfoDict]]:
"""
:py:func:`asyncio.coroutine`
@@ -770,59 +926,78 @@ def list(self, path="", *, recursive=False, raw_command=None):
>>> stats = await client.list()
"""
- class AsyncLister(AsyncListerMixin):
- stream = None
-
- async def _new_stream(cls, local_path):
- cls.path = local_path
- cls.parse_line = self.parse_mlsx_line
+ class AsyncLister(
+ AsyncIterator[Tuple[pathlib.PurePath, InfoDict]],
+ AsyncListerMixin[Tuple[pathlib.PurePath, InfoDict]],
+ ):
+ stream: Optional[DataConnectionThrottleStreamIO] = None
+ directories: Deque[Tuple[_PathType, InfoDict]]
+ parse_line: Callable[[bytes], Tuple[pathlib.PurePath, InfoDict]]
+ client: Client = self
+
+ async def _new_stream(
+ self,
+ local_path: _PathType,
+ ) -> Optional[DataConnectionThrottleStreamIO]:
+ self.path = local_path
+ self.parse_line = self.client.parse_mlsx_line
if raw_command not in [None, "MLSD", "LIST"]:
raise ValueError(
- "raw_command must be one of MLSD or " f"LIST, but got {raw_command}",
+ f"raw_command must be one of MLSD or LIST, but got {raw_command}",
)
if raw_command in [None, "MLSD"]:
try:
- command = ("MLSD " + str(cls.path)).strip()
- return await self.get_stream(command, "1xx")
+ command = ("MLSD " + str(self.path)).strip()
+ return await self.client.get_stream(command, Code("1xx"))
except errors.StatusCodeError as e:
code = e.received_codes[-1]
if not code.matches("50x") or raw_command is not None:
raise
if raw_command in [None, "LIST"]:
- cls.parse_line = self.parse_list_line
- command = ("LIST " + str(cls.path)).strip()
- return await self.get_stream(command, "1xx")
+ self.parse_line = self.client.parse_list_line
+ command = ("LIST " + str(self.path)).strip()
+ return await self.client.get_stream(command, Code("1xx"))
+ return None
+
+ @override
+ def __aiter__(self) -> Self:
+ self.directories = collections.deque()
+ return self
- def __aiter__(cls):
- cls.directories = collections.deque()
- return cls
+ @override
+ async def __anext__(self) -> Tuple[pathlib.PurePath, InfoDict]:
+ if self.stream is None:
+ self.stream = await self._new_stream(path)
- async def __anext__(cls):
- if cls.stream is None:
- cls.stream = await cls._new_stream(path)
while True:
- line = await cls.stream.readline()
+ if self.stream is None:
+ raise StopAsyncIteration
+ line = await self.stream.readline()
while not line:
- await cls.stream.finish()
- if cls.directories:
- current_path, info = cls.directories.popleft()
- cls.stream = await cls._new_stream(current_path)
- line = await cls.stream.readline()
+ await self.stream.finish()
+ if self.directories:
+ current_path, info = self.directories.popleft()
+ self.stream = await self._new_stream(current_path)
+
+ if self.stream is None:
+ raise ValueError("Connection is not established")
+
+ line = await self.stream.readline()
else:
raise StopAsyncIteration
- name, info = cls.parse_line(line)
+ name, info = self.parse_line(line)
# skipping . and .. as these are symlinks in Unix
if str(name) in (".", ".."):
continue
- stat = cls.path / name, info
+ stat = self.path / name, info
if info["type"] == "dir" and recursive:
- cls.directories.append(stat)
+ self.directories.append(stat)
return stat
return AsyncLister()
- async def stat(self, path):
+ async def stat(self, path: _PathType) -> InfoDict:
"""
:py:func:`asyncio.coroutine`
@@ -836,8 +1011,8 @@ async def stat(self, path):
"""
path = pathlib.PurePosixPath(path)
try:
- code, info = await self.command("MLST " + str(path), "2xx")
- name, info = self.parse_mlsx_line(info[1].lstrip())
+ _, info_ = await self.command("MLST " + str(path), Code("2xx"))
+ _, info = self.parse_mlsx_line(info_[1].lstrip())
return info
except errors.StatusCodeError as e:
if not e.received_codes[-1].matches("50x"):
@@ -853,7 +1028,7 @@ async def stat(self, path):
"path does not exists",
)
- async def is_file(self, path):
+ async def is_file(self, path: _PathType) -> bool:
"""
:py:func:`asyncio.coroutine`
@@ -867,7 +1042,7 @@ async def is_file(self, path):
info = await self.stat(path)
return info["type"] == "file"
- async def is_dir(self, path):
+ async def is_dir(self, path: _PathType) -> bool:
"""
:py:func:`asyncio.coroutine`
@@ -881,7 +1056,7 @@ async def is_dir(self, path):
info = await self.stat(path)
return info["type"] == "dir"
- async def exists(self, path):
+ async def exists(self, path: _PathType) -> bool:
"""
:py:func:`asyncio.coroutine`
@@ -900,7 +1075,7 @@ async def exists(self, path):
return False
raise
- async def rename(self, source, destination):
+ async def rename(self, source: _PathType, destination: _PathType) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -912,10 +1087,10 @@ async def rename(self, source, destination):
:param destination: path new name
:type destination: :py:class:`str` or :py:class:`pathlib.PurePosixPath`
"""
- await self.command("RNFR " + str(source), "350")
- await self.command("RNTO " + str(destination), "2xx")
+ await self.command("RNFR " + str(source), Code("350"))
+ await self.command("RNTO " + str(destination), Code("2xx"))
- async def remove_file(self, path):
+ async def remove_file(self, path: _PathType) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -924,9 +1099,9 @@ async def remove_file(self, path):
:param path: file to remove
:type path: :py:class:`str` or :py:class:`pathlib.PurePosixPath`
"""
- await self.command("DELE " + str(path), "2xx")
+ await self.command("DELE " + str(path), Code("2xx"))
- async def remove(self, path):
+ async def remove(self, path: _PathType) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -946,7 +1121,12 @@ async def remove(self, path):
await self.remove(name)
await self.remove_directory(path)
- def upload_stream(self, destination, *, offset=0):
+ def upload_stream(
+ self,
+ destination: _PathType,
+ *,
+ offset: int = 0,
+ ) -> AsyncEnterableProtocol[DataConnectionThrottleStreamIO]:
"""
Create stream for write data to `destination` file.
@@ -964,7 +1144,12 @@ def upload_stream(self, destination, *, offset=0):
offset=offset,
)
- def append_stream(self, destination, *, offset=0):
+ def append_stream(
+ self,
+ destination: _PathType,
+ *,
+ offset: int = 0,
+ ) -> AsyncEnterableProtocol[DataConnectionThrottleStreamIO]:
"""
Create stream for append (write) data to `destination` file.
@@ -982,7 +1167,14 @@ def append_stream(self, destination, *, offset=0):
offset=offset,
)
- async def upload(self, source, destination="", *, write_into=False, block_size=DEFAULT_BLOCK_SIZE):
+ async def upload(
+ self,
+ source: _PathType,
+ destination: _PathType = "",
+ *,
+ write_into: bool = False,
+ block_size: int = DEFAULT_BLOCK_SIZE,
+ ) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -1003,25 +1195,27 @@ async def upload(self, source, destination="", *, write_into=False, block_size=D
:param block_size: block size for transaction
:type block_size: :py:class:`int`
"""
- source = pathlib.Path(source)
- destination = pathlib.PurePosixPath(destination)
+ source_path = pathlib.Path(source)
+ destination_path = pathlib.PurePosixPath(destination)
if not write_into:
- destination = destination / source.name
- if await self.path_io.is_file(source):
- await self.make_directory(destination.parent)
- async with self.path_io.open(source, mode="rb") as file_in, self.upload_stream(destination) as stream:
+ destination_path = destination_path / source_path.name
+ if await self.path_io.is_file(source_path):
+ await self.make_directory(destination_path.parent)
+ async with self.path_io.open(source_path, mode="rb") as file_in, self.upload_stream(
+ destination_path,
+ ) as stream:
async for block in file_in.iter_by_block(block_size):
await stream.write(block)
- elif await self.path_io.is_dir(source):
- await self.make_directory(destination)
- sources = collections.deque([source])
+ elif await self.path_io.is_dir(source_path):
+ await self.make_directory(destination_path)
+ sources = collections.deque([source_path])
while sources:
src = sources.popleft()
async for path in self.path_io.list(src):
if write_into:
- relative = destination.name / path.relative_to(source)
+ relative = destination_path.name / path.relative_to(source_path)
else:
- relative = path.relative_to(source.parent)
+ relative = path.relative_to(source_path.parent)
if await self.path_io.is_dir(path):
await self.make_directory(relative)
sources.append(path)
@@ -1033,7 +1227,12 @@ async def upload(self, source, destination="", *, write_into=False, block_size=D
block_size=block_size,
)
- def download_stream(self, source, *, offset=0):
+ def download_stream(
+ self,
+ source: _PathType,
+ *,
+ offset: int = 0,
+ ) -> AsyncEnterableProtocol[DataConnectionThrottleStreamIO]:
"""
:py:func:`asyncio.coroutine`
@@ -1049,7 +1248,14 @@ def download_stream(self, source, *, offset=0):
"""
return self.get_stream("RETR " + str(source), "1xx", offset=offset)
- async def download(self, source, destination="", *, write_into=False, block_size=DEFAULT_BLOCK_SIZE):
+ async def download(
+ self,
+ source: _PathType,
+ destination: _PathType = "",
+ *,
+ write_into: bool = False,
+ block_size: int = DEFAULT_BLOCK_SIZE,
+ ) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -1070,23 +1276,25 @@ async def download(self, source, destination="", *, write_into=False, block_size
:param block_size: block size for transaction
:type block_size: :py:class:`int`
"""
- source = pathlib.PurePosixPath(source)
- destination = pathlib.Path(destination)
+ source_path = pathlib.PurePosixPath(source)
+ destination_path = pathlib.Path(destination)
if not write_into:
- destination = destination / source.name
- if await self.is_file(source):
+ destination_path = destination_path / source_path.name
+ if await self.is_file(source_path):
await self.path_io.mkdir(
- destination.parent,
+ destination_path.parent,
parents=True,
exist_ok=True,
)
- async with self.path_io.open(destination, mode="wb") as file_out, self.download_stream(source) as stream:
+ async with self.path_io.open(destination_path, mode="wb") as file_out, self.download_stream(
+ source_path,
+ ) as stream:
async for block in stream.iter_by_block(block_size):
await file_out.write(block)
- elif await self.is_dir(source):
- await self.path_io.mkdir(destination, parents=True, exist_ok=True)
- for name, info in await self.list(source):
- full = destination / name.relative_to(source)
+ elif await self.is_dir(source_path):
+ await self.path_io.mkdir(destination_path, parents=True, exist_ok=True)
+ for name, info in await self.list(source_path):
+ full = destination_path / name.relative_to(source_path)
if info["type"] in ("file", "dir"):
await self.download(
name,
@@ -1095,30 +1303,30 @@ async def download(self, source, destination="", *, write_into=False, block_size
block_size=block_size,
)
- async def quit(self):
+ async def quit(self) -> None:
"""
:py:func:`asyncio.coroutine`
Send "QUIT" and close connection.
"""
- await self.command("QUIT", "2xx")
+ await self.command("QUIT", Code("2xx"))
self.close()
- async def _do_epsv(self):
- code, info = await self.command("EPSV", "229")
+ async def _do_epsv(self) -> Tuple[None, int]:
+ _, info = await self.command("EPSV", Code("229"))
ip, port = self.parse_epsv_response(info[-1])
return ip, port
- async def _do_pasv(self):
- code, info = await self.command("PASV", "227")
+ async def _do_pasv(self) -> Tuple[str, int]:
+ _, info = await self.command("PASV", Code("227"))
ip, port = self.parse_pasv_response(info[-1])
return ip, port
async def get_passive_connection(
self,
- conn_type="I",
- commands=None,
- ):
+ conn_type: _ConnectionType = "I",
+ commands: Optional[Sequence[str]] = None,
+ ) -> Tuple[asyncio.StreamReader, asyncio.StreamWriter]:
"""
:py:func:`asyncio.coroutine`
@@ -1143,7 +1351,7 @@ async def get_passive_connection(
commands = self._passive_commands
if not commands:
raise ValueError("No passive commands provided")
- await self.command("TYPE " + conn_type, "200")
+ await self.command("TYPE " + conn_type, Code("200"))
for i, name in enumerate(commands, start=1):
name = name.lower()
if name not in functions:
@@ -1155,13 +1363,18 @@ async def get_passive_connection(
is_last = i == len(commands)
if is_last or not e.received_codes[-1].matches("50x"):
raise
- if ip in ("0.0.0.0", None):
+ if ip in ("0.0.0.0", None): # type: ignore
ip = self.server_host
- reader, writer = await self._open_connection(ip, port)
+ reader, writer = await self._open_connection(ip, port) # type: ignore
return reader, writer
@async_enterable
- async def get_stream(self, *command_args, conn_type="I", offset=0):
+ async def get_stream(
+ self,
+ *command_args: Any,
+ conn_type: _ConnectionType = "I",
+ offset: int = 0,
+ ) -> DataConnectionThrottleStreamIO:
"""
:py:func:`asyncio.coroutine`
@@ -1180,7 +1393,7 @@ async def get_stream(self, *command_args, conn_type="I", offset=0):
"""
reader, writer = await self.get_passive_connection(conn_type)
if offset:
- await self.command("REST " + str(offset), "350")
+ await self.command("REST " + str(offset), Code("350"))
await self.command(*command_args)
stream = DataConnectionThrottleStreamIO(
self,
@@ -1191,7 +1404,7 @@ async def get_stream(self, *command_args, conn_type="I", offset=0):
)
return stream
- async def abort(self, *, wait=True):
+ async def abort(self, *, wait: bool = True) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -1201,15 +1414,21 @@ async def abort(self, *, wait=True):
:type wait: :py:class:`bool`
"""
if wait:
- await self.command("ABOR", "226", "426")
+ await self.command("ABOR", Code("226"), Code("426"))
else:
await self.command("ABOR")
@classmethod
@contextlib.asynccontextmanager
async def context(
- cls, host, port=DEFAULT_PORT, user=DEFAULT_USER, password=DEFAULT_PASSWORD, account=DEFAULT_ACCOUNT, **kwargs
- ):
+ cls,
+ host: str,
+ port: int = DEFAULT_PORT,
+ user: str = DEFAULT_USER,
+ password: str = DEFAULT_PASSWORD,
+ account: str = DEFAULT_ACCOUNT,
+ **kwargs: Any,
+ ) -> AsyncIterator["Client"]:
"""
Classmethod async context manager. This create
:py:class:`aioftp.Client`, make async call to
diff --git a/src/aioftp/common.py b/src/aioftp/common.py
index 41eec46..fca6d30 100644
--- a/src/aioftp/common.py
+++ b/src/aioftp/common.py
@@ -1,10 +1,40 @@
import abc
import asyncio
-import collections
import functools
import locale
import threading
from contextlib import contextmanager
+from types import TracebackType
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ AsyncContextManager,
+ AsyncIterable,
+ AsyncIterator,
+ Awaitable,
+ Callable,
+ Coroutine,
+ Dict,
+ Final,
+ Generator,
+ Generic,
+ List,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
+
+from typing_extensions import ParamSpec, Self, override
+
+from .types import AsyncEnterableProtocol
+from .utils import get_param
+
+if TYPE_CHECKING:
+ from .pathio import AsyncPathIO
__all__ = (
"with_timeout",
@@ -27,35 +57,67 @@
)
-END_OF_LINE = "\r\n"
-DEFAULT_BLOCK_SIZE = 8192
+END_OF_LINE: Final[str] = "\r\n"
+DEFAULT_BLOCK_SIZE: Final[int] = 8192
+
+DEFAULT_PORT: Final[int] = 21
+DEFAULT_USER: Final[str] = "anonymous"
+DEFAULT_PASSWORD: Final[str] = "anon@"
+DEFAULT_ACCOUNT: Final[str] = ""
+HALF_OF_YEAR_IN_SECONDS: Final[int] = 15778476
+TWO_YEARS_IN_SECONDS: Final[float] = ((365 * 3 + 366) * 24 * 60 * 60) / 2
-DEFAULT_PORT = 21
-DEFAULT_USER = "anonymous"
-DEFAULT_PASSWORD = "anon@"
-DEFAULT_ACCOUNT = ""
-HALF_OF_YEAR_IN_SECONDS = 15778476
-TWO_YEARS_IN_SECONDS = ((365 * 3 + 366) * 24 * 60 * 60) / 2
+_T = TypeVar("_T")
+_PS = ParamSpec("_PS")
-def _now():
+def _now() -> float:
return asyncio.get_running_loop().time()
-def _with_timeout(name):
- def decorator(f):
+def _with_timeout(
+ name: str,
+) -> Callable[
+ [Callable[_PS, Coroutine[None, None, _T]]],
+ Callable[_PS, Coroutine[None, None, _T]],
+]:
+ def decorator(f: Callable[_PS, Coroutine[None, None, _T]]) -> Callable[_PS, Coroutine[None, None, _T]]:
@functools.wraps(f)
- def wrapper(cls, *args, **kwargs):
- coro = f(cls, *args, **kwargs)
- timeout = getattr(cls, name)
- return asyncio.wait_for(coro, timeout)
+ async def wrapper(*args: _PS.args, **kwargs: _PS.kwargs) -> _T:
+ self: "AsyncPathIO" = get_param((0, "self"), args, kwargs)
+ coro = f(*args, **kwargs)
+ timeout = getattr(self, name)
+ return await asyncio.wait_for(coro, timeout)
return wrapper
return decorator
-def with_timeout(name):
+@overload
+def with_timeout(
+ name: str,
+) -> Callable[
+ [Callable[_PS, Coroutine[None, None, _T]]],
+ Callable[_PS, Coroutine[None, None, _T]],
+]: ...
+
+
+@overload
+def with_timeout(
+ name: Callable[_PS, Coroutine[None, None, _T]],
+) -> Callable[_PS, Coroutine[None, None, _T]]: ...
+
+
+def with_timeout(
+ name: Union[str, Callable[_PS, Coroutine[None, None, _T]]],
+) -> Union[
+ Callable[
+ [Callable[_PS, Coroutine[None, None, _T]]],
+ Callable[_PS, Coroutine[None, None, _T]],
+ ],
+ Callable[_PS, Coroutine[None, None, _T]],
+]:
"""
Method decorator, wraps method with :py:func:`asyncio.wait_for`. `timeout`
argument takes from `name` decorator argument or "timeout".
@@ -97,14 +159,16 @@ def with_timeout(name):
return _with_timeout("timeout")(name)
-class AsyncStreamIterator:
- def __init__(self, read_coro):
+class AsyncStreamIterator(AsyncIterator[_T], Generic[_T]):
+ def __init__(self, read_coro: Callable[[], Awaitable[_T]]):
self.read_coro = read_coro
- def __aiter__(self):
+ @override
+ def __aiter__(self) -> "AsyncStreamIterator[_T]":
return self
- async def __anext__(self):
+ @override
+ async def __anext__(self) -> _T:
data = await self.read_coro()
if data:
return data
@@ -112,7 +176,11 @@ async def __anext__(self):
raise StopAsyncIteration
-class AsyncListerMixin:
+class AsyncListerMixin(
+ AsyncIterable[_T],
+ Awaitable[List[_T]],
+ Generic[_T],
+):
"""
Add ability to `async for` context to collect data to list via await.
@@ -123,17 +191,18 @@ class AsyncListerMixin:
>>> results = await Context(...)
"""
- async def _to_list(self):
- items = []
+ async def _to_list(self) -> List[_T]:
+ items: List[_T] = []
async for item in self:
items.append(item)
return items
- def __await__(self):
+ @override
+ def __await__(self) -> Generator[None, None, List[_T]]:
return self._to_list().__await__()
-class AbstractAsyncLister(AsyncListerMixin, abc.ABC):
+class AbstractAsyncLister(AsyncListerMixin[_T], abc.ABC):
"""
Abstract context with ability to collect all iterables into
:py:class:`list` via `await` with optional timeout (via
@@ -162,16 +231,17 @@ class AbstractAsyncLister(AsyncListerMixin, abc.ABC):
[block, block, block, ...]
"""
- def __init__(self, *, timeout=None):
+ def __init__(self, *, timeout: Optional[float] = None) -> None:
super().__init__()
self.timeout = timeout
- def __aiter__(self):
+ @override
+ def __aiter__(self) -> "AbstractAsyncLister[_T]":
return self
@with_timeout
@abc.abstractmethod
- async def __anext__(self):
+ async def __anext__(self) -> _T:
"""
:py:func:`asyncio.coroutine`
@@ -179,7 +249,12 @@ async def __anext__(self):
"""
-def async_enterable(f):
+_T_acm = TypeVar("_T_acm", bound=AsyncContextManager[Any])
+
+
+def async_enterable(
+ f: Callable[_PS, Coroutine[None, None, _T_acm]],
+) -> Callable[_PS, AsyncEnterableProtocol[_T_acm]]:
"""
Decorator. Bring coroutine result up, so it can be used as async context
@@ -215,16 +290,19 @@ def async_enterable(f):
"""
@functools.wraps(f)
- def wrapper(*args, **kwargs):
- class AsyncEnterableInstance:
- async def __aenter__(self):
+ def wrapper(*args: _PS.args, **kwargs: _PS.kwargs) -> AsyncEnterableProtocol[_T_acm]:
+ class AsyncEnterableInstance(AsyncEnterableProtocol[_T_acm]): # pyright: ignore[reportGeneralTypeIssues]
+ @override
+ async def __aenter__(self) -> _T_acm:
self.context = await f(*args, **kwargs)
- return await self.context.__aenter__()
+ return await self.context.__aenter__() # type: ignore
- async def __aexit__(self, *args, **kwargs):
+ @override
+ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
await self.context.__aexit__(*args, **kwargs)
- def __await__(self):
+ @override
+ def __await__(self) -> Generator[None, None, _T_acm]:
return f(*args, **kwargs).__await__()
return AsyncEnterableInstance()
@@ -232,9 +310,18 @@ def __await__(self):
return wrapper
-def wrap_with_container(o):
+_T_str_bounded = TypeVar("_T_str_bounded", bound=str)
+
+
+@overload
+def wrap_with_container(o: _T_str_bounded) -> Tuple[_T_str_bounded]: ...
+@overload
+def wrap_with_container(o: Any) -> Any: ...
+
+
+def wrap_with_container(o: _T) -> Union[_T, Tuple[str]]:
if isinstance(o, str):
- o = (o,)
+ return (o,)
return o
@@ -260,14 +347,22 @@ class StreamIO:
:type write_timeout: :py:class:`int`, :py:class:`float` or :py:class:`None`
"""
- def __init__(self, reader, writer, *, timeout=None, read_timeout=None, write_timeout=None):
+ def __init__(
+ self,
+ reader: asyncio.StreamReader,
+ writer: asyncio.StreamWriter,
+ *,
+ timeout: Optional[Union[int, float]] = None,
+ read_timeout: Optional[Union[int, float]] = None,
+ write_timeout: Optional[Union[int, float]] = None,
+ ) -> None:
self.reader = reader
self.writer = writer
self.read_timeout = read_timeout or timeout
self.write_timeout = write_timeout or timeout
@with_timeout("read_timeout")
- async def readline(self):
+ async def readline(self) -> bytes:
"""
:py:func:`asyncio.coroutine`
@@ -276,7 +371,7 @@ async def readline(self):
return await self.reader.readline()
@with_timeout("read_timeout")
- async def read(self, count=-1):
+ async def read(self, count: int = -1) -> bytes:
"""
:py:func:`asyncio.coroutine`
@@ -288,7 +383,7 @@ async def read(self, count=-1):
return await self.reader.read(count)
@with_timeout("read_timeout")
- async def readexactly(self, count):
+ async def readexactly(self, count: int) -> bytes:
"""
:py:func:`asyncio.coroutine`
@@ -300,7 +395,7 @@ async def readexactly(self, count):
return await self.reader.readexactly(count)
@with_timeout("write_timeout")
- async def write(self, data):
+ async def write(self, data: bytes) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -313,7 +408,7 @@ async def write(self, data):
self.writer.write(data)
await self.writer.drain()
- def close(self):
+ def close(self) -> None:
"""
Close connection.
"""
@@ -332,13 +427,13 @@ class Throttle:
:type reset_rate: :py:class:`int` or :py:class:`float`
"""
- def __init__(self, *, limit=None, reset_rate=10):
+ def __init__(self, *, limit: Optional[int] = None, reset_rate: Union[int, float] = 10) -> None:
self._limit = limit
self.reset_rate = reset_rate
- self._start = None
+ self._start: Optional[float] = None
self._sum = 0
- async def wait(self):
+ async def wait(self) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -349,7 +444,7 @@ async def wait(self):
end = self._start + self._sum / self._limit
await asyncio.sleep(max(0, end - now))
- def append(self, data, start):
+ def append(self, data: bytes, start: float) -> None:
"""
Count `data` for throttle
@@ -369,14 +464,14 @@ def append(self, data, start):
self._sum += len(data)
@property
- def limit(self):
+ def limit(self) -> Optional[int]:
"""
Throttle limit
"""
return self._limit
@limit.setter
- def limit(self, value):
+ def limit(self, value: Optional[int]) -> None:
"""
Set throttle limit
@@ -387,17 +482,18 @@ def limit(self, value):
self._start = None
self._sum = 0
- def clone(self):
+ def clone(self) -> "Throttle":
"""
Clone throttle without memory
"""
return Throttle(limit=self._limit, reset_rate=self.reset_rate)
- def __repr__(self):
+ @override
+ def __repr__(self) -> str:
return f"{self.__class__.__name__}(limit={self._limit!r}, " f"reset_rate={self.reset_rate!r})"
-class StreamThrottle(collections.namedtuple("StreamThrottle", "read write")):
+class StreamThrottle(NamedTuple):
"""
Stream throttle with `read` and `write` :py:class:`aioftp.Throttle`
@@ -408,7 +504,10 @@ class StreamThrottle(collections.namedtuple("StreamThrottle", "read write")):
:type write: :py:class:`aioftp.Throttle`
"""
- def clone(self):
+ read: Throttle
+ write: Throttle
+
+ def clone(self) -> "StreamThrottle":
"""
Clone throttles without memory
"""
@@ -418,7 +517,11 @@ def clone(self):
)
@classmethod
- def from_limits(cls, read_speed_limit=None, write_speed_limit=None):
+ def from_limits(
+ cls,
+ read_speed_limit: Optional[int] = None,
+ write_speed_limit: Optional[int] = None,
+ ) -> "StreamThrottle":
"""
Simple wrapper for creation :py:class:`aioftp.StreamThrottle`
@@ -463,11 +566,11 @@ class ThrottleStreamIO(StreamIO):
... )
"""
- def __init__(self, *args, throttles={}, **kwargs):
+ def __init__(self, *args: Any, throttles: Dict[str, StreamThrottle] = {}, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.throttles = throttles
- async def wait(self, name):
+ async def wait(self, name: str) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -476,7 +579,7 @@ async def wait(self, name):
:param name: name of throttle to acquire ("read" or "write")
:type name: :py:class:`str`
"""
- tasks = []
+ tasks: List[asyncio.Task[None]] = []
for throttle in self.throttles.values():
curr_throttle = getattr(throttle, name)
if curr_throttle.limit:
@@ -484,7 +587,7 @@ async def wait(self, name):
if tasks:
await asyncio.wait(tasks)
- def append(self, name, data, start):
+ def append(self, name: str, data: bytes, start: float) -> None:
"""
Update timeout for all throttles
@@ -501,7 +604,8 @@ def append(self, name, data, start):
for throttle in self.throttles.values():
getattr(throttle, name).append(data, start)
- async def read(self, count=-1):
+ @override
+ async def read(self, count: int = -1) -> bytes:
"""
:py:func:`asyncio.coroutine`
@@ -513,7 +617,8 @@ async def read(self, count=-1):
self.append("read", data, start)
return data
- async def readline(self):
+ @override
+ async def readline(self) -> bytes:
"""
:py:func:`asyncio.coroutine`
@@ -525,7 +630,8 @@ async def readline(self):
self.append("read", data, start)
return data
- async def write(self, data):
+ @override
+ async def write(self, data: bytes) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -536,13 +642,18 @@ async def write(self, data):
await super().write(data)
self.append("write", data, start)
- async def __aenter__(self):
+ async def __aenter__(self) -> Self:
return self
- async def __aexit__(self, *args):
+ async def __aexit__(
+ self,
+ exc_type: Optional[Type[BaseException]],
+ exc: Optional[BaseException],
+ tb: Optional[TracebackType],
+ ) -> None:
self.close()
- def iter_by_line(self):
+ def iter_by_line(self) -> AsyncStreamIterator[bytes]:
"""
Read/iterate stream by line.
@@ -553,9 +664,9 @@ def iter_by_line(self):
>>> async for line in stream.iter_by_line():
... ...
"""
- return AsyncStreamIterator(self.readline)
+ return AsyncStreamIterator[bytes](self.readline)
- def iter_by_block(self, count=DEFAULT_BLOCK_SIZE):
+ def iter_by_block(self, count: int = DEFAULT_BLOCK_SIZE) -> AsyncStreamIterator[bytes]:
"""
Read/iterate stream by block.
@@ -573,7 +684,7 @@ def iter_by_block(self, count=DEFAULT_BLOCK_SIZE):
@contextmanager
-def setlocale(name):
+def setlocale(name: str) -> Generator[str, None, None]:
"""
Context manager with threading lock for set locale on enter, and set it
back to original state on exit.
diff --git a/src/aioftp/errors.py b/src/aioftp/errors.py
index 25af25c..a80e73a 100644
--- a/src/aioftp/errors.py
+++ b/src/aioftp/errors.py
@@ -1,5 +1,11 @@
+import typing
+from typing import Any, Iterable, Optional, Tuple, Union
+
from . import common
+if typing.TYPE_CHECKING:
+ from . import Code
+
__all__ = (
"AIOFTPException",
"StatusCodeError",
@@ -41,9 +47,14 @@ class StatusCodeError(AIOFTPException):
Exception members are tuples, even for one code.
"""
- def __init__(self, expected_codes, received_codes, info):
+ def __init__(
+ self,
+ expected_codes: Union[Tuple["Code", ...], "Code"],
+ received_codes: Union[Tuple["Code", ...], "Code"],
+ info: Iterable[str],
+ ) -> None:
super().__init__(
- f"Waiting for {expected_codes} but got " f"{received_codes} {info!r}",
+ f"Waiting for {expected_codes} but got {received_codes} {info!r}",
)
self.expected_codes = common.wrap_with_container(expected_codes)
self.received_codes = common.wrap_with_container(received_codes)
@@ -72,7 +83,7 @@ class PathIOError(AIOFTPException):
... # handle
"""
- def __init__(self, *args, reason=None, **kwargs):
+ def __init__(self, *args: Any, reason: Optional[Any] = None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.reason = reason
diff --git a/src/aioftp/pathio.py b/src/aioftp/pathio.py
index 8335ea6..9a6248d 100644
--- a/src/aioftp/pathio.py
+++ b/src/aioftp/pathio.py
@@ -1,13 +1,36 @@
import abc
import asyncio
-import collections
import functools
import io
import operator
+import os
import pathlib
import stat
import sys
import time
+import typing
+from concurrent import futures
+from typing import (
+ Any,
+ Awaitable,
+ BinaryIO,
+ Callable,
+ Coroutine,
+ Dict,
+ Generator,
+ List,
+ Literal,
+ NamedTuple,
+ Optional,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ overload,
+)
+
+from typing_extensions import ParamSpec, Protocol, Self, TypeAlias, override
from . import errors
from .common import (
@@ -16,6 +39,11 @@
AsyncStreamIterator,
with_timeout,
)
+from .types import OpenMode, StatsProtocol
+from .utils import get_param
+
+if typing.TYPE_CHECKING:
+ from .server import Connection
__all__ = (
"AbstractPathIO",
@@ -25,6 +53,36 @@
"PathIONursery",
)
+_Ps = ParamSpec("_Ps")
+_T = TypeVar("_T")
+_T_co = TypeVar("_T_co", covariant=True)
+
+
+class _DirNodeProtocol(Protocol):
+ type: Literal["dir"]
+ name: str
+ ctime: int
+ mtime: int
+ content: List[Union["_FileNodeProtocol", "_DirNodeProtocol"]]
+
+
+class _FileNodeProtocol(Protocol):
+ type: Literal["file"]
+ name: str
+ ctime: int
+ mtime: int
+ content: io.BytesIO
+
+
+_NodeProtocol: TypeAlias = Union[_DirNodeProtocol, _FileNodeProtocol]
+_NodeType: TypeAlias = Literal["file", "dir"]
+
+
+class SupportsNext(Protocol[_T_co]):
+ @abc.abstractmethod
+ def __next__(self) -> _T_co:
+ raise NotImplementedError
+
class AsyncPathIOContext:
"""
@@ -47,32 +105,34 @@ class AsyncPathIOContext:
"""
- def __init__(self, pathio, args, kwargs):
+ close: Optional["functools.partial[Awaitable[None]]"]
+
+ def __init__(self, pathio: "AbstractPathIO", args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> None:
self.close = None
self.pathio = pathio
self.args = args
self.kwargs = kwargs
- async def __aenter__(self):
- self.file = await self.pathio._open(*self.args, **self.kwargs)
+ async def __aenter__(self) -> Self:
+ self.file = await self.pathio._open(*self.args, **self.kwargs) # pyright: ignore[reportPrivateUsage]
self.seek = functools.partial(self.pathio.seek, self.file)
self.write = functools.partial(self.pathio.write, self.file)
self.read = functools.partial(self.pathio.read, self.file)
self.close = functools.partial(self.pathio.close, self.file)
return self
- async def __aexit__(self, *args):
+ async def __aexit__(self, *args: object) -> None:
if self.close is not None:
await self.close()
- def __await__(self):
+ def __await__(self) -> Generator[None, None, "AsyncPathIOContext"]:
return self.__aenter__().__await__()
- def iter_by_block(self, count=DEFAULT_BLOCK_SIZE):
+ def iter_by_block(self, count: int = DEFAULT_BLOCK_SIZE) -> AsyncStreamIterator[bytes]:
return AsyncStreamIterator(lambda: self.read(count))
-def universal_exception(coro):
+def universal_exception(coro: Callable[_Ps, Coroutine[None, None, _T]]) -> Callable[_Ps, Coroutine[None, None, _T]]:
"""
Decorator. Reraising any exception (except `CancelledError` and
`NotImplementedError`) with universal exception
@@ -80,7 +140,7 @@ def universal_exception(coro):
"""
@functools.wraps(coro)
- async def wrapper(*args, **kwargs):
+ async def wrapper(*args: _Ps.args, **kwargs: _Ps.kwargs) -> _T:
try:
return await coro(*args, **kwargs)
except (
@@ -96,30 +156,38 @@ async def wrapper(*args, **kwargs):
class PathIONursery:
- def __init__(self, factory):
+ state: Optional[Union[List[_NodeProtocol], io.BytesIO]]
+
+ def __init__(self, factory: Type["AbstractPathIO"]):
self.factory = factory
self.state = None
- def __call__(self, *args, **kwargs):
- instance = self.factory(*args, state=self.state, **kwargs)
+ def __call__(self, *args: Any, **kwargs: Any) -> "AbstractPathIO":
+ instance = self.factory(*args, state=self.state, **kwargs) # type: ignore
if self.state is None:
self.state = instance.state
return instance
-def defend_file_methods(coro):
+def defend_file_methods(
+ coro: Callable[_Ps, Coroutine[None, None, _T]],
+) -> Callable[_Ps, Coroutine[None, None, _T]]:
"""
Decorator. Raises exception when file methods called with wrapped by
:py:class:`aioftp.AsyncPathIOContext` file object.
"""
@functools.wraps(coro)
- async def wrapper(self, file, *args, **kwargs):
+ async def wrapper(
+ *args: _Ps.args,
+ **kwargs: _Ps.kwargs,
+ ) -> _T:
+ file = get_param((1, "file"), args, kwargs)
if isinstance(file, AsyncPathIOContext):
raise ValueError(
- "Native path io file methods can not be used " "with wrapped file object",
+ "Native path io file methods can not be used with wrapped file object",
)
- return await coro(self, file, *args, **kwargs)
+ return await coro(*args, **kwargs)
return wrapper
@@ -137,19 +205,25 @@ class AbstractPathIO(abc.ABC):
:param state: shared pathio state per server
"""
- def __init__(self, timeout=None, connection=None, state=None):
+ def __init__(
+ self,
+ timeout: Optional[Union[float, int]] = None,
+ connection: Optional["Connection"] = None,
+ state: Optional[List[_NodeProtocol]] = None,
+ ):
self.timeout = timeout
self.connection = connection
@property
- def state(self):
+ def state(self) -> Optional[Union[List[_NodeProtocol], io.BytesIO]]:
"""
Shared pathio state per server
"""
+ return None
@universal_exception
@abc.abstractmethod
- async def exists(self, path):
+ async def exists(self, path: pathlib.Path) -> bool:
"""
:py:func:`asyncio.coroutine`
@@ -163,7 +237,7 @@ async def exists(self, path):
@universal_exception
@abc.abstractmethod
- async def is_dir(self, path):
+ async def is_dir(self, path: pathlib.Path) -> bool:
"""
:py:func:`asyncio.coroutine`
@@ -177,7 +251,7 @@ async def is_dir(self, path):
@universal_exception
@abc.abstractmethod
- async def is_file(self, path):
+ async def is_file(self, path: pathlib.Path) -> bool:
"""
:py:func:`asyncio.coroutine`
@@ -191,7 +265,7 @@ async def is_file(self, path):
@universal_exception
@abc.abstractmethod
- async def mkdir(self, path, *, parents=False, exist_ok=False):
+ async def mkdir(self, path: pathlib.Path, *, parents: bool = False, exist_ok: bool = False) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -209,7 +283,7 @@ async def mkdir(self, path, *, parents=False, exist_ok=False):
@universal_exception
@abc.abstractmethod
- async def rmdir(self, path):
+ async def rmdir(self, path: pathlib.Path) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -221,7 +295,7 @@ async def rmdir(self, path):
@universal_exception
@abc.abstractmethod
- async def unlink(self, path):
+ async def unlink(self, path: pathlib.Path) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -232,7 +306,7 @@ async def unlink(self, path):
"""
@abc.abstractmethod
- def list(self, path):
+ def list(self, path: pathlib.Path) -> AbstractAsyncLister[Any]:
"""
Create instance of subclass of :py:class:`aioftp.AbstractAsyncLister`.
You should subclass and implement `__anext__` method
@@ -260,7 +334,7 @@ def list(self, path):
@universal_exception
@abc.abstractmethod
- async def stat(self, path):
+ async def stat(self, path: pathlib.Path) -> StatsProtocol:
"""
:py:func:`asyncio.coroutine`
@@ -276,7 +350,7 @@ async def stat(self, path):
@universal_exception
@abc.abstractmethod
- async def _open(self, path, mode):
+ async def _open(self, path: pathlib.Path, mode: OpenMode) -> BinaryIO:
"""
:py:func:`asyncio.coroutine`
@@ -295,7 +369,7 @@ async def _open(self, path, mode):
:return: file-object
"""
- def open(self, *args, **kwargs):
+ def open(self, *args: Any, **kwargs: Any) -> AsyncPathIOContext:
"""
Create instance of :py:class:`aioftp.pathio.AsyncPathIOContext`,
parameters passed to :py:meth:`aioftp.AbstractPathIO._open`
@@ -307,7 +381,7 @@ def open(self, *args, **kwargs):
@universal_exception
@defend_file_methods
@abc.abstractmethod
- async def seek(self, file, offset, whence=io.SEEK_SET):
+ async def seek(self, file: BinaryIO, offset: int, whence: int = io.SEEK_SET) -> int:
"""
:py:func:`asyncio.coroutine`
@@ -326,7 +400,7 @@ async def seek(self, file, offset, whence=io.SEEK_SET):
@universal_exception
@defend_file_methods
@abc.abstractmethod
- async def write(self, file, data):
+ async def write(self, file: BinaryIO, data: bytes) -> int:
"""
:py:func:`asyncio.coroutine`
@@ -341,7 +415,7 @@ async def write(self, file, data):
@universal_exception
@defend_file_methods
@abc.abstractmethod
- async def read(self, file, block_size):
+ async def read(self, file: BinaryIO, block_size: int) -> bytes:
"""
:py:func:`asyncio.coroutine`
@@ -358,7 +432,7 @@ async def read(self, file, block_size):
@universal_exception
@defend_file_methods
@abc.abstractmethod
- async def close(self, file):
+ async def close(self, file: BinaryIO) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -369,7 +443,7 @@ async def close(self, file):
@universal_exception
@abc.abstractmethod
- async def rename(self, source, destination):
+ async def rename(self, source: pathlib.Path, destination: pathlib.Path) -> Optional[pathlib.Path]:
"""
:py:func:`asyncio.coroutine`
@@ -388,36 +462,44 @@ class PathIO(AbstractPathIO):
Blocking path io. Directly based on :py:class:`pathlib.Path` methods.
"""
+ @override
@universal_exception
- async def exists(self, path):
+ async def exists(self, path: pathlib.Path) -> bool:
return path.exists()
+ @override
@universal_exception
- async def is_dir(self, path):
+ async def is_dir(self, path: pathlib.Path) -> bool:
return path.is_dir()
+ @override
@universal_exception
- async def is_file(self, path):
+ async def is_file(self, path: pathlib.Path) -> bool:
return path.is_file()
+ @override
@universal_exception
- async def mkdir(self, path, *, parents=False, exist_ok=False):
+ async def mkdir(self, path: pathlib.Path, *, parents: bool = False, exist_ok: bool = False) -> None:
return path.mkdir(parents=parents, exist_ok=exist_ok)
+ @override
@universal_exception
- async def rmdir(self, path):
+ async def rmdir(self, path: pathlib.Path) -> None:
return path.rmdir()
+ @override
@universal_exception
- async def unlink(self, path):
+ async def unlink(self, path: pathlib.Path) -> None:
return path.unlink()
- def list(self, path):
- class Lister(AbstractAsyncLister):
- iter = None
+ @override
+ def list(self, path: pathlib.Path) -> AbstractAsyncLister[pathlib.Path]:
+ class Lister(AbstractAsyncLister[pathlib.Path]):
+ iter: Optional[Generator[pathlib.Path, None, None]] = None
+ @override
@universal_exception
- async def __anext__(self):
+ async def __anext__(self) -> pathlib.Path:
if self.iter is None:
self.iter = path.glob("*")
try:
@@ -427,45 +509,58 @@ async def __anext__(self):
return Lister(timeout=self.timeout)
+ @override
@universal_exception
- async def stat(self, path):
+ async def stat(self, path: pathlib.Path) -> os.stat_result:
return path.stat()
+ @override
@universal_exception
- async def _open(self, path, *args, **kwargs):
- return path.open(*args, **kwargs)
+ async def _open(self, path: pathlib.Path, mode: OpenMode) -> BinaryIO:
+ return path.open(mode)
+ @override
@universal_exception
@defend_file_methods
- async def seek(self, file, *args, **kwargs):
+ async def seek(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int:
return file.seek(*args, **kwargs)
+ @override
@universal_exception
@defend_file_methods
- async def write(self, file, *args, **kwargs):
+ async def write(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int:
return file.write(*args, **kwargs)
+ @override
@universal_exception
@defend_file_methods
- async def read(self, file, *args, **kwargs):
+ async def read(self, file: BinaryIO, *args: Any, **kwargs: Any) -> bytes:
return file.read(*args, **kwargs)
+ @override
@universal_exception
@defend_file_methods
- async def close(self, file):
+ async def close(self, file: BinaryIO) -> None:
return file.close()
+ @override
@universal_exception
- async def rename(self, source, destination):
+ async def rename(self, source: pathlib.Path, destination: pathlib.Path) -> pathlib.Path:
return source.rename(destination)
-def _blocking_io(f):
+def _blocking_io(
+ f: Callable[_Ps, _T],
+) -> Callable[_Ps, Coroutine[None, None, _T]]:
@functools.wraps(f)
- async def wrapper(self, *args, **kwargs):
+ async def wrapper(
+ *args: _Ps.args,
+ **kwargs: _Ps.kwargs,
+ ) -> _T:
+ self: "AsyncPathIO" = get_param((0, "self"), args, kwargs)
return await asyncio.get_running_loop().run_in_executor(
self.executor,
- functools.partial(f, self, *args, **kwargs),
+ functools.partial(f, *args, **kwargs),
)
return wrapper
@@ -482,130 +577,195 @@ class AsyncPathIO(AbstractPathIO):
:type executor: :py:class:`concurrent.futures.Executor`
"""
- def __init__(self, *args, executor=None, **kwargs):
+ executor: Optional[futures.Executor]
+
+ def __init__(
+ self,
+ *args: Any,
+ executor: Optional[futures.Executor] = None,
+ **kwargs: Any,
+ ) -> None:
super().__init__(*args, **kwargs)
self.executor = executor
+ @override
@universal_exception
@with_timeout
@_blocking_io
- def exists(self, path):
+ def exists(self, path: pathlib.Path) -> bool:
return path.exists()
+ @override
@universal_exception
@with_timeout
@_blocking_io
- def is_dir(self, path):
+ def is_dir(self, path: pathlib.Path) -> bool:
return path.is_dir()
+ @override
@universal_exception
@with_timeout
@_blocking_io
- def is_file(self, path):
+ def is_file(self, path: pathlib.Path) -> bool:
return path.is_file()
+ @override
@universal_exception
@with_timeout
@_blocking_io
- def mkdir(self, path, *, parents=False, exist_ok=False):
+ def mkdir(
+ self,
+ path: pathlib.Path,
+ *,
+ parents: bool = False,
+ exist_ok: bool = False,
+ ) -> None:
return path.mkdir(parents=parents, exist_ok=exist_ok)
+ @override
@universal_exception
@with_timeout
@_blocking_io
- def rmdir(self, path):
+ def rmdir(self, path: pathlib.Path) -> None:
return path.rmdir()
+ @override
@universal_exception
@with_timeout
@_blocking_io
- def unlink(self, path):
+ def unlink(self, path: pathlib.Path) -> None:
return path.unlink()
- def list(self, path):
- class Lister(AbstractAsyncLister):
- iter = None
+ @override
+ def list(self, path: pathlib.Path) -> AbstractAsyncLister[pathlib.Path]:
+ class Lister(AbstractAsyncLister[pathlib.Path]):
+ iter: Optional[SupportsNext[pathlib.Path]] = None
- def __init__(self, *args, executor=None, **kwargs):
+ def __init__(self, *args: Any, executor: Optional[futures.Executor] = None, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.executor = executor
- def worker(self):
+ def worker(self) -> pathlib.Path:
try:
- return next(self.iter)
+ return next(self.iter) # type: ignore
except StopIteration:
raise StopAsyncIteration
+ @override
@universal_exception
@with_timeout
@_blocking_io
- def __anext__(self):
+ def __anext__(self) -> pathlib.Path:
if self.iter is None:
self.iter = path.glob("*")
return self.worker()
return Lister(timeout=self.timeout, executor=self.executor)
+ @override
@universal_exception
@with_timeout
@_blocking_io
- def stat(self, path):
+ def stat(self, path: pathlib.Path) -> os.stat_result:
return path.stat()
+ @override
@universal_exception
@with_timeout
@_blocking_io
- def _open(self, path, *args, **kwargs):
- return path.open(*args, **kwargs)
+ def _open(self, path: pathlib.Path, *args: Any, **kwargs: Any) -> BinaryIO:
+ return cast(BinaryIO, path.open(*args, **kwargs))
+ @override
@universal_exception
@defend_file_methods
@with_timeout
@_blocking_io
- def seek(self, file, *args, **kwargs):
+ def seek(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int:
return file.seek(*args, **kwargs)
+ @override
@universal_exception
@defend_file_methods
@with_timeout
@_blocking_io
- def write(self, file, *args, **kwargs):
+ def write(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int:
return file.write(*args, **kwargs)
+ @override
@universal_exception
@defend_file_methods
@with_timeout
@_blocking_io
- def read(self, file, *args, **kwargs):
+ def read(self, file: BinaryIO, *args: Any, **kwargs: Any) -> bytes:
return file.read(*args, **kwargs)
+ @override
@universal_exception
@defend_file_methods
@with_timeout
@_blocking_io
- def close(self, file):
+ def close(self, file: BinaryIO) -> None:
return file.close()
+ @override
@universal_exception
@with_timeout
@_blocking_io
- def rename(self, source, destination):
+ def rename(self, source: pathlib.Path, destination: pathlib.Path) -> pathlib.Path:
return source.rename(destination)
class Node:
- def __init__(self, type, name, ctime=None, mtime=None, *, content):
+ type: _NodeType
+ name: str
+ ctime: int
+ mtime: int
+ content: Union[List["Node"], io.BytesIO]
+
+ @overload
+ def __init__(
+ self,
+ type: Literal["file"],
+ name: str,
+ ctime: Optional[int] = None,
+ mtime: Optional[int] = None,
+ *,
+ content: io.BytesIO,
+ ) -> None: ...
+
+ @overload
+ def __init__(
+ self,
+ type: Literal["dir"],
+ name: str,
+ ctime: Optional[int] = None,
+ mtime: Optional[int] = None,
+ *,
+ content: List["Node"],
+ ) -> None: ...
+
+ def __init__(
+ self,
+ type: _NodeType,
+ name: str,
+ ctime: Optional[int] = None,
+ mtime: Optional[int] = None,
+ *,
+ content: Union[List["Node"], io.BytesIO],
+ ) -> None:
self.type = type
self.name = name
self.ctime = ctime or int(time.time())
self.mtime = mtime or int(time.time())
self.content = content
- def __repr__(self):
+ @override
+ def __repr__(self) -> str:
return (
f"{self.__class__.__name__}(type={self.type!r}, "
f"name={self.name!r}, ctime={self.ctime!r}, "
- f"mtime={self.mtime!r}, content={self.content!r})"
+ f"mtime={self.mtime!r}, content={self.content!r}"
)
@@ -615,132 +775,159 @@ class MemoryPathIO(AbstractPathIO):
and probably not so fast as it can be.
"""
- Stats = collections.namedtuple(
- "Stats",
- (
- "st_size",
- "st_ctime",
- "st_mtime",
- "st_nlink",
- "st_mode",
- ),
- )
-
- def __init__(self, *args, state=None, cwd=None, **kwargs):
+ class Stats(NamedTuple):
+ st_size: int
+ st_ctime: int
+ st_mtime: int
+ st_nlink: int
+ st_mode: int
+
+ fs: Union[List[_NodeProtocol], io.BytesIO]
+
+ def __init__(
+ self,
+ *args: Any,
+ state: Optional[Union[List[_NodeProtocol], io.BytesIO]] = None,
+ cwd: Optional[str] = None,
+ **kwargs: Any,
+ ) -> None:
super().__init__(*args, **kwargs)
self.cwd = pathlib.PurePosixPath(cwd or "/")
if state is None:
- self.fs = [Node("dir", "/", content=[])]
+ self.fs = [cast(_DirNodeProtocol, Node("dir", "/", content=[]))]
else:
self.fs = state
@property
- def state(self):
+ @override
+ def state(self) -> Union[List[_NodeProtocol], io.BytesIO]:
return self.fs
- def __repr__(self):
+ @override
+ def __repr__(self) -> str:
return repr(self.fs)
- def _absolute(self, path):
+ def _absolute(self, path: pathlib.PurePath) -> pathlib.PurePath:
if not path.is_absolute():
path = self.cwd / path
return path
- def get_node(self, path):
- nodes = self.fs
+ def get_node(self, path: pathlib.PurePath) -> Optional[_NodeProtocol]:
+ nodes: Union[List[_NodeProtocol], io.BytesIO] = self.fs
node = None
- path = self._absolute(path)
- for part in path.parts:
+ path_ = self._absolute(path)
+ for part in path_.parts:
if not isinstance(nodes, list):
- return
+ return None
for node in nodes:
if node.name == part:
nodes = node.content
break
else:
- return
+ return None
return node
+ @override
@universal_exception
- async def exists(self, path):
+ async def exists(self, path: pathlib.Path) -> bool:
return self.get_node(path) is not None
+ @override
@universal_exception
- async def is_dir(self, path):
+ async def is_dir(self, path: pathlib.Path) -> bool:
node = self.get_node(path)
return not (node is None or node.type != "dir")
+ @override
@universal_exception
- async def is_file(self, path):
+ async def is_file(self, path: pathlib.Path) -> bool:
node = self.get_node(path)
return not (node is None or node.type != "file")
+ @override
@universal_exception
- async def mkdir(self, path, *, parents=False, exist_ok=False):
- path = self._absolute(path)
- node = self.get_node(path)
+ async def mkdir(
+ self,
+ path: pathlib.Path,
+ *,
+ parents: bool = False,
+ exist_ok: bool = False,
+ ) -> None:
+ path_ = self._absolute(path)
+ node = self.get_node(path_)
if node:
if node.type != "dir" or not exist_ok:
raise FileExistsError
elif not parents:
- parent = self.get_node(path.parent)
+ parent = self.get_node(path_.parent)
if parent is None:
raise FileNotFoundError
if parent.type != "dir":
raise NotADirectoryError
- node = Node("dir", path.name, content=[])
+ node = cast(_DirNodeProtocol, Node("dir", path.name, content=[]))
parent.content.append(node)
else:
- nodes = self.fs
- for part in path.parts:
+ nodes: Union[List[_NodeProtocol], io.BytesIO] = self.fs
+ for part in path_.parts:
if isinstance(nodes, list):
for node in nodes:
if node.name == part:
nodes = node.content
break
else:
- node = Node("dir", part, content=[])
+ node = cast(_DirNodeProtocol, Node("dir", part, content=[]))
nodes.append(node)
nodes = node.content
else:
raise NotADirectoryError
+ @override
@universal_exception
- async def rmdir(self, path):
+ async def rmdir(self, path: pathlib.Path) -> None:
node = self.get_node(path)
if node is None:
raise FileNotFoundError
if node.type != "dir":
raise NotADirectoryError
+
if node.content:
raise OSError("Directory not empty")
- parent = self.get_node(path.parent)
+ parent = cast(_DirNodeProtocol, self.get_node(path.parent))
for i, node in enumerate(parent.content):
if node.name == path.name:
break
+ else:
+ return None
+
parent.content.pop(i)
+ @override
@universal_exception
- async def unlink(self, path):
+ async def unlink(self, path: pathlib.Path) -> None:
node = self.get_node(path)
if node is None:
raise FileNotFoundError
if node.type != "file":
raise IsADirectoryError
- parent = self.get_node(path.parent)
+ parent = cast(_DirNodeProtocol, self.get_node(path.parent))
for i, node in enumerate(parent.content):
if node.name == path.name:
break
+ else:
+ return None
+
parent.content.pop(i)
- def list(self, path):
- class Lister(AbstractAsyncLister):
- iter = None
+ @override
+ def list(self, path: pathlib.Path) -> "AbstractAsyncLister[pathlib.Path]":
+ class Lister(AbstractAsyncLister[pathlib.Path]):
+ iter: Optional[SupportsNext[pathlib.Path]] = None
+ @override
@universal_exception
- async def __anext__(cls):
+ async def __anext__(cls) -> pathlib.Path:
if cls.iter is None:
node = self.get_node(path)
if node is None or node.type != "dir":
@@ -756,8 +943,9 @@ async def __anext__(cls):
return Lister(timeout=self.timeout)
+ @override
@universal_exception
- async def stat(self, path):
+ async def stat(self, path: pathlib.Path) -> "MemoryPathIO.Stats":
node = self.get_node(path)
if node is None:
raise FileNotFoundError
@@ -776,21 +964,22 @@ async def stat(self, path):
mode,
)
+ @override
@universal_exception
- async def _open(self, path, mode="rb", *args, **kwargs):
+ async def _open(self, path: pathlib.Path, mode: OpenMode = "rb", *args: Any, **kwargs: Any) -> io.BytesIO:
if mode == "rb":
- node = self.get_node(path)
+ node = cast(Optional[_FileNodeProtocol], self.get_node(path))
if node is None:
raise FileNotFoundError
file_like = node.content
file_like.seek(0, io.SEEK_SET)
elif mode in ("wb", "ab", "r+b"):
- node = self.get_node(path)
+ node = cast(Optional[_FileNodeProtocol], self.get_node(path))
if node is None:
parent = self.get_node(path.parent)
if parent is None or parent.type != "dir":
raise FileNotFoundError
- new_node = Node("file", path.name, content=io.BytesIO())
+ new_node = cast(_FileNodeProtocol, Node("file", path.name, content=io.BytesIO()))
parent.content.append(new_node)
file_like = new_node.content
elif node.type != "file":
@@ -806,36 +995,42 @@ async def _open(self, path, mode="rb", *args, **kwargs):
file_like.seek(0, io.SEEK_SET)
else:
raise ValueError(f"invalid mode: {mode}")
- return file_like
+ return file_like # type: ignore
+ @override
@universal_exception
@defend_file_methods
- async def seek(self, file, *args, **kwargs):
+ async def seek(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int:
return file.seek(*args, **kwargs)
+ @override
@universal_exception
@defend_file_methods
- async def write(self, file, *args, **kwargs):
- file.write(*args, **kwargs)
- file.mtime = int(time.time())
+ async def write(self, file: BinaryIO, *args: Any, **kwargs: Any) -> int:
+ result = file.write(*args, **kwargs)
+ file.mtime = int(time.time()) # type: ignore
+ return result
+ @override
@universal_exception
@defend_file_methods
- async def read(self, file, *args, **kwargs):
+ async def read(self, file: BinaryIO, *args: Any, **kwargs: Any) -> bytes:
return file.read(*args, **kwargs)
+ @override
@universal_exception
@defend_file_methods
- async def close(self, file):
+ async def close(self, file: BinaryIO) -> None:
pass
+ @override
@universal_exception
- async def rename(self, source, destination):
+ async def rename(self, source: pathlib.Path, destination: pathlib.Path) -> Optional[pathlib.Path]:
if source != destination:
- sparent = self.get_node(source.parent)
- dparent = self.get_node(destination.parent)
+ sparent = cast(_DirNodeProtocol, self.get_node(source.parent))
+ dparent = cast(Optional[_DirNodeProtocol], self.get_node(destination.parent))
snode = self.get_node(source)
- if None in (snode, dparent):
+ if snode is None or dparent is None:
raise FileNotFoundError
for i, node in enumerate(sparent.content):
if node.name == source.name:
@@ -847,3 +1042,4 @@ async def rename(self, source, destination):
break
else:
dparent.content.append(snode)
+ return None
diff --git a/src/aioftp/server.py b/src/aioftp/server.py
index 1830d78..a16f421 100644
--- a/src/aioftp/server.py
+++ b/src/aioftp/server.py
@@ -1,6 +1,5 @@
import abc
import asyncio
-import collections
import enum
import errno
import functools
@@ -8,19 +7,46 @@
import pathlib
import socket
import stat
-import sys
import time
+from ssl import SSLContext
+from typing import (
+ Any,
+ Callable,
+ ClassVar,
+ Coroutine,
+ DefaultDict,
+ Dict,
+ Generic,
+ Iterable,
+ List,
+ NamedTuple,
+ Optional,
+ Protocol,
+ Sequence,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ cast,
+ final,
+)
+
+from typing_extensions import Literal, NotRequired, ParamSpec, Self, TypeAlias, TypedDict, override
-from . import errors, pathio
+from . import client, errors, pathio
from .common import (
DEFAULT_BLOCK_SIZE,
END_OF_LINE,
HALF_OF_YEAR_IN_SECONDS,
+ StreamIO,
StreamThrottle,
ThrottleStreamIO,
setlocale,
wrap_with_container,
)
+from .types import OpenMode, StatsProtocol
+from .utils import get_param
__all__ = (
"Permission",
@@ -36,14 +62,69 @@
"Server",
)
-IS_PY37_PLUS = sys.version_info[:2] >= (3, 7)
-if IS_PY37_PLUS:
- get_current_task = asyncio.current_task
-else:
- get_current_task = asyncio.Task.current_task
logger = logging.getLogger(__name__)
+_PS = ParamSpec("_PS")
+_T = TypeVar("_T")
+
+_PathType: TypeAlias = Union[str, pathlib.PurePosixPath]
+
+
+_FAIL_CODE: TypeAlias = Literal["503", "425"]
+
+
+class _ConnectionCondition(NamedTuple):
+ name: str
+ message: str
+
+
+_Future: TypeAlias = "asyncio.Future[Any]"
+
+
+class ConnectionProtocol(Protocol):
+ extra_workers: Set["asyncio.Task[Any]"]
+ client_host: str
+ server_host: str
+ passive_server_port: int
+ server_port: int
+ client_port: int
+ command_connection: ThrottleStreamIO
+ data_connection: ThrottleStreamIO
+ socket_timeout: int
+ wait_future_timeout: int
+ block_size: int
+ path_io_factory: pathio.AbstractPathIO
+ path_timeout: Optional[Union[int, float]]
+ user: "User"
+ response: Callable[..., None]
+ acquired: bool
+ restart_offset: int
+ current_directory: pathlib.PurePosixPath
+ _dispatcher: "asyncio.Task[Any]"
+ path_io: pathio.AbstractPathIO
+ passive_server: asyncio.base_events.Server
+ rename_from: pathlib.Path
+ logged: object
+ transfer_type: str
+
+ future: "Connection.Container[Any]"
+
+ def __getitem__(self, name: str) -> _Future: ...
+
+
+class _PathCondition(NamedTuple):
+ name: str
+ fail: bool
+ message: str
+
+
+class _MLSXFacts(TypedDict):
+ Size: int
+ Create: str
+ Modify: str
+ Type: NotRequired[Literal["dir", "file", "unknown"]]
+
class Permission:
"""
@@ -59,19 +140,26 @@ class Permission:
:type writable: :py:class:`bool`
"""
- def __init__(self, path="/", *, readable=True, writable=True):
+ def __init__(
+ self,
+ path: Union[str, pathlib.PurePosixPath] = "/",
+ *,
+ readable: bool = True,
+ writable: bool = True,
+ ) -> None:
self.path = pathlib.PurePosixPath(path)
self.readable = readable
self.writable = writable
- def is_parent(self, other):
+ def is_parent(self, other: pathlib.PurePosixPath) -> bool:
try:
other.relative_to(self.path)
return True
except ValueError:
return False
- def __repr__(self):
+ @override
+ def __repr__(self) -> str:
return f"{self.__class__.__name__}({self.path!r}, " f"readable={self.readable!r}, writable={self.writable!r})"
@@ -116,18 +204,18 @@ class User:
def __init__(
self,
- login=None,
- password=None,
+ login: Optional[str] = None,
+ password: Optional[str] = None,
*,
- base_path=pathlib.Path("."),
- home_path=pathlib.PurePosixPath("/"),
- permissions=None,
- maximum_connections=None,
- read_speed_limit=None,
- write_speed_limit=None,
- read_speed_limit_per_connection=None,
- write_speed_limit_per_connection=None,
- ):
+ base_path: Union[str, pathlib.Path] = pathlib.Path("."),
+ home_path: Union[str, pathlib.PurePosixPath] = pathlib.PurePosixPath("/"),
+ permissions: Optional[Iterable[Permission]] = None,
+ maximum_connections: Optional[int] = None,
+ read_speed_limit: Optional[int] = None,
+ write_speed_limit: Optional[int] = None,
+ read_speed_limit_per_connection: Optional[int] = None,
+ write_speed_limit_per_connection: Optional[int] = None,
+ ) -> None:
self.login = login
self.password = password
self.base_path = pathlib.Path(base_path)
@@ -142,7 +230,7 @@ def __init__(
# damn 80 symbols
self.write_speed_limit_per_connection = write_speed_limit_per_connection
- async def get_permissions(self, path):
+ async def get_permissions(self, path: Union[str, pathlib.PurePosixPath]) -> Permission:
"""
Return nearest parent permission for `path`.
@@ -151,16 +239,17 @@ async def get_permissions(self, path):
:rtype: :py:class:`aioftp.Permission`
"""
- path = pathlib.PurePosixPath(path)
- parents = filter(lambda p: p.is_parent(path), self.permissions)
+ path_ = pathlib.PurePosixPath(path)
+ parents = filter(lambda p: p.is_parent(path_), self.permissions)
perm = min(
parents,
- key=lambda p: len(path.relative_to(p.path).parts),
+ key=lambda p: len(path_.relative_to(p.path).parts),
default=Permission(),
)
return perm
- def __repr__(self):
+ @override
+ def __repr__(self) -> str:
return (
f"{self.__class__.__name__}({self.login!r}, "
f"{self.password!r}, base_path={self.base_path!r}, "
@@ -184,16 +273,16 @@ class AbstractUserManager(abc.ABC):
:type timeout: :py:class:`float`, :py:class:`int` or :py:class:`None`
"""
- GetUserResponse = enum.Enum(
- "UserManagerResponse",
- "OK PASSWORD_REQUIRED ERROR",
- )
+ class GetUserResponse(enum.Enum):
+ OK = enum.auto()
+ PASSWORD_REQUIRED = enum.auto()
+ ERROR = enum.auto()
- def __init__(self, *, timeout=None):
+ def __init__(self, *, timeout: Optional[Union[float, int]] = None) -> None:
self.timeout = timeout
@abc.abstractmethod
- async def get_user(self, login):
+ async def get_user(self, login: str) -> Tuple[GetUserResponse, Optional[User], str]:
"""
:py:func:`asyncio.coroutine`
@@ -204,7 +293,7 @@ async def get_user(self, login):
"""
@abc.abstractmethod
- async def authenticate(self, user, password):
+ async def authenticate(self, user: User, password: str) -> bool:
"""
:py:func:`asyncio.coroutine`
@@ -219,7 +308,7 @@ async def authenticate(self, user, password):
:rtype: :py:class:`bool`
"""
- async def notify_logout(self, user):
+ async def notify_logout(self, user: User) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -239,12 +328,13 @@ class MemoryUserManager(AbstractUserManager):
:py:class:`aioftp.User`
"""
- def __init__(self, users, *args, **kwargs):
+ def __init__(self, users: Optional[Sequence[User]], *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self.users = users or [User()]
self.available_connections = dict((user, AvailableConnections(user.maximum_connections)) for user in self.users)
- async def get_user(self, login):
+ @override
+ async def get_user(self, login: str) -> Tuple[AbstractUserManager.GetUserResponse, Optional[User], str]:
user = None
for u in self.users:
if u.login is None and user is None:
@@ -269,17 +359,19 @@ async def get_user(self, login):
info = "password required"
if state != AbstractUserManager.GetUserResponse.ERROR:
- self.available_connections[user].acquire()
+ self.available_connections[user].acquire() # type: ignore
return state, user, info
- async def authenticate(self, user, password):
+ @override
+ async def authenticate(self, user: User, password: str) -> bool:
return user.password == password
- async def notify_logout(self, user):
+ @override
+ async def notify_logout(self, user: User) -> None:
self.available_connections[user].release()
-class Connection(collections.defaultdict):
+class Connection(DefaultDict[str, _Future]):
"""
Connection state container for transparent work with futures for async
wait
@@ -311,37 +403,40 @@ class Connection(collections.defaultdict):
__slots__ = ("future",)
- class Container:
- def __init__(self, storage):
+ class Container(Generic[_T]):
+ def __init__(self, storage: Dict[str, "asyncio.Future[_T]"]):
self.storage = storage
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> "asyncio.Future[_T]":
return self.storage[name]
- def __delattr__(self, name):
+ @override
+ def __delattr__(self, name: str) -> None:
self.storage.pop(name)
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: Any) -> None:
super().__init__(asyncio.Future)
- self.future = Connection.Container(self)
+ self.future = Connection.Container[Any](self)
for k, v in kwargs.items():
self[k].set_result(v)
- def __getattr__(self, name):
+ def __getattr__(self, name: str) -> Any:
if name in self:
return self[name].result()
else:
raise AttributeError(f"{name!r} not in storage")
- def __setattr__(self, name, value):
+ @override
+ def __setattr__(self, name: str, value: Any) -> None:
if name in Connection.__slots__:
super().__setattr__(name, value)
else:
if self[name].done():
- self[name] = super().default_factory()
+ self[name] = cast(Callable[[], _Future], super().default_factory)()
self[name].set_result(value)
- def __delattr__(self, name):
+ @override
+ def __delattr__(self, name: str) -> None:
if name in self:
self.pop(name)
@@ -355,10 +450,10 @@ class AvailableConnections:
:type value: :py:class:`int` or :py:class:`None`
"""
- def __init__(self, value=None):
+ def __init__(self, value: Optional[int] = None):
self.value = self.maximum_value = value
- def locked(self):
+ def locked(self) -> bool:
"""
Returns True if semaphore-like can not be acquired.
@@ -366,7 +461,7 @@ def locked(self):
"""
return self.value == 0
- def acquire(self):
+ def acquire(self) -> None:
"""
Acquire, decrementing the internal counter by one.
"""
@@ -375,13 +470,13 @@ def acquire(self):
if self.value < 0:
raise ValueError("Too many acquires")
- def release(self):
+ def release(self) -> None:
"""
Release, incrementing the internal counter by one.
"""
if self.value is not None:
self.value += 1
- if self.value > self.maximum_value:
+ if self.value > cast(int, self.maximum_value): # self.maximum_value exist if exist self.value
raise ValueError("Too many releases")
@@ -424,27 +519,59 @@ class ConnectionConditions:
... ...
"""
- user_required = ("user", "no user (use USER firstly)")
- login_required = ("logged", "not logged in")
- passive_server_started = (
+ user_required: ClassVar[_ConnectionCondition] = _ConnectionCondition("user", "no user (use USER firstly)")
+ login_required: ClassVar[_ConnectionCondition] = _ConnectionCondition("logged", "not logged in")
+ passive_server_started: ClassVar[_ConnectionCondition] = _ConnectionCondition(
"passive_server",
"no listen socket created (use PASV firstly)",
)
- data_connection_made = ("data_connection", "no data connection made")
- rename_from_required = ("rename_from", "no filename (use RNFR firstly)")
+ data_connection_made: ClassVar[_ConnectionCondition] = _ConnectionCondition(
+ "data_connection",
+ "no data connection made",
+ )
+ rename_from_required: ClassVar[_ConnectionCondition] = _ConnectionCondition(
+ "rename_from",
+ "no filename (use RNFR firstly)",
+ )
- def __init__(self, *fields, wait=False, fail_code="503", fail_info=None):
+ def __init__(
+ self,
+ *fields: _ConnectionCondition,
+ wait: bool = False,
+ fail_code: _FAIL_CODE = "503",
+ fail_info: Optional[str] = None,
+ ) -> None:
self.fields = fields
self.wait = wait
self.fail_code = fail_code
self.fail_info = fail_info
- def __call__(self, f):
+ def __call__(
+ self,
+ f: Callable[_PS, Coroutine[None, None, _T]],
+ ) -> Callable[
+ _PS,
+ Coroutine[
+ None,
+ None,
+ Union[
+ _T,
+ Literal[True],
+ ],
+ ],
+ ]:
+ ctx = self
+
@functools.wraps(f)
- async def wrapper(cls, connection, rest, *args):
- futures = {connection[name]: msg for name, msg in self.fields}
+ async def wrapper(
+ *args: _PS.args,
+ **kwargs: _PS.kwargs,
+ ) -> Union[_T, Literal[True]]:
+ connection: ConnectionProtocol = get_param((1, "connection"), args, kwargs)
+
+ futures: Dict[_Future, str] = {connection[name]: msg for name, msg in ctx.fields}
aggregate = asyncio.gather(*futures)
- if self.wait:
+ if ctx.wait:
timeout = connection.wait_future_timeout
else:
timeout = 0
@@ -457,17 +584,18 @@ async def wrapper(cls, connection, rest, *args):
except asyncio.TimeoutError:
for future, message in futures.items():
if not future.done():
- if self.fail_info is None:
+ if ctx.fail_info is None:
info = f"bad sequence of commands ({message})"
else:
- info = self.fail_info
- connection.response(self.fail_code, info)
+ info = ctx.fail_info
+ connection.response(ctx.fail_code, info)
return True
- return await f(cls, connection, rest, *args)
+ return await f(*args, **kwargs)
return wrapper
+@final
class PathConditions:
"""
Decorator for checking paths. Available options:
@@ -486,28 +614,51 @@ class PathConditions:
... ...
"""
- path_must_exists = ("exists", False, "path does not exists")
- path_must_not_exists = ("exists", True, "path already exists")
- path_must_be_dir = ("is_dir", False, "path is not a directory")
- path_must_be_file = ("is_file", False, "path is not a file")
+ path_must_exists: ClassVar[_PathCondition] = _PathCondition("exists", False, "path does not exists")
+ path_must_not_exists: ClassVar[_PathCondition] = _PathCondition("exists", True, "path already exists")
+ path_must_be_dir: ClassVar[_PathCondition] = _PathCondition("is_dir", False, "path is not a directory")
+ path_must_be_file: ClassVar[_PathCondition] = _PathCondition("is_file", False, "path is not a file")
- def __init__(self, *conditions):
+ def __init__(self, *conditions: _PathCondition) -> None:
self.conditions = conditions
- def __call__(self, f):
+ def __call__(
+ self,
+ f: Callable[_PS, Coroutine[None, None, _T]],
+ ) -> Callable[
+ _PS,
+ Coroutine[
+ None,
+ None,
+ Union[
+ _T,
+ Literal[True],
+ ],
+ ],
+ ]:
+ ctx = self
+
@functools.wraps(f)
- async def wrapper(cls, connection, rest, *args):
- real_path, virtual_path = cls.get_paths(connection, rest)
- for name, fail, message in self.conditions:
+ async def wrapper(
+ *args: _PS.args,
+ **kwargs: _PS.kwargs,
+ ) -> Union[_T, Literal[True]]:
+ self: Server = get_param((0, "self"), args, kwargs)
+ connection: ConnectionProtocol = get_param((1, "connection"), args, kwargs)
+ rest: pathlib.PurePosixPath = get_param((2, "rest"), args, kwargs)
+
+ real_path, _ = self.get_paths(connection, rest)
+ for name, fail, message in ctx.conditions:
coro = getattr(connection.path_io, name)
if await coro(real_path) == fail:
connection.response("550", message)
return True
- return await f(cls, connection, rest, *args)
+ return await f(*args, **kwargs)
return wrapper
+@final
class PathPermissions:
"""
Decorator for checking path permissions. There is two permissions right
@@ -528,29 +679,52 @@ class PathPermissions:
... ...
"""
- readable = "readable"
- writable = "writable"
+ readable: ClassVar[str] = "readable"
+ writable: ClassVar[str] = "writable"
- def __init__(self, *permissions):
+ def __init__(self, *permissions: str) -> None:
self.permissions = permissions
- def __call__(self, f):
+ def __call__(
+ self,
+ f: Callable[_PS, Coroutine[None, None, _T]],
+ ) -> Callable[
+ _PS,
+ Coroutine[
+ None,
+ None,
+ Optional[
+ Union[
+ _T,
+ Literal[True],
+ ]
+ ],
+ ],
+ ]:
+ ctx = self
+
@functools.wraps(f)
- async def wrapper(cls, connection, rest, *args):
- real_path, virtual_path = cls.get_paths(connection, rest)
- current_permission = await connection.user.get_permissions(
- virtual_path,
- )
- for permission in self.permissions:
+ async def wrapper(
+ *args: _PS.args,
+ **kwargs: _PS.kwargs,
+ ) -> Optional[Union[_T, Literal[True]]]:
+ self: Server = get_param((0, "self"), args, kwargs)
+ connection: ConnectionProtocol = get_param((1, "connection"), args, kwargs)
+ rest: pathlib.PurePosixPath = get_param((2, "rest"), args, kwargs)
+
+ _, virtual_path = self.get_paths(connection, rest)
+ current_permission = await connection.user.get_permissions(virtual_path)
+ for permission in ctx.permissions:
if not getattr(current_permission, permission):
connection.response("550", "permission denied")
return True
- return await f(cls, connection, rest, *args)
+ return await f(*args, **kwargs)
+ return None
return wrapper
-def worker(f):
+def worker(f: Callable[_PS, Coroutine[None, None, None]]) -> Callable[_PS, Coroutine[None, None, None]]:
"""
Decorator. Abortable worker. If wrapped task will be cancelled by
dispatcher, decorator will send ftp codes of successful interrupt.
@@ -564,12 +738,15 @@ def worker(f):
"""
@functools.wraps(f)
- async def wrapper(cls, connection, rest):
+ async def wrapper(*args: _PS.args, **kwargs: _PS.kwargs) -> None:
+ connection: ConnectionProtocol = get_param((1, "connection"), args, kwargs)
+
try:
- await f(cls, connection, rest)
+ await f(*args, **kwargs)
except asyncio.CancelledError:
connection.response("426", "transfer aborted")
connection.response("226", "abort successful")
+ return None
return wrapper
@@ -641,25 +818,36 @@ class Server:
:type ssl: :py:class:`ssl.SSLContext`
"""
+ available_data_ports: Optional["asyncio.PriorityQueue[Tuple[int, int]]"]
+ throttle_per_user: Dict[User, StreamThrottle]
+ commands_mapping: Dict[str, Callable[..., Any]]
+ connections: Dict[StreamIO, ConnectionProtocol]
+
def __init__(
self,
- users=None,
+ users: Optional[
+ Union[
+ Tuple[User],
+ List[User],
+ AbstractUserManager,
+ ]
+ ] = None,
*,
- block_size=DEFAULT_BLOCK_SIZE,
- socket_timeout=None,
- idle_timeout=None,
- wait_future_timeout=1,
- path_timeout=None,
- path_io_factory=pathio.PathIO,
- maximum_connections=None,
- read_speed_limit=None,
- write_speed_limit=None,
- read_speed_limit_per_connection=None,
- write_speed_limit_per_connection=None,
- ipv4_pasv_forced_response_address=None,
- data_ports=None,
- encoding="utf-8",
- ssl=None,
+ block_size: int = DEFAULT_BLOCK_SIZE,
+ socket_timeout: Optional[Union[int, float]] = None,
+ idle_timeout: Optional[Union[int, float]] = None,
+ wait_future_timeout: Optional[Union[int, float]] = 1,
+ path_timeout: Optional[Union[int, float]] = None,
+ path_io_factory: Type[pathio.AbstractPathIO] = pathio.PathIO,
+ maximum_connections: Optional[int] = None,
+ read_speed_limit: Optional[int] = None,
+ write_speed_limit: Optional[int] = None,
+ read_speed_limit_per_connection: Optional[int] = None,
+ write_speed_limit_per_connection: Optional[int] = None,
+ ipv4_pasv_forced_response_address: Optional[str] = None,
+ data_ports: Optional[Iterable[int]] = None,
+ encoding: str = "utf-8",
+ ssl: Optional[SSLContext] = None,
):
self.block_size = block_size
self.socket_timeout = socket_timeout
@@ -720,7 +908,7 @@ def __init__(
"user": self.user,
}
- async def start(self, host=None, port=0, **kwargs):
+ async def start(self, host: Optional[str] = None, port: int = 0, **kwargs: Any) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -755,7 +943,7 @@ async def start(self, host=None, port=0, **kwargs):
self.server_host = host
logger.info("serving on %s:%s", host, port)
- async def serve_forever(self):
+ async def serve_forever(self) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -763,7 +951,7 @@ async def serve_forever(self):
"""
return await self.server.serve_forever()
- async def run(self, host=None, port=0, **kwargs):
+ async def run(self, host: Optional[str] = None, port: int = 0, **kwargs: Any) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -785,13 +973,13 @@ async def run(self, host=None, port=0, **kwargs):
await self.close()
@property
- def address(self):
+ def address(self) -> Tuple[Optional[str], int]:
"""
Server listen socket host and port as :py:class:`tuple`
"""
return self.server_host, self.server_port
- async def close(self):
+ async def close(self) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -800,16 +988,22 @@ async def close(self):
self.server.close()
tasks = [asyncio.create_task(self.server.wait_closed())]
for connection in self.connections.values():
- connection._dispatcher.cancel()
- tasks.append(connection._dispatcher)
+ connection._dispatcher.cancel() # pyright: ignore[reportPrivateUsage]
+ tasks.append(connection._dispatcher) # pyright: ignore[reportPrivateUsage]
logger.debug("waiting for %d tasks", len(tasks))
await asyncio.wait(tasks)
- async def write_line(self, stream, line):
+ async def write_line(self, stream: StreamIO, line: str) -> None:
logger.debug(line)
await stream.write((line + END_OF_LINE).encode(encoding=self.encoding))
- async def write_response(self, stream, code, lines="", list=False):
+ async def write_response(
+ self,
+ stream: StreamIO,
+ code: client.Code,
+ lines: Iterable[str] = "",
+ list: bool = False,
+ ) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -842,7 +1036,11 @@ async def write_response(self, stream, code, lines="", list=False):
await write(code + "-" + line)
await write(code + " " + tail)
- async def parse_command(self, stream, censor_commands=("pass",)):
+ async def parse_command(
+ self,
+ stream: StreamIO,
+ censor_commands: Tuple[str] = ("pass",),
+ ) -> Tuple[str, str]:
"""
:py:func:`asyncio.coroutine`
@@ -871,7 +1069,7 @@ async def parse_command(self, stream, censor_commands=("pass",)):
return cmd.lower(), rest
- async def response_writer(self, stream, response_queue):
+ async def response_writer(self, stream: StreamIO, response_queue: "asyncio.Queue[Any]") -> None:
"""
:py:func:`asyncio.coroutine`
@@ -892,7 +1090,7 @@ async def response_writer(self, stream, response_queue):
finally:
response_queue.task_done()
- async def dispatcher(self, reader, writer):
+ async def dispatcher(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
"""
:py:func:`asyncio.coroutine`
@@ -911,7 +1109,7 @@ async def dispatcher(self, reader, writer):
read_timeout=self.idle_timeout,
write_timeout=self.socket_timeout,
)
- response_queue = asyncio.Queue()
+ response_queue: "asyncio.Queue[Any]" = asyncio.Queue()
connection = Connection(
client_host=host,
client_port=port,
@@ -926,21 +1124,22 @@ async def dispatcher(self, reader, writer):
path_io_factory=self.path_io_factory,
path_timeout=self.path_timeout,
extra_workers=set(),
- response=lambda *args: response_queue.put_nowait(args),
+ response=lambda *args: response_queue.put_nowait(args), # pyright: ignore[reportUnknownLambdaType]
acquired=False,
restart_offset=0,
- _dispatcher=get_current_task(),
+ _dispatcher=asyncio.current_task(),
)
connection.path_io = self.path_io_factory(
timeout=self.path_timeout,
connection=connection,
)
+ connection_: ConnectionProtocol = cast(Any, connection)
pending = {
- asyncio.create_task(self.greeting(connection, "")),
+ asyncio.create_task(self.greeting(connection_, "")),
asyncio.create_task(self.response_writer(stream, response_queue)),
asyncio.create_task(self.parse_command(stream)),
}
- self.connections[key] = connection
+ self.connections[key] = connection_
try:
while True:
done, pending = await asyncio.wait(
@@ -949,6 +1148,7 @@ async def dispatcher(self, reader, writer):
)
connection.extra_workers -= done
for task in done:
+ result: Optional[Union[bool, Tuple[str, str]]] = None
try:
result = task.result()
except errors.PathIOError:
@@ -981,7 +1181,7 @@ async def dispatcher(self, reader, writer):
logger.exception("dispatcher caught exception")
finally:
logger.info("closing connection from %s:%s", host, port)
- tasks_to_wait = []
+ tasks_to_wait: List[asyncio.Task[Any]] = []
if not asyncio.get_running_loop().is_closed():
for task in pending | connection.extra_workers:
task.cancel()
@@ -1006,7 +1206,7 @@ async def dispatcher(self, reader, writer):
await asyncio.wait(tasks_to_wait)
@staticmethod
- def get_paths(connection, path):
+ def get_paths(connection: ConnectionProtocol, path: _PathType) -> Tuple[pathlib.Path, pathlib.PurePosixPath]:
"""
Return *real* and *virtual* paths, resolves ".." with "up" action.
*Real* path is path for path_io, when *virtual* deals with
@@ -1040,7 +1240,7 @@ def get_paths(connection, path):
resolved_virtual_path = pathlib.PurePosixPath("/")
return real_path, resolved_virtual_path
- async def greeting(self, connection, rest):
+ async def greeting(self, connection: ConnectionProtocol, rest: object) -> bool:
if self.available_connections.locked():
ok, code, info = False, "421", "Too many connections"
else:
@@ -1050,7 +1250,7 @@ async def greeting(self, connection, rest):
connection.response(code, info)
return ok
- async def user(self, connection, rest):
+ async def user(self, connection: ConnectionProtocol, rest: str) -> Literal[True]:
if connection.future.user.done():
await self.user_manager.notify_logout(connection.user)
del connection.user
@@ -1059,15 +1259,14 @@ async def user(self, connection, rest):
if state == AbstractUserManager.GetUserResponse.OK:
code = "230"
connection.logged = True
- connection.user = user
+ connection.user = cast(User, user)
elif state == AbstractUserManager.GetUserResponse.PASSWORD_REQUIRED:
code = "331"
- connection.user = user
+ connection.user = cast(User, user)
elif state == AbstractUserManager.GetUserResponse.ERROR:
code = "530"
else:
- message = f"Unknown response {state}"
- raise NotImplementedError(message)
+ raise NotImplementedError(f"Unknown response {state}")
if connection.future.user.done():
connection.current_directory = connection.user.home_path
@@ -1089,7 +1288,7 @@ async def user(self, connection, rest):
return True
@ConnectionConditions(ConnectionConditions.user_required)
- async def pass_(self, connection, rest):
+ async def pass_(self, connection: ConnectionProtocol, rest: str) -> Literal[True]:
if connection.future.logged.done():
code, info = "503", "already logged in"
elif await self.user_manager.authenticate(connection.user, rest):
@@ -1100,12 +1299,12 @@ async def pass_(self, connection, rest):
connection.response(code, info)
return True
- async def quit(self, connection, rest):
+ async def quit(self, connection: ConnectionProtocol, rest: Any) -> Literal[False]:
connection.response("221", "bye")
return False
@ConnectionConditions(ConnectionConditions.login_required)
- async def pwd(self, connection, rest):
+ async def pwd(self, connection: ConnectionProtocol, rest: Any) -> Literal[True]:
code, info = "257", f'"{connection.current_directory}"'
connection.response(code, info)
return True
@@ -1116,21 +1315,21 @@ async def pwd(self, connection, rest):
PathConditions.path_must_be_dir,
)
@PathPermissions(PathPermissions.readable)
- async def cwd(self, connection, rest):
- real_path, virtual_path = self.get_paths(connection, rest)
+ async def cwd(self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> Literal[True]:
+ _, virtual_path = self.get_paths(connection, rest)
connection.current_directory = virtual_path
connection.response("250", "")
return True
@ConnectionConditions(ConnectionConditions.login_required)
- async def cdup(self, connection, rest):
+ async def cdup(self, connection: ConnectionProtocol, rest: object) -> Optional[Literal[True]]:
return await self.cwd(connection, connection.current_directory.parent)
@ConnectionConditions(ConnectionConditions.login_required)
@PathConditions(PathConditions.path_must_not_exists)
@PathPermissions(PathPermissions.writable)
- async def mkd(self, connection, rest):
- real_path, virtual_path = self.get_paths(connection, rest)
+ async def mkd(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]:
+ real_path, _ = self.get_paths(connection, rest)
await connection.path_io.mkdir(real_path, parents=True)
connection.response("257", "")
return True
@@ -1141,24 +1340,25 @@ async def mkd(self, connection, rest):
PathConditions.path_must_be_dir,
)
@PathPermissions(PathPermissions.writable)
- async def rmd(self, connection, rest):
- real_path, virtual_path = self.get_paths(connection, rest)
+ async def rmd(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]:
+ real_path, _ = self.get_paths(connection, rest)
await connection.path_io.rmdir(real_path)
connection.response("250", "")
return True
@staticmethod
- def _format_mlsx_time(local_seconds):
+ def _format_mlsx_time(local_seconds: float) -> str:
return time.strftime("%Y%m%d%H%M%S", time.gmtime(local_seconds))
- def _build_mlsx_facts_from_stats(self, stats):
+ def _build_mlsx_facts_from_stats(self, stats: StatsProtocol) -> _MLSXFacts:
return {
"Size": stats.st_size,
"Create": self._format_mlsx_time(stats.st_ctime),
"Modify": self._format_mlsx_time(stats.st_mtime),
}
- async def build_mlsx_string(self, connection, path):
+ async def build_mlsx_string(self, connection: ConnectionProtocol, path: pathlib.Path) -> str:
+ facts: Union[Dict[Any, Any], _MLSXFacts]
if not await connection.path_io.exists(path):
facts = {}
else:
@@ -1172,7 +1372,7 @@ async def build_mlsx_string(self, connection, path):
facts["Type"] = "unknown"
s = ""
- for name, value in facts.items():
+ for name, value in cast(_MLSXFacts, facts).items():
s += f"{name}={value};"
s += " " + path.name
return s
@@ -1183,7 +1383,7 @@ async def build_mlsx_string(self, connection, path):
)
@PathConditions(PathConditions.path_must_exists)
@PathPermissions(PathPermissions.readable)
- async def mlsd(self, connection, rest):
+ async def mlsd(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]:
@ConnectionConditions(
ConnectionConditions.data_connection_made,
wait=True,
@@ -1191,7 +1391,7 @@ async def mlsd(self, connection, rest):
fail_info="Can't open data connection",
)
@worker
- async def mlsd_worker(self, connection, rest):
+ async def mlsd_worker(self: Self, connection: ConnectionProtocol, rest: _PathType) -> None:
stream = connection.data_connection
del connection.data_connection
async with stream:
@@ -1200,17 +1400,16 @@ async def mlsd_worker(self, connection, rest):
b = (s + END_OF_LINE).encode(encoding=self.encoding)
await stream.write(b)
connection.response("200", "mlsd transfer done")
- return True
- real_path, virtual_path = self.get_paths(connection, rest)
+ real_path, _ = self.get_paths(connection, rest)
coro = mlsd_worker(self, connection, rest)
- task = asyncio.create_task(coro)
+ task: asyncio.Task[Optional[Literal[True]]] = asyncio.create_task(coro)
connection.extra_workers.add(task)
connection.response("150", "mlsd transfer started")
return True
@staticmethod
- def build_list_mtime(st_mtime, now=None):
+ def build_list_mtime(st_mtime: float, now: Optional[float] = None) -> str:
if now is None:
now = time.time()
mtime = time.localtime(st_mtime)
@@ -1221,7 +1420,7 @@ def build_list_mtime(st_mtime, now=None):
s = time.strftime("%b %e %Y", mtime)
return s
- async def build_list_string(self, connection, path):
+ async def build_list_string(self, connection: ConnectionProtocol, path: pathlib.Path) -> str:
stats = await connection.path_io.stat(path)
mtime = self.build_list_mtime(stats.st_mtime)
fields = (
@@ -1242,7 +1441,7 @@ async def build_list_string(self, connection, path):
)
@PathConditions(PathConditions.path_must_exists)
@PathPermissions(PathPermissions.readable)
- async def list(self, connection, rest):
+ async def list(self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> Literal[True]:
@ConnectionConditions(
ConnectionConditions.data_connection_made,
wait=True,
@@ -1250,7 +1449,7 @@ async def list(self, connection, rest):
fail_info="Can't open data connection",
)
@worker
- async def list_worker(self, connection, rest):
+ async def list_worker(self: Self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> None:
stream = connection.data_connection
del connection.data_connection
async with stream:
@@ -1262,11 +1461,10 @@ async def list_worker(self, connection, rest):
b = (s + END_OF_LINE).encode(encoding=self.encoding)
await stream.write(b)
connection.response("226", "list transfer done")
- return True
- real_path, virtual_path = self.get_paths(connection, rest)
+ real_path, _ = self.get_paths(connection, rest)
coro = list_worker(self, connection, rest)
- task = asyncio.create_task(coro)
+ task: asyncio.Task[Optional[Literal[True]]] = asyncio.create_task(coro)
connection.extra_workers.add(task)
connection.response("150", "list transfer started")
return True
@@ -1274,8 +1472,8 @@ async def list_worker(self, connection, rest):
@ConnectionConditions(ConnectionConditions.login_required)
@PathConditions(PathConditions.path_must_exists)
@PathPermissions(PathPermissions.readable)
- async def mlst(self, connection, rest):
- real_path, virtual_path = self.get_paths(connection, rest)
+ async def mlst(self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> Literal[True]:
+ real_path, _ = self.get_paths(connection, rest)
s = await self.build_mlsx_string(connection, real_path)
connection.response("250", ["start", s, "end"], True)
return True
@@ -1283,8 +1481,8 @@ async def mlst(self, connection, rest):
@ConnectionConditions(ConnectionConditions.login_required)
@PathConditions(PathConditions.path_must_exists)
@PathPermissions(PathPermissions.writable)
- async def rnfr(self, connection, rest):
- real_path, virtual_path = self.get_paths(connection, rest)
+ async def rnfr(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]:
+ real_path, _ = self.get_paths(connection, rest)
connection.rename_from = real_path
connection.response("350", "rename from accepted")
return True
@@ -1295,8 +1493,8 @@ async def rnfr(self, connection, rest):
)
@PathConditions(PathConditions.path_must_not_exists)
@PathPermissions(PathPermissions.writable)
- async def rnto(self, connection, rest):
- real_path, virtual_path = self.get_paths(connection, rest)
+ async def rnto(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]:
+ real_path, _ = self.get_paths(connection, rest)
rename_from = connection.rename_from
del connection.rename_from
await connection.path_io.rename(rename_from, real_path)
@@ -1309,8 +1507,8 @@ async def rnto(self, connection, rest):
PathConditions.path_must_be_file,
)
@PathPermissions(PathPermissions.writable)
- async def dele(self, connection, rest):
- real_path, virtual_path = self.get_paths(connection, rest)
+ async def dele(self, connection: ConnectionProtocol, rest: _PathType) -> Literal[True]:
+ real_path, _ = self.get_paths(connection, rest)
await connection.path_io.unlink(real_path)
connection.response("250", "")
return True
@@ -1320,7 +1518,12 @@ async def dele(self, connection, rest):
ConnectionConditions.passive_server_started,
)
@PathPermissions(PathPermissions.writable)
- async def stor(self, connection, rest, mode="wb"):
+ async def stor(
+ self,
+ connection: ConnectionProtocol,
+ rest: _PathType,
+ mode: OpenMode = "wb",
+ ) -> Literal[True]:
@ConnectionConditions(
ConnectionConditions.data_connection_made,
wait=True,
@@ -1328,7 +1531,7 @@ async def stor(self, connection, rest, mode="wb"):
fail_info="Can't open data connection",
)
@worker
- async def stor_worker(self, connection, rest):
+ async def stor_worker(self: Self, connection: ConnectionProtocol, rest: object) -> None:
stream = connection.data_connection
del connection.data_connection
if connection.restart_offset:
@@ -1341,17 +1544,16 @@ async def stor_worker(self, connection, rest):
await file_out.seek(connection.restart_offset)
async for data in stream.iter_by_block(connection.block_size):
await file_out.write(data)
- connection.response("226", "data transfer done")
- return True
+ connection.response("226", ["data transfer done"])
- real_path, virtual_path = self.get_paths(connection, rest)
+ real_path, _ = self.get_paths(connection, rest)
if await connection.path_io.is_dir(real_path.parent):
coro = stor_worker(self, connection, rest)
task = asyncio.create_task(coro)
connection.extra_workers.add(task)
- code, info = "150", "data transfer started"
+ code, info = "150", ["data transfer started"]
else:
- code, info = "550", "path unreachable"
+ code, info = "550", ["path unreachable"]
connection.response(code, info)
return True
@@ -1364,7 +1566,7 @@ async def stor_worker(self, connection, rest):
PathConditions.path_must_be_file,
)
@PathPermissions(PathPermissions.readable)
- async def retr(self, connection, rest):
+ async def retr(self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> Literal[True]:
@ConnectionConditions(
ConnectionConditions.data_connection_made,
wait=True,
@@ -1372,7 +1574,7 @@ async def retr(self, connection, rest):
fail_info="Can't open data connection",
)
@worker
- async def retr_worker(self, connection, rest):
+ async def retr_worker(self: Self, connection: ConnectionProtocol, rest: object) -> None:
stream = connection.data_connection
del connection.data_connection
file_in = connection.path_io.open(real_path, mode="rb")
@@ -1382,9 +1584,8 @@ async def retr_worker(self, connection, rest):
async for data in file_in.iter_by_block(connection.block_size):
await stream.write(data)
connection.response("226", "data transfer done")
- return True
- real_path, virtual_path = self.get_paths(connection, rest)
+ real_path, _ = self.get_paths(connection, rest)
coro = retr_worker(self, connection, rest)
task = asyncio.create_task(coro)
connection.extra_workers.add(task)
@@ -1392,7 +1593,7 @@ async def retr_worker(self, connection, rest):
return True
@ConnectionConditions(ConnectionConditions.login_required)
- async def type(self, connection, rest):
+ async def type(self, connection: ConnectionProtocol, rest: str) -> Literal[True]:
if rest in ("I", "A"):
connection.transfer_type = rest
code, info = "200", ""
@@ -1402,12 +1603,12 @@ async def type(self, connection, rest):
return True
@ConnectionConditions(ConnectionConditions.login_required)
- async def pbsz(self, connection, rest):
+ async def pbsz(self, connection: ConnectionProtocol, rest: object) -> Literal[True]:
connection.response("200", "")
return True
@ConnectionConditions(ConnectionConditions.login_required)
- async def prot(self, connection, rest):
+ async def prot(self, connection: ConnectionProtocol, rest: str) -> Literal[True]:
if rest == "P":
code, info = "200", ""
else:
@@ -1415,9 +1616,16 @@ async def prot(self, connection, rest):
connection.response(code, info)
return True
- async def _start_passive_server(self, connection, handler_callback):
+ async def _start_passive_server(
+ self,
+ connection: ConnectionProtocol,
+ handler_callback: Callable[
+ [asyncio.StreamReader, asyncio.StreamWriter],
+ Coroutine[None, None, None],
+ ],
+ ) -> asyncio.base_events.Server:
if self.available_data_ports is not None:
- viewed_ports = set()
+ viewed_ports: Set[int] = set()
while True:
try:
priority, port = self.available_data_ports.get_nowait()
@@ -1436,7 +1644,7 @@ async def _start_passive_server(self, connection, handler_callback):
except asyncio.QueueEmpty:
raise errors.NoAvailablePort
except OSError as err:
- self.available_data_ports.put_nowait((priority + 1, port))
+ self.available_data_ports.put_nowait((priority + 1, port)) # type: ignore
if err.errno != errno.EADDRINUSE:
raise
else:
@@ -1447,11 +1655,11 @@ async def _start_passive_server(self, connection, handler_callback):
ssl=self.ssl,
**self._start_server_extra_arguments,
)
- return passive_server
+ return passive_server # type: ignore
@ConnectionConditions(ConnectionConditions.login_required)
- async def pasv(self, connection, rest):
- async def handler(reader, writer):
+ async def pasv(self, connection: ConnectionProtocol, rest: object) -> bool:
+ async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
if connection.future.data_connection.done():
writer.close()
else:
@@ -1486,7 +1694,7 @@ async def handler(reader, writer):
connection.response("503", ["this server started in ipv6 mode"])
return False
- nums = tuple(map(int, host.split("."))) + (port >> 8, port & 0xFF)
+ nums = tuple(map(int, host.split("."))) + (port >> 8, port & 0xFF) # type: ignore
info.append(f"({','.join(map(str, nums))})")
if connection.future.data_connection.done():
connection.data_connection.close()
@@ -1495,8 +1703,8 @@ async def handler(reader, writer):
return True
@ConnectionConditions(ConnectionConditions.login_required)
- async def epsv(self, connection, rest):
- async def handler(reader, writer):
+ async def epsv(self, connection: ConnectionProtocol, rest: object) -> bool:
+ async def handler(reader: asyncio.StreamReader, writer: asyncio.StreamWriter) -> None:
if connection.future.data_connection.done():
writer.close()
else:
@@ -1527,7 +1735,7 @@ async def handler(reader, writer):
_, port, *_ = sock.getsockname()
break
- info[0] += f" (|||{port}|)"
+ info[0] += f" (|||{port}|)" # type: ignore
if connection.future.data_connection.done():
connection.data_connection.close()
del connection.data_connection
@@ -1535,7 +1743,7 @@ async def handler(reader, writer):
return True
@ConnectionConditions(ConnectionConditions.login_required)
- async def abor(self, connection, rest):
+ async def abor(self, connection: ConnectionProtocol, rest: object) -> Literal[True]:
if connection.extra_workers:
for worker in connection.extra_workers:
worker.cancel()
@@ -1543,10 +1751,10 @@ async def abor(self, connection, rest):
connection.response("226", "nothing to abort")
return True
- async def appe(self, connection, rest):
+ async def appe(self, connection: ConnectionProtocol, rest: pathlib.PurePosixPath) -> Optional[Literal[True]]:
return await self.stor(connection, rest, "ab")
- async def rest(self, connection, rest):
+ async def rest(self, connection: ConnectionProtocol, rest: str) -> Literal[True]:
if rest.isdigit():
connection.restart_offset = int(rest)
connection.response("350", f"restarting at {rest}")
@@ -1556,7 +1764,7 @@ async def rest(self, connection, rest):
connection.response("501", message)
return True
- async def syst(self, connection, rest):
+ async def syst(self, connection: ConnectionProtocol, rest: object) -> Literal[True]:
"""Return system type (always returns UNIX type: L8)."""
connection.response("215", "UNIX Type: L8")
return True
diff --git a/src/aioftp/types.py b/src/aioftp/types.py
new file mode 100644
index 0000000..5ac8dd3
--- /dev/null
+++ b/src/aioftp/types.py
@@ -0,0 +1,55 @@
+from typing import TYPE_CHECKING, Any, Generator, Literal, NewType, Protocol, Tuple, TypeVar
+
+from typing_extensions import TypeAlias
+
+if TYPE_CHECKING:
+ from .client import Code
+
+_T_co = TypeVar("_T_co", covariant=True)
+
+__all__ = ("NotEmptyCodes", "check_not_empty_codes")
+
+
+class StatsProtocol(Protocol):
+ @property
+ def st_size(self) -> int:
+ raise NotImplementedError
+
+ @property
+ def st_ctime(self) -> float:
+ raise NotImplementedError
+
+ @property
+ def st_mtime(self) -> float:
+ raise NotImplementedError
+
+ @property
+ def st_nlink(self) -> int:
+ raise NotImplementedError
+
+ @property
+ def st_mode(self) -> int:
+ raise NotImplementedError
+
+
+class AsyncEnterableProtocol(Protocol[_T_co]):
+ async def __aenter__(self) -> _T_co:
+ raise NotImplementedError
+
+ async def __aexit__(self, *args: Any, **kwargs: Any) -> None:
+ raise NotImplementedError
+
+ def __await__(self) -> Generator[None, None, _T_co]:
+ raise NotImplementedError
+
+
+OpenMode: TypeAlias = Literal["rb", "wb", "ab", "r+b"]
+
+NotEmptyCodes = NewType("NotEmptyCodes", Tuple["Code", ...])
+
+
+def check_not_empty_codes(codes: Tuple["Code", ...]) -> NotEmptyCodes:
+ if not codes:
+ raise ValueError("Codes should not be empty")
+
+ return NotEmptyCodes(codes)
diff --git a/src/aioftp/utils.py b/src/aioftp/utils.py
new file mode 100644
index 0000000..cf74c2f
--- /dev/null
+++ b/src/aioftp/utils.py
@@ -0,0 +1,5 @@
+from typing import Any, Dict, Tuple
+
+
+def get_param(where: Tuple[int, str], args: Tuple[Any, ...], kwargs: Dict[str, Any]) -> Any:
+ return kwargs.get(where[1]) or args[where[0]]