From 51b97ef71631a7c466119bba086381ac91ce2b09 Mon Sep 17 00:00:00 2001 From: gwirn <71886945+gwirn@users.noreply.github.com> Date: Fri, 31 May 2024 10:10:12 +0200 Subject: [PATCH] added formatting to verbose and logfile, fixed multiprocessing porblem and GPU usage in analysis_example --- README.md | 5 + environment.yml | 1 + examples/analysis_example.py | 144 +++++++------ src/molearn/trainers/trainer.py | 372 +++++++++++++++++++------------- 4 files changed, 310 insertions(+), 212 deletions(-) diff --git a/README.md b/README.md index 736c87a..6636a87 100644 --- a/README.md +++ b/README.md @@ -63,6 +63,11 @@ Manual installation requires the following three steps: #### Using molearn without installation #### Molearn can used without installation by making the sure the requirements above are met, and adding the `src` directory to your path at the beginning of every script, e.g.: +installation using conda while creating a new environment `molearn_env` +``` +conda env create --file environment.yml -n molearn_env +``` + ``` import sys sys.path.insert(0, 'path/to/molearn/src') diff --git a/environment.yml b/environment.yml index 565af49..b68e492 100644 --- a/environment.yml +++ b/environment.yml @@ -18,5 +18,6 @@ dependencies: - ipywidgets - plotly - nglview + - openmmtorchplugin - pip: - geomloss diff --git a/examples/analysis_example.py b/examples/analysis_example.py index 7e7c8b8..aec3673 100644 --- a/examples/analysis_example.py +++ b/examples/analysis_example.py @@ -6,69 +6,81 @@ import matplotlib.pyplot as plt -print("> Loading network parameters...") - -fname = f'xbb_foldingnet_checkpoints{os.sep}checkpoint_no_optimizer_state_dict_epoch167_loss0.003259085263643.ckpt' -# change 'cpu' to 'cuda' if you have a suitable cuda enabled device -checkpoint = torch.load(fname, map_location=torch.device('cpu')) -net = AutoEncoder(**checkpoint['network_kwargs']) -net.load_state_dict(checkpoint['model_state_dict']) - -print("> Loading training data...") - -MA = MolearnAnalysis() -MA.set_network(net) - -# increasing the batch size makes encoding/decoding operations faster, -# but more memory demanding -MA.batch_size = 4 - -# increasing processes makes DOPE and Ramachandran scores calculations faster, -# but more more memory demanding -MA.processes = 2 - -# what follows is a method to re-create the training and test set -# by defining the manual see and loading the dataset in the same order as when -#the neural network was trained, the same train-test split will be obtained -data = PDBData() -data.import_pdb(f'data{os.sep}MurD_closed_selection.pdb') -data.import_pdb(f'data{os.sep}MurD_open_selection.pdb') -data.fix_terminal() -data.atomselect(atoms = ['CA', 'C', 'N', 'CB', 'O']) -data.prepare_dataset() -data_train, data_test = data.split(manual_seed=25) - -# store the training and test set in the MolearnAnalysis instance -# the second parameter of the sollowing commands can be both a PDBData instance -# or a path to a multi-PDB file -MA.set_dataset("training", data_train) -MA.set_dataset("test", data_test) - -print("> calculating RMSD of training and test set") - -err_train = MA.get_error('training') -err_test = MA.get_error('test') - -print(f'Mean RMSD is {err_train.mean()} for training set and {err_test.mean()} for test set') -fig, ax = plt.subplots() -violin = ax.violinplot([err_train, err_test], showmeans = True, ) -ax.set_xticks([1,2]) -ax.set_title('RMSD of training and test set') -ax.set_xticklabels(['Training', 'Test']) -plt.savefig('RMSD_plot.png') - - -print("> generating error landscape") -# build a 50x50 grid. By default, it will be 10% larger than the region occupied -# by all loaded datasets -MA.setup_grid(50) -landscape_err_latent, landscape_err_3d, xaxis, yaxis = MA.scan_error() - -fig, ax = plt.subplots() -c = ax.pcolormesh(xaxis, yaxis, landscape_err_latent) -plt.savefig('Error_grid.png') - - -## to visualise the GUI, execute the code above in a Jupyter notebook, then call: -# from molearn.analysis import MolearnGUI -# MolearnGUI(MA) \ No newline at end of file +def main(): + print("> Loading network parameters...") + + fname = f"xbb_foldingnet_checkpoints{os.sep}checkpoint_no_optimizer_state_dict_epoch167_loss0.003259085263643.ckpt" + # if GPU is available we will use the GPU else the CPU + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + checkpoint = torch.load(fname, map_location=device) + net = AutoEncoder(**checkpoint["network_kwargs"]) + net.load_state_dict(checkpoint["model_state_dict"]) + if torch.cuda.is_available(): + # otherwise net is still not on the GPU + net.to(device) + + print("> Loading training data...") + + MA = MolearnAnalysis() + MA.set_network(net) + + # increasing the batch size makes encoding/decoding operations faster, + # but more memory demanding + MA.batch_size = 4 + + # increasing processes makes DOPE and Ramachandran scores calculations faster, + # but more more memory demanding + MA.processes = 2 + + # what follows is a method to re-create the training and test set + # by defining the manual see and loading the dataset in the same order as when + # the neural network was trained, the same train-test split will be obtained + data = PDBData() + data.import_pdb(f"data{os.sep}MurD_closed_selection.pdb") + data.import_pdb(f"data{os.sep}MurD_open_selection.pdb") + data.fix_terminal() + data.atomselect(atoms=["CA", "C", "N", "CB", "O"]) + data.prepare_dataset() + data_train, data_test = data.split(manual_seed=25) + + # store the training and test set in the MolearnAnalysis instance + # the second parameter of the sollowing commands can be both a PDBData instance + # or a path to a multi-PDB file + MA.set_dataset("training", data_train) + MA.set_dataset("test", data_test) + + print("> calculating RMSD of training and test set") + + err_train = MA.get_error("training") + err_test = MA.get_error("test") + + print( + f"Mean RMSD is {err_train.mean()} for training set and {err_test.mean()} for test set" + ) + fig, ax = plt.subplots() + _ = ax.violinplot( + [err_train, err_test], + showmeans=True, + ) + ax.set_xticks([1, 2]) + ax.set_title("RMSD of training and test set") + ax.set_xticklabels(["Training", "Test"]) + plt.savefig("RMSD_plot.png") + + print("> generating error landscape") + # build a 50x50 grid. By default, it will be 10% larger than the region occupied + # by all loaded datasets + MA.setup_grid(50) + landscape_err_latent, landscape_err_3d, xaxis, yaxis = MA.scan_error() + + fig, ax = plt.subplots() + _ = ax.pcolormesh(xaxis, yaxis, landscape_err_latent) + plt.savefig("Error_grid.png") + + ## to visualise the GUI, execute the code above in a Jupyter notebook, then call: + # from molearn.analysis import MolearnGUI + # MolearnGUI(MA) + + +if __name__ == "__main__": + main() diff --git a/src/molearn/trainers/trainer.py b/src/molearn/trainers/trainer.py index 866e6a1..d3eec39 100644 --- a/src/molearn/trainers/trainer.py +++ b/src/molearn/trainers/trainer.py @@ -13,7 +13,7 @@ class TrainingFailure(Exception): class Trainer: - ''' + """ Trainer class that defines a number of useful methods for training an autoencoder. :ivar autoencoder: any torch.nn.module network that has methods ``autoencoder.encode`` and ``autoencoder.decode`` with the weights associated with these operations accessible via ``autoencoder.encoder`` and ``autoencoder.decoder``. This can be set using set_autoencoder @@ -29,32 +29,41 @@ class Trainer: :ivar torch.Dataloader valid_dataloader: Validation data :ivar _data: (:func:`molearn.data ` Data object given to :func:`set_data ` - ''' - - def __init__(self, device=None, log_filename='log_file.dat'): - ''' + """ + + def __init__(self, device=None, log_filename="log_file.dat"): + """ :param torch.Device device: if not given will be determinined automatically based on torch.cuda.is_available() :param str log_filename: (default: 'default_log_filename.json') file used to log outputs to - ''' + """ if not device: - self.device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') + self.device = ( + torch.device("cuda") + if torch.cuda.is_available() + else torch.device("cpu") + ) else: self.device = device - print(f'device: {self.device}') + print(f"device: {self.device}") self.best = None self.best_name = None self.epoch = 0 self.scheduler = None self.verbose = True - self.log_filename = 'default_log_filename.json' + self.log_filename = "default_log_filename.json" self.scheduler_key = None def get_network_summary(self): - ''' + """ returns a dictionary containing information about the size of the autoencoder. - ''' + """ + def get_parameters(trainable_only, model): - return sum(p.numel() for p in model.parameters() if (p.requires_grad and trainable_only)) + return sum( + p.numel() + for p in model.parameters() + if (p.requires_grad and trainable_only) + ) return dict( encoder_trainable=get_parameters(True, self.autoencoder.encoder), @@ -62,13 +71,14 @@ def get_parameters(trainable_only, model): decoder_trainable=get_parameters(True, self.autoencoder.decoder), decoder_total=get_parameters(False, self.autoencoder.decoder), autoencoder_trainable=get_parameters(True, self.autoencoder), - autoencoder_total=get_parameters(False, self.autoencoder)) + autoencoder_total=get_parameters(False, self.autoencoder), + ) def set_autoencoder(self, autoencoder, **kwargs): - ''' + """ :param autoencoder: (:func:`autoencoder `,) torch network class that implements ``autoencoder.encode``, and ``autoencoder.decode``. Please pass the class not the instance :param \*\*kwargs: any other kwargs given to this method will be used to initialise the network ``self.autoencoder = autoencoder(**kwargs)`` - ''' + """ if isinstance(autoencoder, type): self.autoencoder = autoencoder(**kwargs).to(self.device) else: @@ -76,35 +86,37 @@ def set_autoencoder(self, autoencoder, **kwargs): self._autoencoder_kwargs = kwargs def set_dataloader(self, train_dataloader=None, valid_dataloader=None): - ''' + """ :param torch.DataLoader train_dataloader: Alternatively set using ``trainer.train_dataloader = dataloader`` - :param torch.DataLoader valid_dataloader: Alternatively set using ``trainer.valid_dataloader = dataloader`` - ''' + :param torch.DataLoader valid_dataloader: Alternatively set using ``trainer.valid_dataloader = dataloader`` + """ if train_dataloader is not None: self.train_dataloader = train_dataloader if valid_dataloader is not None: self.valid_dataloader = valid_dataloader def set_data(self, data, **kwargs): - ''' + """ Sets up internal variables and gives trainer access to dataloaders. ``self.train_dataloader``, ``self.valid_dataloader``, ``self.std``, ``self.mean``, ``self.mol`` will all be obtained from this object. :param :func:`PDBData ` data: data object to be set. :param \*\*kwargs: will be passed on to :func:`data.get_dataloader(**kwargs) ` - ''' + """ if isinstance(data, PDBData): self.set_dataloader(*data.get_dataloader(**kwargs)) else: - raise NotImplementedError('Have not implemented this method to use any data other than PDBData yet') + raise NotImplementedError( + "Have not implemented this method to use any data other than PDBData yet" + ) self.std = data.std self.mean = data.mean self.mol = data.mol self._data = data def prepare_optimiser(self, lr=1e-3, weight_decay=0.0001, **optimiser_kwargs): - ''' + """ The Default optimiser is ``AdamW`` and is saved in ``self.optimiser``. With no optional arguments this function is the same as doing: ``trainer.optimiser = torch.optim.AdawW(self.autoencoder.parameters(), lr=1e-3, weight_decay = 0.0001)`` @@ -112,30 +124,52 @@ def prepare_optimiser(self, lr=1e-3, weight_decay=0.0001, **optimiser_kwargs): :param float lr: (default: 1e-3) optimiser learning rate. :param float weight_decay: (default: 0.0001) optimiser weight_decay :param \*\*optimiser_kwargs: other kwargs that are passed onto AdamW - ''' - self.optimiser = torch.optim.AdamW(self.autoencoder.parameters(), lr=lr, weight_decay=weight_decay, **optimiser_kwargs) + """ + self.optimiser = torch.optim.AdamW( + self.autoencoder.parameters(), + lr=lr, + weight_decay=weight_decay, + **optimiser_kwargs, + ) def log(self, log_dict, verbose=None): - ''' + """ Then contents of log_dict are dumped using ``json.dumps(log_dict)`` and printed and/or appended to ``self.log_filename`` This function is called from :func:`self.run ` :param dict log_dict: dictionary to be printed or saved :param bool verbose: (default: False) if True or self.verbose is true the output will be printed - ''' + """ - dump = json.dumps(log_dict) if verbose or self.verbose: - print(dump) - with open(self.log_filename, 'a') as f: - f.write(dump+'\n') + max_key_len = max([len(k) for k in log_dict.keys()]) + if "epoch" in log_dict: + cur_epoch = log_dict["epoch"] + print(f"{'epoch': <{max_key_len+1}}: {cur_epoch}") + for k, v in log_dict.items(): + if k != "epoch": + print(f"{k: <{max_key_len+1}}: {v:.6f}") + print() + + # create header if file doesn't exist => first epoch + if not os.path.isfile(self.log_filename): + with open(self.log_filename, "a") as f: + f.write(f"{','.join([str(k) for k in log_dict.keys()])}\n") + + with open(self.log_filename, "a") as f: + # just try to format if it is not a Failure + if "Failure" not in log_dict.values(): + f.write(f"{','.join([str(v) for v in log_dict.values()])}\n") + else: + dump = json.dumps(log_dict) + f.write(dump + "\n") def scheduler_step(self, logs): - ''' + """ This function does nothing. It is called after :func:`self.valid_epoch ` in :func:`Trainer.run() ` and before :func:`checkpointing `. It is designed to be overridden if you wish to use a scheduler. - :param dict logs: Dictionary passed passed containing all logs returned from ``self.train_epoch`` and ``self.valid_epoch``. - ''' + :param dict logs: Dictionary passed passed containing all logs returned from ``self.train_epoch`` and ``self.valid_epoch``. + """ pass def prepare_logs(self, log_filename, log_folder=None): @@ -143,14 +177,23 @@ def prepare_logs(self, log_filename, log_folder=None): if log_folder is not None: if not os.path.exists(log_folder): os.mkdir(log_folder) - if hasattr(self, "_repeat") and self._repeat >0: - self.log_filename = f'{log_folder}/{self._repeat}_{self.log_filename}' + if hasattr(self, "_repeat") and self._repeat > 0: + self.log_filename = f"{log_folder}/{self._repeat}_{self.log_filename}" else: - self.log_filename = f'{log_folder}/{self.log_filename}' - - - def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_frequency=1, checkpoint_folder='checkpoint_folder', allow_n_failures=10, verbose=None, allow_grad_in_valid=False): - ''' + self.log_filename = f"{log_folder}/{self.log_filename}" + + def run( + self, + max_epochs=100, + log_filename=None, + log_folder=None, + checkpoint_frequency=1, + checkpoint_folder="checkpoint_folder", + allow_n_failures=10, + verbose=None, + allow_grad_in_valid=False, + ): + """ Calls the following in a loop: - :func:`Trainer.train_epoch ` @@ -161,17 +204,19 @@ def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_fre - :func:`Trainer.log ` :param int max_epochs: (default: 100). run until ``self.epoch`` matches max_epochs - :param str log_filename: (default: None) If log_filename already exists, all logs are appended to the existing file. Else new log file file is created. + :param str log_filename: (default: None) If log_filename already exists, all logs are appended to the existing file. Else new log file file is created. :param str log_folder: (default: None) If not None log_folder directory is created and the log file is saved within this folder :param int checkpoint_frequency: (default: 1) The frequency at which last.ckpt is saved. A checkpoint is saved every epoch if ``'valid_loss'`` is lower else when ``self.epoch`` is divisible by checkpoint_frequency. :param str checkpoint_folder: (default: 'checkpoint_folder') Where to save checkpoints. :param int allow_n_failures: (default: 10) How many times should training be restarted on error. Each epoch is run in a try except block. If an error is raised training is continued from the best checkpoint. - :param bool verbose: (default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename + :param bool verbose: (default: None) set trainer.verbose. If True, the epoch logs will be printed as well as written to log_filename - ''' + """ self.get_repeat(checkpoint_folder) - self.prepare_logs(log_filename if log_filename is not None else self.log_filename, log_folder) - #if log_filename is not None: + self.prepare_logs( + log_filename if log_filename is not None else self.log_filename, log_folder + ) + # if log_filename is not None: # self.log_filename = log_filename # if log_folder is not None: # if not os.path.exists(log_folder): @@ -193,34 +238,38 @@ def run(self, max_epochs=100, log_filename=None, log_folder=None, checkpoint_fre logs.update(self.valid_epoch(epoch)) time3 = time.time() self.scheduler_step(logs) - if self.best is None or self.best > logs['valid_loss']: + if self.best is None or self.best > logs["valid_loss"]: self.checkpoint(epoch, logs, checkpoint_folder) elif epoch % checkpoint_frequency == 0: self.checkpoint(epoch, logs, checkpoint_folder) time4 = time.time() - logs.update(epoch=epoch, - train_seconds=time2-time1, - valid_seconds=time3-time2, - checkpoint_seconds=time4-time3, - total_seconds=time4-time1) + logs.update( + epoch=epoch, + train_seconds=time2 - time1, + valid_seconds=time3 - time2, + checkpoint_seconds=time4 - time3, + total_seconds=time4 - time1, + ) self.log(logs) - if np.isnan(logs['valid_loss']) or np.isnan(logs['train_loss']): - raise TrainingFailure('nan received, failing') - self.epoch+= 1 + if np.isnan(logs["valid_loss"]) or np.isnan(logs["train_loss"]): + raise TrainingFailure("nan received, failing") + self.epoch += 1 except TrainingFailure: - if attempt == (allow_n_failures-1): - failure_message = f'Training Failure due to Nan in attempt {attempt}, end now/n' - self.log({'Failure':failure_message}) - raise TrainingFailure('nan received, failing') - failure_message = f'Training Failure due to Nan in attempt {attempt}, try again from best/n' - self.log({'Failure':failure_message}) - if hasattr(self, 'best'): - self.load_checkpoint('best', checkpoint_folder) + if attempt == (allow_n_failures - 1): + failure_message = ( + f"Training Failure due to Nan in attempt {attempt}, end now\n" + ) + self.log({"Failure": failure_message}) + raise TrainingFailure("nan received, failing") + failure_message = f"Training Failure due to Nan in attempt {attempt}, try again from best\n" + self.log({"Failure": failure_message}) + if hasattr(self, "best"): + self.load_checkpoint("best", checkpoint_folder) else: break - def train_epoch(self,epoch): - ''' + def train_epoch(self, epoch): + """ Train one epoch. Called once an epoch from :func:`trainer.run ` This method performs the following functions: - Sets network to train mode via ``self.autoencoder.train()`` @@ -231,12 +280,12 @@ def train_epoch(self,epoch): * Determine gradients using keyword ``'loss'`` e.g. ``result['loss'].backward()`` * Update network gradients. ``self.optimiser.step`` - - All results are aggregated via averaging and returned with ``'train_'`` prepended on the dictionary key + - All results are aggregated via averaging and returned with ``'train_'`` prepended on the dictionary key :param int epoch: The epoch is passed as an argument however epoch number can also be accessed from self.epoch. :returns: Return all results from train_step averaged. These results will be printed and/or logged in :func:`trainer.run() ` via a call to :func:`self.log(results) ` :rtype: dict - ''' + """ self.autoencoder.train() N = 0 results = {} @@ -244,59 +293,62 @@ def train_epoch(self,epoch): batch = batch[0].to(self.device) self.optimiser.zero_grad() train_result = self.train_step(batch) - train_result['loss'].backward() + train_result["loss"].backward() self.optimiser.step() if i == 0: - results = {key:value.item()*len(batch) for key, value in train_result.items()} + results = { + key: value.item() * len(batch) + for key, value in train_result.items() + } else: for key in train_result.keys(): - results[key] += train_result[key].item()*len(batch) - N+=len(batch) - return {f'train_{key}': results[key]/N for key in results.keys()} + results[key] += train_result[key].item() * len(batch) + N += len(batch) + return {f"train_{key}": results[key] / N for key in results.keys()} def train_step(self, batch): - ''' + """ Called from :func:`Trainer.train_epoch `. :param torch.Tensor batch: Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``. :returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that :func:`self.train_epoch ` will call ``result['loss'].backwards()`` to obtain gradients. :rtype: dict - ''' + """ results = self.common_step(batch) - results['loss'] = results['mse_loss'] + results["loss"] = results["mse_loss"] return results def common_step(self, batch): - ''' + """ Called from both train_step and valid_step. - Calculates the mean squared error loss for self.autoencoder. - Encoded and decoded frames are saved in self._internal under keys ``encoded`` and ``decoded`` respectively should you wish to use them elsewhere. + Calculates the mean squared error loss for self.autoencoder. + Encoded and decoded frames are saved in self._internal under keys ``encoded`` and ``decoded`` respectively should you wish to use them elsewhere. :param torch.Tensor batch: Tensor of shape [Batch size, 3, Number of Atoms] A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``. - :returns: Return calculated mse_loss + :returns: Return calculated mse_loss :rtype: dict - ''' + """ self._internal = {} encoded = self.autoencoder.encode(batch) - self._internal['encoded'] = encoded - decoded = self.autoencoder.decode(encoded)[:,:,:batch.size(2)] - self._internal['decoded'] = decoded - return dict(mse_loss=((batch-decoded)**2).mean()) + self._internal["encoded"] = encoded + decoded = self.autoencoder.decode(encoded)[:, :, : batch.size(2)] + self._internal["decoded"] = decoded + return dict(mse_loss=((batch - decoded) ** 2).mean()) def valid_epoch(self, epoch): - ''' + """ Called once an epoch from :func:`trainer.run ` within a no_grad context. This method performs the following functions: - Sets network to eval mode via ``self.autoencoder.eval()`` - for each batch in ``self.valid_dataloader`` calls :func:`trainer.valid_step ` to retrieve validation loss - - All results are aggregated via averaging and returned with ``'valid_'`` prepended on the dictionary key + - All results are aggregated via averaging and returned with ``'valid_'`` prepended on the dictionary key * The loss with key ``'loss'`` is returned as ``'valid_loss'`` this will be the loss value by which the best checkpoint is determined. :param int epoch: The epoch is passed as an argument however epoch number can also be accessed from self.epoch. :returns: Return all results from valid_step averaged. These results will be printed and/or logged in :func:`Trainer.run() ` via a call to :func:`self.log(results) ` :rtype: dict - ''' + """ self.autoencoder.eval() N = 0 results = {} @@ -304,51 +356,62 @@ def valid_epoch(self, epoch): batch = batch[0].to(self.device) valid_result = self.valid_step(batch) if i == 0: - results = {key:value.item()*len(batch) for key, value in valid_result.items()} + results = { + key: value.item() * len(batch) + for key, value in valid_result.items() + } else: for key in valid_result.keys(): - results[key] += valid_result[key].item()*len(batch) - N+=len(batch) - return {f'valid_{key}': results[key]/N for key in results.keys()} + results[key] += valid_result[key].item() * len(batch) + N += len(batch) + return {f"valid_{key}": results[key] / N for key in results.keys()} def valid_step(self, batch): - ''' + """ Called from :func:`Trainer.valid_epoch` on every mini-batch. :param torch.Tensor batch: Tensor of shape [Batch size, 3, Number of Atoms]. A mini-batch of protein frames normalised. To recover original data multiple by ``self.std``. :returns: Return loss. The dictionary must contain an entry with key ``'loss'`` that will be the score via which the best checkpoint is determined. :rtype: dict - ''' + """ results = self.common_step(batch) - results['loss'] = results['mse_loss'] + results["loss"] = results["mse_loss"] return results - def learning_rate_sweep(self, max_lr=100, min_lr=1e-5, number_of_iterations=1000, checkpoint_folder='checkpoint_sweep', train_on='mse_loss', save=['loss', 'mse_loss']): - ''' + def learning_rate_sweep( + self, + max_lr=100, + min_lr=1e-5, + number_of_iterations=1000, + checkpoint_folder="checkpoint_sweep", + train_on="mse_loss", + save=["loss", "mse_loss"], + ): + """ Deprecated method. - Performs a sweep of learning rate between ``max_lr`` and ``min_lr`` over ``number_of_iterations``. + Performs a sweep of learning rate between ``max_lr`` and ``min_lr`` over ``number_of_iterations``. See `Finding Good Learning Rate and The One Cycle Policy `_ :param float max_lr: (default: 100.0) final/maximum learning rate to be used :param float min_lr: (default: 1e-5) Starting learning rate - :param int number_of_iterations: (default: 1000) Number of steps to run sweep over. + :param int number_of_iterations: (default: 1000) Number of steps to run sweep over. :param str train_on: (default: 'mse_loss') key returned from trainer.train_step(batch) on which to train :param list save: (default: ['loss', 'mse_loss']) what loss values to return. :returns: array of shape [len(save), min(number_of_iterations, iterations before NaN)] containing loss values defined in `save` key word. :rtype: numpy.ndarray - ''' + """ self.autoencoder.train() - + def cycle(iterable): while True: for i in iterable: yield i - + init_loss = 0.0 values = [] data = iter(cycle(self.train_dataloader)) for i in range(number_of_iterations): - lr = min_lr*((max_lr/min_lr)**(i/number_of_iterations)) + lr = min_lr * ((max_lr / min_lr) ** (i / number_of_iterations)) self.update_optimiser_hyperparameters(lr=lr) batch = next(data)[0].to(self.device).float() @@ -357,103 +420,120 @@ def cycle(iterable): # result['loss']/=len(batch) result[train_on].backward() self.optimiser.step() - values.append((lr,)+tuple((result[name].item() for name in save))) - if i==0: + values.append((lr,) + tuple((result[name].item() for name in save))) + if i == 0: init_loss = result[train_on].item() # if result[train_on].item()>1e6*init_loss: # break values = np.array(values) - print('min value ', values[np.nanargmin(values[:,1])]) + print("min value ", values[np.nanargmin(values[:, 1])]) return values def update_optimiser_hyperparameters(self, **kwargs): - ''' + """ Update optimeser hyperparameter e.g. ``trainer.update_optimiser_hyperparameters(lr = 1e3)`` :param \*\*kwargs: each key value pair in \*\*kwargs is inserted into ``self.optimiser`` - ''' + """ for g in self.optimiser.param_groups: for key, value in kwargs.items(): g[key] = value - def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key='valid_loss'): - ''' + def checkpoint(self, epoch, valid_logs, checkpoint_folder, loss_key="valid_loss"): + """ Checkpoint the current network. The checkpoint will be saved as ``'last.ckpt'``. If valid_logs[loss_key] is better than self.best then this checkpoint will replace self.best and ``'last.ckpt'`` will be renamed to ``f'{checkpoint_folder}/checkpoint_epoch{epoch}_loss{valid_loss}.ckpt'`` and the former best (filename saved as ``self.best_name``) will be deleted :param int epoch: current epoch, will be saved within the ckpt. Current epoch can usually be obtained with ``self.epoch`` - :param dict valid_logs: results dictionary containing loss_key. - :param str checkpoint_folder: The folder in which to save the checkpoint. + :param dict valid_logs: results dictionary containing loss_key. + :param str checkpoint_folder: The folder in which to save the checkpoint. :param str loss_key: (default: 'valid_loss') The key with which to get loss from valid_logs. - ''' + """ valid_loss = valid_logs[loss_key] if not os.path.exists(checkpoint_folder): os.mkdir(checkpoint_folder) - torch.save({'epoch':epoch, - 'model_state_dict': self.autoencoder.state_dict(), - 'optimizer_state_dict': self.optimiser.state_dict(), - 'loss': valid_loss, - 'network_kwargs': self._autoencoder_kwargs, - 'atoms': self._data.atoms, - 'std': self.std, - 'mean': self.mean}, - f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat > 0 else ""}.ckpt') + torch.save( + { + "epoch": epoch, + "model_state_dict": self.autoencoder.state_dict(), + "optimizer_state_dict": self.optimiser.state_dict(), + "loss": valid_loss, + "network_kwargs": self._autoencoder_kwargs, + "atoms": self._data.atoms, + "std": self.std, + "mean": self.mean, + }, + f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat > 0 else ""}.ckpt', + ) if self.best is None or self.best > valid_loss: filename = f'{checkpoint_folder}/checkpoint{f"_{self._repeat}" if self._repeat>0 else ""}_epoch{epoch}_loss{valid_loss}.ckpt' - shutil.copyfile(f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt', filename) + shutil.copyfile( + f'{checkpoint_folder}/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt', + filename, + ) if self.best is not None: os.remove(self.best_name) self.best_name = filename self.best_epoch = epoch self.best = valid_loss - def load_checkpoint(self, checkpoint_name='best', checkpoint_folder='', load_optimiser=True): - ''' - Load checkpoint. + def load_checkpoint( + self, checkpoint_name="best", checkpoint_folder="", load_optimiser=True + ): + """ + Load checkpoint. :param str checkpoint_name: (default: ``'best'``) if ``'best'`` then checkpoint_folder is searched for all files beginning with ``'checkpoint_'`` and loss values are extracted from the filename by assuming all characters after ``'loss'`` and before ``'.ckpt'`` are a float. The checkpoint with the lowest loss is loaded. checkpoint_name is not ``'best'`` we search for a checkpoint file at ``f'{checkpoint_folder}/{checkpoint_name}'``. :param str checkpoint_folder: Folder whithin which to search for checkpoints. :param bool load_optimiser: (default: True) Should optimiser state dictionary be loaded. - ''' - if checkpoint_name=='best': + """ + if checkpoint_name == "best": if self.best_name is not None: _name = self.best_name else: - ckpts = glob.glob(checkpoint_folder+'/checkpoint_*') - indexs = [x.rfind('loss') for x in ckpts] - losses = [float(x[y+4:-5]) for x,y in zip(ckpts, indexs)] + ckpts = glob.glob(checkpoint_folder + "/checkpoint_*") + indexs = [x.rfind("loss") for x in ckpts] + losses = [float(x[y + 4 : -5]) for x, y in zip(ckpts, indexs)] _name = ckpts[np.argmin(losses)] - elif checkpoint_name =='last': - _name = f'{checkpoint_folder}/last.ckpt' + elif checkpoint_name == "last": + _name = f"{checkpoint_folder}/last.ckpt" else: - _name = f'{checkpoint_folder}/{checkpoint_name}' + _name = f"{checkpoint_folder}/{checkpoint_name}" checkpoint = torch.load(_name, map_location=self.device) - if not hasattr(self, 'autoencoder'): - raise NotImplementedError('self.autoencoder does not exist, I have no way of knowing what network you want to load checkoint weights into yet, please set the network first') + if not hasattr(self, "autoencoder"): + raise NotImplementedError( + "self.autoencoder does not exist, I have no way of knowing what network you want to load checkoint weights into yet, please set the network first" + ) - self.autoencoder.load_state_dict(checkpoint['model_state_dict']) + self.autoencoder.load_state_dict(checkpoint["model_state_dict"]) if load_optimiser: - if not hasattr(self, 'optimiser'): - raise NotImplementedError('self.optimiser does not exist, I have no way of knowing what optimiser you previously used, please set it first.') - self.optimiser.load_state_dict(checkpoint['optimizer_state_dict']) - epoch = checkpoint['epoch'] - self.epoch = epoch+1 + if not hasattr(self, "optimiser"): + raise NotImplementedError( + "self.optimiser does not exist, I have no way of knowing what optimiser you previously used, please set it first." + ) + self.optimiser.load_state_dict(checkpoint["optimizer_state_dict"]) + epoch = checkpoint["epoch"] + self.epoch = epoch + 1 def get_repeat(self, checkpoint_folder): if not os.path.exists(checkpoint_folder): os.mkdir(checkpoint_folder) - if not hasattr(self, '_repeat'): + if not hasattr(self, "_repeat"): self._repeat = 0 for i in range(1000): - if not os.path.exists(checkpoint_folder+f'/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt'): - break#os.mkdir(checkpoint_folder) + if not os.path.exists( + checkpoint_folder + + f'/last{f"_{self._repeat}" if self._repeat>0 else ""}.ckpt' + ): + break # os.mkdir(checkpoint_folder) else: self._repeat += 1 else: - raise Exception('Something went wrong, you surely havnt done 1000 repeats?') - + raise Exception( + "Something went wrong, you surely havnt done 1000 repeats?" + ) -if __name__=='__main__': +if __name__ == "__main__": pass