From 89d6d2ea7f3148662b71d3ca6999f86eecc2a692 Mon Sep 17 00:00:00 2001 From: alexdrydew Date: Mon, 2 May 2022 21:49:11 +0300 Subject: [PATCH 1/2] support non default logging dir --- run_model_v4.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/run_model_v4.py b/run_model_v4.py index 970fe06..872966e 100644 --- a/run_model_v4.py +++ b/run_model_v4.py @@ -21,6 +21,7 @@ def make_parser(): parser.add_argument('--checkpoint_name', type=str, required=True) parser.add_argument('--gpu_num', type=str, required=False) parser.add_argument('--prediction_only', action='store_true', default=False) + parser.add_argument('--logging_dir', type=str, default='logs') return parser @@ -101,8 +102,8 @@ def main(): Y_train, Y_test, X_train, X_test = train_test_split(data_scaled, features, test_size=0.25, random_state=42) if not args.prediction_only: - writer_train = tf.summary.create_file_writer(f'logs/{args.checkpoint_name}/train') - writer_val = tf.summary.create_file_writer(f'logs/{args.checkpoint_name}/validation') + writer_train = tf.summary.create_file_writer(f'{args.logging_dir}/{args.checkpoint_name}/train') + writer_val = tf.summary.create_file_writer(f'{args.logging_dir}/{args.checkpoint_name}/validation') if args.prediction_only: epoch = latest_epoch(model_path) From b8f5cabaddc78d2a09cb0b4f7e0c238e1fce2edf Mon Sep 17 00:00:00 2001 From: alexdrydew Date: Sun, 8 May 2022 21:21:22 +0300 Subject: [PATCH 2/2] support gpu training --- cuda_gpu_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cuda_gpu_config.py b/cuda_gpu_config.py index deeaefb..6c4551b 100644 --- a/cuda_gpu_config.py +++ b/cuda_gpu_config.py @@ -13,4 +13,4 @@ def setup_gpu(gpu_num=None): tf.config.experimental.set_memory_growth(gpu, True) logical_devices = tf.config.experimental.list_logical_devices('GPU') - assert len(logical_devices) > 0, "Not enough GPU hardware devices available" + assert len(logical_devices) > 0 or gpu_num == "", "Not enough GPU hardware devices available"