diff --git a/physicsnemo/models/transolver/Embedding.py b/physicsnemo/models/transolver/Embedding.py index d376606590..ebcfa494f0 100644 --- a/physicsnemo/models/transolver/Embedding.py +++ b/physicsnemo/models/transolver/Embedding.py @@ -1,9 +1,14 @@ # ignore_header_test # ruff: noqa: E402 -"""""" -""" -Transolver model. This code was modified from, https://github.com/thuml/Transolver +r""" +Transolver model embedding utilities. + +This module provides positional encoding and embedding utilities for the +Transolver model, including rotary position embeddings (RoPE), sinusoidal +positional encodings, and timestep embeddings for diffusion-style models. + +This code was modified from https://github.com/thuml/Transolver The following license is provided from their source, @@ -35,89 +40,304 @@ import torch import torch.nn as nn from einops import rearrange +from jaxtyping import Float class RotaryEmbedding(nn.Module): - "ROPE: Rotary Position Embedding" + r""" + Rotary Position Embedding (RoPE). + + Implements rotary position embeddings that encode positional information + by rotating query and key vectors in attention mechanisms. + + For more details, see: `RoFormer paper `_ + + Parameters + ---------- + dim : int + Embedding dimension (must be even). + min_freq : float, optional, default=0.5 + Minimum frequency for the sinusoidal embeddings. + scale : float, optional, default=1.0 + Scaling factor for the input coordinates. - def __init__(self, dim, min_freq=1 / 2, scale=1.0): + Forward + ------- + coordinates : torch.Tensor + Coordinate values of shape :math:`(B, N)` where :math:`B` is batch size + and :math:`N` is sequence length. + device : torch.device + Device to place the output tensor on. + + Outputs + ------- + torch.Tensor + Rotary frequencies of shape :math:`(B, N, D)` where :math:`D` is the + embedding dimension. + + Examples + -------- + >>> import torch + >>> rope = RotaryEmbedding(dim=64) + >>> coords = torch.linspace(0, 1, 100).unsqueeze(0) # (1, 100) + >>> freqs = rope(coords, device="cpu") + >>> freqs.shape + torch.Size([1, 100, 64]) + """ + + def __init__(self, dim: int, min_freq: float = 1 / 2, scale: float = 1.0): super().__init__() + # Compute inverse frequencies for sinusoidal embeddings inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) self.min_freq = min_freq self.scale = scale self.register_buffer("inv_freq", inv_freq) - def forward(self, coordinates, device): - # coordinates [b, n] - t = coordinates.to(device).type_as(self.inv_freq) + def forward( + self, + coordinates: Float[torch.Tensor, "batch seq"], + device: torch.device, + ) -> Float[torch.Tensor, "batch seq dim"]: + r""" + Compute rotary frequencies for given coordinates. + + Parameters + ---------- + coordinates : torch.Tensor + Coordinate values of shape :math:`(B, N)`. + device : torch.device + Target device for output tensor. + + Returns + ------- + torch.Tensor + Rotary frequencies of shape :math:`(B, N, D)`. + """ + # Scale coordinates + inv_freq: torch.Tensor = self.inv_freq # type: ignore[assignment] + t = coordinates.to(device).type_as(inv_freq) t = t * (self.scale / self.min_freq) - freqs = torch.einsum("... i , j -> ... i j", t, self.inv_freq) # [b, n, d//2] - return torch.cat((freqs, freqs), dim=-1) # [b, n, d] + # Compute frequencies via outer product + freqs = torch.einsum("... i , j -> ... i j", t, self.inv_freq) # (B, N, D//2) -def rotate_half(x): + # Concatenate to get full dimension + return torch.cat((freqs, freqs), dim=-1) # (B, N, D) + + +def rotate_half(x: Float[torch.Tensor, "... dim"]) -> Float[torch.Tensor, "... dim"]: + r""" + Rotate the last dimension by splitting and swapping halves. + + This is a helper function for applying rotary position embeddings. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(*, D)` where the last dimension will be + split and rotated. + + Returns + ------- + torch.Tensor + Rotated tensor of the same shape as input. + """ + # Split into two halves and swap with negation x = rearrange(x, "... (j d) -> ... j d", j=2) x1, x2 = x.unbind(dim=-2) return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(t, freqs): +def apply_rotary_pos_emb( + t: Float[torch.Tensor, "... dim"], + freqs: Float[torch.Tensor, "... dim"], +) -> Float[torch.Tensor, "... dim"]: + r""" + Apply rotary position embeddings to input tensor. + + Parameters + ---------- + t : torch.Tensor + Input tensor of shape :math:`(*, D)`. + freqs : torch.Tensor + Rotary frequencies of shape :math:`(*, D)` from + :class:`RotaryEmbedding`. + + Returns + ------- + torch.Tensor + Tensor with rotary embeddings applied, same shape as input. + """ return (t * freqs.cos()) + (rotate_half(t) * freqs.sin()) -def apply_2d_rotary_pos_emb(t, freqs_x, freqs_y): - # split t into first half and second half - # t: [b, h, n, d] - # freq_x/y: [b, n, d] +def apply_2d_rotary_pos_emb( + t: Float[torch.Tensor, "batch heads seq dim"], + freqs_x: Float[torch.Tensor, "batch seq dim_half"], + freqs_y: Float[torch.Tensor, "batch seq dim_half"], +) -> Float[torch.Tensor, "batch heads seq dim"]: + r""" + Apply 2D rotary position embeddings for spatial data. + + Splits the input tensor in half along the feature dimension and applies + separate rotary embeddings for x and y coordinates. + + Parameters + ---------- + t : torch.Tensor + Input tensor of shape :math:`(B, H, N, D)` where :math:`H` is heads. + freqs_x : torch.Tensor + X-coordinate frequencies of shape :math:`(B, N, D/2)`. + freqs_y : torch.Tensor + Y-coordinate frequencies of shape :math:`(B, N, D/2)`. + + Returns + ------- + torch.Tensor + Tensor with 2D rotary embeddings applied, shape :math:`(B, H, N, D)`. + """ d = t.shape[-1] + + # Split input into x and y halves t_x, t_y = t[..., : d // 2], t[..., d // 2 :] + # Apply rotary embeddings separately and concatenate return torch.cat( (apply_rotary_pos_emb(t_x, freqs_x), apply_rotary_pos_emb(t_y, freqs_y)), dim=-1 ) class PositionalEncoding(nn.Module): - "Implement the PE function." + r""" + Sinusoidal positional encoding. + + Implements fixed sinusoidal positional encodings as described in + `Attention Is All You Need `_. + + Parameters + ---------- + d_model : int + Model dimension (embedding size). + dropout : float + Dropout probability applied after adding positional encoding. + max_len : int, optional, default=177241 + Maximum sequence length supported (default is 421*421). + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, D)` where :math:`N` is sequence + length and :math:`D` is ``d_model``. - def __init__(self, d_model, dropout, max_len=421 * 421): + Outputs + ------- + torch.Tensor + Input with positional encoding added, shape :math:`(B, N, D)`. + + Examples + -------- + >>> import torch + >>> pe = PositionalEncoding(d_model=128, dropout=0.1) + >>> x = torch.randn(2, 100, 128) + >>> out = pe(x) + >>> out.shape + torch.Size([2, 100, 128]) + """ + + def __init__(self, d_model: int, dropout: float, max_len: int = 421 * 421): super(PositionalEncoding, self).__init__() self.dropout = nn.Dropout(p=dropout) - # Compute the positional encodings once in log space. + # Compute positional encodings in log space for numerical stability pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2) * -(math.log(10000.0) / d_model) ) + + # Apply sin to even indices, cos to odd indices pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) + + # Add batch dimension and register as buffer pe = pe.unsqueeze(0) self.register_buffer("pe", pe) - def forward(self, x): - x = x + self.pe[:, : x.size(1)].requires_grad_(False) + def forward( + self, x: Float[torch.Tensor, "batch seq dim"] + ) -> Float[torch.Tensor, "batch seq dim"]: + r""" + Add positional encoding to input. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, D)`. + + Returns + ------- + torch.Tensor + Input with positional encoding added, shape :math:`(B, N, D)`. + """ + # Add positional encoding (no gradient needed for PE) + pe: torch.Tensor = self.pe # type: ignore[assignment] + x = x + pe[:, : x.size(1)].requires_grad_(False) return self.dropout(x) -def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): - """ +def timestep_embedding( + timesteps: Float[torch.Tensor, " batch"], + dim: int, + max_period: int = 10000, + repeat_only: bool = False, +) -> Float[torch.Tensor, "batch dim"]: + r""" Create sinusoidal timestep embeddings. - :param timesteps: a 1-D Tensor of N indices, one per batch element. - These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ + Generates embeddings for diffusion model timesteps using sinusoidal + functions, similar to transformer positional encodings. + + Parameters + ---------- + timesteps : torch.Tensor + 1-D tensor of :math:`N` timestep indices, one per batch element. + These may be fractional values. + dim : int + Dimension of the output embeddings. + max_period : int, optional, default=10000 + Controls the minimum frequency of the embeddings. + repeat_only : bool, optional, default=False + Currently unused, kept for API compatibility. + + Returns + ------- + torch.Tensor + Positional embeddings of shape :math:`(N, D)` where :math:`D` is + ``dim``. + + Examples + -------- + >>> import torch + >>> timesteps = torch.tensor([0.0, 0.5, 1.0]) + >>> emb = timestep_embedding(timesteps, dim=64) + >>> emb.shape + torch.Size([3, 64]) + """ half = dim // 2 + + # Compute frequencies freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half ).to(device=timesteps.device) + + # Compute embeddings via outer product args = timesteps[:, None].float() * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + + # Handle odd dimensions by padding with zeros if dim % 2: embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding diff --git a/physicsnemo/models/transolver/Physics_Attention.py b/physicsnemo/models/transolver/Physics_Attention.py index 72ccaeee7a..d5c7eb93b3 100644 --- a/physicsnemo/models/transolver/Physics_Attention.py +++ b/physicsnemo/models/transolver/Physics_Attention.py @@ -1,9 +1,14 @@ # ignore_header_test # ruff: noqa: E402 -"""""" -""" -Transolver model. This code was modified from, https://github.com/thuml/Transolver +r""" +Physics attention modules for the Transolver model. + +This module provides physics-informed attention mechanisms that project inputs +onto learned physics slices before applying attention. These attention variants +support irregular meshes, 2D structured grids, and 3D volumetric data. + +This code was modified from https://github.com/thuml/Transolver The following license is provided from their source, @@ -35,8 +40,13 @@ import torch import torch.nn as nn +from einops import rearrange +from jaxtyping import Float +from torch.autograd.profiler import record_function +from torch.distributed.tensor.placement_types import Replicate from physicsnemo.core.version_check import check_version_spec +from physicsnemo.domain_parallel import ShardTensor TE_AVAILABLE = check_version_spec("transformer_engine", hard_fail=False) @@ -45,52 +55,97 @@ else: te = None -from einops import rearrange -from torch.autograd.profiler import record_function -from torch.distributed.tensor.placement_types import Replicate - -from physicsnemo.domain_parallel import ShardTensor +def gumbel_softmax( + logits: Float[torch.Tensor, "... num_categories"], + tau: torch.Tensor | float = 1.0, +) -> Float[torch.Tensor, "... num_categories"]: + r""" + Implementation of Gumbel Softmax from Transolver++. -def gumbel_softmax(logits: torch.Tensor, tau: float = 1.0) -> torch.Tensor: - """ - Implementation of Gumblel Softmax from transolver++. + Applies a differentiable approximation to sampling from a categorical + distribution using the Gumbel-Softmax trick. Original code: https://github.com/thuml/Transolver_plus/blob/main/models/Transolver_plus.py#L69 - Args: - logits (torch.Tensor): The logits to apply Gumblel Softmax to. - tau (float): The temperature parameter for the Gumblel Softmax. - - Returns: - torch.Tensor: The Gumblel Softmax of the logits. + Parameters + ---------- + logits : torch.Tensor + Input logits tensor of shape :math:`(*, K)` where :math:`K` is the + number of categories. + tau : torch.Tensor | float, optional, default=1.0 + Temperature parameter. Can be a scalar float or a tensor for + per-element temperature. Lower values make the distribution more + concentrated. + + Returns + ------- + torch.Tensor + Gumbel-Softmax output of the same shape as ``logits``. """ + # Sample Gumbel noise u = torch.rand_like(logits) gumbel_noise = -torch.log(-torch.log(u + 1e-8) + 1e-8) + # Add noise and apply temperature-scaled softmax y = logits + gumbel_noise y = y / tau - y = torch.nn.functional.softmax(y, dim=-1) return y class PhysicsAttentionBase(nn.Module, ABC): - """ - Base class for all physics attention modules. - - Implements key functionality that is common across domains: - - Slice weighting and computation - - Attention among slices - - Deslicing - - Output Projection - - Each subclass must implement it's own methods for projecting input domain tokens onto the slice space. - - Deliberately, there are not default values for any of the parameters. It's assumed you will - assign them in the subclass. - + r""" + Base class for physics attention modules. + + This class implements the core physics attention mechanism that projects + inputs onto learned physics-informed slices before applying attention. + Subclasses implement domain-specific input projections. + + The physics attention mechanism consists of: + + 1. Project inputs onto learned slice space + 2. Compute slice weights via temperature-scaled softmax + 3. Aggregate features for each slice + 4. Apply attention among slices + 5. Project attended features back to original space + + Parameters + ---------- + dim : int + Input feature dimension. + heads : int + Number of attention heads. + dim_head : int + Dimension per attention head. + dropout : float + Dropout rate. + slice_num : int + Number of physics slices. + use_te : bool + Whether to use transformer engine. + plus : bool + Whether to use Transolver++ variant. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)` where :math:`B` is batch size, + :math:`N` is number of tokens, :math:`C` is feature dimension. + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(B, N, C)`. + + Note + ---- + This is an abstract base class. Use one of the concrete implementations: + + - :class:`PhysicsAttentionIrregularMesh` for unstructured mesh data + - :class:`PhysicsAttentionStructuredMesh2D` for 2D image-like data + - :class:`PhysicsAttentionStructuredMesh3D` for 3D volumetric data """ def __init__( @@ -113,9 +168,12 @@ def __init__( self.softmax = nn.Softmax(dim=-1) self.dropout = nn.Dropout(dropout) + + # Learnable temperature parameter for slice weighting self.temperature = nn.Parameter(torch.ones([1, 1, heads, 1]) * 0.5) if plus: + # Transolver++ uses learned temperature projection linear_layer = te.Linear if self.use_te else nn.Linear self.proj_temperature = torch.nn.Sequential( linear_layer(self.dim_head, slice_num), @@ -124,17 +182,20 @@ def __init__( nn.GELU(), ) + # Projection from head dimension to slice space if self.use_te: self.in_project_slice = te.Linear(dim_head, slice_num) else: self.in_project_slice = nn.Linear(dim_head, slice_num) + # Initialize with orthogonal weights for better slice diversity for l_i in [self.in_project_slice]: - torch.nn.init.orthogonal_(l_i.weight) # use a principled initialization + torch.nn.init.orthogonal_(l_i.weight) + + # QKV projection for slice attention if not use_te: self.qkv_project = nn.Linear(dim_head, 3 * dim_head, bias=False) else: - # These are used in the transformer engine pass function: self.qkv_project = te.Linear(dim_head, 3 * dim_head, bias=False) self.attn_fn = te.DotProductAttention( num_attention_heads=self.heads, @@ -144,6 +205,7 @@ def __init__( softmax_scale=self.scale, ) + # Output projection if self.use_te: self.out_linear = te.Linear(inner_dim, dim) else: @@ -152,63 +214,84 @@ def __init__( self.out_dropout = nn.Dropout(dropout) @abstractmethod - def project_input_onto_slices(self, x) -> tuple[torch.Tensor, torch.Tensor]: - """ - Project the input onto the slice space. + def project_input_onto_slices( + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> ( + Float[torch.Tensor, "batch tokens heads head_dim"] + | tuple[ + Float[torch.Tensor, "batch tokens heads head_dim"], + Float[torch.Tensor, "batch tokens heads head_dim"], + ] + ): + r""" + Project input tensor onto the slice space. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)`. + + Returns + ------- + torch.Tensor | tuple[torch.Tensor, torch.Tensor] + For Transolver++: single projected tensor of shape + :math:`(B, N, H, D_h)`. + For standard Transolver: tuple of (x_mid, fx_mid) both of shape + :math:`(B, N, H, D_h)`. """ raise NotImplementedError("Subclasses must implement this method") def compute_slices_from_projections( - self, slice_projections: torch.Tensor, fx: torch.Tensor - ) -> tuple[torch.Tensor, torch.Tensor]: - """ - Compute slice weights and slice tokens from input projections and latent features. - - In a domain-parallel setting, this function will do an implicit allreduce. - When we sum over the slice_weights over a sharded dimension - and use the output, it will resolve Partial->Replicated placement (aka - allreduce) implicitly. - - Args: - slice_projections (torch.Tensor): - The projected input tensor of shape [Batch, N_tokens, N_heads, Slice_num], - representing the projection of each token onto each slice for each attention head. - fx (torch.Tensor): - The latent feature tensor of shape [Batch, N_tokens, N_heads, Head_dim], - representing the learned states to be aggregated by the slice weights. - - Returns: - tuple[torch.Tensor, torch.Tensor]: - - slice_weights: Tensor of shape [Batch, N_tokens, N_heads, Slice_num], - representing the normalized weights for each slice per token and head. - - slice_token: Tensor of shape [Batch, N_heads, Slice_num, Head_dim], - representing the aggregated latent features for each slice, head, and batch. - - Notes: - - The function first computes a temperature-scaled softmax over the slice projections to obtain slice weights. - - It then aggregates the latent features (fx) for each slice using these weights. - - The aggregated features are normalized by the sum of weights for numerical stability. + self, + slice_projections: Float[torch.Tensor, "batch tokens heads slice_num"], + fx: Float[torch.Tensor, "batch tokens heads head_dim"], + ) -> tuple[ + Float[torch.Tensor, "batch tokens heads slice_num"], + Float[torch.Tensor, "batch heads slice_num head_dim"], + ]: + r""" + Compute slice weights and slice tokens from input projections. + + This method computes soft assignments of tokens to physics slices using + temperature-scaled softmax, then aggregates features for each slice. + + In domain-parallel settings, this performs an implicit allreduce when + summing over the sharded token dimension. + + Parameters + ---------- + slice_projections : torch.Tensor + Projected input of shape :math:`(B, N, H, S)` where :math:`S` is + the number of slices. + fx : torch.Tensor + Latent features of shape :math:`(B, N, H, D_h)`. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + - ``slice_weights``: Shape :math:`(B, N, H, S)`, normalized weights + for each slice per token. + - ``slice_token``: Shape :math:`(B, H, S, D_h)`, aggregated features + per slice. """ - - # Project the latent space vectors on to the weight computation space, - # and compute a temperature adjusted softmax. - + # Compute temperature-scaled softmax over slices if self.plus: + # Transolver++ uses learned per-token temperature temperature = self.temperature + self.proj_temperature(fx) clamped_temp = torch.clamp(temperature, min=0.01).to( slice_projections.dtype ) slice_weights = gumbel_softmax( slice_projections, clamped_temp - ) # [Batch, N_tokens, N_heads, Slice_num] - + ) # (B, N, H, S) else: + # Standard Transolver uses global temperature clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( slice_projections.dtype ) slice_weights = nn.functional.softmax( slice_projections / clamped_temp, dim=-1 - ) # [Batch, N_heads, N_tokens, Slice_num] + ) # (B, N, H, S) # Cast to the computation type (since the parameter is probably fp32) slice_weights = slice_weights.to(slice_projections.dtype) @@ -234,6 +317,7 @@ def compute_slices_from_projections( # Like the weight norm, this sum is a **partial** sum since we are summing # over the tokens + # Aggregate features: (B, N, H, S)^T @ (B, N, H, D_h) -> (B, H, S, D_h) slice_token = torch.matmul( normed_weights.permute(0, 2, 3, 1), fx.permute(0, 2, 1, 3) ) @@ -242,33 +326,50 @@ def compute_slices_from_projections( return slice_weights, slice_token - def compute_slice_attention_te(self, slice_tokens: torch.Tensor) -> torch.Tensor: + def compute_slice_attention_te( + self, slice_tokens: Float[torch.Tensor, "batch heads slice_num head_dim"] + ) -> Float[torch.Tensor, "batch heads slice_num head_dim"]: + r""" + Compute attention among slices using Transformer Engine. + + Parameters + ---------- + slice_tokens : torch.Tensor + Slice features of shape :math:`(B, H, S, D_h)`. + + Returns + ------- + torch.Tensor + Attended slice features of shape :math:`(B, H, S, D_h)`. """ - TE implementation of slice attention - """ - + # Project to Q, K, V qkv = self.qkv_project(slice_tokens) - qkv = rearrange(qkv, " b h s (t d) -> t b s h d", t=3, d=self.dim_head) + qkv = rearrange(qkv, "b h s (t d) -> t b s h d", t=3, d=self.dim_head) q_slice_token, k_slice_token, v_slice_token = qkv.unbind(0) - out_slice_token2 = self.attn_fn(q_slice_token, k_slice_token, v_slice_token) - out_slice_token2 = rearrange( - out_slice_token2, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head + # Apply TE attention + out_slice_token = self.attn_fn(q_slice_token, k_slice_token, v_slice_token) + out_slice_token = rearrange( + out_slice_token, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head ) - return out_slice_token2 + return out_slice_token - def compute_slice_attention_sdpa(self, slice_tokens: torch.Tensor) -> torch.Tensor: - """ - Torch SDPA implementation of slice attention + def compute_slice_attention_sdpa( + self, slice_tokens: Float[torch.Tensor, "batch heads slice_num head_dim"] + ) -> Float[torch.Tensor, "batch heads slice_num head_dim"]: + r""" + Compute attention among slices using PyTorch SDPA. - Args: - slice_tokens (torch.Tensor): - The slice tokens tensor of shape [Batch, N_heads, Slice_num, Head_dim]. + Parameters + ---------- + slice_tokens : torch.Tensor + Slice features of shape :math:`(B, H, S, D_h)`. - Returns: - torch.Tensor: - The output tensor of shape [Batch, N_heads, Slice_num, Head_dim]. + Returns + ------- + torch.Tensor + Attended slice features of shape :math:`(B, H, S, D_h)`. """ with record_function("compute_slice_attention_sdpa"): # In this case we're using ShardTensor, ensure slice_token is *replicated* @@ -283,6 +384,7 @@ def compute_slice_attention_sdpa(self, slice_tokens: torch.Tensor) -> torch.Tens q_slice_token, k_slice_token, v_slice_token = qkv.unbind(3) + # Apply scaled dot-product attention out_slice_token = torch.nn.functional.scaled_dot_product_attention( q_slice_token, k_slice_token, v_slice_token, is_causal=False ) @@ -290,100 +392,143 @@ def compute_slice_attention_sdpa(self, slice_tokens: torch.Tensor) -> torch.Tens return out_slice_token def project_attention_outputs( - self, out_slice_token: torch.Tensor, slice_weights: torch.Tensor - ) -> torch.Tensor: - """ - Project the attended slice tokens back onto the original token space. - - Note that in the distributed case, this will have a replicated and - sharded inputs. Slice tokens will be replicated, and slice weights will be sharded. - - Args: - out_slice_token (torch.Tensor): - The output tensor from the attention mechanism over slices, - of shape [Batch, N_heads, Slice_num, Head_dim]. - slice_weights (torch.Tensor): - The slice weights tensor of shape [Batch, N_tokens, N_heads, Slice_num], - representing the contribution of each slice to each token. - - Returns: - torch.Tensor: - The reconstructed output tensor of shape [Batch, N_tokens, N_heads * Head_dim], - representing the attended features for each token, with all heads concatenated. - - Notes: - - The function projects the attended slice tokens back to the token space using the slice weights. - - The output is reshaped to concatenate all attention heads for each token. + self, + out_slice_token: Float[torch.Tensor, "batch heads slice_num head_dim"], + slice_weights: Float[torch.Tensor, "batch tokens heads slice_num"], + ) -> Float[torch.Tensor, "batch tokens channels"]: + r""" + Project attended slice features back to token space. + + In distributed settings, ``out_slice_token`` is replicated while + ``slice_weights`` may be sharded over tokens. + + Parameters + ---------- + out_slice_token : torch.Tensor + Attended slice features of shape :math:`(B, H, S, D_h)`. + slice_weights : torch.Tensor + Slice weights of shape :math:`(B, N, H, S)`. + + Returns + ------- + torch.Tensor + Output features of shape :math:`(B, N, H \cdot D_h)`. """ with record_function("project_attention_outputs"): - # Slice weights has shape (Batch, n_tokens, n_heads, slice_num) - # Out slice tokens has shape (Batch, n_heads, slice_num, head_dim) - # The output of this function needs to have shape - # (Batch, n_tokens, n_channels) == (Batch, n_tokens, n_heads * head_dim) - # Note that tokens may be sharded, in which case slice_weights - # is a sharded tensor and out_slice_token is a replicated tensor - + # Weighted combination: (B, N, H, S) @ (B, H, S, D_h) -> (B, N, H, D_h) out_x = torch.einsum("bths,bhsd->bthd", slice_weights, out_slice_token) - # Condense the last two dimensions: + # Concatenate heads: (B, N, H, D_h) -> (B, N, H*D_h) out_x = rearrange(out_x, "b t h d -> b t (h d)") + # Output projection with dropout out_x = self.out_linear(out_x) return self.out_dropout(out_x) - def forward(self, x: torch.Tensor) -> torch.Tensor: + def forward( + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> Float[torch.Tensor, "batch tokens channels"]: + r""" + Forward pass of physics attention. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)`. + + Returns + ------- + torch.Tensor + Output tensor of shape :math:`(B, N, C)`. """ - Forward pass of the Physics Attention module. - - Input x should have shape of [Batch, N_tokens, N_Channels] ([B, N, C]) - """ - - # Project the inputs onto learned spaces: + # Input validation (skip during torch.compile) + if not torch.compiler.is_compiling(): + if x.ndim != 3: + raise ValueError( + f"Expected 3D input tensor (B, N, C), " + f"got {x.ndim}D tensor with shape {tuple(x.shape)}" + ) + + # Project inputs onto learned spaces + projected = self.project_input_onto_slices(x) if self.plus: - x_mid = self.project_input_onto_slices(x) - # In transolver ++, fx_mid is gone. - # x_mid is used to compute the projections instead: - fx_mid = x_mid + x_mid = projected + fx_mid = x_mid # Transolver++ reuses x_mid else: - x_mid, fx_mid = self.project_input_onto_slices(x) - - # Perform the linear projection of learned latent space onto slices: + x_mid, fx_mid = projected # type: ignore[misc] + # Project onto slice space slice_projections = self.in_project_slice(x_mid) + # slice_projections: (B, N, H, S) - # Slice projections has shape [B, N_tokens, N_head, Head_dim], but head_dim may have changed! - - # Use the slice projections and learned spaces to compute the slices, and their weights: + # Compute slice weights and aggregate features per slice slice_weights, slice_tokens = self.compute_slices_from_projections( slice_projections, fx_mid ) - # slice_weights has shape [Batch, N_tokens, N_heads, Slice_num] - # slice_tokens has shape [Batch, N_tokens, N_heads, head_dim] + # slice_weights: (B, N, H, S) + # slice_tokens: (B, H, S, D_h) - # Apply attention to the slice tokens + # Apply attention among slices if self.use_te: out_slice_token = self.compute_slice_attention_te(slice_tokens) else: out_slice_token = self.compute_slice_attention_sdpa(slice_tokens) + # out_slice_token: (B, H, S, D_h) - # Shape unchanged - - # Deslice: + # Project back to token space outputs = self.project_attention_outputs(out_slice_token, slice_weights) - - # Outputs now has the same shape as the original input x + # outputs: (B, N, C) return outputs class PhysicsAttentionIrregularMesh(PhysicsAttentionBase): - """ - Specialization of PhysicsAttention to Irregular Meshes + r""" + Physics attention for irregular/unstructured mesh data. + + Uses linear projections to map input tokens to the slice space, suitable + for meshes with arbitrary connectivity. + + Parameters + ---------- + dim : int + Input feature dimension. + heads : int, optional, default=8 + Number of attention heads. + dim_head : int, optional, default=64 + Dimension per attention head. + dropout : float, optional, default=0.0 + Dropout rate. + slice_num : int, optional, default=64 + Number of physics slices. + use_te : bool, optional, default=True + Whether to use transformer engine. + plus : bool, optional, default=False + Whether to use Transolver++ variant. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)`. + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(B, N, C)`. + + Examples + -------- + >>> import torch + >>> attn = PhysicsAttentionIrregularMesh(dim=128, heads=4, dim_head=32, dropout=0.0, slice_num=16, use_te=False) + >>> x = torch.randn(2, 1000, 128) + >>> out = attn(x) + >>> out.shape + torch.Size([2, 1000, 128]) """ def __init__( self, - dim, + dim: int, heads: int = 8, dim_head: int = 64, dropout: float = 0.0, @@ -393,6 +538,8 @@ def __init__( ): super().__init__(dim, heads, dim_head, dropout, slice_num, use_te, plus) inner_dim = dim_head * heads + + # Linear projections for irregular mesh data if use_te: self.in_project_x = te.Linear(dim, inner_dim) if not plus: @@ -403,21 +550,32 @@ def __init__( self.in_project_fx = nn.Linear(dim, inner_dim) def project_input_onto_slices( - self, x - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """ - Project the input onto the slice space. - - Args: - x (torch.Tensor): The input tensor of shape [Batch, N_tokens, N_Channels] - - Returns: - tuple[torch.Tensor, torch.Tensor]: The projected x and fx tensors of shape [Batch, N_tokens, N_Channels], [Batch, N_tokens, N_heads, Head_dim] - + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> ( + Float[torch.Tensor, "batch tokens heads head_dim"] + | tuple[ + Float[torch.Tensor, "batch tokens heads head_dim"], + Float[torch.Tensor, "batch tokens heads head_dim"], + ] + ): + r""" + Project input onto slice space using linear layers. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)`. + + Returns + ------- + torch.Tensor | tuple[torch.Tensor, torch.Tensor] + Projected tensors of shape :math:`(B, N, H, D_h)`. """ + # Project and reshape to multi-head format x_mid = rearrange( self.in_project_x(x), "B N (h d) -> B N h d", h=self.heads, d=self.dim_head ) + if self.plus: return x_mid else: @@ -427,15 +585,63 @@ def project_input_onto_slices( h=self.heads, d=self.dim_head, ) - return x_mid, fx_mid class PhysicsAttentionStructuredMesh2D(PhysicsAttentionBase): - """ - Specialization for 2d image-like meshes - - Only implements the projection onto the slice space. + r""" + Physics attention for 2D structured/image-like data. + + Uses 2D convolutions to project inputs, leveraging spatial locality in + structured grids. + + Parameters + ---------- + dim : int + Input feature dimension. + spatial_shape : tuple[int, int] + Spatial dimensions (height, width) of the input. + heads : int, optional, default=8 + Number of attention heads. + dim_head : int, optional, default=64 + Dimension per attention head. + dropout : float, optional, default=0.0 + Dropout rate. + slice_num : int, optional, default=64 + Number of physics slices. + kernel : int, optional, default=3 + Convolution kernel size. + use_te : bool, optional, default=True + Whether to use transformer engine. + plus : bool, optional, default=False + Whether to use Transolver++ variant. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, H \times W, C)` (flattened spatial). + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(B, H \times W, C)`. + + Examples + -------- + >>> import torch + >>> attn = PhysicsAttentionStructuredMesh2D( + ... dim=128, + ... spatial_shape=(32, 32), + ... heads=4, + ... dim_head=32, + ... dropout=0.0, + ... slice_num=16, + ... use_te=False, + ... ) + >>> x = torch.randn(2, 32*32, 128) + >>> out = attn(x) + >>> out.shape + torch.Size([2, 1024, 128]) """ def __init__( @@ -443,35 +649,54 @@ def __init__( dim: int, spatial_shape: tuple[int, int], heads: int = 8, - dim_head=64, + dim_head: int = 64, dropout: float = 0.0, slice_num: int = 64, kernel: int = 3, use_te: bool = True, plus: bool = False, - ): # kernel=3): + ): super().__init__(dim, heads, dim_head, dropout, slice_num, use_te, plus) inner_dim = dim_head * heads self.H = spatial_shape[0] self.W = spatial_shape[1] + # 2D convolution projections self.in_project_x = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2) if not plus: self.in_project_fx = nn.Conv2d(dim, inner_dim, kernel, 1, kernel // 2) def project_input_onto_slices( - self, x - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - # Rearrange the input tokens back to an image shape: + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> ( + Float[torch.Tensor, "batch tokens heads head_dim"] + | tuple[ + Float[torch.Tensor, "batch tokens heads head_dim"], + Float[torch.Tensor, "batch tokens heads head_dim"], + ] + ): + r""" + Project input onto slice space using 2D convolutions. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, H \times W, C)`. + + Returns + ------- + torch.Tensor | tuple[torch.Tensor, torch.Tensor] + Projected tensors of shape :math:`(B, H \times W, H_{heads}, D_h)`. + """ b = x.shape[0] c = x.shape[-1] + # Reshape from flattened to 2D spatial format x = x.view(b, self.H, self.W, c) - x = x.permute(0, 3, 1, 2) - - # Apply the projections, here they are convolutions in 2D: + x = x.permute(0, 3, 1, 2) # (B, C, H, W) + # Apply 2D convolution and reshape to multi-head format input_projected_x = self.in_project_x(x) input_projected_x = rearrange( input_projected_x, @@ -479,28 +704,74 @@ def project_input_onto_slices( head_dim=self.dim_head, n_heads=self.heads, ) + if self.plus: return input_projected_x else: input_projected_fx = self.in_project_fx(x) - - # Next, re-reshape the projections into token-like shapes: input_projected_fx = rearrange( input_projected_fx, "b (n_heads head_dim) h w -> b (h w) n_heads head_dim", head_dim=self.dim_head, n_heads=self.heads, ) - - # Return the projections: return input_projected_x, input_projected_fx class PhysicsAttentionStructuredMesh3D(PhysicsAttentionBase): - """ - Specialization for 3D-image like meshes - - Only implements the projection onto the slice space. + r""" + Physics attention for 3D structured/volumetric data. + + Uses 3D convolutions to project inputs, suitable for voxel-based + representations. + + Parameters + ---------- + dim : int + Input feature dimension. + spatial_shape : tuple[int, int, int] + Spatial dimensions (height, width, depth) of the input. + heads : int, optional, default=8 + Number of attention heads. + dim_head : int, optional, default=64 + Dimension per attention head. + dropout : float, optional, default=0.0 + Dropout rate. + slice_num : int, optional, default=32 + Number of physics slices. + kernel : int, optional, default=3 + Convolution kernel size. + use_te : bool, optional, default=True + Whether to use transformer engine. + plus : bool, optional, default=False + Whether to use Transolver++ variant. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, H \times W \times D, C)` (flattened). + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(B, H \times W \times D, C)`. + + Examples + -------- + >>> import torch + >>> attn = PhysicsAttentionStructuredMesh3D( + ... dim=64, + ... spatial_shape=(16, 16, 16), + ... heads=4, + ... dim_head=16, + ... dropout=0.0, + ... slice_num=8, + ... use_te=False, + ... ) + >>> x = torch.randn(2, 16*16*16, 64) + >>> out = attn(x) + >>> out.shape + torch.Size([2, 4096, 64]) """ def __init__( @@ -512,7 +783,7 @@ def __init__( dropout: float = 0.0, slice_num: int = 32, kernel: int = 3, - use_te: int = True, + use_te: bool = True, plus: bool = False, ): super().__init__(dim, heads, dim_head, dropout, slice_num, use_te, plus) @@ -522,43 +793,56 @@ def __init__( self.W = spatial_shape[1] self.D = spatial_shape[2] + # 3D convolution projections self.in_project_x = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2) if not plus: self.in_project_fx = nn.Conv3d(dim, inner_dim, kernel, 1, kernel // 2) def project_input_onto_slices( - self, x - ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: - """ - Project the input onto the slice space. - - Input tensor has shape [Batch, N_tokens, N_Channels] + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> ( + Float[torch.Tensor, "batch tokens heads head_dim"] + | tuple[ + Float[torch.Tensor, "batch tokens heads head_dim"], + Float[torch.Tensor, "batch tokens heads head_dim"], + ] + ): + r""" + Project input onto slice space using 3D convolutions. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, H \times W \times D, C)`. + + Returns + ------- + torch.Tensor | tuple[torch.Tensor, torch.Tensor] + Projected tensors of shape :math:`(B, H \times W \times D, H_{heads}, D_h)`. """ - b = x.shape[0] c = x.shape[-1] - # x = rearrange(x, "b (h w d) c -> b c h w d", h=self.H, w=self.W, d=self.D) + # Reshape from flattened to 3D spatial format x = x.view(b, self.H, self.W, self.D, c) - x = x.permute(0, 4, 1, 2, 3) + x = x.permute(0, 4, 1, 2, 3) # (B, C, H, W, D) - # Apply the projections, here they are convolutions: + # Apply 3D convolution and reshape to multi-head format input_projected_x = self.in_project_x(x) - - # Next, re-reshape the projections into token-like shapes: input_projected_x = rearrange( input_projected_x, "b (n_heads head_dim) h w d -> b (h w d) n_heads head_dim", head_dim=self.dim_head, n_heads=self.heads, ) + if self.plus: return input_projected_x else: input_projected_fx = self.in_project_fx(x) input_projected_fx = rearrange( input_projected_fx, - "b (n_heads head_dim) h w -> b (h w d) n_heads head_dim", + "b (n_heads head_dim) h w d -> b (h w d) n_heads head_dim", head_dim=self.dim_head, n_heads=self.heads, ) diff --git a/physicsnemo/models/transolver/__init__.py b/physicsnemo/models/transolver/__init__.py index 79bd53b5e4..e1f4919055 100644 --- a/physicsnemo/models/transolver/__init__.py +++ b/physicsnemo/models/transolver/__init__.py @@ -1,9 +1,17 @@ # ignore_header_test # ruff: noqa: E402 -"""""" -""" -Transolver model. This code was modified from, https://github.com/thuml/Transolver +r""" +Transolver model for physics-informed neural operator learning. + +This module provides the Transolver model, which adapts the transformer +architecture with a physics-attention mechanism for solving partial +differential equations on both structured and unstructured meshes. + +The Transolver model learns to project inputs onto physics-informed slices +before applying attention, enabling efficient learning of physical systems. + +This code was modified from https://github.com/thuml/Transolver The following license is provided from their source, @@ -28,6 +36,49 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +References +---------- +- `Transolver paper `_ +- `Transolver++ paper `_ + +Examples +-------- +Structured 2D data with unified position: + +>>> import torch +>>> from physicsnemo.models.transolver import Transolver +>>> model = Transolver( +... functional_dim=3, +... out_dim=1, +... structured_shape=(64, 64), +... unified_pos=True, +... n_hidden=128, +... n_head=4, +... use_te=False, +... ) +>>> x = torch.randn(2, 64, 64, 3) +>>> out = model(x) +>>> out.shape +torch.Size([2, 64, 64, 1]) + +Unstructured mesh data: + +>>> model = Transolver( +... functional_dim=2, +... embedding_dim=3, +... out_dim=1, +... structured_shape=None, +... unified_pos=False, +... n_hidden=128, +... n_head=4, +... use_te=False, +... ) +>>> fx = torch.randn(2, 1000, 2) +>>> emb = torch.randn(2, 1000, 3) +>>> out = model(fx, embedding=emb) +>>> out.shape +torch.Size([2, 1000, 1]) """ from .transolver import Transolver diff --git a/physicsnemo/models/transolver/transolver.py b/physicsnemo/models/transolver/transolver.py index 225aa1f3d4..7d16691b4e 100644 --- a/physicsnemo/models/transolver/transolver.py +++ b/physicsnemo/models/transolver/transolver.py @@ -1,9 +1,14 @@ # ignore_header_test # ruff: noqa: E402 -"""""" -""" -Transolver model. This code was modified from, https://github.com/thuml/Transolver +r""" +Transolver model and building blocks for physics-informed neural operator learning. + +This module provides the main Transolver model class along with its internal +building blocks (MLP, Transolver_block) for solving PDEs on structured and +unstructured meshes. + +This code was modified from https://github.com/thuml/Transolver The following license is provided from their source, @@ -36,6 +41,7 @@ import numpy as np import torch import torch.nn as nn +from jaxtyping import Float import physicsnemo # noqa: F401 for docs from physicsnemo.core.meta import ModelMetaData @@ -43,8 +49,6 @@ from physicsnemo.core.version_check import check_version_spec from .Embedding import timestep_embedding - -# from .Physics_Attention import Physics_Attention_Structured_Mesh_2D from .Physics_Attention import ( PhysicsAttentionIrregularMesh, PhysicsAttentionStructuredMesh2D, @@ -71,22 +75,76 @@ class MLP(nn.Module): + r""" + Multi-layer perceptron with optional residual connections. + + This MLP supports transformer engine linear layers for optimized performance + and optional residual connections in hidden layers. + + Parameters + ---------- + n_input : int + Number of input features. + n_hidden : int + Number of hidden features in each layer. + n_output : int + Number of output features. + n_layers : int, optional, default=1 + Number of hidden layers with residual connections. + act : str, optional, default="gelu" + Activation function name. Must be one of: ``"gelu"``, ``"tanh"``, + ``"sigmoid"``, ``"relu"``, ``"leaky_relu"``, ``"softplus"``, + ``"ELU"``, ``"silu"``. + res : bool, optional, default=True + Whether to use residual connections in hidden layers. + use_te : bool, optional, default=True + Whether to use transformer engine linear layers. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(*, D_{in})` where :math:`*` denotes + any number of batch dimensions. + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(*, D_{out})`. + + Note + ---- + This MLP differs from :class:`~physicsnemo.nn.Mlp` by supporting: + + - Transformer engine linear layers via ``use_te`` + - Residual connections via ``res`` + - Fixed hidden dimension across all layers + """ + def __init__( - self, n_input, n_hidden, n_output, n_layers=1, act="gelu", res=True, use_te=True + self, + n_input: int, + n_hidden: int, + n_output: int, + n_layers: int = 1, + act: str = "gelu", + res: bool = True, + use_te: bool = True, ): super(MLP, self).__init__() if act in ACTIVATION.keys(): - act = ACTIVATION[act] + act_fn = ACTIVATION[act] else: - raise NotImplementedError + raise NotImplementedError( + f"Activation '{act}' not supported. Choose from: {list(ACTIVATION.keys())}" + ) self.n_input = n_input self.n_hidden = n_hidden self.n_output = n_output self.n_layers = n_layers self.res = res - self.act = act() + self.act = act_fn() linear_layer = nn.Linear if not use_te else te.Linear @@ -94,52 +152,118 @@ def __init__( self.linear_post = linear_layer(n_hidden, n_output) self.linears = nn.ModuleList( [ - nn.Sequential(linear_layer(n_hidden, n_hidden), act()) + nn.Sequential(linear_layer(n_hidden, n_hidden), act_fn()) for _ in range(n_layers) ] ) - def forward(self, x): + def forward( + self, x: Float[torch.Tensor, "... d_in"] + ) -> Float[torch.Tensor, "... d_out"]: + r""" + Forward pass of the MLP. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(*, D_{in})`. + + Returns + ------- + torch.Tensor + Output tensor of shape :math:`(*, D_{out})`. + """ + # Project input to hidden dimension x = self.act(self.linear_pre(x)) + + # Apply hidden layers with optional residual connections for i in range(self.n_layers): if self.res: x = self.linears[i](x) + x else: x = self.linears[i](x) + + # Project to output dimension x = self.linear_post(x) return x class Transolver_block(nn.Module): - """Transformer encoder block, replacing standard attention with physics attention.""" + r""" + Transformer encoder block with physics attention mechanism. + + This block replaces standard attention with physics attention, which learns + to project inputs onto physics-informed slices before applying attention. + + Parameters + ---------- + num_heads : int + Number of attention heads. + hidden_dim : int + Hidden dimension of the block. + dropout : float + Dropout rate. + act : str, optional, default="gelu" + Activation function name. + mlp_ratio : int, optional, default=4 + Ratio of MLP hidden dimension to ``hidden_dim``. + last_layer : bool, optional, default=False + Whether this is the last layer (applies output projection). + out_dim : int, optional, default=1 + Output dimension (only used if ``last_layer=True``). + slice_num : int, optional, default=32 + Number of physics slices. + spatial_shape : tuple[int, ...] | None, optional, default=None + Spatial shape for structured data. ``None`` for irregular meshes. + use_te : bool, optional, default=True + Whether to use transformer engine. + plus : bool, optional, default=False + Whether to use Transolver++ variant. + + Forward + ------- + fx : torch.Tensor + Input tensor of shape :math:`(B, N, C)` where :math:`B` is batch size, + :math:`N` is number of tokens, :math:`C` is hidden dimension. + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(B, N, C)` or :math:`(B, N, D_{out})` + if ``last_layer=True``. + """ def __init__( self, num_heads: int, hidden_dim: int, dropout: float, - act="gelu", - mlp_ratio=4, - last_layer=False, - out_dim=1, - slice_num=32, + act: str = "gelu", + mlp_ratio: int = 4, + last_layer: bool = False, + out_dim: int = 1, + slice_num: int = 32, spatial_shape: tuple[int, ...] | None = None, - use_te=True, + use_te: bool = True, plus: bool = False, ): super().__init__() if use_te and not TE_AVAILABLE: raise ImportError( - "Transformer Engine is not installed. Please install it with `pip install transformer-engine`." + "Transformer Engine is not installed. Please install it with " + "`pip install transformer-engine`." ) self.last_layer = last_layer + + # Layer normalization before attention if use_te: self.ln_1 = te.LayerNorm(hidden_dim) else: self.ln_1 = nn.LayerNorm(hidden_dim) + # Select appropriate physics attention based on spatial structure if spatial_shape is None: self.Attn = PhysicsAttentionIrregularMesh( hidden_dim, @@ -174,10 +298,12 @@ def __init__( plus=plus, ) else: - raise Exception( - f"Unexpected length of spatial shape encountered in Transolver_block: {len(spatial_shape)}" + raise ValueError( + f"Unexpected length of spatial shape encountered in Transolver_block: " + f"{len(spatial_shape)}. Expected 2 or 3." ) + # Feed-forward network with layer norm if use_te: self.ln_mlp1 = te.LayerNormMLP( hidden_size=hidden_dim, @@ -196,6 +322,8 @@ def __init__( use_te=False, ), ) + + # Output projection for final layer if self.last_layer: if use_te: self.ln_mlp2 = te.LayerNormLinear( @@ -207,9 +335,29 @@ def __init__( nn.Linear(hidden_dim, out_dim), ) - def forward(self, fx): + def forward( + self, fx: Float[torch.Tensor, "batch tokens hidden"] + ) -> Float[torch.Tensor, "batch tokens out"]: + r""" + Forward pass of the Transolver block. + + Parameters + ---------- + fx : torch.Tensor + Input tensor of shape :math:`(B, N, C)`. + + Returns + ------- + torch.Tensor + Output tensor of shape :math:`(B, N, C)` or :math:`(B, N, D_{out})`. + """ + # Apply physics attention with residual connection fx = self.Attn(self.ln_1(fx)) + fx + + # Apply feed-forward network with residual connection fx = self.ln_mlp1(fx) + fx + + # Apply output projection if last layer if self.last_layer: return self.ln_mlp2(fx) else: @@ -218,6 +366,8 @@ def forward(self, fx): @dataclass class MetaData(ModelMetaData): + r"""Metadata for the Transolver model.""" + # Optimization jit: bool = False cuda_graphs: bool = False @@ -233,82 +383,113 @@ class MetaData(ModelMetaData): class Transolver(Module): - """ - Transolver model, adapted from original transolver code. - - Transolver is an adaptation of the transformer architecture, with a physics-attention - mechanism replacing the standard attention mechanism. - - For more architecture details, see: https://arxiv.org/pdf/2402.02366 and https://arxiv.org/pdf/2502.02414 - - Transolver can work on structured or unstructured data points as a model construction choice: - - unstructured data (like a mesh) should provide some sort of positional encoding to accompany inputs - - structured data (2D and 3D grids) can provide positional encodings optionally - - When constructing Transolver, you can choose to use "unified position" or not. If you select "unified - position" (`unified_pos=True`), then + r""" + Transolver model for physics-informed neural operator learning. - If using structured data, pass the structured shape as a tuple in the model constructor. - Length 2 tuples are assumed to be image-like, length 3 tuples are assumed to be 3D voxel like. - Other structured shape sizes are not supported. Passing a structured_shape of None assumes irregular data. + Transolver adapts the transformer architecture with a physics-attention + mechanism replacing standard attention. It can work on both structured + (2D/3D grids) and unstructured (mesh) data. - Output shape will have the same spatial shape as the input shape, with potentially more features + For architecture details, see: - Also can support Transolver++ implementation. When using the distributed algorithm - of Transolver++, use PhysicsNeMo's ShardTensor implementation to support automatic - domain parallelism and 2D parallelization (data parallel + domain parallel, for example). + - `Transolver paper `_ + - `Transolver++ paper `_ - Note - ---- + .. note:: + When using structured data, pass the ``structured_shape`` as a tuple. + Length-2 tuples are treated as 2D image-like data, length-3 tuples as + 3D volumetric data. Parameters ---------- functional_dim : int - The dimension of the input values, not including any embeddings. No Default. - Input will be concatenated with embeddings or unified position before processing - with PhysicsAttention blocks. Originally known as "fun_dim" + Dimension of input values, not including embeddings. out_dim : int - The dimension of the output of the model. This is a mandatory parameter. - embedding_dim : int | None - The spatial dimension of the input data embeddings. Should include not just - position but all computed embedding features. Default is None, but if - `unified_pos=False` this is a mandatory parameter. Originally named "space_dim" - n_layers : int - The number of transformer PhysicsAttention layers in the model. Default of 4. - n_hidden : int - The hidden dimension of the transformer. Default of 256. Projection is made - from the input data + embeddings in the early preprocessing, before the - PhysicsAttention layers. - dropout : float - The dropout rate, applied across the PhysicsAttention Layers. Default is 0.0 - n_head : int - The number of attention heads in each PhysicsAttention Layer. Default is 8. Note - that the number of heads must evenly divide the `n_hidden` parameter to yield an - integer head dimension. - act : str - The activation function, default is gelu. - mlp_ratio : int - The ratio of hidden dimension in the MLP, default is 4. Used in the MLPs in the - PhysicsAttention Layers. - slice_num : int - The number of slices in the PhysicsAttention layers. Default is 32. Represents the - number of learned states each layer should project inputs onto. - unified_pos : bool - Whether to use unified positional embeddings. Unified positions are only available for - structured data (2D grids, 3D grids). They are computed once initially, and reused through - training in place of embeddings. - ref : int - The reference dimension size when using unified positions. Default is 8. Will be - used to create a linear grid in spatial dimensions to serve as spatial embeddings. - If `unified_pos=False`, this value is unused. - structured_shape : None | tuple(int) - The shape of the latent space. If None, assumes irregular latent space. If not - `None`, this parameter can only be a length-2 or length-3 tuple of ints. - use_te: bool - Whether to use transformer engine backend when possible. - time_input : bool - Whether to include time embeddings. Default is false + Dimension of model output. + embedding_dim : int | None, optional, default=None + Dimension of input embeddings. Required if ``unified_pos=False``. + n_layers : int, optional, default=4 + Number of Transolver blocks. + n_hidden : int, optional, default=256 + Hidden dimension of the transformer. + dropout : float, optional, default=0.0 + Dropout rate. + n_head : int, optional, default=8 + Number of attention heads. Must evenly divide ``n_hidden``. + act : str, optional, default="gelu" + Activation function name. + mlp_ratio : int, optional, default=4 + Ratio of MLP hidden dimension to ``n_hidden``. + slice_num : int, optional, default=32 + Number of physics slices in attention layers. + unified_pos : bool, optional, default=False + Whether to use unified positional embeddings (structured data only). + ref : int, optional, default=8 + Reference grid size for unified position encoding. + structured_shape : None | tuple[int, ...], optional, default=None + Shape of structured data. ``None`` for unstructured mesh data. + use_te : bool, optional, default=True + Whether to use transformer engine. + time_input : bool, optional, default=False + Whether to include time embeddings. + plus : bool, optional, default=False + Whether to use Transolver++ variant. + + Forward + ------- + fx : torch.Tensor + Functional input tensor of shape :math:`(B, N, D_{func})` for flattened + data or :math:`(B, H, W, D_{func})` / :math:`(B, H, W, D, D_{func})` + for structured data. + embedding : torch.Tensor | None, optional + Embedding tensor. Required if ``unified_pos=False``. Shape should + match ``fx`` spatial dimensions. + time : torch.Tensor | None, optional + Time tensor of shape :math:`(B,)` for time-dependent models. + + Outputs + ------- + torch.Tensor + Output tensor with same spatial shape as input and ``out_dim`` features. + + Examples + -------- + Structured 2D data with unified position: + + >>> import torch + >>> from physicsnemo.models.transolver import Transolver + >>> model = Transolver( + ... functional_dim=3, + ... out_dim=1, + ... structured_shape=(64, 64), + ... unified_pos=True, + ... n_hidden=128, + ... n_head=4, + ... use_te=False, + ... ) + >>> x = torch.randn(2, 64, 64, 3) + >>> out = model(x) + >>> out.shape + torch.Size([2, 64, 64, 1]) + + Unstructured mesh data: + + >>> model = Transolver( + ... functional_dim=2, + ... embedding_dim=3, + ... out_dim=1, + ... structured_shape=None, + ... unified_pos=False, + ... n_hidden=128, + ... n_head=4, + ... use_te=False, + ... ) + >>> fx = torch.randn(2, 1000, 2) + >>> emb = torch.randn(2, 1000, 3) + >>> out = model(fx, embedding=emb) + >>> out.shape + torch.Size([2, 1000, 1]) """ def __init__( @@ -325,7 +506,7 @@ def __init__( slice_num: int = 32, unified_pos: bool = False, ref: int = 8, - structured_shape: None | tuple[int] = None, + structured_shape: None | tuple[int, ...] = None, use_te: bool = True, time_input: bool = False, plus: bool = False, @@ -334,53 +515,56 @@ def __init__( self.__name__ = "Transolver" self.use_te = use_te - # Check that the hidden dimension and head dimensions are compatible: + + # Validate hidden dimension and head compatibility if not n_hidden % n_head == 0: raise ValueError( - f"Transolver requires n_hidden % n_head == 0, but instead got {n_hidden % n_head}" + f"Transolver requires n_hidden % n_head == 0, " + f"but got n_hidden={n_hidden}, n_head={n_head} " + f"(remainder={n_hidden % n_head})" ) - # Check the shape of the data, if it's structured data: + # Validate structured shape if provided if structured_shape is not None: - # Has to be 2D or 3D data: if len(structured_shape) not in [2, 3]: raise ValueError( - f"Transolver can only use structured data in 2D or 3D, got {structured_shape}" + f"Transolver only supports 2D or 3D structured data, " + f"got shape with {len(structured_shape)} dimensions" ) - - # Ensure it's all integers > 0: if not all([s > 0 and s == int(s) for s in structured_shape]): raise ValueError( - f"Transolver can only use integer shapes > 0, got {structured_shape}" + f"Transolver requires positive integer shapes, " + f"got {structured_shape}" ) else: - # It's mandatory for unified position: if unified_pos: raise ValueError( - "Transolver requires structured_shape to be passed if using unified_pos=True" + "Transolver requires structured_shape when using unified_pos=True" ) self.structured_shape = structured_shape - - # If we're using the unified position, create and save the position embeddings: self.unified_pos = unified_pos + # Set up positional embeddings if unified_pos: if structured_shape is None: raise ValueError( - "Transolver can not use unified position without a structured_shape argument (got None)" + "Transolver cannot use unified position without " + "structured_shape argument (got None)" ) - - # This ensures embedding is tracked by torch and moves to the GPU, and saves/loads + # Register unified position embedding as buffer self.register_buffer("embedding", self.get_grid(ref)) self.embedding_dim = ref * ref mlp_input_dimension = functional_dim + ref * ref - else: + if embedding_dim is None: + raise ValueError( + "Transolver requires embedding_dim when unified_pos=False" + ) self.embedding_dim = embedding_dim mlp_input_dimension = functional_dim + embedding_dim - # This MLP is the initial projection onto the hidden space + # Initial projection MLP self.preprocess = MLP( mlp_input_dimension, n_hidden * 2, @@ -393,11 +577,16 @@ def __init__( self.time_input = time_input self.n_hidden = n_hidden + + # Time embedding projection if time_input: self.time_fc = nn.Sequential( - nn.Linear(n_hidden, n_hidden), nn.SiLU(), nn.Linear(n_hidden, n_hidden) + nn.Linear(n_hidden, n_hidden), + nn.SiLU(), + nn.Linear(n_hidden, n_hidden), ) + # Build transformer blocks self.blocks = nn.ModuleList( [ Transolver_block( @@ -418,55 +607,74 @@ def __init__( ) self.initialize_weights() - def initialize_weights(self): + def initialize_weights(self) -> None: + r"""Initialize model weights using truncated normal distribution.""" self.apply(self._init_weights) - def _init_weights(self, m): + def _init_weights(self, m: nn.Module) -> None: + r""" + Initialize weights for a single module. + + Parameters + ---------- + m : nn.Module + Module to initialize. + """ linear_layers = (nn.Linear,) if self.use_te: linear_layers = linear_layers + (te.Linear,) if isinstance(m, linear_layers): - nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.trunc_normal_(m.weight, std=0.02) # type: ignore[arg-type] if isinstance(m, linear_layers) and m.bias is not None: - nn.init.constant_(m.bias, 0) + nn.init.constant_(m.bias, 0) # type: ignore[arg-type] + norm_layers = (nn.LayerNorm, nn.BatchNorm1d) if self.use_te: norm_layers = norm_layers + (te.LayerNorm,) if isinstance(m, norm_layers): - nn.init.constant_(m.bias, 0) - nn.init.constant_(m.weight, 1.0) + nn.init.constant_(m.bias, 0) # type: ignore[arg-type] + nn.init.constant_(m.weight, 1.0) # type: ignore[arg-type] def get_grid(self, ref: int, batchsize: int = 1) -> torch.Tensor: - """ - Generate a unified positional encoding grid for structured 2D data. + r""" + Generate unified positional encoding grid for structured 2D data. Parameters ---------- ref : int - The reference grid size for the unified position encoding. - batchsize : int, optional - The batch size for the generated grid (default is 1). + Reference grid size for unified position encoding. + batchsize : int, optional, default=1 + Batch size for the generated grid. Returns ------- torch.Tensor - A tensor of shape (batchsize, H*W, ref*ref) containing the positional encodings, - where H and W are the spatial dimensions from self.structured_shape. + Positional encoding tensor of shape + :math:`(B, H \times W, \text{ref}^2)`. """ + if self.structured_shape is None: + raise ValueError( + "Cannot generate positional encoding grid: structured_shape is None. " + "This method requires structured_shape to be set." + ) size_x, size_y = self.structured_shape + + # Create spatial grid for the structured shape gridx = torch.tensor(np.linspace(0, 1, size_x), dtype=torch.float) gridx = gridx.reshape(1, size_x, 1, 1).repeat([batchsize, 1, size_y, 1]) gridy = torch.tensor(np.linspace(0, 1, size_y), dtype=torch.float) gridy = gridy.reshape(1, 1, size_y, 1).repeat([batchsize, size_x, 1, 1]) - grid = torch.cat((gridx, gridy), dim=-1) # B H W 2 + grid = torch.cat((gridx, gridy), dim=-1) # (B, H, W, 2) + # Create reference grid gridx = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float) gridx = gridx.reshape(1, ref, 1, 1).repeat([batchsize, 1, ref, 1]) gridy = torch.tensor(np.linspace(0, 1, ref), dtype=torch.float) gridy = gridy.reshape(1, 1, ref, 1).repeat([batchsize, ref, 1, 1]) - grid_ref = torch.cat((gridx, gridy), dim=-1) # B H W 8 8 2 + grid_ref = torch.cat((gridx, gridy), dim=-1) # (B, ref, ref, 2) + # Compute distance-based positional encoding pos = ( torch.sqrt( torch.sum( @@ -482,34 +690,49 @@ def get_grid(self, ref: int, batchsize: int = 1) -> torch.Tensor: def forward( self, - fx: torch.Tensor | None, - embedding: torch.Tensor | None = None, - time: torch.Tensor | None = None, - ) -> torch.Tensor: - """ - Forward pass of the transolver model. - - Args: - fx (torch.Tensor | None): Functional input tensor. For structured data, - shape should be [B, N, C] or [B, *structure, C]. For unstructured data, - shape should be [B, N, C]. Can be None if not used. - embedding (torch.Tensor | None, optional): Embedding tensor. For structured - data, shape should be [B, N, C] or [B, *structure, C]. For unstructured - data, shape should be [B, N, C]. Defaults to None. - time (torch.Tensor | None, optional): Optional time tensor. Shape and usage - depend on the model configuration. Defaults to None. + fx: Float[torch.Tensor, "batch *spatial functional_dim"], + embedding: Float[torch.Tensor, "batch *spatial embedding_dim"] | None = None, + time: Float[torch.Tensor, " batch"] | None = None, + ) -> Float[torch.Tensor, "batch *spatial out_dim"]: + r""" + Forward pass of the Transolver model. - Returns: - torch.Tensor: Output tensor with the same shape as the input. + Parameters + ---------- + fx : torch.Tensor + Functional input tensor. Shape :math:`(B, N, D_{func})` for + flattened data or :math:`(B, H, W, D_{func})` for structured 2D. + embedding : torch.Tensor | None, optional + Embedding tensor. Required if ``unified_pos=False``. + time : torch.Tensor | None, optional + Time tensor of shape :math:`(B,)` for time-dependent models. + Returns + ------- + torch.Tensor + Output tensor with same spatial shape as input. """ + # Input validation (skip during torch.compile for performance) + if not torch.compiler.is_compiling(): + if fx.ndim < 2: + raise ValueError( + f"Expected input tensor with at least 2 dimensions, " + f"got {fx.ndim}D tensor with shape {tuple(fx.shape)}" + ) + if not self.unified_pos and embedding is None: + raise ValueError("Embedding is required when unified_pos=False") + + # Track whether we need to unflatten output + unflatten_output = False + n_tokens = 0 + if self.unified_pos: - # Extend the embedding to the batch size: - embedding = self.embedding.repeat(fx.shape[0], 1, 1) + # Extend unified position embedding to batch size + emb_buffer: torch.Tensor = self.embedding # type: ignore[assignment] + embedding = emb_buffer.repeat(fx.shape[0], 1, 1) - # Reshape automatically, if necessary: + # Reshape structured data to flattened format if necessary if self.structured_shape is not None: - unflatten_output = False if len(fx.shape) != 3: unflatten_output = True fx = fx.reshape(fx.shape[0], -1, fx.shape[-1]) @@ -521,23 +744,28 @@ def forward( if embedding is None: raise ValueError("Embedding is required for unstructured data") - # Combine the embedding and functional input: + # Store n_tokens for time embedding + if embedding is not None: + n_tokens = embedding.shape[1] + + # Concatenate embedding with functional input if embedding is not None: fx = torch.cat((embedding, fx), -1) - # Apply preprocessing + # Project to hidden dimension fx = self.preprocess(fx) + # Add time embedding if provided if time is not None: - time_emb = timestep_embedding(time, self.n_hidden).repeat( - 1, embedding.shape[1], 1 - ) + time_emb = timestep_embedding(time, self.n_hidden).repeat(1, n_tokens, 1) time_emb = self.time_fc(time_emb) fx = fx + time_emb - for i, block in enumerate(self.blocks): + # Apply transformer blocks + for block in self.blocks: fx = block(fx) + # Reshape back to structured format if needed if self.structured_shape is not None: if unflatten_output: fx = fx.reshape(fx.shape[0], *self.structured_shape, -1) diff --git a/pyproject.toml b/pyproject.toml index 260758bb02..f042459c33 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,7 +65,7 @@ managed = true torch-sparse = ["torch"] torch-cluster = ["torch"] torch-scatter = ["torch"] -earth2grid = ["setuptools", "torch"] +earth2grid = ["setuptools", "torch"] [[tool.uv.index]] name = "nvidia" @@ -190,7 +190,7 @@ lint.fixable = ["I"] # Never enforce `E501` (line length violations), # and `S311` (random number generators) -lint.ignore = ["E501", "S311"] +lint.ignore = ["E501", "S311", "F722"] # Exclude the docs and experimental folders (this applies to both lint and format) exclude = ["docs", "physicsnemo/experimental"] diff --git a/test/models/transolver/test_transolver.py b/test/models/transolver/test_transolver.py index be52d93b56..37c85bd2cb 100644 --- a/test/models/transolver/test_transolver.py +++ b/test/models/transolver/test_transolver.py @@ -19,6 +19,7 @@ import pytest import torch +from physicsnemo.core.module import Module from physicsnemo.models.transolver import Transolver from test.common import ( check_ort_version, @@ -34,6 +35,65 @@ from test.conftest import requires_module +@pytest.mark.parametrize( + "config", + ["default_structured", "custom_irregular"], + ids=["with_defaults_structured", "with_custom_irregular"], +) +def test_transolver_constructor(config): + """Test Transolver model constructor and attributes per MOD-008a.""" + if config == "default_structured": + # Test with structured 2D data and default parameters + model = Transolver( + functional_dim=3, + out_dim=1, + structured_shape=(64, 64), + unified_pos=True, + use_te=False, + ) + # Verify default attribute values + assert model.n_hidden == 256, "Default n_hidden should be 256" + assert model.time_input is False, "Default time_input should be False" + assert model.unified_pos is True + assert model.structured_shape == (64, 64) + assert model.embedding_dim == 64 # ref * ref = 8 * 8 = 64 + assert len(model.blocks) == 4, "Default n_layers should be 4" + else: + # Test with irregular mesh data and custom parameters + model = Transolver( + functional_dim=2, + out_dim=4, + embedding_dim=3, + n_layers=8, + n_hidden=64, + dropout=0.1, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=16, + unified_pos=False, + structured_shape=None, + use_te=False, + time_input=True, + plus=True, + ) + # Verify custom attribute values + assert model.n_hidden == 64 + assert model.time_input is True + assert model.unified_pos is False + assert model.structured_shape is None + assert model.embedding_dim == 3 + assert len(model.blocks) == 8 + + # Common assertions for all configurations + assert isinstance(model, Module), ( + "Transolver should inherit from physicsnemo.Module" + ) + assert hasattr(model, "preprocess"), "Model should have preprocess MLP" + assert hasattr(model, "blocks"), "Model should have transformer blocks" + assert hasattr(model, "meta"), "Model should have metadata" + + def test_transolver2d_forward(device): """Test Transolver2D forward pass""" torch.manual_seed(0)