From 0fe6277fa3b80c685ff4685955ee2c0bb86ad849 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Yuan-Ting=20Hsieh=20=28=E8=AC=9D=E6=B2=85=E5=BB=B7=29?= Date: Fri, 21 Jan 2022 14:55:12 -0800 Subject: [PATCH] Update docstrings and clean up (#135) --- nvflare/apis/utils/local_logger.py | 12 +- .../{log_receiver.py => log_streaming.py} | 50 +++++++- nvflare/app_common/widgets/streaming.py | 109 ++---------------- .../tb_streaming/custom/custom_executor.py | 9 +- test/test_streaming.py | 23 ++-- 5 files changed, 77 insertions(+), 126 deletions(-) rename nvflare/app_common/widgets/{log_receiver.py => log_streaming.py} (61%) diff --git a/nvflare/apis/utils/local_logger.py b/nvflare/apis/utils/local_logger.py index b42bed04bc..17170207d3 100644 --- a/nvflare/apis/utils/local_logger.py +++ b/nvflare/apis/utils/local_logger.py @@ -23,10 +23,7 @@ class LocalLogger: @staticmethod def initialize(): - """ Initialize the LocalLogger to keep all the handlers before the adding of LogSender handler. - Returns: - - """ + """Initializes the LocalLogger.""" if not LocalLogger.handlers: LocalLogger.handlers = [] for handler in logging.root.handlers: @@ -34,12 +31,13 @@ def initialize(): @staticmethod def get_logger(name=None) -> logging.Logger: - """ Get a logger only do the local logging. + """Gets a logger only do the local logging. + Args: name: logger name - Returns: local_logger - + Returns: + A local logger. """ with LocalLogger.lock: if not LocalLogger.handlers: diff --git a/nvflare/app_common/widgets/log_receiver.py b/nvflare/app_common/widgets/log_streaming.py similarity index 61% rename from nvflare/app_common/widgets/log_receiver.py rename to nvflare/app_common/widgets/log_streaming.py index ebcf8aa821..23cf2f27c1 100644 --- a/nvflare/app_common/widgets/log_receiver.py +++ b/nvflare/app_common/widgets/log_streaming.py @@ -15,15 +15,53 @@ import logging import os from logging import LogRecord +from threading import Lock from typing import List, Optional from nvflare.apis.analytix import AnalyticsData, AnalyticsDataType from nvflare.apis.dxo import from_shareable +from nvflare.apis.event_type import EventType from nvflare.apis.fl_constant import LogMessageTag from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable from nvflare.app_common.app_event_type import AppEventType -from nvflare.app_common.widgets.streaming import AnalyticsReceiver +from nvflare.app_common.widgets.streaming import AnalyticsReceiver, create_analytic_dxo, send_analytic_dxo +from nvflare.widgets.widget import Widget + + +class LogAnalyticsSender(Widget, logging.StreamHandler): + def __init__(self, log_level=""): + """Sends the log record. + + Args: + log_level: log_level threshold + """ + Widget.__init__(self) + logging.StreamHandler.__init__(self) + self.log_level = getattr(logging, log_level, logging.INFO) + self.engine = None + + def handle_event(self, event_type: str, fl_ctx: FLContext): + if event_type == EventType.ABOUT_TO_START_RUN: + self.engine = fl_ctx.get_engine() + logging.root.addHandler(self) + elif event_type == EventType.END_RUN: + logging.root.removeHandler(self) + + def emit(self, record: LogRecord): + """Sends the log record. + + When the log_level higher than the configured level, sends the log record. + Args: + record: logging record + """ + if record.levelno >= self.log_level and self.engine: + dxo = create_analytic_dxo( + tag=LogMessageTag.LOG_RECORD, value=record, data_type=AnalyticsDataType.LOG_RECORD + ) + with self.engine.new_context() as fl_ctx: + send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx, event_type=AppEventType.LOGGING_EVENT_TYPE) + self.flush() class LogAnalyticsReceiver(AnalyticsReceiver): @@ -40,6 +78,7 @@ def __init__(self, events: Optional[List[str]] = None, formatter=None): self.formatter = formatter self.handlers = {} + self.handlers_lock = Lock() def initialize(self, fl_ctx: FLContext): workspace = fl_ctx.get_engine().get_workspace() @@ -54,9 +93,11 @@ def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin: str): data_type = analytic_data.data_type if data_type == AnalyticsDataType.LOG_RECORD: - handler = self.handlers.get(record_origin) - if not handler: - handler = self._create_log_handler(record_origin) + with self.handlers_lock: + handler = self.handlers.get(record_origin) + if not handler: + handler = self._create_log_handler(record_origin) + self.handlers[record_origin] = handler record: LogRecord = dxo.data.get(LogMessageTag.LOG_RECORD) handler.emit(record) @@ -67,7 +108,6 @@ def _create_log_handler(self, record_origin): filename = os.path.join(self.root_log_dir, record_origin + ".log") handler = logging.FileHandler(filename) handler.setFormatter(logging.Formatter(self.formatter)) - self.handlers[record_origin] = handler return handler def finalize(self, fl_ctx: FLContext): diff --git a/nvflare/app_common/widgets/streaming.py b/nvflare/app_common/widgets/streaming.py index 438792a8fb..cd5270f6b7 100644 --- a/nvflare/app_common/widgets/streaming.py +++ b/nvflare/app_common/widgets/streaming.py @@ -11,9 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import logging from abc import ABC, abstractmethod -from logging import LogRecord from threading import Lock from typing import List, Optional @@ -21,7 +19,7 @@ from nvflare.apis.dxo import DXO from nvflare.apis.event_type import EventType from nvflare.apis.fl_component import FLComponent -from nvflare.apis.fl_constant import EventScope, FLContextKey, LogMessageTag, ReservedKey +from nvflare.apis.fl_constant import EventScope, FLContextKey, ReservedKey from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable from nvflare.app_common.app_event_type import AppEventType @@ -50,14 +48,14 @@ def send_analytic_dxo( comp.fire_event(event_type=event_type, fl_ctx=fl_ctx) -def _write(tag: str, value, data_type: AnalyticsDataType, kwargs: Optional[dict] = None) -> DXO: - """Writes the analytic data. +def create_analytic_dxo(tag: str, value, data_type: AnalyticsDataType, **kwargs) -> DXO: + """Creates the analytic DXO. Args: tag (str): the tag associated with this value. value: the analytic data. data_type (AnalyticsDataType): analytic data type. - kwargs (dict): additional arguments to be passed into the receiver side's function. + kwargs: additional arguments to be passed into the receiver side's function. Returns: A DXO object that contains the analytic data. @@ -67,51 +65,11 @@ def _write(tag: str, value, data_type: AnalyticsDataType, kwargs: Optional[dict] return dxo -def write_scalar(tag: str, scalar: float, **kwargs) -> DXO: - """Writes a scalar. - - Args: - tag (str): the tag associated with this value. - scalar (float): a scalar to write. - """ - return _write(tag, scalar, data_type=AnalyticsDataType.SCALAR, kwargs=kwargs) - - -def write_scalars(tag: str, tag_scalar_dict: dict, **kwargs) -> DXO: - """Writes scalars. - - Args: - tag (str): the tag associated with this dict. - tag_scalar_dict (dict): A dictionary that contains tag and scalars to write. - """ - return _write(tag, tag_scalar_dict, data_type=AnalyticsDataType.SCALARS, kwargs=kwargs) - - -def write_image(tag: str, image, **kwargs) -> DXO: - """Writes an image. - - Args: - tag (str): the tag associated with this value. - image: the image to write. - """ - return _write(tag, image, data_type=AnalyticsDataType.IMAGE, kwargs=kwargs) - - -def write_text(tag: str, text: str, **kwargs) -> DXO: - """Writes text. - - Args: - tag (str): the tag associated with this value. - text (str): the text to write. - """ - return _write(tag, text, data_type=AnalyticsDataType.TEXT, kwargs=kwargs) - - class AnalyticsSender(Widget): def __init__(self): """Sends analytics data. - This class implements some common methods follows signatures from PyTorch SummaryWriter and Python logger. + This class implements some common methods follows signatures from PyTorch SummaryWriter. It provides a convenient way for Learner to use. """ super().__init__() @@ -134,7 +92,7 @@ def _add( if not isinstance(global_step, int): raise TypeError(f"Expect global step to be an instance of int, but got {type(global_step)}") kwargs["global_step"] = global_step - dxo = _write(tag=tag, value=value, data_type=data_type, kwargs=kwargs) + dxo = create_analytic_dxo(tag=tag, value=value, data_type=data_type, kwargs=kwargs) with self.engine.new_context() as fl_ctx: send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx) @@ -182,21 +140,6 @@ def add_image(self, tag: str, image, global_step: Optional[int] = None, **kwargs """ self._add(tag=tag, value=image, data_type=AnalyticsDataType.IMAGE, global_step=global_step, kwargs=kwargs) - def _log(self, tag: LogMessageTag, msg: str, event_type: str, *args, **kwargs): - """Logs a message. - - Args: - tag (LogMessageTag): A tag that contains the level of the log message. - msg (str): Message to log. - event_type (str): Event type that associated with this message. - *args: From python logger api, args is used to format strings. - **kwargs: Additional arguments to be passed into the log function. - """ - msg = msg.format(*args, **kwargs) - dxo = _write(tag=str(tag), value=msg, data_type=AnalyticsDataType.TEXT, kwargs=kwargs) - with self.engine.new_context() as fl_ctx: - send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx, event_type=event_type) - def flush(self): """Flushes out the message. @@ -256,7 +199,10 @@ def finalize(self, fl_ctx: FLContext): def handle_event(self, event_type: str, fl_ctx: FLContext): if event_type == EventType.START_RUN: self.initialize(fl_ctx) - elif event_type in self.events and not self._end: + elif event_type in self.events: + if self._end: + self.log_debug(fl_ctx, f"Already received end run event, drop event {event_type}.", fire_event=False) + return data = fl_ctx.get_prop(FLContextKey.EVENT_DATA, None) if data is None: self.log_error(fl_ctx, "Missing event data.", fire_event=False) @@ -281,38 +227,3 @@ def handle_event(self, event_type: str, fl_ctx: FLContext): elif event_type == EventType.END_RUN: self._end = True self.finalize(fl_ctx) - - -class LogSender(Widget, logging.StreamHandler): - def __init__(self, log_level=""): - """ - LogSender for sending the logging record to the FL server. - Args: - log_level: log_level threshold - """ - Widget.__init__(self) - logging.StreamHandler.__init__(self) - self.log_level = getattr(logging, log_level, logging.INFO) - self.engine = None - - def handle_event(self, event_type: str, fl_ctx: FLContext): - if event_type == EventType.ABOUT_TO_START_RUN: - self.engine = fl_ctx.get_engine() - logging.root.addHandler(self) - if event_type == EventType.END_RUN: - logging.root.removeHandler(self) - - def emit(self, record: LogRecord) -> None: - """ - When the log_level higher than the configured level, sends the log record to the FL server to collect. - Args: - record: logging record - - Returns: - - """ - if record.levelno >= self.log_level and self.engine: - dxo = _write(tag=LogMessageTag.LOG_RECORD, value=record, data_type=AnalyticsDataType.LOG_RECORD) - with self.engine.new_context() as fl_ctx: - send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx, event_type=AppEventType.LOGGING_EVENT_TYPE) - self.flush() diff --git a/test/app_testing/test_apps/tb_streaming/custom/custom_executor.py b/test/app_testing/test_apps/tb_streaming/custom/custom_executor.py index 3a11a95ca8..23279cad0a 100755 --- a/test/app_testing/test_apps/tb_streaming/custom/custom_executor.py +++ b/test/app_testing/test_apps/tb_streaming/custom/custom_executor.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import random import time @@ -21,7 +20,8 @@ from nvflare.apis.fl_context import FLContext from nvflare.apis.shareable import Shareable from nvflare.apis.signal import Signal -from nvflare.app_common.widgets.streaming import send_analytic_dxo, write_scalar, write_text +from nvflare.app_common.widgets.streaming import create_analytic_dxo, send_analytic_dxo +from nvflare.apis.analytix import AnalyticsDataType class CustomExecutor(Executor): @@ -46,9 +46,10 @@ def execute( number = random.random() # send analytics - dxo = write_scalar("random_number", number, global_step=r) + dxo = create_analytic_dxo(tag="random_number", value=number, data_type=AnalyticsDataType.SCALAR, global_step=r) send_analytic_dxo(comp=self, dxo=dxo, fl_ctx=fl_ctx) - dxo = write_text("debug_msg", "Hello world", global_step=r) + dxo = create_analytic_dxo(tag="debug_msg", value="Hello world", data_type=AnalyticsDataType.TEXT, + global_step=r) send_analytic_dxo(comp=self, dxo=dxo, fl_ctx=fl_ctx) time.sleep(2.0) diff --git a/test/test_streaming.py b/test/test_streaming.py index 6f109aa91d..76b9d6c0fd 100644 --- a/test/test_streaming.py +++ b/test/test_streaming.py @@ -14,10 +14,11 @@ import pytest +from nvflare.apis.analytix import AnalyticsDataType from nvflare.apis.dxo import DXO, DataKind from nvflare.apis.fl_component import FLComponent from nvflare.apis.fl_context import FLContext -from nvflare.app_common.widgets.streaming import send_analytic_dxo, write_image, write_scalar, write_scalars, write_text +from nvflare.app_common.widgets.streaming import send_analytic_dxo, create_analytic_dxo INVALID_TEST_CASES = [ (list(), dict(), FLContext(), TypeError, f"expect comp to be an instance of FLComponent, but got {type(list())}"), @@ -32,13 +33,13 @@ ] INVALID_WRITE_TEST_CASES = [ - (write_scalar, list(), 1.0, TypeError, f"expect tag to be an instance of str, but got {type(list())}"), - (write_scalar, "tag", list(), TypeError, f"expect value to be an instance of float, but got {type(list())}"), - (write_scalars, list(), 1.0, TypeError, f"expect tag to be an instance of str, but got {type(list())}"), - (write_scalars, "tag", 1.0, TypeError, f"expect value to be an instance of dict, but got {type(1.0)}"), - (write_text, list(), 1.0, TypeError, f"expect tag to be an instance of str, but got {type(list())}"), - (write_text, "tag", 1.0, TypeError, f"expect value to be an instance of str, but got {type(1.0)}"), - (write_image, list(), 1.0, TypeError, f"expect tag to be an instance of str, but got {type(list())}"), + (list(), 1.0, AnalyticsDataType.SCALAR, TypeError, f"expect tag to be an instance of str, but got {type(list())}"), + ("tag", list(), AnalyticsDataType.SCALAR, TypeError, f"expect value to be an instance of float, but got {type(list())}"), + (list(), 1.0, AnalyticsDataType.SCALARS, TypeError, f"expect tag to be an instance of str, but got {type(list())}"), + ("tag", 1.0, AnalyticsDataType.SCALARS, TypeError, f"expect value to be an instance of dict, but got {type(1.0)}"), + (list(), 1.0, AnalyticsDataType.TEXT, TypeError, f"expect tag to be an instance of str, but got {type(list())}"), + ("tag", 1.0, AnalyticsDataType.TEXT, TypeError, f"expect value to be an instance of str, but got {type(1.0)}"), + (list(), 1.0, AnalyticsDataType.IMAGE, TypeError, f"expect tag to be an instance of str, but got {type(list())}"), ] @@ -48,7 +49,7 @@ def test_invalid_send_analytic_dxo(self, comp, dxo, fl_ctx, expected_error, expe with pytest.raises(expected_error, match=expected_msg): send_analytic_dxo(comp=comp, dxo=dxo, fl_ctx=fl_ctx) - @pytest.mark.parametrize("func,tag,value,expected_error,expected_msg", INVALID_WRITE_TEST_CASES) - def test_invalid_write_func(self, func, tag, value, expected_error, expected_msg): + @pytest.mark.parametrize("tag,value,data_type,expected_error,expected_msg", INVALID_WRITE_TEST_CASES) + def test_invalid_write_func(self, tag, value, data_type, expected_error, expected_msg): with pytest.raises(expected_error, match=expected_msg): - func(tag, value) + create_analytic_dxo(tag, value, data_type)