Skip to content

Commit 595ca84

Browse files
committed
fix: dynamic input dim in DiffusionModelEncoder
Signed-off-by: IamTingTing <[email protected]>
1 parent e499362 commit 595ca84

File tree

1 file changed

+8
-1
lines changed

1 file changed

+8
-1
lines changed

monai/networks/nets/diffusion_model_unet.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333

3434
import math
3535
from collections.abc import Sequence
36+
from typing import Optional
3637

3738
import torch
3839
from torch import nn
@@ -2005,7 +2006,7 @@ def __init__(
20052006

20062007
self.down_blocks.append(down_block)
20072008

2008-
self.out = nn.Sequential(nn.Linear(4096, 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels))
2009+
self.out: Optional[nn.Module] = None
20092010

20102011
def forward(
20112012
self,
@@ -2048,6 +2049,12 @@ def forward(
20482049
h, _ = downsample_block(hidden_states=h, temb=emb, context=context)
20492050

20502051
h = h.reshape(h.shape[0], -1)
2052+
2053+
# 5. out
2054+
if self.out is None:
2055+
self.out = nn.Sequential(
2056+
nn.Linear(h.shape[1], 512), nn.ReLU(), nn.Dropout(0.1), nn.Linear(512, self.out_channels)
2057+
)
20512058
output: torch.Tensor = self.out(h)
20522059

20532060
return output

0 commit comments

Comments
 (0)