@@ -306,24 +306,16 @@ def _compute_spec(config: Config, spec: MaybeSpec,
306
306
spec = spec (** config .meta )
307
307
return spec
308
308
309
- def specialize_kernel (config : Config ,
309
+ def specialize_kernel (config : pallas_core . KernelConfig ,
310
310
func : Callable ,
311
- grid : Optional [pallas_core .Grid ],
312
311
name : Optional [str ],
313
- in_specs : Optional [list [Optional [BlockSpec ]]],
314
- out_specs : Optional [list [Optional [BlockSpec ]]],
315
312
in_avals : tuple [jax_core .ShapedArray , ...],
316
313
out_avals : tuple [jax_core .ShapedArray , ...],
317
314
in_tree : tree_util .PyTreeDef ,
318
315
compiler_params : dict [str , Any ]
319
316
) -> tuple [SpecializedKernel , ...]:
320
- specialized_grid = grid
321
- if callable (specialized_grid ):
322
- specialized_grid = specialized_grid (** config .meta )
323
- specialized_grid = pallas_core .preprocess_grid (specialized_grid )
324
- specialized_in_specs = map (partial (_compute_spec , config ), in_specs )
325
- specialized_out_specs = map (partial (_compute_spec , config ), out_specs )
326
- if specialized_grid == ():
317
+ grid = config .grid
318
+ if grid == ():
327
319
in_ref_avals = [state .shaped_array_ref (arg .shape , arg .dtype )
328
320
for arg in in_avals ]
329
321
out_ref_avals = [state .shaped_array_ref (arg .shape , arg .dtype )
@@ -333,42 +325,76 @@ def specialize_kernel(config: Config,
333
325
state .shaped_array_ref (
334
326
pallas_core .compute_shape_from_block_spec (block_spec , aval .shape ),
335
327
aval .dtype )
336
- for block_spec , aval in zip (specialized_in_specs , in_avals )]
328
+ for block_spec , aval in zip (config . in_specs , in_avals )]
337
329
out_ref_avals = [
338
330
state .shaped_array_ref (
339
331
pallas_core .compute_shape_from_block_spec (block_spec , aval .shape ),
340
332
aval .dtype )
341
- for block_spec , aval in zip (specialized_out_specs , out_avals )]
342
- in_block_mappings = map (partial (pallas_core .convert_block_spec_to_block_mapping , specialized_grid ),
343
- specialized_in_specs )
344
- out_block_mappings = map (partial (pallas_core .convert_block_spec_to_block_mapping , specialized_grid ),
345
- specialized_out_specs )
346
- grid_spec = pallas_core .GridSpec (specialized_grid , (* in_block_mappings , * out_block_mappings ), ())
333
+ for block_spec , aval in zip (config .out_specs , out_avals )]
334
+ in_block_mappings = map (
335
+ partial (pallas_core .convert_block_spec_to_block_mapping , grid ),
336
+ config .in_specs )
337
+ out_block_mappings = map (
338
+ partial (pallas_core .convert_block_spec_to_block_mapping , grid ),
339
+ config .out_specs )
340
+ grid_spec = pallas_core .GridSpec (grid , (* in_block_mappings , * out_block_mappings ), ())
347
341
jaxpr , consts , out_tree = tracing_utils .initial_style_open_jaxpr (
348
342
func , in_tree , tuple ((* in_ref_avals , * out_ref_avals )), "pallas_call" , ** config .meta )
349
343
return SpecializedKernel ("foo" , jaxpr , len (consts ), grid_spec ,
350
344
dict (compiler_params , ** config .compiler_params )), consts , out_tree
351
345
352
- def pallas_call (f : Callable , out_shape : Any , * , debug : bool = False ,
346
+ def _canonicalize_kernel_config (
347
+ maybe_kernel_config : Optional [pallas_core .KernelConfig ],
348
+ in_avals : Sequence [jax_core .AbstractValue ],
349
+ out_avals : Sequence [jax_core .AbstractValue ],
350
+ in_specs : Optional [Sequence [Optional [BlockSpec ]]],
351
+ out_specs : Optional [Sequence [Optional [BlockSpec ]]],
352
+ grid : Optional [Union [Grid , int ]],
353
+ ) -> pallas_core .KernelConfig :
354
+ if not maybe_kernel_config :
355
+ config = pallas_core .KernelConfig (in_specs = in_specs , out_specs = out_specs , grid = grid )
356
+ else :
357
+ config = maybe_kernel_config
358
+ grid = maybe_kernel_config .grid
359
+ grid , in_specs , out_specs = config .grid , config .in_specs , config .out_specs
360
+ grid = pallas_core .preprocess_grid (grid )
361
+ if in_specs is not None and not isinstance (in_specs , (tuple , list )):
362
+ in_specs = (in_specs ,)
363
+ if out_specs is not None and not isinstance (out_specs , (tuple , list )):
364
+ out_specs = (out_specs ,)
365
+ if in_specs is None :
366
+ in_specs = [None ] * len (in_avals )
367
+ if out_specs is None :
368
+ out_specs = [None ] * len (out_avals )
369
+ return config .replace (grid = grid , in_specs = in_specs , out_specs = out_specs )
370
+
371
+ def pallas_call (f : Callable , out_shape : Any , * ,
353
372
grid : Optional [Grid ] = None ,
373
+ config : Optional [pallas_core .KernelConfig ] = None ,
354
374
in_specs : Optional [Sequence [Optional [BlockSpec ]]] = None ,
355
375
out_specs : Optional [Sequence [Optional [BlockSpec ]]] = None ,
356
376
input_output_aliases : Dict [int , int ] = {},
357
377
interpret : bool = False ,
358
378
name : Optional [str ] = None ,
359
- autotuning_configs : Optional [list [Config ]] = None ,
379
+ autotuning_configs : Optional [Sequence [pallas_core .KernelConfig ]] = None ,
380
+ debug : bool = False ,
360
381
** compiler_params : Any ):
382
+ if config is not None :
383
+ if grid is not None or in_specs is not None or out_specs is not None :
384
+ raise ValueError ("Cannot specify both `config` and any of `grid`, "
385
+ "`in_specs`, or `out_specs`." )
386
+ if autotuning_configs is not None :
387
+ raise ValueError ("Cannot specify both `config` and `autotuning_configs`" )
388
+ if autotuning_configs is not None :
389
+ if grid is not None or in_specs is not None or out_specs is not None :
390
+ raise ValueError ("Cannot specify both `autotuning_configs` and any of `grid`, "
391
+ "`in_specs`, or `out_specs`." )
361
392
singleton = False
362
393
if not isinstance (out_shape , (tuple , list )):
363
394
out_shape = (out_shape ,)
364
395
singleton = True
365
396
if not isinstance (out_shape , tuple ):
366
397
out_shape = tuple (out_shape )
367
- if in_specs is not None and not isinstance (in_specs , (tuple , list )):
368
- in_specs = (in_specs ,)
369
- if out_specs is not None and not isinstance (out_specs , (tuple , list )):
370
- out_specs = (out_specs ,)
371
-
372
398
if not name :
373
399
name = f .__name__ if hasattr (f , "__name__" ) else "unnamed"
374
400
@@ -382,29 +408,32 @@ def wrapped(*args):
382
408
for a in flat_args )
383
409
flat_out_avals = tuple (jax_core .ShapedArray (a .shape , a .dtype )
384
410
for a in flat_out_shapes )
411
+ canonicalized_configs = []
412
+ if autotuning_configs is None :
413
+ canonicalized_configs .append (_canonicalize_kernel_config (config ,
414
+ flat_in_avals ,
415
+ flat_out_avals ,
416
+ in_specs ,
417
+ out_specs ,
418
+ grid ))
419
+ else :
420
+ canonicalized_configs .extend (map (partial (_canonicalize_kernel_config ,
421
+ in_avals = flat_in_avals ,
422
+ out_avals = flat_out_avals ,
423
+ in_specs = in_specs ,
424
+ out_specs = out_specs ,
425
+ grid = grid ),
426
+ autotuning_configs ))
385
427
kernels = []
386
- flat_in_specs = in_specs
387
- flat_out_specs = out_specs
388
- if flat_in_specs is None :
389
- flat_in_specs = [None ] * len (flat_in_avals )
390
- if flat_out_specs is None :
391
- flat_out_specs = [None ] * len (flat_out_avals )
392
428
all_consts = []
393
- if autotuning_configs is None :
429
+ if len (canonicalized_configs ) == 0 :
430
+ raise ValueError ("Cannot pass in empty autotuning configs" )
431
+ for canonicalized_config in canonicalized_configs :
394
432
specialized_kernel , consts , jaxpr_out_tree = specialize_kernel (
395
- Config ({}, {}), f , grid , name , flat_in_specs , flat_out_specs , flat_in_avals ,
433
+ canonicalized_config , f , name , flat_in_avals ,
396
434
flat_out_avals , jaxpr_in_tree , compiler_params )
397
435
kernels .append (specialized_kernel )
398
436
all_consts .extend (consts )
399
- else :
400
- if len (autotuning_configs ) == 0 :
401
- raise ValueError ("Cannot pass in empty autotuning configs" )
402
- for config in autotuning_configs :
403
- specialized_kernel , consts , jaxpr_out_tree = specialize_kernel (
404
- config , f , grid , name , flat_in_specs , flat_out_specs , flat_in_avals , flat_out_avals ,
405
- jaxpr_in_tree , compiler_params )
406
- kernels .append (specialized_kernel )
407
- all_consts .extend (consts )
408
437
if all_consts :
409
438
raise NotImplementedError ("Cannot handle consts." )
410
439
del jaxpr_out_tree
0 commit comments