Skip to content
20 changes: 15 additions & 5 deletions monai/networks/nets/diffusion_model_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@

import math
from collections.abc import Sequence
from functools import reduce
from typing import Optional

import numpy as np
import torch
from torch import nn

Expand Down Expand Up @@ -1882,6 +1884,7 @@ class DiffusionModelEncoder(nn.Module):
spatial_dims: number of spatial dimensions.
in_channels: number of input channels.
out_channels: number of output channels.
input_shape: spatial shape of the input (without batch and channel dims).
num_res_blocks: number of residual blocks (see _ResnetBlock) per level.
channels: tuple of block output channels.
attention_levels: list of levels to add attention.
Expand All @@ -1901,6 +1904,7 @@ def __init__(
spatial_dims: int,
in_channels: int,
out_channels: int,
input_shape: Sequence[int] = (64, 64),
num_res_blocks: Sequence[int] | int = (2, 2, 2, 2),
channels: Sequence[int] = (32, 64, 64, 64),
attention_levels: Sequence[bool] = (False, False, True, True),
Expand Down Expand Up @@ -2007,7 +2011,14 @@ def __init__(

self.down_blocks.append(down_block)

self.out: Optional[nn.Module] = None
for _ in channels:
input_shape = [int(np.ceil(i_ / 2)) for i_ in input_shape]

last_dim_flattened = int(reduce(lambda x, y: x * y, input_shape) * channels[-1])

self.out: Optional[nn.Module] = nn.Sequential(
nn.Linear(last_dim_flattened, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)
)

def forward(
self,
Expand Down Expand Up @@ -2052,10 +2063,9 @@ def forward(
h = h.reshape(h.shape[0], -1)

# 5. out
if self.out is None:
self.out = nn.Sequential(
nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)
)
self.out = nn.Sequential(
nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)
)
output: torch.Tensor = self.out(h)

return output
Loading