Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
505 changes: 0 additions & 505 deletions physicsnemo/models/rnn/layers.py

This file was deleted.

161 changes: 106 additions & 55 deletions physicsnemo/models/rnn/rnn_one2many.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,18 +18,19 @@

import torch
import torch.nn as nn
from jaxtyping import Float
from torch import Tensor

import physicsnemo # noqa: F401 for docs
from physicsnemo.core.meta import ModelMetaData
from physicsnemo.core.module import Module
from physicsnemo.models.rnn.layers import (
_ConvGRULayer,
_ConvLayer,
_ConvResidualBlock,
_TransposeConvLayer,
)
from physicsnemo.nn import get_activation
from physicsnemo.nn.conv_layers import (
ConvGRULayer,
ConvLayer,
ConvResidualBlock,
TransposeConvLayer,
)


@dataclass
Expand All @@ -48,38 +49,56 @@ class MetaData(ModelMetaData):


class One2ManyRNN(Module):
"""A RNN model with encoder/decoder for 2d/3d problems that provides predictions
r"""
A RNN model with encoder/decoder for 2D/3D problems that provides predictions
based on single initial condition.

Parameters
----------
input_channels : int
Number of channels in the input
dimension : int, optional
Spatial dimension of the input. Only 2d and 3d are supported, by default 2
nr_latent_channels : int, optional
Channels for encoding/decoding, by default 512
nr_residual_blocks : int, optional
Number of residual blocks, by default 2
activation_fn : str, optional
Activation function to use, by default "relu"
nr_downsamples : int, optional
Number of downsamples, by default 2
nr_tsteps : int, optional
Time steps to predict, by default 32

Example
Number of channels in the input.
dimension : int, optional, default=2
Spatial dimension of the input. Only 2D and 3D are supported.
nr_latent_channels : int, optional, default=512
Channels for encoding/decoding.
nr_residual_blocks : int, optional, default=2
Number of residual blocks.
activation_fn : str, optional, default="relu"
Activation function to use.
nr_downsamples : int, optional, default=2
Number of downsamples.
nr_tsteps : int, optional, default=32
Time steps to predict.

Forward
-------
x : torch.Tensor
Input tensor of shape :math:`(N, C, 1, H, W)` for 2D or
:math:`(N, C, 1, D, H, W)` for 3D, where :math:`N` is the batch size,
:math:`C` is the number of channels, ``1`` is the number of input
timesteps, and :math:`D, H, W` are spatial dimensions.

Outputs
-------
torch.Tensor
Output tensor of shape :math:`(N, C, T, H, W)` for 2D or
:math:`(N, C, T, D, H, W)` for 3D, where :math:`T` is the number of
timesteps being predicted.

Examples
--------
>>> import torch
>>> import physicsnemo
>>> model = physicsnemo.models.rnn.One2ManyRNN(
... input_channels=6,
... dimension=2,
... nr_latent_channels=32,
... activation_fn="relu",
... nr_downsamples=2,
... nr_tsteps=16,
... input_channels=6,
... dimension=2,
... nr_latent_channels=32,
... activation_fn="relu",
... nr_downsamples=2,
... nr_tsteps=16,
... )
>>> input = invar = torch.randn(4, 6, 1, 16, 16) # [N, C, T, H, W]
>>> output = model(input)
>>> input_tensor = torch.randn(4, 6, 1, 16, 16) # [N, C, T, H, W]
>>> output = model(input_tensor)
>>> output.size()
torch.Size([4, 6, 16, 16, 16])
"""
Expand Down Expand Up @@ -118,7 +137,7 @@ def __init__(
channels_out = channels_out * 2
stride = 2
self.encoder_layers.append(
_ConvResidualBlock(
ConvResidualBlock(
in_channels=channels_in,
out_channels=channels_out,
stride=stride,
Expand All @@ -130,7 +149,7 @@ def __init__(
)
)

self.rnn_layer = _ConvGRULayer(
self.rnn_layer = ConvGRULayer(
in_features=channels_out, hidden_size=channels_out, dimension=dimension
)

Expand All @@ -141,7 +160,7 @@ def __init__(
channels_in = channels_out
channels_out = channels_out // 2
self.upsampling_layers.append(
_TransposeConvLayer(
TransposeConvLayer(
in_channels=channels_in,
out_channels=channels_out,
kernel_size=4,
Expand All @@ -151,7 +170,7 @@ def __init__(
)
for j in range(nr_residual_blocks):
self.upsampling_layers.append(
_ConvResidualBlock(
ConvResidualBlock(
in_channels=channels_out,
out_channels=channels_out,
stride=1,
Expand All @@ -163,7 +182,7 @@ def __init__(
)
)
self.conv_layers.append(
_ConvLayer(
ConvLayer(
in_channels=channels_in,
out_channels=nr_latent_channels,
kernel_size=1,
Expand All @@ -187,54 +206,86 @@ def __init__(
padding="valid",
)

def forward(self, x: Tensor) -> Tensor:
"""Forward pass
def forward(
self,
x: Float[Tensor, "batch channels 1 ..."], # noqa: F722
) -> Float[Tensor, "batch channels tsteps ..."]: # noqa: F722
r"""
Forward pass.

Parameters
----------
x : Tensor
Expects a tensor of size [N, C, 1, H, W] for 2D or [N, C, 1, D, H, W] for 3D
Where, N is the batch size, C is the number of channels, 1 is the number of
input timesteps and D, H, W are spatial dimensions.
x : torch.Tensor
Expects a tensor of shape :math:`(N, C, 1, H, W)` for 2D or
:math:`(N, C, 1, D, H, W)` for 3D. Where :math:`N` is the batch size,
:math:`C` is the number of channels, ``1`` is the number of input
timesteps, and :math:`D, H, W` are spatial dimensions.

Returns
-------
Tensor
Size [N, C, T, H, W] for 2D or [N, C, T, D, H, W] for 3D.
Where, T is the number of timesteps being predicted.
torch.Tensor
Size :math:`(N, C, T, H, W)` for 2D or :math:`(N, C, T, D, H, W)` for 3D,
where :math:`T` is the number of timesteps being predicted.
"""
# Encoding step
### Input validation
if not torch.compiler.is_compiling():
# Check number of dimensions
expected_ndim = 5 if self.encoder_layers[0].dimension == 2 else 6
if x.ndim != expected_ndim:
raise ValueError(
f"Expected {expected_ndim}D input tensor, "
f"got {x.ndim}D tensor with shape {tuple(x.shape)}"
)

# Check time dimension is 1
if x.shape[2] != 1:
raise ValueError(
f"Expected input with 1 timestep (dimension 2), "
f"got {x.shape[2]} timesteps in tensor with shape {tuple(x.shape)}"
)

# Encoding step - encode the single input timestep
encoded_inputs = []
for t in range(1):
x_in = x[:, :, t, ...]
x_in = x[:, :, t, ...] # (B, C, *spatial)
# Pass through encoder layers
for layer in self.encoder_layers:
x_in = layer(x_in)
encoded_inputs.append(x_in)

# RNN step
# RNN step - autoregressively generate future timesteps
rnn_output = []
for t in range(self.nr_tsteps):
if t == 0:
h = torch.zeros(list(x_in.size())).to(x.device)
# Initialize hidden state to zeros
h = torch.zeros(list(x_in.size())).to(
x.device
) # (B, C_latent, *spatial)
x_in_rnn = encoded_inputs[0]
h = self.rnn_layer(x_in_rnn, h)
# Update hidden state
h = self.rnn_layer(x_in_rnn, h) # (B, C_latent, *spatial)
x_in_rnn = h
rnn_output.append(h)

# Decoding step - decode each hidden state to output
decoded_output = []
for t in range(self.nr_tsteps):
x_out = rnn_output[t]
# Decoding step
x_out = rnn_output[t] # (B, C_latent, *spatial)

# Multi-resolution decoding with skip connections
latent_context_grid = []
for conv_layer, decoder in zip(self.conv_layers, self.decoder_layers):
latent_context_grid.append(conv_layer(x_out))
upsampling_layers = decoder
# Progressively upsample
for upsampling_layer in upsampling_layers:
x_out = upsampling_layer(x_out)

# Add a convolution here to make the channel dimensions same as output
# Only last latent context grid is used, but mult-resolution is available
out = self.final_conv(latent_context_grid[-1])
# Final convolution to match output channels
# Only last latent context grid is used, but multi-resolution is available
out = self.final_conv(latent_context_grid[-1]) # (B, C, *spatial)
decoded_output.append(out)

decoded_output = torch.stack(decoded_output, dim=2)
# Stack outputs along time dimension
decoded_output = torch.stack(decoded_output, dim=2) # (B, C, T, *spatial)
return decoded_output
Loading