30
30
statelib ,
31
31
variablelib ,
32
32
)
33
- from flax .typing import Missing
33
+ from flax .typing import MISSING , Missing
34
34
35
35
F = tp .TypeVar ('F' , bound = tp .Callable [..., tp .Any ])
36
+ P = tp .ParamSpec ('P' )
37
+ R = tp .TypeVar ('R' )
36
38
Specs = tp .Any
37
39
AxisName = tp .Hashable
38
40
@@ -150,10 +152,10 @@ def jit(
150
152
backend : tp .Optional [str ] = None ,
151
153
inline : bool = False ,
152
154
abstracted_axes : tp .Optional [tp .Any ] = None ,
153
- ) -> tp .Callable [[tp .Callable [..., tp . Any ]], JitWrapped ]: ...
155
+ ) -> tp .Callable [[tp .Callable [P , R ]], JitWrapped [ P , R ] ]: ...
154
156
@tp .overload
155
157
def jit (
156
- fun : tp .Callable [..., tp . Any ],
158
+ fun : tp .Callable [P , R ],
157
159
* ,
158
160
in_shardings : tp .Any = None ,
159
161
out_shardings : tp .Any = None ,
@@ -166,9 +168,9 @@ def jit(
166
168
backend : tp .Optional [str ] = None ,
167
169
inline : bool = False ,
168
170
abstracted_axes : tp .Optional [tp .Any ] = None ,
169
- ) -> JitWrapped : ...
171
+ ) -> JitWrapped [ P , R ] : ...
170
172
def jit (
171
- fun : tp .Callable [..., tp . Any ] | type [ Missing ] = Missing ,
173
+ fun : tp .Callable [P , R ] | Missing = MISSING ,
172
174
* ,
173
175
in_shardings : tp .Any = None ,
174
176
out_shardings : tp .Any = None ,
@@ -181,7 +183,7 @@ def jit(
181
183
backend : tp .Optional [str ] = None ,
182
184
inline : bool = False ,
183
185
abstracted_axes : tp .Optional [tp .Any ] = None ,
184
- ) -> JitWrapped | tp .Callable [[tp .Callable [..., tp . Any ]], JitWrapped ]:
186
+ ) -> JitWrapped [ P , R ] | tp .Callable [[tp .Callable [P , R ]], JitWrapped [ P , R ] ]:
185
187
"""
186
188
Lifted version of ``jax.jit`` that can handle Modules / graph nodes as
187
189
arguments.
@@ -302,7 +304,7 @@ def jit(
302
304
A wrapped version of ``fun``, set up for just-in-time compilation.
303
305
"""
304
306
305
- if fun is Missing :
307
+ if isinstance ( fun , Missing ) :
306
308
return functools .partial (
307
309
jit ,
308
310
in_shardings = in_shardings ,
@@ -317,7 +319,6 @@ def jit(
317
319
inline = inline ,
318
320
abstracted_axes = abstracted_axes ,
319
321
) # type: ignore[return-value]
320
-
321
322
return JitWrapped (
322
323
fun ,
323
324
in_shardings = in_shardings ,
@@ -334,7 +335,7 @@ def jit(
334
335
)
335
336
336
337
337
- class JitWrapped :
338
+ class JitWrapped ( tp . Generic [ P , R ]) :
338
339
"""A function ready to be traced, lowered, and compiled.
339
340
340
341
This protocol reflects the output of functions such as
@@ -345,7 +346,7 @@ class JitWrapped:
345
346
346
347
def __init__ (
347
348
self ,
348
- fun : tp .Callable [..., tp . Any ],
349
+ fun : tp .Callable [P , R ],
349
350
in_shardings : tp .Any ,
350
351
out_shardings : tp .Any ,
351
352
static_argnums : int | tp .Sequence [int ] | None = None ,
@@ -359,6 +360,7 @@ def __init__(
359
360
abstracted_axes : tp .Optional [tp .Any ] = None ,
360
361
):
361
362
functools .update_wrapper (self , fun )
363
+ self .fun : tp .Callable [P , R ] = fun
362
364
kwarg_shardings = None
363
365
self .jax_in_shardings = jax .tree .map (
364
366
lambda x : extract .NodeStates .from_prefixes (x .shardings , metadata = x )
@@ -424,7 +426,7 @@ def _get_non_pure_out(self, pure_args_out, pure_kwargs_out, pure_out, /):
424
426
)
425
427
return out
426
428
427
- def __call__ (self , * args , ** kwargs ) :
429
+ def __call__ (self , * args : P . args , ** kwargs : P . kwargs ) -> R :
428
430
# run dynamic_cache_context before update_context
429
431
with graph .update_context (self ):
430
432
pure_args , pure_kwargs = self ._get_pure_args_kwargs (args , kwargs )
0 commit comments