From 9abd390fe0d04bc1ace40aac8db4aba04cf360bb Mon Sep 17 00:00:00 2001 From: YuanTingHsieh Date: Fri, 2 Feb 2024 10:48:46 -0800 Subject: [PATCH] Simplify code --- nvflare/app_opt/tracking/tb/tb_receiver.py | 80 ++++++++-------------- 1 file changed, 29 insertions(+), 51 deletions(-) diff --git a/nvflare/app_opt/tracking/tb/tb_receiver.py b/nvflare/app_opt/tracking/tb/tb_receiver.py index 7eecb86eee..585087dadc 100644 --- a/nvflare/app_opt/tracking/tb/tb_receiver.py +++ b/nvflare/app_opt/tracking/tb/tb_receiver.py @@ -33,6 +33,17 @@ } +def _create_new_data(key, value, sender): + if isinstance(value, (int, float)): + data_type = AnalyticsDataType.SCALAR + elif isinstance(value, str): + data_type = AnalyticsDataType.TEXT + else: + return None + + return AnalyticsData(key=key, value=value, data_type=data_type, sender=sender) + + class TBAnalyticsReceiver(AnalyticsReceiver): def __init__(self, tb_folder="tb_events", events: Optional[List[str]] = None): """Receives analytics data to save to TensorBoard. @@ -69,53 +80,26 @@ def initialize(self, fl_ctx: FLContext): os.makedirs(root_log_dir, exist_ok=True) self.root_log_dir = root_log_dir - def _convert_params_to_records(self, analytic_data: AnalyticsData) -> List[AnalyticsData]: + def _convert_to_records(self, analytic_data: AnalyticsData, fl_ctx: FLContext) -> List[AnalyticsData]: # break dict of stuff to smaller items to support # AnalyticsDataType.PARAMETER and AnalyticsDataType.PARAMETERS - if analytic_data.data_type == AnalyticsDataType.PARAMETER: - if isinstance(analytic_data.value, int) or isinstance(analytic_data.value, float): - return [ - AnalyticsData( - key=analytic_data.tag, - value=float(analytic_data.value), - data_type=AnalyticsDataType.SCALAR, - sender=analytic_data.sender, - ) - ] - elif isinstance(analytic_data.value, str): - return [ - AnalyticsData( - key=analytic_data.tag, - value=analytic_data.value, - data_type=AnalyticsDataType.TEXT, - sender=analytic_data.sender, - ) - ] - else: - return [analytic_data] - elif analytic_data.data_type == AnalyticsDataType.PARAMETERS: - records = [] - for k, v in analytic_data.value.items(): - if isinstance(v, int) or isinstance(v, float): - new_data = AnalyticsData( - key=analytic_data.tag, - value={k: float(v)}, - data_type=AnalyticsDataType.SCALARS, - sender=analytic_data.sender, - ) - elif isinstance(v, str): - new_data = AnalyticsData( - key=analytic_data.tag, value=v, data_type=AnalyticsDataType.TEXT, sender=analytic_data.sender - ) + records = [] + + if analytic_data.data_type in (AnalyticsDataType.PARAMETER, AnalyticsDataType.PARAMETERS): + for k, v in ( + analytic_data.value.items() + if analytic_data.data_type == AnalyticsDataType.PARAMETERS + else [(analytic_data.tag, analytic_data.value)] + ): + new_data = _create_new_data(k, v, analytic_data.sender) + if new_data is None: + self.log_warning(fl_ctx, f"Entry {k} of type {type(v)} is not supported.", fire_event=False) else: - new_data = AnalyticsData( - key=analytic_data.tag, - value={k: v}, - data_type=AnalyticsDataType.PARAMETERS, - sender=analytic_data.sender, - ) - records.append(new_data) - return records + records.append(new_data) + else: + records.append(analytic_data) + + return records def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin): dxo = from_shareable(shareable) @@ -135,13 +119,7 @@ def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin): f"try to save data {analytic_data} from {record_origin}", fire_event=False, ) - if ( - analytic_data.data_type == AnalyticsDataType.PARAMETER - or analytic_data.data_type == AnalyticsDataType.PARAMETERS - ): - data_records = self._convert_params_to_records(analytic_data) - else: - data_records = [analytic_data] + data_records = self._convert_to_records(analytic_data, fl_ctx) for data_record in data_records: func_name = FUNCTION_MAPPING.get(data_record.data_type, None)