diff --git a/nvflare/app_opt/tracking/tb/tb_receiver.py b/nvflare/app_opt/tracking/tb/tb_receiver.py index aad8133297..7eecb86eee 100644 --- a/nvflare/app_opt/tracking/tb/tb_receiver.py +++ b/nvflare/app_opt/tracking/tb/tb_receiver.py @@ -69,6 +69,54 @@ def initialize(self, fl_ctx: FLContext): os.makedirs(root_log_dir, exist_ok=True) self.root_log_dir = root_log_dir + def _convert_params_to_records(self, analytic_data: AnalyticsData) -> List[AnalyticsData]: + # break dict of stuff to smaller items to support + # AnalyticsDataType.PARAMETER and AnalyticsDataType.PARAMETERS + if analytic_data.data_type == AnalyticsDataType.PARAMETER: + if isinstance(analytic_data.value, int) or isinstance(analytic_data.value, float): + return [ + AnalyticsData( + key=analytic_data.tag, + value=float(analytic_data.value), + data_type=AnalyticsDataType.SCALAR, + sender=analytic_data.sender, + ) + ] + elif isinstance(analytic_data.value, str): + return [ + AnalyticsData( + key=analytic_data.tag, + value=analytic_data.value, + data_type=AnalyticsDataType.TEXT, + sender=analytic_data.sender, + ) + ] + else: + return [analytic_data] + elif analytic_data.data_type == AnalyticsDataType.PARAMETERS: + records = [] + for k, v in analytic_data.value.items(): + if isinstance(v, int) or isinstance(v, float): + new_data = AnalyticsData( + key=analytic_data.tag, + value={k: float(v)}, + data_type=AnalyticsDataType.SCALARS, + sender=analytic_data.sender, + ) + elif isinstance(v, str): + new_data = AnalyticsData( + key=analytic_data.tag, value=v, data_type=AnalyticsDataType.TEXT, sender=analytic_data.sender + ) + else: + new_data = AnalyticsData( + key=analytic_data.tag, + value={k: v}, + data_type=AnalyticsDataType.PARAMETERS, + sender=analytic_data.sender, + ) + records.append(new_data) + return records + def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin): dxo = from_shareable(shareable) analytic_data = AnalyticsData.from_dxo(dxo) @@ -84,19 +132,28 @@ def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin): # do different things depending on the type in dxo self.log_debug( fl_ctx, - f"save data {analytic_data} from {record_origin}", + f"try to save data {analytic_data} from {record_origin}", fire_event=False, ) - func_name = FUNCTION_MAPPING.get(analytic_data.data_type, None) - if func_name is None: - self.log_warning(fl_ctx, f"The data_type {analytic_data.data_type} is not supported.", fire_event=False) - return - - func = getattr(writer, func_name) - if analytic_data.step: - func(analytic_data.tag, analytic_data.value, analytic_data.step) + if ( + analytic_data.data_type == AnalyticsDataType.PARAMETER + or analytic_data.data_type == AnalyticsDataType.PARAMETERS + ): + data_records = self._convert_params_to_records(analytic_data) else: - func(analytic_data.tag, analytic_data.value) + data_records = [analytic_data] + + for data_record in data_records: + func_name = FUNCTION_MAPPING.get(data_record.data_type, None) + if func_name is None: + self.log_warning(fl_ctx, f"The data_type {data_record.data_type} is not supported.", fire_event=False) + return + + func = getattr(writer, func_name) + if data_record.step: + func(data_record.tag, data_record.value, data_record.step) + else: + func(data_record.tag, data_record.value) def finalize(self, fl_ctx: FLContext): for writer in self.writers_table.values():