Skip to content

Commit

Permalink
update tb streaming app tests (#70)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Dec 16, 2021
1 parent ba77afb commit d784e7b
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 2 deletions.
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from nvflare.apis.client import Client
from nvflare.apis.fl_context import FLContext
from nvflare.apis.impl.controller import ClientTask, Controller, Task
Expand Down
17 changes: 15 additions & 2 deletions test/app_testing/test_apps/tb_streaming/custom/custom_executor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
import time
Expand All @@ -17,7 +31,6 @@ def __init__(self, task_name: str = "poc"):
raise TypeError("task name should be a string.")

self.task_name = task_name
self.logger = logging.getLogger("POCExecutor")

def execute(
self,
Expand All @@ -26,7 +39,7 @@ def execute(
fl_ctx: FLContext,
abort_signal: Signal,
) -> Shareable:
if task_name in self.task_name:
if task_name == self.task_name:
peer_ctx = fl_ctx.get_prop(FLContextKey.PEER_CONTEXT)
r = peer_ctx.get_prop("current_round")

Expand Down
17 changes: 17 additions & 0 deletions test/app_testing/test_apps/validators/tb_result_validator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,17 @@
# Copyright (c) 2021, NVIDIA CORPORATION.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
from test.app_testing.app_result_validator import AppResultValidator

Expand All @@ -12,15 +26,18 @@ def validate_results(self, server_data, client_data, run_data) -> bool:
server_run_dir = os.path.join(server_path, f"run_{run_number}")
server_tb_root_dir = os.path.join(server_run_dir, TB_PATH)
if not os.path.exists(server_tb_root_dir):
print(f"tb validate results: server_tb_root_dir {server_tb_root_dir} doesn't exist.")
return False

for i, client_path in enumerate(client_data["client_paths"]):
client_run_dir = os.path.join(client_path, f"run_{run_number}")
client_side_client_tb_dir = os.path.join(client_run_dir, TB_PATH, client_data["client_names"][i])
if not os.path.exists(client_side_client_tb_dir):
print(f"tb validate results: client_side_client_tb_dir {client_side_client_tb_dir} doesn't exist.")
return False

server_side_client_tb_dir = os.path.join(server_tb_root_dir, client_data["client_names"][i])
if not os.path.exists(server_side_client_tb_dir):
print(f"tb validate results: server_side_client_tb_dir {server_side_client_tb_dir} doesn't exist.")
return False
return True

0 comments on commit d784e7b

Please sign in to comment.