Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Group bulk write requests by device list. It avoids setting devices repeatedly for each request. #59

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 70 additions & 21 deletions pathwaysutils/persistence/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,35 +168,52 @@ def get_read_request(
dtype: np.dtype,
shape: Sequence[int],
sharding: jax.sharding.Sharding,
devices: Sequence[jax.Device],
devices: Sequence[jax.Device] | None,
timeout: datetime.timedelta,
return_dict: bool = False,
) -> Union[str, dict[str, Any]]:
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
if not isinstance(devices, np.ndarray):
devices = np.array(devices)

timeout_seconds, timeout_fractional_seconds = divmod(
timeout.total_seconds(), 1
)
timeout_nanoseconds = timeout_fractional_seconds * 1e9
d = {
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"shape": get_shape_info(dtype, shape),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
sharding, len(shape)
),
"devices": {
"device_ids": [device.id for device in devices.flatten()]
},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
}

if devices is None:
d = {
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"shape": get_shape_info(dtype, shape),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
sharding, len(shape)
),
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
}
else:
if not isinstance(devices, np.ndarray):
devices = np.array(devices)

d = {
"persistenceReadRequest": {
"b64_location": string_to_base64(location_path),
"shape": get_shape_info(dtype, shape),
"b64_name": string_to_base64(name),
"b64_hlo_sharding_string": get_hlo_sharding_string(
sharding, len(shape)
),
"devices": {
"device_ids": [device.id for device in devices.flatten()]
},
"timeout": {
"seconds": int(timeout_seconds),
"nanos": int(timeout_nanoseconds),
},
}
}

if return_dict:
return d
Expand Down Expand Up @@ -224,6 +241,38 @@ def get_bulk_read_request(
)


def get_bulk_read_request_per_device_list(
location_path: str,
names: Sequence[str],
dtypes: Sequence[np.dtype],
shapes: Sequence[Sequence[int]],
shardings: Sequence[jax.sharding.Sharding],
devices: Sequence[jax.Device],
timeout: datetime.timedelta,
) -> str:
"""Returns a string representation of a bulk read request, reads multiple arrays with one call."""
read_requests = [
get_read_request(
location_path, name, dtype, shape, sharding, None, timeout, True
)["persistenceReadRequest"]
for name, dtype, shape, sharding in zip(names, dtypes, shapes, shardings)
]

if not isinstance(devices, np.ndarray):
devices = np.array(devices)

return json.dumps({
"bulk_persistence_read_request": {
"read_requests_per_device_list": {
"device_list": {
"device_ids": [device.id for device in devices.flatten()]
},
"read_requests": read_requests,
}
}
})


def write_one_array(
location: str,
name: str,
Expand Down