diff --git a/examples/notebooks/ogbn_mag_e2e.ipynb b/examples/notebooks/ogbn_mag_e2e.ipynb index 0b773ad7..c59e1803 100644 --- a/examples/notebooks/ogbn_mag_e2e.ipynb +++ b/examples/notebooks/ogbn_mag_e2e.ipynb @@ -102,14 +102,14 @@ "base_uri": "https://localhost:8080/" }, "id": "oA4_zh0EyNHv", - "outputId": "78cd1fa3-45f7-40f1-895b-d3f5b7dac7da" + "outputId": "4e3e16b7-64dd-4516-99da-8cea252750d8" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running TF-GNN 0.6.0 under TensorFlow 2.12.0.\n" + "Running TF-GNN 1.0.2 under TensorFlow 2.12.0.\n" ] } ], @@ -187,7 +187,7 @@ "source": [ "### Approach\n", "\n", - "OGBN-MAG asks to classify each of the \"paper\" nodes. The number of nodes is on the order of a million, and we intuit that the most informative other nodes are found just a few hops away (cited papers, papers with overlapping authors, etc.).\n", + "OGBN-MAG asks to classify each of the \"paper\" nodes. The number of nodes is on the order of a million, and we intuit that the most informative other nodes are found just a few hops away (cited papers, papers with a common author, etc.).\n", "\n", "Therefore, and to stay scalable for even bigger datasets, we approach this task with **graph sampling**: Each \"paper\" node becomes one training example, expressed by a subgraph that has the node to be classified as its root and stores a sample of its neighborhood in the original graph. The sample is taken by going out a fixed number of steps along specific edge sets, and randomly downsampling the edges in each step if they are too numerous.\n", "\n", @@ -220,7 +220,7 @@ "id": "kFF1w8sGzM6m" }, "source": [ - "We provide the entire OGBN-MAG graph data casted as a TF-GNN graph tensor as input to the graph sampler. The command below loads the entire OGBN-MAG as a single graph tensor from the already-saved serialized Tensorflow Example message (subject to this [license](https://storage.googleapis.com/download.tensorflow.org/data/ogbn-mag/npz/LICENSE.txt)). Additionally, it loads the supporting OGBN-MAG graph schema.\n" + "We provide the entire OGBN-MAG graph data cast as a TF-GNN graph tensor as input to the graph sampler. The command below loads the entire OGBN-MAG as a single graph tensor from the already-saved serialized Tensorflow Example message (subject to this [license](https://storage.googleapis.com/download.tensorflow.org/data/ogbn-mag/npz/LICENSE.txt)). Additionally, it loads the supporting OGBN-MAG graph schema.\n" ] }, { @@ -339,8 +339,7 @@ "id": "hZpeVtalnkHc" }, "source": [ - "## Data Split Preparation\n", - "\n", + "### Data Split Preparation\n", "\n", "Under [OGB rules](https://ogb.stanford.edu/docs/leader_rules/), we can sample subgraphs for the training, validation and test dataset from the full graph, just with different seed nodes, selected by the year of publication. We define the `seed_dataset` responsible for providing the seeds for the different splits. (Models for production systems should probably use separate validation and test data, to prevent leakage of their seed nodes into the sampled subgraphs of other splits.)" ] @@ -450,14 +449,14 @@ "base_uri": "https://localhost:8080/" }, "id": "2oBuJEZ3izQm", - "outputId": "75f00503-38a9-49ab-ea7a-3b4427a31f1e" + "outputId": "680d981b-9ee6-4ffe-d696-c70110edadca" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "Running on TPU ['10.31.81.194:8470']\n", + "Running on TPU ['10.113.155.218:8470']\n", "Using TPUStrategy\n", "Found 8 replicas in sync\n" ] @@ -609,7 +608,7 @@ " return {\"hashed_id\": tf.keras.layers.Hashing(6_500)(node_set[\"#id\"])}\n", " if node_set_name == \"paper\":\n", " # Keep `labels` for eventual extraction.\n", - " return {\"feat\": node_set[\"feat\"], \"labels\": node_set[\"label\"]}\n", + " return {\"feat\": node_set[\"feat\"], \"label\": node_set[\"label\"]}\n", " if node_set_name == \"author\":\n", " return {\"empty_state\": tfgnn.keras.layers.MakeEmptyFeature()(node_set)}\n", " raise KeyError(f\"Unexpected node_set_name='{node_set_name}'\")\n", @@ -618,12 +617,62 @@ "def drop_all_features(_, **unused_kwargs):\n", " return {}\n", "\n", - "# The combined feature mapping of context, edges and nodes\n", - "# is all the preprocessing we need for this dataset.\n", + "# The combined feature preprocessing of context, edges and nodes.\n", + "process_features = tfgnn.keras.layers.MapFeatures(\n", + " context_fn=drop_all_features,\n", + " node_sets_fn=process_node_features,\n", + " edge_sets_fn=drop_all_features)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "Tju6tg3swpI5" + }, + "source": [ + "### Readout structure and labels\n", + "\n", + "GNNs can be applied to a wide range of problems, and it depends on the problem\n", + "which nodes have the hidden state(s) from which a prediction can be made. For node classification on sampled subgraphs, we want to read out the final hidden state from the seed node of each subgraph. By convention, the sampler stores the seed as the first `\"paper\"` node in each subgraph, but recall there are multiple of them in a training batch.\n", + "\n", + "The `AddReadoutFromFirstNode` helper lets us express readout from seeds by adding the following **readout structure** (explained further in the [Schema](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/schema.md#about-labels-and-reading-out-the-final-gnn-states) and [Data Preparation](https://github.com/tensorflow/gnn/blob/main/tensorflow_gnn/docs/guide/data_prep.md#readout) guides):\n", + "\n", + " * a node set `\"_readout\"` with as many nodes as there are sampled subgraphs in the training batch;\n", + " * an edge set `\"_readout/seed\"` that connects the seed of the *i*-th sampled subgraph to the *i*-th readout node.\n", + "\n", + "The GNN model itself ignores auxiliary graph pieces like these whose names starts with an underscore.\n", + "\n", + "The readout structure is also useful for handling labels: Originally provided as node features, the labels need to be read out from the seed nodes as well, and they need to be removed from the node features seen by the model. The `StructuredReadoutIntoFeature` helper does just that: it creates a new feature with the read-out labels on the `\"_readout\"` node set (conveniently aligned with the eventual predictions) and optionally deletes the original label feature.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "VQJgy_ypHfBe" + }, + "outputs": [], + "source": [ + "add_readout = tfgnn.keras.layers.AddReadoutFromFirstNode(\n", + " \"seed\", node_set_name=\"paper\")\n", + "move_label_to_readout = tfgnn.keras.layers.StructuredReadoutIntoFeature(\n", + " \"seed\", feature_name=\"label\", new_feature_name=\"paper_venue\",\n", + " remove_input_feature=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_btPx4aEHe_E" + }, + "outputs": [], + "source": [ + "# The complete list of feature processors.\n", "feature_processors = [\n", - " tfgnn.keras.layers.MapFeatures(context_fn=drop_all_features,\n", - " node_sets_fn=process_node_features,\n", - " edge_sets_fn=drop_all_features),\n", + " process_features,\n", + " add_readout,\n", + " move_label_to_readout,\n", "]" ] }, @@ -635,13 +684,13 @@ "source": [ "## Model Architecture\n", "\n", - "Typically, a model with a GNN architecture at its core consists of three parts:\n", + "Typically, a model with a GNN architecture as its base consists of three parts:\n", "\n", "1. The initialization of hidden states on nodes (and possibly also edges and/or the graph context) from their respective preprocessed features.\n", "2. The base Graph Neural Network: several rounds of updating hidden states from neighboring items in the graph.\n", - "3. The readout of one or more hidden states into some prediction head, such as a linear classifier.\n", + "3. A prediction head, such as a linear classifier, applied to the final hidden states read out from the nodes of interest.\n", "\n", - "We are going to use one model for training, validation, and export for inference, so we need to build it from an input type spec with generic tensor shapes. (For TPUs, it's good enough to use it on a *dataset* with fixed-size elements.) Before defining the base Graph Neural Network, we show how to initialize the hidden states of all the necessary components (nodes, edges and context) given the pre-processed features." + "We are going to use one model for training, validation, and export for inference, so we need to build it from an input type spec with generic tensor shapes. (For TPUs, it's good enough if each *dataset* it gets called on has fixed-size elements.) Before defining the base Graph Neural Network, we show how to initialize the hidden states of all the necessary components (nodes, edges and context) given the pre-processed features." ] }, { @@ -754,8 +803,9 @@ "source": [ "### The Task\n", "\n", - "A Task collects the ancillary pieces for training a Keras model\n", - "with the graph learning objective. It also provides losses and metrics for that objective. Common implementations for classification and regression (by graph or root node) are provided in TF-GNN library.\n" + "A `Task` object defines the learning objective for a GNN and defines the model pieces that are needed around the `model_fn` to adapt it to the task at hand.\n", + "\n", + "The library defines `Task` subclasses for a variety of standard prediction tasks, including the suitable training loss and matching metrics. Here, we use `NodeMulticlassClassification`, because our task is to predict one of many mutually exclusive classes (venues) for the nodes (papers) of interest. The `Node*` tasks rely on the readout structure in model's input graph to identify the nodes of interest and to provide the labels for them. (Recall how the code above set up structured readout from the root nodes of sampled subgraphs and moved their labels there.)" ] }, { @@ -766,11 +816,9 @@ }, "outputs": [], "source": [ - "label_fn = runner.RootNodeLabelFn(node_set_name=\"paper\", feature_name=\"labels\")\n", - "task = runner.RootNodeMulticlassClassification(\n", - " node_set_name=\"paper\",\n", + "task = runner.NodeMulticlassClassification(\n", " num_classes=349,\n", - " label_fn=label_fn)" + " label_feature_name=\"paper_venue\")" ] }, { @@ -816,7 +864,7 @@ " model_dir=\"/tmp/gnn_model/\",\n", " callbacks=None,\n", " steps_per_epoch=steps_per_epoch,\n", - " validation_steps=validation_steps,\n", + " validation_steps=validation_steps, # \u003c\u003c\u003c Remove if not training for real.\n", " restore_best_weights=False,\n", " checkpoint_every_n_steps=\"never\",\n", " summarize_every_n_steps=\"never\",\n", @@ -846,7 +894,8 @@ "outputs": [], "source": [ "save_options = tf.saved_model.SaveOptions(experimental_io_device=\"/job:localhost\")\n", - "model_exporter = runner.KerasModelExporter(options=save_options)" + "model_exporter = runner.KerasModelExporter(output_names=\"paper_venue_logits\",\n", + " options=save_options)" ] }, { @@ -874,7 +923,7 @@ "base_uri": "https://localhost:8080/" }, "id": "Ay2hhL3d0dZz", - "outputId": "38d76c8d-4845-4e3a-c956-6697b24152c9" + "outputId": "6f0d261d-7bea-4b73-ba6d-f5fccbe2b249" }, "outputs": [ { @@ -882,70 +931,61 @@ "output_type": "stream", "text": [ "Epoch 1/10\n", - "4918/4918 [==============================] - 141s 29ms/step - loss: 2.5981 - sparse_categorical_accuracy: 0.3266 - sparse_categorical_crossentropy: 2.7027 - val_loss: 2.1224 - val_sparse_categorical_accuracy: 0.4123 - val_sparse_categorical_crossentropy: 2.1855\n", + "4918/4918 [==============================] - 138s 28ms/step - loss: 2.6087 - sparse_categorical_accuracy: 0.3227 - sparse_categorical_crossentropy: 2.7150 - val_loss: 2.0895 - val_sparse_categorical_accuracy: 0.4181 - val_sparse_categorical_crossentropy: 2.1522\n", "Epoch 2/10\n", - "4918/4918 [==============================] - 90s 18ms/step - loss: 2.0927 - sparse_categorical_accuracy: 0.4256 - sparse_categorical_crossentropy: 2.1475 - val_loss: 1.9497 - val_sparse_categorical_accuracy: 0.4474 - val_sparse_categorical_crossentropy: 1.9906\n", + "4918/4918 [==============================] - 90s 18ms/step - loss: 2.1015 - sparse_categorical_accuracy: 0.4214 - sparse_categorical_crossentropy: 2.1580 - val_loss: 1.9503 - val_sparse_categorical_accuracy: 0.4458 - val_sparse_categorical_crossentropy: 1.9923\n", "Epoch 3/10\n", - "4918/4918 [==============================] - 89s 18ms/step - loss: 1.9615 - sparse_categorical_accuracy: 0.4556 - sparse_categorical_crossentropy: 1.9998 - val_loss: 1.9174 - val_sparse_categorical_accuracy: 0.4524 - val_sparse_categorical_crossentropy: 1.9509\n", + "4918/4918 [==============================] - 90s 18ms/step - loss: 1.9663 - sparse_categorical_accuracy: 0.4530 - sparse_categorical_crossentropy: 2.0059 - val_loss: 1.8771 - val_sparse_categorical_accuracy: 0.4692 - val_sparse_categorical_crossentropy: 1.9089\n", "Epoch 4/10\n", - "4918/4918 [==============================] - 89s 18ms/step - loss: 1.8757 - sparse_categorical_accuracy: 0.4744 - sparse_categorical_crossentropy: 1.9056 - val_loss: 1.8417 - val_sparse_categorical_accuracy: 0.4782 - val_sparse_categorical_crossentropy: 1.8690\n", + "4918/4918 [==============================] - 92s 19ms/step - loss: 1.8823 - sparse_categorical_accuracy: 0.4733 - sparse_categorical_crossentropy: 1.9131 - val_loss: 1.8625 - val_sparse_categorical_accuracy: 0.4673 - val_sparse_categorical_crossentropy: 1.8919\n", "Epoch 5/10\n", - "4918/4918 [==============================] - 89s 18ms/step - loss: 1.8065 - sparse_categorical_accuracy: 0.4908 - sparse_categorical_crossentropy: 1.8320 - val_loss: 1.8294 - val_sparse_categorical_accuracy: 0.4797 - val_sparse_categorical_crossentropy: 1.8571\n", + "4918/4918 [==============================] - 91s 18ms/step - loss: 1.8097 - sparse_categorical_accuracy: 0.4893 - sparse_categorical_crossentropy: 1.8360 - val_loss: 1.8237 - val_sparse_categorical_accuracy: 0.4772 - val_sparse_categorical_crossentropy: 1.8517\n", "Epoch 6/10\n", - "4918/4918 [==============================] - 89s 18ms/step - loss: 1.7430 - sparse_categorical_accuracy: 0.5035 - sparse_categorical_crossentropy: 1.7666 - val_loss: 1.7843 - val_sparse_categorical_accuracy: 0.4924 - val_sparse_categorical_crossentropy: 1.8119\n", + "4918/4918 [==============================] - 89s 18ms/step - loss: 1.7476 - sparse_categorical_accuracy: 0.5030 - sparse_categorical_crossentropy: 1.7720 - val_loss: 1.8024 - val_sparse_categorical_accuracy: 0.4844 - val_sparse_categorical_crossentropy: 1.8317\n", "Epoch 7/10\n", - "4918/4918 [==============================] - 89s 18ms/step - loss: 1.6875 - sparse_categorical_accuracy: 0.5165 - sparse_categorical_crossentropy: 1.7104 - val_loss: 1.7536 - val_sparse_categorical_accuracy: 0.4978 - val_sparse_categorical_crossentropy: 1.7819\n", + "4918/4918 [==============================] - 91s 19ms/step - loss: 1.6904 - sparse_categorical_accuracy: 0.5155 - sparse_categorical_crossentropy: 1.7141 - val_loss: 1.7533 - val_sparse_categorical_accuracy: 0.4953 - val_sparse_categorical_crossentropy: 1.7822\n", "Epoch 8/10\n", - "4918/4918 [==============================] - 91s 18ms/step - loss: 1.6454 - sparse_categorical_accuracy: 0.5260 - sparse_categorical_crossentropy: 1.6680 - val_loss: 1.7462 - val_sparse_categorical_accuracy: 0.4992 - val_sparse_categorical_crossentropy: 1.7761\n", + "4918/4918 [==============================] - 91s 18ms/step - loss: 1.6432 - sparse_categorical_accuracy: 0.5268 - sparse_categorical_crossentropy: 1.6663 - val_loss: 1.7408 - val_sparse_categorical_accuracy: 0.4973 - val_sparse_categorical_crossentropy: 1.7709\n", "Epoch 9/10\n", - "4918/4918 [==============================] - 90s 18ms/step - loss: 1.6120 - sparse_categorical_accuracy: 0.5341 - sparse_categorical_crossentropy: 1.6340 - val_loss: 1.7419 - val_sparse_categorical_accuracy: 0.4998 - val_sparse_categorical_crossentropy: 1.7724\n", + "4918/4918 [==============================] - 89s 18ms/step - loss: 1.6137 - sparse_categorical_accuracy: 0.5331 - sparse_categorical_crossentropy: 1.6364 - val_loss: 1.7367 - val_sparse_categorical_accuracy: 0.4992 - val_sparse_categorical_crossentropy: 1.7675\n", "Epoch 10/10\n", - "4918/4918 [==============================] - 89s 18ms/step - loss: 1.5973 - sparse_categorical_accuracy: 0.5373 - sparse_categorical_crossentropy: 1.6189 - val_loss: 1.7394 - val_sparse_categorical_accuracy: 0.5018 - val_sparse_categorical_crossentropy: 1.7699\n" + "4918/4918 [==============================] - 91s 19ms/step - loss: 1.6014 - sparse_categorical_accuracy: 0.5355 - sparse_categorical_crossentropy: 1.6238 - val_loss: 1.7309 - val_sparse_categorical_accuracy: 0.4997 - val_sparse_categorical_crossentropy: 1.7614\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "WARNING:absl:Found untraced functions such as _update_step_xla, node_set_update_layer_call_fn, node_set_update_layer_call_and_return_conditional_losses, node_set_update_1_layer_call_fn, node_set_update_1_layer_call_and_return_conditional_losses while saving (showing 5 of 143). These functions will not be directly callable after loading.\n", - "/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/nested_structure_coder.py:497: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.GraphTensorSpec; loading this StructuredValue will require that this type be imported and registered.\n", - " warnings.warn(\"Encoding a StructuredValue with type %s; loading this \"\n", - "/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/nested_structure_coder.py:497: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.ContextSpec.v2; loading this StructuredValue will require that this type be imported and registered.\n", - " warnings.warn(\"Encoding a StructuredValue with type %s; loading this \"\n", - "/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/nested_structure_coder.py:497: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.NodeSetSpec; loading this StructuredValue will require that this type be imported and registered.\n", - " warnings.warn(\"Encoding a StructuredValue with type %s; loading this \"\n", - "/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/nested_structure_coder.py:497: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.EdgeSetSpec; loading this StructuredValue will require that this type be imported and registered.\n", - " warnings.warn(\"Encoding a StructuredValue with type %s; loading this \"\n", - "/usr/local/lib/python3.10/dist-packages/tensorflow/python/saved_model/nested_structure_coder.py:497: UserWarning: Encoding a StructuredValue with type tensorflow_gnn.AdjacencySpec; loading this StructuredValue will require that this type be imported and registered.\n", - " warnings.warn(\"Encoding a StructuredValue with type %s; loading this \"\n" + "WARNING:absl:Found untraced functions such as _update_step_xla while saving (showing 1 of 1). These functions will not be directly callable after loading.\n" ] }, { "data": { "text/plain": [ - "RunResult(preprocess_model=\u003ckeras.engine.functional.Functional object at 0x7df54cbbe6b0\u003e, base_model=\u003ckeras.engine.sequential.Sequential object at 0x7df54bebe920\u003e, trained_model=\u003ckeras.engine.functional.Functional object at 0x7df54bbaff10\u003e)" + "RunResult(preprocess_model=\u003ckeras.engine.functional.Functional object at 0x7b8b4c62beb0\u003e, base_model=\u003ckeras.engine.sequential.Sequential object at 0x7b8b500d9240\u003e, trained_model=\u003ckeras.engine.functional.Functional object at 0x7b8b4fc4d840\u003e)" ] }, - "execution_count": 15, + "execution_count": 17, "metadata": {}, "output_type": "execute_result" } ], "source": [ "runner.run(\n", + " gtspec=example_input_graph_spec,\n", " train_ds_provider=train_ds_provider,\n", " train_padding=train_padding,\n", + " valid_ds_provider=valid_ds_provider, # \u003c\u003c\u003c Remove if not training for real.\n", + " valid_padding=valid_padding, # \u003c\u003c\u003c Remove if not training for real.\n", + " global_batch_size=global_batch_size,\n", + " epochs=epochs,\n", + " feature_processors=feature_processors,\n", " model_fn=model_fn,\n", + " task=task,\n", " optimizer_fn=optimizer_fn,\n", - " epochs=epochs,\n", " trainer=trainer,\n", - " task=task,\n", - " gtspec=example_input_graph_spec,\n", - " global_batch_size=global_batch_size,\n", " model_exporters=[model_exporter],\n", - " feature_processors=feature_processors,\n", - " valid_ds_provider=valid_ds_provider, # \u003c\u003c\u003c Remove if not training for real.\n", - " valid_padding=valid_padding)" + ")" ] }, { @@ -971,47 +1011,53 @@ "base_uri": "https://localhost:8080/" }, "id": "ki33s9EpsQnF", - "outputId": "89ac34e7-7129-4bf8-f968-dd0f975f882a" + "outputId": "2f45a98b-70d0-4ebc-d29f-741f30ae80f4" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The predicted class for input 0 is 289 with predicted probability 0.3882\n", - "The predicted class for input 1 is 281 with predicted probability 0.4415\n", - "The predicted class for input 2 is 189 with predicted probability 0.3256\n", - "The predicted class for input 3 is 158 with predicted probability 0.7522\n", - "The predicted class for input 4 is 82 with predicted probability 0.2598\n", - "The predicted class for input 5 is 247 with predicted probability 0.8446\n", - "The predicted class for input 6 is 209 with predicted probability 0.4843\n", - "The predicted class for input 7 is 247 with predicted probability 0.672\n", - "The predicted class for input 8 is 192 with predicted probability 0.5376\n", - "The predicted class for input 9 is 311 with predicted probability 0.8969\n" + "The predicted class for input 0 is 9 with predicted probability 0.346\n", + "The predicted class for input 1 is 281 with predicted probability 0.5623\n", + "The predicted class for input 2 is 189 with predicted probability 0.2929\n", + "The predicted class for input 3 is 158 with predicted probability 0.9645\n", + "The predicted class for input 4 is 200 with predicted probability 0.1749\n", + "The predicted class for input 5 is 247 with predicted probability 0.9088\n", + "The predicted class for input 6 is 209 with predicted probability 0.5486\n", + "The predicted class for input 7 is 189 with predicted probability 0.5403\n", + "The predicted class for input 8 is 192 with predicted probability 0.5332\n", + "The predicted class for input 9 is 311 with predicted probability 0.7223\n" ] } ], "source": [ + "# Load model.\n", "load_options = tf.saved_model.LoadOptions(experimental_io_device=\"/job:localhost\")\n", "saved_model = tf.saved_model.load(os.path.join(trainer.model_dir, \"export\"),\n", " options=load_options)\n", + "signature_fn = saved_model.signatures[\n", + " tf.saved_model.DEFAULT_SERVING_SIGNATURE_DEF_KEY]\n", "\n", "def _clean_example_for_serving(graph_tensor):\n", + " graph_tensor = graph_tensor.remove_features(node_sets={\"paper\": [\"label\"]})\n", " serialized_example = tfgnn.write_example(graph_tensor)\n", " return serialized_example.SerializeToString()\n", "\n", "# Convert 10 examples to serialized string format.\n", "num_examples = 10\n", "demo_ds = valid_ds_provider.get_dataset(tf.distribute.InputContext())\n", - "serialized_examples = [_clean_example_for_serving(gt) for gt in itertools.islice(demo_ds, num_examples)]\n", + "serialized_examples = [_clean_example_for_serving(gt)\n", + " for gt in itertools.islice(demo_ds, num_examples)]\n", "\n", "# Inference on 10 examples\n", "ds = tf.data.Dataset.from_tensor_slices(serialized_examples)\n", - "kwargs = {\"examples\": next(iter(ds.batch(10)))}\n", - "output = saved_model.signatures[\"serving_default\"](**kwargs)\n", + "# The name \"examples\" for serialized tf.Example protos is defined by the runner.\n", + "input_dict = {\"examples\": next(iter(ds.batch(10)))}\n", "\n", - "# Outputs are in the form of logits\n", - "logits = next(iter(output.values()))\n", + "# Outputs are in the form of logits.\n", + "output_dict = signature_fn(**input_dict)\n", + "logits = output_dict[\"paper_venue_logits\"] # As configured above.\n", "probabilities = tf.math.softmax(logits).numpy()\n", "classes = probabilities.argmax(axis=1)\n", "\n",