Skip to content

Commit 93fb3d5

Browse files
authored
JIT argument order fix (#639)
* Fix argument ordering in JIT * Format * Update JIT tests * Fix JIT test
1 parent b3f6c12 commit 93fb3d5

File tree

3 files changed

+75
-57
lines changed

3 files changed

+75
-57
lines changed

jit/codon/__init__.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# Copyright (C) 2022-2025 Exaloop Inc. <https://exaloop.io>
22

3-
__all__ = ["jit", "convert", "JITError"]
3+
__all__ = [
4+
"jit", "convert", "JITError", "JITWrapper", "_jit_register_fn", "_jit"
5+
]
46

5-
from .decorator import jit, convert, execute, JITError
7+
from .decorator import jit, convert, execute, JITError, JITWrapper, _jit_register_fn, _jit_callback_fn, _jit

jit/codon/decorator.py

+58-55
Original file line numberDiff line numberDiff line change
@@ -23,16 +23,14 @@
2323
if codon_lib_path:
2424
codon_path.append(Path(codon_lib_path).parent / "stdlib")
2525
codon_path.append(
26-
Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib"
27-
)
26+
Path(os.path.expanduser("~")) / ".codon" / "lib" / "codon" / "stdlib")
2827
for path in codon_path:
2928
if path.exists():
3029
os.environ["CODON_PATH"] = str(path.resolve())
3130
break
3231
else:
3332
raise RuntimeError(
34-
"Cannot locate Codon. Please install Codon or set CODON_PATH."
35-
)
33+
"Cannot locate Codon. Please install Codon or set CODON_PATH.")
3634

