Skip to content

Commit 609a0ab

Browse files
author
Flax Authors
committed
Merge pull request #4981 from google:jit-wrapped-types
PiperOrigin-RevId: 811527474
2 parents f8c5943 + 566daf7 commit 609a0ab

File tree

1 file changed

+13
-11
lines changed

1 file changed

+13
-11
lines changed

flax/nnx/transforms/compilation.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,11 @@
3030
statelib,
3131
variablelib,
3232
)
33-
from flax.typing import Missing
33+
from flax.typing import MISSING, Missing
3434

3535
F = tp.TypeVar('F', bound=tp.Callable[..., tp.Any])
36+
P = tp.ParamSpec('P')
37+
R = tp.TypeVar('R')
3638
Specs = tp.Any
3739
AxisName = tp.Hashable
3840

@@ -150,10 +152,10 @@ def jit(
150152
backend: tp.Optional[str] = None,
151153
inline: bool = False,
152154
abstracted_axes: tp.Optional[tp.Any] = None,
153-
) -> tp.Callable[[tp.Callable[..., tp.Any]], JitWrapped]: ...
155+
) -> tp.Callable[[tp.Callable[P, R]], JitWrapped[P, R]]: ...
154156
@tp.overload
155157
def jit(
156-
fun: tp.Callable[..., tp.Any],
158+
fun: tp.Callable[P, R],
157159
*,
158160
in_shardings: tp.Any = None,
159161
out_shardings: tp.Any = None,
@@ -166,9 +168,9 @@ def jit(
166168
backend: tp.Optional[str] = None,
167169
inline: bool = False,
168170
abstracted_axes: tp.Optional[tp.Any] = None,
169-
) -> JitWrapped: ...
171+
) -> JitWrapped[P, R]: ...
170172
def jit(
171-
fun: tp.Callable[..., tp.Any] | type[Missing] = Missing,
173+
fun: tp.Callable[P, R] | Missing = MISSING,
172174
*,
173175
in_shardings: tp.Any = None,
174176
out_shardings: tp.Any = None,
@@ -181,7 +183,7 @@ def jit(
181183
backend: tp.Optional[str] = None,
182184
inline: bool = False,
183185
abstracted_axes: tp.Optional[tp.Any] = None,
184-
) -> JitWrapped | tp.Callable[[tp.Callable[..., tp.Any]], JitWrapped]:
186+
) -> JitWrapped[P, R] | tp.Callable[[tp.Callable[P, R]], JitWrapped[P, R]]:
185187
"""
186188
Lifted version of ``jax.jit`` that can handle Modules / graph nodes as
187189
arguments.
@@ -302,7 +304,7 @@ def jit(
302304
A wrapped version of ``fun``, set up for just-in-time compilation.
303305
"""
304306

305-
if fun is Missing:
307+
if isinstance(fun, Missing):
306308
return functools.partial(
307309
jit,
308310
in_shardings=in_shardings,
@@ -317,7 +319,6 @@ def jit(
317319
inline=inline,
318320
abstracted_axes=abstracted_axes,
319321
) # type: ignore[return-value]
320-
321322
return JitWrapped(
322323
fun,
323324
in_shardings=in_shardings,
@@ -334,7 +335,7 @@ def jit(
334335
)
335336

336337

337-
class JitWrapped:
338+
class JitWrapped(tp.Generic[P, R]):
338339
"""A function ready to be traced, lowered, and compiled.
339340
340341
This protocol reflects the output of functions such as
@@ -345,7 +346,7 @@ class JitWrapped:
345346

346347
def __init__(
347348
self,
348-
fun: tp.Callable[..., tp.Any],
349+
fun: tp.Callable[P, R],
349350
in_shardings: tp.Any,
350351
out_shardings: tp.Any,
351352
static_argnums: int | tp.Sequence[int] | None = None,
@@ -359,6 +360,7 @@ def __init__(
359360
abstracted_axes: tp.Optional[tp.Any] = None,
360361
):
361362
functools.update_wrapper(self, fun)
363+
self.fun: tp.Callable[P, R] = fun
362364
kwarg_shardings = None
363365
self.jax_in_shardings = jax.tree.map(
364366
lambda x: extract.NodeStates.from_prefixes(x.shardings, metadata=x)
@@ -424,7 +426,7 @@ def _get_non_pure_out(self, pure_args_out, pure_kwargs_out, pure_out, /):
424426
)
425427
return out
426428

427-
def __call__(self, *args, **kwargs):
429+
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
428430
# run dynamic_cache_context before update_context
429431
with graph.update_context(self):
430432
pure_args, pure_kwargs = self._get_pure_args_kwargs(args, kwargs)

0 commit comments

Comments
 (0)