Skip to content

Commit

Permalink
Cleanup process creation in training
Browse files Browse the repository at this point in the history
  • Loading branch information
YohannPerron committed Jul 20, 2023
1 parent ed6d939 commit e317676
Showing 1 changed file with 50 additions and 47 deletions.
97 changes: 50 additions & 47 deletions scripts/recon/train_unrolled.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,35 @@ def simulate_dataset(config, psf):
return ds_loader


def create_process_network(network, depth, device="cpu"):
if network == "DruNet":
from lensless.utils.image import load_drunet

process = load_drunet(
os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True
).to(device)
process_name = "DruNet"
elif network == "UnetRes":
from lensless.drunet.network_unet import UNetRes

n_channels = 3
process = UNetRes(
in_nc=n_channels + 1,
out_nc=n_channels,
nc=[64, 128, 256, 512],
nb=depth,
act_mode="R",
downsample_mode="strideconv",
upsample_mode="convtranspose",
).to(device)
process_name = "UnetRes_d" + str(depth)
else:
process = None
process_name = None

return (process, process_name)


def measure_gradient(model):
# return the L2 norm of the gradient
total_norm = 0.0
Expand Down Expand Up @@ -138,45 +167,20 @@ def train_unrolled(
save = os.getcwd()

start_time = time.time()
# Load post process model
if config.reconstruction.post_process.network == "DruNet":
from lensless.utils.image import load_drunet

post_process = load_drunet(
os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True
).to(device)
elif config.reconstruction.post_process.network == "UnetRes":
from lensless.drunet.network_unet import UNetRes

n_channels = 3
post_process = UNetRes(
in_nc=n_channels + 1,
out_nc=n_channels,
nc=[64, 128, 256, 512],
nb=config.reconstruction.post_process.depth,
act_mode="R",
downsample_mode="strideconv",
upsample_mode="convtranspose",
).to(device)
else:
post_process = None

if config.reconstruction.pre_process.network == "UnetRes":
from lensless.drunet.network_unet import UNetRes

n_channels = 3
pre_process = UNetRes(
in_nc=n_channels + 1,
out_nc=n_channels,
nc=[64, 128, 256, 512],
nb=config.reconstruction.pre_process.depth,
act_mode="R",
downsample_mode="strideconv",
upsample_mode="convtranspose",
).to(device)
else:
pre_process = None

# Load pre process model
pre_process, pre_process_name = create_process_network(
config.reconstruction.pre_process.network,
config.reconstruction.pre_process.depth,
device=device,
)
# Load post process model
post_process, post_process_name = create_process_network(
config.reconstruction.post_process.network,
config.reconstruction.post_process.depth,
device=device,
)
# create reconstruction algorithm
if config.reconstruction.method == "unrolled_fista":
recon = UnrolledFISTA(
psf,
Expand All @@ -203,6 +207,13 @@ def train_unrolled(
else:
raise ValueError(f"{config.reconstruction.method} is not a supported algorithm")

# constructing algorithm name by appending pre and post process
algorithm_name = config.reconstruction.method
if config.reconstruction.pre_process.network is not None:
algorithm_name = pre_process_name + "_" + algorithm_name
if config.reconstruction.post_process.network is not None:
algorithm_name += "_" + post_process_name

# print number of parameters
print(f"Training model with {sum(p.numel() for p in recon.parameters())} parameters")
# transform from BGR to RGB
Expand Down Expand Up @@ -283,14 +294,6 @@ def learning_rate_function(epoch):

scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=learning_rate_function)

# constructing algorithm name by appending pre and post process
algorithm = config.reconstruction.method
if config.reconstruction.post_process.network == "DruNet":
algorithm += "_DruNet"
elif config.reconstruction.post_process.network == "UnetRes":
algorithm += "_UnetRes"
if config.reconstruction.pre_process.network == "UnetRes":
algorithm = "PreUnetRes_" + algorithm
metrics = {
"LOSS": [],
"MSE": [],
Expand All @@ -301,7 +304,7 @@ def learning_rate_function(epoch):
"SSIM": [],
"ReconstructionError": [],
"n_iter": n_iter,
"algorithm": algorithm,
"algorithm": algorithm_name,
}

# Backward hook that detect NAN in the gradient and print the layer weights
Expand Down

0 comments on commit e317676

Please sign in to comment.