-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Allow present_key to be empty when past_key is provided in Attention #26303
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR relaxes the validation logic for past/present key tensor requirements in the Attention operator. Previously, the code enforced that past_key and present_key must both be null or both non-null. The new logic allows present_key to be empty when past_key is provided, accommodating ONNX models where dead code elimination may remove unused outputs.
Key changes:
- Replaced strict bidirectional enforcement with unidirectional validation
- Removed implementation limitation comments that are no longer applicable
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
f03d6d1 to
a648379
Compare
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
Signed-off-by: Justin Chu <[email protected]>
|
For some reason present key and present values are not the same as k and v when there is no past: 2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 1.2701495885848999, which exceeds tolerance, where
2: cur_expected[i] evaluates to 2.0174980163574219,
2: cur_actual[i] evaluates to 0.74734842777252197, and
2: tolerance evaluates to 0.00023174978559836745.
2: i:0
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 1.2701497077941895, which exceeds tolerance, where
2: cur_expected[i] evaluates to 3.0174980163574219,
2: cur_actual[i] evaluates to 1.7473483085632324, and
2: tolerance evaluates to 0.00033174979034811258.
2: i:1
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 0.73147225379943848, which exceeds tolerance, where
2: cur_expected[i] evaluates to 2,
2: cur_actual[i] evaluates to 2.7314722537994385, and
2: tolerance evaluates to 0.00022999999055173248.
2: i:2
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 0.73147225379943848, which exceeds tolerance, where
2: cur_expected[i] evaluates to 3,
2: cur_actual[i] evaluates to 3.7314722537994385, and
2: tolerance evaluates to 0.00032999998074956238.
2: i:3
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 0.73406648635864258, which exceeds tolerance, where
2: cur_expected[i] evaluates to 1.4550299644470215,
2: cur_actual[i] evaluates to 0.72096347808837891, and
2: tolerance evaluates to 0.00017550299526192248.
2: i:4
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 0.73406648635864258, which exceeds tolerance, where
2: cur_expected[i] evaluates to 2.4550299644470215,
2: cur_actual[i] evaluates to 1.7209634780883789, and
2: tolerance evaluates to 0.00027550300001166761.
2: i:5
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 1.3152304887771606, which exceeds tolerance, where
2: cur_expected[i] evaluates to 1.4400719404220581,
2: cur_actual[i] evaluates to 2.7553024291992188, and
2: tolerance evaluates to 0.00017400718934368342.
2: i:6
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 1.3152303695678711, which exceeds tolerance, where
2: cur_expected[i] evaluates to 2.4400720596313477,
2: cur_actual[i] evaluates to 3.7553024291992188, and
2: tolerance evaluates to 0.00027400720864534378.
2: i:7
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 0.40000000596046448, which exceeds tolerance, where
2: cur_expected[i] evaluates to 0.40000000596046448,
2: cur_actual[i] evaluates to 0.80000001192092896, and
2: tolerance evaluates to 4.9999998736893758e-05.
2: i:2
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 0.30000001192092896, which exceeds tolerance, where
2: cur_expected[i] evaluates to 0.60000002384185791,
2: cur_actual[i] evaluates to 0.30000001192092896, and
2: tolerance evaluates to 7.0000001869630069e-05.
2: i:3
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 0.40000000596046448, which exceeds tolerance, where
2: cur_expected[i] evaluates to 0.80000001192092896,
2: cur_actual[i] evaluates to 0.40000000596046448, and
2: tolerance evaluates to 9.0000001364387572e-05.
2: i:4
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 0.30000001192092896, which exceeds tolerance, where
2: cur_expected[i] evaluates to 0.30000001192092896,
2: cur_actual[i] evaluates to 0.60000002384185791, and
2: tolerance evaluates to 3.9999998989515007e-05.
2: i:5
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 2.5, which exceeds tolerance, where
2: cur_expected[i] evaluates to 3,
2: cur_actual[i] evaluates to 0.5, and
2: tolerance evaluates to 0.00030999997397884727.
2: i:2
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 2.5, which exceeds tolerance, where
2: cur_expected[i] evaluates to 4,
2: cur_actual[i] evaluates to 1.5, and
2: tolerance evaluates to 0.0004099999787285924.
2: i:3
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 2.5, which exceeds tolerance, where
2: cur_expected[i] evaluates to 0.5,
2: cur_actual[i] evaluates to 3, and
2: tolerance evaluates to 5.999999848427251e-05.
2: i:4
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider
2:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(432): error: The difference between cur_expected[i] and cur_actual[i] is 2.5, which exceeds tolerance, where
2: cur_expected[i] evaluates to 1.5,
2: cur_actual[i] evaluates to 4, and
2: tolerance evaluates to 0.00015999999595806003.
2: i:5
2: Google Test trace:
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\checkers.cc(610): provider type: CPUExecutionProvider
2: E:\_work\onnxruntime\onnxruntime\onnxruntime\test\unittest_util\base_tester.cc(877): registered execution providers: CPUExecutionProvider |
|
Signed-off-by: Justin Chu <[email protected]>
d859a63 to
46f44e0
Compare
Signed-off-by: Justin Chu <[email protected]>
|
Sorry, I wasn't following this. Curious if this is going to change any assumption we have with MultiHeadAttention? I am using MultiHeadAttention as CUDA kernel for Attention. I wonder whether we enforced this because of contrib op cuda implementation? cc @tianleiwu |
|
Shouldn't be affecting as this is the cpu implementation. The cuda kernel can have its own constraints, but in general this is already supported as far as I understand. We just had redundant checks. |
The original check enforces both the present_key and the past_key must be present. But with IO-binding there may be an issue: The past_key can be nullptr even when present_key is allocated. In reality, the kernel should just do the computation when it has the data, or when the output is requested.