-
Notifications
You must be signed in to change notification settings - Fork 29
Converting pytorch model to keras
Nick Tustison edited this page Apr 29, 2020
·
4 revisions
This is simply a record of trying to manually convert pytorch weights to keras weights for a simple network. Specifically, I was interested in the work of Han Peng, Weikang Gong, Christian F. Beckmann, Andrea Vedaldi, and Stephen M Smith, Accurate brain age prediction with lightweight deep neural networks, biorxiv, 2010. with the architecture and weights available here.
# First clone the repo:
#
# https://github.com/ha-ha-ha-han/UKBiobank_deep_pretrain/blob/master/dp_model/model_files/sfcn.py
#
from dp_model.model_files.sfcn import SFCN
from dp_model import dp_loss as dpl
from dp_model import dp_utils as dpu
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import antspynet
# https://github.com/keras-team/keras/blob/master/keras/utils/conv_utils.py#L62-L82
# We need the following function to convert the convolution kernels from channel
# first (pytorch) to channel last (keras)
# def convert_kernel(kernel):
# """Converts a Numpy kernel matrix from Theano format to TensorFlow format.
# Also works reciprocally, since the transformation is its own inverse.
# # Arguments
# kernel: Numpy array (3D, 4D or 5D).
# # Returns
# The converted kernel.
# # Raises
# ValueError: in case of invalid kernel shape or invalid data_format.
# """
# kernel = np.asarray(kernel)
# if not 3 <= kernel.ndim <= 5:
# raise ValueError('Invalid kernel shape:', kernel.shape)
# slices = [slice(None, None, -1) for _ in range(kernel.ndim)]
# no_flip = (slice(None, None), slice(None, None))
# slices[-2:] = no_flip
# return np.copy(kernel[tuple(slices)])
# Create the Pytorch model and load the weights
pytorch_model = SFCN()
pytorch_model = torch.nn.DataParallel(pytorch_model)
fp_ = './brain_age/run_20190719_00_epoch_best_mae.p'
pytorch_model.load_state_dict(torch.load(fp_, map_location=torch.device('cpu')))
# Print out the model so we can see what we have to convert.
# Since the model is simple, we're going to do it by hand.
# - Each convolution layer has the weight kernel and bias vector.
# - Each batch norm. layer has a weight vector and bias vector.
# >>> model
# DataParallel(
# (module): SFCN(
# (feature_extractor): Sequential(
# (conv_0): Sequential(
# (0): Conv3d(1, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
# (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (3): ReLU()
# )
# (conv_1): Sequential(
# (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
# (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (3): ReLU()
# )
# (conv_2): Sequential(
# (0): Conv3d(64, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
# (1): BatchNorm3d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (3): ReLU()
# )
# (conv_3): Sequential(
# (0): Conv3d(128, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
# (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (3): ReLU()
# )
# (conv_4): Sequential(
# (0): Conv3d(256, 256, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
# (1): BatchNorm3d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
# (3): ReLU()
# )
# (conv_5): Sequential(
# (0): Conv3d(256, 64, kernel_size=(1, 1, 1), stride=(1, 1, 1))
# (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
# (2): ReLU()
# )
# )
# (classifier): Sequential(
# (average_pool): AvgPool3d(kernel_size=[5, 6, 5], stride=[5, 6, 5], padding=0)
# (dropout): Dropout(p=0.5, inplace=False)
# (conv_6): Conv3d(64, 40, kernel_size=(1, 1, 1), stride=(1, 1, 1))
# )
# )
# )
pytorch_weights = list(pytorch_model.parameters())
# >>> len(pytorch_weights)
# 26
#
# For each *sequential* layer, the first two are the weights and bias of the conv. kernel
# pytorch_weights[0].shape = torch.Size([32, 1, 3, 3, 3])
# pytorch_weights[1].shape = torch.Size([32])
# The next two belong to the batch normalization
# pytorch_weights[2].shape = torch.Size([32])
# pytorch_weights[3].shape = torch.Size([32])
keras_model = antspynet.create_simple_fully_convolutional_network_model_3d((160, 192, 160, 1))
# conv_0
pytorch_kernel = pytorch_weights[0].detach().numpy()
pytorch_bias = pytorch_weights[1].detach().numpy()
keras_weights = list()
keras_weights.append(np.transpose(pytorch_kernel, [2, 3, 4, 1, 0]))
keras_weights.append(pytorch_bias)
keras_model.get_layer(index=1).set_weights(keras_weights)
pytorch_scale = pytorch_weights[2].detach().numpy()
pytorch_offset = pytorch_weights[3].detach().numpy()
keras_weights = list()
keras_weights.append(pytorch_scale)
keras_weights.append(pytorch_offset)
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_0.1.running_mean').numpy())
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_0.1.running_var').numpy())
keras_model.get_layer(index=3).set_weights(keras_weights)
# conv_1
pytorch_kernel = pytorch_weights[4].detach().numpy()
pytorch_bias = pytorch_weights[5].detach().numpy()
keras_weights = list()
keras_weights.append(np.transpose(pytorch_kernel, [2, 3, 4, 1, 0]))
keras_weights.append(pytorch_bias)
keras_model.get_layer(index=6).set_weights(keras_weights)
pytorch_scale = pytorch_weights[6].detach().numpy()
pytorch_offset = pytorch_weights[7].detach().numpy()
keras_weights = list()
keras_weights.append(pytorch_scale)
keras_weights.append(pytorch_offset)
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_1.1.running_mean').numpy())
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_1.1.running_var').numpy())
keras_model.get_layer(index=8).set_weights(keras_weights)
# conv_2
pytorch_kernel = pytorch_weights[8].detach().numpy()
pytorch_bias = pytorch_weights[9].detach().numpy()
keras_weights = list()
keras_weights.append(np.transpose(pytorch_kernel, [2, 3, 4, 1, 0]))
keras_weights.append(pytorch_bias)
keras_model.get_layer(index=11).set_weights(keras_weights)
pytorch_scale = pytorch_weights[10].detach().numpy()
pytorch_offset = pytorch_weights[11].detach().numpy()
keras_weights = list()
keras_weights.append(pytorch_scale)
keras_weights.append(pytorch_offset)
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_2.1.running_mean').numpy())
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_2.1.running_var').numpy())
keras_model.get_layer(index=13).set_weights(keras_weights)
# conv_3
pytorch_kernel = pytorch_weights[12].detach().numpy()
pytorch_bias = pytorch_weights[13].detach().numpy()
keras_weights = list()
keras_weights.append(np.transpose(pytorch_kernel, [2, 3, 4, 1, 0]))
keras_weights.append(pytorch_bias)
keras_model.get_layer(index=16).set_weights(keras_weights)
pytorch_scale = pytorch_weights[14].detach().numpy()
pytorch_offset = pytorch_weights[15].detach().numpy()
keras_weights = list()
keras_weights.append(pytorch_scale)
keras_weights.append(pytorch_offset)
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_3.1.running_mean').numpy())
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_3.1.running_var').numpy())
keras_model.get_layer(index=18).set_weights(keras_weights)
# conv_4
pytorch_kernel = pytorch_weights[16].detach().numpy()
pytorch_bias = pytorch_weights[17].detach().numpy()
keras_weights = list()
keras_weights.append(np.transpose(pytorch_kernel, [2, 3, 4, 1, 0]))
keras_weights.append(pytorch_bias)
keras_model.get_layer(index=21).set_weights(keras_weights)
pytorch_scale = pytorch_weights[18].detach().numpy()
pytorch_offset = pytorch_weights[19].detach().numpy()
keras_weights = list()
keras_weights.append(pytorch_scale)
keras_weights.append(pytorch_offset)
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_4.1.running_mean').numpy())
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_4.1.running_var').numpy())
keras_model.get_layer(index=23).set_weights(keras_weights)
# conv_5
pytorch_kernel = pytorch_weights[20].detach().numpy()
pytorch_bias = pytorch_weights[21].detach().numpy()
keras_weights = list()
keras_weights.append(np.transpose(pytorch_kernel, [2, 3, 4, 1, 0]))
keras_weights.append(pytorch_bias)
keras_model.get_layer(index=26).set_weights(keras_weights)
pytorch_scale = pytorch_weights[22].detach().numpy()
pytorch_offset = pytorch_weights[23].detach().numpy()
keras_weights = list()
keras_weights.append(pytorch_scale)
keras_weights.append(pytorch_offset)
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_5.1.running_mean').numpy())
keras_weights.append(pytorch_model.state_dict().get('module.feature_extractor.conv_5.1.running_var').numpy())
keras_model.get_layer(index=27).set_weights(keras_weights)
# last convolution
pytorch_kernel = pytorch_weights[24].detach().numpy()
pytorch_bias = pytorch_weights[25].detach().numpy()
keras_weights = list()
keras_weights.append(np.transpose(pytorch_kernel, [2, 3, 4, 1, 0]))
keras_weights.append(pytorch_bias)
keras_model.get_layer(index=31).set_weights(keras_weights)
keras_model.save_weights("./brain_age/run_20190719_00_epoch_best_mae_keras.h5")