59
59
from jax ._src .interpreters import xla
60
60
from jax ._src .layout import Layout , AutoLayout , Format
61
61
from jax ._src .lib import xla_client as xc
62
+ from jax ._src .lib import jaxlib_extension_version
62
63
from jax ._src .lib .mlir import ir
63
64
from jax ._src .lib .mlir .dialects import hlo
64
65
from jax ._src .partition_spec import PartitionSpec
@@ -2085,9 +2086,12 @@ class AllArgsInfo(NamedTuple):
2085
2086
def to_gspmd_sharding (s : JSharding , ndim : int ) -> GSPMDSharding :
2086
2087
if isinstance (s , GSPMDSharding ):
2087
2088
return s
2088
- return GSPMDSharding (s ._device_assignment , s ._to_xla_hlo_sharding (ndim ),
2089
- memory_kind = s .memory_kind ,
2090
- _device_list = getattr (s , '_internal_device_list' , None ))
2089
+ if jaxlib_extension_version >= 360 :
2090
+ return GSPMDSharding (s ._internal_device_list , s ._to_xla_hlo_sharding (ndim ),
2091
+ memory_kind = s .memory_kind )
2092
+ else :
2093
+ return GSPMDSharding (s ._device_assignment , s ._to_xla_hlo_sharding (ndim ),
2094
+ memory_kind = s .memory_kind )
2091
2095
2092
2096
2093
2097
def _discharge_refs_jaxpr (closed_jaxpr , in_shardings , in_layouts ,
@@ -2477,7 +2481,7 @@ def get_pspec_from_executable(
2477
2481
2478
2482
def get_out_shardings_from_executable (
2479
2483
xla_executable ,
2480
- device_assignment : Sequence [ xc .Device ] ,
2484
+ device_list : xc .DeviceList ,
2481
2485
num_out_avals : int ,
2482
2486
num_ordered_effects : int ,
2483
2487
) -> Sequence [sharding_impls .GSPMDSharding ] | None :
@@ -2492,9 +2496,14 @@ def get_out_shardings_from_executable(
2492
2496
2493
2497
# When the device assignment only has 1 device, SPMD partitioner will not run.
2494
2498
# Hence the op shardings will not be set on the `hlo_module`.
2495
- if len (device_assignment ) == 1 :
2496
- return [sharding_impls .GSPMDSharding .get_replicated (device_assignment , memory_kind = mk )
2497
- for mk in omk ]
2499
+ if len (device_list ) == 1 :
2500
+ if jaxlib_extension_version >= 360 :
2501
+ return [sharding_impls .GSPMDSharding .get_replicated (device_list , memory_kind = mk )
2502
+ for mk in omk ]
2503
+ else :
2504
+ da = tuple (device_list )
2505
+ return [sharding_impls .GSPMDSharding .get_replicated (da , memory_kind = mk )
2506
+ for mk in omk ]
2498
2507
2499
2508
_ , out_op_shardings = get_op_sharding_from_executable (xla_executable )
2500
2509
if not out_op_shardings :
@@ -2518,19 +2527,27 @@ def get_out_shardings_from_executable(
2518
2527
assert len (out_op_shardings ) == num_out_avals == len (omk ), (
2519
2528
len (out_op_shardings ), num_out_avals , len (omk ))
2520
2529
2521
- return [sharding_impls .GSPMDSharding (device_assignment , os , memory_kind = mk )
2522
- for os , mk in safe_zip (out_op_shardings , omk )]
2530
+ if jaxlib_extension_version >= 360 :
2531
+ return [sharding_impls .GSPMDSharding (device_list , os , memory_kind = mk )
2532
+ for os , mk in safe_zip (out_op_shardings , omk )]
2533
+ else :
2534
+ da = tuple (device_list )
2535
+ return [sharding_impls .GSPMDSharding (da , os , memory_kind = mk )
2536
+ for os , mk in safe_zip (out_op_shardings , omk )]
2523
2537
2524
2538
2525
2539
def _get_in_shardings_from_xla (
2526
- xla_executable , device_assignment : Sequence [ xc .Device ] , num_in_avals : int ,
2540
+ xla_executable , device_list : xc .DeviceList , num_in_avals : int ,
2527
2541
num_ordered_effects : int
2528
2542
) -> Sequence [GSPMDSharding ] | None :
2529
2543
"""Returns input shardings from XLA."""
2530
2544
# When the device assignment only has 1 device, SPMD partitioner will not run.
2531
2545
# Hence the op shardings will not be set on the `hlo_module`.
2532
- if len (device_assignment ) == 1 :
2533
- return [GSPMDSharding .get_replicated (device_assignment )] * num_in_avals
2546
+ if len (device_list ) == 1 :
2547
+ if jaxlib_extension_version >= 360 :
2548
+ return [GSPMDSharding .get_replicated (device_list )] * num_in_avals
2549
+ else :
2550
+ return [GSPMDSharding .get_replicated (tuple (device_list ))] * num_in_avals
2534
2551
2535
2552
in_op_shardings , _ = get_op_sharding_from_executable (xla_executable )
2536
2553
if not in_op_shardings :
@@ -2542,8 +2559,11 @@ def _get_in_shardings_from_xla(
2542
2559
assert len (in_op_shardings ) == num_in_avals , (
2543
2560
len (in_op_shardings ), num_in_avals )
2544
2561
2545
- return [GSPMDSharding (device_assignment , os )
2546
- for os in in_op_shardings ]
2562
+ if jaxlib_extension_version >= 360 :
2563
+ return [GSPMDSharding (device_list , os ) for os in in_op_shardings ]
2564
+ else :
2565
+ da = tuple (device_list )
2566
+ return [GSPMDSharding (da , os ) for os in in_op_shardings ]
2547
2567
2548
2568
2549
2569
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
@@ -2758,8 +2778,8 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
2758
2778
2759
2779
2760
2780
def _maybe_get_and_check_in_shardings (
2761
- xla_executable , in_shardings , device_assignment ,
2762
- global_in_avals , num_ordered_effects ):
2781
+ xla_executable , in_shardings , device_list , global_in_avals ,
2782
+ num_ordered_effects ):
2763
2783
"""Returns in_shardings extracted from XLA or checks and returns original
2764
2784
shardings.
2765
2785
@@ -2770,8 +2790,7 @@ def _maybe_get_and_check_in_shardings(
2770
2790
If in_sharding is unspecified, then the sharding returned by XLA is returned.
2771
2791
"""
2772
2792
in_shardings_xla = _get_in_shardings_from_xla (
2773
- xla_executable , device_assignment , len (global_in_avals ),
2774
- num_ordered_effects )
2793
+ xla_executable , device_list , len (global_in_avals ), num_ordered_effects )
2775
2794
if in_shardings_xla is None :
2776
2795
return in_shardings
2777
2796
@@ -2802,11 +2821,11 @@ def _maybe_get_and_check_in_shardings(
2802
2821
2803
2822
2804
2823
def _maybe_get_and_check_out_shardings (
2805
- xla_executable , out_shardings , device_assignment , global_out_avals ,
2824
+ xla_executable , out_shardings , device_list , global_out_avals ,
2806
2825
num_ordered_effects
2807
2826
):
2808
2827
out_shardings_xla = get_out_shardings_from_executable (
2809
- xla_executable , device_assignment , len (global_out_avals ),
2828
+ xla_executable , device_list , len (global_out_avals ),
2810
2829
num_ordered_effects )
2811
2830
if out_shardings_xla is None :
2812
2831
return out_shardings
@@ -2987,10 +3006,10 @@ def from_hlo(name: str,
2987
3006
if pmap_nreps == 1 :
2988
3007
assert mesh is None
2989
3008
in_shardings = _maybe_get_and_check_in_shardings (
2990
- xla_executable , in_shardings , tuple ( device_list ) , global_in_avals ,
3009
+ xla_executable , in_shardings , device_list , global_in_avals ,
2991
3010
len (ordered_effects ))
2992
3011
out_shardings = _maybe_get_and_check_out_shardings (
2993
- xla_executable , out_shardings , tuple ( device_list ) , global_out_avals ,
3012
+ xla_executable , out_shardings , device_list , global_out_avals ,
2994
3013
len (ordered_effects ))
2995
3014
else :
2996
3015
in_shardings , out_shardings , committed , device_list = _get_metadata_jit_pmap (
0 commit comments