Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Utilizing resnet_50.pth for 3D Feature Map Extraction #83

Open
aeinkoupaei opened this issue Jan 29, 2024 · 3 comments
Open

Utilizing resnet_50.pth for 3D Feature Map Extraction #83

aeinkoupaei opened this issue Jan 29, 2024 · 3 comments

Comments

@aeinkoupaei
Copy link

Hi, I want to use resnet_50.pth pre-trained encoder to extract 3D feature maps from medical images. Is the following method correct? It seems strange that the parameters of width, height, depth and number of channels can be adjusted manually. Isn't it the case that the resnet_50.pth pre-trained model is trained with a specific architecture, length, width, height, and channel? Therefore, shouldn't the input of the trained model for extracting 3D feature maps have the same dimensions as inputs of the model in the training phase?

resnet50 = resnet50(
sample_input_D=32,
sample_input_H=256,
sample_input_W=256,
shortcut_type='B',
no_cuda=True,
num_seg_classes=1
)
pretrain = torch.load("pretrain/resnet_50.pth") # Load the weights from the pretrained file
pretrained_dict = pretrain['state_dict']
new_state_dict = OrderedDict()
for k, v in pretrained_dict.items():
name = k[7:] # Remove 'module.'
new_state_dict[name] = v
resnet10.load_state_dict(new_state_dict, strict=False)

A_img_feature_map = resnet50(A_img)

@Ram2314
Copy link

Ram2314 commented May 3, 2024

Hi! @aeinkoupaei Did you figure out if this is the correct approach to get the feature map from a image? Also needing to use this make a feature map. Is there any reason you made num_seg_classes=1?

@aeinkoupaei
Copy link
Author

aeinkoupaei commented May 8, 2024

Hi @Ram2314,

To use a pre-trained ResNet model for extracting 3D feature maps, you'll need to focus on the ResNet class within the resnet.py file. Here's what to change:
1- Modifying the init function of ResNet class: Delete the entire self.conv_seg block. This removes the unnecessary layers for our feature extraction task.
2- Modifying the forward function of ResNet class: Delete this line: x = self.conv_seg(x). This ensures the model doesn't perform the final segmentation prediction, but instead outputs the feature map before that stage.

Here's an example of how to use a pre-trained ResNet-10 model for feature extraction:

resnet_10 = resnet10(shortcut_type='B', no_cuda=True)
pretrain = torch.load("resnet_10_23dataset.pth")
pretrained_dict = pretrain['state_dict']
new_state_dict = OrderedDict()
for k, v in pretrained_dict.items():
name = k[7:]
new_state_dict[name] = v

resnet_10.load_state_dict(new_state_dict)

@Ram2314
Copy link

Ram2314 commented May 9, 2024

@aeinkoupaei Awesome thanks! Also did you figure out the length, width, height issue? Seems I can set it to anything and it will work?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants