22
33from __future__ import annotations
44
5- import copyreg
65import io
76import math
87import pickle
9- from collections .abc import Callable , Generator , Iterable , Iterator
10- from contextvars import ContextVar
11- from types import ModuleType
12- from typing import TYPE_CHECKING , Any , TypeVar , cast
8+ from collections .abc import Callable , Generator , Hashable , Iterable
9+ from functools import wraps
10+ from types import ModuleType , NoneType
11+ from typing import TYPE_CHECKING , Any , Literal , ParamSpec , TypeVar , cast
1312
1413from . import _compat
1514from ._compat import (
2322from ._typing import Array
2423
2524if TYPE_CHECKING : # pragma: no cover
26- # TODO import from typing (requires Python >=3.13)
27- from typing_extensions import TypeIs
25+ # TODO import from typing (requires Python >=3.12 and >=3.13)
26+ from typing_extensions import TypeIs , override
27+ else :
2828
29+ def override (func ):
30+ return func
31+
32+
33+ P = ParamSpec ("P" )
2934T = TypeVar ("T" )
3035
3136
3540 "eager_shape" ,
3641 "in1d" ,
3742 "is_python_scalar" ,
43+ "jax_autojit" ,
3844 "mean" ,
3945 "meta_namespace" ,
4046 "pickle_without" ,
@@ -316,48 +322,39 @@ def capabilities(xp: ModuleType) -> dict[str, int]:
316322 return out
317323
318324
319- # Helper of ``extract_objects`` and ``repack_objects``
320- _repacking_objects : ContextVar [Iterator [object ]] = ContextVar ("_repacking_objects" )
321-
322-
323- def _expand () -> object : # numpydoc ignore=RT01
324- """
325- Helper of ``extract_objects`` and ``repack_objects``.
326-
327- Inverse of the reducer function.
328-
329- Notes
330- -----
331- This function must be global in order to be picklable.
332- """
333- try :
334- return next (_repacking_objects .get ())
335- except StopIteration :
336- msg = "Not enough objects to repack"
337- raise ValueError (msg )
325+ _BASIC_TYPES = frozenset ((
326+ NoneType , bool , int , float , complex , str , bytes , bytearray ,
327+ list , tuple , dict , set , frozenset , range , slice ,
328+ )) # fmt: skip
338329
339330
340- def pickle_without (obj : object , * classes : type [T ]) -> tuple [bytes , list [T ]]:
331+ def pickle_without (
332+ obj : object , cls : type [T ] | tuple [type [T ], ...] = ()
333+ ) -> tuple [bytes , tuple [T , ...], tuple [object , ...]]:
341334 """
342- Variant of ``pickle.dumps`` that extracts inner objects.
335+ Variant of ``pickle.dumps`` that always succeeds and extracts inner objects.
343336
344337 Conceptually, this is similar to passing the ``buffer_callback`` argument to
345- ``pickle.dumps``, but instead of extracting buffers it extracts entire objects.
338+ ``pickle.dumps``, but instead of extracting buffers it extracts entire objects,
339+ which are either not serializable with ``pickle`` (e.g. local classes or functions)
340+ or instances of an explicit list of classes.
346341
347342 Parameters
348343 ----------
349344 obj : object
350345 The object to pickle.
351- *classes : type
352- One or more classes to extract from the object.
346+ cls : type | tuple[type, ...], optional
347+ One or multiple classes to extract from the object.
353348 The instances of these classes inside ``obj`` will not be pickled.
354349
355350 Returns
356351 -------
357352 bytes
358353 The pickled object. Must be unpickled with :func:`unpickle_without`.
359- list
360- All instances of ``classes`` found inside ``obj`` (not pickled).
354+ tuple
355+ All instances of ``cls`` found inside ``obj`` (not pickled).
356+ tuple
357+ All other objects which failed to pickle.
361358
362359 See Also
363360 --------
@@ -366,75 +363,221 @@ def pickle_without(obj: object, *classes: type[T]) -> tuple[bytes, list[T]]:
366363
367364 Examples
368365 --------
366+ >>> class NS:
367+ ... def __repr__(self):
368+ ... return "<NS>"
369+ ... def __reduce__(self):
370+ ... assert False
369371 >>> class A:
370372 ... def __repr__(self):
371373 ... return "<A>"
372- ... def __reduce__(self):
373- ... assert False, "Not serializable"
374- >>> obj = {1: A(), 2: [A(), A()]} # Any serializable object
375- >>> pik, extracted = pickle_without(obj, A)
376- >>> extracted
377- [<A>, <A>, <A>]
378- >>> unpickle_without(pik, extracted)
379- {1: <A>, 2: [<A>, <A>]}
374+ >>> obj = {1: A(), 2: [A(), NS(), A()]}
375+ >>> pik, instances, unpickleable = pickle_without(obj, A)
376+ >>> instances, unpickleable
377+ ([<A>, <A>, <A>], [<NS>])
378+ >>> unpickle_without(pik, instances, unpickleable)
379+ {1: <A>, 2: [<A>, <NS>, <A>]}
380380
381381 This can be also used to hot-swap inner objects; the only constraint is that
382382 the number of objects in and out must be the same:
383383
384384 >>> class B:
385385 ... def __repr__(self): return "<B>"
386- >>> unpickle_without(pik, [B(), B(), B()])
387- {1: <B>, 2: [<B>, <B>]}
386+ >>> unpickle_without(pik, [B(), B(), B()], [NS()] )
387+ {1: <B>, 2: [<B>, <NS>, < B>]}
388388 """
389- extracted = []
390-
391- def reduce (x : T ) -> tuple [Callable [[], object ], tuple [()]]: # numpydoc ignore=GL08
392- extracted .append (x )
393- return _expand , ()
389+ instances : list [T ] = []
390+ unpickleable : list [object ] = []
391+ seen : dict [int , Literal [0 , 1 , None ]] = {}
392+
393+ class Pickler (pickle .Pickler ): # numpydoc ignore=GL01,RT01
394+ """Override pickle.Pickler.persistent_id.
395+
396+ TODO consider moving to top-level scope to allow for
397+ the full Pickler API to be used.
398+ """
399+
400+ @override
401+ def persistent_id (self , obj : object ) -> object : # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
402+ # Fast exit in case of basic builtin types.
403+ # Note that basic collections (tuple, list, etc.) are in this;
404+ # persistent_id() will be called again with their contents.
405+ if type (obj ) in _BASIC_TYPES : # No subclasses!
406+ return None
407+
408+ id_ = id (obj )
409+ try :
410+ kind = seen [id_ ]
411+ return None if kind is None else (id_ , kind )
412+ except KeyError :
413+ pass
414+
415+ if isinstance (obj , cls ):
416+ instances .append (obj ) # type: ignore[arg-type]
417+ seen [id_ ] = 0
418+ return id_ , 0
419+
420+ for func in (
421+ # Note: a class that defines __slots__ without defining __getstate__
422+ # cannot be pickled with __reduce__(), but can with __reduce_ex__(5)
423+ lambda : obj .__reduce_ex__ (pickle .HIGHEST_PROTOCOL ),
424+ lambda : obj .__reduce__ (),
425+ # Global functions don't have __reduce__, which can be pickled
426+ lambda : pickle .dumps (obj , protocol = pickle .HIGHEST_PROTOCOL ),
427+ ):
428+ try :
429+ # a class that defines __slots__ without defining __getstate__
430+ # cannot be pickled with __reduce__(), but can with __reduce_ex__(5)
431+ func ()
432+ except Exception : # pylint: disable=broad-exception-caught
433+ pass
434+ else : # Can be pickled
435+ seen [id_ ] = None
436+ return None
437+
438+ # Can't be pickled
439+ unpickleable .append (obj )
440+ seen [id_ ] = 1
441+ return id_ , 1
394442
395443 f = io .BytesIO ()
396- p = pickle .Pickler (f )
397-
398- # Override the reducer for the given classes and all their
399- # subclasses (recursively).
400- p .dispatch_table = copyreg .dispatch_table .copy ()
401- subclasses = list (classes )
402- while subclasses :
403- cls = subclasses .pop ()
404- p .dispatch_table [cls ] = reduce
405- subclasses .extend (cls .__subclasses__ ())
406-
444+ p = Pickler (f , protocol = pickle .HIGHEST_PROTOCOL )
407445 p .dump (obj )
446+ return f .getvalue (), tuple (instances ), tuple (unpickleable )
408447
409- return f .getvalue (), extracted
410448
411-
412- def unpickle_without (pik : bytes , objects : Iterable [object ], / ) -> Any : # type: ignore[explicit-any]
449+ def unpickle_without ( # type: ignore[explicit-any]
450+ pik : bytes ,
451+ instances : Iterable [object ],
452+ unpickleable : Iterable [object ],
453+ / ,
454+ ) -> Any :
413455 """
414456 Variant of ``pickle.loads``, reverse of ``pickle_without``.
415457
416458 Parameters
417459 ----------
418460 pik : bytes
419461 The pickled object generated by ``pickle_without``.
420- objects : Iterable
421- The objects to be reinserted into the unpickled object.
422- Must be the at least the same number of elements as the ones extracted by
423- ``pickle_without``, but does not need to be the same objects or even the
424- same types of objects. Excess objects, if any, won't be inserted .
462+ instances : Iterable[object]
463+ Instances of the class or classes explicitly passed to ``pickle_without``,
464+ to be reinserted into the unpickled object.
465+ unpickleable : Iterable[object]
466+ The objects that failed to pickle, as returned by ``pickle_without`` .
425467
426468 Returns
427469 -------
428470 object
429- The unpickled object, with the objects in ``objects`` inserted back into it .
471+ The unpickled object.
430472
431473 See Also
432474 --------
433475 pickle_without : Serializing function.
434476 pickle.loads : Standard unpickle function.
477+
478+ Notes
479+ -----
480+ The second and third parameter of this function must yield at least the same number
481+ of elements as the ones returned by ``pickle_without``, but do not need to be the
482+ same objects, or even the same types of objects. Excess objects, if any, will be
483+ quietly ignored.
484+ """
485+ iters = iter (instances ), iter (unpickleable )
486+ seen : dict [tuple [int , int ], object ] = {}
487+
488+ class Unpickler (pickle .Unpickler ): # numpydoc ignore=GL01,RT01
489+ """
490+ Override pickle.Pickler.persistent_load.
491+
492+ TODO consider moving to top-level scope to allow for
493+ the full Unpickler API to be used.
494+ """
495+
496+ @override
497+ def persistent_load (self , pid : tuple [int , int ]) -> object : # pyright: ignore[reportIncompatibleMethodOverride] # numpydoc ignore=GL08
498+ try :
499+ return seen [pid ]
500+ except KeyError :
501+ pass
502+
503+ _ , kind = pid
504+ try :
505+ obj = next (iters [kind ])
506+ except StopIteration as e :
507+ msg = "Not enough objects to unpickle"
508+ raise ValueError (msg ) from e
509+
510+ seen [pid ] = obj
511+ return obj
512+
513+ f = io .BytesIO (pik )
514+ return Unpickler (f ).load ()
515+
516+
517+ def jax_autojit (
518+ func : Callable [P , T ],
519+ ) -> Callable [P , T ]: # numpydoc ignore=PR01,RT01,SS03
520+ """
521+ Wrap `func` with ``jax.jit``, with the following differences:
522+
523+ - Array-like arguments and return values are not automatically converted to
524+ ``jax.Array`` objects.
525+ - All non-array arguments are automatically treated as static.
526+ Unlike ``jax.jit``, static arguments must be either hashable or serializable with
527+ ``pickle``.
528+ - Unlike ``jax.jit``, non-array arguments and return values are not limited to
529+ tuple/list/dict, but can be any object serializable with ``pickle``.
530+ - Automatically descend into non-array arguments and find ``jax.Array`` objects
531+ inside them, then rebuild the arguments when entering `func`, swapping the JAX
532+ concrete arrays with tracer objects.
533+ - Automatically descend into non-array return values and find ``jax.Array`` objects
534+ inside them, then rebuild them downstream of exiting the JIT, swapping the JAX
535+ tracer objects with concrete arrays.
435536 """
436- tok = _repacking_objects .set (iter (objects ))
437- try :
438- return pickle .loads (pik )
439- finally :
440- _repacking_objects .reset (tok )
537+ import jax
538+
539+ # {
540+ # jit_cache_key(args_pik, args_arrays, args_unpickleable):
541+ # (res_pik, res_unpickleable)
542+ # }
543+ static_return_values : dict [Hashable , tuple [bytes , tuple [object , ...]]] = {}
544+
545+ def jit_cache_key ( # type: ignore[no-any-unimported] # numpydoc ignore=GL08
546+ args_pik : bytes ,
547+ args_arrays : tuple [jax .Array , ...], # pyright: ignore[reportUnknownParameterType]
548+ args_unpickleable : tuple [Hashable , ...],
549+ ) -> Hashable :
550+ return (
551+ args_pik ,
552+ tuple ((arr .shape , arr .dtype ) for arr in args_arrays ), # pyright: ignore[reportUnknownArgumentType]
553+ args_unpickleable ,
554+ )
555+
556+ def inner ( # type: ignore[no-any-unimported] # pyright: ignore[reportUnknownParameterType]
557+ args_pik : bytes ,
558+ args_arrays : tuple [jax .Array , ...], # pyright: ignore[reportUnknownParameterType]
559+ args_unpickleable : tuple [Hashable , ...],
560+ ) -> tuple [jax .Array , ...]: # numpydoc ignore=GL08
561+ args , kwargs = unpickle_without (args_pik , args_arrays , args_unpickleable ) # pyright: ignore[reportUnknownArgumentType]
562+ res = func (* args , ** kwargs ) # pyright: ignore[reportCallIssue]
563+ res_pik , res_arrays , res_unpickleable = pickle_without (res , jax .Array ) # pyright: ignore[reportUnknownArgumentType]
564+ key = jit_cache_key (args_pik , args_arrays , args_unpickleable )
565+ val = res_pik , res_unpickleable
566+ prev = static_return_values .setdefault (key , val )
567+ assert prev == val , "cache key collision"
568+ return res_arrays
569+
570+ jitted = jax .jit (inner , static_argnums = (0 , 2 ))
571+
572+ @wraps (func )
573+ def outer (* args : P .args , ** kwargs : P .kwargs ) -> T : # numpydoc ignore=GL08
574+ args_pik , args_arrays , args_unpickleable = pickle_without (
575+ (args , kwargs ),
576+ jax .Array , # pyright: ignore[reportUnknownArgumentType]
577+ )
578+ res_arrays = jitted (args_pik , args_arrays , args_unpickleable )
579+ key = jit_cache_key (args_pik , args_arrays , args_unpickleable )
580+ res_pik , res_unpickleable = static_return_values [key ]
581+ return unpickle_without (res_pik , res_arrays , res_unpickleable ) # pyright: ignore[reportUnknownArgumentType]
582+
583+ return outer
0 commit comments