Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

❓ [Question] Examples not working in nvcr.io/nvidia/pytorch:23.09-py3. #2415

Open
sayakpaul opened this issue Oct 26, 2023 · 26 comments
Open
Assignees
Labels
question Further information is requested

Comments

@sayakpaul
Copy link

❓ Question

I am within the nvcr.io/nvidia/pytorch:23.09-py3 container. Trying out some snippets from:
https://youtu.be/eGDMJ3MY4zk?si=MhkbgwAPVQSFZEha.

Both JIT and AoT examples failed. For JIT, it complained that "tensorrt" backend isn't available, for AoT, it complained that "The user code is using a feature we don't support. Please try torchdynamo.explain() to get possible the reasons".

I am on an A100. What's going on?

@sayakpaul sayakpaul added the question Further information is requested label Oct 26, 2023
@gs-olive
Copy link
Collaborator

Hello - for the backend="tensorrt" issue, please ensure the line import torch_tensorrt is in the script (even though it may not be used directly). This is important because the import itself registers the backend, as here:

@td.register_backend(name="tensorrt") # type: ignore[misc]

If this does not address the issue, could you please also share the version of Torch-TensorRT being used?

Regarding the AoT approach, @peri044 may be able to help further, but based on the error it sounds like the model may not be natively trace-able. Which specific example from the presentation are you referencing for the AoT sample?

@sayakpaul
Copy link
Author

This one:
image

@peri044
Copy link
Collaborator

peri044 commented Oct 27, 2023

Hello @sayakpaul , I would suggest you to try our latest main branch. 23.09 is old now. Let us know if you see any issues with the main branch.
Alternatively, you can try our nightly container if you don't want to build from source : https://github.com/pytorch/TensorRT/pkgs/container/tensorrt%2Ftorch_tensorrt

@sayakpaul
Copy link
Author

sayakpaul commented Oct 28, 2023

With the latest container (https://github.com/pytorch/TensorRT/pkgs/container/tensorrt%2Ftorch_tensorrt), the JIT workflow works. Have a few questions. So, let me use this thread to ask them. Maybe the community might find them to be useful.

@sayakpaul
Copy link
Author

sayakpaul commented Oct 28, 2023

I used the following code snippet I am using to benchmark the JIT workflow:

import torch
import torch_tensorrt
import torch.utils.benchmark as benchmark

class MyModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(64, 128, 3)
    def forward(self, a):
        if a.sum() > 3:
            b = torch.nn.functional.elu(a)
        else:
            b = a + 1
        return self.conv(b)


# Taken from
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6


model = MyModel().eval().cuda()
inputs = torch.nn.randn(32, 64, 64, 64, device="cuda")

print("Using regular model:")
print(benchmark_fn(model, inputs))


optimized_model = torch.compile(
    model,
    backend="tensorrt",
    dynamic=False,
    options={
        "debug": True,
        "enabled_precisions": {torch.half},
        "min_block_size": 1,
    }
)
_ = optimized_model(inputs)
print("Using optimized model:")
print(benchmark_fn(optimized_model, inputs))

I am getting:

620.6526302793078 microseconds (for the non-compiled model)
2095.269737765193 (for the compiled model)

I am on an 80 GB A100. I can understand the compute load might not be enough for me to see evident speedups here, but is this expected? Wanted to know your opinion.

@sayakpaul
Copy link
Author

The AOT workflow still fails:

import torch
import torch_tensorrt
import torch.utils.benchmark as benchmark

class MyModel(torch.nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv = torch.nn.Conv2d(64, 128, 3)
    def forward(self, a):
        if a.sum() > 3:
            b = torch.nn.functional.elu(a)
        else:
            b = a + 1
        return self.conv(b)


# Taken from
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6


inputs = torch.randn((256, 64, 32, 32)).cuda()
model = MyModel().eval().cuda()

print("Using regular model:")
print(benchmark_fn(model, inputs))

exp_program = torch_tensorrt.dynamo.trace(model, [inputs])
trt_model = torch_tensorrt.dynamo.compile(exp_program, inputs=[inputs])

_ = trt_model(input)
print("Using optimized model:")
print(benchmark_fn(optimized_model, inputs))

print("Serializing...")
trt_ser_model = torch_tensorrt.dynamo.serialize(trt_model, *inputs)
torch.save(trt_ser_model, "trt_model.pt")

Trace:

Traceback (most recent call last):
  File "/opt/torch_tensorrt/trt_aot.py", line 33, in <module>
    exp_program = torch_tensorrt.dynamo.trace(model, [inputs])
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/aten_tracer.py", line 51, in trace
    if input.shape_mode == Input._ShapeMode.DYNAMIC:
AttributeError: 'Tensor' object has no attribute 'shape_mode'

@sayakpaul
Copy link
Author

When I swapped the inputs with the following:

inputs = [torch_tensorrt.Input(
    min_shape=(1, 512, 1, 1),
    opt_shape=(4, 512, 1, 1),
    max_shape=(8, 512, 1, 1)
)]

it led to the following:

Traceback (most recent call last):
  File "/opt/torch_tensorrt/trt_aot.py", line 37, in <module>
    exp_program = torch_tensorrt.dynamo.trace(model, inputs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/aten_tracer.py", line 86, in trace
    exp_program = export(model, tuple(trace_inputs), constraints=constraints)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/__init__.py", line 556, in export
    return _export(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/__init__.py", line 596, in _export
    gm_torch_level = _export_to_torch_ir(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_export/__init__.py", line 517, in _export_to_torch_ir
    gm_torch_level, _ = torch._dynamo.export(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 1226, in inner
    result_traced = opt_f(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 410, in _fn
    return fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 558, in catch_errors
    return callback(frame, cache_entry, hooks, frame_state)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 148, in _fn
    return fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 402, in _convert_frame_assert
    return _compile(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 610, in _compile
    guarded_code = compile_inner(code, one_graph, hooks, transform)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/utils.py", line 221, in time_wrapper
    r = func(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 527, in compile_inner
    out_code = transform_code_object(code, transform)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/bytecode_transformation.py", line 1028, in transform_code_object
    transformations(instructions, code_options)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/convert_frame.py", line 497, in transform
    tracer.run()
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 2127, in run
    super().run()
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 751, in run
    and self.step()
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 714, in step
    getattr(self, inst.opname)(inst)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_dynamo/symbolic_convert.py", line 382, in inner
    raise exc.UserError(
torch._dynamo.exc.UserError: Dynamic control flow is not supported at the moment. Please use functorch.experimental.control_flow.cond to explicitly capture the control flow. For more information about this error, see: https://pytorch.org/docs/main/generated/exportdb/index.html#cond-operands

from user code:
   File "/opt/torch_tensorrt/trt_aot.py", line 10, in forward
    if a.sum() > 3:

Set TORCH_LOGS="+dynamo" and TORCHDYNAMO_VERBOSE=1 for more information

I changed the code accordingly as well:

exp_program = torch_tensorrt.dynamo.trace(model, inputs)
trt_model = torch_tensorrt.dynamo.compile(exp_program, inputs=inputs)

@gs-olive
Copy link
Collaborator

Hi @sayakpaul - thank you for the comments. Regarding the JIT workflow, this specific example was intended to illustrate how graph breaks are handled in torch.compile with the Torch-TRT backend. As you suggested, the model itself does not have sufficient operations to make the overhead of using a TRT engine worthwhile here. Additionally, this graph has an intentionally-added control-flow break, which means that a minimum of two TRT engines will be generated, further increasing the overhead for this small model. Since this model only has one computational operator per TRT block, the default options would not have converted these operators to TRT at all ("min_block_size": 1 overrides this default). As a result, I believe the metrics are roughly as expected here.

The AOT workflow is not expected to work for this specific model (the one with the conditional). It would only work for the model without the conditional, as here. This is because the Torch ATen tracer cannot handle this sort of Python control flow. In order to use the AOT workflow on a model like this, the code would need to be modified to either remove the conditional or use tracer-allowed conditionals, like torch.cond

@sayakpaul
Copy link
Author

Will try and get back. Thanks for your inputs!

Could you maybe provide an example for JIT where the improvements are evident? FWIW, we did try something with SD but the benefits were not there: huggingface/diffusers#5564.

Probably we are missing something out?

@gs-olive
Copy link
Collaborator

Thanks for the follow-up! I will look into that sample you provided a bit more - it looks like the TRT conversion errored out for some reason, which would cause compilation to fall back to Torch eager. One other to try might be the Stable Diffusion version highlighted here.

@gs-olive
Copy link
Collaborator

gs-olive commented Oct 30, 2023

Also, regarding the error cited in huggingface/diffusers#5564, I retried your sample on the latest nightly version of the repository, with the following small code modifications, and it is working:

    elif run_compile and with_tensorrt:
        print("Run torch compile with TensorRT backend")
        pipe.unet = torch.compile(
            pipe.unet, fullgraph=True, backend="tensorrt", dynamic=False,
            options={"min_block_size": 1,
                     "truncate_long_and_double": True,
                     "enabled_precisions": {torch.half}}
        )

Please let me know if this resolves the issue for you as well.

@sayakpaul
Copy link
Author

I retried your sample on the latest nightly version of the repository

Don't want to build from source. Will https://github.com/pytorch/TensorRT/pkgs/container/tensorrt%2Ftorch_tensorrt work here? (I will prune the previous container before mounting)

@gs-olive
Copy link
Collaborator

gs-olive commented Oct 31, 2023

Yes, I believe the Docker container should also work. You could alternatively also use the following for installing the latest nightly distribution of Torch-TRT:

pip install --pre torch torch-tensorrt --extra-index-url https://download.pytorch.org/whl/nightly/cu121

@sayakpaul
Copy link
Author

sayakpaul commented Nov 1, 2023

Hello @gs-olive! Apologies for the delay on my end.

I did run the SDXL sample script on the latest nightly container (ensuring that the previous one was pruned from the system before), but I am still not seeing the expected speedups (for a batch size of 4):

With compilation: False, and TensorRT: False in 23988496.810 microseconds

With compilation: True, and TensorRT: False in 20895331.856 microseconds

With compilation: True, and TensorRT: True in 32650460.551 microseconds

Here is the trace:

Traceback (most recent call last):
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 95, in _pretraced_backend
    trt_compiled = compile_module(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 304, in compile_module
    trt_module = convert_module(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 34, in convert_module
    module_outputs = module(*torch_inputs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 736, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 315, in __call__
    raise e
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 302, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1519, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1528, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.14", line 6, in forward
    view_14 = torch.ops.aten.view.default(permute_15, [8, -1, 640]);  permute_15 = None
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_ops.py", line 513, in __call__
    return self._op(*args, **kwargs or {})
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1378, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1676, in dispatch
    r = func(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_ops.py", line 513, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

Are you seeing anything different?

I am using a 80 GB A100 FYI. I also modified the code with your suggestion and below is the full code snippet:

import argparse

import torch
import torch_tensorrt
import torch.utils.benchmark as benchmark

from diffusers import DiffusionPipeline

CKPT = "stabilityai/stable-diffusion-xl-base-1.0"
NUM_INFERENCE_STEPS = 50
PROMPT = "ghibli style, a fantasy landscape with castles"


def load_pipeline(run_compile=False, with_tensorrt=False):
    pipe = DiffusionPipeline.from_pretrained(
        CKPT, torch_dtype=torch.float16, use_safetensors=True
    )
    pipe = pipe.to("cuda")
    pipe.unet.to(memory_format=torch.channels_last)

    if run_compile and not with_tensorrt:
        print("Run torch compile")
        pipe.unet = torch.compile(pipe.unet, mode="reduce-overhead", fullgraph=True)
    elif run_compile and with_tensorrt:
        print("Run torch compile with TensorRT backend")
        pipe.unet = torch.compile(
            pipe.unet, fullgraph=True, backend="tensorrt", dynamic=False,
            options={"min_block_size": 1,
                     "truncate_long_and_double": True,
                     "enabled_precisions": {torch.half}}
        )

    pipe.set_progress_bar_config(disable=True)
    return pipe


def run_inference(pipe, batch_size=1):
    _ = pipe(
        prompt=PROMPT,
        num_inference_steps=NUM_INFERENCE_STEPS,
        num_images_per_prompt=batch_size,
    )


# Taken from
# https://pytorch.org/tutorials/intermediate/scaled_dot_product_attention_tutorial.html
def benchmark_fn(f, *args, **kwargs):
    t0 = benchmark.Timer(
        stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
    )
    return t0.blocked_autorange().mean * 1e6


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--batch_size", type=int, default=1)
    parser.add_argument("--run_compile", action="store_true")
    parser.add_argument("--with_tensorrt", action="store_true")
    args = parser.parse_args()

    pipeline = load_pipeline(
        run_compile=args.run_compile, with_tensorrt=args.with_tensorrt
    )
    print(
        f"With compilation: {args.run_compile}, and TensorRT: {args.with_tensorrt} in {benchmark_fn(run_inference, pipeline, args.batch_size):.3f} microseconds"
    )

@gs-olive
Copy link
Collaborator

gs-olive commented Nov 1, 2023

Hi @sayakpaul - thanks for the details. I am able to reproduce the issue you are facing, but it only appears to happen with the latest nightly torch distribution. Specifically, I do not see the error when using torch==2.1.0. We will be pushing a Docker container with that build later today. I will update this issue when that is ready. It seems that something changed in the Torch tracer which resulted in this error.

Additionally, we do not have full converter support yet for this specific model on the latest Torch nightly (but we do on 2.1.0). The missing operator we need to implement is aten._scaled_dot_product_flash_attention and I have instanced an issue for this #2427, so we can prioritize it.

@sayakpaul
Copy link
Author

Cool, let me know know :-) Looking forward to it. I will try to fix the SHA of the Docker container so that it's reproducible for the community.

@sayakpaul
Copy link
Author

@gs-olive a gentle ping if the container is available :)

@gs-olive
Copy link
Collaborator

gs-olive commented Nov 6, 2023

Thanks for the message and apologies for the delay - I expect it to be ready at some point today or tomorrow and I will update this issue once it is ready.

@gs-olive
Copy link
Collaborator

gs-olive commented Nov 7, 2023

Hi @sayakpaul - the container built against PyTorch 2.1.0 can be found here:

ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.1

@sayakpaul
Copy link
Author

This is what I get now:

Traceback (most recent call last):
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/backend/backends.py", line 95, in _pretraced_backend
    trt_compiled = compile_module(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/_compiler.py", line 312, in compile_module
    trt_module = convert_module(
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch_tensorrt/dynamo/conversion/_conversion.py", line 33, in convert_module
    module_outputs = module(*torch_inputs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 678, in call_wrapped
    return self._wrapped_call(self, *args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 284, in __call__
    raise e
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/fx/graph_module.py", line 274, in __call__
    return super(self.cls, obj).__call__(*args, **kwargs)  # type: ignore[misc]
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "<eval_with_key>.14", line 6, in forward
    view_14 = torch.ops.aten.view.default(permute_15, [8, -1, 640]);  permute_15 = None
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
    return self._op(*args, **kwargs or {})
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/utils/_stats.py", line 20, in wrapper
    return fn(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1250, in __torch_dispatch__
    return self.dispatch(func, types, args, kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_subclasses/fake_tensor.py", line 1541, in dispatch
    r = func(*args, **kwargs)
  File "/root/.pyenv/versions/3.10.13/lib/python3.10/site-packages/torch/_ops.py", line 448, in __call__
    return self._op(*args, **kwargs or {})
RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

@gs-olive
Copy link
Collaborator

gs-olive commented Nov 7, 2023

Ok - thanks for the update. Could you also try with transformers==4.33.2 and diffusers==0.21.4? I am trying to determine whether the transformers and diffusers versions may also be related to the issue.

@sayakpaul
Copy link
Author

sayakpaul commented Nov 10, 2023 via email

@sayakpaul
Copy link
Author

@gs-olive I have a chance to do the benchmarking for Stable Diffusion. Here are the timings:

With compilation: False, and TensorRT: False in 3.767 seconds
With compilation: True, and TensorRT: False in 3.045 seconds
With compilation: True, and TensorRT: True in 1.157 seconds

Env:

  1. ghcr.io/pytorch/tensorrt/torch_tensorrt:release_2.1
  2. transformers==4.33.2 diffusers==0.21.4
  3. Code: https://gist.github.com/sayakpaul/2e4534b205841ccce085400b8c42da85

Will see with SDXL too.

@sayakpaul
Copy link
Author

sayakpaul commented Nov 30, 2023

Also, a bit suprising that with diffusers and transformers latest stable releases timings are quite off:

With compilation: False, and TensorRT: False in 1.708 seconds
With compilation: True, and TensorRT: False in 1.447 seconds
With compilation: True, and TensorRT: True in 2.379 seconds

But if I keep the transformers version to 4.33.2 and diffusers to the latest stable release then the timings are:

With compilation: False, and TensorRT: False in 3.512 seconds
With compilation: True, and TensorRT: False in 3.061 seconds
With compilation: True, and TensorRT: True in 1.181 seconds

Are you able to reproduce these numbers on a 80GB A100?

@sayakpaul
Copy link
Author

I also extended the above the script to use SDXL and the timings (updated here):

With compilation: False, and TensorRT: False in 6.713 seconds
With compilation: True, and TensorRT: False in 6.417 seconds
With compilation: True, and TensorRT: True in 5.537 seconds

I have set transformers version to 4.33.2 and I am using diffusers latest.

@gs-olive
Copy link
Collaborator

gs-olive commented Dec 13, 2023

Thanks for the follow-up. It seems that in the newer transformers and diffusers versions, some of the operators have changed in the Stable Diffusion models, so our converter support is different which can affect performance. We have issues filed to support these as well, so the performance can be more uniform across transformers and diffusers versions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

4 participants