23
23
if codon_lib_path :
24
24
codon_path .append (Path (codon_lib_path ).parent / "stdlib" )
25
25
codon_path .append (
26
- Path (os .path .expanduser ("~" )) / ".codon" / "lib" / "codon" / "stdlib"
27
- )
26
+ Path (os .path .expanduser ("~" )) / ".codon" / "lib" / "codon" / "stdlib" )
28
27
for path in codon_path :
29
28
if path .exists ():
30
29
os .environ ["CODON_PATH" ] = str (path .resolve ())
31
30
break
32
31
else :
33
32
raise RuntimeError (
34
- "Cannot locate Codon. Please install Codon or set CODON_PATH."
35
- )
33
+ "Cannot locate Codon. Please install Codon or set CODON_PATH." )
36
34
37
35
pod_conversions = {
38
36
type (None ): "pyobj" ,
61
59
custom_conversions = {}
62
60
_error_msgs = set ()
63
61
64
-
65
62
def _common_type (t , debug , sample_size ):
66
63
sub , is_optional = None , False
67
64
for i in itertools .islice (t , sample_size ):
@@ -76,7 +73,6 @@ def _common_type(t, debug, sample_size):
76
73
sub = "Optional[{}]" .format (sub )
77
74
return sub if sub else "pyobj"
78
75
79
-
80
76
def _codon_type (arg , ** kwargs ):
81
77
t = type (arg )
82
78
@@ -88,11 +84,11 @@ def _codon_type(arg, **kwargs):
88
84
if issubclass (t , set ):
89
85
return "Set[{}]" .format (_common_type (arg , ** kwargs ))
90
86
if issubclass (t , dict ):
91
- return "Dict[{},{}]" .format (
92
- _common_type (arg .keys (), ** kwargs ), _common_type (arg .values (), ** kwargs )
93
- )
87
+ return "Dict[{},{}]" .format (_common_type (arg .keys (), ** kwargs ),
88
+ _common_type (arg .values (), ** kwargs ))
94
89
if issubclass (t , tuple ):
95
- return "Tuple[{}]" .format ("," .join (_codon_type (a , ** kwargs ) for a in arg ))
90
+ return "Tuple[{}]" .format ("," .join (
91
+ _codon_type (a , ** kwargs ) for a in arg ))
96
92
if issubclass (t , np .ndarray ):
97
93
if arg .dtype == np .bool_ :
98
94
dtype = "bool"
@@ -134,7 +130,8 @@ def _codon_type(arg, **kwargs):
134
130
135
131
s = custom_conversions .get (t , "" )
136
132
if s :
137
- j = "," .join (_codon_type (getattr (arg , slot ), ** kwargs ) for slot in t .__slots__ )
133
+ j = "," .join (
134
+ _codon_type (getattr (arg , slot ), ** kwargs ) for slot in t .__slots__ )
138
135
return "{}[{}]" .format (s , j )
139
136
140
137
debug = kwargs .get ("debug" , None )
@@ -145,28 +142,22 @@ def _codon_type(arg, **kwargs):
145
142
_error_msgs .add (msg )
146
143
return "pyobj"
147
144
148
-
149
145
def _codon_types (args , ** kwargs ):
150
146
return tuple (_codon_type (arg , ** kwargs ) for arg in args )
151
147
152
-
153
148
def _reset_jit ():
154
149
global _jit
155
150
_jit = JITWrapper ()
156
- init_code = (
157
- "from internal.python import "
158
- "setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n "
159
- "setup_decorator()\n "
160
- "import numpy as np\n "
161
- "import numpy.pybridge\n "
162
- )
151
+ init_code = ("from internal.python import "
152
+ "setup_decorator, PyTuple_GetItem, PyObject_GetAttrString\n "
153
+ "setup_decorator()\n "
154
+ "import numpy as np\n "
155
+ "import numpy.pybridge\n " )
163
156
_jit .execute (init_code , "" , 0 , False )
164
157
return _jit
165
158
166
-
167
159
_jit = _reset_jit ()
168
160
169
-
170
161
class RewriteFunctionArgs (ast .NodeTransformer ):
171
162
def __init__ (self , args ):
172
163
self .args = args
@@ -176,7 +167,6 @@ def visit_FunctionDef(self, node):
176
167
node .args .args .append (ast .arg (arg = a , annotation = None ))
177
168
return node
178
169
179
-
180
170
def _obj_to_str (obj , ** kwargs ) -> str :
181
171
if inspect .isclass (obj ):
182
172
lines = inspect .getsourcelines (obj )[0 ]
@@ -185,8 +175,10 @@ def _obj_to_str(obj, **kwargs) -> str:
185
175
obj_name = obj .__name__
186
176
elif callable (obj ) or isinstance (obj , str ):
187
177
is_str = isinstance (obj , str )
188
- lines = [i + '\n ' for i in obj .split ('\n ' )] if is_str else inspect .getsourcelines (obj )[0 ]
189
- if not is_str : lines = lines [1 :]
178
+ lines = [i + '\n ' for i in obj .split ('\n ' )
179
+ ] if is_str else inspect .getsourcelines (obj )[0 ]
180
+ if not is_str :
181
+ lines = lines [1 :]
190
182
obj_str = textwrap .dedent ('' .join (lines ))
191
183
192
184
pyvars = kwargs .get ("pyvars" , None )
@@ -195,8 +187,7 @@ def _obj_to_str(obj, **kwargs) -> str:
195
187
if not isinstance (i , str ):
196
188
raise ValueError ("pyvars only takes string literals" )
197
189
node = ast .fix_missing_locations (
198
- RewriteFunctionArgs (pyvars ).visit (ast .parse (obj_str ))
199
- )
190
+ RewriteFunctionArgs (pyvars ).visit (ast .parse (obj_str )))
200
191
obj_str = astunparse .unparse (node )
201
192
if is_str :
202
193
try :
@@ -206,46 +197,38 @@ def _obj_to_str(obj, **kwargs) -> str:
206
197
else :
207
198
obj_name = obj .__name__
208
199
else :
209
- raise TypeError ("Function or class expected, got " + type (obj ).__name__ )
200
+ raise TypeError ("Function or class expected, got " +
201
+ type (obj ).__name__ )
210
202
return obj_name , obj_str .replace ("_@par" , "@par" )
211
203
212
-
213
204
def _parse_decorated (obj , ** kwargs ):
214
- return _obj_to_str (obj , ** kwargs )
215
-
205
+ return _obj_to_str (obj , ** kwargs )
216
206
217
207
def convert (t ):
218
208
if not hasattr (t , "__slots__" ):
219
- raise JITError ("class '{}' does not have '__slots__' attribute" .format (str (t )))
209
+ raise JITError ("class '{}' does not have '__slots__' attribute" .format (
210
+ str (t )))
220
211
221
212
name = t .__name__
222
213
slots = t .__slots__
223
- code = (
224
- "@tuple\n "
225
- "class "
226
- + name
227
- + "["
228
- + "," .join ("T{}" .format (i ) for i in range (len (slots )))
229
- + "]:\n "
230
- )
214
+ code = ("@tuple\n "
215
+ "class " + name + "[" +
216
+ "," .join ("T{}" .format (i ) for i in range (len (slots ))) + "]:\n " )
231
217
for i , slot in enumerate (slots ):
232
218
code += " {}: T{}\n " .format (slot , i )
233
219
234
220
# PyObject_GetAttrString
235
221
code += " def __from_py__(p: cobj):\n "
236
222
for i , slot in enumerate (slots ):
237
223
code += " a{} = T{}.__from_py__(PyObject_GetAttrString(p, '{}'.ptr))\n " .format (
238
- i , i , slot
239
- )
224
+ i , i , slot )
240
225
code += " return {}({})\n " .format (
241
- name , ", " .join ("a{}" .format (i ) for i in range (len (slots )))
242
- )
226
+ name , ", " .join ("a{}" .format (i ) for i in range (len (slots ))))
243
227
244
228
_jit .execute (code , "" , 0 , False )
245
229
custom_conversions [t ] = name
246
230
return t
247
231
248
-
249
232
def _jit_register_fn (f , pyvars , debug ):
250
233
try :
251
234
obj_name , obj_str = _parse_decorated (f , pyvars = pyvars )
@@ -258,29 +241,46 @@ def _jit_register_fn(f, pyvars, debug):
258
241
_reset_jit ()
259
242
raise
260
243
261
- def _jit_callback_fn (obj_name , module , debug = None , sample_size = 5 , pyvars = None , * args , ** kwargs ):
262
- try :
244
+ def _jit_callback_fn (fn ,
245
+ obj_name ,
246
+ module ,
247
+ debug = None ,
248
+ sample_size = 5 ,
249
+ pyvars = None ,
250
+ * args ,
251
+ ** kwargs ):
252
+ if fn is not None :
253
+ sig = inspect .signature (fn )
254
+ bound_args = sig .bind (* args , ** kwargs )
255
+ bound_args .apply_defaults ()
256
+ args = tuple (bound_args .arguments [param ] for param in sig .parameters )
257
+ else :
263
258
args = (* args , * kwargs .values ())
259
+
260
+ try :
264
261
types = _codon_types (args , debug = debug , sample_size = sample_size )
265
262
if debug :
266
- print ("[python] {}({})" .format (obj_name , list (types )), file = sys . stderr )
267
- return _jit . run_wrapper (
268
- obj_name , list (types ), module , list (pyvars ), args , 1 if debug else 0
269
- )
263
+ print ("[python] {}({})" .format (obj_name , list (types )),
264
+ file = sys . stderr )
265
+ return _jit . run_wrapper ( obj_name , list (types ), module , list (pyvars ),
266
+ args , 1 if debug else 0 )
270
267
except JITError :
271
268
_reset_jit ()
272
269
raise
273
270
274
271
def _jit_str_fn (fstr , debug = None , sample_size = 5 , pyvars = None ):
275
272
obj_name = _jit_register_fn (fstr , pyvars , debug )
273
+
276
274
def wrapped (* args , ** kwargs ):
277
- return _jit_callback_fn (obj_name , "__main__" , debug , sample_size , pyvars , * args , ** kwargs )
278
- return wrapped
275
+ return _jit_callback_fn (None , obj_name , "__main__" , debug , sample_size ,
276
+ pyvars , * args , ** kwargs )
279
277
278
+ return wrapped
280
279
281
280
def jit (fn = None , debug = None , sample_size = 5 , pyvars = None ):
282
281
if not pyvars :
283
282
pyvars = []
283
+
284
284
if not isinstance (pyvars , list ):
285
285
raise ArgumentError ("pyvars must be a list" )
286
286
@@ -289,12 +289,15 @@ def jit(fn=None, debug=None, sample_size=5, pyvars=None):
289
289
290
290
def _decorate (f ):
291
291
obj_name = _jit_register_fn (f , pyvars , debug )
292
+
292
293
@functools .wraps (f )
293
294
def wrapped (* args , ** kwargs ):
294
- return _jit_callback_fn (obj_name , f .__module__ , debug , sample_size , pyvars , * args , ** kwargs )
295
+ return _jit_callback_fn (f , obj_name , f .__module__ , debug ,
296
+ sample_size , pyvars , * args , ** kwargs )
297
+
295
298
return wrapped
296
- return _decorate (fn ) if fn else _decorate
297
299
300
+ return _decorate (fn ) if fn else _decorate
298
301
299
302
def execute (code , debug = False ):
300
303
try :
0 commit comments