From 910179c62c83c51085c1f025b51f27b0d5644aeb Mon Sep 17 00:00:00 2001 From: Sean Yang Date: Wed, 4 Sep 2024 19:19:28 -0700 Subject: [PATCH] fix hello_world tf result printing (#2910) --- examples/hello-world/hello_world.ipynb | 52 ++++++++++++++++++-------- 1 file changed, 36 insertions(+), 16 deletions(-) diff --git a/examples/hello-world/hello_world.ipynb b/examples/hello-world/hello_world.ipynb index 5b0ef3ed04..efcdac1fbd 100644 --- a/examples/hello-world/hello_world.ipynb +++ b/examples/hello-world/hello_world.ipynb @@ -559,21 +559,30 @@ }, "outputs": [], "source": [ - "from nvflare.fuel.utils import fobs\n", - "from nvflare.app_common.decomposers import common_decomposers\n", "import pprint\n", - "\n", - "# This example stores numpy arrays in FOBS format. Decomposers for Numpy is not registered automatically.\n", - "common_decomposers.register()\n", + "from nvflare.app_opt.tf.utils import flat_layer_weights_dict\n", + "from tensorflow.keras import layers, models\n", "\n", "result = sess.download_job_result(job_id)\n", - "with open(result + \"/workspace/app_server/tf_model.weights.h5\", \"rb\") as f:\n", - " bytes = f.read()\n", "\n", - "weights = fobs.loads(bytes)\n", + "class Net(models.Sequential):\n", + " def __init__(self, input_shape=(None, 28, 28)):\n", + " super().__init__()\n", + " self._input_shape = input_shape\n", + " self.add(layers.Flatten())\n", + " self.add(layers.Dense(128, activation=\"relu\"))\n", + " self.add(layers.Dropout(0.2))\n", + " self.add(layers.Dense(10))\n", + "\n", + "model = Net()\n", + "model.build(input_shape=(None, 28, 28))\n", + "model.load_weights(result + \"/workspace/app_server/tf_model.weights.h5\")\n", + "model.summary()\n", + "\n", + "layer_weights_dict = flat_layer_weights_dict({layer.name: layer.get_weights() for layer in model.layers})\n", "\n", "pp = pprint.PrettyPrinter(indent=4)\n", - "pp.pprint(weights)" + "pp.pprint(layer_weights_dict)" ] }, { @@ -872,19 +881,30 @@ }, "outputs": [], "source": [ - "from nvflare.fuel.utils import fobs\n", - "from nvflare.app_common.decomposers import common_decomposers\n", "import pprint\n", + "from nvflare.app_opt.tf.utils import flat_layer_weights_dict\n", + "from tensorflow.keras import layers, models\n", "\n", - "common_decomposers.register()\n", "result = sess.download_job_result(job_id)\n", - "with open(result + \"/workspace/app_server/tf_model.weights.h5\", \"rb\") as f:\n", - " bytes = f.read()\n", "\n", - "weights = fobs.loads(bytes)\n", + "class Net(models.Sequential):\n", + " def __init__(self, input_shape=(None, 28, 28)):\n", + " super().__init__()\n", + " self._input_shape = input_shape\n", + " self.add(layers.Flatten())\n", + " self.add(layers.Dense(128, activation=\"relu\"))\n", + " self.add(layers.Dropout(0.2))\n", + " self.add(layers.Dense(10))\n", + "\n", + "model = Net()\n", + "model.build(input_shape=(None, 28, 28))\n", + "model.load_weights(result + \"/workspace/app_server/tf_model.weights.h5\")\n", + "model.summary()\n", + "\n", + "layer_weights_dict = flat_layer_weights_dict({layer.name: layer.get_weights() for layer in model.layers})\n", "\n", "pp = pprint.PrettyPrinter(indent=4)\n", - "pp.pprint(weights)" + "pp.pprint(layer_weights_dict)" ] }, {