Skip to content

Commit

Permalink
Add type hints to CleverCSV (#108)
Browse files Browse the repository at this point in the history
* Add type hints to CleverCSV

This commit adds type hints to CleverCSV and enables static type checking with MyPy. An attempt was made to keep the changes minimal, in order to change only the code needed to get MyPy to pass. As can be seen, this already required a large number of changes to the code. It is also clear that certain design decisions could be reevaluated to make the type checking smoother (e.g., requiring @overload is not ideal). Such improvements are left for future work.
  • Loading branch information
GjjvdBurg committed Sep 6, 2023
1 parent 6220562 commit 207f4ef
Show file tree
Hide file tree
Showing 44 changed files with 1,992 additions and 261 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ jobs:
strategy:
matrix:
os: [ 'ubuntu-latest', 'macos-latest', 'windows-latest' ]
py: [ '3.7', '3.11' ] # minimal and latest
py: [ '3.8', '3.11' ] # minimal and latest
steps:
- name: Install Python ${{ matrix.py }}
uses: actions/setup-python@v2
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ jobs:
env:
CIBW_TEST_COMMAND: "python -VV && python -m unittest discover -f -s {project}/tests/test_unit/"
CIBW_TEST_EXTRAS: "full"
CIBW_SKIP: "pp* cp27-* cp33-* cp34-* cp35-* cp36-* cp310-win32 *-musllinux_* *-manylinux_i686"
CIBW_SKIP: "pp* cp27-* cp33-* cp34-* cp35-* cp36-* cp37-* cp310-win32 *-musllinux_* *-manylinux_i686"
CIBW_ARCHS_MACOS: x86_64 arm64 universal2
CIBW_ARCHS_LINUX: auto aarch64

Expand Down
8 changes: 6 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ MAKEFLAGS += --no-builtin-rules

PACKAGE=clevercsv
DOC_DIR=./docs/
VENV_DIR=/tmp/clevercsv_venv/
VENV_DIR=/tmp/clevercsv_venv
PYTHON ?= python

.PHONY: help
Expand Down Expand Up @@ -51,14 +51,18 @@ dist: man ## Make Python source distribution

.PHONY: test integration integration_partial

test: green pytest
test: mypy green pytest

green: venv ## Run unit tests
source $(VENV_DIR)/bin/activate && green -a -vv ./tests/test_unit

pytest: venv ## Run unit tests with PyTest
source $(VENV_DIR)/bin/activate && pytest -ra -m 'not network'

mypy: venv ## Run type checks
source $(VENV_DIR)/bin/activate && \
mypy --check-untyped-defs ./stubs $(PACKAGE) ./tests

integration: venv ## Run integration tests
source $(VENV_DIR)/bin/activate && python ./tests/test_integration/test_dialect_detection.py -v

Expand Down
6 changes: 4 additions & 2 deletions clevercsv/__version__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-

VERSION = (0, 8, 0)
from typing import Tuple

__version__ = ".".join(map(str, VERSION))
VERSION: Tuple[int, int, int] = (0, 8, 0)

__version__: str = ".".join(map(str, VERSION))
7 changes: 6 additions & 1 deletion clevercsv/_optional.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@

import importlib

from types import ModuleType

from typing import Dict
from typing import List
from typing import NamedTuple
from typing import Optional

from packaging.version import Version

Expand All @@ -35,7 +38,9 @@ class OptionalDependency(NamedTuple):
]


def import_optional_dependency(name, raise_on_missing=True):
def import_optional_dependency(
name: str, raise_on_missing: bool = True
) -> Optional[ModuleType]:
"""
Import an optional dependency.
Expand Down
50 changes: 50 additions & 0 deletions clevercsv/_types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# -*- coding: utf-8 -*-

from __future__ import annotations

import csv
import os
import sys

from typing import TYPE_CHECKING
from typing import Any
from typing import Mapping
from typing import Type
from typing import Union

from clevercsv.dialect import SimpleDialect

AnyPath = Union[str, bytes, "os.PathLike[str]", "os.PathLike[bytes]"]
_OpenFile = Union[AnyPath, int]
_DictRow = Mapping[str, Any]
_DialectLike = Union[str, csv.Dialect, Type[csv.Dialect], SimpleDialect]

if sys.version_info >= (3, 8):
from typing import Dict as _DictReadMapping
else:
from collections import OrderedDict as _DictReadMapping


if TYPE_CHECKING:
from _typeshed import FileDescriptorOrPath # NOQA
from _typeshed import SupportsIter # NOQA
from _typeshed import SupportsWrite # NOQA

__all__ = [
"SupportsWrite",
"SupportsIter",
"FileDescriptorOrPath",
"AnyPath",
"_OpenFile",
"_DictRow",
"_DialectLike",
"_DictReadMapping",
]
else:
__all__ = [
"AnyPath",
"_OpenFile",
"_DictRow",
"_DialectLike",
"_DictReadMapping",
]
69 changes: 45 additions & 24 deletions clevercsv/break_ties.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@
"""

from typing import List
from typing import Optional

from .cparser_util import parse_string
from .dialect import SimpleDialect
from .utils import pairwise


def tie_breaker(data, dialects):
def tie_breaker(
data: str, dialects: List[SimpleDialect]
) -> Optional[SimpleDialect]:
"""
Break ties between dialects.
Expand Down Expand Up @@ -42,7 +47,9 @@ def tie_breaker(data, dialects):
return None


def reduce_pairwise(data, dialects):
def reduce_pairwise(
data: str, dialects: List[SimpleDialect]
) -> Optional[List[SimpleDialect]]:
"""Reduce the set of dialects by breaking pairwise ties
Parameters
Expand All @@ -62,7 +69,7 @@ def reduce_pairwise(data, dialects):
"""
equal_delim = len(set([d.delimiter for d in dialects])) == 1
if not equal_delim:
return None
return None # TODO: This might be wrong, it can just return the input!

# First, identify dialects that result in the same parsing result.
equal_dialects = []
Expand Down Expand Up @@ -99,7 +106,9 @@ def _dialects_only_differ_in_field(
)


def break_ties_two(data, A, B):
def break_ties_two(
data: str, A: SimpleDialect, B: SimpleDialect
) -> Optional[SimpleDialect]:
"""Break ties between two dialects.
This function breaks ties between two dialects that give the same score. We
Expand Down Expand Up @@ -152,7 +161,7 @@ def break_ties_two(data, A, B):
# quotechar has an effect
return d_yes
elif _dialects_only_differ_in_field(A, B, "delimiter"):
if sorted([A.delimiter, B.delimiter]) == sorted([",", " "]):
if set([A.delimiter, B.delimiter]) == set([",", " "]):
# Artifact due to type detection (comma as radix point)
if A.delimiter == ",":
return A
Expand All @@ -175,14 +184,14 @@ def break_ties_two(data, A, B):
# we can't break this tie (for now)
if len(X) != len(Y):
return None
for x, y in zip(X, Y):
if len(x) != len(y):
for row_X, row_Y in zip(X, Y):
if len(row_X) != len(row_Y):
return None

cells_escaped = []
cells_unescaped = []
for x, y in zip(X, Y):
for u, v in zip(x, y):
for row_X, row_Y in zip(X, Y):
for u, v in zip(row_X, row_Y):
if u != v:
cells_unescaped.append(u)
cells_escaped.append(v)
Expand Down Expand Up @@ -221,16 +230,18 @@ def break_ties_two(data, A, B):

if len(X) != len(Y):
return None
for x, y in zip(X, Y):
if len(x) != len(y):
for row_X, row_Y in zip(X, Y):
if len(row_X) != len(row_Y):
return None

# if we're here, then there is no effect on structure.
# we test if the only cells that differ are those that have an
# escapechar+quotechar combination.
assert isinstance(d_yes.escapechar, str)
assert isinstance(d_yes.quotechar, str)
eq = d_yes.escapechar + d_yes.quotechar
for rX, rY in zip(X, Y):
for x, y in zip(rX, rY):
for row_X, row_Y in zip(X, Y):
for x, y in zip(row_X, row_Y):
if x != y:
if eq not in x:
return None
Expand All @@ -243,7 +254,9 @@ def break_ties_two(data, A, B):
return None


def break_ties_three(data, A, B, C):
def break_ties_three(
data: str, A: SimpleDialect, B: SimpleDialect, C: SimpleDialect
) -> Optional[SimpleDialect]:
"""Break ties between three dialects.
If the delimiters and the escape characters are all equal, then we look for
Expand Down Expand Up @@ -273,7 +286,7 @@ def break_ties_three(data, A, B, C):
Returns
-------
dialect: SimpleDialect
dialect: Optional[SimpleDialect]
The chosen dialect if the tie can be broken, None otherwise.
Notes
Expand Down Expand Up @@ -307,6 +320,7 @@ def break_ties_three(data, A, B, C):
)
if p_none is None:
return None
assert d_none is not None

rem = [
(p, d) for p, d in zip([pA, pB, pC], dialects) if not p == p_none
Expand All @@ -318,6 +332,8 @@ def break_ties_three(data, A, B, C):
# the CSV paper. When fixing the delimiter to Tab, rem = [].
# Try to reduce pairwise
new_dialects = reduce_pairwise(data, dialects)
if new_dialects is None:
return None
if len(new_dialects) == 1:
return new_dialects[0]
return None
Expand Down Expand Up @@ -347,7 +363,9 @@ def break_ties_three(data, A, B, C):
return None


def break_ties_four(data, dialects):
def break_ties_four(
data: str, dialects: List[SimpleDialect]
) -> Optional[SimpleDialect]:
"""Break ties between four dialects.
This function works by breaking the ties between pairs of dialects that
Expand All @@ -368,7 +386,7 @@ def break_ties_four(data, dialects):
Returns
-------
dialect: SimpleDialect
dialect: Optional[SimpleDialect]
The chosen dialect if the tie can be broken, None otherwise.
Notes
Expand All @@ -378,19 +396,22 @@ def break_ties_four(data, dialects):
examples are found.
"""
# TODO: Check for length 4, handle more than 4 too?

equal_delim = len(set([d.delimiter for d in dialects])) == 1
if not equal_delim:
return None

dialects = reduce_pairwise(data, dialects)
reduced_dialects = reduce_pairwise(data, dialects)
if reduced_dialects is None:
return None

# Defer to other functions if the number of dialects was reduced
if len(dialects) == 1:
return dialects[0]
elif len(dialects) == 2:
return break_ties_two(data, *dialects)
elif len(dialects) == 3:
return break_ties_three(data, *dialects)
if len(reduced_dialects) == 1:
return reduced_dialects[0]
elif len(reduced_dialects) == 2:
return break_ties_two(data, *reduced_dialects)
elif len(reduced_dialects) == 3:
return break_ties_three(data, *reduced_dialects)

return None
11 changes: 11 additions & 0 deletions clevercsv/cabstraction.pyi
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
# -*- coding: utf-8 -*-

from typing import Optional

def base_abstraction(
data: str,
delimiter: Optional[str],
quotechar: Optional[str],
escapechar: Optional[str],
) -> str: ...
def c_merge_with_quotechar(data: str) -> str: ...
15 changes: 11 additions & 4 deletions clevercsv/consistency.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def cached_is_known_type(cell: str, is_quoted: bool) -> bool:

def detect(
self, data: str, delimiters: Optional[Iterable[str]] = None
) -> None:
) -> Optional[SimpleDialect]:
"""Detect the dialect using the consistency measure
Parameters
Expand Down Expand Up @@ -184,8 +184,11 @@ def get_best_dialects(
) -> List[SimpleDialect]:
"""Identify the dialects with the highest consistency score"""
Qscores = [score.Q for score in scores.values()]
Qscores = list(filter(lambda q: q is not None, Qscores))
Qmax = max(Qscores)
Qmax = -float("inf")
for q in Qscores:
if q is None:
continue
Qmax = max(Qmax, q)
return [d for d, score in scores.items() if score.Q == Qmax]

def compute_type_score(
Expand All @@ -194,6 +197,7 @@ def compute_type_score(
"""Compute the type score"""
total = known = 0
for row in parse_string(data, dialect, return_quoted=True):
assert all(isinstance(cell, tuple) for cell in row)
for cell, is_quoted in row:
total += 1
known += self._cached_is_known_type(cell, is_quoted=is_quoted)
Expand All @@ -203,7 +207,10 @@ def compute_type_score(


def detect_dialect_consistency(
data, delimiters=None, skip=True, verbose=False
data: str,
delimiters: Optional[Iterable[str]] = None,
skip: bool = True,
verbose: bool = False,
):
"""Helper function that wraps ConsistencyDetector"""
# Mostly kept for backwards compatibility
Expand Down
5 changes: 4 additions & 1 deletion clevercsv/console/commands/detect.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,9 @@
import sys
import time

from typing import Any
from typing import Dict

from wilderness import Command

from clevercsv.wrappers import detect_dialect
Expand Down Expand Up @@ -125,7 +128,7 @@ def handle(self):
if self.args.add_runtime:
print(f"runtime = {runtime}")
elif self.args.json:
dialect_dict = dialect.to_dict()
dialect_dict: Dict[str, Any] = dialect.to_dict()
if self.args.add_runtime:
dialect_dict["runtime"] = runtime
print(json.dumps(dialect_dict))
Expand Down
Loading

0 comments on commit 207f4ef

Please sign in to comment.