Skip to content

Commit

Permalink
Update docstrings and clean up (#135)
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored Jan 21, 2022
1 parent 57152d3 commit 0fe6277
Show file tree
Hide file tree
Showing 5 changed files with 77 additions and 126 deletions.
12 changes: 5 additions & 7 deletions nvflare/apis/utils/local_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,23 +23,21 @@ class LocalLogger:

@staticmethod
def initialize():
""" Initialize the LocalLogger to keep all the handlers before the adding of LogSender handler.
Returns:
"""
"""Initializes the LocalLogger."""
if not LocalLogger.handlers:
LocalLogger.handlers = []
for handler in logging.root.handlers:
LocalLogger.handlers.append(handler)

@staticmethod
def get_logger(name=None) -> logging.Logger:
""" Get a logger only do the local logging.
"""Gets a logger only do the local logging.
Args:
name: logger name
Returns: local_logger
Returns:
A local logger.
"""
with LocalLogger.lock:
if not LocalLogger.handlers:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,53 @@
import logging
import os
from logging import LogRecord
from threading import Lock
from typing import List, Optional

from nvflare.apis.analytix import AnalyticsData, AnalyticsDataType
from nvflare.apis.dxo import from_shareable
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_constant import LogMessageTag
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.app_event_type import AppEventType
from nvflare.app_common.widgets.streaming import AnalyticsReceiver
from nvflare.app_common.widgets.streaming import AnalyticsReceiver, create_analytic_dxo, send_analytic_dxo
from nvflare.widgets.widget import Widget


class LogAnalyticsSender(Widget, logging.StreamHandler):
def __init__(self, log_level=""):
"""Sends the log record.
Args:
log_level: log_level threshold
"""
Widget.__init__(self)
logging.StreamHandler.__init__(self)
self.log_level = getattr(logging, log_level, logging.INFO)
self.engine = None

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.ABOUT_TO_START_RUN:
self.engine = fl_ctx.get_engine()
logging.root.addHandler(self)
elif event_type == EventType.END_RUN:
logging.root.removeHandler(self)

def emit(self, record: LogRecord):
"""Sends the log record.
When the log_level higher than the configured level, sends the log record.
Args:
record: logging record
"""
if record.levelno >= self.log_level and self.engine:
dxo = create_analytic_dxo(
tag=LogMessageTag.LOG_RECORD, value=record, data_type=AnalyticsDataType.LOG_RECORD
)
with self.engine.new_context() as fl_ctx:
send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx, event_type=AppEventType.LOGGING_EVENT_TYPE)
self.flush()


class LogAnalyticsReceiver(AnalyticsReceiver):
Expand All @@ -40,6 +78,7 @@ def __init__(self, events: Optional[List[str]] = None, formatter=None):
self.formatter = formatter

self.handlers = {}
self.handlers_lock = Lock()

def initialize(self, fl_ctx: FLContext):
workspace = fl_ctx.get_engine().get_workspace()
Expand All @@ -54,9 +93,11 @@ def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin: str):
data_type = analytic_data.data_type

if data_type == AnalyticsDataType.LOG_RECORD:
handler = self.handlers.get(record_origin)
if not handler:
handler = self._create_log_handler(record_origin)
with self.handlers_lock:
handler = self.handlers.get(record_origin)
if not handler:
handler = self._create_log_handler(record_origin)
self.handlers[record_origin] = handler

record: LogRecord = dxo.data.get(LogMessageTag.LOG_RECORD)
handler.emit(record)
Expand All @@ -67,7 +108,6 @@ def _create_log_handler(self, record_origin):
filename = os.path.join(self.root_log_dir, record_origin + ".log")
handler = logging.FileHandler(filename)
handler.setFormatter(logging.Formatter(self.formatter))
self.handlers[record_origin] = handler
return handler

def finalize(self, fl_ctx: FLContext):
Expand Down
109 changes: 10 additions & 99 deletions nvflare/app_common/widgets/streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,15 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from abc import ABC, abstractmethod
from logging import LogRecord
from threading import Lock
from typing import List, Optional

from nvflare.apis.analytix import AnalyticsData, AnalyticsDataType
from nvflare.apis.dxo import DXO
from nvflare.apis.event_type import EventType
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_constant import EventScope, FLContextKey, LogMessageTag, ReservedKey
from nvflare.apis.fl_constant import EventScope, FLContextKey, ReservedKey
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.app_common.app_event_type import AppEventType
Expand Down Expand Up @@ -50,14 +48,14 @@ def send_analytic_dxo(
comp.fire_event(event_type=event_type, fl_ctx=fl_ctx)


def _write(tag: str, value, data_type: AnalyticsDataType, kwargs: Optional[dict] = None) -> DXO:
"""Writes the analytic data.
def create_analytic_dxo(tag: str, value, data_type: AnalyticsDataType, **kwargs) -> DXO:
"""Creates the analytic DXO.
Args:
tag (str): the tag associated with this value.
value: the analytic data.
data_type (AnalyticsDataType): analytic data type.
kwargs (dict): additional arguments to be passed into the receiver side's function.
kwargs: additional arguments to be passed into the receiver side's function.
Returns:
A DXO object that contains the analytic data.
Expand All @@ -67,51 +65,11 @@ def _write(tag: str, value, data_type: AnalyticsDataType, kwargs: Optional[dict]
return dxo


def write_scalar(tag: str, scalar: float, **kwargs) -> DXO:
"""Writes a scalar.
Args:
tag (str): the tag associated with this value.
scalar (float): a scalar to write.
"""
return _write(tag, scalar, data_type=AnalyticsDataType.SCALAR, kwargs=kwargs)


def write_scalars(tag: str, tag_scalar_dict: dict, **kwargs) -> DXO:
"""Writes scalars.
Args:
tag (str): the tag associated with this dict.
tag_scalar_dict (dict): A dictionary that contains tag and scalars to write.
"""
return _write(tag, tag_scalar_dict, data_type=AnalyticsDataType.SCALARS, kwargs=kwargs)


def write_image(tag: str, image, **kwargs) -> DXO:
"""Writes an image.
Args:
tag (str): the tag associated with this value.
image: the image to write.
"""
return _write(tag, image, data_type=AnalyticsDataType.IMAGE, kwargs=kwargs)


def write_text(tag: str, text: str, **kwargs) -> DXO:
"""Writes text.
Args:
tag (str): the tag associated with this value.
text (str): the text to write.
"""
return _write(tag, text, data_type=AnalyticsDataType.TEXT, kwargs=kwargs)


class AnalyticsSender(Widget):
def __init__(self):
"""Sends analytics data.
This class implements some common methods follows signatures from PyTorch SummaryWriter and Python logger.
This class implements some common methods follows signatures from PyTorch SummaryWriter.
It provides a convenient way for Learner to use.
"""
super().__init__()
Expand All @@ -134,7 +92,7 @@ def _add(
if not isinstance(global_step, int):
raise TypeError(f"Expect global step to be an instance of int, but got {type(global_step)}")
kwargs["global_step"] = global_step
dxo = _write(tag=tag, value=value, data_type=data_type, kwargs=kwargs)
dxo = create_analytic_dxo(tag=tag, value=value, data_type=data_type, kwargs=kwargs)
with self.engine.new_context() as fl_ctx:
send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx)

Expand Down Expand Up @@ -182,21 +140,6 @@ def add_image(self, tag: str, image, global_step: Optional[int] = None, **kwargs
"""
self._add(tag=tag, value=image, data_type=AnalyticsDataType.IMAGE, global_step=global_step, kwargs=kwargs)

def _log(self, tag: LogMessageTag, msg: str, event_type: str, *args, **kwargs):
"""Logs a message.
Args:
tag (LogMessageTag): A tag that contains the level of the log message.
msg (str): Message to log.
event_type (str): Event type that associated with this message.
*args: From python logger api, args is used to format strings.
**kwargs: Additional arguments to be passed into the log function.
"""
msg = msg.format(*args, **kwargs)
dxo = _write(tag=str(tag), value=msg, data_type=AnalyticsDataType.TEXT, kwargs=kwargs)
with self.engine.new_context() as fl_ctx:
send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx, event_type=event_type)

