diff --git a/monai/networks/nets/diffusion_model_unet.py b/monai/networks/nets/diffusion_model_unet.py index 113aa505e9..3e496979bf 100644 --- a/monai/networks/nets/diffusion_model_unet.py +++ b/monai/networks/nets/diffusion_model_unet.py @@ -33,8 +33,10 @@ import math from collections.abc import Sequence +from functools import reduce from typing import Optional +import numpy as np import torch from torch import nn @@ -1882,6 +1884,7 @@ class DiffusionModelEncoder(nn.Module): spatial_dims: number of spatial dimensions. in_channels: number of input channels. out_channels: number of output channels. + input_shape: spatial shape of the input (without batch and channel dims). num_res_blocks: number of residual blocks (see _ResnetBlock) per level. channels: tuple of block output channels. attention_levels: list of levels to add attention. @@ -1901,6 +1904,7 @@ def __init__( spatial_dims: int, in_channels: int, out_channels: int, + input_shape: Sequence[int] = (64, 64), num_res_blocks: Sequence[int] | int = (2, 2, 2, 2), channels: Sequence[int] = (32, 64, 64, 64), attention_levels: Sequence[bool] = (False, False, True, True), @@ -2007,7 +2011,14 @@ def __init__( self.down_blocks.append(down_block) - self.out: Optional[nn.Module] = None + for _ in channels: + input_shape = [int(np.ceil(i_ / 2)) for i_ in input_shape] + + last_dim_flattened = int(reduce(lambda x, y: x * y, input_shape) * channels[-1]) + + self.out: Optional[nn.Module] = nn.Sequential( + nn.Linear(last_dim_flattened, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels) + ) def forward( self, @@ -2052,10 +2063,9 @@ def forward( h = h.reshape(h.shape[0], -1) # 5. out - if self.out is None: - self.out = nn.Sequential( - nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels) - ) + self.out = nn.Sequential( + nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels) + ) output: torch.Tensor = self.out(h) return output