3735
pod_conversions = {
3836
type(None): "pyobj",
@@ -61,7 +59,6 @@
6159
custom_conversions = {}
6260
_error_msgs = set()
6361

64-
6562
def _common_type(t, debug, sample_size):
6663
sub, is_optional = None, False
6764
for i in itertools.islice(t, sample_size):
@@ -76,7 +73,6 @@ def _common_type(t, debug, sample_size):
7673
sub = "Optional[{}]".format(sub)
7774
return sub if sub else "pyobj"
7875

79-
8076
def _codon_type(arg, **kwargs):
8177
t = type(arg)
8278

@@ -88,11 +84,11 @@ def _codon_type(arg, **kwargs):
8884
if issubclass(t, set):
8985
return "Set[{}]".format(_common_type(arg, **kwargs))
9086
if issubclass(t, dict):
91-
return "Dict[{},{}]".format(
92-
_common_type(arg.keys(), **kwargs), _common_type(arg.values(), **kwargs)
93-
)
87+
return "Dict[{},{}]".format(_common_type(arg.keys(), **kwargs),
88+
_common_type(arg.values(), **kwargs))
9489
if issubclass(t, tuple):
95-
return "Tuple[{}]".format(",".join(_codon_type(a, **kwargs) for a in arg))
90+
return "Tuple[{}]".format(",".join(
91+
_codon_type(a, **kwargs) for a in arg))
9692
if issubclass(t, np.ndarray):
9793
if arg.dtype == np.bool_:
9894
dtype = "bool"
@@ -134,7 +130,8 @@ def _codon_type(arg, **kwargs):
134130

135131
s = custom_conversions.get(t, "")
136132
if s:
137-
j = ",".join(_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__)
133+
j = ",".join(
134+
_codon_type(getattr(arg, slot), **kwargs) for slot in t.__slots__)
138135
return "{}[{}]".format(s, j)
139136

140137
debug = kwargs.get("debug", None)
@@ -145,28 +142,22 @@ def _codon_type(arg, **kwargs):
145142
_error_msgs.add(msg)
146143
return "pyobj"
147144

148-
149145
def _codon_types(args, **kwargs):
150146
return tuple(_codon_type(arg, **kwargs) for arg in args)
151147

152-
153148
def _reset_jit():
154149
global _jit
155150
_jit = JITWrapper()
156-
init_code = (
157-
"from internal.python import "
158-
"setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n"
159-
"setup_decorator()\n"
160-
"import numpy as np\n"
161-
"import numpy.pybridge\n"
162-
)
151+
init_code = ("from internal.python import "
152+
"setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n"
153+
"setup_decorator()\n"
154+
"import numpy as np\n"
155+
"import numpy.pybridge\n")
163156
_jit.execute(init_code, "", 0, False)
164157
return _jit
165158

166-
167159
_jit = _reset_jit()
168160

169-
170161
class RewriteFunctionArgs(ast.NodeTransformer):
171162
def __init__(self, args):
172163
self.args = args
@@ -176,7 +167,6 @@ def visit_FunctionDef(self, node):
176167
node.args.args.append(ast.arg(arg=a, annotation=None))
177168
return node
178169

179-
180170
def _obj_to_str(obj, **kwargs) -> str:
181171
if inspect.isclass(obj):
182172
lines = inspect.getsourcelines(obj)[0]
@@ -185,8 +175,10 @@ def _obj_to_str(obj, **kwargs) -> str:
185175
obj_name = obj.__name__
186176
elif callable(obj) or isinstance(obj, str):
187177
is_str = isinstance(obj, str)
188-
lines = [i + '\n' for i in obj.split('\n')] if is_str else inspect.getsourcelines(obj)[0]
189-
if not is_str: lines = lines[1:]
178+
lines = [i + '\n' for i in obj.split('\n')
179+
] if is_str else inspect.getsourcelines(obj)[0]
180+
if not is_str:
181+
lines = lines[1:]
190182
obj_str = textwrap.dedent(''.join(lines))
191183

192184
pyvars = kwargs.get("pyvars", None)
@@ -195,8 +187,7 @@ def _obj_to_str(obj, **kwargs) -> str:
195187
if not isinstance(i, str):
196188
raise ValueError("pyvars only takes string literals")
197189
node = ast.fix_missing_locations(
198-
RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str))
199-
)
190+
RewriteFunctionArgs(pyvars).visit(ast.parse(obj_str)))
200191
obj_str = astunparse.unparse(node)
201192
if is_str:
202193
try:
@@ -206,46 +197,38 @@ def _obj_to_str(obj, **kwargs) -> str:
206197
else:
207198
obj_name = obj.__name__
208199
else:
209-
raise TypeError("Function or class expected, got " + type(obj).__name__)
200+
raise TypeError("Function or class expected, got " +
201+
type(obj).__name__)
210202
return obj_name, obj_str.replace("_@par", "@par")
211203

212-
213204
def _parse_decorated(obj, **kwargs):
214-
return _obj_to_str(obj, **kwargs)
215-
205+
return _obj_to_str(obj, **kwargs)
216206

217207
def convert(t):
218208
if not hasattr(t, "__slots__"):
219-
raise JITError("class '{}' does not have '__slots__' attribute".format(str(t)))
209+
raise JITError("class '{}' does not have '__slots__' attribute".format(
210+
str(t)))
220211

221212
name = t.__name__
222213
slots = t.__slots__
223-
code = (
224-
"@tuple\n"
225-
"class "
226-
+ name
227-
+ "["
228-
+ ",".join("T{}".format(i) for i in range(len(slots)))
229-
+ "]:\n"
230-
)
214+
code = ("@tuple\n"
215+
"class " + name + "[" +
216+
",".join("T{}".format(i) for i in range(len(slots))) + "]:\n")
231217
for i, slot in enumerate(slots):
232218
code += " {}: T{}\n".format(slot, i)
233219

234220
# PyObject_GetAttrString
235221
code += " def __from_py__(p: cobj):\n"
236222
for i, slot in enumerate(slots):
237223
code += " a{} = T{}.__from_py__(PyObject_GetAttrString(p, '{}'.ptr))\n".format(
238-
i, i, slot
239-
)
224+
i, i, slot)
240225
code += " return {}({})\n".format(
241-
name, ", ".join("a{}".format(i) for i in range(len(slots)))
242-
)
226+
name, ", ".join("a{}".format(i) for i in range(len(slots))))
243227

244228
_jit.execute(code, "", 0, False)
245229
custom_conversions[t] = name
246230
return t
247231

248-
249232
def _jit_register_fn(f, pyvars, debug):
250233
try:
251234
obj_name, obj_str = _parse_decorated(f, pyvars=pyvars)
@@ -258,29 +241,46 @@ def _jit_register_fn(f, pyvars, debug):
258241
_reset_jit()
259242
raise
260243

261-
def _jit_callback_fn(obj_name, module, debug=None, sample_size=5, pyvars=None, *args, **kwargs):
262-
try:
244+
def _jit_callback_fn(fn,
245+
obj_name,
246+
module,
247+
debug=None,
248+
sample_size=5,
249+
pyvars=None,
250+
*args,
251+
**kwargs):
252+
if fn is not None:
253+
sig = inspect.signature(fn)
254+
bound_args = sig.bind(*args, **kwargs)
255+
bound_args.apply_defaults()
256+
args = tuple(bound_args.arguments[param] for param in sig.parameters)
257+
else:
263258
args = (*args, *kwargs.values())
259+
260+
try:
264261
types = _codon_types(args, debug=debug, sample_size=sample_size)
265262
if debug:
266-
print("[python] {}({})".format(obj_name, list(types)), file=sys.stderr)
267-
return _jit.run_wrapper(
268-
obj_name, list(types), module, list(pyvars), args, 1 if debug else 0
269-
)
263+
print("[python] {}({})".format(obj_name, list(types)),
264+
file=sys.stderr)
265+
return _jit.run_wrapper(obj_name, list(types), module, list(pyvars),
266+
args, 1 if debug else 0)
270267
except JITError:
271268
_reset_jit()
272269
raise
273270

274271
def _jit_str_fn(fstr, debug=None, sample_size=5, pyvars=None):
275272
obj_name = _jit_register_fn(fstr, pyvars, debug)
273+
276274
def wrapped(*args, **kwargs):
277-
return _jit_callback_fn(obj_name, "__main__", debug, sample_size, pyvars, *args, **kwargs)
278-
return wrapped
275+
return _jit_callback_fn(None, obj_name, "__main__", debug, sample_size,
276+
pyvars, *args, **kwargs)
279277

278+
return wrapped
280279

281280
def jit(fn=None, debug=None, sample_size=5, pyvars=None):
282281
if not pyvars:
283282
pyvars = []
283+
284284
if not isinstance(pyvars, list):
285285
raise ArgumentError("pyvars must be a list")
286286

@@ -289,12 +289,15 @@ def jit(fn=None, debug=None, sample_size=5, pyvars=None):
289289

290290
def _decorate(f):
291291
obj_name = _jit_register_fn(f, pyvars, debug)
292+
292293
@functools.wraps(f)
293294
def wrapped(*args, **kwargs):
294-
return _jit_callback_fn(obj_name, f.__module__, debug, sample_size, pyvars, *args, **kwargs)
295+
return _jit_callback_fn(f, obj_name, f.__module__, debug,
296+
sample_size, pyvars, *args, **kwargs)
297+
295298
return wrapped
296-
return _decorate(fn) if fn else _decorate
297299

300+
return _decorate(fn) if fn else _decorate
298301

299302
def execute(code, debug=False):
300303
try:

test/python/cython_jit.py

+13
Original file line numberDiff line numberDiff line change
@@ -181,3 +181,16 @@ def test_ndarray():
181181
assert np.datetime_data(y.dtype) == ('s', 2)
182182

183183
test_ndarray()
184+
185+
@codon.jit
186+
def e(x=2, y=99):
187+
return 2*x + y
188+
189+
def test_arg_order():
190+
assert e(1, 2) == 4
191+
assert e(1) == 101
192+
assert e(y=10, x=1) == 12
193+
assert e(x=1) == 101
194+
assert e() == 103
195+
196+
test_arg_order()

0 commit comments

Comments
 (0)