-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🐛 [Bug] Can't load UNet on H100 after compiling ExportedProgram with torch_tensorrt.dynamo.compile and saving #3163
Comments
@readleyj I tried in my environment from today's latest main branch using RTX 4080, I don't get the error as you pasted. |
@lanluo-nvidia Thank you for the reply. That is very strange. I will try with today's nightly and report back. Also, I am running this on an H100, could that possibly be the source of the issue? |
I tried again with today's nightly ( |
I also tried with release 2.4. There, I can successfully save and load the model but the compiled model outputs are full of nans. In general, Stable Diffusion with Torch-TensorRT seems very problematic. |
@readleyj yes, we have bugs in release 2.4 which got fixed in current main branch, if you could paste the code: |
@lanluo-nvidia After loading the UNet, I first check if the results match ( with torch.inference_mode():
tensorrt_outputs_unet = loaded_unet(*arg_inputs_unet)
for expected_output, tensorrt_output in zip(expected_outputs_unet, tensorrt_outputs_unet):
assert torch.allclose(
expected_output, tensorrt_output, 1e-2, 1e-2
), "UNet results do not match"
print("UNet results match for Torch-TensorRT and Diffusers") To generate an image, I plug the loaded UNet into a StableDiffusion pipeline as follows (code block assumes import torch
import torch_tensorrt
from diffusers import StableDiffusionPipeline, EulerDiscreteScheduler
PROMPT = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k"
with torch.inference_mode():
pipe = StableDiffusionPipeline.from_pretrained(
"CompVis/stable-diffusion-v1-4",
torch_dtype=torch.float16,
variant="fp16",
use_safetensors=True
).to("cuda")
pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
class LoadedUNet(torch.nn.Module):
def __init__(self):
super().__init__()
self.in_channels = pipe.unet.config.in_channels
setattr(self, "config", pipe.unet.config)
self.device = pipe.unet.device
def forward(self, latent_model_input, t, encoder_hidden_states, **kwargs):
sample = loaded_unet(latent_model_input, t, encoder_hidden_states)
return sample
pipe.unet = LoadedUNet()
image = pipe(PROMPT,
num_inference_steps=50,
height=512,
width=512,
).images[0] Note that you may receive a |
@readleyj I have tried with release/2.5 branch: (this is our upcoming release branch and it is more stable then main branch since main branch is getting all the latest changes from both pytorch and torch_tensorrt)
I can see it does throw out the results does not match error (the rtol atol is actually very close to 1e -2, 1e-2),
|
Test2) tested locally in my RTX4080 with release/2.5 branch found that if I do not do save and load the model, directly use the torch_tensorrt compiled model to inference, the UNet results does match and it can also generate the images as expected as Test1) |
Test3) tested with H100 using release/2.5 docker image: in the docker container: It does throw me the following error:
|
Yes, the error in Test 3) is exactly what I'm getting on my H100. I thought the problem might be with torch.export so I already created an issue on the PyTorch repo (pytorch/pytorch#136317) |
@readleyj seems like it only happens for H100, I did the exactly same test in RTX 4080 using the same image, same test code as you provided, it is working. Test4) test with RTX 4080 using release/2.5 docker image: in the docker container: |
@lanluo-nvidia Also, on my H100 tests, the model successfully compiles, the UNet results match (using |
@lanluo-nvidia Any updates on this? Should I expect this issue to be resolved soon or will this be on the backlog for a while? Unfortunately, I only have H100s at my disposal and this is blocking progress for me. |
@readleyj Have you tried saving with torchscript format instead of exported_program? Simply change the two lines from torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ep", inputs=arg_inputs_unet)
loaded_unet = torch.export.load("sd_unet_compiled.ep").module() to torch_tensorrt.save(compiled_unet, "sd_unet_compiled.ts", output_format="torchscript", inputs=arg_inputs_unet)
loaded_unet = torch.jit.load("sd_unet_compiled.ts").eval() |
@HolyWu Thanks for the suggestion. I hadn't looked at this for a while, I just tried my original code (torch_tensorrt.save, torch.export.load) on torch 2.5.1, torch_tensorrt 2.5.0. Everything seems to be working. I can compile, save and load successfully. The problem seems resolved so I'll close the issue. Thanks a lot. |
Bug Description
I am trying to use
torch_tensorrt.dynamo.compile()
to AOT compile the UNet portion of aStableDiffusionPipeline
from the diffusers library (version 0.30.2). I am able to export the UNet withtorch.export.export()
, compile it withtorch_tensorrt.dynamo.compile()
and save it withtorch_tensorrt.save()
. However, I am encountering a runtime error when attempting to load the saved compiled UNet withtorch.export.load()
.To Reproduce
Run the code below
Error message
Expected behavior
The code should load the saved compiled model without erroring out.
Environment
conda
,pip
,libtorch
, source): pipAdditional context
I have to use
functools.partial()
in the code above because the default output of the pipeline's forward method is theUNet2DConditionOutput
dataclass. I tried to get rid offunctools.partial()
by instead usingtorch.export.register_dataclass()
but was met with the same runtime error mentioned above.Additionally, saving and loading the ExportedProgram (without Torch-TensorRT compilation) works fine.
The text was updated successfully, but these errors were encountered: