From 19780e0cef8bb3aa90ca595517288ac8840f38ad Mon Sep 17 00:00:00 2001 From: HessTaha Date: Sun, 31 May 2020 16:25:33 +0200 Subject: [PATCH] Add saving features --- src/main.py | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/src/main.py b/src/main.py index 14c3d1f..eba8924 100644 --- a/src/main.py +++ b/src/main.py @@ -7,6 +7,8 @@ logger.setLevel(logging.INFO) import pandas as pd +import json + import torch import progressbar @@ -16,6 +18,7 @@ from sklearn.metrics import f1_score def main(path_to_data: str, + cache_dir: str, texts_col: str, labels_col: str, n_classes: int, @@ -30,6 +33,14 @@ def main(path_to_data: str, ''' df = pd.read_csv(path_to_data) + if os.path.isdir(): + logger.info('Cache dir found here {}'.format(cache_dir)) + pass + else: + logger.info('Creating cache dir') + os.mkdir(cache_dir) + + # Preprocess optimal_length = get_length(df, texts_col) X, vocab_size = encode_texts(df, texts_col, max_seq_length=optimal_length, return_vocab_size=True) @@ -42,6 +53,11 @@ def main(path_to_data: str, vocab_size=vocab_size, n_classes=n_classes ) + + config_dict = { + "vocab_size" : vocab_size, + "n_classes" : n_classes + } if n_classes > 2: criterion = torch.nn.CrossEntropyLoss() @@ -133,11 +149,22 @@ def main(path_to_data: str, logger.info("Evaluation at iteration {} done: eval loss: {}\n eval f1: {}".format(epoch, eval_loss.item(), tmp_f1)) + ## Bring back model to cpu + Model.cpu() + + ## Get/Save param dict + logger.info('Saving model in cache dir {}'.format(cache_dir)) + torch.save(Model.state_dict(), cache_dir+'state_dict.pt') + with open('config_model.json', 'w') as file: + json.dump(config_dict, file) + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--data_path", help="path to the data directory", type=str) + parser.add_argument("--cache_dir", help="cache directory", type=str) parser.add_argument("--texts_col", help="name of the column containing textual data", type=str) parser.add_argument("--labels_col", help="name of the column containing labels", type=str) parser.add_argument("--n_classes", type=int) @@ -152,6 +179,7 @@ def main(path_to_data: str, main( path_to_data = args.data_path, + cache_dir = args.cache_dir, texts_col = args.texts_col, labels_col = args.labels_col, n_classes = args.n_classes,