Skip to content

Commit

Permalink
Merge pull request #7 from SiLiKhon/s3_logging
Browse files Browse the repository at this point in the history
S3 logging
  • Loading branch information
SiLiKhon authored Jul 14, 2022
2 parents cf60f9d + b8f5cab commit 7a4f590
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
2 changes: 1 addition & 1 deletion cuda_gpu_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,4 @@ def setup_gpu(gpu_num=None):
tf.config.experimental.set_memory_growth(gpu, True)

logical_devices = tf.config.experimental.list_logical_devices('GPU')
assert len(logical_devices) > 0, "Not enough GPU hardware devices available"
assert len(logical_devices) > 0 or gpu_num == "", "Not enough GPU hardware devices available"
5 changes: 3 additions & 2 deletions run_model_v4.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def make_parser():
parser.add_argument('--checkpoint_name', type=str, required=True)
parser.add_argument('--gpu_num', type=str, required=False)
parser.add_argument('--prediction_only', action='store_true', default=False)
parser.add_argument('--logging_dir', type=str, default='logs')

return parser

Expand Down Expand Up @@ -101,8 +102,8 @@ def main():
Y_train, Y_test, X_train, X_test = train_test_split(data_scaled, features, test_size=0.25, random_state=42)

if not args.prediction_only:
writer_train = tf.summary.create_file_writer(f'logs/{args.checkpoint_name}/train')
writer_val = tf.summary.create_file_writer(f'logs/{args.checkpoint_name}/validation')
writer_train = tf.summary.create_file_writer(f'{args.logging_dir}/{args.checkpoint_name}/train')
writer_val = tf.summary.create_file_writer(f'{args.logging_dir}/{args.checkpoint_name}/validation')

if args.prediction_only:
epoch = latest_epoch(model_path)
Expand Down

0 comments on commit 7a4f590

Please sign in to comment.