You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi all,
When executing an HLO program using the Metal PJRT plugin, the program fails due to an unsupported data type eeturned by the rng_bit_generator operation.
The error message indicates that: Metal only supports MPSDataTypeFloat16, MPSDataTypeBFloat16, MPSDataTypeFloat32, MPSDataTypeInt32, and MPSDataTypeInt64.
The use of ui32 seems to be incompatible with Metal’s allowed types.
I’m trying to understand if the ui32 output is the problem or maybe the use of rng_bit_generator is wrong.
Could you clarify if there is a workaround or planned support for ui32 output in this context? Alternatively, guidance on configuring rng_bit_generator for compatibility with Metal’s supported types would be greatly appreciated.
Thanks
System info (python version, jaxlib version, accelerator, etc.)
dmarro89
changed the title
Unsupported Type in Metal PJRT Plugin with rng_bit_generator
Unsupported type in metal PJRT plugin with rng_bit_generator
Nov 12, 2024
Description
Hi all,
When executing an HLO program using the Metal PJRT plugin, the program fails due to an unsupported data type eeturned by the rng_bit_generator operation.
Specifically, the generated HLO includes:
%output_state, %output = "mhlo.rng_bit_generator"(%1) <{rng_algorithm = #mhlo.rng_algorithm<PHILOX>}> : (tensor<3xi64>) -> (tensor<3xi64>, tensor<3xui32>)
The error message indicates that:
Metal only supports MPSDataTypeFloat16, MPSDataTypeBFloat16, MPSDataTypeFloat32, MPSDataTypeInt32, and MPSDataTypeInt64.
The use of ui32 seems to be incompatible with Metal’s allowed types.
I’m trying to understand if the ui32 output is the problem or maybe the use of rng_bit_generator is wrong.
Could you clarify if there is a workaround or planned support for ui32 output in this context? Alternatively, guidance on configuring rng_bit_generator for compatibility with Metal’s supported types would be greatly appreciated.
Thanks
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: