diff --git a/conftest.py b/conftest.py index ed290ec..7ea9068 100644 --- a/conftest.py +++ b/conftest.py @@ -1,3 +1,4 @@ +import pathlib as std_pathlib import sys import pytest @@ -5,11 +6,11 @@ if sys.version_info < (3, 10): import pathlib2 as pathlib # pragma: nocover else: - import pathlib # pragma: nocover + pathlib = std_pathlib # pragma: nocover @pytest.fixture -def tmp_path(tmp_path): +def tmp_path(tmp_path: std_pathlib.Path) -> pathlib.Path: """ Override tmp_path to wrap in a more modern interface. """ diff --git a/docs/conf.py b/docs/conf.py index d226e14..0dcb725 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -56,3 +56,12 @@ # local extensions += ['jaraco.tidelift'] +nitpick_ignore += [ + ("py:class", "FileDescriptorOrPath"), + ("py:class", "_GetItemIterable"), + ("py:class", "_SupportsDecode"), + ("py:class", "_TranslateTable"), + ("py:class", "itertools.chain"), + ("py:class", "jaraco.text._SupportsDecode"), + ("py:class", "jaraco.text._T"), +] diff --git a/jaraco/text/__init__.py b/jaraco/text/__init__.py index 5066662..42666e9 100644 --- a/jaraco/text/__init__.py +++ b/jaraco/text/__init__.py @@ -1,22 +1,52 @@ +from __future__ import annotations + import functools import itertools import re +import sys import textwrap -from collections.abc import Iterable +from collections.abc import Callable, Generator, Iterable, Sequence from importlib.resources import files +from typing import ( + TYPE_CHECKING, + Literal, + Protocol, + SupportsIndex, + TypeVar, + overload, +) from jaraco.context import ExceptionTrap from jaraco.functools import compose, method_cache +if sys.version_info >= (3, 11): # pragma: no cover + from importlib.resources.abc import Traversable +else: # pragma: no cover + from importlib.abc import Traversable + +if TYPE_CHECKING: + from _typeshed import FileDescriptorOrPath, SupportsGetItem + from typing_extensions import Self, TypeAlias, TypeGuard, Unpack + + _T_co = TypeVar("_T_co", covariant=True) + # Same as builtins._GetItemIterable from typeshed + _GetItemIterable: TypeAlias = SupportsGetItem[int, _T_co] + +_T = TypeVar("_T") + -def substitution(old, new): +class _SupportsDecode(Protocol): + def decode(self) -> object: ... + + +def substitution(old: str, new: str) -> Callable[[str], str]: """ Return a function that will perform a substitution on a string """ return lambda s: s.replace(old, new) -def multi_substitution(*substitutions): +def multi_substitution(*substitutions: str) -> Callable[[str], str]: """ Take a sequence of pairs specifying substitutions, and create a function that performs those substitutions. @@ -24,11 +54,13 @@ def multi_substitution(*substitutions): >>> multi_substitution(('foo', 'bar'), ('bar', 'baz'))('foo') 'baz' """ - substitutions = itertools.starmap(substitution, substitutions) + callables: Iterable[Callable[[str], str]] = itertools.starmap( + substitution, substitutions + ) # compose function applies last function first, so reverse the # substitutions to get the expected order. - substitutions = reversed(tuple(substitutions)) - return compose(*substitutions) + reversed_ = reversed(tuple(callables)) + return compose(*reversed_) class FoldedCase(str): @@ -53,6 +85,11 @@ class FoldedCase(str): >>> s.split('O') ['hell', ' w', 'rld'] + Like ``str``, split accepts None as ''. + + >>> s.split(None) + ['hello', 'world'] + >>> sorted(map(FoldedCase, ['GAMMA', 'alpha', 'Beta'])) ['alpha', 'Beta', 'GAMMA'] @@ -92,45 +129,81 @@ class FoldedCase(str): >>> FoldedCase('ß') == FoldedCase('ss') True + + Also supports string to object comparisons: + + >>> FoldedCase('foo') == object() + False + >>> FoldedCase('foo') != object() + True + >>> object() in FoldedCase('foo') + False """ - def __lt__(self, other): + def __lt__(self, other: str) -> bool: return self.casefold() < other.casefold() - def __gt__(self, other): + def __gt__(self, other: str) -> bool: return self.casefold() > other.casefold() - def __eq__(self, other): - return self.casefold() == other.casefold() + @functools.singledispatchmethod + def __eq__(self, other: object) -> bool: + return False + + @__eq__.register + def _(self, other: str) -> bool: + return self.casefold().__eq__(other.casefold()) - def __ne__(self, other): - return self.casefold() != other.casefold() + @functools.singledispatchmethod + def __ne__(self, other: object) -> bool: + return True - def __hash__(self): + @__ne__.register + def _(self, other: str) -> bool: + return self.casefold().__ne__(other.casefold()) + + def __hash__(self) -> int: return hash(self.casefold()) - def __contains__(self, other): + @functools.singledispatchmethod + def __contains__(self, other: object) -> bool: + return False + + @__contains__.register + def _(self, other: str) -> bool: return super().casefold().__contains__(other.casefold()) - def in_(self, other): - "Does self appear in other?" + def in_(self, other: str) -> bool: + """Does self appear in other?""" return self in FoldedCase(other) # cache casefold since it's likely to be called frequently. @method_cache - def casefold(self): + def casefold(self) -> str: return super().casefold() - def index(self, sub): - return self.casefold().index(sub.casefold()) - - def split(self, splitter=' ', maxsplit=0): + def index( + self, + sub: str, + start: SupportsIndex | None = None, + end: SupportsIndex | None = None, + ) -> int: + return self.casefold().index(sub.casefold(), start, end) + + @functools.singledispatchmethod + def split( + self, splitter: str | None = ' ', maxsplit: SupportsIndex = 0 + ) -> list[str]: + return self.split(' ', maxsplit=maxsplit) + + @split.register + def _(self, splitter: str, maxsplit: SupportsIndex = 0) -> list[str]: pattern = re.compile(re.escape(splitter), re.I) - return pattern.split(self, maxsplit) + return pattern.split(self, int(maxsplit)) -@ExceptionTrap(UnicodeDecodeError).passes -def is_decodable(value): +@ExceptionTrap(UnicodeDecodeError).passes # type: ignore[no-untyped-call, untyped-decorator, unused-ignore, misc] # jaraco/jaraco.context#15 +def is_decodable(value: _SupportsDecode) -> None: r""" Return True if the supplied value is decodable (using the default encoding). @@ -143,7 +216,7 @@ def is_decodable(value): value.decode() -def is_binary(value): +def is_binary(value: _SupportsDecode) -> TypeGuard[bytes]: r""" Return True if the value appears to be binary (that is, it's a byte string and isn't decodable). @@ -156,7 +229,7 @@ def is_binary(value): return isinstance(value, bytes) and not is_decodable(value) -def trim(s): +def trim(s: str) -> str: r""" Trim something like a docstring to remove the whitespace that is common due to indentation and formatting. @@ -167,7 +240,7 @@ def trim(s): return textwrap.dedent(s).strip() -def wrap(s): +def wrap(s: str) -> str: """ Wrap lines of text, retaining existing newlines as paragraph markers. @@ -200,7 +273,7 @@ def wrap(s): return '\n\n'.join(wrapped) -def unwrap(s): +def unwrap(s: str) -> str: r""" Given a multi-line string, return an unwrapped version. @@ -233,14 +306,14 @@ class Splitter: ['hello', ' world', ' this is your', ' master calling'] """ - def __init__(self, *args): + def __init__(self, *args: Unpack[tuple[str | None, SupportsIndex]]) -> None: self.args = args - def __call__(self, s): + def __call__(self, s: str) -> list[str]: return s.split(*self.args) -def indent(string, prefix=' ' * 4): +def indent(string: str, prefix: str = ' ' * 4) -> str: """ >>> indent('foo') ' foo' @@ -248,7 +321,7 @@ def indent(string, prefix=' ' * 4): return prefix + string -class WordSet(tuple): +class WordSet(tuple[str, ...]): """ Given an identifier, return the words that identifier represents, whether in camel case, underscore-separated, etc. @@ -304,31 +377,31 @@ class WordSet(tuple): _pattern = re.compile('([A-Z]?[a-z]+)|([A-Z]+(?![a-z]))') - def capitalized(self): + def capitalized(self) -> WordSet: return WordSet(word.capitalize() for word in self) - def lowered(self): + def lowered(self) -> WordSet: return WordSet(word.lower() for word in self) - def camel_case(self): + def camel_case(self) -> str: return ''.join(self.capitalized()) - def headless_camel_case(self): + def headless_camel_case(self) -> str: words = iter(self) first = next(words).lower() new_words = itertools.chain((first,), WordSet(words).camel_case()) return ''.join(new_words) - def underscore_separated(self): + def underscore_separated(self) -> str: return '_'.join(self) - def dash_separated(self): + def dash_separated(self) -> str: return '-'.join(self) - def space_separated(self): + def space_separated(self) -> str: return ' '.join(self) - def trim_right(self, item): + def trim_right(self, item: str) -> WordSet: """ Remove the item from the end of the set. @@ -341,7 +414,7 @@ def trim_right(self, item): """ return self[:-1] if self and self[-1] == item else self - def trim_left(self, item): + def trim_left(self, item: str) -> WordSet: """ Remove the item from the beginning of the set. @@ -354,26 +427,30 @@ def trim_left(self, item): """ return self[1:] if self and self[0] == item else self - def trim(self, item): + def trim(self, item: str) -> WordSet: """ >>> WordSet.parse('foo bar').trim('foo') ('bar',) """ return self.trim_left(item).trim_right(item) - def __getitem__(self, item): + @overload # type:ignore[override] # more restricted return type + def __getitem__(self, item: slice) -> WordSet: ... + @overload + def __getitem__(self, item: SupportsIndex) -> str: ... + def __getitem__(self, item: slice | SupportsIndex) -> WordSet | str: result = super().__getitem__(item) - if isinstance(item, slice): - result = WordSet(result) + if isinstance(result, tuple): + return WordSet(result) return result @classmethod - def parse(cls, identifier): + def parse(cls, identifier: str) -> WordSet: matches = cls._pattern.finditer(identifier) return WordSet(match.group(0) for match in matches) @classmethod - def from_class_name(cls, subject): + def from_class_name(cls, subject: object) -> WordSet: return cls.parse(subject.__class__.__name__) @@ -381,7 +458,7 @@ def from_class_name(cls, subject): words = WordSet.parse -def simple_html_strip(s): +def simple_html_strip(s: str) -> str: r""" Remove HTML from the string `s`. @@ -419,7 +496,7 @@ class SeparatedValues(str): separator = ',' - def __iter__(self): + def __iter__(self) -> filter[str]: parts = self.split(self.separator) return filter(None, (part.strip() for part in parts)) @@ -451,24 +528,30 @@ class Stripper: ['abcd\n', '1234\n'] """ - def __init__(self, prefix, lines): + def __init__(self, prefix: str | None, lines: Iterable[str]) -> None: self.prefix = prefix self.lines = map(self, lines) @classmethod - def strip_prefix(cls, lines): + def strip_prefix(cls, lines: Iterable[str]) -> Self: prefix_lines, lines = itertools.tee(lines) prefix = functools.reduce(cls.common_prefix, prefix_lines) return cls(prefix, lines) - def __call__(self, line): + def __call__(self, line: str) -> str: if not self.prefix: return line null, prefix, rest = line.partition(self.prefix) return rest + @overload + @staticmethod + def common_prefix(s1: str, s2: str) -> str: ... + @overload + @staticmethod + def common_prefix(s1: Sequence[str], s2: Sequence[str]) -> Sequence[str]: ... @staticmethod - def common_prefix(s1, s2): + def common_prefix(s1: Sequence[str], s2: Sequence[str]) -> Sequence[str]: """ Return the common prefix of two lines. """ @@ -478,7 +561,7 @@ def common_prefix(s1, s2): return s1[:index] -def remove_prefix(text, prefix): +def remove_prefix(text: str, prefix: str) -> str: """ Remove the prefix from the text if it exists. @@ -492,7 +575,7 @@ def remove_prefix(text, prefix): return rest -def remove_suffix(text, suffix): +def remove_suffix(text: str, suffix: str) -> str: """ Remove the suffix from the text if it exists. @@ -506,7 +589,7 @@ def remove_suffix(text, suffix): return rest -def normalize_newlines(text): +def normalize_newlines(text: str) -> str: r""" Replace alternate newlines with the canonical newline. @@ -522,12 +605,12 @@ def normalize_newlines(text): return re.sub(pattern, '\n', text) -def _nonblank(str): +def _nonblank(str: str) -> bool | Literal['']: return str and not str.startswith('#') @functools.singledispatch -def yield_lines(iterable): +def yield_lines(iterable: Iterable[_T] | str) -> itertools.chain[str]: r""" Yield valid lines of a string or iterable. @@ -546,18 +629,18 @@ def yield_lines(iterable): @yield_lines.register(str) -def _(text): +def _(text: str) -> filter[str]: return clean(text.splitlines()) -def clean(lines: Iterable[str]): +def clean(lines: Iterable[str]) -> filter[str]: """ Yield non-blank, non-comment elements from lines. """ return filter(_nonblank, map(str.strip, lines)) -def drop_comment(line): +def drop_comment(line: str) -> str: """ Drop comments. @@ -572,7 +655,7 @@ def drop_comment(line): return line.partition(' #')[0] -def join_continuation(lines): +def join_continuation(lines: _GetItemIterable[str]) -> Generator[str]: r""" Join lines continued by a trailing backslash. @@ -595,17 +678,19 @@ def join_continuation(lines): >>> list(join_continuation(['foo', 'bar\\', 'baz\\'])) ['foo'] """ - lines = iter(lines) - for item in lines: + lines_ = iter(lines) + for item in lines_: while item.endswith('\\'): try: - item = item[:-2].strip() + next(lines) + item = item[:-2].strip() + next(lines_) except StopIteration: return yield item -def read_newlines(filename, limit=1024): +def read_newlines( + filename: FileDescriptorOrPath, limit: int | None = 1024 +) -> str | tuple[str, ...] | None: r""" >>> tmp_path = getfixture('tmp_path') >>> filename = tmp_path / 'out.txt' @@ -624,7 +709,7 @@ def read_newlines(filename, limit=1024): return fp.newlines -def lines_from(input): +def lines_from(input: Traversable) -> Generator[str]: """ Generate lines from a :class:`importlib.resources.abc.Traversable` path. diff --git a/jaraco/text/layouts.py b/jaraco/text/layouts.py index 9636f0f..7ad0cd4 100644 --- a/jaraco/text/layouts.py +++ b/jaraco/text/layouts.py @@ -1,3 +1,14 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Union + +if TYPE_CHECKING: + from _typeshed import SupportsGetItem, SupportsRead + from typing_extensions import TypeAlias + + # Same as builtins._TranslateTable from typeshed + _TranslateTable: TypeAlias = SupportsGetItem[int, Union[str, int, None]] + qwerty = "-=qwertyuiop[]asdfghjkl;'zxcvbnm,./_+QWERTYUIOP{}ASDFGHJKL:\"ZXCVBNM<>?" dvorak = "[]',.pyfgcrl/=aoeuidhtns-;qjkxbmwvz{}\"<>PYFGCRL?+AOEUIDHTNS_:QJKXBMWVZ" @@ -6,7 +17,7 @@ to_qwerty = str.maketrans(dvorak, qwerty) -def translate(input, translation): +def translate(input: str, translation: _TranslateTable) -> str: """ >>> translate('dvorak', to_dvorak) 'ekrpat' @@ -16,7 +27,7 @@ def translate(input, translation): return input.translate(translation) -def _translate_stream(stream, translation): +def _translate_stream(stream: SupportsRead[str], translation: _TranslateTable) -> None: """ >>> import io >>> _translate_stream(io.StringIO('foo'), to_dvorak) diff --git a/jaraco/text/py.typed b/jaraco/text/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/jaraco/text/show-newlines.py b/jaraco/text/show-newlines.py index ef4cc54..e3d1c67 100644 --- a/jaraco/text/show-newlines.py +++ b/jaraco/text/show-newlines.py @@ -1,11 +1,18 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING + import autocommand import inflect from more_itertools import always_iterable import jaraco.text +if TYPE_CHECKING: + from _typeshed import FileDescriptorOrPath + -def report_newlines(filename): +def report_newlines(filename: FileDescriptorOrPath) -> None: r""" Report the newlines in the indicated file. @@ -23,6 +30,7 @@ def report_newlines(filename): count = len(tuple(always_iterable(newlines))) engine = inflect.engine() print( + # Pyright typing issue: jaraco/inflect#210 engine.plural_noun("newline", count), engine.plural_verb("is", count), repr(newlines), diff --git a/jaraco/text/strip-prefix.py b/jaraco/text/strip-prefix.py index 761717a..3ca85a9 100644 --- a/jaraco/text/strip-prefix.py +++ b/jaraco/text/strip-prefix.py @@ -5,7 +5,7 @@ from jaraco.text import Stripper -def strip_prefix(): +def strip_prefix() -> None: r""" Strip any common prefix from stdin. diff --git a/jaraco/text/to-dvorak.py b/jaraco/text/to-dvorak.py index 14c8981..12ca363 100644 --- a/jaraco/text/to-dvorak.py +++ b/jaraco/text/to-dvorak.py @@ -2,4 +2,4 @@ from . import layouts -__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_dvorak) +__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_dvorak) # type: ignore[func-returns-value] diff --git a/jaraco/text/to-qwerty.py b/jaraco/text/to-qwerty.py index 23596fd..0b4d6d2 100644 --- a/jaraco/text/to-qwerty.py +++ b/jaraco/text/to-qwerty.py @@ -2,4 +2,4 @@ from . import layouts -__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_qwerty) +__name__ == '__main__' and layouts._translate_stream(sys.stdin, layouts.to_qwerty) # type: ignore[func-returns-value] diff --git a/mypy.ini b/mypy.ini index 0e233b0..cb55399 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,6 +1,6 @@ [mypy] # Is the project well-typed? -strict = False +strict = True # Early opt-in even when strict = False warn_unused_ignores = True diff --git a/newsfragments/17.feature.1.rst b/newsfragments/17.feature.1.rst new file mode 100644 index 0000000..88e9b7c --- /dev/null +++ b/newsfragments/17.feature.1.rst @@ -0,0 +1 @@ +Complete annotations and add ``py.typed`` marker -- by :user:`Avasam` diff --git a/newsfragments/17.feature.2.rst b/newsfragments/17.feature.2.rst new file mode 100644 index 0000000..238f9ec --- /dev/null +++ b/newsfragments/17.feature.2.rst @@ -0,0 +1 @@ +Add support for ``start`` and ``end`` in `jaraco.text.FoldedCase.index` -- by :user:`Avasam` diff --git a/newsfragments/17.feature.3.rst b/newsfragments/17.feature.3.rst new file mode 100644 index 0000000..e7e6137 --- /dev/null +++ b/newsfragments/17.feature.3.rst @@ -0,0 +1 @@ +Support `None` ``splitter`` argument in `jaraco.text.FoldedCase.split` -- by :user:`Avasam`