Skip to content

Commit

Permalink
Add saving features
Browse files Browse the repository at this point in the history
  • Loading branch information
HessTaha committed May 31, 2020
1 parent 511432b commit 19780e0
Showing 1 changed file with 28 additions and 0 deletions.
28 changes: 28 additions & 0 deletions src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
logger.setLevel(logging.INFO)

import pandas as pd
import json

import torch
import progressbar

Expand All @@ -16,6 +18,7 @@
from sklearn.metrics import f1_score

def main(path_to_data: str,
cache_dir: str,
texts_col: str,
labels_col: str,
n_classes: int,
Expand All @@ -30,6 +33,14 @@ def main(path_to_data: str,
'''
df = pd.read_csv(path_to_data)

if os.path.isdir():
logger.info('Cache dir found here {}'.format(cache_dir))
pass
else:
logger.info('Creating cache dir')
os.mkdir(cache_dir)


# Preprocess
optimal_length = get_length(df, texts_col)
X, vocab_size = encode_texts(df, texts_col, max_seq_length=optimal_length, return_vocab_size=True)
Expand All @@ -42,6 +53,11 @@ def main(path_to_data: str,
vocab_size=vocab_size,
n_classes=n_classes
)

config_dict = {
"vocab_size" : vocab_size,
"n_classes" : n_classes
}

if n_classes > 2:
criterion = torch.nn.CrossEntropyLoss()
Expand Down Expand Up @@ -133,11 +149,22 @@ def main(path_to_data: str,
logger.info("Evaluation at iteration {} done: eval loss: {}\n eval f1: {}".format(epoch, eval_loss.item(), tmp_f1))


## Bring back model to cpu
Model.cpu()

## Get/Save param dict
logger.info('Saving model in cache dir {}'.format(cache_dir))
torch.save(Model.state_dict(), cache_dir+'state_dict.pt')
with open('config_model.json', 'w') as file:
json.dump(config_dict, file)


if __name__ == "__main__":

parser = argparse.ArgumentParser()

parser.add_argument("--data_path", help="path to the data directory", type=str)
parser.add_argument("--cache_dir", help="cache directory", type=str)
parser.add_argument("--texts_col", help="name of the column containing textual data", type=str)
parser.add_argument("--labels_col", help="name of the column containing labels", type=str)
parser.add_argument("--n_classes", type=int)
Expand All @@ -152,6 +179,7 @@ def main(path_to_data: str,

main(
path_to_data = args.data_path,
cache_dir = args.cache_dir,
texts_col = args.texts_col,
labels_col = args.labels_col,
n_classes = args.n_classes,
Expand Down

0 comments on commit 19780e0

Please sign in to comment.