@@ -168,35 +168,52 @@ def get_read_request(
168
168
dtype : np .dtype ,
169
169
shape : Sequence [int ],
170
170
sharding : jax .sharding .Sharding ,
171
- devices : Sequence [jax .Device ],
171
+ devices : Sequence [jax .Device ] | None ,
172
172
timeout : datetime .timedelta ,
173
173
return_dict : bool = False ,
174
174
) -> Union [str , dict [str , Any ]]:
175
175
"""Returns a string representation of the plugin program which reads the given array from the given location into the provided sharding."""
176
- if not isinstance (devices , np .ndarray ):
177
- devices = np .array (devices )
178
-
179
176
timeout_seconds , timeout_fractional_seconds = divmod (
180
177
timeout .total_seconds (), 1
181
178
)
182
179
timeout_nanoseconds = timeout_fractional_seconds * 1e9
183
- d = {
184
- "persistenceReadRequest" : {
185
- "b64_location" : string_to_base64 (location_path ),
186
- "shape" : get_shape_info (dtype , shape ),
187
- "b64_name" : string_to_base64 (name ),
188
- "b64_hlo_sharding_string" : get_hlo_sharding_string (
189
- sharding , len (shape )
190
- ),
191
- "devices" : {
192
- "device_ids" : [device .id for device in devices .flatten ()]
193
- },
194
- "timeout" : {
195
- "seconds" : int (timeout_seconds ),
196
- "nanos" : int (timeout_nanoseconds ),
197
- },
198
- }
199
- }
180
+
181
+ if devices is None :
182
+ d = {
183
+ "persistenceReadRequest" : {
184
+ "b64_location" : string_to_base64 (location_path ),
185
+ "shape" : get_shape_info (dtype , shape ),
186
+ "b64_name" : string_to_base64 (name ),
187
+ "b64_hlo_sharding_string" : get_hlo_sharding_string (
188
+ sharding , len (shape )
189
+ ),
190
+ "timeout" : {
191
+ "seconds" : int (timeout_seconds ),
192
+ "nanos" : int (timeout_nanoseconds ),
193
+ },
194
+ }
195
+ }
196
+ else :
197
+ if not isinstance (devices , np .ndarray ):
198
+ devices = np .array (devices )
199
+
200
+ d = {
201
+ "persistenceReadRequest" : {
202
+ "b64_location" : string_to_base64 (location_path ),
203
+ "shape" : get_shape_info (dtype , shape ),
204
+ "b64_name" : string_to_base64 (name ),
205
+ "b64_hlo_sharding_string" : get_hlo_sharding_string (
206
+ sharding , len (shape )
207
+ ),
208
+ "devices" : {
209
+ "device_ids" : [device .id for device in devices .flatten ()]
210
+ },
211
+ "timeout" : {
212
+ "seconds" : int (timeout_seconds ),
213
+ "nanos" : int (timeout_nanoseconds ),
214
+ },
215
+ }
216
+ }
200
217
201
218
if return_dict :
202
219
return d
@@ -224,6 +241,38 @@ def get_bulk_read_request(
224
241
)
225
242
226
243
244
+ def get_bulk_read_request_per_device_list (
245
+ location_path : str ,
246
+ names : Sequence [str ],
247
+ dtypes : Sequence [np .dtype ],
248
+ shapes : Sequence [Sequence [int ]],
249
+ shardings : Sequence [jax .sharding .Sharding ],
250
+ devices : Sequence [jax .Device ],
251
+ timeout : datetime .timedelta ,
252
+ ) -> str :
253
+ """Returns a string representation of a bulk read request, reads multiple arrays with one call."""
254
+ read_requests = [
255
+ get_read_request (
256
+ location_path , name , dtype , shape , sharding , None , timeout , True
257
+ )["persistenceReadRequest" ]
258
+ for name , dtype , shape , sharding in zip (names , dtypes , shapes , shardings )
259
+ ]
260
+
261
+ if not isinstance (devices , np .ndarray ):
262
+ devices = np .array (devices )
263
+
264
+ return json .dumps ({
265
+ "bulk_persistence_read_request" : {
266
+ "read_requests_per_device_list" : {
267
+ "device_list" : {
268
+ "device_ids" : [device .id for device in devices .flatten ()]
269
+ },
270
+ "read_requests" : read_requests ,
271
+ }
272
+ }
273
+ })
274
+
275
+
227
276
def write_one_array (
228
277
location : str ,
229
278
name : str ,
0 commit comments