Skip to content

Commit

Permalink
fix hello_world tf result printing (#2910)
Browse files Browse the repository at this point in the history
  • Loading branch information
SYangster authored Sep 5, 2024
1 parent 81b57ca commit 910179c
Showing 1 changed file with 36 additions and 16 deletions.
52 changes: 36 additions & 16 deletions examples/hello-world/hello_world.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -559,21 +559,30 @@
},
"outputs": [],
"source": [
"from nvflare.fuel.utils import fobs\n",
"from nvflare.app_common.decomposers import common_decomposers\n",
"import pprint\n",
"\n",
"# This example stores numpy arrays in FOBS format. Decomposers for Numpy is not registered automatically.\n",
"common_decomposers.register()\n",
"from nvflare.app_opt.tf.utils import flat_layer_weights_dict\n",
"from tensorflow.keras import layers, models\n",
"\n",
"result = sess.download_job_result(job_id)\n",
"with open(result + \"/workspace/app_server/tf_model.weights.h5\", \"rb\") as f:\n",
" bytes = f.read()\n",
"\n",
"weights = fobs.loads(bytes)\n",
"class Net(models.Sequential):\n",
" def __init__(self, input_shape=(None, 28, 28)):\n",
" super().__init__()\n",
" self._input_shape = input_shape\n",
" self.add(layers.Flatten())\n",
" self.add(layers.Dense(128, activation=\"relu\"))\n",
" self.add(layers.Dropout(0.2))\n",
" self.add(layers.Dense(10))\n",
"\n",
"model = Net()\n",
"model.build(input_shape=(None, 28, 28))\n",
"model.load_weights(result + \"/workspace/app_server/tf_model.weights.h5\")\n",
"model.summary()\n",
"\n",
"layer_weights_dict = flat_layer_weights_dict({layer.name: layer.get_weights() for layer in model.layers})\n",
"\n",
"pp = pprint.PrettyPrinter(indent=4)\n",
"pp.pprint(weights)"
"pp.pprint(layer_weights_dict)"
]
},
{
Expand Down Expand Up @@ -872,19 +881,30 @@
},
"outputs": [],
"source": [
"from nvflare.fuel.utils import fobs\n",
"from nvflare.app_common.decomposers import common_decomposers\n",
"import pprint\n",
"from nvflare.app_opt.tf.utils import flat_layer_weights_dict\n",
"from tensorflow.keras import layers, models\n",
"\n",
"common_decomposers.register()\n",
"result = sess.download_job_result(job_id)\n",
"with open(result + \"/workspace/app_server/tf_model.weights.h5\", \"rb\") as f:\n",
" bytes = f.read()\n",
"\n",
"weights = fobs.loads(bytes)\n",
"class Net(models.Sequential):\n",
" def __init__(self, input_shape=(None, 28, 28)):\n",
" super().__init__()\n",
" self._input_shape = input_shape\n",
" self.add(layers.Flatten())\n",
" self.add(layers.Dense(128, activation=\"relu\"))\n",
" self.add(layers.Dropout(0.2))\n",
" self.add(layers.Dense(10))\n",
"\n",
"model = Net()\n",
"model.build(input_shape=(None, 28, 28))\n",
"model.load_weights(result + \"/workspace/app_server/tf_model.weights.h5\")\n",
"model.summary()\n",
"\n",
"layer_weights_dict = flat_layer_weights_dict({layer.name: layer.get_weights() for layer in model.layers})\n",
"\n",
"pp = pprint.PrettyPrinter(indent=4)\n",
"pp.pprint(weights)"
"pp.pprint(layer_weights_dict)"
]
},
{
Expand Down

0 comments on commit 910179c

Please sign in to comment.