|
19 | 19 |
|
20 | 20 | logger = logging.getLogger(__name__)
|
21 | 21 |
|
22 |
| -# Remove decompositions for aten.scaled_dot_product_attention, aten._scaled_dot_product_efficient_attention, aten._scaled_dot_product_flash_attention |
23 |
| -# This is because we want to have SDPA as a standalone operator in the graph and invoke the custom converter for it. |
24 |
| -TORCH_TRT_DECOMPOSITIONS.pop(torch.ops.aten.scaled_dot_product_attention.default, None) |
25 |
| -TORCH_TRT_DECOMPOSITIONS.pop( |
26 |
| - torch.ops.aten._scaled_dot_product_efficient_attention.default, None |
27 |
| -) |
28 |
| -TORCH_TRT_DECOMPOSITIONS.pop( |
29 |
| - torch.ops.aten._scaled_dot_product_flash_attention.default, None |
| 22 | +_SDPA_OPS_TO_REMOVE = ( |
| 23 | + torch.ops.aten.scaled_dot_product_attention.default, |
| 24 | + torch.ops.aten._scaled_dot_product_efficient_attention.default, |
| 25 | + torch.ops.aten._scaled_dot_product_flash_attention.default, |
30 | 26 | )
|
31 | 27 |
|
| 28 | + |
| 29 | +def _remove_decompositions(): |
| 30 | + """ |
| 31 | + Remove decompositions for SDPA operators. |
| 32 | +
|
| 33 | + This function is idempotent. It ensures that the SDPA operators are removed |
| 34 | + from the decomposition table, allowing a custom converter to be used. |
| 35 | + """ |
| 36 | + # Check if any of the decompositions still exist before proceeding |
| 37 | + if any(op in TORCH_TRT_DECOMPOSITIONS for op in _SDPA_OPS_TO_REMOVE): |
| 38 | + logger.debug("Removing SDPA decompositions to enable custom converter.") |
| 39 | + for op in _SDPA_OPS_TO_REMOVE: |
| 40 | + TORCH_TRT_DECOMPOSITIONS.pop(op, None) |
| 41 | + |
| 42 | + |
32 | 43 | REPLACEABLE_ATEN_OPS = {
|
33 | 44 | torch.ops.aten._scaled_dot_product_efficient_attention.default,
|
34 | 45 | torch.ops.aten._scaled_dot_product_flash_attention.default,
|
@@ -271,3 +282,35 @@ def default_sdpa_pass(
|
271 | 282 | "google/gemma-3-1b-it": register_gemma3_sdpa_pass,
|
272 | 283 | "default": register_default_sdpa_pass,
|
273 | 284 | }
|
| 285 | + |
| 286 | + |
| 287 | +def enable_sdpa_converter(model_name: str, model_config: Any) -> None: |
| 288 | + """ |
| 289 | + Enables the custom SDPA converter for a given model. |
| 290 | +
|
| 291 | + This function performs two main actions: |
| 292 | + 1. Removes the default PyTorch SDPA decompositions from Torch-TensorRT's |
| 293 | + lowering registry. This is necessary to prevent them from being used |
| 294 | + instead of our custom converter. |
| 295 | + 2. Registers a model-specific or default lowering pass that replaces the |
| 296 | + standard SDPA operators with a version optimized for TensorRT conversion. |
| 297 | +
|
| 298 | + Args: |
| 299 | + model_name (str): The name of the model (e.g., from Hugging Face). |
| 300 | + model_config (Any): The model's configuration object. This is used to |
| 301 | + extract parameters for model-specific optimizations, |
| 302 | + like sliding window attention. |
| 303 | + """ |
| 304 | + _remove_decompositions() |
| 305 | + |
| 306 | + pass_registrator = _SDPA_MAPPING.get(model_name) |
| 307 | + |
| 308 | + if pass_registrator: |
| 309 | + logger.info(f"Registering specific SDPA lowering pass for model: {model_name}") |
| 310 | + pass_registrator(model_config=model_config) |
| 311 | + else: |
| 312 | + logger.info( |
| 313 | + f"No specific SDPA lowering pass for model '{model_name}'. " |
| 314 | + "Using default SDPA pass." |
| 315 | + ) |
| 316 | + _SDPA_MAPPING["default"](model_config=model_config) |
0 commit comments