Skip to content

Commit b733c86

Browse files
Pathways-on-Cloud Teamcopybara-github
Pathways-on-Cloud Team
authored andcommitted
Group bulk write requests by device list. It avoids setting devices repeatedly for each request.
PiperOrigin-RevId: 743346435
1 parent 5df1e25 commit b733c86

File tree

1 file changed

+70
-21
lines changed

1 file changed

+70
-21
lines changed

pathwaysutils/persistence/helper.py

+70-21
Original file line numberDiff line numberDiff line change
@@ -168,35 +168,52 @@ def get_read_request(
168168
dtype: np.dtype,
169169
shape: Sequence[int],
170170
sharding: jax.sharding.Sharding,
171-
devices: Sequence[jax.Device],
171+
devices: Sequence[jax.Device] | None,
172172
timeout: datetime.timedelta,
173173
return_dict: bool = False,
174174
) -> Union[str, dict[str, Any]]:
175175
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
176-
if not isinstance(devices, np.ndarray):
177-
devices = np.array(devices)
178-
179176
timeout_seconds, timeout_fractional_seconds = divmod(
180177
timeout.total_seconds(), 1
181178
)
182179
timeout_nanoseconds = timeout_fractional_seconds * 1e9
183-
d = {
184-
"persistenceReadRequest": {
185-
"b64_location": string_to_base64(location_path),
186-
"shape": get_shape_info(dtype, shape),
187-
"b64_name": string_to_base64(name),
188-
"b64_hlo_sharding_string": get_hlo_sharding_string(
189-
sharding, len(shape)
190-
),
191-
"devices": {
192-
"device_ids": [device.id for device in devices.flatten()]
193-
},
194-
"timeout": {
195-
"seconds": int(timeout_seconds),
196-
"nanos": int(timeout_nanoseconds),
197-
},
198-
}
199-
}
180+
181+
if devices is None:
182+
d = {
183+
"persistenceReadRequest": {
184+
"b64_location": string_to_base64(location_path),
185+
"shape": get_shape_info(dtype, shape),
186+
"b64_name": string_to_base64(name),
187+
"b64_hlo_sharding_string": get_hlo_sharding_string(
188+
sharding, len(shape)
189+
),
190+
"timeout": {
191+
"seconds": int(timeout_seconds),
192+
"nanos": int(timeout_nanoseconds),
193+
},
194+
}
195+
}
196+
else:
197+
if not isinstance(devices, np.ndarray):
198+
devices = np.array(devices)
199+
200+
d = {
201+
"persistenceReadRequest": {
202+
"b64_location": string_to_base64(location_path),
203+
"shape": get_shape_info(dtype, shape),
204+
"b64_name": string_to_base64(name),
205+
"b64_hlo_sharding_string": get_hlo_sharding_string(
206+
sharding, len(shape)
207+
),
208+
"devices": {
209+
"device_ids": [device.id for device in devices.flatten()]
210+
},
211+
"timeout": {
212+
"seconds": int(timeout_seconds),
213+
"nanos": int(timeout_nanoseconds),
214+
},
215+
}
216+
}
200217

201218
if return_dict:
202219
return d
@@ -224,6 +241,38 @@ def get_bulk_read_request(
224241
)
225242

226243

244+
def get_bulk_read_request_per_device_list(
245+
location_path: str,
246+
names: Sequence[str],
247+
dtypes: Sequence[np.dtype],
248+
shapes: Sequence[Sequence[int]],
249+
shardings: Sequence[jax.sharding.Sharding],
250+
devices: Sequence[jax.Device],
251+
timeout: datetime.timedelta,
252+
) -> str:
253+
"""Returns a string representation of a bulk read request, reads multiple arrays with one call."""
254+
read_requests = [
255+
get_read_request(
256+
location_path, name, dtype, shape, sharding, None, timeout, True
257+
)["persistenceReadRequest"]
258+
for name, dtype, shape, sharding in zip(names, dtypes, shapes, shardings)
259+
]
260+
261+
if not isinstance(devices, np.ndarray):
262+
devices = np.array(devices)
263+
264+
return json.dumps({
265+
"bulk_persistence_read_request": {
266+
"read_requests_per_device_list": {
267+
"device_list": {
268+
"device_ids": [device.id for device in devices.flatten()]
269+
},
270+
"read_requests": read_requests,
271+
}
272+
}
273+
})
274+
275+
227276
def write_one_array(
228277
location: str,
229278
name: str,

0 commit comments

Comments
 (0)