forked from NVIDIA/NVFlare
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add notebook for gnn examples (NVIDIA#2289)
Co-authored-by: Chester Chen <[email protected]>
- Loading branch information
1 parent
e4f8d41
commit 9a7942e
Showing
5 changed files
with
343 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,336 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "cada310b-e776-4b9a-aabe-f111c31efcc2", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"source": [ | ||
"# Federated GNN on Graph Dataset using Inductive Learning" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "0653cbf2-92f2-4a22-8317-69cfb0266e92", | ||
"metadata": {}, | ||
"source": [ | ||
"## Introduction to GNN, Tasks, and federated GNN via Inductive Learning\n", | ||
"### GNN\n", | ||
"This example shows how to train a classification model using Graph Neural Network (GNN). GNNs show a promising future in research and industry, with potential applications in various domains, including social networks, e-commerce, recommendation systems, and more.\n", | ||
"GNNs excel in learning, modeling, and leveraging complex relationships within graph-structured data. They combine local and global information, incorporate structural knowledge, adapt to diverse tasks, handle heterogeneous data, support transfer learning, scale for large graphs, offer interpretable insights, and achieve impressive performance. \n", | ||
"\n", | ||
"### Tasks\n", | ||
"In this example, we provide two tasks:\n", | ||
"1. **Protein Classification**:\n", | ||
"The aim is to classify protein roles based on their cellular functions from gene ontology. The dataset we are using is PPI\n", | ||
"([protein-protein interaction](http://snap.stanford.edu/graphsage/#code)) graphs, where each graph represents a specific human tissue. Protein-protein interaction (PPI) dataset is commonly used in graph-based machine-learning tasks, especially in the field of bioinformatics. This dataset represents interactions between proteins as graphs, where nodes represent proteins and edges represent interactions between them.\n", | ||
"2. **Financial Transaction Classification**:\n", | ||
"The aim is to classify whether a given transaction is licit or illicit. For this financial application, we use the [Elliptic++](https://github.com/git-disl/EllipticPlusPlus) dataset. It consists of 203k Bitcoin transactions and 822k wallet addresses to enable both the detection of fraudulent transactions and the detection of illicit addresses (actors) in the Bitcoin network by leveraging graph data. For more details, please refer to this [paper](https://arxiv.org/pdf/2306.06108.pdf).\n", | ||
"\n", | ||
"\n", | ||
"### Federated GNN via Inductive Learning\n", | ||
"Both tasks are for node classification. We used the inductive representation learning method [GraphSAGE](https://arxiv.org/pdf/1706.02216.pdf) based on [Pytorch Geometric](https://github.com/pyg-team/pytorch_geometric)'s examples. \n", | ||
"[Pytorch Geometric](https://pytorch-geometric.readthedocs.io/en/latest/) is a library built upon PyTorch to easily write and train Graph Neural Networks (GNNs) for a wide range of applications related to structured data.\n", | ||
"\n", | ||
"For protein classification task, we used it in an unsupervised manner, following [PyG's unsupervised PPI example](https://github.com/pyg-team/pytorch_geometric/blob/master/examples/graph_sage_unsup_ppi.py).\n", | ||
"For financial transaction classification task, we used it in a supervised manner, directly using the node labels with supervised classification loss.\n", | ||
"\n", | ||
"Since the inductive learning mode is being used, the locally learnt model (a representation encoding / classification network) is irrelevant to the candidate graph, we are able to use the basic [FedAvg](https://arxiv.org/abs/1602.05629) as the federated learning algorithm. The workflow is Scatter and Gather (SAG).\n", | ||
"\n", | ||
"\n", | ||
"Below we listed steps to run this example." | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "a5a0292c-78b6-4bde-96d6-699dae996173", | ||
"metadata": {}, | ||
"source": [ | ||
"## 1. Setup NVFLARE\n", | ||
"\n", | ||
"Follow the [Getting_Started](https://nvflare.readthedocs.io/en/main/getting_started.html) to setup virtual environment and install NVFLARE\n", | ||
"\n", | ||
"We also provide a [Notebook](../../nvflare_setup.ipynb) for this setup process. \n", | ||
"\n", | ||
"Assume you have already setup the venv, lets first install required packages." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f4130b15-09e6-456f-a3c7-87c8ee9e07f0", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install -r requirements.txt" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1d872d8a-9e44-49dd-94b1-7862b3815ffe", | ||
"metadata": {}, | ||
"source": [ | ||
"To support functions of PyTorch Geometric necessary for this example, we need extra dependencies. Please refer to [installation guide](https://pytorch-geometric.readthedocs.io/en/latest/install/installation.html) and install accordingly:" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f906a1c9-dce0-476c-be65-79ebd8ad5da9", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%pip install pyg_lib torch_scatter torch_sparse torch_cluster torch_spline_conv -f https://data.pyg.org/whl/torch-2.1.0+cpu.html" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "e93b1bf2-6157-4ab6-9766-819450304038", | ||
"metadata": {}, | ||
"source": [ | ||
"## 2. Data preparation \n", | ||
"This example uses two datasets: \n", | ||
"- For Protein Classification, the PPI dataset is available from torch_geometric's dataset API. \n", | ||
"- For Financial Transaction Classification, we first download the [Elliptic++](https://github.com/git-disl/EllipticPlusPlus) dataset to `/tmp/nvflare/datasets/elliptic_pp` folder. In this example, we will use the following three files:\n", | ||
" - `txs_classes.csv`: transaction id and its class (licit or illicit)\n", | ||
" - `txs_edgelist.csv`: connections for transaction ids \n", | ||
" - `txs_features.csv`: transaction id and its features" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "af257e69-2bb7-49b6-ac6c-f007b0e6618e", | ||
"metadata": {}, | ||
"source": [ | ||
"## 3. Local Experiments\n", | ||
"For comparison with federated learning results, we first perform local experiments on each client's data and the whole dataset. Here we simulate 2 clients with uniform data split (client_id = 0 means the whole dataset). The 6 experiments will take a while to finish. The default epoch number is set to 70. " | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "fdb7290a-48ff-4e80-be58-5e6b0e0f9379", | ||
"metadata": { | ||
"scrolled": true, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"! python3 code/graphsage_protein_local.py --client_id 0\n", | ||
"! python3 code/graphsage_protein_local.py --client_id 1\n", | ||
"! python3 code/graphsage_protein_local.py --client_id 2 " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "9a2d55cf-4f7a-4030-8cba-b1619fdf1614", | ||
"metadata": {}, | ||
"source": [ | ||
"And for finance experiment" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "f4cf2c09-1f78-4d28-9b86-af9f9cf86479", | ||
"metadata": { | ||
"scrolled": true, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"! python3 code/graphsage_finance_local.py --client_id 0\n", | ||
"! python3 code/graphsage_finance_local.py --client_id 1\n", | ||
"! python3 code/graphsage_finance_local.py --client_id 2 " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d178c6dc-c180-4ca6-8dea-3b0fe147665b", | ||
"metadata": {}, | ||
"source": [ | ||
"## 4. Prepare NVFlare job based on GNN template\n", | ||
"We are using NVFlare's FL simulator to run the FL experiments. First, we create jobs using GNN template. We reuse the job templates from [sag_gnn](../../../job_templates/sag_gnn), let's set the job template path with the following command." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "fd885e6b-ae4d-40aa-b89d-fe34217ad3da", | ||
"metadata": { | ||
"scrolled": true, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"! nvflare config -jt ../../../job_templates/" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "f608a992-5096-4452-8775-b89987970a75", | ||
"metadata": {}, | ||
"source": [ | ||
"Then we can check the available templates with the following command." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "1f8041c7-7fae-4c8a-8e07-1c6a6d59e541", | ||
"metadata": { | ||
"scrolled": true, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"! nvflare job list_templates" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "5bad4f55-d582-4f37-a523-927dc015e564", | ||
"metadata": {}, | ||
"source": [ | ||
"We shall see `sag_gnn` from the above command. We then create jobs using this template, and set local epochs to 10 with 7 rounds of FL to match local experiments' 70 epoch default training." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "7971a800-70fc-4213-96ed-c157801b5a11", | ||
"metadata": { | ||
"scrolled": true, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"! nvflare job create -force -j \"/tmp/nvflare/jobs/gnn_protein\" -w \"sag_gnn\" -sd \"code\" \\\n", | ||
" -f app_1/config_fed_client.conf app_script=\"graphsage_protein_fl.py\" app_config=\"--client_id 1 --epochs 10\" \\\n", | ||
" -f app_2/config_fed_client.conf app_script=\"graphsage_protein_fl.py\" app_config=\"--client_id 2 --epochs 10\" \\\n", | ||
" -f app_server/config_fed_server.conf num_rounds=7 key_metric=\"validation_f1\" model_class_path=\"torch_geometric.nn.GraphSAGE\" components[0].args.model.args.in_channels=50 components[0].args.model.args.hidden_channels=64 components[0].args.model.args.num_layers=2 components[0].args.model.args.out_channels=64 " | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "675bff95-dcfa-4a47-9a05-460da16760ef", | ||
"metadata": {}, | ||
"source": [ | ||
"And for finance experiment" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a6d0b643-31f0-4d52-ae3c-1fafcd404072", | ||
"metadata": { | ||
"scrolled": true, | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"! nvflare job create -force -j \"/tmp/nvflare/jobs/gnn_finance\" -w \"sag_gnn\" -sd \"code\" \\\n", | ||
" -f app_1/config_fed_client.conf app_script=\"graphsage_finance_fl.py\" app_config=\"--client_id 1 --epochs 10\" \\\n", | ||
" -f app_2/config_fed_client.conf app_script=\"graphsage_finance_fl.py\" app_config=\"--client_id 2 --epochs 10\" \\\n", | ||
" -f app_server/config_fed_server.conf num_rounds=7 key_metric=\"validation_auc\" model_class_path=\"pyg_sage.SAGE\" components[0].args.model.args.in_channels=165 components[0].args.model.args.hidden_channels=256 components[0].args.model.args.num_layers=3 components[0].args.model.args.num_classes=2 \n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "bd0713e2-e393-41c0-9da0-392535cf8a54", | ||
"metadata": {}, | ||
"source": [ | ||
"## 5. Run simulated kmeans experiment\n", | ||
"Now that we have the jobs ready, we run the experiment using Simulator." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "9bb6cab4-9c24-400a-bc3c-f1e4a6d5a346", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"! nvflare simulator -w /tmp/nvflare/gnn/protein_fl_workspace -n 2 -t 2 /tmp/nvflare/jobs/gnn_protein" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "98c64648-1d09-42da-bd48-9a6ac48587af", | ||
"metadata": {}, | ||
"source": [ | ||
"And for finance experiment" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a9f256a0-ae99-4a7e-8bc2-e7fc8de2e6f6", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"! nvflare simulator -w /tmp/nvflare/gnn/finance_fl_workspace -n 2 -t 2 /tmp/nvflare/jobs/gnn_finance" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "913e9ee2-e993-442d-a525-d2baf92af539", | ||
"metadata": {}, | ||
"source": [ | ||
"## 6. Result visualization\n", | ||
"Results from both local and federated experiments can be visualized in tensorboard." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "a6814434-4e6d-4460-b480-709cb3e77cc8", | ||
"metadata": { | ||
"tags": [] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"%load_ext tensorboard\n", | ||
"%tensorboard --logdir /tmp/nvflare/gnn" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"id": "0f6ae6cb-12df-4279-b6af-9c4d356e727e", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "nvflare_example", | ||
"language": "python", | ||
"name": "nvflare_example" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.8.16" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,4 +1,3 @@ | ||
nvflare | ||
torch | ||
torch_geometric | ||
tensorboard | ||
|