Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Zero3 for torch.compile with compiled_autograd when running LayerNorm #6719

Open
yitingw1 opened this issue Nov 6, 2024 · 2 comments
Assignees
Labels
bug Something isn't working training

Comments

@yitingw1
Copy link

yitingw1 commented Nov 6, 2024

Describe the bug

When running a simple model including torch.nn.LayerNorm using deepspeed zero3 with torch.compile and compiled_autograd. An error occurs:

site-packages/torch/_subclasses/fake_tensor.py:2017] RuntimeError: Attempting to broadcast a dimension of length 0 at -1! Mismatching argument at index 1 had torch.Size([0]); but expected shape should be broadcastable to [100, 120]

We first found this error in BERT model with deepspeed Zero3 with torch.compile and compiled_autograd.

  • It's ok for deepspeed Zero1/2 with torch.compile and compiled_autograd
  • It's ok for deepspeed Zero3 with torch.compile and without compiled_autograd
  • There are a lot of graph beaks and recompiles in deepspeed Zero3 with torch.compile.
  • To simplify the issue, I made a small reproducer to extract error op(torch.nn.LayerNorm)

Expected behavior
Running the model with deepspeed Zero3 without error.

Investigation

The error: "RuntimeError: Attempting to broadcast a dimension of length 0 at -1! Mismatching argument at index 1 had torch.Size([0]); but expected shape should be broadcastable to [128, 128, 1600]"
It occurs when compiled autograd tries to trace the backward graph.
It appears in LayerNorm backward decompositions. It tries to broadcast weight_cast(torch.Size([0]) to grad_out_cast' shape([128,128,1600]) and fails.

if weight_cast is not None:         
    grad_x_hat = grad_out_cast * weight_cast 

If bypassing the LayerNorm weight by setting nn.LayerNorm(120, eps=1e-12, elementwise_affine=False) instead of elementwise_affine=True in the file deepspeed_reproducer_cpu.py, the running is ok.

System info:

  • OS: Ubuntu 22.04
  • No GPU (it's device-independent, so we use CPU to reproduce)
  • Python version 3.10.12
  • PyTorch version 2.5.1
  • DeepSpeed version 0.15.3

To Reproduce
Steps to reproduce the behavior:

  1. Set environment variable for more verbose logs: TORCH_LOGS="+dynamo,graph,graph_code,graph_breaks,recompiles,aot_graphs,aot_joint_graph,compiled_autograd_verbose"
  2. Run with deepspeed --num_nodes 1 --num_gpus 1 deepspeed_reproducer_cpu.py
  3. You can use --num_gpus 2/4/8 for multi-cards
  4. Below is deepspeed_reproducer_cpu.py
import torch
import torchvision
import torchvision.transforms as transforms
import torch.distributed as dist
import deepspeed
from deepspeed.accelerator import get_accelerator
from tqdm import tqdm
from torch.utils.data import DataLoader
from torch.utils.data.distributed import DistributedSampler
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(32 * 32 * 3, 120)
        self.fc2 = nn.Linear(120, 10)
        self.LayerNorm1 = nn.LayerNorm(120, eps=1e-12, elementwise_affine=True)

    def forward(self, x):
        x = torch.flatten(x, 1)  # flatten all dimensions except batch
        x = F.relu(self.fc1(x))
        x = self.LayerNorm1(x)
        x = self.fc2(x)
        return x

compile_kwargs = {"dynamic": False}
device = torch.device('cpu')

model = Net()
model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
model_engine, optimizer, *_ = deepspeed.initialize(
    model=model,
    model_parameters=model.parameters(),
    optimizer=optimizer,
    config="./deepspeed_config.json",
)
# torch_compile
model_engine.compile(
    compile_kwargs=compile_kwargs,
)

# dataset
transform = transforms.Compose(
    [transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)
batch_size = 100
trainset = torchvision.datasets.CIFAR10(
    root="./DATA/CIFAR10", train=True, download=True, transform=transform
)
# process dataset
trainloader = DataLoader(
    trainset,
    batch_size=batch_size,
    sampler=DistributedSampler(trainset, shuffle=True),
    num_workers=16,
    pin_memory=True,
)
progress_bar = tqdm(
    total=len(trainloader),
    desc=f"Training 1/1 epoch",
    position=0,
    leave=True,
    disable= dist.is_initialized() and dist.get_rank() != 0,
)
for epoch in range(100):
    with torch._dynamo.compiled_autograd.enable(
                torch.compile(backend=get_accelerator().get_compile_backend(), **compile_kwargs)):
        running_loss = 0.0
        for i, data in enumerate(trainloader, 0):
            inputs, labels = data
            inputs, labels = inputs.to(device), labels.to(device)
            optimizer.zero_grad()
            
            # forward + backward + optimize
            outputs = model_engine(inputs)
            loss = criterion(outputs, labels)
            model_engine.backward(loss)
            model_engine.step()

            # print statistics
            running_loss += loss.item()
            if i % 2000 == 1999:  # print every 2000 mini-batches
                print(f"[{epoch + 1}, {i + 1:5d}] loss: {running_loss / 2000:.3f}")
                running_loss = 0.0
            progress_bar.update(1)
print("Finished Training")
  1. Below is deepspeed_config.json
{
    "train_batch_size": 32, 
    "optimizer": {
        "type": "SGD",
        "params": {
            "lr": 0.001,
            "momentum": 0.9
        }
    },
    "zero_allow_untested_optimizer": true,
    "zero_optimization": {
        "stage": 3,
        "overlap_comm": false,
        "reduce_scatter" : false,
        "contiguous_gradients" : false
    },
}
@yitingw1 yitingw1 added bug Something isn't working training labels Nov 6, 2024
@tohtana tohtana self-assigned this Nov 8, 2024
@tohtana
Copy link
Contributor

tohtana commented Nov 8, 2024

Hi @yitingw1, I wonder if persistent parameter might not work well with the compiler.
Can you try setting stage3_param_persistence_threshold to zero?

@yitingw1
Copy link
Author

Hi @tohtana, I have tried setting stage3_param_persistence_threshold to zero, but it seems it doesn't help. The error still occurs.
I also opened an issue in pytorch.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working training
Projects
None yet
Development

No branches or pull requests

2 participants