Skip to content

Commit

Permalink
Add support of param and params with str and float
Browse files Browse the repository at this point in the history
  • Loading branch information
YuanTingHsieh authored and IsaacYangSLA committed Feb 2, 2024
1 parent 2a8e805 commit 4fcf2ca
Showing 1 changed file with 67 additions and 10 deletions.
77 changes: 67 additions & 10 deletions nvflare/app_opt/tracking/tb/tb_receiver.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,54 @@ def initialize(self, fl_ctx: FLContext):
os.makedirs(root_log_dir, exist_ok=True)
self.root_log_dir = root_log_dir

def _convert_params_to_records(self, analytic_data: AnalyticsData) -> List[AnalyticsData]:
# break dict of stuff to smaller items to support
# AnalyticsDataType.PARAMETER and AnalyticsDataType.PARAMETERS
if analytic_data.data_type == AnalyticsDataType.PARAMETER:
if isinstance(analytic_data.value, int) or isinstance(analytic_data.value, float):
return [
AnalyticsData(
key=analytic_data.tag,
value=float(analytic_data.value),
data_type=AnalyticsDataType.SCALAR,
sender=analytic_data.sender,
)
]
elif isinstance(analytic_data.value, str):
return [
AnalyticsData(
key=analytic_data.tag,
value=analytic_data.value,
data_type=AnalyticsDataType.TEXT,
sender=analytic_data.sender,
)
]
else:
return [analytic_data]
elif analytic_data.data_type == AnalyticsDataType.PARAMETERS:
records = []
for k, v in analytic_data.value.items():
if isinstance(v, int) or isinstance(v, float):
new_data = AnalyticsData(
key=analytic_data.tag,
value={k: float(v)},
data_type=AnalyticsDataType.SCALARS,
sender=analytic_data.sender,
)
elif isinstance(v, str):
new_data = AnalyticsData(
key=analytic_data.tag, value=v, data_type=AnalyticsDataType.TEXT, sender=analytic_data.sender
)
else:
new_data = AnalyticsData(
key=analytic_data.tag,
value={k: v},
data_type=AnalyticsDataType.PARAMETERS,
sender=analytic_data.sender,
)
records.append(new_data)
return records

def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin):
dxo = from_shareable(shareable)
analytic_data = AnalyticsData.from_dxo(dxo)
Expand All @@ -84,19 +132,28 @@ def save(self, fl_ctx: FLContext, shareable: Shareable, record_origin):
# do different things depending on the type in dxo
self.log_debug(
fl_ctx,
f"save data {analytic_data} from {record_origin}",
f"try to save data {analytic_data} from {record_origin}",
fire_event=False,
)
func_name = FUNCTION_MAPPING.get(analytic_data.data_type, None)
if func_name is None:
self.log_warning(fl_ctx, f"The data_type {analytic_data.data_type} is not supported.", fire_event=False)
return

func = getattr(writer, func_name)
if analytic_data.step:
func(analytic_data.tag, analytic_data.value, analytic_data.step)
if (
analytic_data.data_type == AnalyticsDataType.PARAMETER
or analytic_data.data_type == AnalyticsDataType.PARAMETERS
):
data_records = self._convert_params_to_records(analytic_data)
else:
func(analytic_data.tag, analytic_data.value)
data_records = [analytic_data]

for data_record in data_records:
func_name = FUNCTION_MAPPING.get(data_record.data_type, None)
if func_name is None:
self.log_warning(fl_ctx, f"The data_type {data_record.data_type} is not supported.", fire_event=False)
return

func = getattr(writer, func_name)
if data_record.step:
func(data_record.tag, data_record.value, data_record.step)
else:
func(data_record.tag, data_record.value)

def finalize(self, fl_ctx: FLContext):
for writer in self.writers_table.values():
Expand Down

0 comments on commit 4fcf2ca

Please sign in to comment.