def flush(self):
"""Flushes out the message.
Expand Down Expand Up @@ -256,7 +199,10 @@ def finalize(self, fl_ctx: FLContext):
def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.START_RUN:
self.initialize(fl_ctx)
elif event_type in self.events and not self._end:
elif event_type in self.events:
if self._end:
self.log_debug(fl_ctx, f"Already received end run event, drop event {event_type}.", fire_event=False)
return
data = fl_ctx.get_prop(FLContextKey.EVENT_DATA, None)
if data is None:
self.log_error(fl_ctx, "Missing event data.", fire_event=False)
Expand All @@ -281,38 +227,3 @@ def handle_event(self, event_type: str, fl_ctx: FLContext):
elif event_type == EventType.END_RUN:
self._end = True
self.finalize(fl_ctx)


class LogSender(Widget, logging.StreamHandler):
def __init__(self, log_level=""):
"""
LogSender for sending the logging record to the FL server.
Args:
log_level: log_level threshold
"""
Widget.__init__(self)
logging.StreamHandler.__init__(self)
self.log_level = getattr(logging, log_level, logging.INFO)
self.engine = None

def handle_event(self, event_type: str, fl_ctx: FLContext):
if event_type == EventType.ABOUT_TO_START_RUN:
self.engine = fl_ctx.get_engine()
logging.root.addHandler(self)
if event_type == EventType.END_RUN:
logging.root.removeHandler(self)

def emit(self, record: LogRecord) -> None:
"""
When the log_level higher than the configured level, sends the log record to the FL server to collect.
Args:
record: logging record
Returns:
"""
if record.levelno >= self.log_level and self.engine:
dxo = _write(tag=LogMessageTag.LOG_RECORD, value=record, data_type=AnalyticsDataType.LOG_RECORD)
with self.engine.new_context() as fl_ctx:
send_analytic_dxo(self, dxo=dxo, fl_ctx=fl_ctx, event_type=AppEventType.LOGGING_EVENT_TYPE)
self.flush()
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import random
import time

Expand All @@ -21,7 +20,8 @@
from nvflare.apis.fl_context import FLContext
from nvflare.apis.shareable import Shareable
from nvflare.apis.signal import Signal
from nvflare.app_common.widgets.streaming import send_analytic_dxo, write_scalar, write_text
from nvflare.app_common.widgets.streaming import create_analytic_dxo, send_analytic_dxo
from nvflare.apis.analytix import AnalyticsDataType


class CustomExecutor(Executor):
Expand All @@ -46,9 +46,10 @@ def execute(
number = random.random()

# send analytics
dxo = write_scalar("random_number", number, global_step=r)
dxo = create_analytic_dxo(tag="random_number", value=number, data_type=AnalyticsDataType.SCALAR, global_step=r)
send_analytic_dxo(comp=self, dxo=dxo, fl_ctx=fl_ctx)
dxo = write_text("debug_msg", "Hello world", global_step=r)
dxo = create_analytic_dxo(tag="debug_msg", value="Hello world", data_type=AnalyticsDataType.TEXT,
global_step=r)
send_analytic_dxo(comp=self, dxo=dxo, fl_ctx=fl_ctx)
time.sleep(2.0)

Expand Down
23 changes: 12 additions & 11 deletions test/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@

import pytest

from nvflare.apis.analytix import AnalyticsDataType
from nvflare.apis.dxo import DXO, DataKind
from nvflare.apis.fl_component import FLComponent
from nvflare.apis.fl_context import FLContext
from nvflare.app_common.widgets.streaming import send_analytic_dxo, write_image, write_scalar, write_scalars, write_text
from nvflare.app_common.widgets.streaming import send_analytic_dxo, create_analytic_dxo

INVALID_TEST_CASES = [
(list(), dict(), FLContext(), TypeError, f"expect comp to be an instance of FLComponent, but got {type(list())}"),
Expand All @@ -32,13 +33,13 @@
]

INVALID_WRITE_TEST_CASES = [
(write_scalar, list(), 1.0, TypeError, f"expect tag to be an instance of str, but got {type(list())}"),
(write_scalar, "tag", list(), TypeError, f"expect value to be an instance of float, but got {type(list())}"),
(write_scalars, list(), 1.0, TypeError, f"expect tag to be an instance of str, but got {type(list())}"),
(write_scalars, "tag", 1.0, TypeError, f"expect value to be an instance of dict, but got {type(1.0)}"),
(write_text, list(), 1.0, TypeError, f"expect tag to be an instance of str, but got {type(list())}"),
(write_text, "tag", 1.0, TypeError, f"expect value to be an instance of str, but got {type(1.0)}"),
(write_image, list(), 1.0, TypeError, f"expect tag to be an instance of str, but got {type(list())}"),
(list(), 1.0, AnalyticsDataType.SCALAR, TypeError, f"expect tag to be an instance of str, but got {type(list())}"),
("tag", list(), AnalyticsDataType.SCALAR, TypeError, f"expect value to be an instance of float, but got {type(list())}"),
(list(), 1.0, AnalyticsDataType.SCALARS, TypeError, f"expect tag to be an instance of str, but got {type(list())}"),
("tag", 1.0, AnalyticsDataType.SCALARS, TypeError, f"expect value to be an instance of dict, but got {type(1.0)}"),
(list(), 1.0, AnalyticsDataType.TEXT, TypeError, f"expect tag to be an instance of str, but got {type(list())}"),
("tag", 1.0, AnalyticsDataType.TEXT, TypeError, f"expect value to be an instance of str, but got {type(1.0)}"),
(list(), 1.0, AnalyticsDataType.IMAGE, TypeError, f"expect tag to be an instance of str, but got {type(list())}"),
]


Expand All @@ -48,7 +49,7 @@ def test_invalid_send_analytic_dxo(self, comp, dxo, fl_ctx, expected_error, expe
with pytest.raises(expected_error, match=expected_msg):
send_analytic_dxo(comp=comp, dxo=dxo, fl_ctx=fl_ctx)

@pytest.mark.parametrize("func,tag,value,expected_error,expected_msg", INVALID_WRITE_TEST_CASES)
def test_invalid_write_func(self, func, tag, value, expected_error, expected_msg):
@pytest.mark.parametrize("tag,value,data_type,expected_error,expected_msg", INVALID_WRITE_TEST_CASES)
def test_invalid_write_func(self, tag, value, data_type, expected_error, expected_msg):
with pytest.raises(expected_error, match=expected_msg):
func(tag, value)
create_analytic_dxo(tag, value, data_type)

0 comments on commit 0fe6277

Please sign in to comment.