You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
When ONNX produces a training graph for a model that uses RMS Norm (for eg Llama2), how does it recognize a node as SimplifiedLayerNormalization along with SimplifiedLayerNormalizationGrad for the gradient?
Also, the output for the forward pass is a single output for the RMS Norm class but the graph has two outputs. How can both versions be calibrated?
The issue occurs when the torch model (with single output) is converted to ONNX-MLIR. SimplifiedLayerNormalization cannot find the second output and considers it as None which causes issue in the subsequent nodes.
To reproduce
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import onnx
import io
class RMSNorm(torch.nn.Module):
def init(self, dim: int, eps: float = 1e-6):
Describe the issue
When ONNX produces a training graph for a model that uses RMS Norm (for eg Llama2), how does it recognize a node as
SimplifiedLayerNormalization
along withSimplifiedLayerNormalizationGrad
for the gradient?Also, the output for the forward pass is a single output for the RMS Norm class but the graph has two outputs. How can both versions be calibrated?
The issue occurs when the torch model (with single output) is converted to
ONNX-MLIR
.SimplifiedLayerNormalization
cannot find the second output and considers it asNone
which causes issue in the subsequent nodes.To reproduce
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import onnx
import io
class RMSNorm(torch.nn.Module):
def init(self, dim: int, eps: float = 1e-6):
def _norm(self, x):
def forward(self, x):
def init(self, input_dim, hidden_dim):
super(SimpleNetRMS, self).init()
def forward(self, x):
x = self.linear1(x)
x = self.rmsnorm(x)
return x
rmsmodel = SimpleNetRMS(3,3)
model_inputs = torch.randn(2,2,3)
model_outputs = rmsmodel(model_inputs)
if isinstance(model_outputs, torch.Tensor):
model_outputs = [model_outputs]
input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}
torch.onnx.export(
rmsmodel,
model_inputs,
"rmsnorm.onnx",
input_names=input_names,
output_names=output_names,
opset_version=14,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamic_axes=dynamic_axes,
export_params=True,
keep_initializers_as_inputs=False,
)
requires_grad = [name for name, param in rmsmodel.named_parameters() if param.requires_grad]
frozen_params = [name for name, param in rmsmodel.named_parameters() if not param.requires_grad]
artifacts.generate_artifacts(
onnx_model,
optimizer=artifacts.OptimType.AdamW,
loss=artifacts.LossType.MSELoss,
requires_grad=requires_grad,
frozen_params=frozen_params,
artifact_directory="RMS",
additional_output_names=["output"])
Urgency
Urgent.
Platform
Linux
OS Version
6.5.0-1019-nvidia-64k
ONNX Runtime Installation
Built from Source
ONNX Runtime Version or Commit ID
1.18.1
ONNX Runtime API
Python
Architecture
X64
Execution Provider
CUDA
Execution Provider Library Version
CUDA 12.4
The text was updated successfully, but these errors were encountered: