Skip to content

Commit

Permalink
ACLSD model with 2 UNets
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Nov 7, 2023
1 parent f6a4b94 commit e12ea4b
Showing 1 changed file with 30 additions and 7 deletions.
37 changes: 30 additions & 7 deletions src/raygun/torch/models/ACLSDModel.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# from funlib.learn.torch.models import UNet, ConvPass
from raygun.torch.networks import UNet
from raygun.torch.networks.UNet import ConvPass
import torch
Expand All @@ -11,7 +12,15 @@
class ACLSDModel(torch.nn.Module):
def __init__(
self,
unet_kwargs={
mt_unet_kwargs={
"input_nc": 1,
"ngf": 12,
"fmap_inc_factor": 6,
"num_heads": 2,
"downsample_factors": [(2, 2, 2), (2, 2, 2), (2, 2, 2)],
"constant_upsample": True,
},
ac_unet_kwargs={
"input_nc": 1,
"ngf": 12,
"fmap_inc_factor": 6,
Expand All @@ -22,13 +31,22 @@ def __init__(
):
super().__init__()

self.unet = UNet(**unet_kwargs)
self.mt_unet = UNet(**mt_unet_kwargs)
self.ac_unet = UNet(**ac_unet_kwargs)

self.aff_head = ConvPass(
unet_kwargs["ngf"], num_affs, [[1, 1, 1]], activation="Sigmoid"
mt_unet_kwargs["ngf"], num_affs, [[1, 1, 1]], activation="Sigmoid"
)

self.lsd_head = ConvPass( # TODO: Make work without LSD
mt_unet_kwargs["ngf"], 10, [[1, 1, 1]], activation="Sigmoid"
)

self.ac_aff_head = ConvPass(
ac_unet_kwargs["ngf"], num_affs, [[1, 1, 1]], activation="Sigmoid"
)

self.output_arrays = ["pred_affs"]
self.output_arrays = ["pred_affs", "pred_lsds", "pred_affs_ac"]
self.data_dict = {}

def add_log(self, writer, step):
Expand All @@ -51,7 +69,12 @@ def add_log(self, writer, step):

def forward(self, raw):
self.data_dict.update({"raw": raw.detach()})
z = self.unet(raw)
affs = self.aff_head(z)
a = self.mt_unet(raw)
# conv passes for MTLSD
affs = self.aff_head(a)
lsds = self.lsd_head(a)
b = self.ac_unet(lsds)
# conv pass for ACLSD
affs_ac = self.ac_aff_head(b)

return affs
return affs, lsds, affs_ac

0 comments on commit e12ea4b

Please sign in to comment.