Skip to content

Commit 14b2c90

Browse files
yashk2810Google-ML-Automation
authored andcommitted
Allow GSPMDSharding constructor to take in device_list (xc.DeviceList) as input along with Sequence[jax.Device]. This prevents extremely slow tuple(devices) -> DeviceList conversion in the GSPMDSharding constructor.
PiperOrigin-RevId: 778627673
1 parent 28fd600 commit 14b2c90

File tree

8 files changed

+115
-70
lines changed

8 files changed

+115
-70
lines changed

jax/_src/array.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -826,7 +826,7 @@ def get_data(index: Index | None) -> ArrayImpl | np.ndarray:
826826
)
827827

828828
if dll is not None:
829-
devices = [Format(dll, SingleDeviceSharding(d)) for d in devices]
829+
devices = [Format(dll, SingleDeviceSharding(d)) for d in devices] # type: ignore
830830
# pxla.batched_device_put doesn't support Layout... Take the slow route
831831
arrays = api.device_put(per_device_values, devices)
832832
return ArrayImpl(aval, sharding, arrays, committed=True)

jax/_src/interpreters/pxla.py

Lines changed: 41 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@
5959
from jax._src.interpreters import xla
6060
from jax._src.layout import Layout, AutoLayout, Format
6161
from jax._src.lib import xla_client as xc
62+
from jax._src.lib import jaxlib_extension_version
6263
from jax._src.lib.mlir import ir
6364
from jax._src.lib.mlir.dialects import hlo
6465
from jax._src.partition_spec import PartitionSpec
@@ -2085,9 +2086,12 @@ class AllArgsInfo(NamedTuple):
20852086
def to_gspmd_sharding(s: JSharding, ndim: int) -> GSPMDSharding:
20862087
if isinstance(s, GSPMDSharding):
20872088
return s
2088-
return GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
2089-
memory_kind=s.memory_kind,
2090-
_device_list=getattr(s, '_internal_device_list', None))
2089+
if jaxlib_extension_version >= 360:
2090+
return GSPMDSharding(s._internal_device_list, s._to_xla_hlo_sharding(ndim),
2091+
memory_kind=s.memory_kind)
2092+
else:
2093+
return GSPMDSharding(s._device_assignment, s._to_xla_hlo_sharding(ndim),
2094+
memory_kind=s.memory_kind)
20912095

20922096

20932097
def _discharge_refs_jaxpr(closed_jaxpr, in_shardings, in_layouts,
@@ -2477,7 +2481,7 @@ def get_pspec_from_executable(
24772481

24782482
def get_out_shardings_from_executable(
24792483
xla_executable,
2480-
device_assignment: Sequence[xc.Device],
2484+
device_list: xc.DeviceList,
24812485
num_out_avals: int,
24822486
num_ordered_effects: int,
24832487
) -> Sequence[sharding_impls.GSPMDSharding] | None:
@@ -2492,9 +2496,14 @@ def get_out_shardings_from_executable(
24922496

24932497
# When the device assignment only has 1 device, SPMD partitioner will not run.
24942498
# Hence the op shardings will not be set on the `hlo_module`.
2495-
if len(device_assignment) == 1:
2496-
return [sharding_impls.GSPMDSharding.get_replicated(device_assignment, memory_kind=mk)
2497-
for mk in omk]
2499+
if len(device_list) == 1:
2500+
if jaxlib_extension_version >= 360:
2501+
return [sharding_impls.GSPMDSharding.get_replicated(device_list, memory_kind=mk)
2502+
for mk in omk]
2503+
else:
2504+
da = tuple(device_list)
2505+
return [sharding_impls.GSPMDSharding.get_replicated(da, memory_kind=mk)
2506+
for mk in omk]
24982507

24992508
_, out_op_shardings = get_op_sharding_from_executable(xla_executable)
25002509
if not out_op_shardings:
@@ -2518,19 +2527,27 @@ def get_out_shardings_from_executable(
25182527
assert len(out_op_shardings) == num_out_avals == len(omk), (
25192528
len(out_op_shardings), num_out_avals, len(omk))
25202529

2521-
return [sharding_impls.GSPMDSharding(device_assignment, os, memory_kind=mk)
2522-
for os, mk in safe_zip(out_op_shardings, omk)]
2530+
if jaxlib_extension_version >= 360:
2531+
return [sharding_impls.GSPMDSharding(device_list, os, memory_kind=mk)
2532+
for os, mk in safe_zip(out_op_shardings, omk)]
2533+
else:
2534+
da = tuple(device_list)
2535+
return [sharding_impls.GSPMDSharding(da, os, memory_kind=mk)
2536+
for os, mk in safe_zip(out_op_shardings, omk)]
25232537

25242538

25252539
def _get_in_shardings_from_xla(
2526-
xla_executable, device_assignment: Sequence[xc.Device], num_in_avals: int,
2540+
xla_executable, device_list: xc.DeviceList, num_in_avals: int,
25272541
num_ordered_effects: int
25282542
) -> Sequence[GSPMDSharding] | None:
25292543
"""Returns input shardings from XLA."""
25302544
# When the device assignment only has 1 device, SPMD partitioner will not run.
25312545
# Hence the op shardings will not be set on the `hlo_module`.
2532-
if len(device_assignment) == 1:
2533-
return [GSPMDSharding.get_replicated(device_assignment)] * num_in_avals
2546+
if len(device_list) == 1:
2547+
if jaxlib_extension_version >= 360:
2548+
return [GSPMDSharding.get_replicated(device_list)] * num_in_avals
2549+
else:
2550+
return [GSPMDSharding.get_replicated(tuple(device_list))] * num_in_avals
25342551

25352552
in_op_shardings, _ = get_op_sharding_from_executable(xla_executable)
25362553
if not in_op_shardings:
@@ -2542,8 +2559,11 @@ def _get_in_shardings_from_xla(
25422559
assert len(in_op_shardings) == num_in_avals, (
25432560
len(in_op_shardings), num_in_avals)
25442561

2545-
return [GSPMDSharding(device_assignment, os)
2546-
for os in in_op_shardings]
2562+
if jaxlib_extension_version >= 360:
2563+
return [GSPMDSharding(device_list, os) for os in in_op_shardings]
2564+
else:
2565+
da = tuple(device_list)
2566+
return [GSPMDSharding(da, os) for os in in_op_shardings]
25472567

25482568

25492569
# TODO(yashkatariya): Remove this function after `AUTO` can return shardings
@@ -2758,8 +2778,8 @@ def _cached_compilation(computation, name, mesh, spmd_lowering,
27582778

27592779

27602780
def _maybe_get_and_check_in_shardings(
2761-
xla_executable, in_shardings, device_assignment,
2762-
global_in_avals, num_ordered_effects):
2781+
xla_executable, in_shardings, device_list, global_in_avals,
2782+
num_ordered_effects):
27632783
"""Returns in_shardings extracted from XLA or checks and returns original
27642784
shardings.
27652785
@@ -2770,8 +2790,7 @@ def _maybe_get_and_check_in_shardings(
27702790
If in_sharding is unspecified, then the sharding returned by XLA is returned.
27712791
"""
27722792
in_shardings_xla = _get_in_shardings_from_xla(
2773-
xla_executable, device_assignment, len(global_in_avals),
2774-
num_ordered_effects)
2793+
xla_executable, device_list, len(global_in_avals), num_ordered_effects)
27752794
if in_shardings_xla is None:
27762795
return in_shardings
27772796

@@ -2802,11 +2821,11 @@ def _maybe_get_and_check_in_shardings(
28022821

28032822

28042823
def _maybe_get_and_check_out_shardings(
2805-
xla_executable, out_shardings, device_assignment, global_out_avals,
2824+
xla_executable, out_shardings, device_list, global_out_avals,
28062825
num_ordered_effects
28072826
):
28082827
out_shardings_xla = get_out_shardings_from_executable(
2809-
xla_executable, device_assignment, len(global_out_avals),
2828+
xla_executable, device_list, len(global_out_avals),
28102829
num_ordered_effects)
28112830
if out_shardings_xla is None:
28122831
return out_shardings
@@ -2987,10 +3006,10 @@ def from_hlo(name: str,
29873006
if pmap_nreps == 1:
29883007
assert mesh is None
29893008
in_shardings = _maybe_get_and_check_in_shardings(
2990-
xla_executable, in_shardings, tuple(device_list), global_in_avals,
3009+
xla_executable, in_shardings, device_list, global_in_avals,
29913010
len(ordered_effects))
29923011
out_shardings = _maybe_get_and_check_out_shardings(
2993-
xla_executable, out_shardings, tuple(device_list), global_out_avals,
3012+
xla_executable, out_shardings, device_list, global_out_avals,
29943013
len(ordered_effects))
29953014
else:
29963015
in_shardings, out_shardings, committed, device_list = _get_metadata_jit_pmap(

jax/_src/pjit.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
from jax._src.lib.mlir.dialects import func as func_dialect
6262
from jax._src.lib import jax_jit
6363
from jax._src.lib import xla_client as xc
64+
from jax._src.lib import jaxlib_extension_version
6465
from jax._src.mesh import AbstractMesh
6566
from jax._src.sharding import Sharding
6667
from jax._src.sharding_impls import (
@@ -2151,8 +2152,7 @@ def _pjit_batcher(axis_data, vals_in,
21512152

21522153

21532154
def _pjit_batcher_for_sharding(
2154-
s: Sharding | UnspecifiedValue,
2155-
dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None,
2155+
s, dim: int | batching.RaggedAxis, spmd_axis_name: tuple[str, ...] | None,
21562156
mesh, ndim: int):
21572157
if isinstance(s, UnspecifiedValue):
21582158
return s
@@ -2167,9 +2167,10 @@ def _pjit_batcher_for_sharding(
21672167
tad = list(new_op.tile_assignment_dimensions)
21682168
tad.insert(dim, 1) # type: ignore
21692169
new_op.tile_assignment_dimensions = tad
2170-
new_gs = GSPMDSharding(
2171-
s._device_assignment, new_op,
2172-
_device_list=getattr(s, '_internal_device_list', None))
2170+
if jaxlib_extension_version >= 360:
2171+
new_gs = GSPMDSharding(s._internal_device_list, new_op)
2172+
else:
2173+
new_gs = GSPMDSharding(s._device_assignment, new_op)
21732174
return pxla._get_out_sharding_from_orig_sharding([new_gs], [None], s, None)[0]
21742175
else:
21752176
if isinstance(s, NamedSharding) and isinstance(s.mesh, AbstractMesh):

jax/_src/sharding.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,10 @@ def with_memory_kind(self, kind: str) -> Sharding:
126126
def _device_assignment(self) -> XLADeviceAssignment:
127127
raise NotImplementedError('Subclasses should implement this method.')
128128

129+
@property
130+
def _internal_device_list(self) -> xc.DeviceList:
131+
raise NotImplementedError('Subclasses should implement this method.')
132+
129133
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
130134
raise NotImplementedError('Subclasses should implement this method.')
131135

jax/_src/sharding_impls.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
from jax._src import xla_bridge as xb
3434
from jax._src import mesh_utils
3535
from jax._src.lib import xla_client as xc
36+
from jax._src.lib import jaxlib_extension_version
3637
from jax._src.lib.mlir.dialects import sdy
3738
from jax._src.named_sharding import ( # noqa: F401
3839
SdyArray, SdyDim, UnspecifiedValue, AUTO,
@@ -360,22 +361,20 @@ def _unpickle_gspmd_sharding(devices, op_sharding, memory_kind):
360361

361362
@use_cpp_class(xc.GSPMDSharding)
362363
class GSPMDSharding(jsharding.Sharding):
363-
_devices: tuple[Device, ...]
364+
_devices: xc.DeviceList
364365
_hlo_sharding: xc.HloSharding
365366
_memory_kind: str | None
366-
_device_list: xc.DeviceList | None
367367
_internal_device_list: xc.DeviceList
368368

369369
@use_cpp_method()
370-
def __init__(self, devices: Sequence[Device],
370+
def __init__(self, devices: Sequence[Device] | xc.DeviceList,
371371
op_sharding: xc.OpSharding | xc.HloSharding,
372-
*, memory_kind: str | None = None,
373-
_device_list: xc.DeviceList | None = None):
374-
self._devices = tuple(devices)
375-
if isinstance(op_sharding, xc.OpSharding):
376-
self._hlo_sharding = xc.HloSharding.from_proto(op_sharding)
377-
else:
378-
self._hlo_sharding = op_sharding
372+
*, memory_kind: str | None = None):
373+
self._devices = (devices if isinstance(devices, xc.DeviceList) else
374+
xc.DeviceList(tuple(devices)))
375+
self._hlo_sharding = (xc.HloSharding.from_proto(op_sharding)
376+
if isinstance(op_sharding, xc.OpSharding) else
377+
op_sharding)
379378
self._memory_kind = memory_kind
380379

381380
def __reduce__(self):
@@ -417,7 +416,7 @@ def check_compatible_aval(self, aval_shape: Shape) -> None:
417416

418417
@property
419418
def num_devices(self) -> int:
420-
return len(self.device_set)
419+
return len(self._internal_device_list)
421420

422421
@functools.cached_property
423422
def device_set(self) -> set[Device]:
@@ -432,7 +431,7 @@ def with_memory_kind(self, kind: str) -> GSPMDSharding:
432431

433432
@property
434433
def _device_assignment(self) -> XLADeviceAssignment:
435-
return self._devices
434+
return tuple(self._devices)
436435

437436
def _to_xla_hlo_sharding(self, num_dimensions: int) -> xc.HloSharding:
438437
return self._hlo_sharding
@@ -468,7 +467,7 @@ def is_fully_addressable(self) -> bool:
468467

469468
@classmethod
470469
def get_replicated(cls, device_assignment, *, memory_kind: str | None = None):
471-
return cls(tuple(device_assignment), replicated_hlo_sharding,
470+
return cls(device_assignment, replicated_hlo_sharding,
472471
memory_kind=memory_kind)
473472

474473

@@ -982,12 +981,15 @@ def make_key_array_phys_sharding(aval, sharding):
982981
return sharding.update(spec=PartitionSpec(*sharding.spec, *trailing_spec))
983982
else:
984983
hlos = sharding._to_xla_hlo_sharding(aval.ndim)
985-
return GSPMDSharding(
986-
sharding._device_assignment, physical_hlo_sharding(aval, hlos))
984+
if jaxlib_extension_version >= 360:
985+
return GSPMDSharding(
986+
sharding._internal_device_list, physical_hlo_sharding(aval, hlos))
987+
else:
988+
return GSPMDSharding(
989+
sharding._device_assignment, physical_hlo_sharding(aval, hlos))
987990

988991

989-
def physical_sharding(
990-
aval, sharding: jsharding.Sharding) -> jsharding.Sharding:
992+
def physical_sharding(aval, sharding: jsharding.Sharding) -> jsharding.Sharding:
991993
return make_key_array_phys_sharding(aval, sharding)
992994

993995

@@ -1001,8 +1003,12 @@ def get_logical_gspmd_sharding(logical_shape, dtype, phys_sharding):
10011003
logical_op_sharding = phys_hlo_sharding.to_proto().clone()
10021004
tad = partitions[:-elt_aval.ndim] + suffix
10031005
logical_op_sharding.tile_assignment_dimensions = tad
1004-
return GSPMDSharding(phys_sharding._device_assignment,
1005-
xc.HloSharding.from_proto(logical_op_sharding))
1006+
if jaxlib_extension_version >= 360:
1007+
return GSPMDSharding(phys_sharding._internal_device_list,
1008+
xc.HloSharding.from_proto(logical_op_sharding))
1009+
else:
1010+
return GSPMDSharding(phys_sharding._device_assignment,
1011+
xc.HloSharding.from_proto(logical_op_sharding))
10061012

10071013
def check_replicated_trailing_dims(sharding: jsharding.Sharding,
10081014
logical_shape, dtype):

jaxlib/sharding.cc

Lines changed: 18 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -271,22 +271,18 @@ PmapSharding::PmapSharding(xla::nb_numpy_ndarray devices,
271271
type_ = nanobind::type<PmapSharding>().inc_ref().ptr();
272272
}
273273

274-
GSPMDSharding::GSPMDSharding(nb::sequence devices, xla::HloSharding op_sharding,
275-
nb::object memory_kind, nb::object device_list)
274+
GSPMDSharding::GSPMDSharding(xla::nb_class_ptr<PyDeviceList> devices,
275+
xla::HloSharding op_sharding,
276+
nb::object memory_kind)
276277
: Sharding(/*num_devices=*/nb::len(devices.ptr())),
277-
devices_(nb::tuple(devices)),
278+
devices_(std::move(devices)),
278279
hlo_sharding_(std::move(op_sharding)),
279280
memory_kind_(std::move(memory_kind)) {
280-
if (device_list.is_none()) {
281-
internal_device_list_ = xla::make_nb_class<PyDeviceList>(devices_);
282-
} else {
283-
internal_device_list_ =
284-
nb::cast<xla::nb_class_ptr<jax::PyDeviceList>>(std::move(device_list));
285-
}
281+
internal_device_list_ = devices_;
286282
// This checks in python if the memory kind is correct for the given
287283
// devices. Currently in python this check is optimized but we want to
288284
// move that check to C++ after which we can remove this call.
289-
CHECK(devices_.size() != 0)
285+
CHECK(devices_->Len() != 0)
290286
<< "Devices given to GSPMDSharding must not be empty";
291287
memory_kind_ =
292288
CheckAndCanonicalizeMemoryKind(memory_kind_, internal_device_list_);
@@ -346,14 +342,20 @@ void RegisterSharding(nb::module_& m) {
346342
PmapSharding::InitializeType();
347343

348344
nb::class_<GSPMDSharding, Sharding>(m, "GSPMDSharding", nb::dynamic_attr())
349-
.def(nb::init<nb::sequence, xla::OpSharding, nb::object, nb::object>(),
345+
.def(nb::init<nb::sequence, xla::OpSharding, nb::object>(),
350346
nb::arg("devices"), nb::arg("op_sharding"),
351-
nb::arg("memory_kind").none() = nb::none(),
352-
nb::arg("_device_list").none() = nb::none())
353-
.def(nb::init<nb::sequence, xla::HloSharding, nb::object, nb::object>(),
347+
nb::arg("memory_kind").none() = nb::none())
348+
.def(nb::init<nb::sequence, xla::HloSharding, nb::object>(),
354349
nb::arg("devices"), nb::arg("op_sharding"),
355-
nb::arg("memory_kind").none() = nb::none(),
356-
nb::arg("_device_list").none() = nb::none())
350+
nb::arg("memory_kind").none() = nb::none())
351+
.def(nb::init<xla::nb_class_ptr<PyDeviceList>, xla::OpSharding,
352+
nb::object>(),
353+
nb::arg("devices"), nb::arg("op_sharding"),
354+
nb::arg("memory_kind").none() = nb::none())
355+
.def(nb::init<xla::nb_class_ptr<PyDeviceList>, xla::HloSharding,
356+
nb::object>(),
357+
nb::arg("devices"), nb::arg("op_sharding"),
358+
nb::arg("memory_kind").none() = nb::none())
357359
.def_prop_ro("_devices", &GSPMDSharding::devices)
358360
.def_prop_ro("_hlo_sharding", &GSPMDSharding::hlo_sharding)
359361
.def_prop_ro("_memory_kind", &GSPMDSharding::memory_kind)

jaxlib/sharding.h

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -164,16 +164,29 @@ class PmapSharding : public Sharding {
164164
class GSPMDSharding : public Sharding {
165165
public:
166166
GSPMDSharding(nanobind::sequence devices, xla::OpSharding op_sharding,
167-
nanobind::object memory_kind, nanobind::object device_list)
167+
nanobind::object memory_kind)
168168
: GSPMDSharding(
169-
std::move(devices),
169+
xla::make_nb_class<PyDeviceList>(nanobind::tuple(devices)),
170170
xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)),
171-
std::move(memory_kind), std::move(device_list)) {}
171+
std::move(memory_kind)) {}
172172

173173
GSPMDSharding(nanobind::sequence devices, xla::HloSharding op_sharding,
174-
nanobind::object memory_kind, nanobind::object device_list);
174+
nanobind::object memory_kind)
175+
: GSPMDSharding(
176+
xla::make_nb_class<PyDeviceList>(nanobind::tuple(devices)),
177+
std::move(op_sharding), std::move(memory_kind)) {}
178+
179+
GSPMDSharding(xla::nb_class_ptr<PyDeviceList> devices,
180+
xla::OpSharding op_sharding, nanobind::object memory_kind)
181+
: GSPMDSharding(
182+
std::move(devices),
183+
xla::ValueOrThrow(xla::HloSharding::FromProto(op_sharding)),
184+
std::move(memory_kind)) {}
185+
186+
GSPMDSharding(xla::nb_class_ptr<PyDeviceList> devices,
187+
xla::HloSharding op_sharding, nanobind::object memory_kind);
175188

176-
const nanobind::tuple& devices() const { return devices_; }
189+
xla::nb_class_ptr<PyDeviceList> devices() const { return devices_; }
177190
const nanobind::object& memory_kind() const { return memory_kind_; }
178191

179192
size_t Hash() {
@@ -226,7 +239,7 @@ class GSPMDSharding : public Sharding {
226239
return hlo_sharding().IsReplicated();
227240
}
228241

229-
nanobind::tuple devices_;
242+
xla::nb_class_ptr<PyDeviceList> devices_;
230243
xla::HloSharding hlo_sharding_;
231244
nanobind::object memory_kind_;
232245
std::optional<size_t> hash_;

0 commit comments

Comments
 (0)