diff --git a/stream_attention/core/config.py b/stream_attention/core/config.py index 36d6f2a..a880840 100644 --- a/stream_attention/core/config.py +++ b/stream_attention/core/config.py @@ -6,7 +6,6 @@ from dataclasses import dataclass, field from typing import Optional, Dict, Any, Tuple -import yaml import os @@ -60,6 +59,13 @@ class StreamAttentionConfig: @classmethod def from_yaml(cls, yaml_path: str) -> "StreamAttentionConfig": """Load configuration from YAML file""" + try: + import yaml # Lazy import to avoid hard dependency during module import + except Exception as e: # pragma: no cover - depends on env + raise ImportError( + "PyYAML is required for from_yaml(). Install with `pip install pyyaml`." + ) from e + with open(yaml_path, "r") as f: config_dict = yaml.safe_load(f) return cls(**config_dict) diff --git a/stream_attention/core/flashattention_v3.py b/stream_attention/core/flashattention_v3.py index dacbe3b..1119572 100644 --- a/stream_attention/core/flashattention_v3.py +++ b/stream_attention/core/flashattention_v3.py @@ -80,17 +80,43 @@ def forward( if _use_flash_sdpa() and q.device.type == "cuda": try: # Prefer the newer torch.nn.attention API when available - sdpa_ctx = torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) + sdpa_ctx = torch.nn.attention.sdpa_kernel( + SDPBackend.FLASH_ATTENTION + ) except (AttributeError, TypeError): + # Fallback for older PyTorch releases try: - # Older PyTorch versions expose the context in torch.backends sdpa_ctx = torch.backends.cuda.sdp_kernel( - enable_math=False, enable_flash=True, enable_mem_efficient=False + enable_math=False, + enable_flash=True, + enable_mem_efficient=False, + ) + except Exception as e: # pragma: no cover - depends on env + # Gracefully degrade to default kernel selection when the + # CUDA SDPA context manager is unavailable or unsupported. + logger.debug( + "torch.backends.cuda.sdp_kernel unavailable or unsupported: %s", + e, ) - except (AttributeError, RuntimeError): sdpa_ctx = nullcontext() - with sdpa_ctx: + try: + with sdpa_ctx: + out = F.scaled_dot_product_attention( + q, + k, + v, + attn_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=causal, + ) + except RuntimeError as e: # pragma: no cover - device/kernel dependent + # If a forced-flash configuration leads to "no available kernel", + # retry without any forced backend so PyTorch can choose a valid one. + logger.debug( + "FlashAttention SDPA failed under forced settings, retrying default: %s", + e, + ) out = F.scaled_dot_product_attention( q, k,