From 207f4ef87b7d4180c6a11c8921586b460c140789 Mon Sep 17 00:00:00 2001 From: Gertjan van den Burg Date: Wed, 6 Sep 2023 21:22:39 +0100 Subject: [PATCH] Add type hints to CleverCSV (#108) * Add type hints to CleverCSV This commit adds type hints to CleverCSV and enables static type checking with MyPy. An attempt was made to keep the changes minimal, in order to change only the code needed to get MyPy to pass. As can be seen, this already required a large number of changes to the code. It is also clear that certain design decisions could be reevaluated to make the type checking smoother (e.g., requiring @overload is not ideal). Such improvements are left for future work. --- .github/workflows/build.yml | 2 +- .github/workflows/deploy.yml | 2 +- Makefile | 8 +- clevercsv/__version__.py | 6 +- clevercsv/_optional.py | 7 +- clevercsv/_types.py | 50 ++ clevercsv/break_ties.py | 69 ++- clevercsv/cabstraction.pyi | 11 + clevercsv/consistency.py | 15 +- clevercsv/console/commands/detect.py | 5 +- clevercsv/console/commands/view.py | 31 +- clevercsv/cparser.pyi | 54 ++ clevercsv/cparser_util.py | 91 +++- clevercsv/cparser_util.pyi | 69 +++ clevercsv/detect.py | 14 +- clevercsv/detect_pattern.py | 22 +- clevercsv/detect_type.py | 43 +- clevercsv/dialect.py | 48 +- clevercsv/dict_read_write.py | 97 ++-- clevercsv/encoding.py | 7 +- clevercsv/escape.py | 14 +- clevercsv/py.typed | 0 clevercsv/read.py | 49 +- clevercsv/wrappers.py | 120 +++-- clevercsv/write.py | 34 +- pyproject.toml | 7 + setup.py | 3 +- stubs/pandas/__init__.pyi | 119 +++++ stubs/pythonfuzz/__init__.pyi | 0 stubs/pythonfuzz/main.pyi | 6 + stubs/regex/__init__.pyi | 61 +++ stubs/regex/_regex.pyi | 13 + stubs/regex/_regex_core.pyi | 503 ++++++++++++++++++ stubs/regex/regex.pyi | 189 +++++++ stubs/tabview/__init__.pyi | 1 + stubs/tabview/tabview.pyi | 163 ++++++ stubs/termcolor/__init__.pyi | 22 + stubs/wilderness/__init__.pyi | 168 ++++++ .../test_dialect_detection.py | 3 +- tests/test_unit/test_console.py | 60 ++- tests/test_unit/test_detect_type.py | 33 +- tests/test_unit/test_dict.py | 18 +- tests/test_unit/test_wrappers.py | 9 +- tests/test_unit/test_write.py | 7 - 44 files changed, 1992 insertions(+), 261 deletions(-) create mode 100644 clevercsv/_types.py create mode 100644 clevercsv/cabstraction.pyi create mode 100644 clevercsv/cparser.pyi create mode 100644 clevercsv/cparser_util.pyi create mode 100644 clevercsv/py.typed create mode 100644 stubs/pandas/__init__.pyi create mode 100644 stubs/pythonfuzz/__init__.pyi create mode 100644 stubs/pythonfuzz/main.pyi create mode 100644 stubs/regex/__init__.pyi create mode 100644 stubs/regex/_regex.pyi create mode 100644 stubs/regex/_regex_core.pyi create mode 100644 stubs/regex/regex.pyi create mode 100644 stubs/tabview/__init__.pyi create mode 100644 stubs/tabview/tabview.pyi create mode 100644 stubs/termcolor/__init__.pyi create mode 100644 stubs/wilderness/__init__.pyi diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 04088745..312ba269 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -42,7 +42,7 @@ jobs: strategy: matrix: os: [ 'ubuntu-latest', 'macos-latest', 'windows-latest' ] - py: [ '3.7', '3.11' ] # minimal and latest + py: [ '3.8', '3.11' ] # minimal and latest steps: - name: Install Python ${{ matrix.py }} uses: actions/setup-python@v2 diff --git a/.github/workflows/deploy.yml b/.github/workflows/deploy.yml index 79b3169b..7eb0dce8 100644 --- a/.github/workflows/deploy.yml +++ b/.github/workflows/deploy.yml @@ -31,7 +31,7 @@ jobs: env: CIBW_TEST_COMMAND: "python -VV && python -m unittest discover -f -s {project}/tests/test_unit/" CIBW_TEST_EXTRAS: "full" - CIBW_SKIP: "pp* cp27-* cp33-* cp34-* cp35-* cp36-* cp310-win32 *-musllinux_* *-manylinux_i686" + CIBW_SKIP: "pp* cp27-* cp33-* cp34-* cp35-* cp36-* cp37-* cp310-win32 *-musllinux_* *-manylinux_i686" CIBW_ARCHS_MACOS: x86_64 arm64 universal2 CIBW_ARCHS_LINUX: auto aarch64 diff --git a/Makefile b/Makefile index 3cba1509..4317c169 100644 --- a/Makefile +++ b/Makefile @@ -9,7 +9,7 @@ MAKEFLAGS += --no-builtin-rules PACKAGE=clevercsv DOC_DIR=./docs/ -VENV_DIR=/tmp/clevercsv_venv/ +VENV_DIR=/tmp/clevercsv_venv PYTHON ?= python .PHONY: help @@ -51,7 +51,7 @@ dist: man ## Make Python source distribution .PHONY: test integration integration_partial -test: green pytest +test: mypy green pytest green: venv ## Run unit tests source $(VENV_DIR)/bin/activate && green -a -vv ./tests/test_unit @@ -59,6 +59,10 @@ green: venv ## Run unit tests pytest: venv ## Run unit tests with PyTest source $(VENV_DIR)/bin/activate && pytest -ra -m 'not network' +mypy: venv ## Run type checks + source $(VENV_DIR)/bin/activate && \ + mypy --check-untyped-defs ./stubs $(PACKAGE) ./tests + integration: venv ## Run integration tests source $(VENV_DIR)/bin/activate && python ./tests/test_integration/test_dialect_detection.py -v diff --git a/clevercsv/__version__.py b/clevercsv/__version__.py index 2b9f8268..e054b220 100644 --- a/clevercsv/__version__.py +++ b/clevercsv/__version__.py @@ -1,5 +1,7 @@ # -*- coding: utf-8 -*- -VERSION = (0, 8, 0) +from typing import Tuple -__version__ = ".".join(map(str, VERSION)) +VERSION: Tuple[int, int, int] = (0, 8, 0) + +__version__: str = ".".join(map(str, VERSION)) diff --git a/clevercsv/_optional.py b/clevercsv/_optional.py index 922f44b1..77588e84 100644 --- a/clevercsv/_optional.py +++ b/clevercsv/_optional.py @@ -13,9 +13,12 @@ import importlib +from types import ModuleType + from typing import Dict from typing import List from typing import NamedTuple +from typing import Optional from packaging.version import Version @@ -35,7 +38,9 @@ class OptionalDependency(NamedTuple): ] -def import_optional_dependency(name, raise_on_missing=True): +def import_optional_dependency( + name: str, raise_on_missing: bool = True +) -> Optional[ModuleType]: """ Import an optional dependency. diff --git a/clevercsv/_types.py b/clevercsv/_types.py new file mode 100644 index 00000000..cc684bfd --- /dev/null +++ b/clevercsv/_types.py @@ -0,0 +1,50 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +import csv +import os +import sys + +from typing import TYPE_CHECKING +from typing import Any +from typing import Mapping +from typing import Type +from typing import Union + +from clevercsv.dialect import SimpleDialect + +AnyPath = Union[str, bytes, "os.PathLike[str]", "os.PathLike[bytes]"] +_OpenFile = Union[AnyPath, int] +_DictRow = Mapping[str, Any] +_DialectLike = Union[str, csv.Dialect, Type[csv.Dialect], SimpleDialect] + +if sys.version_info >= (3, 8): + from typing import Dict as _DictReadMapping +else: + from collections import OrderedDict as _DictReadMapping + + +if TYPE_CHECKING: + from _typeshed import FileDescriptorOrPath # NOQA + from _typeshed import SupportsIter # NOQA + from _typeshed import SupportsWrite # NOQA + + __all__ = [ + "SupportsWrite", + "SupportsIter", + "FileDescriptorOrPath", + "AnyPath", + "_OpenFile", + "_DictRow", + "_DialectLike", + "_DictReadMapping", + ] +else: + __all__ = [ + "AnyPath", + "_OpenFile", + "_DictRow", + "_DialectLike", + "_DictReadMapping", + ] diff --git a/clevercsv/break_ties.py b/clevercsv/break_ties.py index 017fc631..490adfd7 100644 --- a/clevercsv/break_ties.py +++ b/clevercsv/break_ties.py @@ -7,12 +7,17 @@ """ +from typing import List +from typing import Optional + from .cparser_util import parse_string from .dialect import SimpleDialect from .utils import pairwise -def tie_breaker(data, dialects): +def tie_breaker( + data: str, dialects: List[SimpleDialect] +) -> Optional[SimpleDialect]: """ Break ties between dialects. @@ -42,7 +47,9 @@ def tie_breaker(data, dialects): return None -def reduce_pairwise(data, dialects): +def reduce_pairwise( + data: str, dialects: List[SimpleDialect] +) -> Optional[List[SimpleDialect]]: """Reduce the set of dialects by breaking pairwise ties Parameters @@ -62,7 +69,7 @@ def reduce_pairwise(data, dialects): """ equal_delim = len(set([d.delimiter for d in dialects])) == 1 if not equal_delim: - return None + return None # TODO: This might be wrong, it can just return the input! # First, identify dialects that result in the same parsing result. equal_dialects = [] @@ -99,7 +106,9 @@ def _dialects_only_differ_in_field( ) -def break_ties_two(data, A, B): +def break_ties_two( + data: str, A: SimpleDialect, B: SimpleDialect +) -> Optional[SimpleDialect]: """Break ties between two dialects. This function breaks ties between two dialects that give the same score. We @@ -152,7 +161,7 @@ def break_ties_two(data, A, B): # quotechar has an effect return d_yes elif _dialects_only_differ_in_field(A, B, "delimiter"): - if sorted([A.delimiter, B.delimiter]) == sorted([",", " "]): + if set([A.delimiter, B.delimiter]) == set([",", " "]): # Artifact due to type detection (comma as radix point) if A.delimiter == ",": return A @@ -175,14 +184,14 @@ def break_ties_two(data, A, B): # we can't break this tie (for now) if len(X) != len(Y): return None - for x, y in zip(X, Y): - if len(x) != len(y): + for row_X, row_Y in zip(X, Y): + if len(row_X) != len(row_Y): return None cells_escaped = [] cells_unescaped = [] - for x, y in zip(X, Y): - for u, v in zip(x, y): + for row_X, row_Y in zip(X, Y): + for u, v in zip(row_X, row_Y): if u != v: cells_unescaped.append(u) cells_escaped.append(v) @@ -221,16 +230,18 @@ def break_ties_two(data, A, B): if len(X) != len(Y): return None - for x, y in zip(X, Y): - if len(x) != len(y): + for row_X, row_Y in zip(X, Y): + if len(row_X) != len(row_Y): return None # if we're here, then there is no effect on structure. # we test if the only cells that differ are those that have an # escapechar+quotechar combination. + assert isinstance(d_yes.escapechar, str) + assert isinstance(d_yes.quotechar, str) eq = d_yes.escapechar + d_yes.quotechar - for rX, rY in zip(X, Y): - for x, y in zip(rX, rY): + for row_X, row_Y in zip(X, Y): + for x, y in zip(row_X, row_Y): if x != y: if eq not in x: return None @@ -243,7 +254,9 @@ def break_ties_two(data, A, B): return None -def break_ties_three(data, A, B, C): +def break_ties_three( + data: str, A: SimpleDialect, B: SimpleDialect, C: SimpleDialect +) -> Optional[SimpleDialect]: """Break ties between three dialects. If the delimiters and the escape characters are all equal, then we look for @@ -273,7 +286,7 @@ def break_ties_three(data, A, B, C): Returns ------- - dialect: SimpleDialect + dialect: Optional[SimpleDialect] The chosen dialect if the tie can be broken, None otherwise. Notes @@ -307,6 +320,7 @@ def break_ties_three(data, A, B, C): ) if p_none is None: return None + assert d_none is not None rem = [ (p, d) for p, d in zip([pA, pB, pC], dialects) if not p == p_none @@ -318,6 +332,8 @@ def break_ties_three(data, A, B, C): # the CSV paper. When fixing the delimiter to Tab, rem = []. # Try to reduce pairwise new_dialects = reduce_pairwise(data, dialects) + if new_dialects is None: + return None if len(new_dialects) == 1: return new_dialects[0] return None @@ -347,7 +363,9 @@ def break_ties_three(data, A, B, C): return None -def break_ties_four(data, dialects): +def break_ties_four( + data: str, dialects: List[SimpleDialect] +) -> Optional[SimpleDialect]: """Break ties between four dialects. This function works by breaking the ties between pairs of dialects that @@ -368,7 +386,7 @@ def break_ties_four(data, dialects): Returns ------- - dialect: SimpleDialect + dialect: Optional[SimpleDialect] The chosen dialect if the tie can be broken, None otherwise. Notes @@ -378,19 +396,22 @@ def break_ties_four(data, dialects): examples are found. """ + # TODO: Check for length 4, handle more than 4 too? equal_delim = len(set([d.delimiter for d in dialects])) == 1 if not equal_delim: return None - dialects = reduce_pairwise(data, dialects) + reduced_dialects = reduce_pairwise(data, dialects) + if reduced_dialects is None: + return None # Defer to other functions if the number of dialects was reduced - if len(dialects) == 1: - return dialects[0] - elif len(dialects) == 2: - return break_ties_two(data, *dialects) - elif len(dialects) == 3: - return break_ties_three(data, *dialects) + if len(reduced_dialects) == 1: + return reduced_dialects[0] + elif len(reduced_dialects) == 2: + return break_ties_two(data, *reduced_dialects) + elif len(reduced_dialects) == 3: + return break_ties_three(data, *reduced_dialects) return None diff --git a/clevercsv/cabstraction.pyi b/clevercsv/cabstraction.pyi new file mode 100644 index 00000000..df3228e7 --- /dev/null +++ b/clevercsv/cabstraction.pyi @@ -0,0 +1,11 @@ +# -*- coding: utf-8 -*- + +from typing import Optional + +def base_abstraction( + data: str, + delimiter: Optional[str], + quotechar: Optional[str], + escapechar: Optional[str], +) -> str: ... +def c_merge_with_quotechar(data: str) -> str: ... diff --git a/clevercsv/consistency.py b/clevercsv/consistency.py index 6a80d711..71d2afdb 100644 --- a/clevercsv/consistency.py +++ b/clevercsv/consistency.py @@ -89,7 +89,7 @@ def cached_is_known_type(cell: str, is_quoted: bool) -> bool: def detect( self, data: str, delimiters: Optional[Iterable[str]] = None - ) -> None: + ) -> Optional[SimpleDialect]: """Detect the dialect using the consistency measure Parameters @@ -184,8 +184,11 @@ def get_best_dialects( ) -> List[SimpleDialect]: """Identify the dialects with the highest consistency score""" Qscores = [score.Q for score in scores.values()] - Qscores = list(filter(lambda q: q is not None, Qscores)) - Qmax = max(Qscores) + Qmax = -float("inf") + for q in Qscores: + if q is None: + continue + Qmax = max(Qmax, q) return [d for d, score in scores.items() if score.Q == Qmax] def compute_type_score( @@ -194,6 +197,7 @@ def compute_type_score( """Compute the type score""" total = known = 0 for row in parse_string(data, dialect, return_quoted=True): + assert all(isinstance(cell, tuple) for cell in row) for cell, is_quoted in row: total += 1 known += self._cached_is_known_type(cell, is_quoted=is_quoted) @@ -203,7 +207,10 @@ def compute_type_score( def detect_dialect_consistency( - data, delimiters=None, skip=True, verbose=False + data: str, + delimiters: Optional[Iterable[str]] = None, + skip: bool = True, + verbose: bool = False, ): """Helper function that wraps ConsistencyDetector""" # Mostly kept for backwards compatibility diff --git a/clevercsv/console/commands/detect.py b/clevercsv/console/commands/detect.py index 590338a4..74775d90 100644 --- a/clevercsv/console/commands/detect.py +++ b/clevercsv/console/commands/detect.py @@ -4,6 +4,9 @@ import sys import time +from typing import Any +from typing import Dict + from wilderness import Command from clevercsv.wrappers import detect_dialect @@ -125,7 +128,7 @@ def handle(self): if self.args.add_runtime: print(f"runtime = {runtime}") elif self.args.json: - dialect_dict = dialect.to_dict() + dialect_dict: Dict[str, Any] = dialect.to_dict() if self.args.add_runtime: dialect_dict["runtime"] = runtime print(json.dumps(dialect_dict)) diff --git a/clevercsv/console/commands/view.py b/clevercsv/console/commands/view.py index dbce0916..df09b509 100644 --- a/clevercsv/console/commands/view.py +++ b/clevercsv/console/commands/view.py @@ -2,19 +2,9 @@ import sys -try: - import tabview -except ImportError: - - class TabView: - def view(*args, **kwargs): - print( - "Error: unfortunately Tabview is not available on Windows.", - file=sys.stderr, - ) - - tabview = TabView() - +from typing import List +from typing import Optional +from typing import Sequence from wilderness import Command @@ -61,6 +51,17 @@ def register(self): help="Transpose the columns of the input file before viewing", ) + def _tabview(self, rows) -> None: + try: + from tabview import view + except ImportError: + print( + "Error: unfortunately Tabview is not available on Windows.", + file=sys.stderr, + ) + return + view(rows) + def handle(self) -> int: verbose = self.args.verbose num_chars = parse_int(self.args.num_chars, "num-chars") @@ -77,7 +78,7 @@ def handle(self) -> int: if self.args.transpose: max_row_length = max(map(len, rows)) - fixed_rows = [] + fixed_rows: List[Sequence[Optional[str]]] = [] for row in rows: if len(row) == max_row_length: fixed_rows.append(row) @@ -86,5 +87,5 @@ def handle(self) -> int: row + [None] * (max_row_length - len(row)) ) rows = list(map(list, zip(*fixed_rows))) - tabview.view(rows) + self._tabview(rows) return 0 diff --git a/clevercsv/cparser.pyi b/clevercsv/cparser.pyi new file mode 100644 index 00000000..53e3a064 --- /dev/null +++ b/clevercsv/cparser.pyi @@ -0,0 +1,54 @@ +# -*- coding: utf-8 -*- + +from __future__ import annotations + +from typing import Final +from typing import Generic +from typing import Iterable +from typing import List +from typing import Literal +from typing import Optional +from typing import Tuple +from typing import TypeVar +from typing import overload + +_T = TypeVar("_T") + +class Parser(Generic[_T]): + _return_quoted: Final[bool] + + @overload + def __init__( + self: Parser[List[Tuple[str, bool]]], + delimiter: Optional[str] = "", + quotechar: Optional[str] = "", + escapechar: Optional[str] = "", + field_limit: Optional[int] = 128 * 1024, + strict: Optional[bool] = False, + return_quoted: Literal[True] = ..., + ) -> None: ... + @overload + def __init__( + self: Parser[List[str]], + delimiter: Optional[str] = "", + quotechar: Optional[str] = "", + escapechar: Optional[str] = "", + field_limit: Optional[int] = 128 * 1024, + strict: Optional[bool] = False, + return_quoted: Literal[False] = ..., + ) -> None: ... + @overload + def __init__( + self, + data: Iterable[str], + delimiter: Optional[str] = "", + quotechar: Optional[str] = "", + escapechar: Optional[str] = "", + field_limit: Optional[int] = 128 * 1024, + strict: Optional[bool] = False, + return_quoted: bool = ..., + ) -> None: ... + def __iter__(self) -> "Parser": ... + def __next__(self) -> _T: ... + +class Error(Exception): ... diff --git a/clevercsv/cparser_util.py b/clevercsv/cparser_util.py index 4ac528ea..484c2573 100644 --- a/clevercsv/cparser_util.py +++ b/clevercsv/cparser_util.py @@ -7,15 +7,23 @@ import io +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional +from typing import Tuple +from typing import Union + from .cparser import Error as ParserError from .cparser import Parser from .dialect import SimpleDialect from .exceptions import Error -_FIELD_SIZE_LIMIT = 128 * 1024 +_FIELD_SIZE_LIMIT: int = 128 * 1024 -def field_size_limit(*args, **kwargs): +def field_size_limit(*args: Any, **kwargs: Any) -> int: """Get/Set the limit to the field size. This function is adapted from the one in the Python CSV module. See the @@ -23,29 +31,54 @@ def field_size_limit(*args, **kwargs): """ global _FIELD_SIZE_LIMIT old_limit = _FIELD_SIZE_LIMIT - args = list(args) + list(kwargs.values()) - if not 0 <= len(args) <= 1: + all_args = list(args) + list(kwargs.values()) + if not 0 <= len(all_args) <= 1: raise TypeError( - "field_size_limit expected at most 1 arguments, got %i" % len(args) + "field_size_limit expected at most 1 arguments, got %i" + % len(all_args) ) - if len(args) == 0: + if len(all_args) == 0: return old_limit - limit = args[0] + limit = all_args[0] if not isinstance(limit, int): raise TypeError("limit must be an integer") _FIELD_SIZE_LIMIT = int(limit) return old_limit +def _parse_data( + data: Iterable[str], + delimiter: str, + quotechar: str, + escapechar: str, + strict: bool, + return_quoted: bool = False, +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: + parser = Parser( + data, + delimiter=delimiter, + quotechar=quotechar, + escapechar=escapechar, + field_limit=field_size_limit(), + strict=strict, + return_quoted=return_quoted, + ) + try: + for row in parser: + yield row + except ParserError as e: + raise Error(str(e)) + + def parse_data( - data, - dialect=None, - delimiter=None, - quotechar=None, - escapechar=None, - strict=None, - return_quoted=False, -): + data: Iterable[str], + dialect: Optional[SimpleDialect] = None, + delimiter: Optional[str] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + strict: Optional[bool] = None, + return_quoted: bool = False, +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: """Parse the data given a dialect using the C parser Parameters @@ -96,22 +129,24 @@ def parse_data( escapechar_ = escapechar if escapechar is not None else dialect.escapechar strict_ = strict if strict is not None else dialect.strict - parser = Parser( + yield from _parse_data( data, - delimiter=delimiter_, - quotechar=quotechar_, - escapechar=escapechar_, - field_limit=field_size_limit(), - strict=strict_, + delimiter_, + quotechar_, + escapechar_, + strict_, return_quoted=return_quoted, ) - try: - for row in parser: - yield row - except ParserError as e: - raise Error(str(e)) -def parse_string(data, *args, **kwargs): +def parse_string( + data: str, + dialect: SimpleDialect, + return_quoted: bool = False, +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: """Utility for when the CSV file is encoded as a single string""" - return parse_data(io.StringIO(data, newline=""), *args, **kwargs) + return parse_data( + iter(io.StringIO(data, newline="")), + dialect=dialect, + return_quoted=return_quoted, + ) diff --git a/clevercsv/cparser_util.pyi b/clevercsv/cparser_util.pyi new file mode 100644 index 00000000..e78c21cd --- /dev/null +++ b/clevercsv/cparser_util.pyi @@ -0,0 +1,69 @@ +# -*- coding: utf-8 -*- + +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Literal +from typing import Optional +from typing import Tuple +from typing import Union +from typing import overload + +from .dialect import SimpleDialect + +def field_size_limit(*args: Any, **kwargs: Any) -> int: ... +@overload +def _parse_data( + data: Iterable[str], + delimiter: str, + quotechar: str, + escapechar: str, + strict: bool, + return_quoted: Literal[False] = ..., +) -> Iterator[List[str]]: ... +@overload +def _parse_data( + data: Iterable[str], + delimiter: str, + quotechar: str, + escapechar: str, + strict: bool, + return_quoted: Literal[True], +) -> Iterator[List[Tuple[str, bool]]]: ... +@overload +def _parse_data( + data: Iterable[str], + delimiter: str, + quotechar: str, + escapechar: str, + strict: bool, + return_quoted: bool = ..., +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: ... +def parse_data( + data: Iterable[str], + dialect: Optional[SimpleDialect] = None, + delimiter: Optional[str] = None, + quotechar: Optional[str] = None, + escapechar: Optional[str] = None, + strict: Optional[bool] = None, + return_quoted: bool = False, +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: ... +@overload +def parse_string( + data: str, + dialect: SimpleDialect, + return_quoted: Literal[False] = ..., +) -> Iterator[List[str]]: ... +@overload +def parse_string( + data: str, + dialect: SimpleDialect, + return_quoted: Literal[True], +) -> Iterator[List[Tuple[str, bool]]]: ... +@overload +def parse_string( + data: str, + dialect: SimpleDialect, + return_quoted: bool = ..., +) -> Iterator[Union[List[str], List[Tuple[str, bool]]]]: ... diff --git a/clevercsv/detect.py b/clevercsv/detect.py index 8c5735ff..57c1e619 100644 --- a/clevercsv/detect.py +++ b/clevercsv/detect.py @@ -9,6 +9,10 @@ from io import StringIO +from typing import Dict +from typing import Optional +from typing import Union + from .consistency import ConsistencyDetector from .normal_form import detect_dialect_normal from .read import reader @@ -28,9 +32,6 @@ class Detector: """ - def __init__(self): - pass - def sniff(self, sample, delimiters=None, verbose=False): # Compatibility method for Python return self.detect(sample, delimiters=delimiters, verbose=verbose) @@ -126,10 +127,11 @@ def has_header(self, sample): header = next(rdr) # assume first row is header columns = len(header) - columnTypes = {} + columnTypes: Dict[int, Optional[Union[int, type]]] = {} for i in range(columns): columnTypes[i] = None + thisType: Union[int, type] checked = 0 for row in rdr: # arbitrary number of rows to check, to keep it sane @@ -169,6 +171,10 @@ def has_header(self, sample): else: hasHeader -= 1 else: # attempt typecast + if colType is None: + hasHeader += 1 + continue + try: colType(header[col]) except (ValueError, TypeError): diff --git a/clevercsv/detect_pattern.py b/clevercsv/detect_pattern.py index 1b26bbea..dc8b49bd 100644 --- a/clevercsv/detect_pattern.py +++ b/clevercsv/detect_pattern.py @@ -10,24 +10,28 @@ import collections import re +from typing import Optional from typing import Pattern from .cabstraction import base_abstraction from .cabstraction import c_merge_with_quotechar +from .dialect import SimpleDialect -DEFAULT_EPS_PAT = 1e-3 +DEFAULT_EPS_PAT: float = 1e-3 RE_MULTI_C: Pattern = re.compile(r"C{2,}") -def pattern_score(data, dialect, eps=DEFAULT_EPS_PAT): +def pattern_score( + data: str, dialect: SimpleDialect, eps: float = DEFAULT_EPS_PAT +) -> float: """ Compute the pattern score for given data and a dialect. Parameters ---------- - data : string + data : str The data of the file as a raw character string dialect: dialect.Dialect @@ -41,7 +45,7 @@ def pattern_score(data, dialect, eps=DEFAULT_EPS_PAT): """ A = make_abstraction(data, dialect) row_patterns = collections.Counter(A.split("R")) - P = 0 + P = 0.0 for pat_k, Nk in row_patterns.items(): Lk = len(pat_k.split("D")) P += Nk * (max(eps, Lk - 1) / Lk) @@ -49,7 +53,7 @@ def pattern_score(data, dialect, eps=DEFAULT_EPS_PAT): return P -def make_abstraction(data, dialect): +def make_abstraction(data: str, dialect: SimpleDialect) -> str: """Create an abstract representation of the CSV file based on the dialect. This function constructs the basic abstraction used to compute the row @@ -78,7 +82,9 @@ def make_abstraction(data, dialect): return A -def merge_with_quotechar(S, dialect=None): +def merge_with_quotechar( + S: str, dialect: Optional[SimpleDialect] = None +) -> str: """Merge quoted blocks in the abstraction This function takes the abstract representation and merges quoted blocks @@ -103,7 +109,7 @@ def merge_with_quotechar(S, dialect=None): return c_merge_with_quotechar(S) -def fill_empties(abstract): +def fill_empties(abstract: str) -> str: """Fill empty cells in the abstraction The way the row patterns are constructed assumes that empty cells are @@ -143,7 +149,7 @@ def fill_empties(abstract): return abstract -def strip_trailing(abstract): +def strip_trailing(abstract: str) -> str: """Strip trailing row separator from abstraction.""" while abstract.endswith("R"): abstract = abstract[:-1] diff --git a/clevercsv/detect_type.py b/clevercsv/detect_type.py index abc7f69d..97fe116e 100644 --- a/clevercsv/detect_type.py +++ b/clevercsv/detect_type.py @@ -10,6 +10,7 @@ import json from typing import Dict +from typing import List from typing import Optional from typing import Pattern @@ -19,7 +20,7 @@ DEFAULT_EPS_TYPE = 1e-10 -class TypeDetector(object): +class TypeDetector: def __init__( self, patterns: Optional[Dict[str, Pattern]] = None, @@ -48,26 +49,27 @@ def _register_type_tests(self): ("json", self.is_json_obj), ] - def list_known_types(self): + def list_known_types(self) -> List[str]: return [tt[0] for tt in self._type_tests] - def is_known_type(self, cell, is_quoted=False): + def is_known_type(self, cell: str, is_quoted: bool = False) -> bool: return self.detect_type(cell, is_quoted=is_quoted) is not None - def detect_type(self, cell, is_quoted=False): + def detect_type(self, cell: str, is_quoted: bool = False): cell = cell.strip() if self.strip_whitespace else cell for name, func in self._type_tests: if func(cell, is_quoted=is_quoted): return name return None - def _run_regex(self, cell, patname): + def _run_regex(self, cell: str, patname: str) -> bool: cell = cell.strip() if self.strip_whitespace else cell pat = self.patterns.get(patname, None) + assert pat is not None match = pat.fullmatch(cell) return match is not None - def is_number(self, cell, **kwargs): + def is_number(self, cell: str, is_quoted: bool = False) -> bool: if cell == "": return False if self._run_regex(cell, "number_1"): @@ -78,21 +80,21 @@ def is_number(self, cell, **kwargs): return True return False - def is_ipv4(self, cell, **kwargs): + def is_ipv4(self, cell: str, is_quoted: bool = False) -> bool: return self._run_regex(cell, "ipv4") - def is_url(self, cell, **kwargs): + def is_url(self, cell: str, is_quoted: bool = False) -> bool: return self._run_regex(cell, "url") - def is_email(self, cell, **kwargs): + def is_email(self, cell: str, is_quoted: bool = False) -> bool: return self._run_regex(cell, "email") - def is_unicode_alphanum(self, cell, is_quoted=False, **kwargs): + def is_unicode_alphanum(self, cell: str, is_quoted: bool = False) -> bool: if is_quoted: return self._run_regex(cell, "unicode_alphanum_quoted") return self._run_regex(cell, "unicode_alphanum") - def is_date(self, cell, **kwargs): + def is_date(self, cell: str, is_quoted: bool = False) -> bool: # This function assumes the cell is not a number. cell = cell.strip() if self.strip_whitespace else cell if not cell: @@ -101,7 +103,7 @@ def is_date(self, cell, **kwargs): return False return self._run_regex(cell, "date") - def is_time(self, cell, **kwargs): + def is_time(self, cell: str, is_quoted: bool = False) -> bool: cell = cell.strip() if self.strip_whitespace else cell if not cell: return False @@ -114,14 +116,15 @@ def is_time(self, cell, **kwargs): or self._run_regex(cell, "time_hhmmsszz") ) - def is_empty(self, cell, **kwargs): + def is_empty(self, cell: str, is_quoted: bool = False) -> bool: return cell == "" - def is_percentage(self, cell, **kwargs): + def is_percentage(self, cell: str, is_quoted: bool = False) -> bool: return cell.endswith("%") and self.is_number(cell.rstrip("%")) - def is_currency(self, cell, **kwargs): + def is_currency(self, cell: str, is_quoted: bool = False) -> bool: pat = self.patterns.get("currency", None) + assert pat is not None m = pat.fullmatch(cell) if m is None: return False @@ -130,7 +133,7 @@ def is_currency(self, cell, **kwargs): return False return True - def is_datetime(self, cell, **kwargs): + def is_datetime(self, cell: str, is_quoted: bool = False) -> bool: # Takes care of cells with '[date] [time]' and '[date]T[time]' (iso) if not cell: return False @@ -182,18 +185,18 @@ def is_datetime(self, cell, **kwargs): return True return False - def is_nan(self, cell, **kwargs): + def is_nan(self, cell: str, is_quoted: bool = False) -> bool: if cell.lower() in ["n/a", "na", "nan"]: return True return False - def is_unix_path(self, cell, **kwargs): + def is_unix_path(self, cell: str, is_quoted: bool = False) -> bool: return self._run_regex(cell, "unix_path") - def is_bytearray(self, cell: str, **kwargs) -> bool: + def is_bytearray(self, cell: str, is_quoted: bool = False) -> bool: return cell.startswith("bytearray(b") and cell.endswith(")") - def is_json_obj(self, cell: str, **kwargs) -> bool: + def is_json_obj(self, cell: str, is_quoted: bool = False) -> bool: if not (cell.startswith("{") and cell.endswith("}")): return False try: diff --git a/clevercsv/dialect.py b/clevercsv/dialect.py index e01879ed..023622a3 100644 --- a/clevercsv/dialect.py +++ b/clevercsv/dialect.py @@ -12,13 +12,19 @@ import functools import json +from typing import Any +from typing import Dict +from typing import Optional +from typing import Type +from typing import Union + excel = csv.excel excel_tab = csv.excel_tab unix_dialect = csv.unix_dialect @functools.total_ordering -class SimpleDialect(object): +class SimpleDialect: """ The simplified dialect object. @@ -42,13 +48,19 @@ class SimpleDialect(object): """ - def __init__(self, delimiter, quotechar, escapechar, strict=False): + def __init__( + self, + delimiter: Optional[str], + quotechar: Optional[str], + escapechar: Optional[str], + strict: bool = False, + ): self.delimiter = delimiter self.quotechar = quotechar self.escapechar = escapechar self.strict = strict - def validate(self): + def validate(self) -> None: if self.delimiter is None or len(self.delimiter) > 1: raise ValueError( "Delimiter should be zero or one characters, got: %r" @@ -70,21 +82,26 @@ def validate(self): ) @classmethod - def from_dict(cls, d): - d = cls( + def from_dict( + cls: Type["SimpleDialect"], d: Dict[str, Any] + ) -> "SimpleDialect": + dialect = cls( d["delimiter"], d["quotechar"], d["escapechar"], strict=d["strict"] ) - return d + return dialect @classmethod - def from_csv_dialect(cls, d): + def from_csv_dialect( + cls: Type["SimpleDialect"], d: csv.Dialect + ) -> "SimpleDialect": delimiter = "" if d.delimiter is None else d.delimiter quotechar = "" if d.quoting == csv.QUOTE_NONE else d.quotechar escapechar = "" if d.escapechar is None else d.escapechar return cls(delimiter, quotechar, escapechar, strict=d.strict) - def to_csv_dialect(self): + def to_csv_dialect(self) -> csv.Dialect: class dialect(csv.Dialect): + assert self.delimiter is not None delimiter = self.delimiter quotechar = '"' if self.quotechar == "" else self.quotechar escapechar = None if self.escapechar == "" else self.escapechar @@ -93,10 +110,13 @@ class dialect(csv.Dialect): csv.QUOTE_NONE if self.quotechar == "" else csv.QUOTE_MINIMAL ) skipinitialspace = False + # TODO: We need to set this because it can't be None anymore in + # recent versions of Python + lineterminator = "\n" - return dialect + return dialect() - def to_dict(self): + def to_dict(self) -> Dict[str, Union[str, bool, None]]: self.validate() d = dict( delimiter=self.delimiter, @@ -106,16 +126,16 @@ def to_dict(self): ) return d - def serialize(self): + def serialize(self) -> str: """Serialize dialect to a JSON object""" return json.dumps(self.to_dict()) @classmethod - def deserialize(cls, obj): + def deserialize(cls: Type["SimpleDialect"], obj: str) -> "SimpleDialect": """Deserialize dialect from a JSON object""" return cls.from_dict(json.loads(obj)) - def __repr__(self): + def __repr__(self) -> str: return "SimpleDialect(%r, %r, %r)" % ( self.delimiter, self.quotechar, @@ -125,7 +145,7 @@ def __repr__(self): def __key(self): return (self.delimiter, self.quotechar, self.escapechar, self.strict) - def __hash__(self): + def __hash__(self) -> int: return hash(self.__key()) def __eq__(self, other): diff --git a/clevercsv/dict_read_write.py b/clevercsv/dict_read_write.py index 97a1e7bb..521d10a9 100644 --- a/clevercsv/dict_read_write.py +++ b/clevercsv/dict_read_write.py @@ -3,50 +3,78 @@ """ DictReader and DictWriter. -This code is entirely copied from the Python csv module. The only exception is +This code is entirely copied from the Python csv module. The only exception is that it uses the `reader` and `writer` classes from our package. Author: Gertjan van den Burg """ +from __future__ import annotations + import warnings from collections import OrderedDict - -from .read import reader -from .write import writer - - -class DictReader(object): +from collections.abc import Collection + +from typing import TYPE_CHECKING +from typing import Any +from typing import Generic +from typing import Iterable +from typing import Iterator +from typing import Literal +from typing import Mapping +from typing import Optional +from typing import Sequence +from typing import TypeVar +from typing import Union +from typing import cast + +from clevercsv.read import reader +from clevercsv.write import writer + +if TYPE_CHECKING: + from clevercsv._types import SupportsWrite + from clevercsv._types import _DialectLike + from clevercsv._types import _DictReadMapping + +_T = TypeVar("_T") + + +class DictReader( + Generic[_T], Iterator["_DictReadMapping[Union[_T, Any], Union[str, Any]]"] +): def __init__( self, - f, - fieldnames=None, - restkey=None, - restval=None, - dialect="excel", - *args, - **kwds - ): + f: Iterable[str], + fieldnames: Optional[Sequence[_T]] = None, + restkey: Optional[str] = None, + restval: Optional[str] = None, + dialect: "_DialectLike" = "excel", + *args: Any, + **kwds: Any, + ) -> None: self._fieldnames = fieldnames self.restkey = restkey self.restval = restval - self.reader = reader(f, dialect, *args, **kwds) + self.reader: reader = reader(f, dialect, *args, **kwds) self.dialect = dialect self.line_num = 0 - def __iter__(self): + def __iter__(self) -> "DictReader": return self @property - def fieldnames(self): + def fieldnames(self) -> Sequence[_T]: if self._fieldnames is None: try: - self._fieldnames = next(self.reader) + fieldnames = next(self.reader) + self._fieldnames = [cast(_T, f) for f in fieldnames] except StopIteration: pass + assert self._fieldnames is not None + # Note: this was added because I don't think it's expected that Python # simply drops information if there are duplicate headers. There is # discussion on this issue in the Python bug tracker here: @@ -62,10 +90,10 @@ def fieldnames(self): return self._fieldnames @fieldnames.setter - def fieldnames(self, value): + def fieldnames(self, value: Sequence[_T]) -> None: self._fieldnames = value - def __next__(self): + def __next__(self) -> "_DictReadMapping[Union[_T, Any], Union[str, Any]]": if self.line_num == 0: self.fieldnames row = next(self.reader) @@ -73,7 +101,8 @@ def __next__(self): while row == []: row = next(self.reader) - d = OrderedDict(zip(self.fieldnames, row)) + + d: _DictReadMapping = OrderedDict(zip(self.fieldnames, row)) lf = len(self.fieldnames) lr = len(row) if lf < lr: @@ -84,16 +113,16 @@ def __next__(self): return d -class DictWriter(object): +class DictWriter(Generic[_T]): def __init__( self, - f, - fieldnames, - restval="", - extrasaction="raise", - dialect="excel", - *args, - **kwds + f: SupportsWrite[str], + fieldnames: Collection[_T], + restval: Optional[Any] = "", + extrasaction: Literal["raise", "ignore"] = "raise", + dialect: "_DialectLike" = "excel", + *args: Any, + **kwds: Any, ): self.fieldnames = fieldnames self.restval = restval @@ -104,11 +133,11 @@ def __init__( self.extrasaction = extrasaction self.writer = writer(f, dialect, *args, **kwds) - def writeheader(self): + def writeheader(self) -> Any: header = dict(zip(self.fieldnames, self.fieldnames)) return self.writerow(header) - def _dict_to_list(self, rowdict): + def _dict_to_list(self, rowdict: Mapping[_T, Any]) -> Iterator[Any]: if self.extrasaction == "raise": wrong_fields = rowdict.keys() - self.fieldnames if wrong_fields: @@ -118,8 +147,8 @@ def _dict_to_list(self, rowdict): ) return (rowdict.get(key, self.restval) for key in self.fieldnames) - def writerow(self, rowdict): + def writerow(self, rowdict: Mapping[_T, Any]) -> Any: return self.writer.writerow(self._dict_to_list(rowdict)) - def writerows(self, rowdicts): + def writerows(self, rowdicts: Iterable[Mapping[_T, Any]]) -> None: return self.writer.writerows(map(self._dict_to_list, rowdicts)) diff --git a/clevercsv/encoding.py b/clevercsv/encoding.py index e1010d0b..68bd259c 100644 --- a/clevercsv/encoding.py +++ b/clevercsv/encoding.py @@ -9,12 +9,17 @@ """ +from typing import Optional + import chardet from ._optional import import_optional_dependency +from ._types import _OpenFile -def get_encoding(filename, try_cchardet=True): +def get_encoding( + filename: _OpenFile, try_cchardet: bool = True +) -> Optional[str]: """Get the encoding of the file This function uses the chardet package for detecting the encoding of a diff --git a/clevercsv/escape.py b/clevercsv/escape.py index d72b0b4d..3ba574b5 100644 --- a/clevercsv/escape.py +++ b/clevercsv/escape.py @@ -11,8 +11,12 @@ import sys import unicodedata +from typing import Iterable +from typing import Optional +from typing import Set + #: Set of default characters to *never* consider as escape character -DEFAULT_BLOCK_CHARS = set( +DEFAULT_BLOCK_CHARS: Set[str] = set( [ "!", "?", @@ -30,7 +34,7 @@ ) #: Set of characters in the Unicode "Po" category -UNICODE_PO_CHARS = set( +UNICODE_PO_CHARS: Set[str] = set( [ c for c in map(chr, range(sys.maxunicode + 1)) @@ -39,7 +43,9 @@ ) -def is_potential_escapechar(char, encoding, block_char=None): +def is_potential_escapechar( + char: str, encoding: str, block_char: Optional[Iterable[str]] = None +) -> bool: """Check if a character is a potential escape character. A character is considered a potential escape character if it is in the @@ -54,7 +60,7 @@ def is_potential_escapechar(char, encoding, block_char=None): encoding : str The encoding of the character - block_char : iterable + block_char : Optional[Iterable[str]] Characters that are in the Punctuation Other category but that should not be considered as escape character. If None, the default set is used, which is defined in :py:data:`DEFAULT_BLOCK_CHARS`. diff --git a/clevercsv/py.typed b/clevercsv/py.typed new file mode 100644 index 00000000..e69de29b diff --git a/clevercsv/read.py b/clevercsv/read.py index 365115bd..90ce3867 100644 --- a/clevercsv/read.py +++ b/clevercsv/read.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- """ -Drop-in replacement for the Python csv reader class. This is a wrapper for the +Drop-in replacement for the Python csv reader class. This is a wrapper for the Parser class, defined in :mod:`cparser`. Author: Gertjan van den Burg @@ -10,22 +10,40 @@ import csv +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Optional + from . import field_size_limit +from ._types import _DialectLike from .cparser import Error as ParserError from .cparser import Parser from .dialect import SimpleDialect from .exceptions import Error -class reader(object): - def __init__(self, csvfile, dialect="excel", **fmtparams): +class reader: + def __init__( + self, + csvfile: Iterable[str], + dialect: _DialectLike = "excel", + **fmtparams: Any, + ): self.csvfile = csvfile self.original_dialect = dialect - self.dialect = self._make_simple_dialect(dialect, **fmtparams) - self.line_num = 0 - self.parser_gen = None + self._dialect = self._make_simple_dialect(dialect, **fmtparams) + self.line_num: int = 0 + self.parser_gen: Optional[Parser] = None + + @property + def dialect(self) -> csv.Dialect: + return self._dialect.to_csv_dialect() - def _make_simple_dialect(self, dialect, **fmtparams): + def _make_simple_dialect( + self, dialect: _DialectLike, **fmtparams: Any + ) -> SimpleDialect: if isinstance(dialect, str): sd = SimpleDialect.from_csv_dialect(csv.get_dialect(dialect)) elif isinstance(dialect, csv.Dialect): @@ -40,27 +58,24 @@ def _make_simple_dialect(self, dialect, **fmtparams): sd.validate() return sd - def __iter__(self): + def __iter__(self) -> Iterator[List[str]]: self.parser_gen = Parser( self.csvfile, - delimiter=self.dialect.delimiter, - quotechar=self.dialect.quotechar, - escapechar=self.dialect.escapechar, + delimiter=self._dialect.delimiter, + quotechar=self._dialect.quotechar, + escapechar=self._dialect.escapechar, field_limit=field_size_limit(), - strict=self.dialect.strict, + strict=self._dialect.strict, ) return self - def __next__(self): + def __next__(self) -> List[str]: if self.parser_gen is None: self.__iter__() + assert self.parser_gen is not None try: row = next(self.parser_gen) except ParserError as e: raise Error(str(e)) self.line_num += 1 return row - - def next(self): - # for python 2 - return self.__next__() diff --git a/clevercsv/wrappers.py b/clevercsv/wrappers.py index 486a2bd1..807c33f9 100644 --- a/clevercsv/wrappers.py +++ b/clevercsv/wrappers.py @@ -6,12 +6,23 @@ Author: Gertjan van den Burg """ +from __future__ import annotations import os import warnings +from typing import TYPE_CHECKING +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import List +from typing import Mapping +from typing import Optional +from typing import TypeVar + from ._optional import import_optional_dependency from .detect import Detector +from .dialect import SimpleDialect from .dict_read_write import DictReader from .dict_read_write import DictWriter from .encoding import get_encoding @@ -19,10 +30,23 @@ from .read import reader from .write import writer +if TYPE_CHECKING: + import pandas as pd + + from ._types import FileDescriptorOrPath + from ._types import _DialectLike + from ._types import _DictReadMapping + +_T = TypeVar("_T") + def stream_dicts( - filename, dialect=None, encoding=None, num_chars=None, verbose=False -): + filename: FileDescriptorOrPath, + dialect: Optional[_DialectLike] = None, + encoding: Optional[str] = None, + num_chars: Optional[int] = None, + verbose: bool = False, +) -> Iterator["_DictReadMapping"]: """Read a CSV file as a generator over dictionaries This function streams the rows of the CSV file as dictionaries. The keys of @@ -71,14 +95,18 @@ def stream_dicts( data = fid.read(num_chars) if num_chars else fid.read() dialect = Detector().detect(data, verbose=verbose) fid.seek(0) - r = DictReader(fid, dialect=dialect) - for row in r: + reader: DictReader = DictReader(fid, dialect=dialect) + for row in reader: yield row def read_dicts( - filename, dialect=None, encoding=None, num_chars=None, verbose=False -): + filename: "FileDescriptorOrPath", + dialect: Optional["_DialectLike"] = None, + encoding: Optional[str] = None, + num_chars: Optional[int] = None, + verbose: bool = False, +) -> List["_DictReadMapping"]: """Read a CSV file as a list of dictionaries This function returns the rows of the CSV file as a list of dictionaries. @@ -132,12 +160,12 @@ def read_dicts( def read_table( - filename, - dialect=None, - encoding=None, - num_chars=None, - verbose=False, -): + filename: "FileDescriptorOrPath", + dialect: Optional["_DialectLike"] = None, + encoding: Optional[str] = None, + num_chars: Optional[int] = None, + verbose: bool = False, +) -> List[List[str]]: """Read a CSV file as a table (a list of lists) This is a convenience function that reads a CSV file and returns the data @@ -191,12 +219,12 @@ def read_table( def stream_table( - filename, - dialect=None, - encoding=None, - num_chars=None, - verbose=False, -): + filename: "FileDescriptorOrPath", + dialect: Optional["_DialectLike"] = None, + encoding: Optional[str] = None, + num_chars: Optional[int] = None, + verbose: bool = False, +) -> Iterator[List[str]]: """Read a CSV file as a generator over rows of a table This is a convenience function that reads a CSV file and returns the data @@ -251,7 +279,12 @@ def stream_table( yield from r -def read_dataframe(filename, *args, num_chars=None, **kwargs): +def read_dataframe( + filename: "FileDescriptorOrPath", + *args: Any, + num_chars: Optional[int] = None, + **kwargs: Any, +) -> pd.DataFrame: """Read a CSV file to a Pandas dataframe This function uses CleverCSV to detect the dialect, and then passes this to @@ -284,6 +317,7 @@ def read_dataframe(filename, *args, num_chars=None, **kwargs): if not (os.path.exists(filename) and os.path.isfile(filename)): raise ValueError("Filename must be a regular file") pd = import_optional_dependency("pandas") + assert pd is not None # Use provided encoding or detect it, and record it for pandas enc = kwargs.get("encoding") or get_encoding(filename) @@ -306,13 +340,13 @@ def read_dataframe(filename, *args, num_chars=None, **kwargs): def detect_dialect( - filename, - num_chars=None, - encoding=None, - verbose=False, - method="auto", - skip=True, -): + filename: "FileDescriptorOrPath", + num_chars: Optional[int] = None, + encoding: Optional[str] = None, + verbose: bool = False, + method: str = "auto", + skip: bool = True, +) -> SimpleDialect: """Detect the dialect of a CSV file This is a utility function that simply returns the detected dialect of a @@ -360,8 +394,12 @@ def detect_dialect( def write_table( - table, filename, dialect="excel", transpose=False, encoding=None -): + table: Iterable[Iterable[Any]], + filename: "FileDescriptorOrPath", + dialect: "_DialectLike" = "excel", + transpose: bool = False, + encoding: Optional[str] = None, +) -> None: """Write a table (a list of lists) to a file This is a convenience function for writing a table to a CSV file. If the @@ -400,17 +438,24 @@ def write_table( return if transpose: - table = list(map(list, zip(*table))) + list_table = list(map(list, zip(*table))) + else: + list_table = list(map(list, table)) - if len(set(map(len, table))) > 1: + if len(set(map(len, list_table))) > 1: raise ValueError("Table doesn't have constant row length.") with open(filename, "w", newline="", encoding=encoding) as fp: w = writer(fp, dialect=dialect) - w.writerows(table) + w.writerows(list_table) -def write_dicts(items, filename, dialect="excel", encoding=None): +def write_dicts( + items: Iterable[Mapping[_T, Any]], + filename: "FileDescriptorOrPath", + dialect: "_DialectLike" = "excel", + encoding: Optional[str] = None, +) -> None: """Write a list of dicts to a file This is a convenience function to write dicts to a file. The header is @@ -440,8 +485,15 @@ def write_dicts(items, filename, dialect="excel", encoding=None): if not items: return - fieldnames = list(items[0].keys()) + iterator = iter(items) + try: + first = next(iterator) + except StopIteration: + return + + fieldnames = list(first.keys()) with open(filename, "w", newline="", encoding=encoding) as fp: w = DictWriter(fp, fieldnames=fieldnames, dialect=dialect) w.writeheader() - w.writerows(items) + w.writerow(first) + w.writerows(iterator) diff --git a/clevercsv/write.py b/clevercsv/write.py index 01472bd1..a0e3403b 100644 --- a/clevercsv/write.py +++ b/clevercsv/write.py @@ -8,8 +8,20 @@ """ +from __future__ import annotations + import csv +from typing import TYPE_CHECKING +from typing import Any +from typing import Iterable +from typing import Type + +if TYPE_CHECKING: + from clevercsv._types import SupportsWrite + +from clevercsv._types import _DialectLike + from .dialect import SimpleDialect from .exceptions import Error @@ -25,13 +37,23 @@ ] -class writer(object): - def __init__(self, csvfile, dialect="excel", **fmtparams): +class writer: + def __init__( + self, + csvfile: SupportsWrite, + dialect: _DialectLike = "excel", + **fmtparams, + ): self.original_dialect = dialect - self.dialect = self._make_python_dialect(dialect, **fmtparams) + self.dialect: Type[csv.Dialect] = self._make_python_dialect( + dialect, **fmtparams + ) self._writer = csv.writer(csvfile, dialect=self.dialect) - def _make_python_dialect(self, dialect, **fmtparams): + def _make_python_dialect( + self, dialect: _DialectLike, **fmtparams + ) -> Type[csv.Dialect]: + d: _DialectLike = "" if isinstance(dialect, str): d = csv.get_dialect(dialect) elif isinstance(dialect, csv.Dialect): @@ -56,13 +78,13 @@ def _make_python_dialect(self, dialect, **fmtparams): newdialect = type("dialect", (csv.Dialect,), props) return newdialect - def writerow(self, row): + def writerow(self, row: Iterable[Any]) -> Any: try: return self._writer.writerow(row) except csv.Error as e: raise Error(str(e)) - def writerows(self, rows): + def writerows(self, rows: Iterable[Iterable[Any]]) -> Any: try: return self._writer.writerows(rows) except csv.Error as e: diff --git a/pyproject.toml b/pyproject.toml index 23130b7a..a834bd6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -11,3 +11,10 @@ lines_between_types=1 [tool.ruff] # Exclude stubs directory for now exclude = ["stubs"] + +[tool.mypy] +python_version = 3.8 +warn_unused_configs = true + +# [[tool.mypy.overrides]] +# packages = ["stubs", "clevercsv", "tests"] diff --git a/setup.py b/setup.py index 96b20b9d..5adab650 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ LICENSE = "MIT" LICENSE_TROVE = "License :: OSI Approved :: MIT License" NAME = "clevercsv" -REQUIRES_PYTHON = ">=3.6.0" +REQUIRES_PYTHON = ">=3.8.0" URL = "https://github.com/alan-turing-institute/CleverCSV" VERSION = None @@ -113,6 +113,7 @@ def run(self): install_requires=REQUIRED, extras_require=EXTRAS, include_package_data=True, + package_data={"clevercsv": ["py.typed"]}, license=LICENSE, ext_modules=[ Extension("clevercsv.cparser", sources=["src/cparser.c"]), diff --git a/stubs/pandas/__init__.pyi b/stubs/pandas/__init__.pyi new file mode 100644 index 00000000..e84ca9dd --- /dev/null +++ b/stubs/pandas/__init__.pyi @@ -0,0 +1,119 @@ +from typing import Any + +from pandas._config import describe_option as describe_option +from pandas._config import get_option as get_option +from pandas._config import option_context as option_context +from pandas._config import options as options +from pandas._config import reset_option as reset_option +from pandas._config import set_option as set_option +from pandas.core.api import NA as NA +from pandas.core.api import BooleanDtype as BooleanDtype +from pandas.core.api import Categorical as Categorical +from pandas.core.api import CategoricalDtype as CategoricalDtype +from pandas.core.api import CategoricalIndex as CategoricalIndex +from pandas.core.api import DataFrame as DataFrame +from pandas.core.api import DateOffset as DateOffset +from pandas.core.api import DatetimeIndex as DatetimeIndex +from pandas.core.api import DatetimeTZDtype as DatetimeTZDtype +from pandas.core.api import Flags as Flags +from pandas.core.api import Float32Dtype as Float32Dtype +from pandas.core.api import Float64Dtype as Float64Dtype +from pandas.core.api import Float64Index as Float64Index +from pandas.core.api import Grouper as Grouper +from pandas.core.api import Index as Index +from pandas.core.api import IndexSlice as IndexSlice +from pandas.core.api import Int8Dtype as Int8Dtype +from pandas.core.api import Int16Dtype as Int16Dtype +from pandas.core.api import Int32Dtype as Int32Dtype +from pandas.core.api import Int64Dtype as Int64Dtype +from pandas.core.api import Int64Index as Int64Index +from pandas.core.api import Interval as Interval +from pandas.core.api import IntervalDtype as IntervalDtype +from pandas.core.api import IntervalIndex as IntervalIndex +from pandas.core.api import MultiIndex as MultiIndex +from pandas.core.api import NamedAgg as NamedAgg +from pandas.core.api import NaT as NaT +from pandas.core.api import Period as Period +from pandas.core.api import PeriodDtype as PeriodDtype +from pandas.core.api import PeriodIndex as PeriodIndex +from pandas.core.api import RangeIndex as RangeIndex +from pandas.core.api import Series as Series +from pandas.core.api import StringDtype as StringDtype +from pandas.core.api import Timedelta as Timedelta +from pandas.core.api import TimedeltaIndex as TimedeltaIndex +from pandas.core.api import Timestamp as Timestamp +from pandas.core.api import UInt8Dtype as UInt8Dtype +from pandas.core.api import UInt16Dtype as UInt16Dtype +from pandas.core.api import UInt32Dtype as UInt32Dtype +from pandas.core.api import UInt64Dtype as UInt64Dtype +from pandas.core.api import UInt64Index as UInt64Index +from pandas.core.api import array as array +from pandas.core.api import bdate_range as bdate_range +from pandas.core.api import date_range as date_range +from pandas.core.api import factorize as factorize +from pandas.core.api import interval_range as interval_range +from pandas.core.api import isna as isna +from pandas.core.api import isnull as isnull +from pandas.core.api import notna as notna +from pandas.core.api import notnull as notnull +from pandas.core.api import period_range as period_range +from pandas.core.api import set_eng_float_format as set_eng_float_format +from pandas.core.api import timedelta_range as timedelta_range +from pandas.core.api import to_datetime as to_datetime +from pandas.core.api import to_numeric as to_numeric +from pandas.core.api import to_timedelta as to_timedelta +from pandas.core.api import unique as unique +from pandas.core.api import value_counts as value_counts +from pandas.core.arrays.sparse import SparseDtype as SparseDtype +from pandas.core.computation.api import eval as eval +from pandas.core.reshape.api import concat as concat +from pandas.core.reshape.api import crosstab as crosstab +from pandas.core.reshape.api import cut as cut +from pandas.core.reshape.api import get_dummies as get_dummies +from pandas.core.reshape.api import lreshape as lreshape +from pandas.core.reshape.api import melt as melt +from pandas.core.reshape.api import merge as merge +from pandas.core.reshape.api import merge_asof as merge_asof +from pandas.core.reshape.api import merge_ordered as merge_ordered +from pandas.core.reshape.api import pivot as pivot +from pandas.core.reshape.api import pivot_table as pivot_table +from pandas.core.reshape.api import qcut as qcut +from pandas.core.reshape.api import wide_to_long as wide_to_long +from pandas.io.api import ExcelFile as ExcelFile +from pandas.io.api import ExcelWriter as ExcelWriter +from pandas.io.api import HDFStore as HDFStore +from pandas.io.api import read_clipboard as read_clipboard +from pandas.io.api import read_csv as read_csv +from pandas.io.api import read_excel as read_excel +from pandas.io.api import read_feather as read_feather +from pandas.io.api import read_fwf as read_fwf +from pandas.io.api import read_gbq as read_gbq +from pandas.io.api import read_hdf as read_hdf +from pandas.io.api import read_html as read_html +from pandas.io.api import read_json as read_json +from pandas.io.api import read_orc as read_orc +from pandas.io.api import read_parquet as read_parquet +from pandas.io.api import read_pickle as read_pickle +from pandas.io.api import read_sas as read_sas +from pandas.io.api import read_spss as read_spss +from pandas.io.api import read_sql as read_sql +from pandas.io.api import read_sql_query as read_sql_query +from pandas.io.api import read_sql_table as read_sql_table +from pandas.io.api import read_stata as read_stata +from pandas.io.api import read_table as read_table +from pandas.io.api import to_pickle as to_pickle +from pandas.tseries import offsets as offsets +from pandas.tseries.api import infer_freq as infer_freq +from pandas.util._print_versions import show_versions as show_versions +from pandas.util._tester import test as test + +__docformat__: str +hard_dependencies: Any +missing_dependencies: Any +module: Any +v: Any +__git_version__: Any + +def __getattr__(name: Any): ... + +# __doc__: str diff --git a/stubs/pythonfuzz/__init__.pyi b/stubs/pythonfuzz/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/stubs/pythonfuzz/main.pyi b/stubs/pythonfuzz/main.pyi new file mode 100644 index 00000000..f13dbe8b --- /dev/null +++ b/stubs/pythonfuzz/main.pyi @@ -0,0 +1,6 @@ +from typing import Any +from typing import Callable + +class PythonFuzz: + def __init__(self, func: Callable) -> None: ... + def __call__(self, *args: Any, **kwargs: Any) -> None: ... diff --git a/stubs/regex/__init__.pyi b/stubs/regex/__init__.pyi new file mode 100644 index 00000000..b47bdfb5 --- /dev/null +++ b/stubs/regex/__init__.pyi @@ -0,0 +1,61 @@ +from .regex import * + +# Names in __all__ with no definition: +# A +# ASCII +# B +# BESTMATCH +# D +# DEBUG +# DEFAULT_VERSION +# DOTALL +# E +# ENHANCEMATCH +# F +# FULLCASE +# I +# IGNORECASE +# L +# LOCALE +# M +# MULTILINE +# Match +# P +# POSIX +# Pattern +# R +# REVERSE +# Regex +# S +# Scanner +# T +# TEMPLATE +# U +# UNICODE +# V0 +# V1 +# VERBOSE +# VERSION0 +# VERSION1 +# W +# WORD +# X +# __doc__ +# __version__ +# cache_all +# compile +# error +# escape +# findall +# finditer +# fullmatch +# match +# purge +# search +# split +# splititer +# sub +# subf +# subfn +# subn +# template diff --git a/stubs/regex/_regex.pyi b/stubs/regex/_regex.pyi new file mode 100644 index 00000000..6171611d --- /dev/null +++ b/stubs/regex/_regex.pyi @@ -0,0 +1,13 @@ +from typing import Any + +CODE_SIZE: int +MAGIC: int +copyright: str + +def compile(*args, **kwargs) -> Any: ... +def fold_case(*args, **kwargs) -> Any: ... +def get_all_cases(*args, **kwargs) -> Any: ... +def get_code_size(*args, **kwargs) -> Any: ... +def get_expand_on_folding(*args, **kwargs) -> Any: ... +def get_properties(*args, **kwargs) -> Any: ... +def has_property_value(*args, **kwargs) -> Any: ... diff --git a/stubs/regex/_regex_core.pyi b/stubs/regex/_regex_core.pyi new file mode 100644 index 00000000..66cb0e6c --- /dev/null +++ b/stubs/regex/_regex_core.pyi @@ -0,0 +1,503 @@ +from typing import Any as _Any + +class error(Exception): + msg: _Any + pattern: _Any + pos: _Any + lineno: _Any + colno: _Any + def __init__( + self, message, pattern: _Any | None = ..., pos: _Any | None = ... + ) -> None: ... + +class _UnscopedFlagSet(Exception): ... +class ParseError(Exception): ... +class _FirstSetError(Exception): ... + +A: int + +ASCII: int +B: int +BESTMATCH: int +D: int +DEBUG: int +E: int +ENHANCEMATCH: int +F: int +FULLCASE: int +I: int +IGNORECASE: int +L: int +LOCALE: int +M: int +MULTILINE: int +P: int +POSIX: int +R: int +REVERSE: int +S: int +DOTALL: int +U: int +UNICODE: int +V0: int +VERSION0: int +V1: int +VERSION1: int +W: int +WORD: int +X: int +VERBOSE: int +T: int +TEMPLATE: int +DEFAULT_VERSION = VERSION1 + +class Namespace: ... + +class RegexBase: + def __init__(self) -> None: ... + def with_flags( + self, + positive: _Any | None = ..., + case_flags: _Any | None = ..., + zerowidth: _Any | None = ..., + ): ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse) -> None: ... + def has_simple_start(self): ... + def compile(self, reverse: bool = ..., fuzzy: bool = ...): ... + def is_empty(self): ... + def __hash__(self): ... + def __eq__(self, other): ... + def __ne__(self, other): ... + def get_required_string(self, reverse): ... + +class ZeroWidthBase(RegexBase): + positive: _Any + def __init__(self, positive: bool = ...) -> None: ... + def get_firstset(self, reverse): ... + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + +class Any(RegexBase): + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + +class AnyAll(Any): ... +class AnyU(Any): ... + +class Atomic(RegexBase): + subpattern: _Any + def __init__(self, subpattern) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + def get_required_string(self, reverse): ... + +class Boundary(ZeroWidthBase): ... + +class Branch(RegexBase): + branches: _Any + def __init__(self, branches) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + +class CallGroup(RegexBase): + info: _Any + group: _Any + position: _Any + def __init__(self, info, group, position) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def remove_captures(self) -> None: ... + def dump(self, indent, reverse) -> None: ... + def __eq__(self, other): ... + def max_width(self): ... + def __del__(self) -> None: ... + +class CallRef(RegexBase): + ref: _Any + parsed: _Any + def __init__(self, ref, parsed) -> None: ... + +class Character(RegexBase): + value: _Any + positive: _Any + case_flags: _Any + zerowidth: _Any + folded: _Any + def __init__( + self, + value, + positive: bool = ..., + case_flags=..., + zerowidth: bool = ..., + ) -> None: ... + def rebuild(self, positive, case_flags, zerowidth): ... + def optimise(self, info, reverse, in_set: bool = ...): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def matches(self, ch): ... + def max_width(self): ... + folded_characters: _Any + def get_required_string(self, reverse): ... + +class Conditional(RegexBase): + info: _Any + group: _Any + yes_item: _Any + no_item: _Any + position: _Any + def __init__(self, info, group, yes_item, no_item, position) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self) -> None: ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + def __del__(self) -> None: ... + +class DefaultBoundary(ZeroWidthBase): ... +class DefaultEndOfWord(ZeroWidthBase): ... +class DefaultStartOfWord(ZeroWidthBase): ... +class EndOfLine(ZeroWidthBase): ... +class EndOfLineU(EndOfLine): ... +class EndOfString(ZeroWidthBase): ... +class EndOfStringLine(ZeroWidthBase): ... +class EndOfStringLineU(EndOfStringLine): ... +class EndOfWord(ZeroWidthBase): ... +class Failure(ZeroWidthBase): ... + +class Fuzzy(RegexBase): + subpattern: _Any + constraints: _Any + def __init__(self, subpattern, constraints: _Any | None = ...) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def contains_group(self): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + +class Grapheme(RegexBase): + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + +class GraphemeBoundary: + def compile(self, reverse, fuzzy): ... + +class GreedyRepeat(RegexBase): + subpattern: _Any + min_count: _Any + max_count: _Any + def __init__(self, subpattern, min_count, max_count) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + def get_required_string(self, reverse): ... + +class PossessiveRepeat(GreedyRepeat): + def is_atomic(self): ... + def dump(self, indent, reverse) -> None: ... + +class Group(RegexBase): + info: _Any + group: _Any + subpattern: _Any + call_ref: _Any + def __init__(self, info, group, subpattern) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def __eq__(self, other): ... + def max_width(self): ... + def get_required_string(self, reverse): ... + def __del__(self) -> None: ... + +class Keep(ZeroWidthBase): ... +class LazyRepeat(GreedyRepeat): ... + +class LookAround(RegexBase): + behind: _Any + positive: _Any + subpattern: _Any + def __init__(self, behind, positive, subpattern) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + +class LookAroundConditional(RegexBase): + behind: _Any + positive: _Any + subpattern: _Any + yes_item: _Any + no_item: _Any + def __init__( + self, behind, positive, subpattern, yes_item, no_item + ) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self) -> None: ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + def get_required_string(self, reverse): ... + +class PrecompiledCode(RegexBase): + code: _Any + def __init__(self, code) -> None: ... + +class Property(RegexBase): + value: _Any + positive: _Any + case_flags: _Any + zerowidth: _Any + def __init__( + self, + value, + positive: bool = ..., + case_flags=..., + zerowidth: bool = ..., + ) -> None: ... + def rebuild(self, positive, case_flags, zerowidth): ... + def optimise(self, info, reverse, in_set: bool = ...): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def matches(self, ch): ... + def max_width(self): ... + +class Prune(ZeroWidthBase): ... + +class Range(RegexBase): + lower: _Any + upper: _Any + positive: _Any + case_flags: _Any + zerowidth: _Any + def __init__( + self, + lower, + upper, + positive: bool = ..., + case_flags=..., + zerowidth: bool = ..., + ) -> None: ... + def rebuild(self, positive, case_flags, zerowidth): ... + def optimise(self, info, reverse, in_set: bool = ...): ... + def dump(self, indent, reverse) -> None: ... + def matches(self, ch): ... + def max_width(self): ... + +class RefGroup(RegexBase): + info: _Any + group: _Any + position: _Any + case_flags: _Any + def __init__(self, info, group, position, case_flags=...) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def remove_captures(self) -> None: ... + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + def __del__(self) -> None: ... + +class SearchAnchor(ZeroWidthBase): ... + +class Sequence(RegexBase): + items: _Any + def __init__(self, items: _Any | None = ...) -> None: ... + def fix_groups(self, pattern, reverse, fuzzy) -> None: ... + def optimise(self, info, reverse): ... + def pack_characters(self, info): ... + def remove_captures(self): ... + def is_atomic(self): ... + def can_be_affix(self): ... + def contains_group(self): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def is_empty(self): ... + def __eq__(self, other): ... + def max_width(self): ... + def get_required_string(self, reverse): ... + +class SetBase(RegexBase): + info: _Any + items: _Any + positive: _Any + case_flags: _Any + zerowidth: _Any + char_width: int + def __init__( + self, + info, + items, + positive: bool = ..., + case_flags=..., + zerowidth: bool = ..., + ) -> None: ... + def rebuild(self, positive, case_flags, zerowidth): ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + def __del__(self) -> None: ... + +class SetDiff(SetBase): + items: _Any + def optimise(self, info, reverse, in_set: bool = ...): ... + def matches(self, ch): ... + +class SetInter(SetBase): + items: _Any + def optimise(self, info, reverse, in_set: bool = ...): ... + def matches(self, ch): ... + +class SetSymDiff(SetBase): + items: _Any + def optimise(self, info, reverse, in_set: bool = ...): ... + def matches(self, ch): ... + +class SetUnion(SetBase): + items: _Any + def optimise(self, info, reverse, in_set: bool = ...): ... + def matches(self, ch): ... + +class Skip(ZeroWidthBase): ... +class StartOfLine(ZeroWidthBase): ... +class StartOfLineU(StartOfLine): ... +class StartOfString(ZeroWidthBase): ... +class StartOfWord(ZeroWidthBase): ... + +class String(RegexBase): + characters: _Any + case_flags: _Any + folded_characters: _Any + required: bool + def __init__(self, characters, case_flags=...) -> None: ... + def get_firstset(self, reverse): ... + def has_simple_start(self): ... + def dump(self, indent, reverse) -> None: ... + def max_width(self): ... + def get_required_string(self, reverse): ... + +class Literal(String): + def dump(self, indent, reverse) -> None: ... + +class StringSet(Branch): + info: _Any + name: _Any + case_flags: _Any + set_key: _Any + branches: _Any + def __init__(self, info, name, case_flags=...) -> None: ... + def dump(self, indent, reverse) -> None: ... + def __del__(self) -> None: ... + +class Source: + string: _Any + char_type: _Any + pos: int + ignore_space: bool + sep: _Any + def __init__(self, string): ... + def get(self, override_ignore: bool = ...): ... + def get_many(self, count: int = ...): ... + def get_while(self, test_set, include: bool = ...): ... + def skip_while(self, test_set, include: bool = ...) -> None: ... + def match(self, substring): ... + def expect(self, substring) -> None: ... + def at_end(self): ... + +class Info: + flags: _Any + global_flags: _Any + inline_locale: bool + kwargs: _Any + group_count: int + group_index: _Any + group_name: _Any + char_type: _Any + named_lists_used: _Any + open_groups: _Any + open_group_count: _Any + defined_groups: _Any + group_calls: _Any + private_groups: _Any + def __init__( + self, flags: int = ..., char_type: _Any | None = ..., kwargs=... + ) -> None: ... + def open_group(self, name: _Any | None = ...): ... + def close_group(self) -> None: ... + def is_open_group(self, name): ... + +class Scanner: + lexicon: _Any + scanner: _Any + def __init__(self, lexicon, flags: int = ...) -> None: ... + match: _Any + def scan(self, string): ... diff --git a/stubs/regex/regex.pyi b/stubs/regex/regex.pyi new file mode 100644 index 00000000..a6819c4f --- /dev/null +++ b/stubs/regex/regex.pyi @@ -0,0 +1,189 @@ +from typing import Any + +from regex._regex_core import VERSION0 + +def match( + pattern, + string, + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + partial: bool = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def fullmatch( + pattern, + string, + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + partial: bool = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def search( + pattern, + string, + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + partial: bool = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def sub( + pattern, + repl, + string, + count: int = ..., + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def subf( + pattern, + format, + string, + count: int = ..., + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def subn( + pattern, + repl, + string, + count: int = ..., + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def subfn( + pattern, + format, + string, + count: int = ..., + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def split( + pattern, + string, + maxsplit: int = ..., + flags: int = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def splititer( + pattern, + string, + maxsplit: int = ..., + flags: int = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def findall( + pattern, + string, + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + overlapped: bool = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def finditer( + pattern, + string, + flags: int = ..., + pos: Any | None = ..., + endpos: Any | None = ..., + overlapped: bool = ..., + partial: bool = ..., + concurrent: Any | None = ..., + timeout: Any | None = ..., + ignore_unused: bool = ..., + **kwargs +): ... +def compile( + pattern, flags: int = ..., ignore_unused: bool = ..., **kwargs +): ... +def purge() -> None: ... +def cache_all(value: bool = ...): ... +def template(pattern, flags: int = ...): ... +def escape(pattern, special_only: bool = ..., literal_spaces: bool = ...): ... + +DEFAULT_VERSION = VERSION0 +Pattern: Any +Match: Any +Regex = compile + +# Names in __all__ with no definition: +# A +# ASCII +# B +# BESTMATCH +# D +# DEBUG +# DOTALL +# E +# ENHANCEMATCH +# F +# FULLCASE +# I +# IGNORECASE +# L +# LOCALE +# M +# MULTILINE +# P +# POSIX +# R +# REVERSE +# S +# Scanner +# T +# TEMPLATE +# U +# UNICODE +# V0 +# V1 +# VERBOSE +# VERSION0 +# VERSION1 +# W +# WORD +# X +# __doc__ +# __version__ +# error diff --git a/stubs/tabview/__init__.pyi b/stubs/tabview/__init__.pyi new file mode 100644 index 00000000..db597540 --- /dev/null +++ b/stubs/tabview/__init__.pyi @@ -0,0 +1 @@ +from .tabview import view as view diff --git a/stubs/tabview/tabview.pyi b/stubs/tabview/tabview.pyi new file mode 100644 index 00000000..4685cd48 --- /dev/null +++ b/stubs/tabview/tabview.pyi @@ -0,0 +1,163 @@ +import io + +from typing import Any + +basestring = str +file = io.FileIO + +def KEY_CTRL(key): ... +def addstr(*args): ... +def insstr(*args): ... + +class ReloadException(Exception): + start_pos: Any + column_width_mode: Any + column_gap: Any + column_widths: Any + search_str: Any + def __init__( + self, start_pos, column_width, column_gap, column_widths, search_str + ) -> None: ... + +class QuitException(Exception): ... + +class Viewer: + scr: Any + data: Any + info: Any + header_offset_orig: int + header: Any + header_offset: Any + num_data_columns: Any + column_width_mode: Any + column_gap: Any + trunc_char: Any + num_columns: int + vis_columns: int + init_search: Any + modifier: Any + def __init__(self, *args, **kwargs) -> None: ... + def column_xw(self, x): ... + def quit(self) -> None: ... + def reload(self) -> None: ... + def consume_modifier(self, default: int = ...): ... + def down(self) -> None: ... + def up(self) -> None: ... + def left(self) -> None: ... + def right(self) -> None: ... + y: Any + win_y: Any + def page_down(self) -> None: ... + def page_up(self) -> None: ... + x: Any + win_x: Any + def page_right(self) -> None: ... + def page_left(self) -> None: ... + def mark(self) -> None: ... + def goto_mark(self) -> None: ... + def home(self) -> None: ... + def goto_y(self, y) -> None: ... + def goto_row(self) -> None: ... + def goto_x(self, x) -> None: ... + def goto_col(self) -> None: ... + def goto_yx(self, y, x) -> None: ... + def line_home(self) -> None: ... + def line_end(self) -> None: ... + def show_cell(self) -> None: ... + def show_info(self): ... + textpad: Any + search_str: Any + def search(self) -> None: ... + def search_results( + self, rev: bool = ..., look_in_cur: bool = ... + ) -> None: ... + def search_results_prev( + self, rev: bool = ..., look_in_cur: bool = ... + ) -> None: ... + def help(self) -> None: ... + def toggle_header(self) -> None: ... + def column_gap_down(self) -> None: ... + def column_gap_up(self) -> None: ... + column_width: Any + def column_width_all_down(self) -> None: ... + def column_width_all_up(self) -> None: ... + def column_width_down(self) -> None: ... + def column_width_up(self) -> None: ... + def sort_by_column_numeric(self): ... + def sort_by_column_numeric_reverse(self): ... + def sort_by_column(self) -> None: ... + def sort_by_column_reverse(self) -> None: ... + def sort_by_column_natural(self) -> None: ... + def sort_by_column_natural_reverse(self) -> None: ... + def sorted_nicely(self, ls, key, rev: bool = ...): ... + def float_string_key(self, value): ... + def toggle_column_width(self) -> None: ... + def set_current_column_width(self) -> None: ... + def yank_cell(self) -> None: ... + keys: Any + def define_keys(self) -> None: ... + def run(self) -> None: ... + def handle_keys(self) -> None: ... + def handle_modifier(self, mod) -> None: ... + def resize(self) -> None: ... + def num_columns_fwd(self, x): ... + def num_columns_rev(self, x): ... + def recalculate_layout(self) -> None: ... + def location_string(self, yp, xp): ... + def display(self) -> None: ... + def strpad(self, s, width): ... + def hdrstr(self, x, width): ... + def cellstr(self, y, x, width): ... + def skip_to_row_change(self) -> None: ... + def skip_to_row_change_reverse(self) -> None: ... + def skip_to_col_change(self) -> None: ... + def skip_to_col_change_reverse(self) -> None: ... + +class TextBox: + scr: Any + data: Any + title: Any + tdata: Any + hid_rows: int + def __init__(self, scr, data: str = ..., title: str = ...) -> None: ... + def __call__(self) -> None: ... + handlers: Any + def setup_handlers(self) -> None: ... + def run(self) -> None: ... + def handle_key(self, key) -> None: ... + def close(self) -> None: ... + def scroll_down(self) -> None: ... + def scroll_up(self) -> None: ... + def display(self) -> None: ... + +def csv_sniff(data, enc): ... +def fix_newlines(data): ... +def adjust_space_delim(data, enc): ... +def process_data( + data, + enc: Any | None = ..., + delim: Any | None = ..., + quoting: Any | None = ..., + quote_char=..., +): ... +def data_list_or_file(data): ... +def pad_data(d): ... +def readme(): ... +def detect_encoding(data: Any | None = ...): ... +def main(stdscr, *args, **kwargs) -> None: ... +def view( + data, + enc: Any | None = ..., + start_pos=..., + column_width: int = ..., + column_gap: int = ..., + trunc_char: str = ..., + column_widths: Any | None = ..., + search_str: Any | None = ..., + double_width: bool = ..., + delimiter: Any | None = ..., + quoting: Any | None = ..., + info: Any | None = ..., + quote_char=..., +): ... +def parse_path(path): ... diff --git a/stubs/termcolor/__init__.pyi b/stubs/termcolor/__init__.pyi new file mode 100644 index 00000000..9c937267 --- /dev/null +++ b/stubs/termcolor/__init__.pyi @@ -0,0 +1,22 @@ +from typing import Any + +__ALL__: Any +VERSION: Any +ATTRIBUTES: Any +HIGHLIGHTS: Any +COLORS: Any +RESET: str + +def colored( + text, + color: Any | None = ..., + on_color: Any | None = ..., + attrs: Any | None = ..., +): ... +def cprint( + text, + color: Any | None = ..., + on_color: Any | None = ..., + attrs: Any | None = ..., + **kwargs +) -> None: ... diff --git a/stubs/wilderness/__init__.pyi b/stubs/wilderness/__init__.pyi new file mode 100644 index 00000000..47e061a2 --- /dev/null +++ b/stubs/wilderness/__init__.pyi @@ -0,0 +1,168 @@ +import abc +import argparse + +from typing import Dict +from typing import List +from typing import Optional +from typing import TextIO + +class DocumentableMixin(metaclass=abc.ABCMeta): + def __init__( + self, + description: Optional[str] = None, + extra_sections: Optional[Dict[str, str]] = None, + options_prolog: Optional[str] = None, + options_epilog: Optional[str] = None, + ) -> None: ... + @property + def description(self) -> Optional[str]: ... + @property + def parser(self) -> argparse.ArgumentParser: ... + @parser.setter + def parser(self, parser: argparse.ArgumentParser): ... + @property + def args(self) -> argparse.Namespace: ... + @args.setter + def args(self, args: argparse.Namespace): ... + @property + def argument_help(self) -> Dict[str, Optional[str]]: ... + +class Application(DocumentableMixin): + def __init__( + self, + name: str, + version: str, + author: Optional[str] = None, + title: Optional[str] = None, + description: Optional[str] = None, + default_command: Optional[str] = None, + add_help: bool = True, + extra_sections: Optional[Dict[str, str]] = None, + prolog: Optional[str] = None, + epilog: Optional[str] = None, + options_prolog: Optional[str] = None, + options_epilog: Optional[str] = None, + add_commands_section: bool = False, + ) -> None: ... + @property + def name(self) -> str: ... + @property + def author(self) -> str: ... + @property + def version(self) -> str: ... + @property + def commands(self) -> List[Command]: ... + @property + def groups(self) -> List[Group]: ... + def add_argument(self, *args, **kwargs) -> argparse.Action: ... + def add(self, command: Command): ... + def add_group(self, title: str) -> Group: ... + def register(self): ... + def handle(self) -> int: ... + def run( + self, + args: Optional[List[str]] = None, + namespace: Optional[argparse.Namespace] = None, + exit_on_error: bool = True, + ) -> int: ... + def run_command(self, command: Command) -> int: ... + def get_command(self, command_name: str) -> Command: ... + def set_prolog(self, prolog: str) -> None: ... + def set_epilog(self, epilog: str) -> None: ... + def get_commands_text(self) -> str: ... + def create_manpage(self) -> ManPage: ... + def format_help(self) -> str: ... + def print_help(self, file: Optional[TextIO] = None) -> None: ... + +class Group: + def __init__( + self, title: Optional[str] = None, is_root: bool = False + ) -> None: ... + @property + def application(self) -> Optional[Application]: ... + @property + def title(self) -> Optional[str]: ... + @property + def commands(self) -> List[Command]: ... + @property + def is_root(self) -> bool: ... + def commands_as_actions(self) -> List[argparse.Action]: ... + def set_app(self, app: Application) -> None: ... + def add(self, command: Command) -> None: ... + def __len__(self) -> int: ... + +class Command(DocumentableMixin, metaclass=abc.ABCMeta): + def __init__( + self, + name: str, + title: Optional[str] = None, + description: Optional[str] = None, + add_help: bool = True, + extra_sections: Optional[Dict[str, str]] = None, + options_prolog: Optional[str] = None, + options_epilog: Optional[str] = None, + ) -> None: ... + @property + def application(self) -> Optional[Application]: ... + @property + def name(self) -> str: ... + @property + def title(self) -> Optional[str]: ... + def add_argument(self, *args, **kwargs) -> None: ... + def add_argument_group(self, *args, **kwargs) -> ArgumentGroup: ... + def add_mutually_exclusive_group( + self, *args, **kwargs + ) -> MutuallyExclusiveGroup: ... + def register(self) -> None: ... + @abc.abstractmethod + def handle(self) -> int: ... + def create_manpage(self) -> ManPage: ... + +class ManPage: + def __init__( + self, + application_name: str, + author: Optional[str] = "", + command_name: Optional[str] = None, + date: Optional[str] = None, + title: Optional[str] = None, + version: Optional[str] = "", + ) -> None: ... + @property + def name(self) -> str: ... + def metadata(self) -> List[str]: ... + def preamble(self) -> List[str]: ... + def header(self) -> str: ... + def section_name(self) -> str: ... + def add_section_synopsis(self, synopsis: str) -> None: ... + def add_section(self, label: str, text: str) -> None: ... + def groffify(self, text: str) -> str: ... + def groffify_line(self, line: str) -> str: ... + def export(self, output_dir: str) -> str: ... + +class ArgumentGroup: + def __init__(self, group: argparse._ArgumentGroup) -> None: ... + @property + def command(self) -> Optional[Command]: ... + @command.setter + def command(self, command: Command): ... + def add_argument(self, *args, **kwargs): ... + +class MutuallyExclusiveGroup: + def __init__(self, meg: argparse._MutuallyExclusiveGroup) -> None: ... + @property + def command(self) -> Optional[Command]: ... + @command.setter + def command(self, command: Command): ... + def add_argument(self, *args, **kwargs): ... + +class Tester: + def __init__(self, app: Application) -> None: ... + @property + def application(self) -> Application: ... + def clear(self) -> None: ... + def get_return_code(self) -> Optional[int]: ... + def get_stdout(self) -> Optional[str]: ... + def get_stderr(self) -> Optional[str]: ... + def test_command(self, cmd_name: str, args: List[str]) -> None: ... + def test_application(self, args: Optional[List[str]] = None) -> None: ... diff --git a/tests/test_integration/test_dialect_detection.py b/tests/test_integration/test_dialect_detection.py index 917cff47..7308ea73 100644 --- a/tests/test_integration/test_dialect_detection.py +++ b/tests/test_integration/test_dialect_detection.py @@ -47,7 +47,8 @@ def log_result(name, kind, verbose, partial): "success": (LOG_SUCCESS, LOG_SUCCESS_PARTIAL, "green"), "failure": (LOG_FAILED, LOG_FAILED_PARTIAL, "red"), } - outfull, outpartial, color = table.get(kind) + assert kind in table + outfull, outpartial, color = table[kind] fname = outpartial if partial else outfull with open(fname, "a") as fp: diff --git a/tests/test_unit/test_console.py b/tests/test_unit/test_console.py index 5ea1a65a..29595575 100644 --- a/tests/test_unit/test_console.py +++ b/tests/test_unit/test_console.py @@ -11,6 +11,9 @@ import tempfile import unittest +from typing import List +from typing import Union + from wilderness import Tester from clevercsv import __version__ @@ -20,7 +23,7 @@ class ConsoleTestCase(unittest.TestCase): - def _build_file(self, table, dialect, encoding=None, newline=None): + def _build_file(self, table, dialect, encoding=None, newline=None) -> str: tmpfd, tmpfname = tempfile.mkstemp( prefix="ccsv_", suffix=".csv", @@ -40,7 +43,10 @@ def _detect_test_wrap(self, table, dialect): tester.test_command("detect", [tmpfname]) try: - output = tester.get_stdout().strip() + stdout = tester.get_stdout() + self.assertIsNotNone(stdout) + assert stdout is not None + output = stdout.strip() self.assertEqual(exp, output) finally: os.unlink(tmpfname) @@ -79,7 +85,10 @@ def test_detect_opts_1(self): exp = "Detected: " + str(dialect) try: - output = tester.get_stdout().strip() + stdout = tester.get_stdout() + self.assertIsNotNone(stdout) + assert stdout is not None + output = stdout.strip() self.assertEqual(exp, output) finally: os.unlink(tmpfname) @@ -96,7 +105,10 @@ def test_detect_opts_2(self): exp = "Detected: " + str(dialect) try: - output = tester.get_stdout().strip() + stdout = tester.get_stdout() + self.assertIsNotNone(stdout) + assert stdout is not None + output = stdout.strip() self.assertEqual(exp, output) finally: os.unlink(tmpfname) @@ -115,7 +127,10 @@ def test_detect_opts_3(self): quotechar = escapechar =""" try: - output = tester.get_stdout().strip() + stdout = tester.get_stdout() + self.assertIsNotNone(stdout) + assert stdout is not None + output = stdout.strip() self.assertEqual(exp, output) finally: os.unlink(tmpfname) @@ -130,7 +145,10 @@ def test_detect_opts_4(self): tester.test_command("detect", ["--json", "--add-runtime", tmpfname]) try: - output = tester.get_stdout().strip() + stdout = tester.get_stdout() + self.assertIsNotNone(stdout) + assert stdout is not None + output = stdout.strip() data = json.loads(output) self.assertEqual(data["delimiter"], ";") self.assertEqual(data["quotechar"], "") @@ -442,7 +460,11 @@ def test_standardize_in_place_noop(self): os.unlink(tmpfname) def test_standardize_multi(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["A", "B", "C"], + [1, 2, 3], + [4, 5, 6], + ] dialects = ["excel", "unix", "excel-tab"] tmpfnames = [self._build_file(table, D, newline="") for D in dialects] @@ -476,7 +498,11 @@ def test_standardize_multi(self): any(map(os.unlink, tmpoutnames)) def test_standardize_multi_errors(self): - table = [["A", "B", "C"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["A", "B", "C"], + [1, 2, 3], + [4, 5, 6], + ] dialects = ["excel", "unix", "excel-tab"] tmpfnames = [self._build_file(table, D, newline="") for D in dialects] @@ -507,7 +533,11 @@ def test_standardize_multi_errors(self): any(map(os.unlink, tmpoutnames)) def test_standardize_multi_encoding(self): - table = [["Å", "B", "C"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["Å", "B", "C"], + [1, 2, 3], + [4, 5, 6], + ] dialects = ["excel", "unix", "excel-tab"] encoding = "ISO-8859-1" tmpfnames = [ @@ -547,7 +577,11 @@ def test_standardize_multi_encoding(self): any(map(os.unlink, tmpoutnames)) def test_standardize_in_place_multi(self): - table = [["Å", "B", "C"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["Å", "B", "C"], + [1, 2, 3], + [4, 5, 6], + ] dialects = ["excel", "unix", "excel-tab"] encoding = "ISO-8859-1" tmpfnames = [ @@ -572,7 +606,11 @@ def test_standardize_in_place_multi(self): any(map(os.unlink, tmpfnames)) def test_standardize_in_place_multi_noop(self): - table = [["Å", "B", "C"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["Å", "B", "C"], + [1, 2, 3], + [4, 5, 6], + ] dialects = ["excel", "excel", "excel"] tmpfnames = [self._build_file(table, D, newline="") for D in dialects] diff --git a/tests/test_unit/test_detect_type.py b/tests/test_unit/test_detect_type.py index c4b8065b..83fcb42e 100644 --- a/tests/test_unit/test_detect_type.py +++ b/tests/test_unit/test_detect_type.py @@ -9,6 +9,8 @@ import unittest +from typing import List + from clevercsv.detect_type import TypeDetector from clevercsv.detect_type import type_score from clevercsv.dialect import SimpleDialect @@ -21,7 +23,7 @@ def setUp(self): # NUMBERS def test_number(self): - yes_number = [ + yes_number: List[str] = [ "1", "2", "34", @@ -87,7 +89,7 @@ def test_number(self): for num in yes_number: with self.subTest(num=num): self.assertTrue(self.td.is_number(num)) - no_number = [ + no_number: List[str] = [ "0000.213654", "123.465.798", "0.5e0.5", @@ -111,7 +113,7 @@ def test_number(self): # DATES def test_date(self): - yes_date = [ + yes_date: List[str] = [ "031219", "03122019", "03-12-19", @@ -162,7 +164,7 @@ def test_date(self): for date in yes_date: with self.subTest(date=date): self.assertTrue(self.td.is_date(date)) - no_date = [ + no_date: List[str] = [ "2018|01|02", "30/07-88", "12.01-99", @@ -177,11 +179,14 @@ def test_date(self): # DATETIME def test_datetime(self): - yes_dt = ["2019-01-12T04:01:23Z", "2021-09-26T12:13:31+01:00"] + yes_dt: List[str] = [ + "2019-01-12T04:01:23Z", + "2021-09-26T12:13:31+01:00", + ] for dt in yes_dt: with self.subTest(dt=dt): self.assertTrue(self.td.is_datetime(dt)) - no_date = [] + no_date: List[str] = [] for date in no_date: with self.subTest(date=date): self.assertFalse(self.td.is_datetime(dt)) @@ -190,7 +195,7 @@ def test_datetime(self): def test_url(self): # Some cases copied from https://mathiasbynens.be/demo/url-regex - yes_url = [ + yes_url: List[str] = [ "Cocoal.icio.us", "Websquash.com", "bbc.co.uk", @@ -262,7 +267,7 @@ def test_url(self): with self.subTest(url=url): self.assertTrue(self.td.is_url(url)) - no_url = [ + no_url: List[str] = [ "//", "///", "///a", @@ -305,7 +310,7 @@ def test_unicode_alphanum(self): # These tests are by no means inclusive and ought to be extended in the # future. - yes_alphanum = ["this is a cell", "1231 pounds"] + yes_alphanum: List[str] = ["this is a cell", "1231 pounds"] for unicode_alphanum in yes_alphanum: with self.subTest(unicode_alphanum=unicode_alphanum): self.assertTrue(self.td.is_unicode_alphanum(unicode_alphanum)) @@ -315,7 +320,7 @@ def test_unicode_alphanum(self): ) ) - no_alphanum = ["https://www.gertjan.dev"] + no_alphanum: List[str] = ["https://www.gertjan.dev"] for unicode_alpanum in no_alphanum: with self.subTest(unicode_alpanum=unicode_alpanum): self.assertFalse(self.td.is_unicode_alphanum(unicode_alpanum)) @@ -325,7 +330,7 @@ def test_unicode_alphanum(self): ) ) - only_quoted = ["this string, with a comma"] + only_quoted: List[str] = ["this string, with a comma"] for unicode_alpanum in only_quoted: with self.subTest(unicode_alpanum=unicode_alpanum): self.assertFalse( @@ -340,12 +345,12 @@ def test_unicode_alphanum(self): ) def test_bytearray(self): - yes_bytearray = [ + yes_bytearray: List[str] = [ "bytearray(b'')", "bytearray(b'abc,*&@\"')", "bytearray(b'bytearray(b'')')", ] - no_bytearray = [ + no_bytearray: List[str] = [ "bytearray(b'abc", "bytearray(b'abc'", "bytearray('abc')", @@ -363,7 +368,7 @@ def test_bytearray(self): # Unix path def test_unix_path(self): - yes_path = [ + yes_path: List[str] = [ "/Users/person/abc/def-ghi/blabla.csv.test", "/home/username/share/a/_b/c_d/e.py", "/home/username/share", diff --git a/tests/test_unit/test_dict.py b/tests/test_unit/test_dict.py index 8a941da8..43d45c5a 100644 --- a/tests/test_unit/test_dict.py +++ b/tests/test_unit/test_dict.py @@ -12,8 +12,13 @@ import tempfile import unittest +from typing import Any +from typing import Dict + import clevercsv +from clevercsv.dict_read_write import DictReader + class DictTestCase(unittest.TestCase): ############################ @@ -57,8 +62,9 @@ def test_write_fields_not_in_fieldnames(self): with tempfile.TemporaryFile("w+", newline="") as fp: writer = clevercsv.DictWriter(fp, fieldnames=["f1", "f2", "f3"]) # Of special note is the non-string key (CPython issue 19449) + content: Dict[Any, Any] = {"f4": 10, "f2": "spam", 1: "abc"} with self.assertRaises(ValueError) as cx: - writer.writerow({"f4": 10, "f2": "spam", 1: "abc"}) + writer.writerow(content) exception = str(cx.exception) self.assertIn("fieldnames", exception) self.assertIn("'f4'", exception) @@ -101,7 +107,7 @@ def test_read_dict_no_fieldnames(self): with tempfile.TemporaryFile("w+") as fp: fp.write("f1,f2,f3\r\n1,2,abc\r\n") fp.seek(0) - reader = clevercsv.DictReader(fp) + reader: DictReader = clevercsv.DictReader(fp) self.assertEqual(next(reader), {"f1": "1", "f2": "2", "f3": "abc"}) self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) @@ -123,7 +129,7 @@ def test_read_dict_fieldnames_chain(self): with tempfile.TemporaryFile("w+") as fp: fp.write("f1,f2,f3\r\n1,2,abc\r\n") fp.seek(0) - reader = clevercsv.DictReader(fp) + reader: DictReader = clevercsv.DictReader(fp) first = next(reader) for row in itertools.chain([first], reader): self.assertEqual(reader.fieldnames, ["f1", "f2", "f3"]) @@ -155,7 +161,7 @@ def test_read_long_with_rest_no_fieldnames(self): with tempfile.TemporaryFile("w+") as fp: fp.write("f1,f2\r\n1,2,abc,4,5,6\r\n") fp.seek(0) - reader = clevercsv.DictReader(fp, restkey="_rest") + reader: DictReader = clevercsv.DictReader(fp, restkey="_rest") self.assertEqual(reader.fieldnames, ["f1", "f2"]) self.assertEqual( next(reader), @@ -238,7 +244,9 @@ def test_read_semi_sep(self): # Start tests added for CleverCSV # def test_read_duplicate_fieldnames(self): - reader = clevercsv.DictReader(["f1,f2,f1\r\n", "a", "b", "c"]) + reader: DictReader = clevercsv.DictReader( + ["f1,f2,f1\r\n", "a", "b", "c"] + ) with self.assertWarns(UserWarning): reader.fieldnames diff --git a/tests/test_unit/test_wrappers.py b/tests/test_unit/test_wrappers.py index a2becc6a..e054f002 100644 --- a/tests/test_unit/test_wrappers.py +++ b/tests/test_unit/test_wrappers.py @@ -12,6 +12,9 @@ import types import unittest +from typing import List +from typing import Union + import pandas as pd from clevercsv import wrappers @@ -204,7 +207,11 @@ def _write_test_table(self, table, expected, **kwargs): os.unlink(tmpfname) def test_write_table(self): - table = [["A", "B,C", "D"], [1, 2, 3], [4, 5, 6]] + table: List[List[Union[str, int]]] = [ + ["A", "B,C", "D"], + [1, 2, 3], + [4, 5, 6], + ] exp = 'A,"B,C",D\r\n1,2,3\r\n4,5,6\r\n' with self.subTest(name="default"): self._write_test_table(table, exp) diff --git a/tests/test_unit/test_write.py b/tests/test_unit/test_write.py index 01f495e4..2c8367c9 100644 --- a/tests/test_unit/test_write.py +++ b/tests/test_unit/test_write.py @@ -18,13 +18,6 @@ class WriterTestCase(unittest.TestCase): - def writerAssertEqual(self, input, expected_result): - with tempfile.TemporaryFile("w+", newline="", prefix="ccsv_") as fp: - writer = clevercsv.writer(fp, dialect=self.dialect) - writer.writerows(input) - fp.seek(0) - self.assertEqual(fp.read(), expected_result) - def _write_test(self, fields, expect, **kwargs): with tempfile.TemporaryFile("w+", newline="", prefix="ccsv_") as fp: writer = clevercsv.writer(fp, **kwargs)