3333
3434import math
3535from collections .abc import Sequence
36+ from functools import reduce
3637from typing import Optional
3738
39+ import numpy as np
3840import torch
3941from torch import nn
4042
@@ -1882,6 +1884,7 @@ class DiffusionModelEncoder(nn.Module):
18821884 spatial_dims: number of spatial dimensions.
18831885 in_channels: number of input channels.
18841886 out_channels: number of output channels.
1887+ input_shape: spatial shape of the input (without batch and channel dims).
18851888 num_res_blocks: number of residual blocks (see _ResnetBlock) per level.
18861889 channels: tuple of block output channels.
18871890 attention_levels: list of levels to add attention.
@@ -1901,6 +1904,7 @@ def __init__(
19011904 spatial_dims : int ,
19021905 in_channels : int ,
19031906 out_channels : int ,
1907+ input_shape : Sequence [int ] = (64 , 64 ),
19041908 num_res_blocks : Sequence [int ] | int = (2 , 2 , 2 , 2 ),
19051909 channels : Sequence [int ] = (32 , 64 , 64 , 64 ),
19061910 attention_levels : Sequence [bool ] = (False , False , True , True ),
@@ -2007,7 +2011,14 @@ def __init__(
20072011
20082012 self .down_blocks .append (down_block )
20092013
2010- self .out : Optional [nn .Module ] = None
2014+ for _ in channels :
2015+ input_shape = [int (np .ceil (i_ / 2 )) for i_ in input_shape ]
2016+
2017+ last_dim_flattened = int (reduce (lambda x , y : x * y , input_shape ) * channels [- 1 ])
2018+
2019+ self .out : Optional [nn .Module ] = nn .Sequential (
2020+ nn .Linear (last_dim_flattened , 512 ), nn .ReLU (), nn .Dropout (0.1 ), nn .Linear (512 , self .out_channels )
2021+ )
20112022
20122023 def forward (
20132024 self ,
@@ -2052,10 +2063,9 @@ def forward(
20522063 h = h .reshape (h .shape [0 ], - 1 )
20532064
20542065 # 5. out
2055- if self .out is None :
2056- self .out = nn .Sequential (
2057- nn .Linear (h .shape [1 ], 512 ), nn .ReLU (), nn .Dropout (0.1 ), nn .Linear (512 , self .out_channels )
2058- )
2066+ self .out = nn .Sequential (
2067+ nn .Linear (h .shape [1 ], 512 ), nn .ReLU (), nn .Dropout (0.1 ), nn .Linear (512 , self .out_channels )
2068+ )
20592069 output : torch .Tensor = self .out (h )
20602070
20612071 return output
0 commit comments