Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion stream_attention/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

from dataclasses import dataclass, field
from typing import Optional, Dict, Any, Tuple
import yaml
import os


Expand Down Expand Up @@ -60,6 +59,13 @@ class StreamAttentionConfig:
@classmethod
def from_yaml(cls, yaml_path: str) -> "StreamAttentionConfig":
"""Load configuration from YAML file"""
try:
import yaml # Lazy import to avoid hard dependency during module import
except Exception as e: # pragma: no cover - depends on env
raise ImportError(
"PyYAML is required for from_yaml(). Install with `pip install pyyaml`."
) from e

with open(yaml_path, "r") as f:
config_dict = yaml.safe_load(f)
return cls(**config_dict)
Expand Down
36 changes: 31 additions & 5 deletions stream_attention/core/flashattention_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,17 +80,43 @@ def forward(
if _use_flash_sdpa() and q.device.type == "cuda":
try:
# Prefer the newer torch.nn.attention API when available
sdpa_ctx = torch.nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION)
sdpa_ctx = torch.nn.attention.sdpa_kernel(
SDPBackend.FLASH_ATTENTION
)
except (AttributeError, TypeError):
# Fallback for older PyTorch releases
try:
# Older PyTorch versions expose the context in torch.backends
sdpa_ctx = torch.backends.cuda.sdp_kernel(
enable_math=False, enable_flash=True, enable_mem_efficient=False
enable_math=False,
enable_flash=True,
enable_mem_efficient=False,
)
except Exception as e: # pragma: no cover - depends on env
# Gracefully degrade to default kernel selection when the
# CUDA SDPA context manager is unavailable or unsupported.
logger.debug(
"torch.backends.cuda.sdp_kernel unavailable or unsupported: %s",
e,
)
except (AttributeError, RuntimeError):
sdpa_ctx = nullcontext()

with sdpa_ctx:
try:
with sdpa_ctx:
out = F.scaled_dot_product_attention(
q,
k,
v,
attn_mask,
dropout_p=self.dropout if self.training else 0.0,
is_causal=causal,
)
except RuntimeError as e: # pragma: no cover - device/kernel dependent
# If a forced-flash configuration leads to "no available kernel",
# retry without any forced backend so PyTorch can choose a valid one.
logger.debug(
"FlashAttention SDPA failed under forced settings, retrying default: %s",
e,
)
out = F.scaled_dot_product_attention(
q,
k,
Expand Down