Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

🐛 [Bug] Torch-TensorRT runs out of memory when compiling a small PointNet model with otherwise low memory requirements #1854

Closed
airalcorn2 opened this issue Apr 24, 2023 · 13 comments
Assignees
Labels
bug Something isn't working

Comments

@airalcorn2
Copy link

airalcorn2 commented Apr 24, 2023

Bug Description

When trying to compile the small PointNet model below, Torch-TensorRT runs out of memory on a GeForce RTX 3080. The model has fairly low memory requirements: it's only 892,677 parameters (~778MiB) and a forward pass of a tensor with shape (1, 50000, 3) only uses ~1822MiB total, so it's not clear why compiling the model takes so much memory.

To Reproduce

Steps to reproduce the behavior:

import torch
import torch_tensorrt

from torch import nn


class PointNet(nn.Module):
    def __init__(
        self,
        n_classes=5,
        first_mlp_dimensions=[64, 64],
        second_mlp_dimensions=[64, 128, 1024],
        segmentation_mlp_dimensions=[512, 256, 128, 128],
    ):
        super().__init__()

        in_dimensions = 3
        first_mlp_layers = []
        for out_dimensions in first_mlp_dimensions:
            first_mlp_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.first_mlp = nn.Sequential(*first_mlp_layers)

        second_mlp_layers = []
        for out_dimensions in second_mlp_dimensions:
            second_mlp_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.second_mlp = nn.Sequential(*second_mlp_layers)

        in_dimensions = first_mlp_dimensions[-1] + second_mlp_dimensions[-1]
        segmentation_layers = []
        for out_dimensions in segmentation_mlp_dimensions:
            segmentation_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.segmentation_mlp = nn.Sequential(*segmentation_layers)

        self.classifier = nn.Linear(in_dimensions, n_classes)

    def forward(self, points: torch.Tensor) -> torch.Tensor:
        _, n_points, _ = points.shape

        first_mlp_features = self.first_mlp(points)
        second_mlp_features = self.second_mlp(first_mlp_features)

        global_features = second_mlp_features.max(dim=1)[0]
        global_features = global_features.unsqueeze(1).expand(-1, n_points, -1)

        concatenated_features = torch.cat([first_mlp_features, global_features], dim=-1)

        segmentation_features = self.segmentation_mlp(concatenated_features)

        preds = self.classifier(segmentation_features)

        return preds


def do_forward_pass(model, P, device):
    with torch.no_grad():
        points = torch.rand(1, P, 3)
        _ = model(points.to(device))

    # nvidia-smi --> 1822MiB.


def main():
    device = torch.device("cuda:0")
    model = PointNet().to(device)
    model.eval()
    # nvidia-smi --> 778MiB.
    print(model)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # 892,677 parameters.
    print(f"Parameters: {n_params}")

    P = 50000
    inputs = [torch_tensorrt.Input((1, P, 3))]
    enabled_precisions = {torch.float}
    # Out of memory.
    trt_ts_module = torch_tensorrt.compile(
        model, inputs=inputs, enabled_precisions=enabled_precisions
    )


if __name__ == "__main__":
    main()

Expected behavior

Compile without error.

Environment

Build information about Torch-TensorRT can be found by turning on debug messages

  • Torch-TensorRT Version (e.g. 1.0.0): 1.3.0
  • PyTorch Version (e.g. 1.0): 1.13.1+cu117
  • CPU Architecture: i7-12800H
  • OS (e.g., Linux): Linux
  • How you installed PyTorch (conda, pip, libtorch, source): pip
  • Build command you used (if compiling from source):
  • Are you using local sources or building from archives:
  • Python version: 3.8.16
  • CUDA version: 11.7
  • GPU models and configuration: GeForce RTX 3080 Ti
  • Any other relevant information:

Additional context

@airalcorn2 airalcorn2 added the bug Something isn't working label Apr 24, 2023
@airalcorn2
Copy link
Author

If I restructure the forward pass to get rid of the batch dimension (which is OK here because we only have a batch of one and all of the layers only operate on the last dimensions of the input tensor), compiling works fine even though the memory requirements for the model in the forward pass are the same.

import torch
import torch_tensorrt

from torch import nn


class PointNet(nn.Module):
    def __init__(
        self,
        n_classes=5,
        first_mlp_dimensions=[64, 64],
        second_mlp_dimensions=[64, 128, 1024],
        segmentation_mlp_dimensions=[512, 256, 128, 128],
    ):
        super().__init__()

        in_dimensions = 3
        first_mlp_layers = []
        for out_dimensions in first_mlp_dimensions:
            first_mlp_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.first_mlp = nn.Sequential(*first_mlp_layers)

        second_mlp_layers = []
        for out_dimensions in second_mlp_dimensions:
            second_mlp_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.second_mlp = nn.Sequential(*second_mlp_layers)

        in_dimensions = first_mlp_dimensions[-1] + second_mlp_dimensions[-1]
        segmentation_layers = []
        for out_dimensions in segmentation_mlp_dimensions:
            segmentation_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.segmentation_mlp = nn.Sequential(*segmentation_layers)

        self.classifier = nn.Linear(in_dimensions, n_classes)

    def forward(self, points: torch.Tensor) -> torch.Tensor:
        n_points, _ = points.shape

        first_mlp_features = self.first_mlp(points)
        second_mlp_features = self.second_mlp(first_mlp_features)

        global_features = second_mlp_features.max(dim=0)[0]
        global_features = global_features.unsqueeze(0).expand(n_points, -1)

        concatenated_features = torch.cat([first_mlp_features, global_features], dim=-1)

        segmentation_features = self.segmentation_mlp(concatenated_features)

        preds = self.classifier(segmentation_features)

        return preds


def do_forward_pass(model, P, device):
    with torch.no_grad():
        points = torch.rand(P, 3)
        _ = model(points.to(device))

    # nvidia-smi --> 1822MiB.


def main():
    device = torch.device("cuda:0")
    model = PointNet().to(device)
    model.eval()
    # nvidia-smi --> 778MiB.
    print(model)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # 892,677 parameters.
    print(f"Parameters: {n_params}")

    P = 50000
    inputs = [torch_tensorrt.Input((P, 3))]
    enabled_precisions = {torch.float}
    # Works.
    trt_ts_module = torch_tensorrt.compile(
        model, inputs=inputs, enabled_precisions=enabled_precisions
    )


if __name__ == "__main__":
    main()

@airalcorn2
Copy link
Author

airalcorn2 commented Apr 24, 2023

While I could compile on the GeForce RTX 3080 Ti using the refactored model, I was still running into memory issues when trying to compile on a Jetson Xavier. I refactored the model again such that I (1) treated the input point cloud as a 2D feature map and (2) used 2D convolutional layers with a kernel size of one in place of linear layers. Interestingly, this new model compiles fine on the Xavier even though the forward pass requires slightly more memory.

import torch
import torch_tensorrt

from torch import nn


class PointNet(nn.Module):
    def __init__(
        self,
        n_classes=5,
        first_mlp_dimensions=[64, 64],
        second_mlp_dimensions=[64, 128, 1024],
        segmentation_mlp_dimensions=[512, 256, 128, 128],
    ):
        super().__init__()

        in_dimensions = 3
        first_mlp_layers = []
        for out_dimensions in first_mlp_dimensions:
            first_mlp_layers.extend(
                [
                    nn.Conv2d(in_dimensions, out_dimensions, 1),
                    nn.BatchNorm2d(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.first_mlp = nn.Sequential(*first_mlp_layers)

        second_mlp_layers = []
        for out_dimensions in second_mlp_dimensions:
            second_mlp_layers.extend(
                [
                    nn.Conv2d(in_dimensions, out_dimensions, 1),
                    nn.BatchNorm2d(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.second_mlp = nn.Sequential(*second_mlp_layers)
        self.max_pool = nn.MaxPool2d((50000, 1))

        in_dimensions = first_mlp_dimensions[-1] + second_mlp_dimensions[-1]
        segmentation_layers = []
        for out_dimensions in segmentation_mlp_dimensions:
            segmentation_layers.extend(
                [
                    nn.Conv2d(in_dimensions, out_dimensions, 1),
                    nn.BatchNorm2d(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.segmentation_mlp = nn.Sequential(*segmentation_layers)

        self.classifier = nn.Conv2d(in_dimensions, n_classes, 1)

    def forward(self, points: torch.Tensor) -> torch.Tensor:
        first_mlp_features = self.first_mlp(points)

        second_mlp_features = self.second_mlp(first_mlp_features)

        global_features = self.max_pool(second_mlp_features)
        global_features = global_features.expand(-1, -1, 50000, -1)

        concatenated_features = torch.cat([first_mlp_features, global_features], dim=1)

        segmentation_features = self.segmentation_mlp(concatenated_features)

        preds = self.classifier(segmentation_features)

        return preds


def do_forward_pass(model, P, device):
    with torch.no_grad():
        points = torch.rand(P, 3).permute(1, 0).unsqueeze(0).unsqueeze(3)
        _ = model(points.to(device))

    # nvidia-smi --> 2174MiB.


def main():
    device = torch.device("cuda:0")
    model = PointNet()
    model.eval()
    # nvidia-smi --> 778MiB.
    print(model)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # 892,677 parameters.
    print(f"Parameters: {n_params}")

    P = 50000
    inputs = [torch_tensorrt.Input((1, 3, P, 1))]
    enabled_precisions = {torch.float}
    # Works on the Xavier.
    trt_ts_module = torch_tensorrt.compile(
        model, inputs=inputs, enabled_precisions=enabled_precisions
    )


if __name__ == "__main__":
    main()

@apbose apbose self-assigned this Apr 28, 2023
@github-actions
Copy link

This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days

@apbose
Copy link
Collaborator

apbose commented Nov 6, 2023

The test passes locally with torch_tensorrt.compile(..., ir = "ts"). Regarding dynamo IR I encounter the error,

torch._dynamo.exc.TorchRuntimeError: Failed running call_module L__self___first_mlp_1(*(FakeTensor(..., device='cuda:0', size=(1, 64, 50000, 1),
           grad_fn=<ConvolutionBackward0>),), **{}):
Unhandled FakeTensor Device Propagation for aten.sub.Tensor, found two different devices cuda:0, cpu

It seems to be input error.

@apbose
Copy link
Collaborator

apbose commented Nov 6, 2023

It passes with torch_tensorrt.compile(..., ir = "torch_compile")

@narendasan
Copy link
Collaborator

Closing as issue seems resolved, reopen if necessary

@airalcorn2
Copy link
Author

To add an explanation, Torch-TensorRT can now be used as a backend for torch.compile (this was not possible with PyTorch versions before 2.0, which is the version I was using in the initial bug report). The code below compiles quickly and without error using PyTorch 2.0.1, Torch-TensorRT 1.4.0, and CUDA 12.2.

import torch
import torch_tensorrt

from torch import nn


class PointNet(nn.Module):
    def __init__(
        self,
        n_classes=5,
        first_mlp_dimensions=[64, 64],
        second_mlp_dimensions=[64, 128, 1024],
        segmentation_mlp_dimensions=[512, 256, 128, 128],
    ):
        super().__init__()

        in_dimensions = 3
        first_mlp_layers = []
        for out_dimensions in first_mlp_dimensions:
            first_mlp_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.first_mlp = nn.Sequential(*first_mlp_layers)

        second_mlp_layers = []
        for out_dimensions in second_mlp_dimensions:
            second_mlp_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.second_mlp = nn.Sequential(*second_mlp_layers)

        in_dimensions = first_mlp_dimensions[-1] + second_mlp_dimensions[-1]
        segmentation_layers = []
        for out_dimensions in segmentation_mlp_dimensions:
            segmentation_layers.extend(
                [
                    nn.Linear(in_dimensions, out_dimensions),
                    nn.LayerNorm(out_dimensions),
                    nn.ReLU(),
                ]
            )
            in_dimensions = out_dimensions

        self.segmentation_mlp = nn.Sequential(*segmentation_layers)

        self.classifier = nn.Linear(in_dimensions, n_classes)

    def forward(self, points: torch.Tensor) -> torch.Tensor:
        _, n_points, _ = points.shape

        first_mlp_features = self.first_mlp(points)
        second_mlp_features = self.second_mlp(first_mlp_features)

        global_features = second_mlp_features.max(dim=1)[0]
        global_features = global_features.unsqueeze(1).expand(-1, n_points, -1)

        concatenated_features = torch.cat([first_mlp_features, global_features], dim=-1)

        segmentation_features = self.segmentation_mlp(concatenated_features)

        preds = self.classifier(segmentation_features)

        return preds


def main():
    device = torch.device("cuda:0")
    model = PointNet().to(device)
    model.eval()
    # nvidia-smi --> 322MiB.
    print(model)
    n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    # 892,677 parameters.
    print(f"Parameters: {n_params}")

    optimized_model = torch.compile(model, backend="torch_tensorrt")
    P = 50000
    _ = optimized_model(torch.rand((1, P, 3)).to(device))


if __name__ == "__main__":
    main()

@airalcorn2
Copy link
Author

airalcorn2 commented Nov 28, 2023

@narendasan and @apbose - I would like to re-open this. My models that are compiled using torch.compile with the Torch-TensorRT backend are considerably slower than the models compiled with torch_tensorrt.compile, so using torch.compile is not a fix.

@DPeled
Copy link

DPeled commented Nov 30, 2023

@narendasan @apbose I am also encountering the same issue as @airalcorn2 .

I'm using the same environment. For some reason, when running the following torch_tensorrt.compile(), the compilation takes a vast amount of GPU memory RAM, which makes this unusable and unsuitable for large input batch optimization.
Is there a solution specifically for torch_tensorrt v1.3, rather than migrating to pytorch v2.x?

@airalcorn2 Regarding the considerably slow compiled models - I think this is due to the compilation at the first execution of the model. By the next execution, you should expect reduced inference time.

@airalcorn2
Copy link
Author

I'm actually confused about what exactly torch.compile(..., ir = "ts") means. There's no ir argument for torch.compile.

@airalcorn2
Copy link
Author

airalcorn2 commented Nov 30, 2023

It appears that torch.compile(..., ir = "ts") should be torch_tensorrt.compile(..., ir = "ts"). I experience the same issue when using ir="ts". @apbose - did you test the code that was included in the original issue comment?

@airalcorn2
Copy link
Author

airalcorn2 commented Nov 30, 2023

@DPeled - it turns out my models were not actually compiling because I was getting RuntimeErrors (here and here).

@apbose
Copy link
Collaborator

apbose commented Dec 1, 2023

Hi @airalcorn2 apologies for the confusion. So torch.compile(.., backend="tensorrt") andtorch_tensorrt.compile(.., ir= "torch_compile")are the same. Effectively torch_tensorrt.compile(model, ir="torch_compile") is a frontend for torch.compile(.., backend="tensorrt") .
Last I had checked the model mentioned in the issue was passing with torch_tensorrt.compile(model, ir = "torch_compile")

I am checking the issues you mentioned #2506 and #2507. It seems to be some argument issues which I am looking into.
When you mention you are seeing slower model compilation which backend are you using?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

5 participants