diff --git a/physicsnemo/models/rnn/layers.py b/physicsnemo/models/rnn/layers.py deleted file mode 100644 index 44cbdac27c..0000000000 --- a/physicsnemo/models/rnn/layers.py +++ /dev/null @@ -1,505 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from torch import Tensor - - -def _get_same_padding(x: int, k: int, s: int) -> int: - """Function to compute "same" padding. Inspired from: - https://github.com/huggingface/pytorch-image-models/blob/0.5.x/timm/models/layers/padding.py - """ - return max(s * math.ceil(x / s) - s - x + k, 0) - - -class _ConvLayer(nn.Module): - """Generalized Convolution Block - - Parameters - ---------- - in_channels : int - Number of input channels - out_channels : int - Number of output channels - dimension : int - Dimensionality of the input, 1, 2, 3, or 4 - kernel_size : int - Kernel size for the convolution - stride : int - Stride for the convolution, by default 1 - activation_fn : nn.Module, optional - Activation function to use, by default nn.Identity() - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - dimension: int, # TODO check if there are ways to infer this - kernel_size: int, - stride: int = 1, - activation_fn: nn.Module = nn.Identity(), - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.dimension = dimension - self.activation_fn = activation_fn - - if self.dimension == 1: - self.conv = nn.Conv1d( - self.in_channels, - self.out_channels, - self.kernel_size, - self.stride, - bias=True, - ) - elif self.dimension == 2: - self.conv = nn.Conv2d( - self.in_channels, - self.out_channels, - self.kernel_size, - self.stride, - bias=True, - ) - elif self.dimension == 3: - self.conv = nn.Conv3d( - self.in_channels, - self.out_channels, - self.kernel_size, - self.stride, - bias=True, - ) - else: - raise ValueError("Only 1D, 2D and 3D dimensions are supported") - - self.reset_parameters() - - def exec_activation_fn(self, x: Tensor) -> Tensor: - """Executes activation function on the input""" - return self.activation_fn(x) - - def reset_parameters(self) -> None: - """Initialization for network parameters""" - nn.init.constant_(self.conv.bias, 0) - nn.init.xavier_uniform_(self.conv.weight) - - def forward(self, x: Tensor) -> Tensor: - input_length = len(x.size()) - 2 # exclude channel and batch dims - if input_length != self.dimension: - raise ValueError("Input dimension not compatible") - - if input_length == 1: - iw = x.size()[-1:][0] - pad_w = _get_same_padding(iw, self.kernel_size, self.stride) - x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2], mode="constant", value=0.0) - elif input_length == 2: - ih, iw = x.size()[-2:] - pad_h, pad_w = ( - _get_same_padding(ih, self.kernel_size, self.stride), - _get_same_padding(iw, self.kernel_size, self.stride), - ) - x = F.pad( - x, - [pad_h // 2, pad_h - pad_h // 2, pad_w // 2, pad_w - pad_w // 2], - mode="constant", - value=0.0, - ) - else: - _id, ih, iw = x.size()[-3:] - pad_d, pad_h, pad_w = ( - _get_same_padding(_id, self.kernel_size, self.stride), - _get_same_padding(ih, self.kernel_size, self.stride), - _get_same_padding(iw, self.kernel_size, self.stride), - ) - x = F.pad( - x, - [ - pad_d // 2, - pad_d - pad_d // 2, - pad_h // 2, - pad_h - pad_h // 2, - pad_w // 2, - pad_w - pad_w // 2, - ], - mode="constant", - value=0.0, - ) - - x = self.conv(x) - - if self.activation_fn is not nn.Identity(): - x = self.exec_activation_fn(x) - - return x - - -class _TransposeConvLayer(nn.Module): - """Generalized Transposed Convolution Block - - Parameters - ---------- - in_channels : int - Number of input channels - out_channels : int - Number of output channels - dimension : int - Dimensionality of the input, 1, 2, 3, or 4 - kernel_size : int - Kernel size for the convolution - stride : int - Stride for the convolution, by default 1 - activation_fn : nn.Module, optional - Activation function to use, by default nn.Identity() - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - dimension: int, - kernel_size: int, - stride: int = 1, - activation_fn=nn.Identity(), - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.kernel_size = kernel_size - self.stride = stride - self.dimension = dimension - self.activation_fn = activation_fn - - if dimension == 1: - self.trans_conv = nn.ConvTranspose1d( - self.in_channels, - self.out_channels, - self.kernel_size, - self.stride, - bias=True, - ) - elif dimension == 2: - self.trans_conv = nn.ConvTranspose2d( - self.in_channels, - self.out_channels, - self.kernel_size, - self.stride, - bias=True, - ) - elif dimension == 3: - self.trans_conv = nn.ConvTranspose3d( - self.in_channels, - self.out_channels, - self.kernel_size, - self.stride, - bias=True, - ) - else: - raise ValueError("Only 1D, 2D and 3D dimensions are supported") - - self.reset_parameters() - - def exec_activation_fn(self, x: Tensor) -> Tensor: - """Executes activation function on the input""" - return self.activation_fn(x) - - def reset_parameters(self) -> None: - """Initialization for network parameters""" - nn.init.constant_(self.trans_conv.bias, 0) - nn.init.xavier_uniform_(self.trans_conv.weight) - - def forward(self, x: Tensor) -> Tensor: - orig_x = x - input_length = len(orig_x.size()) - 2 # exclude channel and batch dims - if input_length != self.dimension: - raise ValueError("Input dimension not compatible") - - x = self.trans_conv(x) - - if input_length == 1: - iw = orig_x.size()[-1:][0] - pad_w = _get_same_padding(iw, self.kernel_size, self.stride) - x = x[ - :, - :, - pad_w // 2 : x.size(-1) - (pad_w - pad_w // 2), - ] - elif input_length == 2: - ih, iw = orig_x.size()[-2:] - pad_h, pad_w = ( - _get_same_padding( - ih, - self.kernel_size, - self.stride, - ), - _get_same_padding(iw, self.kernel_size, self.stride), - ) - x = x[ - :, - :, - pad_h // 2 : x.size(-2) - (pad_h - pad_h // 2), - pad_w // 2 : x.size(-1) - (pad_w - pad_w // 2), - ] - else: - _id, ih, iw = orig_x.size()[-3:] - pad_d, pad_h, pad_w = ( - _get_same_padding(_id, self.kernel_size, self.stride), - _get_same_padding(ih, self.kernel_size, self.stride), - _get_same_padding(iw, self.kernel_size, self.stride), - ) - x = x[ - :, - :, - pad_d // 2 : x.size(-3) - (pad_d - pad_d // 2), - pad_h // 2 : x.size(-2) - (pad_h - pad_h // 2), - pad_w // 2 : x.size(-1) - (pad_w - pad_w // 2), - ] - - if self.activation_fn is not nn.Identity(): - x = self.exec_activation_fn(x) - - return x - - -class _ConvGRULayer(nn.Module): - """Convolutional GRU layer - - Parameters - ---------- - in_features : int - Input features/channels - hidden_size : int - Hidden layer features/channels - dimension : int - Spatial dimension of the input - activation_fn : nn.Module, optional - Activation Function to use, by default nn.ReLU() - """ - - def __init__( - self, - in_features: int, - hidden_size: int, - dimension: int, - activation_fn: nn.Module = nn.ReLU(), - ) -> None: - super().__init__() - self.in_features = in_features - self.hidden_size = hidden_size - self.activation_fn = activation_fn - self.conv_1 = _ConvLayer( - in_channels=in_features + hidden_size, - out_channels=2 * hidden_size, - kernel_size=3, - stride=1, - dimension=dimension, - ) - self.conv_2 = _ConvLayer( - in_channels=in_features + hidden_size, - out_channels=hidden_size, - kernel_size=3, - stride=1, - dimension=dimension, - ) - - def exec_activation_fn(self, x: Tensor) -> Tensor: - """Executes activation function on the input""" - return self.activation_fn(x) - - def forward(self, x: Tensor, hidden: Tensor) -> Tensor: - concat = torch.cat((x, hidden), dim=1) - conv_concat = self.conv_1(concat) - conv_r, conv_z = torch.split(conv_concat, self.hidden_size, 1) - - reset_gate = torch.special.expit(conv_r) - update_gate = torch.special.expit(conv_z) - concat = torch.cat((x, torch.mul(hidden, reset_gate)), dim=1) - n = self.exec_activation_fn(self.conv_2(concat)) - h_next = torch.mul((1 - update_gate), n) + torch.mul(update_gate, hidden) - - return h_next - - -class _ConvResidualBlock(nn.Module): - """Convolutional ResNet Block - - Parameters - ---------- - in_channels : int - Number of input channels - out_channels : int - Number of output channels - dimension : int - Dimensionality of the input - stride : int - Stride of the convolutions, by default 1 - gated : bool, optional - Residual Gate, by default False - layer_normalization : bool, optional - Layer Normalization, by default False - begin_activation_fn : bool, optional - Whether to use activation function in the beginning, by default True - activation_fn : nn.Module, optional - Activation function to use, by default nn.ReLU() - - Raises - ------ - ValueError - Stride not supported - """ - - def __init__( - self, - in_channels: int, - out_channels: int, - dimension: int, - stride: int = 1, - gated: bool = False, - layer_normalization: bool = False, - begin_activation_fn: bool = True, - activation_fn: nn.Module = nn.ReLU(), - ) -> None: - super().__init__() - self.in_channels = in_channels - self.out_channels = out_channels - self.stride = stride - self.dimension = dimension - self.gated = gated - self.layer_normalization = layer_normalization - self.begin_activation_fn = begin_activation_fn - self.activation_fn = activation_fn - - if self.stride == 1: - self.conv_1 = _ConvLayer( - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=self.stride, - dimension=self.dimension, - ) - elif self.stride == 2: - self.conv_1 = _ConvLayer( - in_channels=self.in_channels, - out_channels=self.out_channels, - kernel_size=4, - stride=self.stride, - dimension=self.dimension, - ) - else: - raise ValueError("stride > 2 is not supported") - - if not self.gated: - self.conv_2 = _ConvLayer( - in_channels=self.out_channels, - out_channels=self.out_channels, - kernel_size=3, - stride=1, - dimension=self.dimension, - ) - else: - self.conv_2 = _ConvLayer( - in_channels=self.out_channels, - out_channels=2 * self.out_channels, - kernel_size=3, - stride=1, - dimension=self.dimension, - ) - - def exec_activation_fn(self, x: Tensor) -> Tensor: - """Executes activation function on the input""" - return self.activation_fn(x) - - def forward(self, x: Tensor) -> Tensor: - orig_x = x - - if self.begin_activation_fn: - if self.layer_normalization: - layer_norm = nn.LayerNorm(x.size()[1:], elementwise_affine=False) - x = layer_norm(x) - x = self.exec_activation_fn(x) - - # first convolutional layer - x = self.conv_1(x) - - # add layer normalization - if self.layer_normalization: - layer_norm = nn.LayerNorm(x.size()[1:], elementwise_affine=False) - x = layer_norm(x) - - # second activation - x = self.exec_activation_fn(x) - # second convolutional layer - x = self.conv_2(x) - if self.gated: - x_1, x_2 = torch.split(x, x.size(1) // 2, 1) - x = x_1 * torch.special.expit(x_2) - - # possibly reshape skip connection - if orig_x.size(-1) > x.size(-1): # Check if widths are same) - if len(orig_x.size()) - 2 == 1: - iw = orig_x.size()[-1:][0] - pad_w = _get_same_padding(iw, 2, 2) - pool = torch.nn.AvgPool1d( - 2, 2, padding=pad_w // 2, count_include_pad=False - ) - elif len(orig_x.size()) - 2 == 2: - ih, iw = orig_x.size()[-2:] - pad_h, pad_w = ( - _get_same_padding( - ih, - 2, - 2, - ), - _get_same_padding(iw, 2, 2), - ) - pool = torch.nn.AvgPool2d( - 2, 2, padding=(pad_h // 2, pad_w // 2), count_include_pad=False - ) - elif len(orig_x.size()) - 2 == 3: - _id, ih, iw = orig_x.size()[-3:] - pad_d, pad_h, pad_w = ( - _get_same_padding(_id, 2, 2), - _get_same_padding(ih, 2, 2), - _get_same_padding(iw, 2, 2), - ) - pool = torch.nn.AvgPool3d( - 2, - 2, - padding=(pad_d // 2, pad_h // 2, pad_w // 2), - count_include_pad=False, - ) - else: - raise ValueError("Only 1D, 2D and 3D dimensions are supported") - orig_x = pool(orig_x) - - # possibly change the channels for skip connection - in_channels = int(orig_x.size(1)) - if self.out_channels > in_channels: - orig_x = F.pad( - orig_x, - (len(orig_x.size()) - 2) * (0, 0) - + (self.out_channels - self.in_channels, 0), - ) - elif self.out_channels < in_channels: - pass - - return orig_x + x diff --git a/physicsnemo/models/rnn/rnn_one2many.py b/physicsnemo/models/rnn/rnn_one2many.py index e11f092c0c..b20b2eb6ab 100644 --- a/physicsnemo/models/rnn/rnn_one2many.py +++ b/physicsnemo/models/rnn/rnn_one2many.py @@ -18,18 +18,19 @@ import torch import torch.nn as nn +from jaxtyping import Float from torch import Tensor import physicsnemo # noqa: F401 for docs from physicsnemo.core.meta import ModelMetaData from physicsnemo.core.module import Module -from physicsnemo.models.rnn.layers import ( - _ConvGRULayer, - _ConvLayer, - _ConvResidualBlock, - _TransposeConvLayer, -) from physicsnemo.nn import get_activation +from physicsnemo.nn.conv_layers import ( + ConvGRULayer, + ConvLayer, + ConvResidualBlock, + TransposeConvLayer, +) @dataclass @@ -48,38 +49,56 @@ class MetaData(ModelMetaData): class One2ManyRNN(Module): - """A RNN model with encoder/decoder for 2d/3d problems that provides predictions + r""" + A RNN model with encoder/decoder for 2D/3D problems that provides predictions based on single initial condition. Parameters ---------- input_channels : int - Number of channels in the input - dimension : int, optional - Spatial dimension of the input. Only 2d and 3d are supported, by default 2 - nr_latent_channels : int, optional - Channels for encoding/decoding, by default 512 - nr_residual_blocks : int, optional - Number of residual blocks, by default 2 - activation_fn : str, optional - Activation function to use, by default "relu" - nr_downsamples : int, optional - Number of downsamples, by default 2 - nr_tsteps : int, optional - Time steps to predict, by default 32 - - Example + Number of channels in the input. + dimension : int, optional, default=2 + Spatial dimension of the input. Only 2D and 3D are supported. + nr_latent_channels : int, optional, default=512 + Channels for encoding/decoding. + nr_residual_blocks : int, optional, default=2 + Number of residual blocks. + activation_fn : str, optional, default="relu" + Activation function to use. + nr_downsamples : int, optional, default=2 + Number of downsamples. + nr_tsteps : int, optional, default=32 + Time steps to predict. + + Forward ------- + x : torch.Tensor + Input tensor of shape :math:`(N, C, 1, H, W)` for 2D or + :math:`(N, C, 1, D, H, W)` for 3D, where :math:`N` is the batch size, + :math:`C` is the number of channels, ``1`` is the number of input + timesteps, and :math:`D, H, W` are spatial dimensions. + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(N, C, T, H, W)` for 2D or + :math:`(N, C, T, D, H, W)` for 3D, where :math:`T` is the number of + timesteps being predicted. + + Examples + -------- + >>> import torch + >>> import physicsnemo >>> model = physicsnemo.models.rnn.One2ManyRNN( - ... input_channels=6, - ... dimension=2, - ... nr_latent_channels=32, - ... activation_fn="relu", - ... nr_downsamples=2, - ... nr_tsteps=16, + ... input_channels=6, + ... dimension=2, + ... nr_latent_channels=32, + ... activation_fn="relu", + ... nr_downsamples=2, + ... nr_tsteps=16, ... ) - >>> input = invar = torch.randn(4, 6, 1, 16, 16) # [N, C, T, H, W] - >>> output = model(input) + >>> input_tensor = torch.randn(4, 6, 1, 16, 16) # [N, C, T, H, W] + >>> output = model(input_tensor) >>> output.size() torch.Size([4, 6, 16, 16, 16]) """ @@ -118,7 +137,7 @@ def __init__( channels_out = channels_out * 2 stride = 2 self.encoder_layers.append( - _ConvResidualBlock( + ConvResidualBlock( in_channels=channels_in, out_channels=channels_out, stride=stride, @@ -130,7 +149,7 @@ def __init__( ) ) - self.rnn_layer = _ConvGRULayer( + self.rnn_layer = ConvGRULayer( in_features=channels_out, hidden_size=channels_out, dimension=dimension ) @@ -141,7 +160,7 @@ def __init__( channels_in = channels_out channels_out = channels_out // 2 self.upsampling_layers.append( - _TransposeConvLayer( + TransposeConvLayer( in_channels=channels_in, out_channels=channels_out, kernel_size=4, @@ -151,7 +170,7 @@ def __init__( ) for j in range(nr_residual_blocks): self.upsampling_layers.append( - _ConvResidualBlock( + ConvResidualBlock( in_channels=channels_out, out_channels=channels_out, stride=1, @@ -163,7 +182,7 @@ def __init__( ) ) self.conv_layers.append( - _ConvLayer( + ConvLayer( in_channels=channels_in, out_channels=nr_latent_channels, kernel_size=1, @@ -187,54 +206,86 @@ def __init__( padding="valid", ) - def forward(self, x: Tensor) -> Tensor: - """Forward pass + def forward( + self, + x: Float[Tensor, "batch channels 1 ..."], # noqa: F722 + ) -> Float[Tensor, "batch channels tsteps ..."]: # noqa: F722 + r""" + Forward pass. Parameters ---------- - x : Tensor - Expects a tensor of size [N, C, 1, H, W] for 2D or [N, C, 1, D, H, W] for 3D - Where, N is the batch size, C is the number of channels, 1 is the number of - input timesteps and D, H, W are spatial dimensions. + x : torch.Tensor + Expects a tensor of shape :math:`(N, C, 1, H, W)` for 2D or + :math:`(N, C, 1, D, H, W)` for 3D. Where :math:`N` is the batch size, + :math:`C` is the number of channels, ``1`` is the number of input + timesteps, and :math:`D, H, W` are spatial dimensions. + Returns ------- - Tensor - Size [N, C, T, H, W] for 2D or [N, C, T, D, H, W] for 3D. - Where, T is the number of timesteps being predicted. + torch.Tensor + Size :math:`(N, C, T, H, W)` for 2D or :math:`(N, C, T, D, H, W)` for 3D, + where :math:`T` is the number of timesteps being predicted. """ - # Encoding step + ### Input validation + if not torch.compiler.is_compiling(): + # Check number of dimensions + expected_ndim = 5 if self.encoder_layers[0].dimension == 2 else 6 + if x.ndim != expected_ndim: + raise ValueError( + f"Expected {expected_ndim}D input tensor, " + f"got {x.ndim}D tensor with shape {tuple(x.shape)}" + ) + + # Check time dimension is 1 + if x.shape[2] != 1: + raise ValueError( + f"Expected input with 1 timestep (dimension 2), " + f"got {x.shape[2]} timesteps in tensor with shape {tuple(x.shape)}" + ) + + # Encoding step - encode the single input timestep encoded_inputs = [] for t in range(1): - x_in = x[:, :, t, ...] + x_in = x[:, :, t, ...] # (B, C, *spatial) + # Pass through encoder layers for layer in self.encoder_layers: x_in = layer(x_in) encoded_inputs.append(x_in) - # RNN step + # RNN step - autoregressively generate future timesteps rnn_output = [] for t in range(self.nr_tsteps): if t == 0: - h = torch.zeros(list(x_in.size())).to(x.device) + # Initialize hidden state to zeros + h = torch.zeros(list(x_in.size())).to( + x.device + ) # (B, C_latent, *spatial) x_in_rnn = encoded_inputs[0] - h = self.rnn_layer(x_in_rnn, h) + # Update hidden state + h = self.rnn_layer(x_in_rnn, h) # (B, C_latent, *spatial) x_in_rnn = h rnn_output.append(h) + # Decoding step - decode each hidden state to output decoded_output = [] for t in range(self.nr_tsteps): - x_out = rnn_output[t] - # Decoding step + x_out = rnn_output[t] # (B, C_latent, *spatial) + + # Multi-resolution decoding with skip connections latent_context_grid = [] for conv_layer, decoder in zip(self.conv_layers, self.decoder_layers): latent_context_grid.append(conv_layer(x_out)) upsampling_layers = decoder + # Progressively upsample for upsampling_layer in upsampling_layers: x_out = upsampling_layer(x_out) - # Add a convolution here to make the channel dimensions same as output - # Only last latent context grid is used, but mult-resolution is available - out = self.final_conv(latent_context_grid[-1]) + # Final convolution to match output channels + # Only last latent context grid is used, but multi-resolution is available + out = self.final_conv(latent_context_grid[-1]) # (B, C, *spatial) decoded_output.append(out) - decoded_output = torch.stack(decoded_output, dim=2) + # Stack outputs along time dimension + decoded_output = torch.stack(decoded_output, dim=2) # (B, C, T, *spatial) return decoded_output diff --git a/physicsnemo/models/rnn/rnn_seq2seq.py b/physicsnemo/models/rnn/rnn_seq2seq.py index 590e09b7f6..2eb4087192 100644 --- a/physicsnemo/models/rnn/rnn_seq2seq.py +++ b/physicsnemo/models/rnn/rnn_seq2seq.py @@ -18,18 +18,19 @@ import torch import torch.nn as nn +from jaxtyping import Float from torch import Tensor import physicsnemo # noqa: F401 for docs from physicsnemo.core.meta import ModelMetaData from physicsnemo.core.module import Module -from physicsnemo.models.rnn.layers import ( - _ConvGRULayer, - _ConvLayer, - _ConvResidualBlock, - _TransposeConvLayer, -) from physicsnemo.nn import get_activation +from physicsnemo.nn.conv_layers import ( + ConvGRULayer, + ConvLayer, + ConvResidualBlock, + TransposeConvLayer, +) @dataclass @@ -48,38 +49,57 @@ class MetaData(ModelMetaData): class Seq2SeqRNN(Module): - """A RNN model with encoder/decoder for 2d/3d problems. Given input 0 to t-1, - predicts signal t to t + nr_tsteps + r""" + A RNN model with encoder/decoder for 2D/3D problems. Given input from time 0 to + t-1, predicts signal from time t to t + nr_tsteps. Parameters ---------- input_channels : int - Number of channels in the input - dimension : int, optional - Spatial dimension of the input. Only 2d and 3d are supported, by default 2 - nr_latent_channels : int, optional - Channels for encoding/decoding, by default 512 - nr_residual_blocks : int, optional - Number of residual blocks, by default 2 - activation_fn : str, optional - Activation function to use, by default "relu" - nr_downsamples : int, optional - Number of downsamples, by default 2 - nr_tsteps : int, optional - Time steps to predict, by default 32 - - Example + Number of channels in the input. + dimension : int, optional, default=2 + Spatial dimension of the input. Only 2D and 3D are supported. + nr_latent_channels : int, optional, default=512 + Channels for encoding/decoding. + nr_residual_blocks : int, optional, default=2 + Number of residual blocks. + activation_fn : str, optional, default="relu" + Activation function to use. + nr_downsamples : int, optional, default=2 + Number of downsamples. + nr_tsteps : int, optional, default=32 + Time steps to predict. + + Forward ------- + x : torch.Tensor + Input tensor of shape :math:`(N, C, T, H, W)` for 2D or + :math:`(N, C, T, D, H, W)` for 3D, where :math:`N` is the batch size, + :math:`C` is the number of channels, :math:`T` is the number of input + timesteps, and :math:`D, H, W` are spatial dimensions. Currently, this + requires input time steps to be same as predicted time steps. + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(N, C, T, H, W)` for 2D or + :math:`(N, C, T, D, H, W)` for 3D, where :math:`T` is the number of + timesteps being predicted. + + Examples + -------- + >>> import torch + >>> import physicsnemo >>> model = physicsnemo.models.rnn.Seq2SeqRNN( - ... input_channels=6, - ... dimension=2, - ... nr_latent_channels=32, - ... activation_fn="relu", - ... nr_downsamples=2, - ... nr_tsteps=16, + ... input_channels=6, + ... dimension=2, + ... nr_latent_channels=32, + ... activation_fn="relu", + ... nr_downsamples=2, + ... nr_tsteps=16, ... ) - >>> input = invar = torch.randn(4, 6, 16, 16, 16) # [N, C, T, H, W] - >>> output = model(input) + >>> input_tensor = torch.randn(4, 6, 16, 16, 16) # [N, C, T, H, W] + >>> output = model(input_tensor) >>> output.size() torch.Size([4, 6, 16, 16, 16]) """ @@ -118,7 +138,7 @@ def __init__( channels_out = channels_out * 2 stride = 2 self.encoder_layers.append( - _ConvResidualBlock( + ConvResidualBlock( in_channels=channels_in, out_channels=channels_out, stride=stride, @@ -130,7 +150,7 @@ def __init__( ) ) - self.rnn_layer = _ConvGRULayer( + self.rnn_layer = ConvGRULayer( in_features=channels_out, hidden_size=channels_out, dimension=dimension ) @@ -141,7 +161,7 @@ def __init__( channels_in = channels_out channels_out = channels_out // 2 self.upsampling_layers.append( - _TransposeConvLayer( + TransposeConvLayer( in_channels=channels_in, out_channels=channels_out, kernel_size=4, @@ -151,7 +171,7 @@ def __init__( ) for j in range(nr_residual_blocks): self.upsampling_layers.append( - _ConvResidualBlock( + ConvResidualBlock( in_channels=channels_out, out_channels=channels_out, stride=1, @@ -163,7 +183,7 @@ def __init__( ) ) self.conv_layers.append( - _ConvLayer( + ConvLayer( in_channels=channels_in, out_channels=nr_latent_channels, kernel_size=1, @@ -187,62 +207,95 @@ def __init__( padding="valid", ) - def forward(self, x: Tensor) -> Tensor: - """Forward pass + def forward( + self, + x: Float[Tensor, "batch channels tsteps ..."], # noqa: F722 + ) -> Float[Tensor, "batch channels tsteps ..."]: # noqa: F722 + r""" + Forward pass. Parameters ---------- - x : Tensor - Expects a tensor of size [N, C, T, H, W] for 2D or [N, C, T, D, H, W] for 3D - Where, N is the batch size, C is the number of channels, T is the number of - input timesteps and D, H, W are spatial dimensions. Currently, this + x : torch.Tensor + Expects a tensor of shape :math:`(N, C, T, H, W)` for 2D or + :math:`(N, C, T, D, H, W)` for 3D. Where :math:`N` is the batch size, + :math:`C` is the number of channels, :math:`T` is the number of input + timesteps, and :math:`D, H, W` are spatial dimensions. Currently, this requires input time steps to be same as predicted time steps. + Returns ------- - Tensor - Size [N, C, T, H, W] for 2D or [N, C, T, D, H, W] for 3D. - Where, T is the number of timesteps being predicted. + torch.Tensor + Size :math:`(N, C, T, H, W)` for 2D or :math:`(N, C, T, D, H, W)` for 3D, + where :math:`T` is the number of timesteps being predicted. """ - # Encoding step + ### Input validation + if not torch.compiler.is_compiling(): + # Check number of dimensions + expected_ndim = 5 if self.encoder_layers[0].dimension == 2 else 6 + if x.ndim != expected_ndim: + raise ValueError( + f"Expected {expected_ndim}D input tensor, " + f"got {x.ndim}D tensor with shape {tuple(x.shape)}" + ) + + # Check time dimension matches nr_tsteps + if x.shape[2] != self.nr_tsteps: + raise ValueError( + f"Expected input with {self.nr_tsteps} timesteps (dimension 2), " + f"got {x.shape[2]} timesteps in tensor with shape {tuple(x.shape)}" + ) + + # Encoding step - encode all input timesteps encoded_inputs = [] for t in range(self.nr_tsteps): - x_in = x[:, :, t, ...] + x_in = x[:, :, t, ...] # (B, C, *spatial) + # Pass through encoder layers for layer in self.encoder_layers: x_in = layer(x_in) encoded_inputs.append(x_in) - # RNN step - # encode + # RNN step - encode all inputs into final hidden state for t in range(x.size(2)): # time dimension of the input signal if t == 0: - h = torch.zeros(list(x_in.size())).to(x.device) + # Initialize hidden state to zeros + h = torch.zeros(list(x_in.size())).to( + x.device + ) # (B, C_latent, *spatial) x_in_rnn = encoded_inputs[t] - h = self.rnn_layer(x_in_rnn, h) + # Update hidden state with current timestep + h = self.rnn_layer(x_in_rnn, h) # (B, C_latent, *spatial) - # decode + # RNN step - decode to generate future predictions rnn_output = [] for t in range(self.nr_tsteps): if t == 0: + # Start decoding from last encoded input x_in_rnn = encoded_inputs[-1] - h = self.rnn_layer(x_in_rnn, h) + # Autoregressively generate next hidden state + h = self.rnn_layer(x_in_rnn, h) # (B, C_latent, *spatial) x_in_rnn = h rnn_output.append(h) + # Decoding step - decode each hidden state to output decoded_output = [] for t in range(self.nr_tsteps): - x_out = rnn_output[t] - # Decoding step + x_out = rnn_output[t] # (B, C_latent, *spatial) + + # Multi-resolution decoding with skip connections latent_context_grid = [] for conv_layer, decoder in zip(self.conv_layers, self.decoder_layers): latent_context_grid.append(conv_layer(x_out)) upsampling_layers = decoder + # Progressively upsample for upsampling_layer in upsampling_layers: x_out = upsampling_layer(x_out) - # Add a convolution here to make the channel dimensions same as output - # Only last latent context grid is used, but mult-resolution is available - out = self.final_conv(latent_context_grid[-1]) + # Final convolution to match output channels + # Only last latent context grid is used, but multi-resolution is available + out = self.final_conv(latent_context_grid[-1]) # (B, C, *spatial) decoded_output.append(out) - decoded_output = torch.stack(decoded_output, dim=2) + # Stack outputs along time dimension + decoded_output = torch.stack(decoded_output, dim=2) # (B, C, T, *spatial) return decoded_output diff --git a/physicsnemo/nn/__init__.py b/physicsnemo/nn/__init__.py index 0960d03ec1..10224d508a 100644 --- a/physicsnemo/nn/__init__.py +++ b/physicsnemo/nn/__init__.py @@ -32,7 +32,15 @@ UNetAttention, ) from .ball_query import BQWarp -from .conv_layers import Conv2d, ConvBlock, CubeEmbedding +from .conv_layers import ( + Conv2d, + ConvBlock, + ConvGRULayer, + ConvLayer, + ConvResidualBlock, + CubeEmbedding, + TransposeConvLayer, +) from .dgm_layers import DGMLayer from .embedding_layers import FourierEmbedding, PositionalEmbedding from .fourier_layers import ( diff --git a/physicsnemo/nn/conv_layers.py b/physicsnemo/nn/conv_layers.py index d22f675f2d..700e506c1c 100644 --- a/physicsnemo/nn/conv_layers.py +++ b/physicsnemo/nn/conv_layers.py @@ -14,11 +14,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import math from typing import List import torch -from torch import nn +import torch.nn as nn +import torch.nn.functional as F +from jaxtyping import Float +from torch import Tensor +from physicsnemo.core.module import Module from physicsnemo.nn.utils.utils import _validate_amp from physicsnemo.nn.utils.weight_init import _weight_init @@ -105,6 +110,29 @@ def forward(self, x): return x + x_skip +def _get_same_padding(x: int, k: int, s: int) -> int: + r""" + Function to compute "same" padding. + + Inspired from: `timm padding `_ + + Parameters + ---------- + x : int + Input dimension size. + k : int + Kernel size. + s : int + Stride. + + Returns + ------- + int + Padding value to achieve "same" padding. + """ + return max(s * math.ceil(x / s) - s - x + k, 0) + + class Conv2d(torch.nn.Module): """ A custom 2D convolutional layer implementation with support for up-sampling, @@ -282,3 +310,671 @@ def forward(self, x): if b is not None and not self.fused_conv_bias: x = x.add_(b.reshape(1, -1, 1, 1)) return x + + +class ConvLayer(Module): + r""" + Generalized Convolution Block. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + dimension : int + Dimensionality of the input (1, 2, or 3). + kernel_size : int + Kernel size for the convolution. + stride : int, optional, default=1 + Stride for the convolution. + activation_fn : nn.Module, optional, default=nn.Identity() + Activation function to use. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, C_{in}, *)` where :math:`*` represents + spatial dimensions matching ``dimension``. + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(B, C_{out}, *)` where spatial dimensions + depend on stride and padding. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + dimension: int, + kernel_size: int, + stride: int = 1, + activation_fn: nn.Module = nn.Identity(), + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.dimension = dimension + self.activation_fn = activation_fn + + if self.dimension == 1: + self.conv = nn.Conv1d( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + bias=True, + ) + elif self.dimension == 2: + self.conv = nn.Conv2d( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + bias=True, + ) + elif self.dimension == 3: + self.conv = nn.Conv3d( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + bias=True, + ) + else: + raise ValueError("Only 1D, 2D and 3D dimensions are supported") + + self._reset_parameters() + + def _exec_activation_fn( + self, + x: Float[Tensor, "batch channels ..."], # noqa: F722 + ) -> Float[Tensor, "batch channels ..."]: # noqa: F722 + r""" + Executes activation function on the input. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, C, *)`. + + Returns + ------- + torch.Tensor + Output tensor of shape :math:`(B, C, *)`. + """ + return self.activation_fn(x) + + def _reset_parameters(self) -> None: + r""" + Initialization for network parameters. + """ + nn.init.constant_(self.conv.bias, 0) + nn.init.xavier_uniform_(self.conv.weight) + + def forward( + self, + x: Float[Tensor, "batch in_channels ..."], # noqa: F722 + ) -> Float[Tensor, "batch out_channels ..."]: # noqa: F722 + r"""Forward pass with same padding.""" + ### Input validation + if not torch.compiler.is_compiling(): + input_length = len(x.size()) - 2 # exclude channel and batch dims + if input_length != self.dimension: + raise ValueError( + f"Expected {self.dimension}D input tensor (excluding batch and channel dims), " + f"got {input_length}D tensor with shape {tuple(x.shape)}" + ) + + input_length = len(x.size()) - 2 # exclude channel and batch dims + + # Apply same padding based on dimensionality + if input_length == 1: + iw = x.size()[-1:][0] + pad_w = _get_same_padding(iw, self.kernel_size, self.stride) + x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2], mode="constant", value=0.0) + elif input_length == 2: + ih, iw = x.size()[-2:] + pad_h, pad_w = ( + _get_same_padding(ih, self.kernel_size, self.stride), + _get_same_padding(iw, self.kernel_size, self.stride), + ) + # F.pad expects padding in reverse dimension order: [left, right, top, bottom] + x = F.pad( + x, + [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], + mode="constant", + value=0.0, + ) + else: + _id, ih, iw = x.size()[-3:] + pad_d, pad_h, pad_w = ( + _get_same_padding(_id, self.kernel_size, self.stride), + _get_same_padding(ih, self.kernel_size, self.stride), + _get_same_padding(iw, self.kernel_size, self.stride), + ) + # F.pad expects padding in reverse dimension order: [left, right, top, bottom, front, back] + x = F.pad( + x, + [ + pad_w // 2, + pad_w - pad_w // 2, + pad_h // 2, + pad_h - pad_h // 2, + pad_d // 2, + pad_d - pad_d // 2, + ], + mode="constant", + value=0.0, + ) + + # Apply convolution + x = self.conv(x) + + # Apply activation if not identity + if self.activation_fn is not nn.Identity(): + x = self._exec_activation_fn(x) + + return x + + +class TransposeConvLayer(Module): + r""" + Generalized Transposed Convolution Block. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + dimension : int + Dimensionality of the input (1, 2, or 3). + kernel_size : int + Kernel size for the convolution. + stride : int, optional, default=1 + Stride for the convolution. + activation_fn : nn.Module, optional, default=nn.Identity() + Activation function to use. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, C_{in}, *)` where :math:`*` represents + spatial dimensions matching ``dimension``. + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(B, C_{out}, *)` where spatial dimensions + are upsampled based on stride. + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + dimension: int, + kernel_size: int, + stride: int = 1, + activation_fn=nn.Identity(), + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.kernel_size = kernel_size + self.stride = stride + self.dimension = dimension + self.activation_fn = activation_fn + + if dimension == 1: + self.trans_conv = nn.ConvTranspose1d( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + bias=True, + ) + elif dimension == 2: + self.trans_conv = nn.ConvTranspose2d( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + bias=True, + ) + elif dimension == 3: + self.trans_conv = nn.ConvTranspose3d( + self.in_channels, + self.out_channels, + self.kernel_size, + self.stride, + bias=True, + ) + else: + raise ValueError("Only 1D, 2D and 3D dimensions are supported") + + self._reset_parameters() + + def _exec_activation_fn( + self, + x: Float[Tensor, "batch channels ..."], # noqa: F722 + ) -> Float[Tensor, "batch channels ..."]: # noqa: F722 + r""" + Executes activation function on the input. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, C, *)`. + + Returns + ------- + torch.Tensor + Output tensor of shape :math:`(B, C, *)`. + """ + return self.activation_fn(x) + + def _reset_parameters(self) -> None: + r""" + Initialization for network parameters. + """ + nn.init.constant_(self.trans_conv.bias, 0) + nn.init.xavier_uniform_(self.trans_conv.weight) + + def forward( + self, + x: Float[Tensor, "batch in_channels ..."], # noqa: F722 + ) -> Float[Tensor, "batch out_channels ..."]: # noqa: F722 + r"""Forward pass with transposed convolution and cropping.""" + ### Input validation + if not torch.compiler.is_compiling(): + input_length = len(x.size()) - 2 # exclude channel and batch dims + if input_length != self.dimension: + raise ValueError( + f"Expected {self.dimension}D input tensor (excluding batch and channel dims), " + f"got {input_length}D tensor with shape {tuple(x.shape)}" + ) + + orig_x = x + input_length = len(orig_x.size()) - 2 # exclude channel and batch dims + + # Apply transposed convolution + x = self.trans_conv(x) + + # Crop output to match expected output size (same padding logic) + if input_length == 1: + iw = orig_x.size()[-1:][0] + pad_w = _get_same_padding(iw, self.kernel_size, self.stride) + x = x[ + :, + :, + pad_w // 2 : x.size(-1) - (pad_w - pad_w // 2), + ] + elif input_length == 2: + ih, iw = orig_x.size()[-2:] + pad_h, pad_w = ( + _get_same_padding( + ih, + self.kernel_size, + self.stride, + ), + _get_same_padding(iw, self.kernel_size, self.stride), + ) + x = x[ + :, + :, + pad_h // 2 : x.size(-2) - (pad_h - pad_h // 2), + pad_w // 2 : x.size(-1) - (pad_w - pad_w // 2), + ] + else: + _id, ih, iw = orig_x.size()[-3:] + pad_d, pad_h, pad_w = ( + _get_same_padding(_id, self.kernel_size, self.stride), + _get_same_padding(ih, self.kernel_size, self.stride), + _get_same_padding(iw, self.kernel_size, self.stride), + ) + x = x[ + :, + :, + pad_d // 2 : x.size(-3) - (pad_d - pad_d // 2), + pad_h // 2 : x.size(-2) - (pad_h - pad_h // 2), + pad_w // 2 : x.size(-1) - (pad_w - pad_w // 2), + ] + + # Apply activation if not identity + if self.activation_fn is not nn.Identity(): + x = self._exec_activation_fn(x) + + return x + + +class ConvGRULayer(Module): + r""" + Convolutional GRU layer. + + Parameters + ---------- + in_features : int + Input features/channels. + hidden_size : int + Hidden layer features/channels. + dimension : int + Spatial dimension of the input. + activation_fn : nn.Module, optional, default=nn.ReLU() + Activation Function to use. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, C_{in}, *)` where :math:`*` represents + spatial dimensions. + hidden : torch.Tensor + Hidden state tensor of shape :math:`(B, H, *)` where :math:`H` is + ``hidden_size``. + + Outputs + ------- + torch.Tensor + Next hidden state of shape :math:`(B, H, *)`. + """ + + def __init__( + self, + in_features: int, + hidden_size: int, + dimension: int, + activation_fn: nn.Module = nn.ReLU(), + ) -> None: + super().__init__() + self.in_features = in_features + self.hidden_size = hidden_size + self.activation_fn = activation_fn + self.conv_1 = ConvLayer( + in_channels=in_features + hidden_size, + out_channels=2 * hidden_size, + kernel_size=3, + stride=1, + dimension=dimension, + ) + self.conv_2 = ConvLayer( + in_channels=in_features + hidden_size, + out_channels=hidden_size, + kernel_size=3, + stride=1, + dimension=dimension, + ) + + def _exec_activation_fn( + self, + x: Float[Tensor, "batch channels ..."], # noqa: F722 + ) -> Float[Tensor, "batch channels ..."]: # noqa: F722 + r""" + Executes activation function on the input. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, C, *)`. + + Returns + ------- + torch.Tensor + Output tensor of shape :math:`(B, C, *)`. + """ + return self.activation_fn(x) + + def forward( + self, + x: Float[Tensor, "batch in_features ..."], # noqa: F722 + hidden: Float[Tensor, "batch hidden_size ..."], # noqa: F722 + ) -> Float[Tensor, "batch hidden_size ..."]: # noqa: F722 + r"""Forward pass implementing GRU update.""" + ### Input validation + if not torch.compiler.is_compiling(): + if x.shape[1] != self.in_features: + raise ValueError( + f"Expected input with {self.in_features} features, " + f"got {x.shape[1]} features in tensor with shape {tuple(x.shape)}" + ) + if hidden.shape[1] != self.hidden_size: + raise ValueError( + f"Expected hidden state with {self.hidden_size} features, " + f"got {hidden.shape[1]} features in tensor with shape {tuple(hidden.shape)}" + ) + if x.shape[0] != hidden.shape[0] or x.shape[2:] != hidden.shape[2:]: + raise ValueError( + f"Input and hidden state must have matching batch size and spatial dims. " + f"Got input shape {tuple(x.shape)} and hidden shape {tuple(hidden.shape)}" + ) + + # Concatenate input and hidden state + concat = torch.cat((x, hidden), dim=1) # (B, in_features + hidden_size, *) + + # Compute reset and update gates + conv_concat = self.conv_1(concat) # (B, 2 * hidden_size, *) + conv_r, conv_z = torch.split(conv_concat, self.hidden_size, 1) + + reset_gate = torch.special.expit(conv_r) # (B, hidden_size, *) + update_gate = torch.special.expit(conv_z) # (B, hidden_size, *) + + # Compute candidate hidden state + concat = torch.cat((x, torch.mul(hidden, reset_gate)), dim=1) + n = self._exec_activation_fn(self.conv_2(concat)) # (B, hidden_size, *) + + # Compute next hidden state + h_next = torch.mul((1 - update_gate), n) + torch.mul(update_gate, hidden) + + return h_next + + +class ConvResidualBlock(Module): + r""" + Convolutional ResNet Block. + + Parameters + ---------- + in_channels : int + Number of input channels. + out_channels : int + Number of output channels. + dimension : int + Dimensionality of the input. + stride : int, optional, default=1 + Stride of the convolutions. + gated : bool, optional, default=False + Residual Gate activation. + layer_normalization : bool, optional, default=False + Whether to apply layer normalization. + begin_activation_fn : bool, optional, default=True + Whether to use activation function in the beginning. + activation_fn : nn.Module, optional, default=nn.ReLU() + Activation function to use. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, C_{in}, *)` where :math:`*` represents + spatial dimensions matching ``dimension``. + + Outputs + ------- + torch.Tensor + Output tensor of shape :math:`(B, C_{out}, *)` with residual connection. + + Raises + ------ + ValueError + If stride > 2 (not supported). + """ + + def __init__( + self, + in_channels: int, + out_channels: int, + dimension: int, + stride: int = 1, + gated: bool = False, + layer_normalization: bool = False, + begin_activation_fn: bool = True, + activation_fn: nn.Module = nn.ReLU(), + ) -> None: + super().__init__() + self.in_channels = in_channels + self.out_channels = out_channels + self.stride = stride + self.dimension = dimension + self.gated = gated + self.layer_normalization = layer_normalization + self.begin_activation_fn = begin_activation_fn + self.activation_fn = activation_fn + + if self.stride == 1: + self.conv_1 = ConvLayer( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=self.stride, + dimension=self.dimension, + ) + elif self.stride == 2: + self.conv_1 = ConvLayer( + in_channels=self.in_channels, + out_channels=self.out_channels, + kernel_size=4, + stride=self.stride, + dimension=self.dimension, + ) + else: + raise ValueError("stride > 2 is not supported") + + if not self.gated: + self.conv_2 = ConvLayer( + in_channels=self.out_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + dimension=self.dimension, + ) + else: + self.conv_2 = ConvLayer( + in_channels=self.out_channels, + out_channels=2 * self.out_channels, + kernel_size=3, + stride=1, + dimension=self.dimension, + ) + + def _exec_activation_fn( + self, + x: Float[Tensor, "batch channels ..."], # noqa: F722 + ) -> Float[Tensor, "batch channels ..."]: # noqa: F722 + r""" + Executes activation function on the input. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, C, *)`. + + Returns + ------- + torch.Tensor + Output tensor of shape :math:`(B, C, *)`. + """ + return self.activation_fn(x) + + def forward( + self, + x: Float[Tensor, "batch in_channels ..."], # noqa: F722 + ) -> Float[Tensor, "batch out_channels ..."]: # noqa: F722 + r"""Forward pass with residual connection.""" + ### Input validation + if not torch.compiler.is_compiling(): + input_length = len(x.size()) - 2 # exclude channel and batch dims + if input_length != self.dimension: + raise ValueError( + f"Expected {self.dimension}D input tensor (excluding batch and channel dims), " + f"got {input_length}D tensor with shape {tuple(x.shape)}" + ) + + orig_x = x + + # Apply layer normalization and activation at the beginning if specified + if self.begin_activation_fn: + if self.layer_normalization: + layer_norm = nn.LayerNorm(x.size()[1:], elementwise_affine=False) + x = layer_norm(x) + x = self._exec_activation_fn(x) + + # First convolutional layer + x = self.conv_1(x) + + # Apply layer normalization after first convolution + if self.layer_normalization: + layer_norm = nn.LayerNorm(x.size()[1:], elementwise_affine=False) + x = layer_norm(x) + + # Second activation and convolution + x = self._exec_activation_fn(x) + x = self.conv_2(x) + + # Apply gating if specified + if self.gated: + x_1, x_2 = torch.split(x, x.size(1) // 2, 1) + x = x_1 * torch.special.expit(x_2) + + # Adjust skip connection if spatial dimensions differ (due to stride) + if orig_x.size(-1) > x.size(-1): # Check if widths are different + if len(orig_x.size()) - 2 == 1: + iw = orig_x.size()[-1:][0] + pad_w = _get_same_padding(iw, 2, 2) + pool = torch.nn.AvgPool1d( + 2, 2, padding=pad_w // 2, count_include_pad=False + ) + elif len(orig_x.size()) - 2 == 2: + ih, iw = orig_x.size()[-2:] + pad_h, pad_w = ( + _get_same_padding( + ih, + 2, + 2, + ), + _get_same_padding(iw, 2, 2), + ) + pool = torch.nn.AvgPool2d( + 2, 2, padding=(pad_h // 2, pad_w // 2), count_include_pad=False + ) + elif len(orig_x.size()) - 2 == 3: + _id, ih, iw = orig_x.size()[-3:] + pad_d, pad_h, pad_w = ( + _get_same_padding(_id, 2, 2), + _get_same_padding(ih, 2, 2), + _get_same_padding(iw, 2, 2), + ) + pool = torch.nn.AvgPool3d( + 2, + 2, + padding=(pad_d // 2, pad_h // 2, pad_w // 2), + count_include_pad=False, + ) + else: + raise ValueError("Only 1D, 2D and 3D dimensions are supported") + orig_x = pool(orig_x) + + # Adjust skip connection channels if needed + in_channels = int(orig_x.size(1)) + if self.out_channels > in_channels: + orig_x = F.pad( + orig_x, + (len(orig_x.size()) - 2) * (0, 0) + + (self.out_channels - self.in_channels, 0), + ) + elif self.out_channels < in_channels: + pass + + return orig_x + x diff --git a/test/models/rnn/data/conv_layer_output.pth b/test/models/rnn/data/conv_layer_output.pth new file mode 100644 index 0000000000..af08e17d54 Binary files /dev/null and b/test/models/rnn/data/conv_layer_output.pth differ diff --git a/test/models/rnn/data/conv_rnn_one2many_2d_checkpoint.mdlus b/test/models/rnn/data/conv_rnn_one2many_2d_checkpoint.mdlus new file mode 100644 index 0000000000..70b0bce79a Binary files /dev/null and b/test/models/rnn/data/conv_rnn_one2many_2d_checkpoint.mdlus differ diff --git a/test/models/rnn/data/conv_rnn_one2many_2d_output.pth b/test/models/rnn/data/conv_rnn_one2many_2d_output.pth index 263eba2605..842b51800e 100644 Binary files a/test/models/rnn/data/conv_rnn_one2many_2d_output.pth and b/test/models/rnn/data/conv_rnn_one2many_2d_output.pth differ diff --git a/test/models/rnn/data/conv_rnn_one2many_3d_checkpoint.mdlus b/test/models/rnn/data/conv_rnn_one2many_3d_checkpoint.mdlus new file mode 100644 index 0000000000..b01f526aef Binary files /dev/null and b/test/models/rnn/data/conv_rnn_one2many_3d_checkpoint.mdlus differ diff --git a/test/models/rnn/data/conv_rnn_one2many_3d_output.pth b/test/models/rnn/data/conv_rnn_one2many_3d_output.pth index 3c58efcf85..228c81889d 100644 Binary files a/test/models/rnn/data/conv_rnn_one2many_3d_output.pth and b/test/models/rnn/data/conv_rnn_one2many_3d_output.pth differ diff --git a/test/models/rnn/data/conv_rnn_seq2seq_2d_checkpoint.mdlus b/test/models/rnn/data/conv_rnn_seq2seq_2d_checkpoint.mdlus new file mode 100644 index 0000000000..a8836622fc Binary files /dev/null and b/test/models/rnn/data/conv_rnn_seq2seq_2d_checkpoint.mdlus differ diff --git a/test/models/rnn/data/conv_rnn_seq2seq_2d_output.pth b/test/models/rnn/data/conv_rnn_seq2seq_2d_output.pth index 58ac0ef215..163e7ebf0e 100644 Binary files a/test/models/rnn/data/conv_rnn_seq2seq_2d_output.pth and b/test/models/rnn/data/conv_rnn_seq2seq_2d_output.pth differ diff --git a/test/models/rnn/data/conv_rnn_seq2seq_3d_checkpoint.mdlus b/test/models/rnn/data/conv_rnn_seq2seq_3d_checkpoint.mdlus new file mode 100644 index 0000000000..a645dabe95 Binary files /dev/null and b/test/models/rnn/data/conv_rnn_seq2seq_3d_checkpoint.mdlus differ diff --git a/test/models/rnn/data/conv_rnn_seq2seq_3d_output.pth b/test/models/rnn/data/conv_rnn_seq2seq_3d_output.pth index b8d64e3912..0e33cc2948 100644 Binary files a/test/models/rnn/data/conv_rnn_seq2seq_3d_output.pth and b/test/models/rnn/data/conv_rnn_seq2seq_3d_output.pth differ diff --git a/test/models/rnn/data/residual_block_output.pth b/test/models/rnn/data/residual_block_output.pth new file mode 100644 index 0000000000..70b3142bb4 Binary files /dev/null and b/test/models/rnn/data/residual_block_output.pth differ diff --git a/test/models/rnn/data/transconv_layer_output.pth b/test/models/rnn/data/transconv_layer_output.pth new file mode 100644 index 0000000000..3d92d2617d Binary files /dev/null and b/test/models/rnn/data/transconv_layer_output.pth differ diff --git a/test/models/rnn/test_rnn.py b/test/models/rnn/test_rnn.py index aea215f45b..91bdc7085b 100644 --- a/test/models/rnn/test_rnn.py +++ b/test/models/rnn/test_rnn.py @@ -19,6 +19,7 @@ import pytest import torch +import physicsnemo from physicsnemo.models.rnn.rnn_one2many import One2ManyRNN from physicsnemo.models.rnn.rnn_seq2seq import Seq2SeqRNN from test import common @@ -87,6 +88,23 @@ def test_conv_rnn_one2many_checkpoint(device, dimension): assert common.validate_checkpoint(model_1, model_2, (invar,)) +@pytest.mark.parametrize("dimension", [2, 3]) +def test_conv_rnn_one2many_load_checkpoint(device, dimension): + """Test loading model from pre-saved checkpoint file""" + from pathlib import Path + + test_dir = Path(__file__).parent.resolve() + checkpoint_path = test_dir / f"data/conv_rnn_one2many_{dimension}d_checkpoint.mdlus" + + # Load model from checkpoint file + model = physicsnemo.Module.from_checkpoint(str(checkpoint_path)).to(device) + + # Verify model attributes match expected values + assert model.nr_tsteps == 4 + assert model.nr_residual_blocks == 2 + assert model.nr_downsamples == 2 + + @pytest.mark.parametrize("dimension", [2, 3]) def test_conv_rnn_one2many_optimizations(device, dimension): """Test model optimizations""" @@ -118,9 +136,26 @@ def setup_model(): def test_conv_rnn_one2many_constructor(device): - """Test model constructor""" + """Test model constructor options""" - # Define dictionary of constructor args + # Test with default parameters + model = One2ManyRNN( + input_channels=1, + dimension=2, + ).to(device) + + # Check default attribute values + assert model.nr_tsteps == 32 + assert model.nr_residual_blocks == 2 + assert model.nr_downsamples == 2 + + # Test forward pass with defaults + bsize = 2 + invar = torch.randn(bsize, 1, 1, 16, 16).to(device) + outvar = model(invar) + assert outvar.shape == (bsize, 1, 32, 16, 16) + + # Define dictionary of constructor args with custom parameters arg_list = [ { "input_channels": 1, @@ -129,6 +164,7 @@ def test_conv_rnn_one2many_constructor(device): "activation_fn": "relu", "nr_downsamples": random.randint(2, 3), "nr_tsteps": random.randint(8, 16), + "nr_residual_blocks": random.randint(1, 3), } for dimension in [2, 3] ] @@ -137,6 +173,11 @@ def test_conv_rnn_one2many_constructor(device): # Construct model model = One2ManyRNN(**kw_args).to(device) + # Check that public attributes match constructor arguments + assert model.nr_tsteps == kw_args["nr_tsteps"] + assert model.nr_residual_blocks == kw_args["nr_residual_blocks"] + assert model.nr_downsamples == kw_args["nr_downsamples"] + bsize = random.randint(1, 4) if kw_args["dimension"] == 2: invar = torch.randn(bsize, kw_args["input_channels"], 1, 8, 8).to(device) @@ -229,6 +270,23 @@ def test_conv_rnn_seq2seq_checkpoint(device, dimension): assert common.validate_checkpoint(model_1, model_2, (invar,)) +@pytest.mark.parametrize("dimension", [2, 3]) +def test_conv_rnn_seq2seq_load_checkpoint(device, dimension): + """Test loading model from pre-saved checkpoint file""" + from pathlib import Path + + test_dir = Path(__file__).parent.resolve() + checkpoint_path = test_dir / f"data/conv_rnn_seq2seq_{dimension}d_checkpoint.mdlus" + + # Load model from checkpoint file + model = physicsnemo.Module.from_checkpoint(str(checkpoint_path)).to(device) + + # Verify model attributes match expected values + assert model.nr_tsteps == 4 + assert model.nr_residual_blocks == 2 + assert model.nr_downsamples == 2 + + @pytest.mark.parametrize("dimension", [2, 3]) def test_conv_rnn_seq2seq_optimizations(device, dimension): """Test model optimizations""" @@ -260,9 +318,26 @@ def setup_model(): def test_conv_rnn_seq2seq_constructor(device): - """Test model constructor""" + """Test model constructor options""" + + # Test with default parameters + model = Seq2SeqRNN( + input_channels=1, + dimension=2, + ).to(device) - # Define dictionary of constructor args + # Check default attribute values + assert model.nr_tsteps == 32 + assert model.nr_residual_blocks == 2 + assert model.nr_downsamples == 2 + + # Test forward pass with defaults + bsize = 2 + invar = torch.randn(bsize, 1, 32, 16, 16).to(device) + outvar = model(invar) + assert outvar.shape == (bsize, 1, 32, 16, 16) + + # Define dictionary of constructor args with custom parameters arg_list = [ { "input_channels": 1, @@ -271,13 +346,19 @@ def test_conv_rnn_seq2seq_constructor(device): "activation_fn": "relu", "nr_downsamples": random.randint(2, 3), "nr_tsteps": random.randint(2, 4), + "nr_residual_blocks": random.randint(1, 3), } for dimension in [2, 3] ] for kw_args in arg_list: # Construct model - model = One2ManyRNN(**kw_args).to(device) + model = Seq2SeqRNN(**kw_args).to(device) + + # Check that public attributes match constructor arguments + assert model.nr_tsteps == kw_args["nr_tsteps"] + assert model.nr_residual_blocks == kw_args["nr_residual_blocks"] + assert model.nr_downsamples == kw_args["nr_downsamples"] bsize = random.randint(1, 4) if kw_args["dimension"] == 2: diff --git a/test/models/rnn/test_rnn_layers.py b/test/models/rnn/test_rnn_layers.py index 15d300f934..30be83f190 100644 --- a/test/models/rnn/test_rnn_layers.py +++ b/test/models/rnn/test_rnn_layers.py @@ -17,11 +17,12 @@ import pytest import torch -from physicsnemo.models.rnn.layers import ( - _ConvLayer, - _ConvResidualBlock, - _TransposeConvLayer, +from physicsnemo.nn.conv_layers import ( + ConvLayer, + ConvResidualBlock, + TransposeConvLayer, ) +from test import common @pytest.mark.parametrize("activation_fn", [torch.nn.ReLU(), torch.nn.Identity()]) @@ -34,7 +35,7 @@ def test_conv_layer(activation_fn, stride, dimension): out_channels = 16 kernel_size = 3 - layer = _ConvLayer( + layer = ConvLayer( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -67,7 +68,7 @@ def test_transconv_layer(activation_fn, stride, dimension): out_channels = 16 kernel_size = 3 - layer = _TransposeConvLayer( + layer = TransposeConvLayer( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, @@ -109,7 +110,7 @@ def test_residual_block_layer( out_channels = 16 # Just test constructor - layer = _ConvResidualBlock( + layer = ConvResidualBlock( in_channels=in_channels, out_channels=out_channels, dimension=dimension, @@ -129,3 +130,74 @@ def test_residual_block_layer( size_out = fig_size assert outvar.shape == (bsize, out_channels) + dimension * (size_out,) + + +def test_conv_layer_forward_accuracy(device): + """Test ConvLayer forward pass accuracy""" + torch.manual_seed(0) + + layer = ConvLayer( + in_channels=4, + out_channels=8, + kernel_size=3, + stride=1, + dimension=2, + activation_fn=torch.nn.ReLU(), + ).to(device) + + invar = torch.randn(2, 4, 16, 16).to(device) + + assert common.validate_forward_accuracy( + layer, + (invar,), + file_name="models/rnn/data/conv_layer_output.pth", + atol=1e-4, + ) + + +def test_transconv_layer_forward_accuracy(device): + """Test TransposeConvLayer forward pass accuracy""" + torch.manual_seed(0) + + layer = TransposeConvLayer( + in_channels=8, + out_channels=4, + kernel_size=3, + stride=2, + dimension=2, + activation_fn=torch.nn.ReLU(), + ).to(device) + + invar = torch.randn(2, 8, 8, 8).to(device) + + assert common.validate_forward_accuracy( + layer, + (invar,), + file_name="models/rnn/data/transconv_layer_output.pth", + atol=1e-4, + ) + + +def test_residual_block_forward_accuracy(device): + """Test ConvResidualBlock forward pass accuracy""" + torch.manual_seed(0) + + layer = ConvResidualBlock( + in_channels=8, + out_channels=8, + dimension=2, + stride=1, + gated=True, + layer_normalization=False, + activation_fn=torch.nn.ReLU(), + begin_activation_fn=True, + ).to(device) + + invar = torch.randn(2, 8, 16, 16).to(device) + + assert common.validate_forward_accuracy( + layer, + (invar,), + file_name="models/rnn/data/residual_block_output.pth", + atol=1e-3, + )