This repository has been archived by the owner on May 18, 2021. It is now read-only.
-
-
Notifications
You must be signed in to change notification settings - Fork 2
/
pymonkey.py
277 lines (216 loc) · 8.45 KB
/
pymonkey.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
from __future__ import absolute_import
from __future__ import unicode_literals
import collections
import contextlib
import imp
import os
import sys
import pkg_resources
Arguments = collections.namedtuple('Arguments', ('all', 'patches', 'cmd'))
class PymonkeySystemExit(SystemExit):
pass
class PymonkeyError(RuntimeError):
pass
HELPMSG = '''\
usage: {} [-h] [--debug] [--all] [patches [patches ...]] -- cmd [cmd ...]
A tool for applying monkeypatches to python executables. Patches are \
registered by supplying a setuptools entrypoint for `pymonkey`. Patches are \
selected by listing them on the commandline when running the pymonkey tool. \
For example, consider a registered patch pip_faster when using pip. An \
invocation may look like `pymonkey pip_faster -- pip install ...`.
positional arguments:
patches
cmd
optional arguments:
- h, --help show this help message and exit
--all Apply all known patches'''.format(sys.argv[0])
def print_std_err(s):
sys.stderr.write(s + '\n')
sys.stderr.flush()
def DEBUG(msg):
if 'PYMONKEY_DEBUG' in os.environ:
print_std_err('pymonkey: ' + msg)
def print_help_and_exit():
print_std_err(HELPMSG)
raise PymonkeySystemExit()
def manual_argument_parsing(argv):
"""sadness because argparse doesn't quite do what we want."""
# Special case these for a better error message
if not argv or argv == ['-h'] or argv == ['--help']:
print_help_and_exit()
try:
dashdash_index = argv.index('--')
except ValueError:
print_std_err('Must separate command by `--`')
print_help_and_exit()
patches, cmd = argv[:dashdash_index], argv[dashdash_index + 1:]
if '--help' in patches or '-h' in patches:
print_help_and_exit()
if '--all' in patches:
all_patches = True
patches.remove('--all')
else:
all_patches = False
unknown_options = [patch for patch in patches if patch.startswith('-')]
if unknown_options:
print_std_err('Unknown options: {!r}'.format(unknown_options))
print_help_and_exit()
if patches and all_patches:
print_std_err('--all and patches specified: {!r}'.format(patches))
print_help_and_exit()
return Arguments(all=all_patches, patches=tuple(patches), cmd=tuple(cmd))
def importmod(mod):
return __import__(mod, fromlist=[str('__name__')], level=0)
def _noop(*a, **k):
return None
class PymonkeyImportHook(object):
"""This is where the magic happens.
This import hook is responsible for the following things:
- It will load all modules
- In loading, it'll first invoke builtin import.
- It'll then pass the module that it imported through each of the
pymonkey hooks.
"""
def __init__(self, hooks):
self._hooks = hooks
self._entry_data = dict.fromkeys(hooks)
self._handling = []
def _module_exists(self, module, path):
# First check other entries in metapath for the module
# Otherwise, try basic python import logic
for entry in sys.meta_path:
if (
entry is not self and (
getattr(entry, 'find_spec', _noop)(module, path) or
getattr(entry, 'find_module', _noop)(module, path)
)
):
return True
# We're either passed:
# - A toplevel module name and `None` for path
# - The fullpath to a module and a list for path
# imp.find_module takes the following:
# - A toplevel module name and `None` for path
# - A subpackage and a list for path
# Solution:
# Convert the full modulename we're given into the subpackage
if path is not None:
to_try_mod = module.split('.')[-1]
else:
to_try_mod = module
try:
imp.find_module(to_try_mod, path)
return True # pragma: no cover (PY3 import is via sys.meta_path)
except ImportError:
return False
@contextlib.contextmanager
def handling(self, modname):
self._handling.append(modname)
try:
yield
finally:
popped = self._handling.pop()
assert popped == modname, (popped, modname)
def find_module(self, fullname, path=None):
# Shortcut if we're already processing this module
if fullname in self._handling:
DEBUG('already handling {}'.format(fullname))
return
# Make sure we can actually handle this module
elif self._module_exists(fullname, path):
DEBUG('found {}'.format(fullname))
return self
else:
DEBUG('not found {}'.format(fullname))
return
def load_module(self, fullname):
# Since we're going to invoke the import machinery and hit ourselves
# again, store some state so we don't recurse forever
with self.handling(fullname):
module = importmod(fullname)
for entry, hook_fn in self._hooks.items():
hook_fn(module, self._entry_data[entry])
return module
def set_entry_data(self, entry, data):
self._entry_data[entry] = data
@contextlib.contextmanager
def assert_no_other_modules_imported(imported_modname):
def getmods():
return {modname for modname, mod in sys.modules.items() if mod}
before = getmods()
yield
after = getmods()
unexpected_imports = sorted(
modname for modname in after - before
if not imported_modname.startswith(modname)
)
if unexpected_imports:
raise PymonkeyError(
'pymonkey modules must not trigger imports at the module scope. '
'The following modules were imported while importing {}:\n'
'{}'.format(
imported_modname, '\t' + '\t\n'.join(unexpected_imports),
),
)
def get_entry_callables(all_patches, patches, pymonkey_entry_points, attr):
def _to_callable(entry_point):
"""If they give us a module, retrieve `attr`"""
with assert_no_other_modules_imported(entry_point.module_name):
# Load the module manually to avoid pkg_resources side-effects
loaded = importmod(entry_point.module_name)
for entry_attr in entry_point.attrs:
loaded = getattr(loaded, entry_attr)
if callable(loaded):
return loaded
else:
return getattr(loaded, attr)
if all_patches:
entry_points = pymonkey_entry_points
else:
all_entries = {entry.name: entry for entry in pymonkey_entry_points}
missing = set(patches) - set(all_entries)
if missing:
print_std_err('Could not find patch(es): {}'.format(missing))
raise PymonkeySystemExit(1)
entry_points = [all_entries[name] for name in patches]
return {entry.name: _to_callable(entry) for entry in entry_points}
def main(argv=None):
argv = argv if argv is not None else sys.argv[1:]
args = manual_argument_parsing(argv)
# Register patches
callables = get_entry_callables(
args.all, args.patches,
tuple(pkg_resources.iter_entry_points('pymonkey')),
attr='pymonkey_patch',
)
hook = PymonkeyImportHook(callables)
# Important to insert at the beginning to be ahead of the stdlib importer
sys.meta_path.insert(0, hook)
# Allow hooks to do argument parsing
argv_callables = get_entry_callables(
args.all, args.patches,
tuple(pkg_resources.iter_entry_points('pymonkey.argparse')),
attr='pymonkey_argparse',
)
cmd, rest = args.cmd[0], tuple(args.cmd[1:])
for entry_name, argv_callable in argv_callables.items():
args, rest = tuple(argv_callable(rest))
hook.set_entry_data(entry_name, args)
# Call the thing
entry, = tuple(pkg_resources.iter_entry_points('console_scripts', cmd))
sys.argv = [cmd] + list(rest)
return entry.load()()
def make_entry_point(patches, original_entry_point):
"""Use this to make a console_script entry point for your application
which applies patches.
:param patches: iterable of pymonkey patches to apply. Ex: ('my-patch,)
:param original_entry_point: Such as 'pip'
"""
def entry(argv=None):
argv = argv if argv is not None else sys.argv[1:]
return main(
tuple(patches) + ('--', original_entry_point) + tuple(argv)
)
return entry
if __name__ == '__main__':
sys.exit(main())