diff --git a/examples/cfd/external_aerodynamics/transolver/README.md b/examples/cfd/external_aerodynamics/transformer_models/README.md similarity index 65% rename from examples/cfd/external_aerodynamics/transolver/README.md rename to examples/cfd/external_aerodynamics/transformer_models/README.md index be8943b89b..463285822e 100644 --- a/examples/cfd/external_aerodynamics/transolver/README.md +++ b/examples/cfd/external_aerodynamics/transformer_models/README.md @@ -1,134 +1,74 @@ -# `Transolver` for External Aerodynamics on Irregular Meshes +# Transformer Models for External Aerodynamics on Irregular Meshes -This example is an end to end training recipe for the `Transolver` model, which can -be run on surface or volume data. +This directory contains training and inference recipes for transformer-based surrogate models for CFD applications. This is a collection of transformer models including `Transolver` and `GeoTransolver`, both of which can be run on surface or volume data. -`Transolver` is a high-performance surrogate model for CFD solvers. The Transolver model -adapts the Attention mechanism, encouraging the learning of meaningful representations. -In each PhysicsAttention layer, input points are projected onto state vectors through -learnable transformations and weights. These transformations are then used to compute -self-attention among all state vectors, and the same weights are reused to project -states back to each input point. +## Models Overview -## External Aerodynamics CFD Example: Overview +### Transolver + +`Transolver` is a high-performance surrogate model for CFD solvers. The Transolver model adapts the Attention mechanism, encouraging the learning of meaningful representations. In each PhysicsAttention layer, input points are projected onto state vectors through learnable transformations and weights. These transformations are then used to compute self-attention among all state vectors, and the same weights are reused to project states back to each input point. + +By stacking multiple PhysicsAttention layers, the `Transolver` model learns to map from the functional input space to the output space with high fidelity. The PhysicsNeMo implementation closely follows the original Transolver architecture ([https://github.com/thuml/Transolver](https://github.com/thuml/Transolver)), but introduces modifications for improved numerical stability and compatibility with NVIDIA TransformerEngine. -This directory contains the essential components for training and evaluating a -model tailored to external aerodynamics CFD problems built on `Transolver`. +### GeoTranSolver -By stacking multiple PhysicsAttention layers, the `Transolver` model learns to map from -the functional input space to the output space with high fidelity. The PhysicsNeMo -implementation closely follows the original Transolver architecture -([https://github.com/thuml/Transolver](https://github.com/thuml/Transolver)), but -introduces modifications for improved numerical stability and compatibility with NVIDIA -TransformerEngine. +GeoTransolver adapts the Transolver backbone by replacing standard attention with GALE (Geometry-Aware Latent Embeddings) attention, which unifies physics-aware self-attention on learned state slices with cross-attention to geometry and global context embeddings. Inspired by Domino's multi-scale ball query formulations, GeoTransolver learns global geometry encodings and local latent encodings that capture neighborhoods at multiple radii, preserving fine-grained near-boundary behavior and far-field interactions. Crucially, geometry and global features are projected into physical state spaces and injected as context in every transformer block, ensuring persistent conditioning and alignment between evolving latent states and the underlying domain. + +GALE directly targets core challenges in AI physics modeling. By structuring self-attention around physics-aware slices, GeoTransolver encourages interactions that reflect operator couplings (e.g., pressure–velocity or field–material). Multi-scale ball queries enforce locality where needed while maintaining access to global signals, balancing efficiency with nonlocal reasoning. Continuous geometry-context projection at depth mitigates representation drift and improves stability, while providing a natural interface for constraint-aware training and regularization. Together, these design choices enhance accuracy, robustness to geometric and regime shifts, and scalability on large, irregular discretizations. + +## External Aerodynamics CFD Example: Overview -The training example for Transolver uses the [DrivaerML dataset](https://caemldatasets.org/drivaerml/). +This directory contains the essential components for training and evaluating models tailored to external aerodynamics CFD problems. The training examples use the [DrivaerML dataset](https://caemldatasets.org/drivaerml/). -As a concrete example, we are training external aerodynamics surrogate models for automobiles. -`Transolver` takes as input a point cloud on the surface or surrounding the surface, -iteratively processing it with PhysicsAttention to produce high-fidelity predictions. +As a concrete example, we are training external aerodynamics surrogate models for automobiles. These models take as input a point cloud on the surface or surrounding the surface, iteratively processing it with transformer-based attention mechanisms to produce high-fidelity predictions. ## Requirements -Transolver can use TransformerEngine from NVIDIA, as well as tensorstore (for IO), -zarr, einops and a few other python packages. Install them with `pip install -r requirements.txt` -as well as physicsnemo 25.11 or higher. +These transformer models can use TransformerEngine from NVIDIA, as well as tensorstore (for IO), zarr, einops and a few other python packages. Install them with `pip install -r requirements.txt` as well as physicsnemo 25.11 or higher. -## Using Transolver for External Aerodynamics +## Using Transformer Models for External Aerodynamics -1. Prepare the Dataset. Transolver uses the same Zarr outputs as other models with DrivaerML. -`PhysicsNeMo` has a related project to help with data processing, called [PhysicsNeMo-Curator](https://github.com/NVIDIA/physicsnemo-curator). -Using `PhysicsNeMo-Curator`, the data needed to train can be setup easily. -Please refer to [these instructions on getting started](https://github.com/NVIDIA/physicsnemo-curator?tab=readme-ov-file#what-is-physicsnemo-curator) -with `PhysicsNeMo-Curator`. For specifics of preparing the dataset for this example, -see the [download](https://github.com/NVIDIA/physicsnemo-curator/blob/main/examples/external_aerodynamics/README.md#download-drivaerml-dataset) -and [preprocessing](https://github.com/NVIDIA/physicsnemo-curator/blob/main/examples/external_aerodynamics/README.md) -instructions from `physicsnemo-curator`. Users should apply the -preprocessing steps locally to produce `zarr` output files. +1. Prepare the Dataset. These models use the same Zarr outputs as other models with DrivaerML. `PhysicsNeMo` has a related project to help with data processing, called [PhysicsNeMo-Curator](https://github.com/NVIDIA/physicsnemo-curator). Using `PhysicsNeMo-Curator`, the data needed to train can be setup easily. Please refer to [these instructions on getting started](https://github.com/NVIDIA/physicsnemo-curator?tab=readme-ov-file#what-is-physicsnemo-curator) with `PhysicsNeMo-Curator`. For specifics of preparing the dataset for this example, see the [download](https://github.com/NVIDIA/physicsnemo-curator/blob/main/examples/external_aerodynamics/README.md#download-drivaerml-dataset) and [preprocessing](https://github.com/NVIDIA/physicsnemo-curator/blob/main/examples/external_aerodynamics/README.md) instructions from `physicsnemo-curator`. Users should apply the preprocessing steps locally to produce `zarr` output files. -2. Train your model. The model and training configuration is configured with -`hydra`, and two configurations are available: `transolver_surface` and `transolver_volume`. -Find configurations in `src/conf`, where you can control both network properties -and training properties. See below for an overview and explanation of key -parameters that may be of special interest. +2. Train your model. The model and training configuration is configured with `hydra`, and configurations are available for both surface and volume modes (e.g., `transolver_surface`, `transolver_volume`, `geotransolver_surface`, `geotransolver_volume`). Find configurations in `src/conf`, where you can control both network properties and training properties. See below for an overview and explanation of key parameters that may be of special interest. -3. Use the trained model to perform inference. This example contains two -inference examples: one for inference on the validation set, already in -Zarr format. The `.vtp` inference pipeline is being updated to accommodate Transolver. +3. Use the trained model to perform inference. This example contains inference examples for the validation set, already in Zarr format. The `.vtp` inference pipeline is being updated to accommodate these models. -The following sections contain further details on the training and inference -recipe. +The following sections contain further details on the training and inference recipe. ## Model Training -To train the model, first we compute normalization factors on the dataset to -make the predictive quantities output in a well-defined range. The included -script, `compute_normalizations.py`, will compute the normalization -factors. Once run, it should save to an output file similar to -"surface_fields_normalization.npz". This will get loaded during training. -The normalization file location can be configured via `data.normalization_dir` -in the training configuration (defaults to current directory). - -> By default, the normalization sets the mean to 0.0 and std to 1.0 of all labels -> in the dataset, computing the mean across the train dataset. You could adapt -> this to a different normalization, however take care to update both the -> preprocessing as well as inference scripts. Min/Max is another popular strategy. - -To configure your training run, use `hydra`. The -config contains sections for the model, data, optimizer, and training settings. -For details on the model parameters, see the API for `physicsnemo.models.transolver`. - -To fit the training into memory, you can apply on-the-fly downsampling to the data -with `data.resolution=N`, where `N` is how many points per GPU to use. This dataloader -will yield the full data examples in shapes of `[1, K, f]` where `K` is the resolution -of the mesh, and `f` is the feature space (3 for points, normals, etc. 4 for surface -fields). Downsampling happens in the preprocessing pipeline. - -During training, the configuration uses a flat learning rate that decays every 100 -epochs, and bfloat16 format by default. The scheduler and learning rate -may be configured. - -The Optimizer for this training is the `Muon` optimizer - available only in -`pytorch>=2.9.0`. While not strictly required, we have found the `muon` optimizer -performs substantially better on these architectures than standard `AdamW` and -a oneCycle schedule. +To train the model, first we compute normalization factors on the dataset to make the predictive quantities output in a well-defined range. The included script, `compute_normalizations.py`, will compute the normalization factors. Once run, it should save to an output file similar to "surface_fields_normalization.npz". This will get loaded during training. The normalization file location can be configured via `data.normalization_dir` in the training configuration (defaults to current directory). + +> By default, the normalization sets the mean to 0.0 and std to 1.0 of all labels in the dataset, computing the mean across the train dataset. You could adapt this to a different normalization, however take care to update both the preprocessing as well as inference scripts. Min/Max is another popular strategy. + +To configure your training run, use `hydra`. The config contains sections for the model, data, optimizer, and training settings. For details on the model parameters, see the API for `physicsnemo.models.transolver` and `physicsnemo.experimental.models.geotransolver`. + +To fit the training into memory, you can apply on-the-fly downsampling to the data with `data.resolution=N`, where `N` is how many points per GPU to use. This dataloader will yield the full data examples in shapes of `[1, K, f]` where `K` is the resolution of the mesh, and `f` is the feature space (3 for points, normals, etc. 4 for surface fields). Downsampling happens in the preprocessing pipeline. + +During training, the configuration uses a flat learning rate that decays every 100 epochs, and bfloat16 format by default. The scheduler and learning rate may be configured. + +The Optimizer for this training is the `Muon` optimizer - available only in `pytorch>=2.9.0`. While not strictly required, we have found the `muon` optimizer performs substantially better on these architectures than standard `AdamW` and a oneCycle schedule. ### Training Precision -Transolver, as a transformer-like architecture, has support for NVIDIA's -[TransformerEngine](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html) -built in. You can enable/disable the transformer engine path in the model with -`model.use_te=[True | False]`. Available precisions for training with `transformer_engine` -are `training.precision=["float32" | "float16" | "bfloat16" | "float8" ]`. In `float8` -precision, the TransformerEngine Hybrid recipe is used for casting weights and inputs -in the forward and backwards passes. For more details on `float8` precision, see -the fp8 guide from -[TransformerEngine](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html). -When using fp8, the training script will automatically pad and unpad the input and output, -respectively, to use the fp8 hardware correctly. - -> **Float8** precisions are only available on GPUs with fp8 tensorcore support, such -> as Hopper, Blackwell, Ada Lovelace, and others. +These transformer architectures have support for NVIDIA's [TransformerEngine](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/index.html) built in. You can enable/disable the transformer engine path in the model with `model.use_te=[True | False]`. Available precisions for training with `transformer_engine` are `training.precision=["float32" | "float16" | "bfloat16" | "float8" ]`. In `float8` precision, the TransformerEngine Hybrid recipe is used for casting weights and inputs in the forward and backwards passes. For more details on `float8` precision, see the fp8 guide from [TransformerEngine](https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/examples/fp8_primer.html). When using fp8, the training script will automatically pad and unpad the input and output, respectively, to use the fp8 hardware correctly. + +> **Float8** precisions are only available on GPUs with fp8 tensorcore support, such as Hopper, Blackwell, Ada Lovelace, and others. ### Other Configuration Settings Several other important configuration settings are available: -- `checkpoint_dir` sets the directory for saving model checkpoints (defaults to `output_dir` -if not specified), allowing separation of checkpoints from other outputs. -- `compile` will use `torch.compile` for optimized performance. It is not -compatible with `transformer_engine` (`model.use_te=True`). If TransformerEngine is -not used, and half precision is, `torch.compile` is recommended for improved performance. +- `checkpoint_dir` sets the directory for saving model checkpoints (defaults to `output_dir` if not specified), allowing separation of checkpoints from other outputs. +- `compile` will use `torch.compile` for optimized performance. It is not compatible with `transformer_engine` (`model.use_te=True`). If TransformerEngine is not used, and half precision is, `torch.compile` is recommended for improved performance. - `training.num_epochs` controls the total number of epochs used during training. -- `training.save_interval` will dictate how often the model weights and training -tools are checkpointed. +- `training.save_interval` will dictate how often the model weights and training tools are checkpointed. -> **Note** Like other parameters of the model, changing the value of `model.use_te` -> will make checkpoints incompatible. +> **Note** Like other parameters of the model, changing the value of `model.use_te` will make checkpoints incompatible. -The training script supports data-parallel training via PyTorch DDP. In a future -update, we may enable domain parallelism via FSDP and ShardTensor. +The training script supports data-parallel training via PyTorch DDP. In a future update, we may enable domain parallelism via FSDP and ShardTensor. The script can be launched on a single GPU with, for example, @@ -185,9 +125,7 @@ Epoch 47 Validation Average Metrics: ## Dataset Inference -The validation dataset in Zarr format can be loaded, processed, and the L2 -metrics summarized in `inference_on_zarr.py`. For surface data, this script will also -compute the drag and lift coefficients and the R^2 correlation of the predictions. +The validation dataset in Zarr format can be loaded, processed, and the L2 metrics summarized in `inference_on_zarr.py`. For surface data, this script will also compute the drag and lift coefficients and the R^2 correlation of the predictions. To run inference on surface data, it's necessary to add a line to your launch command: @@ -196,15 +134,10 @@ python src/inference_on_zarr.py --config-name transolver_surface run_id=/path/to ``` -The `data.return_mesh_features` flag can also be set in the config file. It is -disabled for training but necessary for inference. The model path should be -the folder containing your saved checkpoints. +The `data.return_mesh_features` flag can also be set in the config file. It is disabled for training but necessary for inference. The model path should be the folder containing your saved checkpoints. -To ensure correct calculation of drag and lift, and accurate overall metrics, -the inference script will chunk a full-resolution training example into batches, -and stitch the outputs together at the end. Output will appear as a table -with all metrics for that mode, for example: +To ensure correct calculation of drag and lift, and accurate overall metrics, the inference script will chunk a full-resolution training example into batches, and stitch the outputs together at the end. Output will appear as a table with all metrics for that mode, for example: ``` | Batch | Loss | L2 Pressure | L2 Shear X | L2 Shear Y | L2 Shear Z | Predicted Drag Coefficient | Pred Lift Coefficient | True Drag Coefficient | True Lift Coefficient | Elapsed (s) | @@ -279,6 +212,4 @@ entire mesh. The outputs are then saved to .vtp files for downstream analysis. ## Transolver++ -Transolver++ is supported with the `plus` flag to the model. In -our experiments, we did not see gains, but you are welcome to try it and share -your results with us on GitHub! +Transolver++ is supported with the `plus` flag to the model. In our experiments, we did not see gains, but you are welcome to try it and share your results with us on GitHub! diff --git a/examples/cfd/external_aerodynamics/transolver/conf/train_surface.yaml b/examples/cfd/external_aerodynamics/transformer_models/deprecated/conf/train_surface.yaml similarity index 100% rename from examples/cfd/external_aerodynamics/transolver/conf/train_surface.yaml rename to examples/cfd/external_aerodynamics/transformer_models/deprecated/conf/train_surface.yaml diff --git a/examples/cfd/external_aerodynamics/transolver/datapipe.py b/examples/cfd/external_aerodynamics/transformer_models/deprecated/datapipe.py similarity index 100% rename from examples/cfd/external_aerodynamics/transolver/datapipe.py rename to examples/cfd/external_aerodynamics/transformer_models/deprecated/datapipe.py diff --git a/examples/cfd/external_aerodynamics/transolver/inference_on_vtp.py b/examples/cfd/external_aerodynamics/transformer_models/deprecated/inference_on_vtp.py similarity index 100% rename from examples/cfd/external_aerodynamics/transolver/inference_on_vtp.py rename to examples/cfd/external_aerodynamics/transformer_models/deprecated/inference_on_vtp.py diff --git a/examples/cfd/external_aerodynamics/transolver/requirements.txt b/examples/cfd/external_aerodynamics/transformer_models/requirements.txt similarity index 100% rename from examples/cfd/external_aerodynamics/transolver/requirements.txt rename to examples/cfd/external_aerodynamics/transformer_models/requirements.txt diff --git a/examples/cfd/external_aerodynamics/transolver/src/benchmark_dataloading.py b/examples/cfd/external_aerodynamics/transformer_models/src/benchmark_dataloading.py similarity index 100% rename from examples/cfd/external_aerodynamics/transolver/src/benchmark_dataloading.py rename to examples/cfd/external_aerodynamics/transformer_models/src/benchmark_dataloading.py diff --git a/examples/cfd/external_aerodynamics/transolver/src/compute_normalizations.py b/examples/cfd/external_aerodynamics/transformer_models/src/compute_normalizations.py similarity index 99% rename from examples/cfd/external_aerodynamics/transolver/src/compute_normalizations.py rename to examples/cfd/external_aerodynamics/transformer_models/src/compute_normalizations.py index 749a7ab2f7..13dcff04f1 100644 --- a/examples/cfd/external_aerodynamics/transolver/src/compute_normalizations.py +++ b/examples/cfd/external_aerodynamics/transformer_models/src/compute_normalizations.py @@ -89,7 +89,6 @@ def compute_mean_std_min_max( # Update running mean and M2 (Welford's algorithm) delta = batch_mean - mean - N += batch_n mean = mean + delta * (batch_n / N) M2 = M2 + batch_M2 + delta**2 * (batch_n * N) / N time_end = time.time() diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/data/core.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/data/core.yaml new file mode 100644 index 0000000000..56e1f4e601 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/data/core.yaml @@ -0,0 +1,60 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Paths to your data: +train: + data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/train/ +val: + data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/val/ + +# You can set a normalization factor directory: +normalization_dir: "src/" + +# How many events in advance should we be preloading? +preload_depth: 1 + +# Pin memory for GPU transfers? +pin_memory: true + +# Sampling resolution of the point clouds: +resolution: 200_000 + +# Surface / Volume / (combined, if supported) +mode: ??? + +# For building embeddings: include normal directions for each point? +include_normals: true +# Include SDF? (It's 0 for surface data...) +include_sdf: true +# Apply translation invariance via center-of-mass subtraction? +translational_invariance: true +# Rescale x/y/z inputs to the model for scale invariance? +scale_invariance: true +reference_scale: [12.0, 4.5, 3.25] + +# Which parts of the data files to read? No need to read everything, all the time. +data_keys: ??? + +# Load and return the STL geometry info in the dataloader? +include_geometry: false + +# Broadcast global features to the same resolution as points? +broadcast_global_features: true + +# Return the mesh areas and normals? You don't usually want this for training. +# We switch it on automatically for inference. +return_mesh_features: false + diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/data/surface.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/data/surface.yaml new file mode 100644 index 0000000000..fed0255591 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/data/surface.yaml @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - core + +# Overrides for surface data: +mode: surface + +# Surface-speficic needs: +data_keys: + - "surface_fields" + - "surface_mesh_centers" + - "surface_normals" + - "surface_areas" + - "stl_faces" + - "stl_centers" + - "stl_coordinates" + - "air_density" + - "stream_velocity" \ No newline at end of file diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/data/volume.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/data/volume.yaml new file mode 100644 index 0000000000..9d34cc406a --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/data/volume.yaml @@ -0,0 +1,30 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# You may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - core + +# Overrides for volume data: +mode: volume + +# volume-specific needs: +data_keys: + - "volume_fields" + - "volume_mesh_centers" + - "stl_faces" + - "stl_centers" + - "stl_coordinates" + \ No newline at end of file diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/geotransolver_surface.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/geotransolver_surface.yaml new file mode 100644 index 0000000000..173c5cfc02 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/geotransolver_surface.yaml @@ -0,0 +1,48 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - training: base + - model: geotransolver + - data: surface + +output_dir: "runs" +checkpoint_dir: null # Optional: set custom checkpoint path, defaults to output_dir +run_id: "geotransolver/surface/bq" + +# Performance considerations: +precision: float32 # float32, float16, bfloat16, or float8 +compile: true +profile: false + +model: + functional_dim: 6 + include_local_features: true # use local features + radii: [0.01, 0.05, 0.25, 1.0, 2.5, 5.0] # radius for local features + neighbors_in_radius: [4, 8, 16, 64, 128, 256] # neighbors in radius for local features + n_hidden_local: 32 # hidden dimension for local features + +data: + include_sdf: false + include_geometry: true + geometry_sampling: 300_000 + broadcast_global_features: true + + +# Logging configuration +logging: + level: INFO + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/geotransolver_volume.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/geotransolver_volume.yaml new file mode 100644 index 0000000000..cd91b485cc --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/geotransolver_volume.yaml @@ -0,0 +1,49 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - training: base + - model: geotransolver + - data: volume + +output_dir: "runs" +checkpoint_dir: null # Optional: set custom checkpoint path, defaults to output_dir +run_id: "geotransolver/volume/bq" + +# Performance considerations: +precision: float32 # float32, float16, bfloat16, or float8 +compile: true +profile: false + +data: + include_geometry: true + geometry_sampling: 300_000 + broadcast_global_features: false + volume_sample_from_disk: true + + +model: + functional_dim: 7 + out_dim: 5 + include_local_features: true # use local features + radii: [0.01, 0.05, 0.25, 1.0, 2.5, 5.0] # radius for local features + neighbors_in_radius: [4, 8, 16, 64, 128, 256] # neighbors in radius for local features + n_hidden_local: 32 # hidden dimension for local features + +# Logging configuration +logging: + level: INFO + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/model/geotransolver.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/model/geotransolver.yaml new file mode 100644 index 0000000000..456c1eea8f --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/model/geotransolver.yaml @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_target_: physicsnemo.experimental.models.geotransolver.GeoTransolver +functional_dim: 6 +global_dim: 2 +geometry_dim: 3 +out_dim: 4 +n_layers: 20 +n_hidden: 256 +dropout: 0.0 +n_head: 8 +act: "gelu" +mlp_ratio: 2 +slice_num: 128 +use_te: false +plus: false +include_local_features: true # use local features +radii: [0.05, 0.25, 1.0, 2.5] # radius for local features +neighbors_in_radius: [8, 32, 64, 128] # neighbors in radius for local features +n_hidden_local: 32 # hidden dimension for local features + diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/model/transolver.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/model/transolver.yaml new file mode 100644 index 0000000000..c43fb8560c --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/model/transolver.yaml @@ -0,0 +1,34 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +_target_: physicsnemo.models.transolver.Transolver +functional_dim: 2 +out_dim: 4 +embedding_dim: 6 +n_layers: 8 +n_hidden: 256 +dropout: 0.0 +n_head: 8 +act: "gelu" +mlp_ratio: 2 +slice_num: 512 +unified_pos: false +ref: 8 +structured_shape: null +use_te: false +time_input: false +plus: false + diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/training/base.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/training/base.yaml new file mode 100644 index 0000000000..18797ea051 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/training/base.yaml @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +num_epochs: 501 +save_interval: 25 + +scheduler: + name: "StepLR" + params: + step_size: 100 + gamma: 0.5 + +optimizer: + _target_: torch.optim.AdamW + lr: 1.0e-3 + weight_decay: 1.0e-4 + betas: [0.9, 0.999] + eps: 1.0e-8 + diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_surface.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_surface.yaml new file mode 100644 index 0000000000..47c3f8a59d --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_surface.yaml @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - training: base + - model: transolver + - data: surface + +output_dir: "runs" +checkpoint_dir: null # Optional: set custom checkpoint path, defaults to output_dir +run_id: "surface/float32" + +# Performance considerations: +precision: float32 # float32, float16, bfloat16, or float8 +compile: true +profile: false + +data: + include_sdf: false + +# Logging configuration +logging: + level: INFO + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' + \ No newline at end of file diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_volume.yaml b/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_volume.yaml new file mode 100644 index 0000000000..7d0c8eb249 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/conf/transolver_volume.yaml @@ -0,0 +1,38 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +defaults: + - training: base + - model: transolver + - data: volume + +output_dir: "runs" +checkpoint_dir: null # Optional: set custom checkpoint path, defaults to output_dir +run_id: "volume/float32" + +# Performance considerations: +precision: float32 # float32, float16, bfloat16, or float8 +compile: true +profile: false + +model: + out_dim: 5 + embedding_dim: 7 + +# Logging configuration +logging: + level: INFO + format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/inference_on_vtk.py b/examples/cfd/external_aerodynamics/transformer_models/src/inference_on_vtk.py new file mode 100644 index 0000000000..87cbf3a9a9 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/inference_on_vtk.py @@ -0,0 +1,729 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Inference script for running trained Transolver/GeoTransolver models on raw VTK files. + +This script reads VTP (surface) and VTU (volume) files directly, processes them through +the TransolverDataPipe, runs batched inference, and saves predictions back to VTK files. + +Usage (surface inference with GeoTransolver): + python inference_on_vtk.py --config-name=geotransolver_surface \ + +vtk_inference.input_dir=/path/to/runs \ + +vtk_inference.output_dir=/path/to/output \ + +vtk_inference.air_density=1.2050 \ + +vtk_inference.stream_velocity=30.0 + +Usage (volume inference with GeoTransolver): + python inference_on_vtk.py --config-name=geotransolver_volume \ + +vtk_inference.input_dir=/path/to/runs \ + +vtk_inference.output_dir=/path/to/output + +Usage (surface inference with Transolver): + python inference_on_vtk.py --config-name=transolver_surface \ + +vtk_inference.input_dir=/path/to/runs \ + +vtk_inference.output_dir=/path/to/output + +Note: The '+' prefix adds new config keys that don't exist in the base config. + +Expected input directory structure: + input_dir/ + ├── run_1/ + │ ├── boundary_1.vtp # Surface mesh + │ ├── volume_1.vtu # Volume mesh + │ └── drivaer_1_single_solid.stl # STL geometry + ├── run_2/ + │ └── ... + └── ... +""" + +from pathlib import Path +from typing import Literal +import time + +import numpy as np +import torch +import torchinfo +import pyvista as pv + +import hydra +import omegaconf +from omegaconf import DictConfig + +from physicsnemo.distributed import DistributedManager +from physicsnemo.utils import load_checkpoint +from physicsnemo.utils.logging import PythonLogger, RankZeroLoggingWrapper + +from physicsnemo.datapipes.cae.transolver_datapipe import TransolverDataPipe + +from train import update_model_params_for_fp8 + +from inference_on_zarr import batched_inference_loop + + +# ============================================================================= +# VTK File Reading Functions +# ============================================================================= + + +def read_stl_geometry(stl_path: str, device: torch.device) -> dict[str, torch.Tensor]: + """ + Read STL file and extract geometry data for SDF calculation. + + Parameters + ---------- + stl_path : str + Path to the STL file (e.g., drivaer_N_single_solid.stl). + device : torch.device + Device to place tensors on. + + Returns + ------- + dict[str, torch.Tensor] + Dictionary containing: + - stl_coordinates: Vertex coordinates, shape (num_vertices, 3) + - stl_faces: Face indices (flattened), shape (num_faces * 3,) + - stl_centers: Cell centers, shape (num_cells, 3) + """ + mesh = pv.read(stl_path) + + # Get vertex coordinates + stl_coordinates = torch.from_numpy(np.asarray(mesh.points)).to( + device=device, dtype=torch.float32 + ) + + # Get face indices - pyvista stores as [n_verts, v0, v1, v2, n_verts, v0, v1, v2, ...] + # We reshape to extract just the vertex indices for triangles + faces = mesh.faces.reshape(-1, 4)[:, 1:] # Remove the count column + stl_faces = torch.from_numpy(faces.flatten()).to(device=device, dtype=torch.int32) + + # Get cell centers + stl_centers = torch.from_numpy(np.asarray(mesh.cell_centers().points)).to( + device=device, dtype=torch.float32 + ) + + return { + "stl_coordinates": stl_coordinates, + "stl_faces": stl_faces, + "stl_centers": stl_centers, + } + + +def read_surface_from_vtp( + vtp_path: str, device: torch.device, n_output_fields: int = 4 +) -> dict[str, torch.Tensor]: + """ + Read VTP (PolyData) file and extract surface mesh data. + + Parameters + ---------- + vtp_path : str + Path to the VTP file (e.g., boundary_N.vtp). + device : torch.device + Device to place tensors on. + n_output_fields : int + Number of output fields (default 4: pressure + 3 wall shear stress components). + + Returns + ------- + dict[str, torch.Tensor] + Dictionary containing: + - surface_mesh_centers: Cell center coordinates, shape (num_cells, 3) + - surface_normals: Cell normals, shape (num_cells, 3) + - surface_areas: Cell areas, shape (num_cells,) + - surface_fields: Dummy zeros for inference, shape (num_cells, n_output_fields) + """ + mesh = pv.read(vtp_path) + + # Get cell centers + surface_mesh_centers = torch.from_numpy(np.asarray(mesh.cell_centers().points)).to( + device=device, dtype=torch.float32 + ) + + # Get cell normals (normalized) + normals = np.asarray(mesh.cell_normals) + normals = normals / (np.linalg.norm(normals, axis=1, keepdims=True) + 1e-8) + surface_normals = torch.from_numpy(normals).to(device=device, dtype=torch.float32) + + # Compute cell areas + cell_sizes = mesh.compute_cell_sizes(length=False, area=True, volume=False) + surface_areas = torch.from_numpy(np.asarray(cell_sizes.cell_data["Area"])).to( + device=device, dtype=torch.float32 + ) + + # Create dummy fields for inference (zeros) + num_cells = surface_mesh_centers.shape[0] + surface_fields = torch.zeros( + (num_cells, n_output_fields), device=device, dtype=torch.float32 + ) + + return { + "surface_mesh_centers": surface_mesh_centers, + "surface_normals": surface_normals, + "surface_areas": surface_areas, + "surface_fields": surface_fields, + } + + +def read_volume_from_vtu( + vtu_path: str, device: torch.device, n_output_fields: int = 5 +) -> dict[str, torch.Tensor]: + """ + Read VTU (UnstructuredGrid) file and extract volume mesh data. + + Parameters + ---------- + vtu_path : str + Path to the VTU file (e.g., volume_N.vtu). + device : torch.device + Device to place tensors on. + n_output_fields : int + Number of output fields (default 5: 3 velocity + pressure + turbulent viscosity). + + Returns + ------- + dict[str, torch.Tensor] + Dictionary containing: + - volume_mesh_centers: Cell center coordinates, shape (num_cells, 3) + - volume_fields: Dummy zeros for inference, shape (num_cells, n_output_fields) + """ + mesh = pv.read(vtu_path) + + # Get cell centers + volume_mesh_centers = torch.from_numpy(np.asarray(mesh.cell_centers().points)).to( + device=device, dtype=torch.float32 + ) + + # Create dummy fields for inference (zeros) + num_cells = volume_mesh_centers.shape[0] + volume_fields = torch.zeros( + (num_cells, n_output_fields), device=device, dtype=torch.float32 + ) + + return { + "volume_mesh_centers": volume_mesh_centers, + "volume_fields": volume_fields, + } + + +# ============================================================================= +# Data Dict Builder +# ============================================================================= + + +def build_data_dict( + run_dir: Path, + data_mode: Literal["surface", "volume", "combined"], + device: torch.device, + air_density: float, + stream_velocity: float, + run_idx: int, +) -> dict[str, torch.Tensor]: + """ + Build a complete data dictionary from VTK files for a single run. + + This function reads VTP, VTU, and STL files from a run directory and + combines them into a dictionary compatible with TransolverDataPipe.process_data(). + + Parameters + ---------- + run_dir : Path + Path to the run directory containing VTK files. + data_mode : Literal["surface", "volume", "combined"] + Which data to load - surface, volume, or both. + device : torch.device + Device to place tensors on. + air_density : float + Air density value for the simulation. + stream_velocity : float + Stream velocity value for the simulation. + run_idx : int + The run index (used for file naming conventions). + + Returns + ------- + dict[str, torch.Tensor] + Complete data dictionary for the datapipe. + """ + data_dict = {} + + # Always read STL geometry (needed for SDF in volume mode, center of mass calculation) + stl_path = run_dir / f"drivaer_{run_idx}_single_solid.stl" + if stl_path.exists(): + stl_data = read_stl_geometry(str(stl_path), device) + data_dict.update(stl_data) + else: + # Try alternative naming + stl_files = list(run_dir.glob("*_single_solid.stl")) + if stl_files: + stl_data = read_stl_geometry(str(stl_files[0]), device) + data_dict.update(stl_data) + else: + raise FileNotFoundError(f"No STL file found in {run_dir}") + + # Read surface data if needed + if data_mode in ["surface", "combined"]: + vtp_path = run_dir / f"boundary_{run_idx}.vtp" + if not vtp_path.exists(): + # Try alternative naming + vtp_files = list(run_dir.glob("boundary_*.vtp")) + if vtp_files: + vtp_path = vtp_files[0] + else: + raise FileNotFoundError(f"No VTP file found in {run_dir}") + + surface_data = read_surface_from_vtp(str(vtp_path), device) + data_dict.update(surface_data) + + # Read volume data if needed + if data_mode in ["volume", "combined"]: + vtu_path = run_dir / f"volume_{run_idx}.vtu" + if not vtu_path.exists(): + # Try alternative naming + vtu_files = list(run_dir.glob("volume_*.vtu")) + if vtu_files: + vtu_path = vtu_files[0] + else: + raise FileNotFoundError(f"No VTU file found in {run_dir}") + + volume_data = read_volume_from_vtu(str(vtu_path), device) + data_dict.update(volume_data) + + # Add flow parameters + data_dict["air_density"] = torch.tensor( + [air_density], device=device, dtype=torch.float32 + ) + data_dict["stream_velocity"] = torch.tensor( + [stream_velocity], device=device, dtype=torch.float32 + ) + + return data_dict + + +# ============================================================================= +# Prediction Writer +# ============================================================================= + + +def write_surface_predictions_to_vtk( + vtp_path: str, + output_path: str, + predictions: torch.Tensor, + air_density: float, + stream_velocity: float, +) -> None: + """ + Write surface predictions to a VTP file. + + Parameters + ---------- + vtp_path : str + Path to the original VTP file (to copy mesh structure). + output_path : str + Path to write the output VTP file. + predictions : torch.Tensor + Model predictions, shape (num_cells, 4) - [pressure, wss_x, wss_y, wss_z]. + air_density : float + Air density for dimensional scaling. + stream_velocity : float + Stream velocity for dimensional scaling. + """ + mesh = pv.read(vtp_path) + output_mesh = mesh.copy() + + # Convert to numpy + pred_np = predictions.cpu().numpy() + + # Split into pressure and wall shear stress + pred_pressure = pred_np[:, 0] # Shape: (num_cells,) + pred_wss = pred_np[:, 1:4] # Shape: (num_cells, 3) + + # Scale to physical units + dynamic_pressure = air_density * stream_velocity**2 + pred_pressure = pred_pressure * dynamic_pressure + pred_wss = pred_wss * dynamic_pressure + + # Add to mesh + output_mesh.cell_data["PredictedPressure"] = pred_pressure + output_mesh.cell_data["PredictedWallShearStress"] = pred_wss + + # Save + output_mesh.save(output_path) + + +def write_volume_predictions_to_vtk( + vtu_path: str, + output_path: str, + predictions: torch.Tensor, + air_density: float, + stream_velocity: float, +) -> None: + """ + Write volume predictions to a VTU file. + + Parameters + ---------- + vtu_path : str + Path to the original VTU file (to copy mesh structure). + output_path : str + Path to write the output VTU file. + predictions : torch.Tensor + Model predictions, shape (num_cells, 5) - [vel_x, vel_y, vel_z, pressure, nut]. + air_density : float + Air density for dimensional scaling. + stream_velocity : float + Stream velocity for dimensional scaling. + """ + mesh = pv.read(vtu_path) + output_mesh = mesh.copy() + + # Convert to numpy + pred_np = predictions.cpu().numpy() + + # Split into velocity, pressure, and turbulent viscosity + pred_velocity = pred_np[:, 0:3] # Shape: (num_cells, 3) + pred_pressure = pred_np[:, 3] # Shape: (num_cells,) + pred_nut = pred_np[:, 4] # Shape: (num_cells,) + + # Scale to physical units + dynamic_pressure = air_density * stream_velocity**2 + pred_velocity = pred_velocity * stream_velocity + pred_pressure = pred_pressure * dynamic_pressure + pred_nut = pred_nut * dynamic_pressure + + # Add to mesh + output_mesh.cell_data["PredictedVelocity"] = pred_velocity + output_mesh.cell_data["PredictedPressure"] = pred_pressure + output_mesh.cell_data["PredictedNut"] = pred_nut + + # Save + output_mesh.save(output_path) + + +# ============================================================================= +# Main Inference Function +# ============================================================================= + + +def create_datapipe( + cfg: DictConfig, + data_mode: Literal["surface", "volume", "combined"], + device: torch.device, + surface_factors: dict | None, + volume_factors: dict | None, +) -> TransolverDataPipe: + """ + Create a TransolverDataPipe configured for inference. + + Parameters + ---------- + cfg : DictConfig + Hydra configuration. + data_mode : Literal["surface", "volume", "combined"] + Data mode for the datapipe. + device : torch.device + Device for tensors. + surface_factors : dict | None + Normalization factors for surface fields. + volume_factors : dict | None + Normalization factors for volume fields. + + Returns + ------- + TransolverDataPipe + Configured datapipe for inference. + """ + # Build overrides from config + overrides = {} + + optional_keys = [ + "include_normals", + "include_sdf", + "broadcast_global_features", + "include_geometry", + "geometry_sampling", + "translational_invariance", + "reference_origin", + "scale_invariance", + "reference_scale", + ] + + for key in optional_keys: + if cfg.data.get(key, None) is not None: + overrides[key] = cfg.data[key] + + # Create the datapipe with no resolution limit (we handle batching ourselves) + datapipe = TransolverDataPipe( + input_path=None, # We're not using the dataset iterator + model_type=data_mode, + resolution=None, # No downsampling - we batch manually + surface_factors=surface_factors, + volume_factors=volume_factors, + scaling_type="mean_std_scaling", + return_mesh_features=True, # For surface areas/normals if needed + **overrides, + ) + + # Move reference scale to device if needed + if datapipe.config.scale_invariance and datapipe.config.reference_scale is not None: + datapipe.config.reference_scale = datapipe.config.reference_scale.to(device) + + return datapipe + + +def inference_on_vtk(cfg: DictConfig) -> None: + """ + Main inference function for VTK files. + + Parameters + ---------- + cfg : DictConfig + Hydra configuration object. + """ + # Initialize distributed + DistributedManager.initialize() + dist_manager = DistributedManager() + + logger = RankZeroLoggingWrapper(PythonLogger(name="vtk_inference"), dist_manager) + + # Update config for FP8 if needed + cfg, output_pad_size = update_model_params_for_fp8(cfg, logger) + + logger.info(f"Config:\n{omegaconf.OmegaConf.to_yaml(cfg, resolve=True)}") + + # Get VTK inference config - these are added via command line with '+' prefix + if not cfg.get("vtk_inference", None): + raise ValueError( + "vtk_inference config section is required. " + "Add it via command line with '+vtk_inference.input_dir=...' etc." + ) + + vtk_cfg = cfg.vtk_inference + + # Required parameters + if not vtk_cfg.get("input_dir", None): + raise ValueError("vtk_inference.input_dir is required") + if not vtk_cfg.get("output_dir", None): + raise ValueError("vtk_inference.output_dir is required") + + input_dir = Path(vtk_cfg.input_dir) + output_dir = Path(vtk_cfg.output_dir) + + # Optional parameters with defaults + air_density = vtk_cfg.get("air_density", 1.2050) + stream_velocity = vtk_cfg.get("stream_velocity", 30.0) + run_indices = vtk_cfg.get("run_indices", None) + + logger.info(f"VTK Inference Settings:") + logger.info(f" input_dir: {input_dir}") + logger.info(f" output_dir: {output_dir}") + logger.info(f" air_density: {air_density}") + logger.info(f" stream_velocity: {stream_velocity}") + logger.info(f" run_indices: {run_indices}") + + # Create output directory + output_dir.mkdir(parents=True, exist_ok=True) + + # Determine data mode + data_mode = cfg.data.mode + + # Set up model + model = hydra.utils.instantiate(cfg.model) + logger.info(f"\n{torchinfo.summary(model, verbose=0)}") + + # Load checkpoint + if cfg.checkpoint_dir is not None: + checkpoint_dir = cfg.checkpoint_dir + else: + checkpoint_dir = f"{cfg.output_dir}/{cfg.run_id}/checkpoints" + + ckpt_args = { + "path": checkpoint_dir, + "models": model, + } + + loaded_epoch = load_checkpoint(device=dist_manager.device, **ckpt_args) + logger.info(f"Loaded checkpoint from epoch: {loaded_epoch}") + + model.to(dist_manager.device) + model.eval() + + if cfg.compile: + model = torch.compile(model, dynamic=True) + + num_params = sum(p.numel() for p in model.parameters()) + logger.info(f"Number of model parameters: {num_params}") + + # Load normalization factors + norm_dir = getattr(cfg.data, "normalization_dir", ".") + + surface_factors = None + volume_factors = None + + if data_mode in ["surface", "combined"]: + norm_file = str(Path(norm_dir) / "surface_fields_normalization.npz") + if Path(norm_file).exists(): + norm_data = np.load(norm_file) + surface_factors = { + "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), + "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), + } + logger.info(f"Loaded surface normalization from {norm_file}") + + if data_mode in ["volume", "combined"]: + norm_file = str(Path(norm_dir) / "volume_fields_normalization.npz") + if Path(norm_file).exists(): + norm_data = np.load(norm_file) + volume_factors = { + "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), + "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), + } + logger.info(f"Loaded volume normalization from {norm_file}") + + # Create datapipe + datapipe = create_datapipe( + cfg, data_mode, dist_manager.device, surface_factors, volume_factors + ) + + # Get batch resolution from config + batch_resolution = cfg.data.resolution + + # Find all run directories + if run_indices is not None: + run_dirs = [input_dir / f"run_{idx}" for idx in run_indices] + else: + run_dirs = sorted( + [d for d in input_dir.iterdir() if d.is_dir() and d.name.startswith("run_")] + ) + + logger.info(f"Found {len(run_dirs)} run directories to process") + + # Distribute runs across ranks + this_device_runs = run_dirs[dist_manager.rank :: dist_manager.world_size] + logger.info(f"Rank {dist_manager.rank} processing {len(this_device_runs)} runs") + + # Process each run + for run_dir in this_device_runs: + run_idx = int(run_dir.name.split("_")[1]) + logger.info(f"Processing run {run_idx}: {run_dir}") + + start_time = time.time() + + try: + # Build data dictionary from VTK files + data_dict = build_data_dict( + run_dir=run_dir, + data_mode=data_mode, + device=dist_manager.device, + air_density=air_density, + stream_velocity=stream_velocity, + run_idx=run_idx, + ) + + # Process through datapipe (adds batch dimension) + batch = datapipe(data_dict) + + # Run batched inference using imported function from inference_on_zarr + with torch.no_grad(): + _, _, (predictions, _) = batched_inference_loop( + batch=batch, + model=model, + precision=cfg.precision, + data_mode=data_mode, + batch_resolution=batch_resolution, + output_pad_size=output_pad_size, + dist_manager=dist_manager, + datapipe=datapipe, + ) + + # Remove batch dimension and get predictions + predictions = predictions.squeeze(0) + + # Write predictions to output files + run_output_dir = output_dir / run_dir.name + run_output_dir.mkdir(parents=True, exist_ok=True) + + if data_mode in ["surface", "combined"]: + vtp_path = run_dir / f"boundary_{run_idx}.vtp" + if not vtp_path.exists(): + vtp_path = list(run_dir.glob("boundary_*.vtp"))[0] + + output_vtp = run_output_dir / f"pred_boundary_{run_idx}.vtp" + write_surface_predictions_to_vtk( + str(vtp_path), + str(output_vtp), + predictions, + air_density, + stream_velocity, + ) + logger.info(f"Saved surface predictions to {output_vtp}") + + if data_mode in ["volume", "combined"]: + vtu_path = run_dir / f"volume_{run_idx}.vtu" + if not vtu_path.exists(): + vtu_path = list(run_dir.glob("volume_*.vtu"))[0] + + output_vtu = run_output_dir / f"pred_volume_{run_idx}.vtu" + write_volume_predictions_to_vtk( + str(vtu_path), + str(output_vtu), + predictions, + air_density, + stream_velocity, + ) + logger.info(f"Saved volume predictions to {output_vtu}") + + elapsed = time.time() - start_time + logger.info(f"Completed run {run_idx} in {elapsed:.2f} seconds") + + except Exception as e: + logger.error(f"Error processing run {run_idx}: {e}") + import traceback + + traceback.print_exc() + continue + + logger.info("Inference complete!") + + +# ============================================================================= +# Entry Point +# ============================================================================= + + +@hydra.main(version_base=None, config_path="conf", config_name="geotransolver_surface") +def launch(cfg: DictConfig) -> None: + """ + Launch VTK inference with Hydra configuration. + + Uses existing geotransolver/transolver configs. VTK-specific parameters + must be added via command line with '+' prefix: + +vtk_inference.input_dir=/path/to/runs + +vtk_inference.output_dir=/path/to/output + +vtk_inference.air_density=1.2050 (optional, default: 1.2050) + +vtk_inference.stream_velocity=30.0 (optional, default: 30.0) + +vtk_inference.run_indices=[1,2,3] (optional, default: all runs) + + Parameters + ---------- + cfg : DictConfig + Hydra configuration object. + """ + inference_on_vtk(cfg) + + +if __name__ == "__main__": + launch() diff --git a/examples/cfd/external_aerodynamics/transolver/src/inference_on_zarr.py b/examples/cfd/external_aerodynamics/transformer_models/src/inference_on_zarr.py similarity index 69% rename from examples/cfd/external_aerodynamics/transolver/src/inference_on_zarr.py rename to examples/cfd/external_aerodynamics/transformer_models/src/inference_on_zarr.py index d2e3498c76..b8195ff327 100644 --- a/examples/cfd/external_aerodynamics/transolver/src/inference_on_zarr.py +++ b/examples/cfd/external_aerodynamics/transformer_models/src/inference_on_zarr.py @@ -19,9 +19,10 @@ import numpy as np import torch import torchinfo -import typing +import typing, csv import collections from typing import Literal +from datetime import datetime import hydra import omegaconf @@ -31,6 +32,7 @@ from physicsnemo.utils.logging import PythonLogger, RankZeroLoggingWrapper from sklearn.metrics import r2_score +from metrics import metrics_fn_surface, metrics_fn_volume from physicsnemo.distributed import DistributedManager @@ -199,8 +201,8 @@ def batched_inference_loop( metrics = {k: v / global_weight for k, v in metrics.items()} loss = loss / global_weight - global_predictions = torch.cat([l[0] for l in global_preds_targets], dim=1) - global_targets = torch.cat([l[1] for l in global_preds_targets], dim=1) + global_predictions = torch.cat([l[0][0] for l in global_preds_targets], dim=1) + global_targets = torch.cat([l[1][0] for l in global_preds_targets], dim=1) # Now, we have to *unshuffle* the prediction to the original index inverse_indices = torch.empty_like(indices) @@ -254,16 +256,25 @@ def inference(cfg: DictConfig) -> None: # Load the normalization file from configured directory (defaults to current dir) norm_dir = getattr(cfg.data, "normalization_dir", ".") - if cfg.data.mode == "surface": + if cfg.data.mode == "surface" or cfg.data.mode == "combined": norm_file = str(Path(norm_dir) / "surface_fields_normalization.npz") - elif cfg.data.mode == "volume": - norm_file = str(Path(norm_dir) / "volume_fields_normalization.npz") + norm_data = np.load(norm_file) + surface_factors = { + "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), + "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), + } + else: + surface_factors = None - norm_data = np.load(norm_file) - norm_factors = { - "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), - "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), - } + if cfg.data.mode == "volume" or cfg.data.mode == "combined": + norm_file = str(Path(norm_dir) / "volume_fields_normalization.npz") + norm_data = np.load(norm_file) + volume_factors = { + "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), + "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), + } + else: + volume_factors = None if cfg.compile: model = torch.compile(model, dynamic=True) @@ -287,7 +298,8 @@ def inference(cfg: DictConfig) -> None: val_dataset = create_transolver_dataset( cfg.data, phase="val", - scaling_factors=norm_factors, + surface_factors=surface_factors, + volume_factors=volume_factors, ) results = [] @@ -311,9 +323,23 @@ def inference(cfg: DictConfig) -> None: logger.info(f"Finished batch {batch_idx} in {elapsed:.4f} seconds") start = time.time() + air_density = batch["air_density"] if "air_density" in batch.keys() else None + stream_velocity = ( + batch["stream_velocity"] if "stream_velocity" in batch.keys() else None + ) + if cfg.data.mode == "surface": coeff = 1.0 + if stream_velocity is not None: + global_predictions = ( + global_predictions * stream_velocity**2.0 * air_density + ) + global_targets = global_targets * stream_velocity**2.0 * air_density + + metrics = metrics_fn_surface( + global_predictions, global_targets, dist_manager + ) # Compute the drag and loss coefficients: # (Index on [0] is to remove the 1 batch index) pred_pressure, pred_shear = torch.split( @@ -339,8 +365,6 @@ def inference(cfg: DictConfig) -> None: torch.tensor([[0, 0, 1]], device=dist_manager.device), ) - # air_density = batch["air_density"] if "air_density" in batch.keys() else None - # stream_velocity = batch["stream_velocity"] if "stream_velocity" in batch.keys() else None # true_fields = val_dataset.unscale_model_targets(batch["fields"], air_density=air_density, stream_velocity=stream_velocity) true_pressure, true_shear = torch.split(global_targets[0], (1, 3), dim=-1) @@ -372,20 +396,30 @@ def inference(cfg: DictConfig) -> None: if hasattr(metrics["l2_pressure_surf"], "item") else metrics["l2_pressure_surf"] ) - l2_shear_x = ( - metrics["l2_shear_x"].item() - if hasattr(metrics["l2_shear_x"], "item") - else metrics["l2_shear_x"] + l1_pressure = ( + metrics["l1_pressure_surf"].item() + if hasattr(metrics["l1_pressure_surf"], "item") + else metrics["l1_pressure_surf"] ) - l2_shear_y = ( - metrics["l2_shear_y"].item() - if hasattr(metrics["l2_shear_y"], "item") - else metrics["l2_shear_y"] + mae_pressure = ( + metrics["mae_pressure_surf"].item() + if hasattr(metrics["mae_pressure_surf"], "item") + else metrics["mae_pressure_surf"] ) - l2_shear_z = ( - metrics["l2_shear_z"].item() - if hasattr(metrics["l2_shear_z"], "item") - else metrics["l2_shear_z"] + l2_wall_shear_stress = ( + metrics["l2_wall_shear_stress"].item() + if hasattr(metrics["l2_wall_shear_stress"], "item") + else metrics["l2_wall_shear_stress"] + ) + l1_wall_shear_stress = ( + metrics["l1_wall_shear_stress"].item() + if hasattr(metrics["l1_wall_shear_stress"], "item") + else metrics["l1_wall_shear_stress"] + ) + mae_wall_shear_stress = ( + metrics["mae_wall_shear_stress"].item() + if hasattr(metrics["mae_wall_shear_stress"], "item") + else metrics["mae_wall_shear_stress"] ) results.append( @@ -393,9 +427,11 @@ def inference(cfg: DictConfig) -> None: batch_idx, f"{loss:.4f}", f"{l2_pressure:.4f}", - f"{l2_shear_x:.4f}", - f"{l2_shear_y:.4f}", - f"{l2_shear_z:.4f}", + f"{l1_pressure:.4f}", + f"{mae_pressure:.4f}", + f"{l2_wall_shear_stress:.4f}", + f"{l1_wall_shear_stress:.4f}", + f"{mae_wall_shear_stress:.4f}", f"{pred_drag_coeff:.4f}", f"{pred_lift_coeff:.4f}", f"{true_drag_coeff:.4f}", @@ -405,51 +441,97 @@ def inference(cfg: DictConfig) -> None: ) elif cfg.data.mode == "volume": + if stream_velocity is not None: + global_predictions[:, :, 3] = ( + global_predictions[:, :, 3] * stream_velocity**2.0 * air_density + ) + global_targets[:, :, 3] = ( + global_targets[:, :, 3] * stream_velocity**2.0 * air_density + ) + global_predictions[:, :, 0:3] = ( + global_predictions[:, :, 0:3] * stream_velocity + ) + global_targets[:, :, 0:3] = global_targets[:, :, 0:3] * stream_velocity + global_predictions[:, :, 4] = ( + global_predictions[:, :, 4] * stream_velocity**2.0 * air_density + ) + global_targets[:, :, 4] = ( + global_targets[:, :, 4] * stream_velocity**2.0 * air_density + ) + + metrics = metrics_fn_volume( + global_predictions, global_targets, dist_manager + ) # Extract metric values and convert tensors to floats l2_pressure = ( metrics["l2_pressure_vol"].item() if hasattr(metrics["l2_pressure_vol"], "item") else metrics["l2_pressure_vol"] ) - l2_velocity_x = ( - metrics["l2_velocity_x"].item() - if hasattr(metrics["l2_velocity_x"], "item") - else metrics["l2_velocity_x"] + l1_pressure = ( + metrics["l1_pressure_vol"].item() + if hasattr(metrics["l1_pressure_vol"], "item") + else metrics["l1_pressure_vol"] + ) + mae_pressure = ( + metrics["mae_pressure_vol"].item() + if hasattr(metrics["mae_pressure_vol"], "item") + else metrics["mae_pressure_vol"] + ) + l2_velocity = ( + metrics["l2_velocity"].item() + if hasattr(metrics["l2_velocity"], "item") + else metrics["l2_velocity"] ) - l2_velocity_y = ( - metrics["l2_velocity_y"].item() - if hasattr(metrics["l2_velocity_y"], "item") - else metrics["l2_velocity_y"] + l1_velocity = ( + metrics["l1_velocity"].item() + if hasattr(metrics["l1_velocity"], "item") + else metrics["l1_velocity"] ) - l2_velocity_z = ( - metrics["l2_velocity_z"].item() - if hasattr(metrics["l2_velocity_z"], "item") - else metrics["l2_velocity_z"] + mae_velocity = ( + metrics["mae_velocity"].item() + if hasattr(metrics["mae_velocity"], "item") + else metrics["mae_velocity"] ) + l2_nut = ( metrics["l2_nut"].item() if hasattr(metrics["l2_nut"], "item") else metrics["l2_nut"] ) + l1_nut = ( + metrics["l1_nut"].item() + if hasattr(metrics["l1_nut"], "item") + else metrics["l1_nut"] + ) + mae_nut = ( + metrics["mae_nut"].item() + if hasattr(metrics["mae_nut"], "item") + else metrics["mae_nut"] + ) results.append( [ batch_idx, f"{loss:.4f}", f"{l2_pressure:.4f}", - f"{l2_velocity_x:.4f}", - f"{l2_velocity_y:.4f}", - f"{l2_velocity_z:.4f}", + f"{l1_pressure:.4f}", + f"{mae_pressure:.4f}", + f"{l2_velocity:.4f}", + f"{l1_velocity:.4f}", + f"{mae_velocity:.4f}", f"{l2_nut:.4f}", + f"{l1_nut:.4f}", + f"{mae_nut:.4f}", f"{elapsed:.4f}", ] ) if cfg.data.mode == "surface": - pred_drag_coeffs = [r[6] for r in results] - pred_lift_coeffs = [r[7] for r in results] - true_drag_coeffs = [r[8] for r in results] - true_lift_coeffs = [r[9] for r in results] + pred_drag_coeffs = [r[8] for r in results] + pred_lift_coeffs = [r[9] for r in results] + true_drag_coeffs = [r[10] for r in results] + true_lift_coeffs = [r[11] for r in results] # Compute the R2 scores for lift and drag: r2_lift = r2_score(true_lift_coeffs, pred_lift_coeffs) @@ -459,9 +541,11 @@ def inference(cfg: DictConfig) -> None: "Batch", "Loss", "L2 Pressure", - "L2 Shear X", - "L2 Shear Y", - "L2 Shear Z", + "L1 Pressure", + "MAE Pressure", + "L2 Wall Shear Stress", + "L1 Wall Shear Stress", + "MAE Wall Shear Stress", "Predicted Drag Coefficient", "Pred Lift Coefficient", "True Drag Coefficient", @@ -473,21 +557,37 @@ def inference(cfg: DictConfig) -> None: ) logger.info(f"R2 score for lift: {r2_lift:.4f}") logger.info(f"R2 score for drag: {r2_drag:.4f}") + csv_filename = f"{cfg.output_dir}/{cfg.run_id}/surface_inference_results_{datetime.now()}.csv" + with open(csv_filename, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(headers) + writer.writerows(results) + logger.info(f"Results saved to {csv_filename}") elif cfg.data.mode == "volume": headers = [ "Batch", "Loss", "L2 Pressure", - "L2 Velocity X", - "L2 Velocity Y", - "L2 Velocity Z", + "L1 Pressure", + "MAE Pressure", + "L2 Velocity", + "L1 Velocity", + "MAE Velocity", "L2 Nut", + "L1 Nut", + "MAE Nut", "Elapsed (s)", ] logger.info( f"Results:\n{tabulate(results, headers=headers, tablefmt='github')}" ) + csv_filename = f"{cfg.output_dir}/{cfg.run_id}/volume_inference_results_{datetime.now()}.csv" + with open(csv_filename, "w", newline="") as f: + writer = csv.writer(f) + writer.writerow(headers) + writer.writerows(results) + logger.info(f"Results saved to {csv_filename}") # Calculate means for each metric (skip batch index) if results: diff --git a/examples/cfd/external_aerodynamics/transolver/src/metrics.py b/examples/cfd/external_aerodynamics/transformer_models/src/metrics.py similarity index 56% rename from examples/cfd/external_aerodynamics/transolver/src/metrics.py rename to examples/cfd/external_aerodynamics/transformer_models/src/metrics.py index 143e4fa338..be86d0cdc9 100644 --- a/examples/cfd/external_aerodynamics/transolver/src/metrics.py +++ b/examples/cfd/external_aerodynamics/transformer_models/src/metrics.py @@ -16,9 +16,11 @@ import torch import torch.distributed as dist -from physicsnemo.distributed import ShardTensor +from physicsnemo.domain_parallel import ShardTensor from physicsnemo.distributed import DistributedManager +from utils import tensorwise + def all_reduce_dict( metrics: dict[str, torch.Tensor], dm: DistributedManager @@ -49,6 +51,7 @@ def all_reduce_dict( return metrics +@tensorwise def metrics_fn( pred: torch.Tensor, target: torch.Tensor, @@ -98,6 +101,38 @@ def metrics_fn_volume( Raises: NotImplementedError: Always, as this function is not yet implemented. """ + + # + pressure_pred = pred[:, :, 3] + pressure_target = target[:, :, 3] + + velocity_pred = torch.sqrt(torch.sum(pred[:, :, 0:3] ** 2.0, dim=2)) + velocity_target = torch.sqrt(torch.sum(target[:, :, 0:3] ** 2.0, dim=2)) + + # L1 errors + l1_num = torch.abs(pred - target) + l1_num = torch.sum(l1_num, dim=1) + + l1_denom = torch.abs(target) + l1_denom = torch.sum(l1_denom, dim=1) + + l1 = l1_num / l1_denom + + # L1 errors velocity + l1_num_vel = torch.abs(velocity_pred - velocity_target) + l1_num_vel = torch.sum(l1_num_vel) + + l1_denom_vel = torch.abs(velocity_target) + l1_denom_vel = torch.sum(l1_denom_vel) + + l1_vel = l1_num_vel / l1_denom_vel + + # MAE + mae_num = torch.abs(pred - target) + mae_num_vel = torch.abs(velocity_pred - velocity_target) + mae_pressure = torch.abs(pressure_pred - pressure_target) + + # L2 errors l2_num = (pred - target) ** 2 l2_num = torch.sum(l2_num, dim=1) l2_num = torch.sqrt(l2_num) @@ -108,12 +143,36 @@ def metrics_fn_volume( l2 = l2_num / l2_denom + # L2 errors velocity + l2_num_vel = (velocity_pred - velocity_target) ** 2 + l2_num_vel = torch.sum(l2_num_vel) + l2_num_vel = torch.sqrt(l2_num_vel) + + l2_denom_vel = velocity_target**2 + l2_denom_vel = torch.sum(l2_denom_vel) + l2_denom_vel = torch.sqrt(l2_denom_vel) + + l2_vel = l2_num_vel / l2_denom_vel + metrics = { "l2_pressure_vol": torch.mean(l2[:, 3]), "l2_velocity_x": torch.mean(l2[:, 0]), "l2_velocity_y": torch.mean(l2[:, 1]), "l2_velocity_z": torch.mean(l2[:, 2]), "l2_nut": torch.mean(l2[:, 4]), + "l1_pressure_vol": torch.mean(l1[:, 3]), + "l1_velocity_x": torch.mean(l1[:, 0]), + "l1_velocity_y": torch.mean(l1[:, 1]), + "l1_velocity_z": torch.mean(l1[:, 2]), + "l1_nut": torch.mean(l1[:, 4]), + "mae_pressure_vol": torch.mean(mae_pressure), + "mae_velocity_x": torch.mean(mae_num[:, :, 0]), + "mae_velocity_y": torch.mean(mae_num[:, :, 1]), + "mae_velocity_z": torch.mean(mae_num[:, :, 2]), + "mae_nut": torch.mean(mae_num[:, 4]), + "l2_velocity": torch.mean(l2_vel), + "l1_velocity": torch.mean(l1_vel), + "mae_velocity": torch.mean(mae_num_vel), } return metrics @@ -141,6 +200,36 @@ def metrics_fn_surface( # target = target * norm_factors["std"] + norm_factors["mean"] # pred = pred * norm_factors["std"] + norm_factors["mean"] + pressure_pred = pred[:, :, 0] + pressure_target = target[:, :, 0] + + wall_shear_pred = torch.sqrt(torch.sum(pred[:, :, 1:4] ** 2.0, dim=2)) + wall_shear_target = torch.sqrt(torch.sum(target[:, :, 1:4] ** 2.0, dim=2)) + + # MAE + mae_num = torch.abs(pred - target) + mae_wall_shear = torch.abs(wall_shear_pred - wall_shear_target) + mae_pressure = torch.abs(pressure_pred - pressure_target) + + # L1 errors + l1_num = torch.abs(pred - target) + l1_num = torch.sum(l1_num, dim=1) + + l1_denom = torch.abs(target) + l1_denom = torch.sum(l1_denom, dim=1) + + l1 = l1_num / l1_denom + + # L1 errors for wall shear stress + l1_num_ws = torch.abs(wall_shear_pred - wall_shear_target) + l1_num_ws = torch.sum(l1_num_ws) + + l1_denom_ws = torch.abs(wall_shear_target) + l1_denom_ws = torch.sum(l1_denom_ws) + + l1_ws = l1_num_ws / l1_denom_ws + + # L2 errors l2_num = (pred - target) ** 2 l2_num = torch.sum(l2_num, dim=1) l2_num = torch.sqrt(l2_num) @@ -151,11 +240,33 @@ def metrics_fn_surface( l2 = l2_num / l2_denom + # L2 errors for wall shear stress + l2_num_ws = (wall_shear_pred - wall_shear_target) ** 2 + l2_num_ws = torch.sum(l2_num_ws) + l2_num_ws = torch.sqrt(l2_num_ws) + + l2_denom_ws = wall_shear_target**2 + l2_denom_ws = torch.sum(l2_denom_ws) + l2_denom_ws = torch.sqrt(l2_denom_ws) + + l2_ws = l2_num_ws / l2_denom_ws + metrics = { "l2_pressure_surf": torch.mean(l2[:, 0]), "l2_shear_x": torch.mean(l2[:, 1]), "l2_shear_y": torch.mean(l2[:, 2]), "l2_shear_z": torch.mean(l2[:, 3]), + "l1_pressure_surf": torch.mean(l1[:, 0]), + "l1_shear_x": torch.mean(l1[:, 1]), + "l1_shear_y": torch.mean(l1[:, 2]), + "l1_shear_z": torch.mean(l1[:, 3]), + "mae_pressure_surf": torch.mean(mae_pressure), + "mae_shear_x": torch.mean(mae_num[:, :, 1]), + "mae_shear_y": torch.mean(mae_num[:, :, 2]), + "mae_shear_z": torch.mean(mae_num[:, :, 3]), + "l2_wall_shear_stress": torch.mean(l2_ws), + "l1_wall_shear_stress": torch.mean(l1_ws), + "mae_wall_shear_stress": torch.mean(mae_wall_shear), } return metrics diff --git a/examples/cfd/external_aerodynamics/transolver/src/preprocess.py b/examples/cfd/external_aerodynamics/transformer_models/src/preprocess.py similarity index 100% rename from examples/cfd/external_aerodynamics/transolver/src/preprocess.py rename to examples/cfd/external_aerodynamics/transformer_models/src/preprocess.py diff --git a/examples/cfd/external_aerodynamics/transolver/src/surface_fields_normalization.npz b/examples/cfd/external_aerodynamics/transformer_models/src/surface_fields_normalization.npz similarity index 51% rename from examples/cfd/external_aerodynamics/transolver/src/surface_fields_normalization.npz rename to examples/cfd/external_aerodynamics/transformer_models/src/surface_fields_normalization.npz index b6809d416c..228f7550cc 100644 Binary files a/examples/cfd/external_aerodynamics/transolver/src/surface_fields_normalization.npz and b/examples/cfd/external_aerodynamics/transformer_models/src/surface_fields_normalization.npz differ diff --git a/examples/cfd/external_aerodynamics/transolver/src/train.py b/examples/cfd/external_aerodynamics/transformer_models/src/train.py similarity index 85% rename from examples/cfd/external_aerodynamics/transolver/src/train.py rename to examples/cfd/external_aerodynamics/transformer_models/src/train.py index a7cf2299c6..b4ddacfcb1 100644 --- a/examples/cfd/external_aerodynamics/transolver/src/train.py +++ b/examples/cfd/external_aerodynamics/transformer_models/src/train.py @@ -22,6 +22,8 @@ import collections from contextlib import nullcontext +from collections.abc import Sequence + # Configuration: import hydra import omegaconf @@ -54,10 +56,11 @@ # Local folder imports for this example from metrics import metrics_fn -from preprocess import ( - preprocess_surface_data, - downsample_surface, -) + +# tensorwise is to handle single-point-cloud or multi-point-cloud running. +# it's a decorator that will automatically unzip one or more of a list of tensors, +# run the funtcion, and rezip the results. +from utils import tensorwise # Special import, if transformer engine is available: from physicsnemo.core.version_check import check_version_spec @@ -166,9 +169,12 @@ def get_autocast_context(precision: str) -> nullcontext: return nullcontext() -def cast_precisions(*tensors: torch.Tensor, precision: str) -> list[torch.Tensor]: +@tensorwise +def cast_precisions(tensor: torch.Tensor, precision: str) -> torch.Tensor: """ Casts the tensors to the specified precision. + + We are careful to take either a tensor or list of tensors, and return the same format. """ match precision: @@ -180,11 +186,12 @@ def cast_precisions(*tensors: torch.Tensor, precision: str) -> list[torch.Tensor dtype = None if dtype is not None: - tensors = [t.to(dtype) for t in tensors] - - return tensors + return tensor.to(dtype) + else: + return tensor +@tensorwise def pad_input_for_fp8( features: torch.Tensor, embeddings: torch.Tensor, @@ -217,6 +224,7 @@ def pad_input_for_fp8( return features, geometry +@tensorwise def unpad_output_for_fp8( outputs: torch.Tensor, output_pad_size: int | None ) -> torch.Tensor: @@ -236,6 +244,14 @@ def unpad_output_for_fp8( return outputs +@tensorwise +def loss_fn(outputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: + """ + Compute the loss for the model. + """ + return torch.nn.functional.mse_loss(outputs, targets) + + def forward_pass( batch: dict, model: torch.nn.Module, @@ -247,6 +263,12 @@ def forward_pass( ): """ Run the forward pass of the model for one batch, including metrics and loss calculation. + + Transolver takes just one tensor for features, embeddings. + Typhon takes a list of tensors, for each. + + Typhon needs a `geometry` tensor, so that's the switch we use to distinguish. + """ features = batch["fx"] @@ -254,44 +276,87 @@ def forward_pass( targets = batch["fields"] # Cast precisions: - features, embeddings = cast_precisions(features, embeddings, precision=precision) - + features = cast_precisions(features, precision=precision) + embeddings = cast_precisions(embeddings, precision=precision) if "geometry" in batch.keys(): - (geometry,) = cast_precisions(batch["geometry"], precision=precision) + geometry = cast_precisions(batch["geometry"], precision=precision) else: geometry = None + all_metrics = {} + if datapipe.config.model_type == "combined": + # This is hard coded for Typhon. If you have more point clouds, + # your mileage may vary. + modes = ["surface", "volume"] + elif datapipe.config.model_type == "surface": + modes = [ + "surface", + ] + elif datapipe.config.model_type == "volume": + modes = [ + "volume", + ] + with get_autocast_context(precision): # For fp8, we may have to pad the inputs: if precision == "float8" and TE_AVAILABLE: features, geometry = pad_input_for_fp8(features, embeddings, geometry) if "geometry" in batch.keys(): + local_positions = embeddings[:, :, :3] + # This is the Typhon path outputs = model( - global_embedding=features, local_embedding=embeddings, geometry=geometry + global_embedding=features, + local_embedding=embeddings, + geometry=geometry, + local_positions=local_positions, ) + + outputs = unpad_output_for_fp8(outputs, output_pad_size) + # Loss per point cloud: + loss = loss_fn(outputs, targets) + # Log them too: + for i, mode in enumerate(modes): + all_metrics[f"loss/{mode}"] = loss.item() + # Averaging over point cloud inputs, instead of summing. + full_loss = torch.mean(loss) + else: + # This is the Transolver path outputs = model(fx=features, embedding=embeddings) + outputs = unpad_output_for_fp8(outputs, output_pad_size) + full_loss = torch.nn.functional.mse_loss(outputs, targets) - outputs = unpad_output_for_fp8(outputs, output_pad_size) - - loss = torch.nn.functional.mse_loss(outputs, targets) + all_metrics[f"loss/{modes[0]}"] = full_loss air_density = batch["air_density"] if "air_density" in batch.keys() else None stream_velocity = ( batch["stream_velocity"] if "stream_velocity" in batch.keys() else None ) - unscaled_outputs = datapipe.unscale_model_targets( - outputs, air_density=air_density, stream_velocity=stream_velocity + unscaled_outputs = tensorwise(datapipe.unscale_model_targets)( + outputs, + air_density=air_density, + stream_velocity=stream_velocity, + factor_type=modes, ) - unscaled_targets = datapipe.unscale_model_targets( - targets, air_density=air_density, stream_velocity=stream_velocity + unscaled_targets = tensorwise(datapipe.unscale_model_targets)( + targets, + air_density=air_density, + stream_velocity=stream_velocity, + factor_type=modes, ) + metrics = metrics_fn(unscaled_outputs, unscaled_targets, dist_manager, modes) - metrics = metrics_fn(unscaled_outputs, unscaled_targets, dist_manager, data_mode) + # In the combined mode, this is a list of dicts. Merge them. + metrics = ( + {k: v for d in metrics for k, v in d.items()} + if isinstance(metrics, list) + else metrics + ) + all_metrics.update(metrics) - return loss, metrics, (unscaled_outputs, unscaled_targets) + return full_loss, all_metrics, (unscaled_outputs, unscaled_targets) @profile @@ -369,9 +434,7 @@ def train_epoch( if i == 0: total_metrics = metrics else: - total_metrics = { - k: total_metrics[k] + metrics[k].item() for k in metrics.keys() - } + total_metrics = {k: total_metrics[k] + metrics[k] for k in metrics.keys()} duration = end_time - start_time start_time = end_time @@ -468,7 +531,7 @@ def val_epoch( total_metrics = metrics else: total_metrics = { - k: total_metrics[k] + metrics[k].item() for k in metrics.keys() + k: total_metrics[k] + metrics[k] for k in metrics.keys() } # Logging @@ -592,7 +655,8 @@ def main(cfg: DictConfig): cfg, output_pad_size = update_model_params_for_fp8(cfg, logger) # Set up model - model = hydra.utils.instantiate(cfg.model) + # (Using partial convert to get lists, etc., instead of ListConfigs.) + model = hydra.utils.instantiate(cfg.model, _convert_="partial") logger.info(f"\n{torchinfo.summary(model, verbose=0)}") model.to(dist_manager.device) @@ -608,22 +672,32 @@ def main(cfg: DictConfig): # Load the normalization file from configured directory (defaults to current dir) norm_dir = getattr(cfg.data, "normalization_dir", ".") - if cfg.data.mode == "surface": + if cfg.data.mode == "surface" or cfg.data.mode == "combined": norm_file = str(Path(norm_dir) / "surface_fields_normalization.npz") - elif cfg.data.mode == "volume": - norm_file = str(Path(norm_dir) / "volume_fields_normalization.npz") + norm_data = np.load(norm_file) + surface_factors = { + "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), + "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), + } + else: + surface_factors = None - norm_data = np.load(norm_file) - norm_factors = { - "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), - "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), - } + if cfg.data.mode == "volume" or cfg.data.mode == "combined": + norm_file = str(Path(norm_dir) / "volume_fields_normalization.npz") + norm_data = np.load(norm_file) + volume_factors = { + "mean": torch.from_numpy(norm_data["mean"]).to(dist_manager.device), + "std": torch.from_numpy(norm_data["std"]).to(dist_manager.device), + } + else: + volume_factors = None # Training dataset train_dataloader = create_transolver_dataset( cfg.data, phase="train", - scaling_factors=norm_factors, + surface_factors=surface_factors, + volume_factors=volume_factors, ) # Validation dataset @@ -631,7 +705,8 @@ def main(cfg: DictConfig): val_dataloader = create_transolver_dataset( cfg.data, phase="val", - scaling_factors=norm_factors, + surface_factors=surface_factors, + volume_factors=volume_factors, ) num_replicas = dist_manager.world_size diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/utils.py b/examples/cfd/external_aerodynamics/transformer_models/src/utils.py new file mode 100644 index 0000000000..a5484e9747 --- /dev/null +++ b/examples/cfd/external_aerodynamics/transformer_models/src/utils.py @@ -0,0 +1,102 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from collections.abc import Iterable, Sequence +import torch +import functools + +_SEQUENCE_BLOCKLIST = (torch.Tensor, str, bytes) + + +def _is_tensor_sequence(x): + return isinstance(x, Sequence) and not isinstance(x, _SEQUENCE_BLOCKLIST) + + +def _coerce_iterable(arg): + """ + Normalize iterable inputs so tensorwise can unzip any sequence-like object, + even if it is only an iterator (e.g., zip objects of strings or constants). + """ + if _is_tensor_sequence(arg): + return arg, True + if isinstance(arg, Iterable) and not isinstance(arg, _SEQUENCE_BLOCKLIST): + return tuple(arg), True + return arg, False + + +def tensorwise(fn): + """ + Decorator: allow fn(tensor, ...) or fn(list-of-tensors, ...). + If any argument is a sequence of tensors, apply fn elementwise. Non-sequence + iterables (zip objects, generators of strings, etc.) are automatically + materialized so they can participate in the elementwise zip as well. + All sequences must be the same length. + """ + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + # Detect sequences while allowing generic iterables (e.g., zip objects) + normalized_args = [] + seq_flags = [] + for arg in args: + normalized_arg, is_seq = _coerce_iterable(arg) + normalized_args.append(normalized_arg) + seq_flags.append(is_seq) + + normalized_kwargs = {} + kw_seq_flags = {} + for key, value in kwargs.items(): + normalized_value, is_seq = _coerce_iterable(value) + normalized_kwargs[key] = normalized_value + kw_seq_flags[key] = is_seq + + any_seq = any(seq_flags) or any(kw_seq_flags.values()) + + if not any_seq: + # Nothing is a sequence — call normally + return fn(*normalized_args, **normalized_kwargs) + + # All sequence arguments must be sequences of the same length + # Collect all sequences (positional + keyword) + seq_lengths = {len(a) for a, flag in zip(normalized_args, seq_flags) if flag} + seq_lengths.update( + len(normalized_kwargs[k]) for k, flag in kw_seq_flags.items() if flag + ) + lengths = seq_lengths + if len(lengths) != 1: + raise ValueError( + f"Sequence arguments must have same length; got lengths {lengths}." + ) + + L = lengths.pop() + + outs = [] + for i in range(L): + # Rebuild ith positional args + ith_args = [ + (a[i] if is_s else a) for a, is_s in zip(normalized_args, seq_flags) + ] + # Rebuild ith keyword args + ith_kwargs = { + k: (v[i] if kw_seq_flags[k] else v) + for k, v in normalized_kwargs.items() + } + outs.append(fn(*ith_args, **ith_kwargs)) + + return outs + + return wrapper diff --git a/examples/cfd/external_aerodynamics/transformer_models/src/volume_fields_normalization.npz b/examples/cfd/external_aerodynamics/transformer_models/src/volume_fields_normalization.npz new file mode 100644 index 0000000000..c1f0e6f463 Binary files /dev/null and b/examples/cfd/external_aerodynamics/transformer_models/src/volume_fields_normalization.npz differ diff --git a/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_surface.yaml b/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_surface.yaml deleted file mode 100644 index db7f5938c9..0000000000 --- a/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_surface.yaml +++ /dev/null @@ -1,105 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - - -output_dir: "runs" -checkpoint_dir: null # Optional: set custom checkpoint path, defaults to output_dir -run_id: "surface/bfloat16" - -# Performance considerations: -precision: bfloat16 # float32, float16, bfloat16, or float8 -compile: true -profile: false - -# Training configuration -training: - num_epochs: 501 # Add one to save at 250 - save_interval: 25 # Save checkpoint every N epochs - - # StepLR scheduler: Decays the learning rate by gamma every step_size epochs - scheduler: - name: "StepLR" - params: - step_size: 100 # Decay every 200 epochs (set X as desired) - gamma: 0.5 # Decay factor - - # Optimizer configuration - optimizer: - _target_: torch.optim.AdamW - lr: 1.0e-3 - weight_decay: 1.0e-4 - betas: [0.9, 0.999] - eps: 1.0e-8 - -# Model configuration -model: - _target_: physicsnemo.models.transolver.Transolver - functional_dim: 2 # Input feature dimension - out_dim: 4 # Output feature dimension - embedding_dim: 6 # Spatial embedding dimension - n_layers: 8 # Number of transformer layers - n_hidden: 256 # Hidden dimension - dropout: 0.0 # Dropout rate - n_head: 8 # Number of attention heads - act: "gelu" # Activation function - mlp_ratio: 2 # MLP ratio in attention blocks - slice_num: 512 # Number of slices in physics attention - unified_pos: false # Whether to use unified positional embeddings - ref: 8 # Reference dimension for unified pos - structured_shape: null - use_te: false # Use transformer engine - time_input: false # Whether to use time embeddings - plus: false - - -# Data configuration -data: - train: - data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/train/ - val: - data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/val/ - max_workers: 8 - normalization_dir: "src/" # Directory for normalization files - preload_depth: 1 - pin_memory: true - resolution: 300_000 - mode: surface - # Preprocessing switches: - # (Changing thes will change the embedding dim) - include_normals: true - include_sdf: false - translational_invariance: true - scale_invariance: true - reference_scale: [12.0, 4.5, 3.25] - data_keys: - - "surface_fields" - - "surface_mesh_centers" - - "surface_normals" - - "surface_areas" - - "air_density" - - "stream_velocity" - - "stl_faces" - - "stl_centers" - - "stl_coordinates" - include_geometry: false - broadcast_global_features: true - return_mesh_features: false - -# Logging configuration -logging: - level: INFO - format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' \ No newline at end of file diff --git a/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_volume.yaml b/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_volume.yaml deleted file mode 100644 index 04a907c1b4..0000000000 --- a/examples/cfd/external_aerodynamics/transolver/src/conf/transolver_volume.yaml +++ /dev/null @@ -1,101 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. -# SPDX-FileCopyrightText: All rights reserved. -# SPDX-License-Identifier: Apache-2.0 -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - - -output_dir: "runs" -checkpoint_dir: null # Optional: set custom checkpoint path, defaults to output_dir -run_id: "volume/bfloat16" - -# Performance considerations: -precision: bfloat16 # float32, float16, bfloat16, or float8 -compile: true -profile: false - -# Training configuration -training: - num_epochs: 501 # Add one to save at 250 - save_interval: 25 # Save checkpoint every N epochs - - # StepLR scheduler: Decays the learning rate by gamma every step_size epochs - scheduler: - name: "StepLR" - params: - step_size: 100 # Decay every 200 epochs (set X as desired) - gamma: 0.5 # Decay factor - - # Optimizer configuration - optimizer: - _target_: torch.optim.AdamW - lr: 1.0e-3 - weight_decay: 1.0e-4 - betas: [0.9, 0.999] - eps: 1.0e-8 - -# Model configuration -model: - _target_: physicsnemo.models.transolver.Transolver - functional_dim: 2 # Input feature dimension - out_dim: 5 # Output feature dimension - embedding_dim: 7 # Spatial embedding dimension - n_layers: 8 # Number of transformer layers - n_hidden: 256 # Hidden dimension - dropout: 0.0 # Dropout rate - n_head: 8 # Number of attention heads - act: "gelu" # Activation function - mlp_ratio: 2 # MLP ratio in attention blocks - slice_num: 512 # Number of slices in physics attention - unified_pos: false # Whether to use unified positional embeddings - ref: 8 # Reference dimension for unified pos - structured_shape: null - use_te: false # Use transformer engine - time_input: false # Whether to use time embeddings - plus: false - - -# Data configuration -data: - train: - data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/train/ - val: - data_path: /lustre/fsw/portfolios/coreai/projects/coreai_modulus_cae/datasets/drivaer_aws/domino/val/ - max_workers: 8 - normalization_dir: "src/" # Directory for normalization files - preload_depth: 1 - volume_sample_from_disk: true # Enable faster IO on pre-shuffled volumetric data - pin_memory: true - resolution: 300_000 - # Preprocessing switches: - # (Changing thes will change the embedding dim) - include_normals: true - include_sdf: true - translational_invariance: true - scale_invariance: true - reference_scale: [12.0, 4.5, 3.25] - mode: volume - data_keys: - - "volume_fields" - - "volume_mesh_centers" - - "air_density" - - "stream_velocity" - - "stl_faces" - - "stl_centers" - - "stl_coordinates" - -# Logging configuration -logging: - level: INFO - format: '%(asctime)s - %(name)s - %(levelname)s - %(message)s' diff --git a/physicsnemo/datapipes/cae/cae_dataset.py b/physicsnemo/datapipes/cae/cae_dataset.py index 5272beefaf..e404c7272a 100644 --- a/physicsnemo/datapipes/cae/cae_dataset.py +++ b/physicsnemo/datapipes/cae/cae_dataset.py @@ -652,19 +652,26 @@ def read_file_attributes( keys = store.list().result() + def to_tensor_dict(attributes_dict): + attributes = {} + for k, v in attributes_dict.items(): + try: + attributes[k] = torch.tensor(v) + except (TypeError, ValueError, RuntimeError): # noqa PERF203 + pass + return attributes + # Zarr 3 check: if b"/zarr.json" in keys: zarr_json = store.read(b"/zarr.json").result() # load into json's parser: attributes_dict = json.loads(zarr_json.value)["attributes"] - attributes = {k: torch.tensor(v) for k, v in attributes_dict.items()} - return attributes + return to_tensor_dict(attributes_dict) elif b"/.zattrs" in keys: # Zarr 2: zarr_attrs = store.read(b"/.zattrs").result() attributes_dict = json.loads(zarr_attrs.value) - attributes = {k: torch.tensor(v) for k, v in attributes_dict.items()} - return attributes + return to_tensor_dict(attributes_dict) else: return {} diff --git a/physicsnemo/datapipes/cae/transolver_datapipe.py b/physicsnemo/datapipes/cae/transolver_datapipe.py index 8a4faae344..0161c3a638 100644 --- a/physicsnemo/datapipes/cae/transolver_datapipe.py +++ b/physicsnemo/datapipes/cae/transolver_datapipe.py @@ -69,7 +69,7 @@ class TransolverDataConfig: """ data_path: Path | None - model_type: Literal["surface", "volume"] = "surface" + model_type: Literal["surface", "volume", "combined"] = "surface" resolution: int = 200_000 # Control what features are added to the inputs to the model: @@ -82,7 +82,8 @@ class TransolverDataConfig: # For controlling the normalization of target values: scaling_type: Optional[Literal["min_max_scaling", "mean_std_scaling"]] = None - normalization_factors: Optional[torch.Tensor] = None + surface_factors: Optional[torch.Tensor] = None + volume_factors: Optional[torch.Tensor] = None ############################################################ # Translation invariance configuration: @@ -199,11 +200,6 @@ def preprocess_surface_data( # Build the embeddings: embeddings_inputs = [positions] - # Surface SDF is always 0: - if self.config.include_sdf: - sdf = torch.zeros_like(positions[:, 0:1]) - embeddings_inputs.append(sdf) - if self.config.include_normals: normals = data_dict["surface_normals"] if idx is not None: @@ -213,30 +209,37 @@ def preprocess_surface_data( embeddings = torch.cat(embeddings_inputs, dim=-1) - # Build fx: - fx_inputs = [ - data_dict["air_density"], - data_dict["stream_velocity"], - ] - fx = torch.stack(fx_inputs, dim=-1) - - if self.config.broadcast_global_features: - fx = fx.broadcast_to(embeddings.shape[0], -1) - else: - fx = fx.unsqueeze(0) - fields = data_dict["surface_fields"] if idx is not None: fields = fields[idx] if self.config.scaling_type is not None: - fields = self.scale_model_targets(fields, self.config.normalization_factors) + fields = self.scale_model_targets(fields, self.config.surface_factors) - return { - "embeddings": embeddings, - "fx": fx, - "fields": fields, - } + if "air_density" in data_dict and "stream_velocity" in data_dict: + # Build fx: + fx_inputs = [ + data_dict["air_density"], + data_dict["stream_velocity"], + ] + fx = torch.stack(fx_inputs, dim=-1) + + if self.config.broadcast_global_features: + fx = fx.broadcast_to(embeddings.shape[0], -1) + else: + fx = fx.unsqueeze(0) + + return { + "embeddings": embeddings, + "fx": fx, + "fields": fields, + } + + else: + return { + "embeddings": embeddings, + "fields": fields, + } def preprocess_volume_data( self, @@ -316,30 +319,36 @@ def preprocess_volume_data( embeddings = torch.cat(embeddings_inputs, dim=-1) - # Build fx: - fx_inputs = [ - data_dict["air_density"], - data_dict["stream_velocity"], - ] - fx = torch.stack(fx_inputs, dim=-1) - - if self.config.broadcast_global_features: - fx = fx.broadcast_to(embeddings.shape[0], -1) - else: - fx = fx.unsqueeze(0) - fields = data_dict["volume_fields"] if idx is not None: fields = fields[idx] if self.config.scaling_type is not None: - fields = self.scale_model_targets(fields, self.config.normalization_factors) + fields = self.scale_model_targets(fields, self.config.volume_factors) - return { - "embeddings": embeddings, - "fx": fx, - "fields": fields, - } + if "air_density" in data_dict and "stream_velocity" in data_dict: + # Build fx: + fx_inputs = [ + data_dict["air_density"], + data_dict["stream_velocity"], + ] + fx = torch.stack(fx_inputs, dim=-1) + + if self.config.broadcast_global_features: + fx = fx.broadcast_to(embeddings.shape[0], -1) + else: + fx = fx.unsqueeze(0) + + return { + "embeddings": embeddings, + "fx": fx, + "fields": fields, + } + else: + return { + "embeddings": embeddings, + "fields": fields, + } def process_geometry( self, @@ -422,7 +431,7 @@ def process_data(self, data_dict): "stl_centers", ] - if self.config.model_type == "volume": + if self.config.model_type == "volume" or self.config.model_type == "combined": # We need these for the SDF calculation: required_keys.extend( [ @@ -430,7 +439,9 @@ def process_data(self, data_dict): "stl_faces", ] ) - elif self.config.model_type == "surface": + elif ( + self.config.model_type == "surface" or self.config.model_type == "combined" + ): required_keys.extend( [ "surface_normals", @@ -446,15 +457,20 @@ def process_data(self, data_dict): else: center_of_mass = None - field_key = f"{self.config.model_type}_fields" - coords_key = f"{self.config.model_type}_mesh_centers" - - required_keys.extend( - [ - field_key, - coords_key, - ] - ) + if self.config.model_type == "surface" or self.config.model_type == "combined": + required_keys.extend( + [ + "surface_fields", + "surface_mesh_centers", + ] + ) + elif self.config.model_type == "volume" or self.config.model_type == "combined": + required_keys.extend( + [ + "volume_fields", + "volume_mesh_centers", + ] + ) missing_keys = [key for key in required_keys if key not in data_dict] if missing_keys: @@ -475,6 +491,23 @@ def process_data(self, data_dict): outputs = self.preprocess_volume_data( data_dict, center_of_mass, scale_factor ) + elif self.config.model_type == "combined": + outputs_surf = self.preprocess_surface_data( + data_dict, center_of_mass, scale_factor + ) + + outputs_vol = self.preprocess_volume_data( + data_dict, center_of_mass, scale_factor + ) + + outputs = {} + outputs["embeddings"] = [ + outputs_surf["embeddings"], + outputs_vol["embeddings"], + ] + # This should be the same in either: + outputs["fx"] = outputs_surf["fx"] + outputs["fields"] = [outputs_surf["fields"], outputs_vol["fields"]] if self.config.include_geometry: outputs["geometry"] = self.process_geometry( @@ -512,6 +545,7 @@ def unscale_model_targets( fields: torch.Tensor | None = None, air_density: torch.Tensor | None = None, stream_velocity: torch.Tensor | None = None, + factor_type: Literal["surface", "volume", "auto"] = "auto", ): """ Unscale the model outputs based on the configured scaling factors. @@ -521,7 +555,18 @@ def unscale_model_targets( """ - factors = self.config.normalization_factors + match factor_type: + case "surface": + factors = self.config.surface_factors + case "volume": + factors = self.config.volume_factors + case "auto": + if self.config.model_type == "surface": + factors = self.config.surface_factors + elif self.config.model_type == "volume": + factors = self.config.volume_factors + else: + raise ValueError(f"Invalid model type {self.config.model_type}") if self.config.scaling_type == "mean_std_scaling": field_mean = factors["mean"] @@ -532,8 +577,8 @@ def unscale_model_targets( field_max = factors["max"] fields = unnormalize(fields, field_max, field_min) - if air_density is not None and stream_velocity is not None: - fields = fields * air_density * stream_velocity**2 + # if air_density is not None and stream_velocity is not None: + # fields = fields * air_density * stream_velocity**2 return fields @@ -591,9 +636,11 @@ def __call__(self, data_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor """ outputs = self.process_data(data_dict) - for key in outputs.keys(): - outputs[key] = outputs[key].unsqueeze(0) + if isinstance(outputs[key], list): + outputs[key] = [item.unsqueeze(0) for item in outputs[key]] + else: + outputs[key] = outputs[key].unsqueeze(0) return outputs @@ -610,10 +657,8 @@ def __iter__(self): def create_transolver_dataset( cfg: DictConfig, phase: Literal["train", "val", "test"], - # keys_to_read: list[str], - # keys_to_read_if_available: dict[str, torch.Tensor], - scaling_factors: list[float], - # normalize_coordinates: bool = True, + surface_factors: dict[str, torch.Tensor] | None = None, + volume_factors: dict[str, torch.Tensor] | None = None, device_mesh: torch.distributed.DeviceMesh | None = None, placements: dict[str, torch.distributed.tensor.Placement] | None = None, ): @@ -694,7 +739,8 @@ def create_transolver_dataset( datapipe = TransolverDataPipe( input_path, resolution=cfg.resolution, - normalization_factors=scaling_factors, + surface_factors=surface_factors, + volume_factors=volume_factors, model_type=model_type, scaling_type="mean_std_scaling", **overrides, diff --git a/physicsnemo/experimental/models/geotransolver/__init__.py b/physicsnemo/experimental/models/geotransolver/__init__.py new file mode 100644 index 0000000000..2a5672e971 --- /dev/null +++ b/physicsnemo/experimental/models/geotransolver/__init__.py @@ -0,0 +1,65 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""GeoTransolver: Geometry-Aware Physics Attention Transformer. + +This module provides the GeoTransolver model and its components for learning +physics-based representations with geometry and global context awareness. + +Classes +------- +GeoTransolver + Main model class combining GALE attention with geometry and global context. +GALE + Geometry-Aware Latent Embeddings attention layer. +GALE_block + Transformer block using GALE attention. +ContextProjector + Projects context features onto physical state slices. +GlobalContextBuilder + Orchestrates context construction for the model. + +Examples +-------- +Basic usage: + +>>> import torch +>>> from physicsnemo.experimental.models.geotransolver import GeoTransolver +>>> model = GeoTransolver( +... functional_dim=64, +... out_dim=3, +... n_hidden=256, +... n_layers=4, +... use_te=False, +... ) +>>> x = torch.randn(2, 1000, 64) +>>> output = model(x) +>>> output.shape +torch.Size([2, 1000, 3]) +""" + +from .context_projector import ContextProjector, GlobalContextBuilder +from .gale import GALE, GALE_block +from .geotransolver import GeoTransolver, GeoTransolverMetaData + +__all__ = [ + "GeoTransolver", + "GeoTransolverMetaData", + "GALE", + "GALE_block", + "ContextProjector", + "GlobalContextBuilder", +] \ No newline at end of file diff --git a/physicsnemo/experimental/models/geotransolver/context_projector.py b/physicsnemo/experimental/models/geotransolver/context_projector.py new file mode 100644 index 0000000000..56bde60b06 --- /dev/null +++ b/physicsnemo/experimental/models/geotransolver/context_projector.py @@ -0,0 +1,823 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Context Projector for GeoTransolver model. + +This module provides classes for projecting context features (geometry or global +embeddings) onto learned physical state spaces for use in GALE attention layers. + +Classes +------- +ContextProjector + Projects context features onto physical state slices. +GeometricFeatureProcessor + Processes geometric features at a single spatial scale using BQWarp. +MultiScaleFeatureExtractor + Multi-scale geometric feature extraction with minimal complexity. +GlobalContextBuilder + Orchestrates all context construction for the GeoTransolver model. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +from einops import rearrange +from jaxtyping import Float + +from physicsnemo.core.version_check import check_version_spec +from physicsnemo.models.transolver.Physics_Attention import gumbel_softmax +from physicsnemo.nn.ball_query import BQWarp +from physicsnemo.nn.mlp_layers import Mlp + +# Check optional dependency availability +TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) +if TE_AVAILABLE: + import transformer_engine.pytorch as te + + +class ContextProjector(nn.Module): + r"""Projects context features onto physical state space. + + This context projector is conceptually similar to half of a GALE attention layer. + It projects context values (geometry or global embeddings) onto a learned physical + state space, but unlike a full attention layer, it never projects back to the + original space. The projected features are used as context in all GALE blocks + of the GeoTransolver model. + + Parameters + ---------- + dim : int + Input dimension of the context features. + heads : int, optional + Number of projection heads. Default is 8. + dim_head : int, optional + Dimension of each projection head. Default is 64. + dropout : float, optional + Dropout rate. Default is 0.0. + slice_num : int, optional + Number of learned physical state slices. Default is 64. + use_te : bool, optional + Whether to use Transformer Engine backend when available. Default is ``True``. + plus : bool, optional + Whether to use Transolver++ features. Default is ``False``. + + Forward + ------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)` where :math:`B` is batch size, + :math:`N` is number of tokens, and :math:`C` is number of channels. + + Outputs + ------- + torch.Tensor + Slice tokens of shape :math:`(B, H, S, D)` where :math:`H` is number of heads, + :math:`S` is number of slices, and :math:`D` is head dimension. + + Notes + ----- + The global features are reused in all blocks of the model, so the learned + projections must capture globally useful features rather than layer-specific ones. + + See Also + -------- + :class:`~physicsnemo.experimental.models.geotransolver.gale.GALE` : Full GALE attention layer that uses these projected context features. + :class:`~physicsnemo.experimental.models.geotransolver.GeoTransolver` : Main model that uses ContextProjector for geometry and global embeddings. + + Examples + -------- + >>> import torch + >>> projector = ContextProjector(dim=64, heads=8, dim_head=32, slice_num=32) + >>> x = torch.randn(2, 100, 64) # (batch, tokens, features) + >>> slice_tokens = projector(x) + >>> slice_tokens.shape + torch.Size([2, 8, 32, 32]) + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + use_te: bool = True, + plus: bool = False, + ) -> None: + super().__init__() + inner_dim = dim_head * heads + self.dim_head = dim_head + self.heads = heads + self.plus = plus + self.scale = dim_head**-0.5 + self.use_te = use_te + + # Choose linear layer implementation based on backend + linear_layer = te.Linear if (use_te and TE_AVAILABLE) else nn.Linear + + # Input projection layers for query and key + self.in_project_x = linear_layer(dim, inner_dim) + if not plus: + self.in_project_fx = linear_layer(dim, inner_dim) + + # Attention components + self.softmax = nn.Softmax(dim=-1) + self.dropout = nn.Dropout(dropout) + self.temperature = nn.Parameter(torch.ones([1, heads, 1, 1]) * 0.5) + + # Transolver++ adaptive temperature projection + if plus: + self.proj_temperature = nn.Sequential( + linear_layer(self.dim_head, slice_num), + nn.GELU(), + linear_layer(slice_num, 1), + nn.GELU(), + ) + + # Slice projection layer maps from head dimension to slice space + self.in_project_slice = linear_layer(dim_head, slice_num) + + def project_input_onto_slices( + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> ( + Float[torch.Tensor, "batch heads tokens dim"] + | tuple[ + Float[torch.Tensor, "batch heads tokens dim"], + Float[torch.Tensor, "batch heads tokens dim"], + ] + ): + r"""Project the input onto the slice space. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)` where :math:`B` is batch size, + :math:`N` is number of tokens, and :math:`C` is number of channels. + + Returns + ------- + torch.Tensor or tuple[torch.Tensor, torch.Tensor] + If ``plus=True``, returns single tensor of shape :math:`(B, H, N, D)` where + :math:`H` is number of heads and :math:`D` is head dimension. If ``plus=False``, + returns tuple of two tensors both of shape :math:`(B, H, N, D)`, representing + the query and key projections respectively. + """ + # Project input to multi-head representation: (B, N, C) -> (B, H, N, D) + projected_x = rearrange( + self.in_project_x(x), "B N (h d) -> B h N d", h=self.heads, d=self.dim_head + ) + + if self.plus: + # Transolver++ uses single projection for both paths + return projected_x + else: + # Standard Transolver uses separate query and key projections + feature_projection = rearrange( + self.in_project_fx(x), + "B N (h d) -> B h N d", + h=self.heads, + d=self.dim_head, + ) + return projected_x, feature_projection + + def compute_slices_from_projections( + self, + slice_projections: Float[torch.Tensor, "batch heads tokens slices"], + fx: Float[torch.Tensor, "batch heads tokens dim"], + ) -> tuple[ + Float[torch.Tensor, "batch heads tokens slices"], + Float[torch.Tensor, "batch heads slices dim"], + ]: + r"""Compute slice weights and slice tokens from input projections and latent features. + + Parameters + ---------- + slice_projections : torch.Tensor + Projected input tensor of shape :math:`(B, H, N, S)` where :math:`B` is batch size, + :math:`H` is number of heads, :math:`N` is number of tokens, and :math:`S` is number of + slices, representing the projection of each token onto each slice for each + attention head. + fx : torch.Tensor + Latent feature tensor of shape :math:`(B, H, N, D)` where :math:`D` is head dimension, + representing the learned states to be aggregated by the slice weights. + + Returns + ------- + tuple[torch.Tensor, torch.Tensor] + - ``slice_weights``: Tensor of shape :math:`(B, H, N, S)`, normalized weights for + each slice per token and head. + - ``slice_token``: Tensor of shape :math:`(B, H, S, D)`, aggregated latent features + for each slice, head, and batch. + + Notes + ----- + The function computes a temperature-scaled softmax over the slice projections to + obtain slice weights, then aggregates the latent features for each slice using + these weights. The aggregated features are normalized by the sum of weights for + numerical stability. + """ + # Compute temperature-adjusted softmax weights + if self.plus: + # Transolver++ uses adaptive temperature with Gumbel softmax + temperature = self.temperature + self.proj_temperature(fx) + clamped_temp = torch.clamp(temperature, min=0.01).to( + slice_projections.dtype + ) + slice_weights = gumbel_softmax(slice_projections, clamped_temp) + else: + # Standard Transolver uses fixed temperature with regular softmax + clamped_temp = torch.clamp(self.temperature, min=0.5, max=5).to( + slice_projections.dtype + ) + slice_weights = nn.functional.softmax( + slice_projections / clamped_temp, dim=-1 + ) + + # Ensure weights match the computation dtype + slice_weights = slice_weights.to(slice_projections.dtype) + + # Aggregate features by slice weights with normalization + # Normalize first to prevent overflow in reduced precision + slice_norm = slice_weights.sum(2) # Sum over tokens: (B, H, S) + normed_weights = slice_weights / (slice_norm[:, :, None, :] + 1e-2) + + # Weighted aggregation: (B, H, S, N) @ (B, H, N, D) -> (B, H, S, D) + slice_token = torch.matmul(normed_weights.transpose(2, 3), fx) + + return slice_weights, slice_token + + def forward( + self, x: Float[torch.Tensor, "batch tokens channels"] + ) -> Float[torch.Tensor, "batch heads slices dim"]: + r"""Project inputs to physical state slices. + + This performs a partial physics attention operation: it projects the input onto + learned physical state slices but does not project back to the original space. + The resulting slice tokens serve as context for GALE attention layers. + + Parameters + ---------- + x : torch.Tensor + Input tensor of shape :math:`(B, N, C)` where :math:`B` is batch size, :math:`N` is + number of tokens, and :math:`C` is number of channels. + + Returns + ------- + torch.Tensor + Slice tokens of shape :math:`(B, H, S, D)` where :math:`H` is number of heads, + :math:`S` is number of slices, and :math:`D` is head dimension. + + Notes + ----- + This method implements the encoding portion of the physics attention mechanism. + The slice tokens capture learned physical state representations that are used + as cross-attention context throughout the model. + """ + ### Input validation + if not torch.compiler.is_compiling(): + if x.ndim != 3: + raise ValueError( + f"Expected 3D input tensor (B, N, C), " + f"got {x.ndim}D tensor with shape {tuple(x.shape)}" + ) + + # Project inputs onto learned latent spaces + if self.plus: + projected_x = self.project_input_onto_slices(x) + # Transolver++ reuses the same projection for both paths + feature_projection = projected_x + else: + projected_x, feature_projection = self.project_input_onto_slices(x) + + # Project latent representations onto physical state slices: (B, H, N, D) -> (B, H, N, S) + slice_projections = self.in_project_slice(projected_x) + + # Compute weighted aggregation of features into slice tokens + _, slice_tokens = self.compute_slices_from_projections( + slice_projections, feature_projection + ) + + return slice_tokens + + +class GeometricFeatureProcessor(nn.Module): + r"""Processes geometric features at a single spatial scale using BQWarp. + + This is a simple, reusable component that handles neighbor querying and + feature processing for one radius scale. It encapsulates the BQWarp + + MLP pattern used throughout the model. + + Parameters + ---------- + radius : float + Query radius for neighbor search. + neighbors_in_radius : int + Maximum number of neighbors within the radius. + feature_dim : int + Dimension of the input features to query. + hidden_dim : int + Output dimension after MLP processing. + + Forward + ------- + query_points : torch.Tensor + Query coordinates of shape :math:`(B, N, 3)` where :math:`B` is batch size + and :math:`N` is number of query points. + key_features : torch.Tensor + Features to query from of shape :math:`(B, N, C)` where :math:`C` is + ``feature_dim``. + + Outputs + ------- + torch.Tensor + Processed features of shape :math:`(B, N, D)` where :math:`D` is ``hidden_dim``. + + See Also + -------- + :class:`MultiScaleFeatureExtractor` : Uses multiple GeometricFeatureProcessor instances. + :class:`~physicsnemo.nn.ball_query.BQWarp` : The ball query operation used internally. + + Examples + -------- + >>> import torch + >>> processor = GeometricFeatureProcessor( + ... radius=0.1, neighbors_in_radius=16, feature_dim=3, hidden_dim=64 + ... ) + >>> query_points = torch.randn(2, 100, 3) # (batch, points, xyz) + >>> key_features = torch.randn(2, 100, 3) # (batch, points, features) + >>> output = processor(query_points, key_features) + >>> output.shape + torch.Size([2, 100, 64]) + """ + + def __init__( + self, + radius: float, + neighbors_in_radius: int, + feature_dim: int, + hidden_dim: int, + ) -> None: + super().__init__() + + # Ball query for neighbor search within radius + self.bq_warp = BQWarp(radius=radius, neighbors_in_radius=neighbors_in_radius) + + # MLP to process flattened neighbor features + self.mlp = Mlp( + in_features=feature_dim * neighbors_in_radius, + hidden_features=[hidden_dim, hidden_dim // 2], + out_features=hidden_dim, + act_layer=nn.GELU, + drop=0.0, + ) + + def forward( + self, + query_points: Float[torch.Tensor, "batch points spatial_dim"], + key_features: Float[torch.Tensor, "batch points features"], + ) -> Float[torch.Tensor, "batch points hidden_dim"]: + r"""Query neighbors and process features. + + Parameters + ---------- + query_points : torch.Tensor + Query coordinates of shape :math:`(B, N, 3)` where :math:`B` is batch size + and :math:`N` is number of query points. + key_features : torch.Tensor + Features to query from of shape :math:`(B, N, C)` where :math:`C` is the + feature dimension. + + Returns + ------- + torch.Tensor + Processed features of shape :math:`(B, N, D)` where :math:`D` is the + hidden dimension. + """ + ### Input validation + if not torch.compiler.is_compiling(): + if query_points.ndim != 3: + raise ValueError( + f"Expected 3D query_points tensor (B, N, 3), " + f"got {query_points.ndim}D tensor with shape {tuple(query_points.shape)}" + ) + if key_features.ndim != 3: + raise ValueError( + f"Expected 3D key_features tensor (B, N, C), " + f"got {key_features.ndim}D tensor with shape {tuple(key_features.shape)}" + ) + + # Query neighbors within radius: (B, N, K, C) + _, neighbors = self.bq_warp(query_points, key_features) + + # Flatten neighbor features for MLP: (B, N, K, C) -> (B, N, K*C) + neighbors_flat = rearrange(neighbors, "b n k c -> b n (k c)") + + # Process through MLP with tanh activation for bounded output + return torch.nn.functional.tanh(self.mlp(neighbors_flat)) + + +class MultiScaleFeatureExtractor(nn.Module): + r"""Multi-scale geometric feature extraction with minimal complexity. + + Manages multiple GeometricFeatureProcessor instances for different radii. + Provides both tokenized context and concatenated local features. + + Parameters + ---------- + geometry_dim : int + Dimension of geometry features. + radii : list[float] + Radii for multi-scale processing. + neighbors_in_radius : list[int] + Neighbors per radius (must have same length as ``radii``). + hidden_dim : int + Hidden dimension for processing. + n_head : int + Number of attention heads. + dim_head : int + Dimension per head. + dropout : float, optional + Dropout rate. Default is 0.0. + slice_num : int, optional + Number of slices for context tokenization. Default is 64. + use_te : bool, optional + Whether to use Transformer Engine. Default is ``True``. + plus : bool, optional + Whether to use Transolver++ features. Default is ``False``. + + Forward + ------- + This class does not implement a standard ``forward`` method. Instead, use: + + - :meth:`extract_context_features`: Get tokenized features for GALE context. + - :meth:`extract_local_features`: Get concatenated features for local pathway. + + See Also + -------- + :class:`GeometricFeatureProcessor` : Single-scale processor used by this class. + :class:`ContextProjector` : Tokenizer used for context features. + :class:`GlobalContextBuilder` : High-level builder that uses this class. + + Examples + -------- + >>> import torch + >>> extractor = MultiScaleFeatureExtractor( + ... geometry_dim=3, + ... radii=[0.05, 0.25], + ... neighbors_in_radius=[8, 32], + ... hidden_dim=32, + ... n_head=8, + ... dim_head=32, + ... ) + >>> spatial_coords = torch.randn(2, 100, 3) + >>> geometry = torch.randn(2, 100, 3) + >>> context_feats = extractor.extract_context_features(spatial_coords, geometry) + >>> len(context_feats) # One per scale + 2 + >>> local_feats = extractor.extract_local_features(spatial_coords, geometry) + >>> local_feats.shape # Concatenated across scales + torch.Size([2, 100, 64]) + """ + + def __init__( + self, + geometry_dim: int, + radii: list[float], + neighbors_in_radius: list[int], + hidden_dim: int, + n_head: int, + dim_head: int, + dropout: float = 0.0, + slice_num: int = 64, + use_te: bool = True, + plus: bool = False, + ) -> None: + super().__init__() + self.num_scales = len(radii) + + # One processor per scale for geometric feature extraction + self.processors = nn.ModuleList( + [ + GeometricFeatureProcessor( + radii[i], neighbors_in_radius[i], geometry_dim, hidden_dim + ) + for i in range(self.num_scales) + ] + ) + + # One tokenizer per scale for projecting to context space + self.tokenizers = nn.ModuleList( + [ + ContextProjector( + hidden_dim, n_head, dim_head, dropout, slice_num, use_te, plus + ) + for _ in range(self.num_scales) + ] + ) + + def extract_context_features( + self, + spatial_coords: Float[torch.Tensor, "batch points spatial_dim"], + geometry: Float[torch.Tensor, "batch points geometry_dim"], + ) -> list[Float[torch.Tensor, "batch heads slices dim"]]: + r"""Extract and tokenize features for context. + + Parameters + ---------- + spatial_coords : torch.Tensor + Spatial coordinates of shape :math:`(B, N, 3)`. + geometry : torch.Tensor + Geometry features of shape :math:`(B, N, C_{geo})`. + + Returns + ------- + list[torch.Tensor] + List of tokenized context features, one per scale, each of shape + :math:`(B, H, S, D)`. + """ + return [ + tokenizer(processor(spatial_coords, geometry)) + for processor, tokenizer in zip(self.processors, self.tokenizers) + ] + + def extract_local_features( + self, + spatial_coords: Float[torch.Tensor, "batch points spatial_dim"], + geometry: Float[torch.Tensor, "batch points geometry_dim"], + ) -> Float[torch.Tensor, "batch points total_hidden"]: + r"""Extract and concatenate features for local pathway. + + Parameters + ---------- + spatial_coords : torch.Tensor + Spatial coordinates of shape :math:`(B, N, 3)`. + geometry : torch.Tensor + Geometry features of shape :math:`(B, N, C_{geo})`. + + Returns + ------- + torch.Tensor + Concatenated local features of shape :math:`(B, N, D_{total})` where + :math:`D_{total}` is ``hidden_dim * num_scales``. + """ + return torch.cat( + [processor(geometry, spatial_coords) for processor in self.processors], + dim=-1, + ) + + +class GlobalContextBuilder(nn.Module): + r"""Orchestrates all context construction with a clean, simple interface. + + Manages geometry tokenization, global embedding tokenization, and optional + multi-scale local features. This is the main entry point for building context + in the GeoTransolver model. + + Parameters + ---------- + functional_dims : tuple[int, ...] + Dimensions of each functional input type. + geometry_dim : int | None, optional + Geometry feature dimension. If ``None``, geometry context is disabled. + Default is ``None``. + global_dim : int | None, optional + Global embedding dimension. If ``None``, global context is disabled. + Default is ``None``. + radii : list[float], optional + Radii for local features. Default is ``[0.05, 0.25]``. + neighbors_in_radius : list[int], optional + Neighbors per radius. Default is ``[8, 32]``. + n_hidden_local : int, optional + Hidden dim for local features. Default is 32. + n_hidden : int, optional + Model hidden dimension. Default is 256. + n_head : int, optional + Number of attention heads. Default is 8. + dropout : float, optional + Dropout rate. Default is 0.0. + slice_num : int, optional + Number of slices for tokenization. Default is 32. + use_te : bool, optional + Whether to use Transformer Engine. Default is ``True``. + plus : bool, optional + Whether to use Transolver++ features. Default is ``False``. + include_local_features : bool, optional + Enable local feature extraction. Default is ``False``. + + Forward + ------- + This class does not implement a standard ``forward`` method. Instead, use + :meth:`build_context` to construct context and local features. + + See Also + -------- + :class:`ContextProjector` : Used for tokenizing geometry and global embeddings. + :class:`MultiScaleFeatureExtractor` : Used for multi-scale local features. + :class:`~physicsnemo.experimental.models.geotransolver.GeoTransolver` : Main model that uses this builder. + + Examples + -------- + >>> import torch + >>> builder = GlobalContextBuilder( + ... functional_dims=(64,), + ... geometry_dim=3, + ... global_dim=16, + ... n_hidden=256, + ... n_head=8, + ... ) + >>> local_embeddings = (torch.randn(2, 100, 64),) + >>> geometry = torch.randn(2, 100, 3) + >>> global_embedding = torch.randn(2, 1, 16) + >>> context, local_feats = builder.build_context( + ... local_embeddings, None, geometry, global_embedding + ... ) + >>> context.shape + torch.Size([2, 8, 32, 64]) + """ + + def __init__( + self, + functional_dims: tuple[int, ...], + geometry_dim: int | None = None, + global_dim: int | None = None, + radii: list[float] | None = None, + neighbors_in_radius: list[int] | None = None, + n_hidden_local: int = 32, + n_hidden: int = 256, + n_head: int = 8, + dropout: float = 0.0, + slice_num: int = 32, + use_te: bool = True, + plus: bool = False, + include_local_features: bool = False, + ) -> None: + super().__init__() + + # Set defaults for mutable arguments + if radii is None: + radii = [0.05, 0.25] + if neighbors_in_radius is None: + neighbors_in_radius = [8, 32] + + dim_head = n_hidden // n_head + context_dim = 0 + + # Multi-scale extractors for local features (one per functional dim) + if geometry_dim is not None and include_local_features: + self.local_extractors = nn.ModuleList( + [ + MultiScaleFeatureExtractor( + geometry_dim, + radii, + neighbors_in_radius, + n_hidden_local, + n_head, + dim_head, + dropout, + slice_num, + use_te, + plus, + ) + for _ in functional_dims + ] + ) + context_dim += dim_head * len(radii) * len(functional_dims) + else: + self.local_extractors = None + + # Geometry tokenizer for global geometry context + if geometry_dim is not None: + self.geometry_tokenizer = ContextProjector( + geometry_dim, n_head, dim_head, dropout, slice_num, use_te, plus + ) + context_dim += dim_head + else: + self.geometry_tokenizer = None + + # Global embedding tokenizer + if global_dim is not None: + self.global_tokenizer = ContextProjector( + global_dim, n_head, dim_head, dropout, slice_num, use_te, plus + ) + context_dim += dim_head + else: + self.global_tokenizer = None + + self._context_dim = context_dim + + def get_context_dim(self) -> int: + r"""Return total context dimension. + + Returns + ------- + int + Total dimension of the concatenated context features. + """ + return self._context_dim + + def build_context( + self, + local_embeddings: tuple[Float[torch.Tensor, "batch tokens features"], ...], + local_positions: ( + tuple[Float[torch.Tensor, "batch tokens spatial_dim"], ...] | None + ), + geometry: Float[torch.Tensor, "batch tokens geometry_dim"] | None = None, + global_embedding: Float[torch.Tensor, "batch global_tokens global_dim"] + | None = None, + ) -> tuple[ + Float[torch.Tensor, "batch heads slices context_dim"] | None, + list[Float[torch.Tensor, "batch tokens local_features"]] | None, + ]: + r"""Build all context and local features. + + Parameters + ---------- + local_embeddings : tuple[torch.Tensor, ...] + Input embeddings, each of shape :math:`(B, N, C_i)` where :math:`B` is + batch size, :math:`N` is number of tokens, and :math:`C_i` is the feature + dimension for input type :math:`i`. + local_positions : tuple[torch.Tensor, ...] | None + Local positions, each of shape :math:`(B, N, 3)`. These are used to query + neighbors for local features. Required if ``include_local_features=True``. + geometry : torch.Tensor | None, optional + Geometry features of shape :math:`(B, N, C_{geo})`. Default is ``None``. + global_embedding : torch.Tensor | None, optional + Global embedding of shape :math:`(B, N_g, C_g)`. Default is ``None``. + + Returns + ------- + tuple[torch.Tensor | None, list[torch.Tensor] | None] + - ``context``: Concatenated context tensor of shape :math:`(B, H, S, D_c)` + where :math:`D_c` is the total context dimension, or ``None`` if no + context sources are provided. + - ``local_features``: List of local feature tensors, one per input type, + each of shape :math:`(B, N, D_l)`, or ``None`` if local features are + disabled. + + Raises + ------ + ValueError + If ``local_positions`` is ``None`` but local features are enabled. + """ + ### Input validation + if not torch.compiler.is_compiling(): + if len(local_embeddings) == 0: + raise ValueError("Expected non-empty tuple of local embeddings") + for i, emb in enumerate(local_embeddings): + if emb.ndim != 3: + raise ValueError( + f"Expected 3D local_embedding tensor (B, N, C) at index {i}, " + f"got {emb.ndim}D tensor with shape {tuple(emb.shape)}" + ) + + context_parts = [] + local_features = None + + if local_positions is None and self.local_extractors is not None: + raise ValueError( + "Local positions are required if local features are enabled." + ) + + # Extract multi-scale features if enabled + if self.local_extractors is not None and geometry is not None: + local_features = [] + for i, embedding in enumerate(local_embeddings): + spatial_coords = local_positions[i] # Extract coordinates + + # Get tokenized context features from multi-scale extractor + context_feats = self.local_extractors[i].extract_context_features( + spatial_coords, geometry + ) + context_parts.extend(context_feats) + + # Get concatenated local features for skip connection + local_feats = self.local_extractors[i].extract_local_features( + spatial_coords, geometry + ) + local_features.append(local_feats) + + # Tokenize geometry features + if self.geometry_tokenizer is not None and geometry is not None: + context_parts.append(self.geometry_tokenizer(geometry)) + + # Tokenize global embedding + if self.global_tokenizer is not None and global_embedding is not None: + context_parts.append(self.global_tokenizer(global_embedding)) + + # Concatenate all context features along the last dimension + context = torch.cat(context_parts, dim=-1) if context_parts else None + + return context, local_features \ No newline at end of file diff --git a/physicsnemo/experimental/models/geotransolver/gale.py b/physicsnemo/experimental/models/geotransolver/gale.py new file mode 100644 index 0000000000..75461f6e64 --- /dev/null +++ b/physicsnemo/experimental/models/geotransolver/gale.py @@ -0,0 +1,467 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""GALE (Geometry-Aware Latent Embeddings) attention layer and transformer block. + +This module provides the GALE attention mechanism and GALE_block transformer block, +which extend the Transolver physics attention with cross-attention capabilities for +geometry and global context embeddings. +""" + +from __future__ import annotations + +import torch +import torch.nn as nn +from einops import rearrange +from jaxtyping import Float + +import physicsnemo # noqa: F401 for docs +from physicsnemo.core.version_check import check_version_spec +from physicsnemo.models.transolver.Physics_Attention import ( + PhysicsAttentionIrregularMesh, +) +from physicsnemo.models.transolver.transolver import MLP + +# Check optional dependency availability +TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) +if TE_AVAILABLE: + import transformer_engine.pytorch as te + + +class GALE(PhysicsAttentionIrregularMesh): + r"""Geometry-Aware Latent Embeddings (GALE) attention layer. + + This is an extension of the Transolver PhysicsAttention mechanism to support + cross-attention with a context vector, built from geometry and global embeddings. + GALE combines self-attention on learned physical state slices with cross-attention + to geometry-aware context, using a learnable mixing weight to blend the two. + + Parameters + ---------- + dim : int + Input dimension of the features. + heads : int, optional + Number of attention heads. Default is 8. + dim_head : int, optional + Dimension of each attention head. Default is 64. + dropout : float, optional + Dropout rate. Default is 0.0. + slice_num : int, optional + Number of learned physical state slices. Default is 64. + use_te : bool, optional + Whether to use Transformer Engine backend when available. Default is True. + plus : bool, optional + Whether to use Transolver++ features. Default is False. + context_dim : int, optional + Dimension of the context vector for cross-attention. Default is 0. + + Forward + ------- + x : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` is + batch size, :math:`N` is number of tokens, and :math:`C` is number of channels. + context : tuple[torch.Tensor, ...] | None, optional + Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` where + :math:`H` is number of heads, :math:`S_c` is number of context slices, and + :math:`D_c` is context dimension. If ``None``, only self-attention is applied. + Default is ``None``. + + Outputs + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)`, same shape as inputs. + + Notes + ----- + The mixing between self-attention and cross-attention is controlled by a learnable + parameter ``state_mixing`` which is passed through a sigmoid function to ensure + the mixing weight stays in :math:`[0, 1]`. + + See Also + -------- + :class:`physicsnemo.models.transolver.Physics_Attention.PhysicsAttentionIrregularMesh` : Base physics attention class. + :class:`GALE_block` : Transformer block using GALE attention. + + Examples + -------- + >>> import torch + >>> gale = GALE(dim=256, heads=8, dim_head=32, context_dim=32) + >>> x = (torch.randn(2, 100, 256),) # Single input tensor in tuple + >>> context = torch.randn(2, 8, 64, 32) # Context for cross-attention + >>> outputs = gale(x, context) + >>> len(outputs) + 1 + >>> outputs[0].shape + torch.Size([2, 100, 256]) + """ + + def __init__( + self, + dim: int, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + slice_num: int = 64, + use_te: bool = True, + plus: bool = False, + context_dim: int = 0, + ) -> None: + super().__init__(dim, heads, dim_head, dropout, slice_num, use_te, plus) + + linear_layer = te.Linear if self.use_te else nn.Linear + + # Cross-attention projection layers for context integration + self.cross_q = linear_layer(dim_head, dim_head) + self.cross_k = linear_layer(context_dim, dim_head) + self.cross_v = linear_layer(context_dim, dim_head) + + # Learnable mixing weight between self and cross attention + # Initialize near 0.0 since sigmoid(0) = 0.5, giving balanced initial mixing + self.state_mixing = nn.Parameter(torch.tensor(0.0)) + + def compute_slice_attention_cross( + self, + slice_tokens: list[Float[torch.Tensor, "batch heads slices dim"]], + context: Float[torch.Tensor, "batch heads context_slices context_dim"], + ) -> list[Float[torch.Tensor, "batch heads slices dim"]]: + r"""Compute cross-attention between slice tokens and context. + + Parameters + ---------- + slice_tokens : list[torch.Tensor] + List of slice token tensors, each of shape :math:`(B, H, N, D)` where + :math:`B` is batch size, :math:`H` is number of heads, :math:`N` is + number of slices, and :math:`D` is head dimension. + context : torch.Tensor + Context tensor of shape :math:`(B, H, N_c, D_c)` where :math:`N_c` is + number of context slices and :math:`D_c` is context dimension. + + Returns + ------- + list[torch.Tensor] + List of cross-attention outputs, each of shape :math:`(B, H, N, D)`. + """ + # Concatenate all slice tokens for batched projection + q_input = torch.cat(slice_tokens, dim=-2) # (B, H, total_slices, D) + + # Project queries from slice tokens + q = self.cross_q(q_input) # (B, H, total_slices, D) + + # Project keys and values from context + k = self.cross_k(context) # (B, H, N_c, D) + v = self.cross_v(context) # (B, H, N_c, D) + + # Compute cross-attention using appropriate backend + if self.use_te: + # Transformer Engine expects (B, S, H, D) format + q = rearrange(q, "b h s d -> b s h d") + k = rearrange(k, "b h s d -> b s h d") + v = rearrange(v, "b h s d -> b s h d") + cross_attention = self.attn_fn(q, k, v) + cross_attention = rearrange( + cross_attention, "b s (h d) -> b h s d", h=self.heads, d=self.dim_head + ) + else: + # Use PyTorch's scaled dot-product attention + cross_attention = torch.nn.functional.scaled_dot_product_attention( + q, k, v, is_causal=False + ) + + # Split back into individual slice token outputs + cross_attention = torch.split( + cross_attention, slice_tokens[0].shape[-2], dim=-2 + ) + + return list(cross_attention) + + def forward( + self, + x: tuple[Float[torch.Tensor, "batch tokens channels"], ...], + context: Float[torch.Tensor, "batch heads context_slices context_dim"] + | None = None, + ) -> list[Float[torch.Tensor, "batch tokens channels"]]: + r"""Forward pass of the GALE module. + + Applies physics-aware self-attention combined with optional cross-attention + to geometry and global context. + + Parameters + ---------- + x : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` + is batch size, :math:`N` is number of tokens, and :math:`C` is number + of channels. + context : torch.Tensor | None, optional + Context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` + where :math:`H` is number of heads, :math:`S_c` is number of context + slices, and :math:`D_c` is context dimension. If ``None``, only + self-attention is applied. Default is ``None``. + + Returns + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)``, same shape + as inputs. + """ + ### Input validation + if not torch.compiler.is_compiling(): + if len(x) == 0: + raise ValueError("Expected non-empty tuple of input tensors") + for i, tensor in enumerate(x): + if tensor.ndim != 3: + raise ValueError( + f"Expected 3D input tensor (B, N, C) at index {i}, " + f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" + ) + + # Project inputs onto learned latent spaces + if self.plus: + x_mid = [self.project_input_onto_slices(_x) for _x in x] + # In Transolver++, x_mid is reused for both projections + fx_mid = [_x_mid for _x_mid in x_mid] + else: + x_mid, fx_mid = zip( + *[self.project_input_onto_slices(_x) for _x in x] + ) + + # Project latent representations onto physical state slices + slice_projections = [self.in_project_slice(_x_mid) for _x_mid in x_mid] + + # Compute slice weights and aggregated slice tokens + slice_weights, slice_tokens = zip( + *[ + self.compute_slices_from_projections(proj, _fx_mid) + for proj, _fx_mid in zip(slice_projections, fx_mid) + ] + ) + + # Apply self-attention to slice tokens + if self.use_te: + self_slice_token = [ + self.compute_slice_attention_te(_slice_token) + for _slice_token in slice_tokens + ] + else: + self_slice_token = [ + self.compute_slice_attention_sdpa(_slice_token) + for _slice_token in slice_tokens + ] + + # Apply cross-attention with context if provided + if context is not None: + cross_slice_token = [ + self.compute_slice_attention_cross([_slice_token], context)[0] + for _slice_token in slice_tokens + ] + + # Blend self-attention and cross-attention with learnable mixing weight + mixing_weight = torch.sigmoid(self.state_mixing) + out_slice_token = [ + mixing_weight * sst + (1 - mixing_weight) * cst + for sst, cst in zip(self_slice_token, cross_slice_token) + ] + else: + # Use only self-attention when no context is provided + out_slice_token = self_slice_token + + # Project attention outputs back to original space using slice weights + outputs = [ + self.project_attention_outputs(ost, sw) + for ost, sw in zip(out_slice_token, slice_weights) + ] + + return outputs + + +class GALE_block(nn.Module): + r"""Transformer encoder block using GALE attention. + + This block replaces standard self-attention with the GALE (Geometry-Aware Latent + Embeddings) attention mechanism, which combines physics-aware self-attention with + cross-attention to geometry and global context. + + Parameters + ---------- + num_heads : int + Number of attention heads. + hidden_dim : int + Hidden dimension of the transformer. + dropout : float + Dropout rate. + act : str, optional + Activation function name. Default is ``"gelu"``. + mlp_ratio : int, optional + Ratio of MLP hidden dimension to ``hidden_dim``. Default is 4. + last_layer : bool, optional + Whether this is the last layer in the model. Default is ``False``. + out_dim : int, optional + Output dimension (only used if ``last_layer=True``). Default is 1. + slice_num : int, optional + Number of learned physical state slices. Default is 32. + use_te : bool, optional + Whether to use Transformer Engine backend. Default is ``True``. + plus : bool, optional + Whether to use Transolver++ features. Default is ``False``. + context_dim : int, optional + Dimension of the context vector for cross-attention. Default is 0. + + Forward + ------- + fx : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` is + batch size, :math:`N` is number of tokens, and :math:`C` is hidden dimension. + global_context : tuple[torch.Tensor, ...] + Global context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` + where :math:`H` is number of heads, :math:`S_c` is number of context slices, + and :math:`D_c` is context dimension. + + Outputs + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)`, same shape as inputs. + + Notes + ----- + The block applies layer normalization before the attention operation and uses + residual connections after both the attention and MLP layers. + + See Also + -------- + :class:`GALE` : The attention mechanism used in this block. + :class:`physicsnemo.experimental.models.geotransolver.GeoTransolver` : Main model using GALE_block. + + Examples + -------- + >>> import torch + >>> block = GALE_block(num_heads=8, hidden_dim=256, dropout=0.1, context_dim=32) + >>> fx = (torch.randn(2, 100, 256),) # Single input tensor in tuple + >>> context = torch.randn(2, 8, 64, 32) # Global context + >>> outputs = block(fx, context) + >>> len(outputs) + 1 + >>> outputs[0].shape + torch.Size([2, 100, 256]) + """ + + def __init__( + self, + num_heads: int, + hidden_dim: int, + dropout: float, + act: str = "gelu", + mlp_ratio: int = 4, + last_layer: bool = False, + out_dim: int = 1, + slice_num: int = 32, + use_te: bool = True, + plus: bool = False, + context_dim: int = 0, + ) -> None: + super().__init__() + + if use_te and not TE_AVAILABLE: + raise ImportError( + "Transformer Engine is not installed. " + "Please install it with: pip install transformer-engine>=0.1.0" + ) + + self.last_layer = last_layer + + # Layer normalization before attention + if use_te: + self.ln_1 = te.LayerNorm(hidden_dim) + else: + self.ln_1 = nn.LayerNorm(hidden_dim) + + # GALE attention layer + self.Attn = GALE( + hidden_dim, + heads=num_heads, + dim_head=hidden_dim // num_heads, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + context_dim=context_dim, + ) + + # Feed-forward network with layer normalization + if use_te: + self.ln_mlp1 = te.LayerNormMLP( + hidden_size=hidden_dim, + ffn_hidden_size=hidden_dim * mlp_ratio, + ) + else: + self.ln_mlp1 = nn.Sequential( + nn.LayerNorm(hidden_dim), + MLP( + hidden_dim, + hidden_dim * mlp_ratio, + hidden_dim, + n_layers=0, + res=False, + act=act, + use_te=False, + ), + ) + + def forward( + self, + fx: tuple[Float[torch.Tensor, "batch tokens hidden_dim"], ...], + global_context: Float[torch.Tensor, "batch heads context_slices context_dim"], + ) -> list[Float[torch.Tensor, "batch tokens hidden_dim"]]: + r"""Forward pass of the GALE block. + + Parameters + ---------- + fx : tuple[torch.Tensor, ...] + Tuple of input tensors, each of shape :math:`(B, N, C)` where :math:`B` + is batch size, :math:`N` is number of tokens, and :math:`C` is hidden + dimension. + global_context : torch.Tensor + Global context tensor for cross-attention of shape :math:`(B, H, S_c, D_c)` + where :math:`H` is number of heads, :math:`S_c` is number of context slices, + and :math:`D_c` is context dimension. + + Returns + ------- + list[torch.Tensor] + List of output tensors, each of shape :math:`(B, N, C)`, same shape as inputs. + """ + ### Input validation + if not torch.compiler.is_compiling(): + if len(fx) == 0: + raise ValueError("Expected non-empty tuple of input tensors") + for i, tensor in enumerate(fx): + if tensor.ndim != 3: + raise ValueError( + f"Expected 3D input tensor (B, N, C) at index {i}, " + f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" + ) + + # Apply pre-normalization to all inputs + normed_inputs = [self.ln_1(_fx) for _fx in fx] + + # Apply GALE attention with cross-attention to global context + attn = self.Attn(tuple(normed_inputs), global_context) + + # Residual connection after attention + fx_out = [attn[i] + normed_inputs[i] for i in range(len(normed_inputs))] + + # Feed-forward network with residual connection + fx_out = [self.ln_mlp1(_fx) + _fx for _fx in fx_out] + + return fx_out \ No newline at end of file diff --git a/physicsnemo/experimental/models/geotransolver/geotransolver.py b/physicsnemo/experimental/models/geotransolver/geotransolver.py new file mode 100644 index 0000000000..fe58f8804e --- /dev/null +++ b/physicsnemo/experimental/models/geotransolver/geotransolver.py @@ -0,0 +1,561 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""GeoTransolver: Geometry-Aware Physics Attention Transformer. + +This module provides the GeoTransolver model, which extends the Transolver architecture +with GALE (Geometry-Aware Latent Embeddings) attention for incorporating geometric +structure and global context throughout the forward pass. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass + +import torch +import torch.nn as nn +from jaxtyping import Float + +import physicsnemo # noqa: F401 for docs +from physicsnemo.core.meta import ModelMetaData +from physicsnemo.core.module import Module +from physicsnemo.core.version_check import check_version_spec +from physicsnemo.models.transolver.transolver import MLP + +from .context_projector import GlobalContextBuilder +from .gale import GALE_block + +# Check optional dependency availability +TE_AVAILABLE = check_version_spec("transformer_engine", "0.1.0", hard_fail=False) +if TE_AVAILABLE: + import transformer_engine.pytorch as te + + +@dataclass +class GeoTransolverMetaData(ModelMetaData): + r"""Data class for storing essential meta data needed for the GeoTransolver model. + + Attributes + ---------- + name : str + Model name. Default is ``"GeoTransolver"``. + jit : bool + Whether JIT compilation is supported. Default is ``False``. + cuda_graphs : bool + Whether CUDA graphs are supported. Default is ``False``. + amp : bool + Whether automatic mixed precision is supported. Default is ``True``. + onnx_cpu : bool + Whether ONNX export to CPU is supported. Default is ``False``. + onnx_gpu : bool + Whether ONNX export to GPU is supported. Default is ``True``. + onnx_runtime : bool + Whether ONNX runtime is supported. Default is ``True``. + var_dim : int + Variable dimension for physics-informed features. Default is 1. + func_torch : bool + Whether torch functions are used. Default is ``False``. + auto_grad : bool + Whether automatic differentiation is used. Default is ``False``. + """ + + name: str = "GeoTransolver" + # Optimization + jit: bool = False + cuda_graphs: bool = False + amp: bool = True + # Inference + onnx_cpu: bool = False # No FFT op on CPU + onnx_gpu: bool = True + onnx_runtime: bool = True + # Physics informed + var_dim: int = 1 + func_torch: bool = False + auto_grad: bool = False + + +def _normalize_dim(x: int | Sequence[int]) -> tuple[int, ...]: + r"""Normalize dimension specification to tuple format. + + Parameters + ---------- + x : int | Sequence[int] + Dimension specification as scalar or sequence. + + Returns + ------- + tuple[int, ...] + Normalized dimension tuple. + + Raises + ------ + TypeError + If ``x`` is not an int or valid sequence. + """ + # Accept int as scalar + if isinstance(x, int): + return (x,) + # Accept any non-string sequence of ints + if isinstance(x, Sequence) and not isinstance(x, (str, bytes)): + return tuple(int(v) for v in x) + raise TypeError(f"Invalid dim specifier {x!r}") + + +def _normalize_tensor( + x: torch.Tensor | Sequence[torch.Tensor], +) -> tuple[torch.Tensor, ...]: + r"""Normalize tensor input to tuple format. + + Parameters + ---------- + x : torch.Tensor | Sequence[torch.Tensor] + Single tensor or sequence of tensors. + + Returns + ------- + tuple[torch.Tensor, ...] + Normalized tensor tuple. + + Raises + ------ + TypeError + If ``x`` is not a tensor or valid sequence. + """ + # Accept single tensor + if isinstance(x, torch.Tensor): + return (x,) + if isinstance(x, Sequence): + return tuple(x) + raise TypeError(f"Invalid tensor structure") + + +class GeoTransolver(Module): + r"""GeoTransolver: Geometry-Aware Physics Attention Transformer. + + GeoTransolver is an adaptation of the Transolver architecture, replacing standard + attention with GALE (Geometry-Aware Latent Embeddings) attention. GALE combines + physics-aware self-attention on learned state slices with cross-attention to + geometry and global context embeddings. + + The model projects geometry and global features onto physical state spaces, which + are then used as context in all transformer blocks. This design enables the model + to incorporate geometric structure and global information throughout the forward + pass. + + Parameters + ---------- + functional_dim : int | tuple[int, ...] + Dimension of the input values (local embeddings), not including global + embeddings or geometry features. Input will be projected to ``n_hidden`` + before processing. Can be a single int or tuple for multiple input types. + out_dim : int | tuple[int, ...] + Dimension of the output of the model. Must have same length as + ``functional_dim`` if both are tuples. + geometry_dim : int | None, optional + Pointwise dimension of the geometry input features. If provided, geometry + features will be projected onto physical states and used as context in all + GALE layers. Default is ``None``. + global_dim : int | None, optional + Dimension of the global embedding features. If provided, global features + will be projected onto physical states and used as context in all GALE + layers. Default is ``None``. + n_layers : int, optional + Number of GALE layers in the model. Default is 4. + n_hidden : int, optional + Hidden dimension of the transformer. Default is 256. + dropout : float, optional + Dropout rate applied across the GALE layers. Default is 0.0. + n_head : int, optional + Number of attention heads in each GALE layer. Must evenly divide + ``n_hidden`` to yield an integer head dimension. Default is 8. + act : str, optional + Activation function name. Default is ``"gelu"``. + mlp_ratio : int, optional + Ratio of MLP hidden dimension to ``n_hidden``. Default is 4. + slice_num : int, optional + Number of learned physical state slices in the GALE layers, representing + the number of learned states each layer should project inputs onto. + Default is 32. + use_te : bool, optional + Whether to use Transformer Engine backend when available. Default is ``True``. + time_input : bool, optional + Whether to include time embeddings. Default is ``False``. + plus : bool, optional + Whether to use Transolver++ features in the GALE layers. Default is ``False``. + include_local_features : bool, optional + Whether to include local features in the global context. Default is ``False``. + radii : list[float], optional + Radii for the local features. Default is ``[0.05, 0.25]``. + neighbors_in_radius : list[int], optional + Neighbors in radius for the local features. Default is ``[8, 32]``. + n_hidden_local : int, optional + Hidden dimension for the local features. Default is 32. + + Forward + ------- + local_embedding : torch.Tensor | tuple[torch.Tensor, ...] + Local embedding of the input data of shape :math:`(B, N, C)` where :math:`B` + is batch size, :math:`N` is number of nodes/tokens, and :math:`C` is + ``functional_dim``. Can be a single tensor or tuple for multiple input types. + local_positions : torch.Tensor | tuple[torch.Tensor, ...] | None, optional + Local positions for each input, each of shape :math:`(B, N, 3)`. Required if + ``include_local_features=True``. Default is ``None``. + global_embedding : torch.Tensor | None, optional + Global embedding of the input data of shape :math:`(B, N_g, C_g)` where + :math:`N_g` is number of global tokens and :math:`C_g` is ``global_dim``. + If ``None``, global context is not used. Default is ``None``. + geometry : torch.Tensor | None, optional + Geometry features of the input data of shape :math:`(B, N, C_{geo})` where + :math:`C_{geo}` is ``geometry_dim``. If ``None``, geometry context is not + used. Default is ``None``. + time : torch.Tensor | None, optional + Time embedding (currently not implemented). Default is ``None``. + + Outputs + ------- + torch.Tensor | tuple[torch.Tensor, ...] + Output tensor of shape :math:`(B, N, C_{out})` where :math:`C_{out}` is + ``out_dim``. Returns a single tensor if input was a single tensor, or a + tuple if input was a tuple. + + Raises + ------ + ValueError + If ``n_hidden`` is not evenly divisible by ``n_head``. + ValueError + If ``functional_dim`` and ``out_dim`` have different lengths when both + are tuples. + NotImplementedError + If ``time`` is provided (not yet implemented). + + Notes + ----- + GeoTransolver currently supports unstructured mesh input only. Enhancements for + image-based and voxel-based inputs may be available in the future. + + For more details on Transolver, see: + + - `Transolver paper `_ + - `Transolver++ paper `_ + + See Also + -------- + :class:`~physicsnemo.experimental.models.geotransolver.gale.GALE` : The attention mechanism used in GeoTransolver. + :class:`~physicsnemo.experimental.models.geotransolver.gale.GALE_block` : Transformer block using GALE attention. + :class:`~physicsnemo.experimental.models.geotransolver.context_projector.ContextProjector` : Projects context features onto physical states. + + Examples + -------- + Basic usage with local embeddings only: + + >>> import torch + >>> from physicsnemo.experimental.models.geotransolver import GeoTransolver + >>> model = GeoTransolver( + ... functional_dim=64, + ... out_dim=3, + ... n_hidden=256, + ... n_layers=4, + ... use_te=False, + ... ) + >>> local_emb = torch.randn(2, 1000, 64) # (batch, nodes, features) + >>> output = model(local_emb) + >>> output.shape + torch.Size([2, 1000, 3]) + + Usage with geometry and global context: + + >>> model = GeoTransolver( + ... functional_dim=64, + ... out_dim=3, + ... geometry_dim=3, + ... global_dim=16, + ... n_hidden=256, + ... n_layers=4, + ... use_te=False, + ... ) + >>> local_emb = torch.randn(2, 1000, 64) + >>> geometry = torch.randn(2, 1000, 3) # (batch, nodes, spatial_dim) + >>> global_emb = torch.randn(2, 1, 16) # (batch, 1, global_features) + >>> output = model(local_emb, global_embedding=global_emb, geometry=geometry) + >>> output.shape + torch.Size([2, 1000, 3]) + """ + + def __init__( + self, + functional_dim: int | tuple[int, ...], + out_dim: int | tuple[int, ...], + geometry_dim: int | None = None, + global_dim: int | None = None, + n_layers: int = 4, + n_hidden: int = 256, + dropout: float = 0.0, + n_head: int = 8, + act: str = "gelu", + mlp_ratio: int = 4, + slice_num: int = 32, + use_te: bool = True, + time_input: bool = False, + plus: bool = False, + include_local_features: bool = False, + radii: list[float] | None = None, + neighbors_in_radius: list[int] | None = None, + n_hidden_local: int = 32, + ) -> None: + super().__init__(meta=GeoTransolverMetaData()) + self.__name__ = "GeoTransolver" + + # Set defaults for mutable arguments + if radii is None: + radii = [0.05, 0.25] + if neighbors_in_radius is None: + neighbors_in_radius = [8, 32] + + self.include_local_features = include_local_features + self.use_te = use_te + + # Validate head dimension compatibility + if not n_hidden % n_head == 0: + raise ValueError( + f"GeoTransolver requires n_hidden % n_head == 0, " + f"but instead got {n_hidden % n_head}" + ) + + # Normalize dimension specifications to tuples + functional_dims = _normalize_dim(functional_dim) + out_dims = _normalize_dim(out_dim) + + # Store radii for hidden dimension calculation + self.radii = radii if self.include_local_features else [] + + # Initialize the context builder - handles all context construction + self.context_builder = GlobalContextBuilder( + functional_dims=functional_dims, + geometry_dim=geometry_dim, + global_dim=global_dim, + radii=radii, + neighbors_in_radius=neighbors_in_radius, + n_hidden_local=n_hidden_local, + n_hidden=n_hidden, + n_head=n_head, + dropout=dropout, + slice_num=slice_num, + use_te=use_te, + plus=plus, + include_local_features=self.include_local_features, + ) + context_dim = self.context_builder.get_context_dim() + + # Validate dimension tuple lengths match + if len(functional_dims) != len(out_dims): + raise ValueError( + f"functional_dim and out_dim must be the same length, " + f"but instead got {len(functional_dims)} and {len(out_dims)}" + ) + + # Input projection MLPs - one per input type + self.preprocess = nn.ModuleList( + [ + MLP( + f, + n_hidden * 2, + n_hidden, + n_layers=0, + res=False, + act=act, + use_te=use_te, + ) + for f in functional_dims + ] + ) + + self.n_hidden = n_hidden + + # Compute effective hidden dimension including local features + effective_hidden = ( + n_hidden + n_hidden_local * len(self.radii) + if self.include_local_features + else n_hidden + ) + + # GALE transformer blocks + self.blocks = nn.ModuleList( + [ + GALE_block( + num_heads=n_head, + hidden_dim=effective_hidden, + dropout=dropout, + act=act, + mlp_ratio=mlp_ratio, + slice_num=slice_num, + last_layer=(layer_idx == n_layers - 1), + use_te=use_te, + plus=plus, + context_dim=context_dim, + ) + for layer_idx in range(n_layers) + ] + ) + + # Output projection layers - one per output type + if use_te: + self.ln_mlp_out = nn.ModuleList( + [ + te.LayerNormLinear(in_features=effective_hidden, out_features=o) + for o in out_dims + ] + ) + else: + self.ln_mlp_out = nn.ModuleList( + [ + nn.Sequential( + nn.LayerNorm(effective_hidden), + nn.Linear(effective_hidden, o), + ) + for o in out_dims + ] + ) + + # Time embedding network (optional, not yet implemented) + self.time_input = time_input + if time_input: + self.time_fc = nn.Sequential( + nn.Linear(n_hidden, n_hidden), + nn.SiLU(), + nn.Linear(n_hidden, n_hidden), + ) + + def forward( + self, + local_embedding: ( + Float[torch.Tensor, "batch tokens features"] + | tuple[Float[torch.Tensor, "batch tokens features"], ...] + ), + local_positions: ( + Float[torch.Tensor, "batch tokens spatial_dim"] + | tuple[Float[torch.Tensor, "batch tokens spatial_dim"], ...] + | None + ) = None, + global_embedding: Float[torch.Tensor, "batch global_tokens global_dim"] + | None = None, + geometry: Float[torch.Tensor, "batch tokens geometry_dim"] | None = None, + time: torch.Tensor | None = None, + ) -> ( + Float[torch.Tensor, "batch tokens out_dim"] + | tuple[Float[torch.Tensor, "batch tokens out_dim"], ...] + ): + r"""Forward pass of the GeoTransolver model. + + The model constructs global context embeddings from geometry and global features + by projecting them onto physical state spaces. These context embeddings are then + used in all GALE blocks via cross-attention, allowing geometric and global + information to guide the learned physical state dynamics. + + Parameters + ---------- + local_embedding : torch.Tensor | tuple[torch.Tensor, ...] + Local embedding of the input data of shape :math:`(B, N, C)` where + :math:`B` is batch size, :math:`N` is number of nodes/tokens, and + :math:`C` is ``functional_dim``. + local_positions : torch.Tensor | tuple[torch.Tensor, ...] | None, optional + Local positions for each input, each of shape :math:`(B, N, 3)`. + Required if ``include_local_features=True``. Default is ``None``. + global_embedding : torch.Tensor | None, optional + Global embedding of shape :math:`(B, N_g, C_g)`. Default is ``None``. + geometry : torch.Tensor | None, optional + Geometry features of shape :math:`(B, N, C_{geo})`. Default is ``None``. + time : torch.Tensor | None, optional + Time embedding (not yet implemented). Default is ``None``. + + Returns + ------- + torch.Tensor | tuple[torch.Tensor, ...] + Output tensor of shape :math:`(B, N, C_{out})`. Returns single tensor + if input was single tensor, tuple if input was tuple. + + Raises + ------ + NotImplementedError + If ``time`` is provided. + ValueError + If input tensors have incorrect dimensions. + """ + # Track whether input was a single tensor for output format + single_input = isinstance(local_embedding, torch.Tensor) + + # Time embedding not yet supported + if time is not None: + raise NotImplementedError( + "Time input is not implemented yet. " + "Error rather than silently ignoring it." + ) + + # Normalize inputs to tuple format + local_embedding = _normalize_tensor(local_embedding) + if local_positions is not None: + local_positions = _normalize_tensor(local_positions) + + ### Input validation + if not torch.compiler.is_compiling(): + if len(local_embedding) == 0: + raise ValueError("Expected non-empty local_embedding") + for i, tensor in enumerate(local_embedding): + if tensor.ndim != 3: + raise ValueError( + f"Expected 3D local_embedding tensor (B, N, C) at index {i}, " + f"got {tensor.ndim}D tensor with shape {tuple(tensor.shape)}" + ) + if geometry is not None and geometry.ndim != 3: + raise ValueError( + f"Expected 3D geometry tensor (B, N, C_geo), " + f"got {geometry.ndim}D tensor with shape {tuple(geometry.shape)}" + ) + if global_embedding is not None and global_embedding.ndim != 3: + raise ValueError( + f"Expected 3D global_embedding tensor (B, N_g, C_g), " + f"got {global_embedding.ndim}D tensor with shape {tuple(global_embedding.shape)}" + ) + + # Build context embeddings and extract local features + embedding_states, local_embedding_bq = self.context_builder.build_context( + local_embedding, local_positions, geometry, global_embedding + ) + + # Project inputs to hidden dimension: (B, N, C) -> (B, N, n_hidden) + x = [self.preprocess[i](le) for i, le in enumerate(local_embedding)] + + # Concatenate local features if enabled + if self.include_local_features and local_embedding_bq is not None: + x = [ + torch.cat([x[i], local_embedding_bq[i]], dim=-1) + for i in range(len(x)) + ] + + # Pass through GALE transformer blocks with context cross-attention + for block in self.blocks: + x = block(tuple(x), embedding_states) + + # Project to output dimensions: (B, N, n_hidden) -> (B, N, out_dim) + x = [self.ln_mlp_out[i](x[i]) for i in range(len(x))] + + # Return same format as input (single tensor or tuple) + if single_input: + x = x[0] + else: + x = tuple(x) + + return x \ No newline at end of file diff --git a/test/conftest.py b/test/conftest.py index ec03bfc6c6..ab86abd44d 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -196,3 +196,18 @@ def seed_random_state(): torch.cuda.manual_seed_all(SEED) yield + + +@pytest.fixture(autouse=True, scope="function") +def reset_dynamo_state(): + """Reset torch._dynamo state after each test. + + This ensures test isolation by cleaning up dynamo's compiled function cache + and resetting configuration options like error_on_recompile. Without this, + tests that set error_on_recompile=True can cause subsequent tests to fail + when they trigger recompilation with different tensor shapes. + """ + yield + # Reset after test completes + torch._dynamo.reset() + torch._dynamo.config.error_on_recompile = False diff --git a/test/models/geotransolver/__init__.py b/test/models/geotransolver/__init__.py new file mode 100644 index 0000000000..69e0c20f24 --- /dev/null +++ b/test/models/geotransolver/__init__.py @@ -0,0 +1,16 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/test/models/geotransolver/data/geotransolver_basic_output.pth b/test/models/geotransolver/data/geotransolver_basic_output.pth new file mode 100644 index 0000000000..4c616db8a6 Binary files /dev/null and b/test/models/geotransolver/data/geotransolver_basic_output.pth differ diff --git a/test/models/geotransolver/data/geotransolver_te_output.pth b/test/models/geotransolver/data/geotransolver_te_output.pth new file mode 100644 index 0000000000..6b85b2324d Binary files /dev/null and b/test/models/geotransolver/data/geotransolver_te_output.pth differ diff --git a/test/models/geotransolver/data/geotransolver_tuple_output.pth b/test/models/geotransolver/data/geotransolver_tuple_output.pth new file mode 100644 index 0000000000..63329a08eb Binary files /dev/null and b/test/models/geotransolver/data/geotransolver_tuple_output.pth differ diff --git a/test/models/geotransolver/test_context_projector.py b/test/models/geotransolver/test_context_projector.py new file mode 100644 index 0000000000..524fce9a44 --- /dev/null +++ b/test/models/geotransolver/test_context_projector.py @@ -0,0 +1,55 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from physicsnemo.experimental.models.geotransolver.context_projector import ( + ContextProjector, +) + +# ============================================================================= +# ContextProjector Tests +# ============================================================================= + + +def test_context_projector_forward(device): + """Test ContextProjector forward pass.""" + torch.manual_seed(42) + + dim = 64 + heads = 4 + dim_head = 16 + slice_num = 8 + batch_size = 2 + n_tokens = 100 + + projector = ContextProjector( + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=0.0, + slice_num=slice_num, + use_te=False, + plus=False, + ).to(device) + + x = torch.randn(batch_size, n_tokens, dim).to(device) + + slice_tokens = projector(x) + + # Output shape: [Batch, Heads, Slice_num, dim_head] + assert slice_tokens.shape == (batch_size, heads, slice_num, dim_head) + assert not torch.isnan(slice_tokens).any() diff --git a/test/models/geotransolver/test_gale.py b/test/models/geotransolver/test_gale.py new file mode 100644 index 0000000000..32d9913c81 --- /dev/null +++ b/test/models/geotransolver/test_gale.py @@ -0,0 +1,205 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch + +from physicsnemo.experimental.models.geotransolver.gale import ( + GALE, + GALE_block, +) + +# ============================================================================= +# GALE (Geometry-Aware Latent Embeddings) Attention Tests +# ============================================================================= + + +def test_gale_forward_basic(device): + """Test GALE attention layer forward pass without context.""" + torch.manual_seed(42) + + dim = 64 + heads = 4 + dim_head = 16 + slice_num = 8 + batch_size = 2 + n_tokens = 100 + + gale = GALE( + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=0.0, + slice_num=slice_num, + use_te=False, + plus=False, + context_dim=dim_head, # Must match dim_head for cross attention + ).to(device) + + # Single input tensor wrapped in tuple + x = torch.randn(batch_size, n_tokens, dim).to(device) + + outputs = gale((x,), context=None) + + assert len(outputs) == 1 + assert outputs[0].shape == (batch_size, n_tokens, dim) + assert not torch.isnan(outputs[0]).any() + + +def test_gale_forward_with_context(device): + """Test GALE attention layer forward pass with cross-attention context.""" + torch.manual_seed(42) + + dim = 64 + heads = 4 + dim_head = 16 + slice_num = 8 + batch_size = 2 + n_tokens = 100 + context_tokens = 32 + context_dim = dim_head + + gale = GALE( + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=0.0, + slice_num=slice_num, + use_te=False, + plus=False, + context_dim=context_dim, + ).to(device) + + x = torch.randn(batch_size, n_tokens, dim).to(device) + context = torch.randn(batch_size, heads, context_tokens, context_dim).to(device) + + outputs = gale((x,), context=context) + + assert len(outputs) == 1 + assert outputs[0].shape == (batch_size, n_tokens, dim) + assert not torch.isnan(outputs[0]).any() + + +def test_gale_forward_multiple_inputs(device): + """Test GALE attention layer with multiple input tensors.""" + torch.manual_seed(42) + + dim = 64 + heads = 4 + dim_head = 16 + slice_num = 8 + batch_size = 2 + n_tokens_1 = 100 + n_tokens_2 = 150 + context_dim = dim_head + + gale = GALE( + dim=dim, + heads=heads, + dim_head=dim_head, + dropout=0.0, + slice_num=slice_num, + use_te=False, + plus=False, + context_dim=context_dim, + ).to(device) + + x1 = torch.randn(batch_size, n_tokens_1, dim).to(device) + x2 = torch.randn(batch_size, n_tokens_2, dim).to(device) + + outputs = gale((x1, x2), context=None) + + assert len(outputs) == 2 + assert outputs[0].shape == (batch_size, n_tokens_1, dim) + assert outputs[1].shape == (batch_size, n_tokens_2, dim) + assert not torch.isnan(outputs[0]).any() + assert not torch.isnan(outputs[1]).any() + + +# ============================================================================= +# GALE_block Tests +# ============================================================================= + + +def test_gale_block_forward(device): + """Test GALE_block transformer block forward pass.""" + torch.manual_seed(42) + + hidden_dim = 64 + n_head = 4 + batch_size = 2 + n_tokens = 100 + slice_num = 8 + context_dim = hidden_dim // n_head + + block = GALE_block( + num_heads=n_head, + hidden_dim=hidden_dim, + dropout=0.0, + act="gelu", + mlp_ratio=4, + last_layer=False, + out_dim=1, + slice_num=slice_num, + use_te=False, + plus=False, + context_dim=context_dim, + ).to(device) + + x = torch.randn(batch_size, n_tokens, hidden_dim).to(device) + context = torch.randn(batch_size, n_head, slice_num, context_dim).to(device) + + outputs = block((x,), global_context=context) + + assert len(outputs) == 1 + assert outputs[0].shape == (batch_size, n_tokens, hidden_dim) + assert not torch.isnan(outputs[0]).any() + + +def test_gale_block_multiple_inputs(device): + """Test GALE_block with multiple input tensors.""" + torch.manual_seed(42) + + hidden_dim = 64 + n_head = 4 + batch_size = 2 + n_tokens_1 = 100 + n_tokens_2 = 150 + slice_num = 8 + context_dim = hidden_dim // n_head + + block = GALE_block( + num_heads=n_head, + hidden_dim=hidden_dim, + dropout=0.0, + act="gelu", + mlp_ratio=4, + last_layer=False, + out_dim=1, + slice_num=slice_num, + use_te=False, + plus=False, + context_dim=context_dim, + ).to(device) + + x1 = torch.randn(batch_size, n_tokens_1, hidden_dim).to(device) + x2 = torch.randn(batch_size, n_tokens_2, hidden_dim).to(device) + context = torch.randn(batch_size, n_head, slice_num, context_dim).to(device) + + outputs = block((x1, x2), global_context=context) + + assert len(outputs) == 2 + assert outputs[0].shape == (batch_size, n_tokens_1, hidden_dim) + assert outputs[1].shape == (batch_size, n_tokens_2, hidden_dim) diff --git a/test/models/geotransolver/test_geotransolver.py b/test/models/geotransolver/test_geotransolver.py new file mode 100644 index 0000000000..47b39a3b8a --- /dev/null +++ b/test/models/geotransolver/test_geotransolver.py @@ -0,0 +1,748 @@ +# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES. +# SPDX-FileCopyrightText: All rights reserved. +# SPDX-License-Identifier: Apache-2.0 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +import pytest +import torch + +from physicsnemo.experimental.models.geotransolver.geotransolver import ( + GeoTransolver, +) +from test.common import ( # noqa E402 + validate_amp, + validate_checkpoint, + validate_combo_optims, + validate_cuda_graphs, + validate_forward_accuracy, + validate_jit, +) +from test.conftest import requires_module + +# ============================================================================= +# GeoTransolver End-to-End Model Tests +# ============================================================================= + + +@pytest.mark.parametrize("use_geometry", [False, True]) +@pytest.mark.parametrize("use_global", [False, True]) +def test_geotransolver_forward(device, use_geometry, use_global): + """Test GeoTransolver model forward pass with optional geometry and global context.""" + torch.manual_seed(42) + + batch_size = 2 + n_tokens = 100 + n_geom_tokens = 345 + n_global = 5 + geometry_dim = 3 + global_dim = 16 + + model = GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=geometry_dim if use_geometry else None, + global_dim=global_dim if use_global else None, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + local_emb = torch.randn(batch_size, n_tokens, 32).to(device) + local_positions = local_emb[:, :, :3] + kwargs = {} + if use_geometry: + kwargs["geometry"] = torch.randn(batch_size, n_geom_tokens, geometry_dim).to( + device + ) + if use_global: + kwargs["global_embedding"] = torch.randn(batch_size, n_global, global_dim).to( + device + ) + + outputs = model(local_emb, local_positions, **kwargs) + + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == (batch_size, n_tokens, 4) + assert not torch.isnan(outputs).any() + + +def test_geotransolver_forward_tuple_inputs(device): + """Test GeoTransolver model forward pass with tuple inputs/outputs (multi-head).""" + torch.manual_seed(42) + + functional_dims = (32, 48) + out_dims = (4, 6) + + model = GeoTransolver( + functional_dim=functional_dims, + out_dim=out_dims, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + batch_size = 2 + n_tokens_1 = 100 + n_tokens_2 = 150 + n_geom = 235 + n_global = 5 + + local_emb_1 = torch.randn(batch_size, n_tokens_1, functional_dims[0]).to(device) + local_emb_2 = torch.randn(batch_size, n_tokens_2, functional_dims[1]).to(device) + local_positions_1 = local_emb_1[:, :, :3] + local_positions_2 = local_emb_2[:, :, :3] + geometry = torch.randn(batch_size, n_geom, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + + outputs = model( + (local_emb_1, local_emb_2), + local_positions=(local_positions_1, local_positions_2), + global_embedding=global_emb, + geometry=geometry, + ) + + assert len(outputs) == 2 + assert all(isinstance(output, torch.Tensor) for output in outputs) + assert outputs[0].shape == (batch_size, n_tokens_1, out_dims[0]) + assert outputs[1].shape == (batch_size, n_tokens_2, out_dims[1]) + assert not torch.isnan(outputs[0]).any() + assert not torch.isnan(outputs[1]).any() + + +@requires_module("warp") +def test_geotransolver_forward_with_local_features(device, pytestconfig): + """Test GeoTransolver model forward pass with local features (BQ warp).""" + torch.manual_seed(42) + + model = GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=True, + radii=[0.05, 0.25], + neighbors_in_radius=[8, 32], + n_hidden_local=32, + ).to(device) + + batch_size = 1 + n_tokens = 100 + n_global = 5 + n_geom = 235 + + # For local features, the first 3 channels of local_emb should be coordinates + local_emb = torch.randn(batch_size, n_tokens, 32).to(device) + local_positions = local_emb[:, :, :3] + geometry = torch.randn(batch_size, n_geom, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + + outputs = model( + local_emb, + local_positions=local_positions, + global_embedding=global_emb, + geometry=geometry, + ) + + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == (batch_size, n_tokens, 4) + assert not torch.isnan(outputs).any() + + +# ============================================================================= +# Forward Accuracy Tests (reproducibility) +# ============================================================================= + + +def test_geotransolver_forward_accuracy_basic(device): + """Test GeoTransolver basic forward pass accuracy.""" + torch.manual_seed(42) + + model = GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + batch_size = 2 + n_tokens = 100 + n_geom = 235 + n_global = 5 + + local_emb = torch.randn(batch_size, n_tokens, 32).to(device) + local_positions = local_emb[:, :, :3] + geometry = torch.randn(batch_size, n_geom, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + + assert validate_forward_accuracy( + model, + (local_emb, local_positions, global_emb, geometry), + file_name="models/geotransolver/data/geotransolver_basic_output.pth", + atol=1e-3, + ) + + +def test_geotransolver_forward_accuracy_tuple(device): + """Test GeoTransolver forward pass accuracy with tuple inputs.""" + torch.manual_seed(42) + + functional_dims = (32, 48) + out_dims = (4, 6) + + model = GeoTransolver( + functional_dim=functional_dims, + out_dim=out_dims, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + batch_size = 2 + n_tokens_1 = 100 + n_tokens_2 = 150 + n_global = 5 + n_geom = 235 + + local_emb_1 = torch.randn(batch_size, n_tokens_1, functional_dims[0]).to(device) + local_emb_2 = torch.randn(batch_size, n_tokens_2, functional_dims[1]).to(device) + + local_positions_1 = local_emb_1[:, :, :3] + local_positions_2 = local_emb_2[:, :, :3] + geometry = torch.randn(batch_size, n_geom, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + + assert validate_forward_accuracy( + model, + ( + (local_emb_1, local_emb_2), + (local_positions_1, local_positions_2), + global_emb, + geometry, + ), + file_name="models/geotransolver/data/geotransolver_tuple_output.pth", + atol=1e-3, + ) + + +# ============================================================================= +# Optimization Tests +# ============================================================================= + + +def test_geotransolver_optimizations(device): + """Test GeoTransolver optimizations (CUDA graphs, JIT, AMP, combo).""" + torch.manual_seed(42) + + def setup_model(): + """Setup fresh GeoTransolver model and inputs for each optimization test.""" + model = GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + batch_size = 2 + n_tokens = 100 + n_global = 5 + + local_emb = torch.randn(batch_size, n_tokens, 32).to(device) + geometry = torch.randn(batch_size, n_tokens, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + local_positions = local_emb[:, :, :3] + return model, local_emb, local_positions, global_emb, geometry + + # Check CUDA graphs + model, local_emb, local_positions, global_emb, geometry = setup_model() + + assert validate_cuda_graphs( + model, + (local_emb, local_positions, global_emb, geometry), + ) + + # Check JIT + model, local_emb, local_positions, global_emb, geometry = setup_model() + assert validate_jit( + model, + (local_emb, local_positions, global_emb, geometry), + ) + + # Check AMP + model, local_emb, local_positions, global_emb, geometry = setup_model() + assert validate_amp( + model, + (local_emb, local_positions, global_emb, geometry), + ) + + # Check Combo + model, local_emb, local_positions, global_emb, geometry = setup_model() + assert validate_combo_optims( + model, + (local_emb, local_positions, global_emb, geometry), + ) + + +# ============================================================================= +# Transformer Engine Tests +# ============================================================================= + + +@requires_module("transformer_engine") +def test_geotransolver_te_basic(device, pytestconfig): + """Test GeoTransolver with Transformer Engine backend.""" + torch.manual_seed(42) + + if device == "cpu": + pytest.skip("TE Tests require cuda.") + + model = GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=True, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + batch_size = 2 + n_tokens = 100 + n_geom = 235 + n_global = 5 + + local_emb = torch.randn(batch_size, n_tokens, 32).to(device) + geometry = torch.randn(batch_size, n_geom, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + local_positions = local_emb[:, :, :3] + + outputs = model( + local_emb, + local_positions=local_positions, + global_embedding=global_emb, + geometry=geometry, + ) + + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == (batch_size, n_tokens, 4) + assert not torch.isnan(outputs).any() + + +# ============================================================================= +# Checkpoint Tests +# ============================================================================= + + +def test_geotransolver_checkpoint(device): + """Test GeoTransolver checkpoint save/load.""" + torch.manual_seed(42) + + model_1 = GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + model_2 = GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + batch_size = 2 + n_tokens = 100 + n_global = 5 + + local_emb = torch.randn(batch_size, n_tokens, 32).to(device) + geometry = torch.randn(batch_size, n_tokens, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + local_positions = local_emb[:, :, :3] + assert validate_checkpoint( + model_1, + model_2, + (local_emb, local_positions, global_emb, geometry), + ) + + +def test_geotransolver_checkpoint_tuple(device): + """Test GeoTransolver checkpoint save/load with tuple inputs.""" + torch.manual_seed(42) + + functional_dims = (32, 48) + out_dims = (4, 6) + + model_1 = GeoTransolver( + functional_dim=functional_dims, + out_dim=out_dims, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + model_2 = GeoTransolver( + functional_dim=functional_dims, + out_dim=out_dims, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + batch_size = 2 + n_tokens_1 = 100 + n_tokens_2 = 150 + n_global = 5 + + local_emb_1 = torch.randn(batch_size, n_tokens_1, functional_dims[0]).to(device) + local_emb_2 = torch.randn(batch_size, n_tokens_2, functional_dims[1]).to(device) + geometry = torch.randn(batch_size, n_tokens_1, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + + assert validate_checkpoint( + model_1, + model_2, + ((local_emb_1, local_emb_2), (None, None), global_emb, geometry), + ) + + +# ============================================================================= +# Error Handling Tests +# ============================================================================= + + +def test_geotransolver_invalid_hidden_head_dims(): + """Test that GeoTransolver raises error for incompatible hidden/head dimensions.""" + with pytest.raises(ValueError, match="n_hidden % n_head == 0"): + GeoTransolver( + functional_dim=32, + out_dim=4, + n_hidden=65, # Not divisible by n_head=4 + n_head=4, + use_te=False, + ) + + +def test_geotransolver_mismatched_functional_out_dims(): + """Test that GeoTransolver raises error for mismatched functional/out dim lengths.""" + with pytest.raises( + ValueError, match="functional_dim and out_dim must be the same length" + ): + GeoTransolver( + functional_dim=(32, 48), + out_dim=(4,), # Length mismatch + use_te=False, + ) + + +# ============================================================================= +# Activation Function Tests +# ============================================================================= + + +@pytest.mark.parametrize("activation", ["gelu", "relu", "tanh", "silu"]) +def test_geotransolver_activations(device, activation): + """Test GeoTransolver with different activation functions.""" + torch.manual_seed(42) + + model = GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act=activation, + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + batch_size = 2 + n_tokens = 100 + n_global = 5 + n_geom = 235 + + local_emb = torch.randn(batch_size, n_tokens, 32).to(device) + geometry = torch.randn(batch_size, n_geom, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + + outputs = model( + local_emb, local_positions=None, global_embedding=global_emb, geometry=geometry + ) + + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == (batch_size, n_tokens, 4) + assert not torch.isnan(outputs).any() + + +# ============================================================================= +# Shape and Configuration Tests +# ============================================================================= + + +@pytest.mark.parametrize("n_layers", [1, 2, 4]) +def test_geotransolver_different_depths(device, n_layers): + """Test GeoTransolver with different numbers of layers.""" + torch.manual_seed(42) + + model = GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=3, + global_dim=16, + n_layers=n_layers, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + batch_size = 2 + n_tokens = 100 + n_geom = 235 + n_global = 5 + + local_emb = torch.randn(batch_size, n_tokens, 32).to(device) + geometry = torch.randn(batch_size, n_geom, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + + outputs = model( + local_emb, local_positions=None, global_embedding=global_emb, geometry=geometry + ) + + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == (batch_size, n_tokens, 4) + assert not torch.isnan(outputs).any() + + +@pytest.mark.parametrize("slice_num", [4, 16, 32]) +def test_geotransolver_different_slice_nums(device, slice_num): + """Test GeoTransolver with different numbers of physical state slices.""" + torch.manual_seed(42) + + model = GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=64, + dropout=0.0, + n_head=4, + act="gelu", + mlp_ratio=2, + slice_num=slice_num, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + batch_size = 2 + n_tokens = 100 + n_geom = 235 + n_global = 5 + + local_emb = torch.randn(batch_size, n_tokens, 32).to(device) + geometry = torch.randn(batch_size, n_geom, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + + outputs = model( + local_emb, local_positions=None, global_embedding=global_emb, geometry=geometry + ) + + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == (batch_size, n_tokens, 4) + assert not torch.isnan(outputs).any() + + +@pytest.mark.parametrize("n_hidden,n_head", [(64, 4), (128, 8), (256, 8)]) +def test_geotransolver_different_hidden_sizes(device, n_hidden, n_head): + """Test GeoTransolver with different hidden dimensions and head counts.""" + torch.manual_seed(42) + + model = GeoTransolver( + functional_dim=32, + out_dim=4, + geometry_dim=3, + global_dim=16, + n_layers=2, + n_hidden=n_hidden, + dropout=0.0, + n_head=n_head, + act="gelu", + mlp_ratio=2, + slice_num=8, + use_te=False, + time_input=False, + plus=False, + include_local_features=False, + ).to(device) + + batch_size = 2 + n_tokens = 100 + n_geom = 235 + n_global = 5 + + local_emb = torch.randn(batch_size, n_tokens, 32).to(device) + geometry = torch.randn(batch_size, n_geom, 3).to(device) + global_emb = torch.randn(batch_size, n_global, 16).to(device) + + outputs = model( + local_emb, local_positions=None, global_embedding=global_emb, geometry=geometry + ) + + assert isinstance(outputs, torch.Tensor) + assert outputs.shape == (batch_size, n_tokens, 4) + assert not torch.isnan(outputs[0]).any() + + +# ============================================================================= +# Model Metadata Tests +# ============================================================================= + + +def test_geotransolver_metadata(): + """Test GeoTransolver model metadata.""" + model = GeoTransolver( + functional_dim=32, + out_dim=4, + use_te=False, + ) + + assert model.meta.name == "GeoTransolver" + assert model.meta.amp is True + assert model.__name__ == "GeoTransolver" diff --git a/test/models/transolver/test_transolver.py b/test/models/transolver/test_transolver.py index be52d93b56..5758103d6f 100644 --- a/test/models/transolver/test_transolver.py +++ b/test/models/transolver/test_transolver.py @@ -133,10 +133,15 @@ def setup_model(): use_te=False, ).to(device) - bsize = 4 - - embedding = torch.randn(bsize, 12345, 3).to(device) - functional_input = torch.randn(bsize, 12345, 2).to(device) + if device == "cuda:0": + bsize = 4 + n_points = 12345 + else: + bsize = 1 + n_points = 123 + + embedding = torch.randn(bsize, n_points, 3).to(device) + functional_input = torch.randn(bsize, n_points, 2).to(device) return model, embedding, functional_input