Skip to content

Commit

Permalink
Simplify code
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored and IsaacYangSLA committed Feb 2, 2024
1 parent 4fcf2ca commit 9abd390
Showing 1 changed file with 29 additions and 51 deletions.
80 changes: 29 additions & 51 deletions nvflare/app_opt/tracking/tb/tb_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,17 @@
}


def _create_new_data(key, value, sender):
if isinstance(value, (int, float)):
data_type = AnalyticsDataType.SCALAR
elif isinstance(value, str):
data_type = AnalyticsDataType.TEXT
else:
return None

return AnalyticsData(key=key, value=value, data_type=data_type, sender=sender)


class TBAnalyticsReceiver(AnalyticsReceiver):
def __init__(self, tb_folder="tb_events", events: Optional[List[str]] = None):
"""Receives analytics data to save to TensorBoard.
Expand Down Expand Up @@ -69,53 +80,26 @@ def initialize(self, fl_ctx: FLContext):
os.makedirs(root_log_dir, exist_ok=True)
self.root_log_dir = root_log_dir

def _convert_params_to_records(self, analytic_data: AnalyticsData) -> List[AnalyticsData]:
def _convert_to_records(self, analytic_data: AnalyticsData, fl_ctx: FLContext) -> List[AnalyticsData]:
# break dict of stuff to smaller items to support
# AnalyticsDataType.PARAMETER and AnalyticsDataType.PARAMETERS
if analytic_data.data_type == AnalyticsDataType.PARAMETER:
if isinstance(analytic_data.value, int) or isinstance(analytic_data.value, float):
return [
AnalyticsData(
key=analytic_data.tag,
value=float(analytic_data.value),
data_type=AnalyticsDataType.SCALAR,
sender=analytic_data.sender,
)
]
elif isinstance(analytic_data.value, str):
return [
AnalyticsData(
key=analytic_data.tag,
value=analytic_data.value,
data_type=AnalyticsDataType.TEXT,
sender=analytic_data.sender,
)
]
else:
return [analytic_data]
elif analytic_data.data_type == AnalyticsDataType.PARAMETERS:
records = []
for k, v in analytic_data.value.items():
if isinstance(v, int) or isinstance(v, float):
new_data = AnalyticsData(
key=analytic_data.tag,
value={k: float(v)},
data_type=AnalyticsDataType.SCALARS,
sender=analytic_data.sender,
)
elif isinstance(v, str):
new_data = AnalyticsData(
key=analytic_data.tag, value=v, data_type=AnalyticsDataType.TEXT, sender=analytic_data.sender
)
records = []

if analytic_data.data_type in (AnalyticsDataType.PARAMETER, AnalyticsDataType.PARAMETERS):
for k, v in (
analytic_data.value.items()
if analytic_data.data_type == AnalyticsDataType.PARAMETERS
else [(analytic_data.tag, analytic_data.value)]
):
new_data = _create_new_data(k, v, analytic_data.sender)
if new_data is None:
self.log_warning(fl_ctx, f"Entry {k} of type {type(v)} is not supported.", fire_event=False)
else:
new_data = AnalyticsData(
key=analytic_data.tag,
value={k: v},
data_type=AnalyticsDataType.PARAMETERS,
sender=analytic_data.sender,
)
records.append(new_data)
return records
records.append(new_data)
else:
records.append(analytic_data)

return records

def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin):
dxo = from_shareable(shareable)
Expand All @@ -135,13 +119,7 @@ def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin):
f"try to save data {analytic_data} from {record_origin}",
fire_event=False,
)
if (
analytic_data.data_type == AnalyticsDataType.PARAMETER
or analytic_data.data_type == AnalyticsDataType.PARAMETERS
):
data_records = self._convert_params_to_records(analytic_data)
else:
data_records = [analytic_data]
data_records = self._convert_to_records(analytic_data, fl_ctx)

for data_record in data_records:
func_name = FUNCTION_MAPPING.get(data_record.data_type, None)
Expand Down

0 comments on commit 9abd390

Please sign in to comment.