From d784e7be9742b4d1bcfa0f389ea063fec984fe50 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yuan-Ting=20Hsieh=20=28=E8=AC=9D=E6=B2=85=E5=BB=B7=29?= Date: Wed, 15 Dec 2021 21:56:33 -0800 Subject: [PATCH] update tb streaming app tests (#70) --- .../tb_streaming/custom/custom_controller.py | 14 ++++++++++++++ .../tb_streaming/custom/custom_executor.py | 17 +++++++++++++++-- .../test_apps/validators/tb_result_validator.py | 17 +++++++++++++++++ 3 files changed, 46 insertions(+), 2 deletions(-) diff --git a/test/app_testing/test_apps/tb_streaming/custom/custom_controller.py b/test/app_testing/test_apps/tb_streaming/custom/custom_controller.py index 0dbaee9d60..869cdb161f 100755 --- a/test/app_testing/test_apps/tb_streaming/custom/custom_controller.py +++ b/test/app_testing/test_apps/tb_streaming/custom/custom_controller.py @@ -1,3 +1,17 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from nvflare.apis.client import Client from nvflare.apis.fl_context import FLContext from nvflare.apis.impl.controller import ClientTask, Controller, Task diff --git a/test/app_testing/test_apps/tb_streaming/custom/custom_executor.py b/test/app_testing/test_apps/tb_streaming/custom/custom_executor.py index a0022e99ea..0a8d1f46c0 100755 --- a/test/app_testing/test_apps/tb_streaming/custom/custom_executor.py +++ b/test/app_testing/test_apps/tb_streaming/custom/custom_executor.py @@ -1,3 +1,17 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import logging import random import time @@ -17,7 +31,6 @@ def __init__(self, task_name: str = "poc"): raise TypeError("task name should be a string.") self.task_name = task_name - self.logger = logging.getLogger("POCExecutor") def execute( self, @@ -26,7 +39,7 @@ def execute( fl_ctx: FLContext, abort_signal: Signal, ) -> Shareable: - if task_name in self.task_name: + if task_name == self.task_name: peer_ctx = fl_ctx.get_prop(FLContextKey.PEER_CONTEXT) r = peer_ctx.get_prop("current_round") diff --git a/test/app_testing/test_apps/validators/tb_result_validator.py b/test/app_testing/test_apps/validators/tb_result_validator.py index 66e2aa367c..b03132e53d 100644 --- a/test/app_testing/test_apps/validators/tb_result_validator.py +++ b/test/app_testing/test_apps/validators/tb_result_validator.py @@ -1,3 +1,17 @@ +# Copyright (c) 2021, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + import os from test.app_testing.app_result_validator import AppResultValidator @@ -12,15 +26,18 @@ def validate_results(self, server_data, client_data, run_data) -> bool: server_run_dir = os.path.join(server_path, f"run_{run_number}") server_tb_root_dir = os.path.join(server_run_dir, TB_PATH) if not os.path.exists(server_tb_root_dir): + print(f"tb validate results: server_tb_root_dir {server_tb_root_dir} doesn't exist.") return False for i, client_path in enumerate(client_data["client_paths"]): client_run_dir = os.path.join(client_path, f"run_{run_number}") client_side_client_tb_dir = os.path.join(client_run_dir, TB_PATH, client_data["client_names"][i]) if not os.path.exists(client_side_client_tb_dir): + print(f"tb validate results: client_side_client_tb_dir {client_side_client_tb_dir} doesn't exist.") return False server_side_client_tb_dir = os.path.join(server_tb_root_dir, client_data["client_names"][i]) if not os.path.exists(server_side_client_tb_dir): + print(f"tb validate results: server_side_client_tb_dir {server_side_client_tb_dir} doesn't exist.") return False return True