Skip to content

Commit

Permalink
update decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
brianhou0208 committed Dec 7, 2024
1 parent 8d15c28 commit aa1e005
Showing 1 changed file with 27 additions and 14 deletions.
41 changes: 27 additions & 14 deletions segmentation_models_pytorch/decoders/pan/decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -195,28 +195,41 @@ def __init__(
)
)

encoder_channels = encoder_channels[2:][::-1]
encoder_channels = encoder_channels[2:]

self.fpa = FPABlock(
in_channels=encoder_channels[0], out_channels=decoder_channels
in_channels=encoder_channels[-1], out_channels=decoder_channels
)

for i in range(1, len(encoder_channels)):
self.add_module(
f"gau{len(encoder_channels)-i}",
GAUBlock(
in_channels=encoder_channels[i],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
),
if encoder_depth == 5:
self.gau3 = GAUBlock(
in_channels=encoder_channels[2],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)
if encoder_depth >= 4:
self.gau2 = GAUBlock(
in_channels=encoder_channels[1],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)
if encoder_depth >= 3:
self.gau1 = GAUBlock(
in_channels=encoder_channels[0],
out_channels=decoder_channels,
upscale_mode=upscale_mode,
)

def forward(self, *features):
features = features[2:] # remove first and second skip
features = features[::-1] # reverse channels to start from head of encoder

out = self.fpa(features[0])
out = self.fpa(features[-1]) # 1/16 or 1/32

if hasattr(self, "gau3"):
out = self.gau3(features[2], out) # 1/16
if hasattr(self, "gau2"):
out = self.gau2(features[1], out) # 1/8
if hasattr(self, "gau1"):
out = self.gau1(features[0], out) # 1/4

for i in range(1, len(features)):
out = getattr(self, f"gau{len(features)-i}")(features[i], out)
return out

0 comments on commit aa1e005

Please sign in to comment.