Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pattern match dot algorithm spec to preset name #24820

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

dfm
Copy link
Collaborator

@dfm dfm commented Nov 9, 2024

Since some dot algorithm presets have special cased input and output storage type behavior (e.g. all the BF16 algorithms), and since JAX's handling of these cases is (for better or worse) handled at the "preset" level, this PR provides a small quality of life improvement to convert known lax.DotAlgorithm specs to explicit lax.DotAlgorithmPreset members whenever possible. For example, if a user specifies:

precision = lax.DotAlgorithm(dtypes.bfloat16, dtypes.bfloat16, np.float32,
                             num_primitive_operations=6)

this will be canonicalized to lax.DotAlgorithmPreset.BF16_BF16_F32_X6 and the input and output casting will be handled properly.

@dfm dfm self-assigned this Nov 9, 2024
@dfm dfm added the pull ready Ready for copybara import and testing label Nov 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
pull ready Ready for copybara import and testing
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant