-
Notifications
You must be signed in to change notification settings - Fork 23
/
pytest_black.py
118 lines (91 loc) · 3.55 KB
/
pytest_black.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
# -*- coding: utf-8 -*-
# stdlib imports
import subprocess
import re
import sys
# third-party imports
import pytest
import toml
HISTKEY = "black/mtimes"
def pytest_addoption(parser):
group = parser.getgroup("general")
group.addoption(
"--black", action="store_true", help="enable format checking with black"
)
def pytest_collect_file(path, parent):
config = parent.config
if config.option.black and path.ext == ".py":
if hasattr(BlackItem, "from_parent"):
return BlackItem.from_parent(parent, fspath=path)
else:
return BlackItem(path, parent)
def pytest_configure(config):
# load cached mtimes at session startup
if config.option.black and hasattr(config, "cache"):
config._blackmtimes = config.cache.get(HISTKEY, {})
config.addinivalue_line("markers", "black: enable format checking with black")
def pytest_unconfigure(config):
# save cached mtimes at end of session
if hasattr(config, "_blackmtimes"):
config.cache.set(HISTKEY, config._blackmtimes)
class BlackItem(pytest.Item, pytest.File):
def __init__(self, fspath, parent):
super(BlackItem, self).__init__(fspath, parent)
self._nodeid += "::BLACK"
self.add_marker("black")
try:
with open("pyproject.toml") as toml_file:
settings = toml.load(toml_file)["tool"]["black"]
if "include" in settings.keys():
settings["include"] = self._re_fix_verbose(settings["include"])
if "exclude" in settings.keys():
settings["exclude"] = self._re_fix_verbose(settings["exclude"])
self.pyproject = settings
except Exception:
self.pyproject = {}
def setup(self):
pytest.importorskip("black")
mtimes = getattr(self.config, "_blackmtimes", {})
self._blackmtime = self.fspath.mtime()
old = mtimes.get(str(self.fspath), 0)
if self._blackmtime == old:
pytest.skip("file(s) previously passed black format checks")
if self._skip_test():
pytest.skip("file(s) excluded by pyproject.toml")
def runtest(self):
cmd = [sys.executable, "-m", "black", "--check", "--diff", "--quiet", str(self.fspath)]
try:
subprocess.run(
cmd, check=True, stdout=subprocess.PIPE, universal_newlines=True
)
except subprocess.CalledProcessError as e:
raise BlackError(e)
mtimes = getattr(self.config, "_blackmtimes", {})
mtimes[str(self.fspath)] = self._blackmtime
def repr_failure(self, excinfo):
if excinfo.errisinstance(BlackError):
return excinfo.value.args[0].stdout
return super(BlackItem, self).repr_failure(excinfo)
def reportinfo(self):
return (self.fspath, -1, "Black format check")
def _skip_test(self):
return self._excluded() or (not self._included())
def _included(self):
if "include" not in self.pyproject:
return True
return re.search(self.pyproject["include"], str(self.fspath))
def _excluded(self):
if "exclude" not in self.pyproject:
return False
return re.search(self.pyproject["exclude"], str(self.fspath))
def _re_fix_verbose(self, regex):
if "\n" in regex:
regex = "(?x)" + regex
return re.compile(regex)
def collect(self):
""" returns a list of children (items and collectors)
for this collection node.
"""
return (self,)
class BlackError(Exception):
pass