Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Llama2 RMS Norm : SimplifiedLayerNormalization #21924

Closed
srijanie03 opened this issue Aug 29, 2024 · 0 comments
Closed

Llama2 RMS Norm : SimplifiedLayerNormalization #21924

srijanie03 opened this issue Aug 29, 2024 · 0 comments

Comments

@srijanie03
Copy link

srijanie03 commented Aug 29, 2024

Describe the issue

When ONNX produces a training graph for a model that uses RMS Norm (for eg Llama2), how does it recognize a node as SimplifiedLayerNormalization along with SimplifiedLayerNormalizationGrad for the gradient?

Also, the output for the forward pass is a single output for the RMS Norm class but the graph has two outputs. How can both versions be calibrated?

The issue occurs when the torch model (with single output) is converted to ONNX-MLIR. SimplifiedLayerNormalization cannot find the second output and considers it as None which causes issue in the subsequent nodes.

To reproduce

import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training.api import CheckpointState, Module, Optimizer
from onnxruntime.training import artifacts
from onnxruntime import InferenceSession
from torchvision import datasets, transforms
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import onnx
import io

class RMSNorm(torch.nn.Module):
def init(self, dim: int, eps: float = 1e-6):

super().__init__()
self.eps = eps
self.weight = nn.Parameter(torch.ones(dim))

def _norm(self, x):

return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)

def forward(self, x):

output = self._norm(x.float()).type_as(x)
return output * self.weight

class SimpleNetRMS(nn.Module):

def init(self, input_dim, hidden_dim):
super(SimpleNetRMS, self).init()

self.linear1 = nn.Linear(input_dim, hidden_dim)
self.rmsnorm = RMSNorm(3,1e-6)                    

def forward(self, x):
x = self.linear1(x)
x = self.rmsnorm(x)
return x
rmsmodel = SimpleNetRMS(3,3)
model_inputs = torch.randn(2,2,3)

model_outputs = rmsmodel(model_inputs)
if isinstance(model_outputs, torch.Tensor):
model_outputs = [model_outputs]

input_names = ["input"]
output_names = ["output"]
dynamic_axes = {"input": {0: "batch_size"}, "output": {0: "batch_size"}}

torch.onnx.export(
rmsmodel,
model_inputs,
"rmsnorm.onnx",
input_names=input_names,
output_names=output_names,
opset_version=14,
do_constant_folding=False,
training=torch.onnx.TrainingMode.TRAINING,
dynamic_axes=dynamic_axes,
export_params=True,
keep_initializers_as_inputs=False,
)

requires_grad = [name for name, param in rmsmodel.named_parameters() if param.requires_grad]

frozen_params = [name for name, param in rmsmodel.named_parameters() if not param.requires_grad]

artifacts.generate_artifacts(
onnx_model,
optimizer=artifacts.OptimType.AdamW,
loss=artifacts.LossType.MSELoss,
requires_grad=requires_grad,
frozen_params=frozen_params,
artifact_directory="RMS",
additional_output_names=["output"])

Urgency

Urgent.

Platform

Linux

OS Version

6.5.0-1019-nvidia-64k

ONNX Runtime Installation

Built from Source

ONNX Runtime Version or Commit ID

1.18.1

ONNX Runtime API

Python

Architecture

X64

Execution Provider

CUDA

Execution Provider Library Version

CUDA 12.4

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant