diff --git a/Code/autoencoder_model.py b/Code/autoencoder_model.py index 975d07e..c4e34d2 100644 --- a/Code/autoencoder_model.py +++ b/Code/autoencoder_model.py @@ -16,9 +16,9 @@ class Grey2RGBAutoEncoder(nn.Module): def __init__(self): super(Grey2RGBAutoEncoder, self).__init__() # Define the Encoder - self.encoder = self._make_layers([1, 64, 128, 256]) + self.encoder = self._make_layers([1, 8, 16, 32]) # Define the Decoder - self.decoder = self._make_layers([256, 128, 64, 3], decoder=True) + self.decoder = self._make_layers([32, 16, 8, 3], decoder=True) # Helper function to create the encoder or decoder layers. def _make_layers(self, channels, decoder=False): diff --git a/Code/main.py b/Code/main.py index 79e7ea7..883d85b 100644 --- a/Code/main.py +++ b/Code/main.py @@ -10,52 +10,82 @@ from losses import LossMSE, LossMEP, SSIMLoss from training import Trainer - # Import Necessary Libraries import os import traceback import torch +import torch.multiprocessing as mp +import torch.distributed as dist +import platform # Define Working Directories grayscale_dir = '../Dataset/Greyscale' rgb_dir = '../Dataset/RGB' # Define Universal Parameters -image_height = 400 -image_width = 600 +image_height = 4000 +image_width = 6000 batch_size = 2 - -def main(): +def get_backend(): + system_type = platform.system() + if system_type == "Linux": + return "nccl" + else: + return "gloo" + +def main_worker(rank, world_size): + # Set environment variables + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12345' + # Initialize the distributed environment. + torch.manual_seed(0) + torch.backends.cudnn.enabled = True + torch.backends.cudnn.benchmark = True + dist.init_process_group(backend=get_backend(), init_method="env://", world_size=world_size, rank=rank) + main(rank) # Call the existing main function. + +def main(rank): # Initialize Dataset Object (PyTorch Tensors) try: dataset = CustomDataset(grayscale_dir, rgb_dir, (image_height, image_width), batch_size) - print('Importing Dataset Complete.') + if rank == 0: + print('Importing Dataset Complete.') except Exception as e: - print(f"Importing Dataset In-Complete : \n{e}") + if rank == 0: + print(f"Importing Dataset In-Complete : \n{e}") + if rank == 0: + print('-'*20) # Makes Output Readable # Import Loss Functions try: loss_mse = LossMSE() # Mean Squared Error Loss loss_mep = LossMEP(alpha=0.4) # Maximum Entropy Loss loss_ssim = SSIMLoss() # Structural Similarity Index Measure Loss - print('Importing Loss Functions Complete.') + if rank == 0: + print('Importing Loss Functions Complete.') except Exception as e: - print(f"Importing Loss Functions In-Complete : \n{e}") - print('-'*20) # Makes Output Readable + if rank == 0: + print(f"Importing Loss Functions In-Complete : \n{e}") + if rank == 0: + print('-'*20) # Makes Output Readable # Initialize AutoEncoder Model and Import Dataloader (Training, Validation) data_autoencoder_train, data_autoencoder_val = dataset.get_autoencoder_batches(val_split=0.2) - print('AutoEncoder Model Data Imported.') + if rank == 0: + print('AutoEncoder Model Data Imported.') model_autoencoder = Grey2RGBAutoEncoder() - print('AutoEncoder Model Initialized.') - print('-'*20) # Makes Output Readable + if rank == 0: + print('AutoEncoder Model Initialized.') + print('-'*20) # Makes Output Readable # Initialize LSTM Model and Import Dataloader (Training, Validation) data_lstm_train, data_lstm_val = dataset.get_lstm_batches(val_split=0.25, sequence_length=2) - print('LSTM Model Data Imported.') + if rank == 0: + print('LSTM Model Data Imported.') model_lstm = ConvLSTM(input_dim=1, hidden_dims=[1,1,1], kernel_size=(3, 3), num_layers=3, alpha=0.5) - print('LSTM Model Initialized.') - print('-'*20) # Makes Output Readable + if rank == 0: + print('LSTM Model Initialized.') + print('-'*20) # Makes Output Readable ''' Initialize Trainer Objects @@ -63,34 +93,55 @@ def main(): # Method 1 : Baseline : Mean Squared Error Loss for AutoEncoder and LSTM os.makedirs('../Models/Method1', exist_ok=True) # Creating Directory for Model Saving model_save_path_ae = '../Models/Method1/model_autoencoder_m1.pth' - trainer_autoencoder_baseline = Trainer(model_autoencoder, loss_mse, optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), model_save_path=model_save_path_ae) - print('Method-1 AutoEncoder Trainer Initialized.') + trainer_autoencoder_baseline = Trainer(model=model_autoencoder, + loss_function=loss_mse, + optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), + model_save_path=model_save_path_ae, + rank=rank) + if rank == 0: + print('Method-1 AutoEncoder Trainer Initialized.') model_save_path_lstm = '../Models/Method1/model_lstm_m1.pth' - trainer_lstm_baseline = Trainer(model_lstm, loss_mse, optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001), model_save_path=model_save_path_lstm) - print('Method-1 LSTM Trainer Initialized.') - print('-'*10) # Makes Output Readable + trainer_lstm_baseline = Trainer(model=model_lstm, + loss_function=loss_mse, + optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001), + model_save_path=model_save_path_lstm, + rank=rank) + if rank == 0: + print('Method-1 LSTM Trainer Initialized.') + print('-'*10) # Makes Output Readable # Method 2 : Composite Loss (MSE + MaxEnt) for AutoEncoder and Mean Squared Error Loss for LSTM os.makedirs('../Models/Method2', exist_ok=True) # Creating Directory for Model Saving model_save_path_ae = '../Models/Method2/model_autoencoder_m2.pth' - trainer_autoencoder_m2 = Trainer(model=model_autoencoder, loss_function=loss_mep, optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), model_save_path=model_save_path_ae) - print('Method-2 AutoEncoder Trainer Initialized.') - print('Method-2 LSTM == Method-1 LSTM') - print('-'*10) # Makes Output Readable + trainer_autoencoder_m2 = Trainer(model=model_autoencoder, + loss_function=loss_mep, + optimizer=torch.optim.Adam(model_autoencoder.parameters(), lr=0.001), + model_save_path=model_save_path_ae, + rank=rank) + if rank == 0: + print('Method-2 AutoEncoder Trainer Initialized.') + print('Method-2 LSTM == Method-1 LSTM') + print('-'*10) # Makes Output Readable # Method 3 : Mean Squared Error Loss for AutoEncoder and SSIM Loss for LSTM os.makedirs('../Models/Method3', exist_ok=True) # Creating Directory for Model Saving - print('Method-3 AutoEncoder == Method-1 AutoEncoder') + if rank == 0: + print('Method-3 AutoEncoder == Method-1 AutoEncoder') model_save_path_lstm = '../Models/Method3/model_lstm_m3.pth' - trainer_lstm_m3 = Trainer(model_lstm, loss_ssim, optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001), model_save_path=model_save_path_lstm) - print('Method-3 LSTM Trainer Initialized.') - print('-'*10) # Makes Output Readable + trainer_lstm_m3 = Trainer(model=model_lstm, + loss_function=loss_ssim, + optimizer=torch.optim.Adam(model_lstm.parameters(), lr=0.001), + model_save_path=model_save_path_lstm, + rank=rank) + if rank == 0: + print('Method-3 LSTM Trainer Initialized.') + print('-'*10) # Makes Output Readable # Method 4 : Proposed Method : Composite Loss (MSE + MaxEnt) for AutoEncoder and SSIM Loss for LSTM - print('Method-4 AutoEncoder == Method-2 AutoEncoder') - print('Method-4 LSTM == Method-3 LSTM') - - print('-'*20) # Makes Output Readable + if rank == 0: + print('Method-4 AutoEncoder == Method-2 AutoEncoder') + print('Method-4 LSTM == Method-3 LSTM') + print('-'*20) # Makes Output Readable ''' @@ -99,55 +150,84 @@ def main(): # Method-1 try: epochs = 1 - print('Method-1 AutoEncoder Training Start') + if rank == 0: + print('Method-1 AutoEncoder Training Start') model_autoencoder_m1 = trainer_autoencoder_baseline.train_autoencoder(epochs, data_autoencoder_train, data_autoencoder_val) - print('Method-1 AutoEncoder Training Complete.') + if rank == 0: + print('Method-1 AutoEncoder Training Complete.') except Exception as e: - print(f"Method-1 AutoEncoder Training Error : \n{e}") + if rank == 0: + print(f"Method-1 AutoEncoder Training Error : \n{e}") traceback.print_exc() - print('-'*10) # Makes Output Readable + finally: + if rank == 0: + trainer_autoencoder_baseline.cleanup_ddp() + if rank == 0: + print('-'*10) # Makes Output Readable try: epochs = 1 - print('Method-1 LSTM Training Start') + if rank == 0: + print('Method-1 LSTM Training Start') model_lstm_m1 = trainer_lstm_baseline.train_lstm(epochs, data_lstm_train, data_lstm_val) - print('Method-1 LSTM Training Complete.') + if rank == 0: + print('Method-1 LSTM Training Complete.') except Exception as e: - print(f"Method-1 LSTM Training Error : \n{e}") + if rank == 0: + print(f"Method-1 LSTM Training Error : \n{e}") traceback.print_exc() - print('-'*20) # Makes Output Readable + finally: + if rank == 0: + trainer_lstm_baseline.cleanup_ddp() + if rank == 0: + print('-'*20) # Makes Output Readable # Method-2 try: epochs = 1 - print('Method-2 AutoEncoder Training Start') + if rank == 0: + print('Method-2 AutoEncoder Training Start') model_autoencoder_m2 = trainer_autoencoder_m2.train_autoencoder(epochs, data_autoencoder_train, data_autoencoder_val) - print('Method-2 AutoEncoder Training Complete.') + if rank == 0: + print('Method-2 AutoEncoder Training Complete.') except Exception as e: - print(f"Method-2 AutoEncoder Training Error : \n{e}") + if rank == 0: + print(f"Method-2 AutoEncoder Training Error : \n{e}") traceback.print_exc() - print('-'*10) # Makes Output Readable - print("Method-2 LSTM == Method-1 LSTM, No Need To Train Again.") - print('-'*20) # Makes Output Readable + finally: + trainer_autoencoder_m2.cleanup_ddp() + if rank == 0: + print('-'*10) # Makes Output Readable + print("Method-2 LSTM == Method-1 LSTM, No Need To Train Again.") + print('-'*20) # Makes Output Readable # Method-3 - print("Method-3 AutoEncoder == Method-1 AutoEncoder, No Need To Train Again.") - print('-'*10) # Makes Output Readable + if rank == 0: + print("Method-3 AutoEncoder == Method-1 AutoEncoder, No Need To Train Again.") + print('-'*10) # Makes Output Readable try: epochs = 1 - print('Method-3 LSTM Training Start.') + if rank == 0: + print('Method-3 LSTM Training Start.') model_lstm_m3 = trainer_lstm_m3.train_lstm(epochs, data_lstm_train, data_lstm_val) - print('Method-3 LSTM Training Complete.') + if rank == 0: + print('Method-3 LSTM Training Complete.') except Exception as e: - print(f"Method-3 LSTM Training Error : \n{e}") + if rank == 0: + print(f"Method-3 LSTM Training Error : \n{e}") traceback.print_exc() - print('-'*20) # Makes Output Readable + finally: + trainer_lstm_m3.cleanup_ddp() + if rank == 0: + print('-'*20) # Makes Output Readable # Method-4 - print("Method-4 AutoEncoder == Method-2 AutoEncoder, No Need To Train Again.") - print('-'*10) # Makes Output Readable - print("Method-4 LSTM == Method-3 LSTM, No Need To Train Again.") - print('-'*20) # Makes Output Readable + if rank == 0: + print("Method-4 AutoEncoder == Method-2 AutoEncoder, No Need To Train Again.") + print('-'*10) # Makes Output Readable + print("Method-4 LSTM == Method-3 LSTM, No Need To Train Again.") + print('-'*20) # Makes Output Readable if __name__ == '__main__': - main() + world_size = torch.cuda.device_count() # Number of available GPUs + mp.spawn(main_worker, args=(world_size,), nprocs=world_size, join=True) \ No newline at end of file diff --git a/Code/training.py b/Code/training.py index 0639c45..373eb6e 100644 --- a/Code/training.py +++ b/Code/training.py @@ -10,30 +10,37 @@ # Import Necessary Libraries import torch -import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel as DDP +import torch.distributed as dist # Define Training Class class Trainer(): - def __init__(self, model, loss_function, optimizer=None, model_save_path=None): - # Use All Available CUDA GPUs for Training (if Available) - self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') - if torch.cuda.device_count() > 1: - model = nn.DataParallel(model) + def __init__(self, model, loss_function, optimizer=None, model_save_path=None, rank=None): + self.rank = rank # Rank of the current process + self.device = torch.device(f'cuda:{rank}' if torch.cuda.is_available() else 'cpu') self.model = model.to(self.device) # Define the loss function self.loss_function = loss_function # Define the optimizer self.optimizer = optimizer if optimizer is not None else torch.optim.Adam(self.model.parameters(), lr=0.001) + # Wrap model with DDP + if torch.cuda.device_count() > 1 and rank is not None: + self.model = DDP(self.model, device_ids=[rank], find_unused_parameters=True) # Define the path to save the model - self.model_save_path = model_save_path + self.model_save_path = model_save_path if rank == 0 else None # Only save on master process + + def cleanup_ddp(self): + if dist.is_initialized(): + dist.destroy_process_group() def save_model(self): - # Save the model - torch.save(self.model.state_dict(), self.model_save_path) + if self.rank == 0: + # Save the model + torch.save(self.model.state_dict(), self.model_save_path) def train_autoencoder(self, epochs, train_loader, val_loader): # Print Names of All Available GPUs (if any) to Train the Model - if torch.cuda.device_count() > 0: + if torch.cuda.device_count() > 0 and self.rank == 0: gpu_names = ', '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]) print("\tGPUs being used for Training : ",gpu_names) best_val_loss = float('inf') @@ -54,7 +61,8 @@ def train_autoencoder(self, epochs, train_loader, val_loader): val_loss = sum(self.loss_function(self.model(input.to(self.device)), target.to(self.device)).item() for input, target in val_loader) # Compute Total Validation Loss val_loss /= len(val_loader) # Compute Average Validation Loss # Print epochs and losses - print(f'\tAutoEncoder Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}') + if self.rank == 0: + print(f'\tAutoEncoder Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}') # If the current validation loss is lower than the best validation loss, save the model if val_loss < best_val_loss: best_val_loss = val_loss # Update the best validation loss @@ -64,7 +72,7 @@ def train_autoencoder(self, epochs, train_loader, val_loader): def train_lstm(self, epochs, train_loader, val_loader): # Print Names of All Available GPUs (if any) to Train the Model - if torch.cuda.device_count() > 0: + if torch.cuda.device_count() > 0 and self.rank == 0: gpu_names = ', '.join([torch.cuda.get_device_name(i) for i in range(torch.cuda.device_count())]) print("\tGPUs being used for Training : ",gpu_names) best_val_loss = float('inf') @@ -88,7 +96,8 @@ def train_lstm(self, epochs, train_loader, val_loader): val_loss += self.loss_function(output_sequence, target_sequence).item() # Accumulate loss val_loss /= len(val_loader) # Average validation loss # Print epochs and losses - print(f'\tLSTM Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}') + if self.rank == 0: + print(f'\tLSTM Epoch {epoch+1}/{epochs} --- Training Loss: {loss.item()} --- Validation Loss: {val_loss}') # Model saving based on validation loss if val_loss < best_val_loss: best_val_loss = val_loss