diff --git a/scripts/recon/train_unrolled.py b/scripts/recon/train_unrolled.py index 4c3df596..bb8c1ae6 100644 --- a/scripts/recon/train_unrolled.py +++ b/scripts/recon/train_unrolled.py @@ -86,6 +86,35 @@ def simulate_dataset(config, psf): return ds_loader +def create_process_network(network, depth, device="cpu"): + if network == "DruNet": + from lensless.utils.image import load_drunet + + process = load_drunet( + os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True + ).to(device) + process_name = "DruNet" + elif network == "UnetRes": + from lensless.drunet.network_unet import UNetRes + + n_channels = 3 + process = UNetRes( + in_nc=n_channels + 1, + out_nc=n_channels, + nc=[64, 128, 256, 512], + nb=depth, + act_mode="R", + downsample_mode="strideconv", + upsample_mode="convtranspose", + ).to(device) + process_name = "UnetRes_d" + str(depth) + else: + process = None + process_name = None + + return (process, process_name) + + def measure_gradient(model): # return the L2 norm of the gradient total_norm = 0.0 @@ -138,45 +167,20 @@ def train_unrolled( save = os.getcwd() start_time = time.time() - # Load post process model - if config.reconstruction.post_process.network == "DruNet": - from lensless.utils.image import load_drunet - - post_process = load_drunet( - os.path.join(get_original_cwd(), "data/drunet_color.pth"), requires_grad=True - ).to(device) - elif config.reconstruction.post_process.network == "UnetRes": - from lensless.drunet.network_unet import UNetRes - - n_channels = 3 - post_process = UNetRes( - in_nc=n_channels + 1, - out_nc=n_channels, - nc=[64, 128, 256, 512], - nb=config.reconstruction.post_process.depth, - act_mode="R", - downsample_mode="strideconv", - upsample_mode="convtranspose", - ).to(device) - else: - post_process = None - - if config.reconstruction.pre_process.network == "UnetRes": - from lensless.drunet.network_unet import UNetRes - - n_channels = 3 - pre_process = UNetRes( - in_nc=n_channels + 1, - out_nc=n_channels, - nc=[64, 128, 256, 512], - nb=config.reconstruction.pre_process.depth, - act_mode="R", - downsample_mode="strideconv", - upsample_mode="convtranspose", - ).to(device) - else: - pre_process = None + # Load pre process model + pre_process, pre_process_name = create_process_network( + config.reconstruction.pre_process.network, + config.reconstruction.pre_process.depth, + device=device, + ) + # Load post process model + post_process, post_process_name = create_process_network( + config.reconstruction.post_process.network, + config.reconstruction.post_process.depth, + device=device, + ) + # create reconstruction algorithm if config.reconstruction.method == "unrolled_fista": recon = UnrolledFISTA( psf, @@ -203,6 +207,13 @@ def train_unrolled( else: raise ValueError(f"{config.reconstruction.method} is not a supported algorithm") + # constructing algorithm name by appending pre and post process + algorithm_name = config.reconstruction.method + if config.reconstruction.pre_process.network is not None: + algorithm_name = pre_process_name + "_" + algorithm_name + if config.reconstruction.post_process.network is not None: + algorithm_name += "_" + post_process_name + # print number of parameters print(f"Training model with {sum(p.numel() for p in recon.parameters())} parameters") # transform from BGR to RGB @@ -283,14 +294,6 @@ def learning_rate_function(epoch): scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=learning_rate_function) - # constructing algorithm name by appending pre and post process - algorithm = config.reconstruction.method - if config.reconstruction.post_process.network == "DruNet": - algorithm += "_DruNet" - elif config.reconstruction.post_process.network == "UnetRes": - algorithm += "_UnetRes" - if config.reconstruction.pre_process.network == "UnetRes": - algorithm = "PreUnetRes_" + algorithm metrics = { "LOSS": [], "MSE": [], @@ -301,7 +304,7 @@ def learning_rate_function(epoch): "SSIM": [], "ReconstructionError": [], "n_iter": n_iter, - "algorithm": algorithm, + "algorithm": algorithm_name, } # Backward hook that detect NAN in the gradient and print the layer weights