Skip to content

Commit

Permalink
Fix full model shareable generator (#196)
Browse files Browse the repository at this point in the history
* Fix full model shareable generator

* Pring data_kind value in error message
  • Loading branch information
YuanTingHsieh authored Feb 11, 2022
1 parent f6df3ef commit b7e98b7
Showing 1 changed file with 19 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,38 @@


class FullModelShareableGenerator(ShareableGenerator):
def learnable_to_shareable(self, ml: ModelLearnable, fl_ctx: FLContext) -> Shareable:
"""Convert Learnable to Shareable.
def learnable_to_shareable(self, model_learnable: ModelLearnable, fl_ctx: FLContext) -> Shareable:
"""Convert ModelLearnable to Shareable.
Args:
model (Learnable): model to be converted
model_learnable (ModelLearnable): model to be converted
fl_ctx (FLContext): FL context
Returns:
Shareable: a shareable containing a DXO object,
Shareable: a shareable containing a DXO object.
"""
dxo = model_learnable_to_dxo(ml)
dxo = model_learnable_to_dxo(model_learnable)
return dxo.to_shareable()

def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> ModelLearnable:
"""Convert Shareable to Learnable.
"""Convert Shareable to ModelLearnable.
Supporting TYPE == TYPE_WEIGHT_DIFF or TYPE_WEIGHTS
Args:
shareable (Shareable): Shareable that contains a DXO object
fl_ctx (FLContext): FL context
Returns: a ModelLearnable object
Returns:
A ModelLearnable object
Raises:
TypeError: if shareable is not of type shareable
ValueError: if data_kind is not `DataKind.WEIGHTS` and is not `DataKind.WEIGHT_DIFF`
"""
if not isinstance(shareable, Shareable):
raise TypeError("shareable must be Shareable, but got {}.".format(type(shareable)))

base_model = fl_ctx.get_prop(AppConstants.GLOBAL_MODEL)
if not base_model:
self.system_panic(reason="No global base model!", fl_ctx=fl_ctx)
Expand All @@ -64,6 +72,10 @@ def shareable_to_learnable(self, shareable: Shareable, fl_ctx: FLContext) -> Mod
self.log_info(fl_ctx, "No model weights found. Model will not be updated.")
else:
base_model[ModelLearnableKey.WEIGHTS] = weights
else:
raise ValueError(
"data_kind should be either DataKind.WEIGHTS or DataKind.WEIGHT_DIFF, but got {}".format(dxo.data_kind)
)

base_model[ModelLearnableKey.META] = dxo.get_meta_props()
return base_model

0 comments on commit b7e98b7

Please sign in to comment.