-
Notifications
You must be signed in to change notification settings - Fork 351
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
🐛 [Bug] Torch-TensorRT runs out of memory when compiling a small PointNet model with otherwise low memory requirements #1854
Comments
If I restructure the forward pass to get rid of the batch dimension (which is OK here because we only have a batch of one and all of the layers only operate on the last dimensions of the input tensor), compiling works fine even though the memory requirements for the model in the forward pass are the same. import torch
import torch_tensorrt
from torch import nn
class PointNet(nn.Module):
def __init__(
self,
n_classes=5,
first_mlp_dimensions=[64, 64],
second_mlp_dimensions=[64, 128, 1024],
segmentation_mlp_dimensions=[512, 256, 128, 128],
):
super().__init__()
in_dimensions = 3
first_mlp_layers = []
for out_dimensions in first_mlp_dimensions:
first_mlp_layers.extend(
[
nn.Linear(in_dimensions, out_dimensions),
nn.LayerNorm(out_dimensions),
nn.ReLU(),
]
)
in_dimensions = out_dimensions
self.first_mlp = nn.Sequential(*first_mlp_layers)
second_mlp_layers = []
for out_dimensions in second_mlp_dimensions:
second_mlp_layers.extend(
[
nn.Linear(in_dimensions, out_dimensions),
nn.LayerNorm(out_dimensions),
nn.ReLU(),
]
)
in_dimensions = out_dimensions
self.second_mlp = nn.Sequential(*second_mlp_layers)
in_dimensions = first_mlp_dimensions[-1] + second_mlp_dimensions[-1]
segmentation_layers = []
for out_dimensions in segmentation_mlp_dimensions:
segmentation_layers.extend(
[
nn.Linear(in_dimensions, out_dimensions),
nn.LayerNorm(out_dimensions),
nn.ReLU(),
]
)
in_dimensions = out_dimensions
self.segmentation_mlp = nn.Sequential(*segmentation_layers)
self.classifier = nn.Linear(in_dimensions, n_classes)
def forward(self, points: torch.Tensor) -> torch.Tensor:
n_points, _ = points.shape
first_mlp_features = self.first_mlp(points)
second_mlp_features = self.second_mlp(first_mlp_features)
global_features = second_mlp_features.max(dim=0)[0]
global_features = global_features.unsqueeze(0).expand(n_points, -1)
concatenated_features = torch.cat([first_mlp_features, global_features], dim=-1)
segmentation_features = self.segmentation_mlp(concatenated_features)
preds = self.classifier(segmentation_features)
return preds
def do_forward_pass(model, P, device):
with torch.no_grad():
points = torch.rand(P, 3)
_ = model(points.to(device))
# nvidia-smi --> 1822MiB.
def main():
device = torch.device("cuda:0")
model = PointNet().to(device)
model.eval()
# nvidia-smi --> 778MiB.
print(model)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# 892,677 parameters.
print(f"Parameters: {n_params}")
P = 50000
inputs = [torch_tensorrt.Input((P, 3))]
enabled_precisions = {torch.float}
# Works.
trt_ts_module = torch_tensorrt.compile(
model, inputs=inputs, enabled_precisions=enabled_precisions
)
if __name__ == "__main__":
main() |
While I could compile on the GeForce RTX 3080 Ti using the refactored model, I was still running into memory issues when trying to compile on a Jetson Xavier. I refactored the model again such that I (1) treated the input point cloud as a 2D feature map and (2) used 2D convolutional layers with a kernel size of one in place of linear layers. Interestingly, this new model compiles fine on the Xavier even though the forward pass requires slightly more memory. import torch
import torch_tensorrt
from torch import nn
class PointNet(nn.Module):
def __init__(
self,
n_classes=5,
first_mlp_dimensions=[64, 64],
second_mlp_dimensions=[64, 128, 1024],
segmentation_mlp_dimensions=[512, 256, 128, 128],
):
super().__init__()
in_dimensions = 3
first_mlp_layers = []
for out_dimensions in first_mlp_dimensions:
first_mlp_layers.extend(
[
nn.Conv2d(in_dimensions, out_dimensions, 1),
nn.BatchNorm2d(out_dimensions),
nn.ReLU(),
]
)
in_dimensions = out_dimensions
self.first_mlp = nn.Sequential(*first_mlp_layers)
second_mlp_layers = []
for out_dimensions in second_mlp_dimensions:
second_mlp_layers.extend(
[
nn.Conv2d(in_dimensions, out_dimensions, 1),
nn.BatchNorm2d(out_dimensions),
nn.ReLU(),
]
)
in_dimensions = out_dimensions
self.second_mlp = nn.Sequential(*second_mlp_layers)
self.max_pool = nn.MaxPool2d((50000, 1))
in_dimensions = first_mlp_dimensions[-1] + second_mlp_dimensions[-1]
segmentation_layers = []
for out_dimensions in segmentation_mlp_dimensions:
segmentation_layers.extend(
[
nn.Conv2d(in_dimensions, out_dimensions, 1),
nn.BatchNorm2d(out_dimensions),
nn.ReLU(),
]
)
in_dimensions = out_dimensions
self.segmentation_mlp = nn.Sequential(*segmentation_layers)
self.classifier = nn.Conv2d(in_dimensions, n_classes, 1)
def forward(self, points: torch.Tensor) -> torch.Tensor:
first_mlp_features = self.first_mlp(points)
second_mlp_features = self.second_mlp(first_mlp_features)
global_features = self.max_pool(second_mlp_features)
global_features = global_features.expand(-1, -1, 50000, -1)
concatenated_features = torch.cat([first_mlp_features, global_features], dim=1)
segmentation_features = self.segmentation_mlp(concatenated_features)
preds = self.classifier(segmentation_features)
return preds
def do_forward_pass(model, P, device):
with torch.no_grad():
points = torch.rand(P, 3).permute(1, 0).unsqueeze(0).unsqueeze(3)
_ = model(points.to(device))
# nvidia-smi --> 2174MiB.
def main():
device = torch.device("cuda:0")
model = PointNet()
model.eval()
# nvidia-smi --> 778MiB.
print(model)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# 892,677 parameters.
print(f"Parameters: {n_params}")
P = 50000
inputs = [torch_tensorrt.Input((1, 3, P, 1))]
enabled_precisions = {torch.float}
# Works on the Xavier.
trt_ts_module = torch_tensorrt.compile(
model, inputs=inputs, enabled_precisions=enabled_precisions
)
if __name__ == "__main__":
main() |
This issue has not seen activity for 90 days, Remove stale label or comment or this will be closed in 10 days |
The test passes locally with
It seems to be input error. |
It passes with torch_tensorrt.compile(..., ir = "torch_compile") |
Closing as issue seems resolved, reopen if necessary |
To add an explanation, Torch-TensorRT can now be used as a backend for import torch
import torch_tensorrt
from torch import nn
class PointNet(nn.Module):
def __init__(
self,
n_classes=5,
first_mlp_dimensions=[64, 64],
second_mlp_dimensions=[64, 128, 1024],
segmentation_mlp_dimensions=[512, 256, 128, 128],
):
super().__init__()
in_dimensions = 3
first_mlp_layers = []
for out_dimensions in first_mlp_dimensions:
first_mlp_layers.extend(
[
nn.Linear(in_dimensions, out_dimensions),
nn.LayerNorm(out_dimensions),
nn.ReLU(),
]
)
in_dimensions = out_dimensions
self.first_mlp = nn.Sequential(*first_mlp_layers)
second_mlp_layers = []
for out_dimensions in second_mlp_dimensions:
second_mlp_layers.extend(
[
nn.Linear(in_dimensions, out_dimensions),
nn.LayerNorm(out_dimensions),
nn.ReLU(),
]
)
in_dimensions = out_dimensions
self.second_mlp = nn.Sequential(*second_mlp_layers)
in_dimensions = first_mlp_dimensions[-1] + second_mlp_dimensions[-1]
segmentation_layers = []
for out_dimensions in segmentation_mlp_dimensions:
segmentation_layers.extend(
[
nn.Linear(in_dimensions, out_dimensions),
nn.LayerNorm(out_dimensions),
nn.ReLU(),
]
)
in_dimensions = out_dimensions
self.segmentation_mlp = nn.Sequential(*segmentation_layers)
self.classifier = nn.Linear(in_dimensions, n_classes)
def forward(self, points: torch.Tensor) -> torch.Tensor:
_, n_points, _ = points.shape
first_mlp_features = self.first_mlp(points)
second_mlp_features = self.second_mlp(first_mlp_features)
global_features = second_mlp_features.max(dim=1)[0]
global_features = global_features.unsqueeze(1).expand(-1, n_points, -1)
concatenated_features = torch.cat([first_mlp_features, global_features], dim=-1)
segmentation_features = self.segmentation_mlp(concatenated_features)
preds = self.classifier(segmentation_features)
return preds
def main():
device = torch.device("cuda:0")
model = PointNet().to(device)
model.eval()
# nvidia-smi --> 322MiB.
print(model)
n_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
# 892,677 parameters.
print(f"Parameters: {n_params}")
optimized_model = torch.compile(model, backend="torch_tensorrt")
P = 50000
_ = optimized_model(torch.rand((1, P, 3)).to(device))
if __name__ == "__main__":
main() |
@narendasan and @apbose - I would like to re-open this. My models that are compiled using |
@narendasan @apbose I am also encountering the same issue as @airalcorn2 . I'm using the same environment. For some reason, when running the following @airalcorn2 Regarding the considerably slow compiled models - I think this is due to the compilation at the first execution of the model. By the next execution, you should expect reduced inference time. |
I'm actually confused about what exactly |
It appears that |
Hi @airalcorn2 apologies for the confusion. So I am checking the issues you mentioned #2506 and #2507. It seems to be some argument issues which I am looking into. |
Bug Description
When trying to compile the small PointNet model below, Torch-TensorRT runs out of memory on a GeForce RTX 3080. The model has fairly low memory requirements: it's only 892,677 parameters (~778MiB) and a forward pass of a tensor with shape
(1, 50000, 3)
only uses ~1822MiB total, so it's not clear why compiling the model takes so much memory.To Reproduce
Steps to reproduce the behavior:
Expected behavior
Compile without error.
Environment
1.3.0
1.13.1+cu117
conda
,pip
,libtorch
, source):pip
3.8.16
11.7
Additional context
The text was updated successfully, but these errors were encountered: