Skip to content

Commit 13fb978

Browse files
committed
fix: bugfix for matmul when use_fp32_acc
1 parent 33afb83 commit 13fb978

File tree

5 files changed

+61
-17
lines changed

5 files changed

+61
-17
lines changed

py/torch_tensorrt/dynamo/conversion/impl/matmul.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,9 +48,13 @@ def matrix_multiply(
4848
input, other = broadcast(
4949
ctx, input, other, f"{name}_input", f"{name}_other", preset_diff
5050
)
51+
# Get the original input dtype
52+
input_dtype = _enums.dtype._from(input.dtype).to(torch.dtype)
53+
5154
if (
5255
ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
5356
and ctx.compilation_settings.use_fp32_acc
57+
and input_dtype == torch.float16
5458
):
5559
input = cast_trt_tensor(ctx, input, torch.float32, f"{name}_input_casted")
5660
other = cast_trt_tensor(ctx, other, torch.float32, f"{name}_other_casted")
@@ -63,9 +67,10 @@ def matrix_multiply(
6367
if (
6468
ctx.net.get_flag(trt.NetworkDefinitionCreationFlag.STRONGLY_TYPED)
6569
and ctx.compilation_settings.use_fp32_acc
70+
and input_dtype == torch.float16
6671
):
6772
matmul_output = cast_trt_tensor(
68-
ctx, matmul_output, torch.float16, f"{name}_output_casted"
73+
ctx, matmul_output, input_dtype, f"{name}_output_casted"
6974
)
7075

7176
set_layer_name(matmul_layer, target, name, source_ir)

tests/py/dynamo/models/test_llm_models.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -44,10 +44,7 @@ def test_llm_decoder_layer(precision):
4444
.to("cuda")
4545
)
4646

47-
if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None:
48-
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
49-
else:
50-
register_sdpa._SDPA_MAPPING["default"](model_config=model.config)
47+
register_sdpa.enable_sdpa_converter(args.model, model.config)
5148
model = model.to(dtype)
5249
# use randint will generate nan values in the logits, use a fixed input_ids for now
5350
# input_ids = torch.randint(0, model.config.vocab_size, (1, args.num_tokens)).to("cuda")

tools/llm/run_llm.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,7 @@ def get_model(args):
5959
.cuda()
6060
)
6161
# register SDPA variant for the model
62-
if register_sdpa._SDPA_MAPPING.get(args.model, None) is not None:
63-
register_sdpa._SDPA_MAPPING[args.model](model_config=model.config)
64-
else:
65-
register_sdpa._SDPA_MAPPING["default"](model_config=model.config)
62+
register_sdpa.enable_sdpa_converter(args.model, model.config)
6663

6764
if args.precision == "FP16":
6865
model = model.to(torch.float16)

tools/llm/run_vlm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -589,6 +589,8 @@ def print_outputs(backend_name: str, gen_tokens: torch.Tensor, tokenizer):
589589
print("--- Registering SDPA lowering pass locally for LM compilation ---")
590590
from torchtrt_ext import register_sdpa
591591

592+
register_sdpa.enable_sdpa_converter(args.model, model.config)
593+
592594
if args.cache == "static_v1":
593595
import static_cache_v1 # noqa: F401
594596
elif args.cache not in ("", None):

tools/llm/torchtrt_ext/register_sdpa.py

Lines changed: 51 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -19,16 +19,27 @@
1919

2020
logger = logging.getLogger(__name__)
2121

22-
# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention
23-
# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it.
24-
TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default, None)
25-
TORCH_TRT_DECOMPOSITIONS.pop(
26-
torch.ops.aten._scaled_dot_product_efficient_attention.default, None
27-
)
28-
TORCH_TRT_DECOMPOSITIONS.pop(
29-
torch.ops.aten._scaled_dot_product_flash_attention.default, None
22+
_SDPA_OPS_TO_REMOVE = (
23+
torch.ops.aten.scaled_dot_product_attention.default,
24+
torch.ops.aten._scaled_dot_product_efficient_attention.default,
25+
torch.ops.aten._scaled_dot_product_flash_attention.default,
3026
)
3127

28+
29+
def _remove_decompositions():
30+
"""
31+
Remove decompositions for SDPA operators.
32+
33+
This function is idempotent. It ensures that the SDPA operators are removed
34+
from the decomposition table, allowing a custom converter to be used.
35+
"""
36+
# Check if any of the decompositions still exist before proceeding
37+
if any(op in TORCH_TRT_DECOMPOSITIONS for op in _SDPA_OPS_TO_REMOVE):
38+
logger.debug("Removing SDPA decompositions to enable custom converter.")
39+
for op in _SDPA_OPS_TO_REMOVE:
40+
TORCH_TRT_DECOMPOSITIONS.pop(op, None)
41+
42+
3243
REPLACEABLE_ATEN_OPS = {
3344
torch.ops.aten._scaled_dot_product_efficient_attention.default,
3445
torch.ops.aten._scaled_dot_product_flash_attention.default,
@@ -271,3 +282,35 @@ def default_sdpa_pass(
271282
"google/gemma-3-1b-it": register_gemma3_sdpa_pass,
272283
"default": register_default_sdpa_pass,
273284
}
285+
286+
287+
def enable_sdpa_converter(model_name: str, model_config: Any) -> None:
288+
"""
289+
Enables the custom SDPA converter for a given model.
290+
291+
This function performs two main actions:
292+
1. Removes the default PyTorch SDPA decompositions from Torch-TensorRT's
293+
lowering registry. This is necessary to prevent them from being used
294+
instead of our custom converter.
295+
2. Registers a model-specific or default lowering pass that replaces the
296+
standard SDPA operators with a version optimized for TensorRT conversion.
297+
298+
Args:
299+
model_name (str): The name of the model (e.g., from Hugging Face).
300+
model_config (Any): The model's configuration object. This is used to
301+
extract parameters for model-specific optimizations,
302+
like sliding window attention.
303+
"""
304+
_remove_decompositions()
305+
306+
pass_registrator = _SDPA_MAPPING.get(model_name)
307+
308+
if pass_registrator:
309+
logger.info(f"Registering specific SDPA lowering pass for model: {model_name}")
310+
pass_registrator(model_config=model_config)
311+
else:
312+
logger.info(
313+
f"No specific SDPA lowering pass for model '{model_name}'. "
314+
"Using default SDPA pass."
315+
)
316+
_SDPA_MAPPING["default"](model_config=model_config)

0 commit comments

Comments
 (0)