Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dimension issues with predict method #8

Open
julesmorata opened this issue Dec 19, 2023 · 2 comments
Open

Dimension issues with predict method #8

julesmorata opened this issue Dec 19, 2023 · 2 comments

Comments

@julesmorata
Copy link

julesmorata commented Dec 19, 2023

Hello,

I am trying to use your framework to perform mask extraction on images I have with the following code :

import os

import torchvision.transforms as transforms
from backbones_unet.model.unet import Unet
from PIL import Image

PATH="path_to_folder" # Replaced for confidentiality reasons 
CELLS_PATH = PATH + "cells/"

# Model init
model = Unet(backbone='xception71', in_channels=3, num_classes=1)
transform = transforms.Compose([transforms.ToTensor()])

# Data processing
for filename in os.listdir(CELLS_PATH):
    image = Image.open(CELLS_PATH + filename)
    tensor = transform(image).unsqueeze(0)
    print(tensor.shape)
    mask = model.predict(tensor)
    print(mask)

For the moment I just print masks to check what they look like before saving them. But I get the following error when trying to run the code :

Traceback (most recent call last):
File "hidden/mask.py", line 19, in
mask = model.predict(tensor)
File "env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 160, in predict
x = self.forward(x)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 142, in forward
x = self.decoder(x)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 304, in forward
x = b(x, skip)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 246, in forward
x = torch.cat([x, skip], dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 14 but got size 13 for tensor number 1 in the list.

And when printing dimensions of my input tensor as well as x and skip which are concatenated, I get in the same order :

torch.Size([1, 3, 207, 204])
torch.Size([1, 2048, 14, 14])
torch.Size([1, 1024, 13, 13])

Would you know where this come from / how to fix it ?

@physgorg
Copy link

physgorg commented Mar 8, 2024

I also encountered this issue. It arises from the rescaling line in the "forward" method of the DecoderBlock class in unet.py. I modified the code to interpolate directly to the correct shape as:
if self.scale_factor != 1.0: if skip is not None: target_size = (skip.shape[2],skip.shape[3]) x = F.interpolate(x, size = target_size, mode='nearest') else: x = F.interpolate(x,scale_factor=self.scale_factor,mode = 'nearest')

This ensures that the concatenation operation will proceed as desired. Not pushing this until I've tested that it doesn't mess anything up.

@physgorg
Copy link

physgorg commented Mar 8, 2024

One should probably modify how the target_size tuple is defined by indexing from the back of skip.shape.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants