diff --git a/src/main.py b/src/main.py index 32f74aa..3b5ae3d 100644 --- a/src/main.py +++ b/src/main.py @@ -154,8 +154,8 @@ def main(path_to_data: str, ## Get/Save param dict logger.info('Saving model in cache dir {}'.format(cache_dir)) - torch.save(Model.state_dict(), cache_dir+'state_dict.pt') - with open('config_model.json', 'w') as file: + torch.save(Model.state_dict(), os.path.join(cache_dir,'state_dict.pt')) + with open(os.path.join(cache_dir,'config_model.json'), 'w') as file: json.dump(config_dict, file)