diff --git a/dnn_predict_accuracy/AUTHORS b/dnn_predict_accuracy/AUTHORS new file mode 100644 index 000000000000..fa0a16d1a5b2 --- /dev/null +++ b/dnn_predict_accuracy/AUTHORS @@ -0,0 +1,12 @@ +# This is the list of dnn-predict-accuracy authors for copyright purposes. +# +# This does not necessarily list everyone who has contributed code, since in +# some cases, their employer may be the copyright holder. To see the full list +# of contributors, see the revision history in source control. +Google LLC +Daniel Keysers +Ilya Tolstikhin +Olivier Bousquet +Sylvain Gelly +Thomas Unterthiner + diff --git a/dnn_predict_accuracy/CONTRIBUTING.md b/dnn_predict_accuracy/CONTRIBUTING.md new file mode 100644 index 000000000000..654a071648d6 --- /dev/null +++ b/dnn_predict_accuracy/CONTRIBUTING.md @@ -0,0 +1,28 @@ +# How to Contribute + +We'd love to accept your patches and contributions to this project. There are +just a few small guidelines you need to follow. + +## Contributor License Agreement + +Contributions to this project must be accompanied by a Contributor License +Agreement. You (or your employer) retain the copyright to your contribution; +this simply gives us permission to use and redistribute your contributions as +part of the project. Head over to to see +your current agreements on file or to sign a new one. + +You generally only need to submit a CLA once, so if you've already submitted one +(even if it was for a different project), you probably don't need to do it +again. + +## Code reviews + +All submissions, including submissions by project members, require review. We +use GitHub pull requests for this purpose. Consult +[GitHub Help](https://help.github.com/articles/about-pull-requests/) for more +information on using pull requests. + +## Community Guidelines + +This project follows [Google's Open Source Community +Guidelines](https://opensource.google/conduct/). diff --git a/dnn_predict_accuracy/README.md b/dnn_predict_accuracy/README.md new file mode 100644 index 000000000000..12b63a2e70b3 --- /dev/null +++ b/dnn_predict_accuracy/README.md @@ -0,0 +1,16 @@ +# dnn-predict-accuracy + +This is the source code accompanying the paper "Predicting Neural Network +Accuracy from Weights". Citation/Author names will be added after the anonymous +review. + +## Data + +The data accompanying this paper will be released in the near future. Stay +tuned! Meanwhile, the code here should be sufficient to re-generate the data or +see how the models were trained. + +## License + +This repository is licensed under the Apache License, Version 2.0. See LICENSE +for details. diff --git a/dnn_predict_accuracy/colab/dnn_predict_accuracy.ipynb b/dnn_predict_accuracy/colab/dnn_predict_accuracy.ipynb new file mode 100644 index 000000000000..dc5d4eb53cfd --- /dev/null +++ b/dnn_predict_accuracy/colab/dnn_predict_accuracy.ipynb @@ -0,0 +1,926 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "VwlxkU9cEGPR" + }, + "source": [ + "Copyright 2020 The dnn-predict-accuracy Authors.\n", + "\n", + "Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "you may not use this file except in compliance with the License.\n", + "You may obtain a copy of the License at\n", + "\n", + " https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "Unless required by applicable law or agreed to in writing, software\n", + "distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "See the License for the specific language governing permissions and\n", + "limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "x08yky7rytbD" + }, + "source": [ + "# README\n", + "\n", + "This notebook contains code for training predictors of DNN accuracy. \n", + "\n", + "Contents:\n", + "\n", + "(1) Loading the Small CNN Zoo dataset\n", + "\n", + "(2) Figure 2 of the paper\n", + "\n", + "(3) Examples of training Logit-Linear / GBM / DNN predictors\n", + "\n", + "(4) Transfer of predictors across CNN collections\n", + "\n", + "(5) Various visualizations of CNN collections\n", + "\n", + "Code dependencies:\n", + "Light-GBM package\n" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "aj14NoLVykBz" + }, + "outputs": [], + "source": [ + "from __future__ import division\n", + "\n", + "import time\n", + "import os\n", + "import json\n", + "import sys\n", + "import numpy as np\n", + "from matplotlib import pyplot as plt\n", + "from matplotlib import colors\n", + "import pandas as pd\n", + "import seaborn as sns\n", + "from scipy import stats\n", + "from tensorflow import keras\n", + "from tensorflow.io import gfile\n", + "import lightgbm as lgb\n", + "\n", + "DATAFRAME_CONFIG_COLS = [\n", + " 'config.w_init',\n", + " 'config.activation',\n", + " 'config.learning_rate',\n", + " 'config.init_std',\n", + " 'config.l2reg',\n", + " 'config.train_fraction',\n", + " 'config.dropout']\n", + "CATEGORICAL_CONFIG_PARAMS = ['config.w_init', 'config.activation']\n", + "CATEGORICAL_CONFIG_PARAMS_PREFIX = ['winit', 'act']\n", + "DATAFRAME_METRIC_COLS = [\n", + " 'test_accuracy',\n", + " 'test_loss',\n", + " 'train_accuracy',\n", + " 'train_loss']\n", + "TRAIN_SIZE = 15000\n", + "\n", + "# TODO: modify the following lines\n", + "CONFIGS_PATH_BASE = 'path_to_the_file_with_best_configs'\n", + "MNIST_OUTDIR = \"path_to_files_with_mnist_collection\"\n", + "FMNIST_OUTDIR = 'path_to_files_with_fmnist_collection'\n", + "CIFAR_OUTDIR = 'path_to_files_with_cifar10gs_collection'\n", + "SVHN_OUTDIR = 'path_to_files_with_svhngs_collection'\n", + "\n", + "def filter_checkpoints(weights, dataframe,\n", + " target='test_accuracy',\n", + " stage='final', binarize=True):\n", + " \"\"\"Take one checkpoint per run and do some pre-processing.\n", + "\n", + " Args:\n", + " weights: numpy array of shape (num_runs, num_weights)\n", + " dataframe: pandas DataFrame which has num_runs rows. First 4 columns should\n", + " contain test_accuracy, test_loss, train_accuracy, train_loss respectively.\n", + " target: string, what to use as an output\n", + " stage: flag defining which checkpoint out of potentially many we will take\n", + " for the run.\n", + " binarize: Do we want to binarize the categorical hyperparams?\n", + "\n", + " Returns:\n", + " tuple (weights_new, metrics, hyperparams, ckpts), where\n", + " weights_new is a numpy array of shape (num_remaining_ckpts, num_weights),\n", + " metrics is a numpy array of shape (num_remaining_ckpts, num_metrics) with\n", + " num_metric being the length of DATAFRAME_METRIC_COLS,\n", + " hyperparams is a pandas DataFrame of num_remaining_ckpts rows and columns\n", + " listed in DATAFRAME_CONFIG_COLS.\n", + " ckpts is an instance of pandas Index, keeping filenames of the checkpoints\n", + " All the num_remaining_ckpts rows correspond to one checkpoint out of each\n", + " run we had.\n", + " \"\"\"\n", + "\n", + " assert target in DATAFRAME_METRIC_COLS, 'unknown target'\n", + " ids_to_take = []\n", + " # Keep in mind that the rows of the DataFrame were sorted according to ckpt\n", + " # Fetch the unit id corresponding to the ckpt of the first row\n", + " current_uid = dataframe.axes[0][0].split('/')[-2] # get the unit id\n", + " steps = []\n", + " for i in range(len(dataframe.axes[0])):\n", + " # Fetch the new unit id\n", + " ckpt = dataframe.axes[0][i]\n", + " parts = ckpt.split('/')\n", + " if parts[-2] == current_uid:\n", + " steps.append(int(parts[-1].split('-')[-1]))\n", + " else:\n", + " # We need to process the previous unit\n", + " # and choose which ckpt to take\n", + " steps_sort = sorted(steps)\n", + " target_step = -1\n", + " if stage == 'final':\n", + " target_step = steps_sort[-1]\n", + " elif stage == 'early':\n", + " target_step = steps_sort[0]\n", + " else: # middle\n", + " target_step = steps_sort[int(len(steps) / 2)]\n", + " offset = [j for (j, el) in enumerate(steps) if el == target_step][0]\n", + " # Take the DataFrame row with the corresponding row id\n", + " ids_to_take.append(i - len(steps) + offset)\n", + " current_uid = parts[-2]\n", + " steps = [int(parts[-1].split('-')[-1])]\n", + "\n", + " # Fetch the hyperparameters of the corresponding checkpoints\n", + " hyperparams = dataframe[DATAFRAME_CONFIG_COLS]\n", + " hyperparams = hyperparams.iloc[ids_to_take]\n", + " if binarize:\n", + " # Binarize categorical features\n", + " hyperparams = pd.get_dummies(\n", + " hyperparams,\n", + " columns=CATEGORICAL_CONFIG_PARAMS,\n", + " prefix=CATEGORICAL_CONFIG_PARAMS_PREFIX)\n", + " else:\n", + " # Make the categorical features have pandas type \"category\"\n", + " # Then LGBM can use those as categorical\n", + " hyperparams.is_copy = False\n", + " for col in CATEGORICAL_CONFIG_PARAMS:\n", + " hyperparams[col] = hyperparams[col].astype('category')\n", + "\n", + " # Fetch the file paths of the corresponding checkpoints\n", + " ckpts = dataframe.axes[0][ids_to_take]\n", + "\n", + " return (weights[ids_to_take, :],\n", + " dataframe[DATAFRAME_METRIC_COLS].values[ids_to_take, :].astype(\n", + " np.float32),\n", + " hyperparams,\n", + " ckpts)\n", + "\n", + "def build_fcn(n_layers, n_hidden, n_outputs, dropout_rate, activation,\n", + " w_regularizer, w_init, b_init, last_activation='softmax'):\n", + " \"\"\"Fully connected deep neural network.\"\"\"\n", + " model = keras.Sequential()\n", + " model.add(keras.layers.Flatten())\n", + " for _ in range(n_layers):\n", + " model.add(\n", + " keras.layers.Dense(\n", + " n_hidden,\n", + " activation=activation,\n", + " kernel_regularizer=w_regularizer,\n", + " kernel_initializer=w_init,\n", + " bias_initializer=b_init))\n", + " if dropout_rate \u003e 0.0:\n", + " model.add(keras.layers.Dropout(dropout_rate))\n", + " if n_layers \u003e 0:\n", + " model.add(keras.layers.Dense(n_outputs, activation=last_activation))\n", + " else:\n", + " model.add(keras.layers.Dense(\n", + " n_outputs,\n", + " activation='sigmoid',\n", + " kernel_regularizer=w_regularizer,\n", + " kernel_initializer=w_init,\n", + " bias_initializer=b_init))\n", + " return model\n", + "\n", + "def extract_summary_features(w, qts=(0, 25, 50, 75, 100)):\n", + " \"\"\"Extract various statistics from the flat vector w.\"\"\"\n", + " features = np.percentile(w, qts)\n", + " features = np.append(features, [np.std(w), np.mean(w)])\n", + " return features\n", + "\n", + "\n", + "def extract_per_layer_features(w, qts=None, layers=(0, 1, 2, 3)):\n", + " \"\"\"Extract per-layer statistics from the weight vector and concatenate.\"\"\"\n", + " # Indices of the location of biases/kernels in the flattened vector\n", + " all_boundaries = {\n", + " 0: [(0, 16), (16, 160)], \n", + " 1: [(160, 176), (176, 2480)], \n", + " 2: [(2480, 2496), (2496, 4800)], \n", + " 3: [(4800, 4810), (4810, 4970)]}\n", + " boundaries = []\n", + " for layer in layers:\n", + " boundaries += all_boundaries[layer]\n", + " \n", + " if not qts:\n", + " features = [extract_summary_features(w[a:b]) for (a, b) in boundaries]\n", + " else:\n", + " features = [extract_summary_features(w[a:b], qts) for (a, b) in boundaries]\n", + " all_features = np.concatenate(features)\n", + " return all_features\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "SBM6xNSjz8Bs" + }, + "source": [ + "# 1. Loading the Small CNN Zoo dataset\n", + "\n", + "The following code loads the dataset (trained weights from *.npy files and all the relevant metrics, including accuracy, from *.csv files). " + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "mp-5POSc0ap2" + }, + "outputs": [], + "source": [ + "all_dirs = [MNIST_OUTDIR, FMNIST_OUTDIR, CIFAR_OUTDIR, SVHN_OUTDIR]\n", + "weights = {'mnist': None,\n", + " 'fashion_mnist': None,\n", + " 'cifar10': None,\n", + " 'svhn_cropped': None}\n", + "metrics = {'mnist': None,\n", + " 'fashion_mnist': None,\n", + " 'cifar10': None,\n", + " 'svhn_cropped': None}\n", + "for (dirname, dataname) in zip(\n", + " all_dirs, ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']):\n", + " print('Loading %s' % dataname)\n", + " with gfile.GFile(os.path.join(dirname, \"all_weights.npy\"), \"rb\") as f:\n", + " # Weights of the trained models\n", + " weights[dataname] = np.load(f)\n", + " with gfile.GFile(os.path.join(dirname, \"all_metrics.csv\")) as f:\n", + " # pandas DataFrame with metrics\n", + " metrics[dataname] = pd.read_csv(f, index_col=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "FNqWMZcx1y5m" + }, + "source": [ + "Next it filters the dataset by keeping only checkpoints corresponding to 18 epochs and discarding runs that resulted in numerical instabilities. Finally, it performs the train / test splits." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "1qL5_-FZ11gm" + }, + "outputs": [], + "source": [ + "weights_train = {}\n", + "weights_test = {}\n", + "configs_train = {}\n", + "configs_test = {}\n", + "outputs_train = {}\n", + "outputs_test = {}\n", + "\n", + "for dataset in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:\n", + " # Take one checkpoint per each run\n", + " # If using GBM as predictor, set binarize=False\n", + " weights_flt, metrics_flt, configs_flt, ckpts = filter_checkpoints(\n", + " weights[dataset], metrics[dataset], binarize=True)\n", + "\n", + " # Filter out DNNs with NaNs and Inf in the weights\n", + " idx_valid = (np.isfinite(weights_flt).mean(1) == 1.0)\n", + " inputs = np.asarray(weights_flt[idx_valid], dtype=np.float32)\n", + " outputs = np.asarray(metrics_flt[idx_valid], dtype=np.float32)\n", + " configs = configs_flt.iloc[idx_valid]\n", + " ckpts = ckpts[idx_valid]\n", + "\n", + " # Shuffle and split the data\n", + " random_idx = list(range(inputs.shape[0]))\n", + " np.random.shuffle(random_idx)\n", + " weights_train[dataset], weights_test[dataset] = (\n", + " inputs[random_idx[:TRAIN_SIZE]], inputs[random_idx[TRAIN_SIZE:]])\n", + " outputs_train[dataset], outputs_test[dataset] = (\n", + " 1. * outputs[random_idx[:TRAIN_SIZE]],\n", + " 1. * outputs[random_idx[TRAIN_SIZE:]])\n", + " configs_train[dataset], configs_test[dataset] = (\n", + " configs.iloc[random_idx[:TRAIN_SIZE]], \n", + " configs.iloc[random_idx[TRAIN_SIZE:]])" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "K7cpDNyB2tCc" + }, + "source": [ + "# 2. Figure 2 of the paper\n", + "\n", + "Next we plot distribution of CNNs from 4 collections in Small CNN Zoo according to their train / test accuracy" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "VnToqYeT25pb" + }, + "outputs": [], + "source": [ + "plt.figure(figsize = (16, 8))\n", + "pic_id = 0\n", + "\n", + "for dataset in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:\n", + " pic_id += 1\n", + " sp = plt.subplot(2, 4, pic_id)\n", + "\n", + " outputs = outputs_train[dataset]\n", + "\n", + " if dataset == 'mnist':\n", + " plt.title('MNIST', fontsize=24)\n", + " if dataset == 'fashion_mnist':\n", + " plt.title('Fashion MNIST', fontsize=24)\n", + " if dataset == 'cifar10':\n", + " plt.title('CIFAR10-GS', fontsize=24)\n", + " if dataset == 'svhn_cropped':\n", + " plt.title('SVHN-GS', fontsize=24)\n", + "\n", + " # 1. test accuracy hist plots\n", + " sns.distplot(np.array(outputs[:, 0]), bins=15, kde=False, color='green')\n", + " plt.xlim((0.0, 1.0))\n", + " sp.axes.get_xaxis().set_ticklabels([])\n", + " sp.axes.get_yaxis().set_ticklabels([])\n", + " pic_id += 4\n", + " sp = plt.subplot(2, 4, pic_id)\n", + "\n", + " # 2. test / train accuracy scatter plots\n", + " NUM_POINTS = 1000\n", + " random_idx = range(len(outputs))\n", + " np.random.shuffle(random_idx)\n", + " plt.plot([0.0, 1.0], [0.0, 1.0], 'r--')\n", + " sns.scatterplot(np.array(outputs[random_idx[:NUM_POINTS], 0]), # test acc\n", + " np.array(outputs[random_idx[:NUM_POINTS], 2]), # train acc\n", + " s=30\n", + " )\n", + " if pic_id == 5:\n", + " plt.ylabel('Train accuracy', fontsize=22)\n", + " sp.axes.get_yaxis().set_ticklabels([0.0, 0.2, .4, .6, .8, 1.])\n", + " else:\n", + " sp.axes.get_yaxis().set_ticklabels([])\n", + " plt.xlim((0.0, 1.0))\n", + " plt.ylim((0.0, 1.0))\n", + " sp.axes.get_xaxis().set_ticks([0.0, 0.2, .4, .6, .8, 1.])\n", + " sp.axes.tick_params(axis='both', labelsize=18)\n", + " plt.xlabel('Test accuracy', fontsize=22)\n", + "\n", + " pic_id -= 4\n", + "\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "fxtGdIK55t9B" + }, + "source": [ + "# 3. Examples of training Logit-Linear / GBM / DNN predictors\n", + "\n", + "Next we train 3 models on all 4 CNN collections with the best hyperparameter configurations we found during our studies (documented in Table 2 and Section 4 of the paper).\n", + "\n", + "First, we load the best hyperparameter configurations we found.\n", + "The file best_configs.json contains a list. \n", + "Each entry of that list corresponds to the single hyperparameter configuration. \n", + "It consists of: \n", + "\n", + " (1) name of the CNN collection (mnist/fashion mnist/cifar10/svhn) \n", + " \n", + " (2) predictor type (linear/dnn/lgbm)\n", + " \n", + " (3) type of inputs, (refer to Table 2)\n", + " \n", + " (4) value of MSE you will get training with these settings, \n", + " \n", + " (5) dictionary of \"parameter name\"-\u003e \"parameter value\" for the given type of predictor." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "B7oCx5rr6y4D" + }, + "outputs": [], + "source": [ + "with gfile.GFile(os.path.join(CONFIGS_PATH_BASE, 'best_configs.json'), 'r') as file:\n", + " best_configs = json.load(file)\n", + "\n" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "nQsP1aA5UhqT" + }, + "source": [ + "# 3.1 Training GBM predictors\n", + "\n", + "GBM code below requires the lightgbm package.\n", + "\n", + "This is an example of training GBM on CIFAR10-GS CNN collection using per-layer weights statistics as inputs." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "t4KzPiTAXWuo" + }, + "outputs": [], + "source": [ + "# Take the best config we found\n", + "config = [el[-1] for el in best_configs if\n", + " el[0] == 'cifar10' and\n", + " el[1] == 'lgbm' and\n", + " el[2] == 'wstats-perlayer'][0]\n", + "\n", + "# Pre-process the weights\n", + "train_x = np.apply_along_axis(\n", + " extract_per_layer_features, 1,\n", + " weights_train['cifar10'],\n", + " qts=None,\n", + " layers=(0, 1, 2, 3))\n", + "test_x = np.apply_along_axis(\n", + " extract_per_layer_features, 1,\n", + " weights_test['cifar10'], \n", + " qts=None, \n", + " layers=(0, 1, 2, 3))\n", + "# Get the target values\n", + "train_y, test_y = outputs_train['cifar10'][:, 0], outputs_test['cifar10'][:, 0]\n", + "\n", + "# Define the GBM model\n", + "lgbm_model = lgb.LGBMRegressor(\n", + " num_leaves=config['num_leaves'],\n", + " max_depth=config['max_depth'],\n", + " learning_rate=config['learning_rate'],\n", + " max_bin=int(config['max_bin']),\n", + " min_child_weight=config['min_child_weight'],\n", + " reg_lambda=config['reg_lambda'],\n", + " reg_alpha=config['reg_alpha'],\n", + " subsample=config['subsample'],\n", + " subsample_freq=1, # it means always subsample\n", + " colsample_bytree=config['colsample_bytree'],\n", + " n_estimators=2000,\n", + " first_metric_only=True\n", + ")\n", + "\n", + "# Train the GBM model;\n", + "# Early stopping will be based on rmse of test set\n", + "eval_metric = ['rmse', 'l1']\n", + "eval_set = [(test_x, test_y)]\n", + "lgbm_model.fit(train_x, train_y, verbose=100,\n", + " early_stopping_rounds=500,\n", + " eval_metric=eval_metric,\n", + " eval_set=eval_set,\n", + " eval_names=['test'])\n", + "\n", + "# Evaluate the GBM model\n", + "assert hasattr(lgbm_model, 'best_iteration_')\n", + "# Choose the step which had the best rmse on the test set\n", + "best_iter = lgbm_model.best_iteration_ - 1\n", + "lgbm_history = lgbm_model.evals_result_\n", + "mse = lgbm_history['test']['rmse'][best_iter] ** 2.\n", + "mad = lgbm_history['test']['l1'][best_iter]\n", + "var = np.mean((test_y - np.mean(test_y)) ** 2.)\n", + "r2 = 1. - mse / var\n", + "print('Test MSE = ', mse)\n", + "print('Test MAD = ', mad)\n", + "print('Test R2 = ', r2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "1Sf5cFosZcmk" + }, + "source": [ + "# 3.2 Training DNN predictors\n", + "\n", + "This is an example of training DNN on MNIST CNN collection using all weights as inputs." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "cVsPbhQYZodD" + }, + "outputs": [], + "source": [ + "# Take the best config we found\n", + "config = [el[-1] for el in best_configs if\n", + " el[0] == 'mnist' and\n", + " el[1] == 'dnn' and\n", + " el[2] == 'weights'][0]\n", + "\n", + "train_x, test_x = weights_train['cifar10'], weights_test['cifar10']\n", + "train_y, test_y = outputs_train['cifar10'][:, 0], outputs_test['cifar10'][:, 0]\n", + "\n", + "# Get the optimizer, initializers, and regularizers\n", + "optimizer = keras.optimizers.get(config['optimizer_name'])\n", + "optimizer.learning_rate = config['learning_rate']\n", + "w_init = keras.initializers.get(config['w_init_name'])\n", + "if config['w_init_name'].lower() in ['truncatednormal', 'randomnormal']:\n", + " w_init.stddev = config['init_stddev']\n", + "b_init = keras.initializers.get('zeros')\n", + "w_reg = (keras.regularizers.l2(config['l2_penalty']) \n", + " if config['l2_penalty'] \u003e 0 else None)\n", + "\n", + "# Get the fully connected DNN architecture\n", + "dnn_model = build_fcn(int(config['n_layers']),\n", + " int(config['n_hiddens']),\n", + " 1, # number of outputs\n", + " config['dropout_rate'],\n", + " 'relu',\n", + " w_reg, w_init, b_init,\n", + " 'sigmoid') # Last activation\n", + "dnn_model.compile(\n", + " optimizer=optimizer,\n", + " loss='mean_squared_error',\n", + " metrics=['mse', 'mae'])\n", + "\n", + "# Train the model\n", + "dnn_model.fit(\n", + " train_x, train_y,\n", + " batch_size=int(config['batch_size']),\n", + " epochs=300,\n", + " validation_data=(test_x, test_y),\n", + " verbose=1,\n", + " callbacks=[keras.callbacks.EarlyStopping(\n", + " monitor='val_loss',\n", + " min_delta=0,\n", + " patience=10,\n", + " verbose=0,\n", + " mode='auto',\n", + " baseline=None,\n", + " restore_best_weights=False)]\n", + " )\n", + "\n", + "# Evaluate the model\n", + "eval_train = dnn_model.evaluate(train_x, train_y, batch_size=128, verbose=0)\n", + "eval_test = dnn_model.evaluate(test_x, test_y, batch_size=128, verbose=0)\n", + "assert dnn_model.metrics_names[1] == 'mean_squared_error'\n", + "assert dnn_model.metrics_names[2] == 'mean_absolute_error'\n", + "mse = eval_test[1]\n", + "var = np.mean((test_y - np.mean(test_y)) ** 2.)\n", + "r2 = 1. - mse / var\n", + "print('Test MSE = ', mse)\n", + "print('Test MAD = ', eval_test[2])\n", + "print('Test R2 = ', r2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "DF3N5jZ9JQMs" + }, + "source": [ + "# 3.3 Train Logit-Linear predictors\n", + "\n", + "This is an example of training Logit-Linear model on CIFAR10 CNN collection using hyperparameters as inputs." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "_S_183RnJUZu" + }, + "outputs": [], + "source": [ + "# Take the best config we found\n", + "config = [el[-1] for el in best_configs if\n", + " el[0] == 'cifar10' and\n", + " el[1] == 'linear' and\n", + " el[2] == 'hyper'][0]\n", + "\n", + "# Turn DataFrames to numpy arrays. \n", + "# Since we used \"binarize=True\" when calling filter_checkpoints all the\n", + "# categorical columns were binarized.\n", + "train_x = configs_train['cifar10'].values.astype(np.float32)\n", + "test_x = configs_test['cifar10'].values.astype(np.float32)\n", + "train_y, test_y = outputs_train['cifar10'][:, 0], outputs_test['cifar10'][:, 0]\n", + "\n", + "# Get the optimizer, initializers, and regularizers\n", + "optimizer = keras.optimizers.get(config['optimizer_name'])\n", + "optimizer.learning_rate = config['learning_rate']\n", + "w_init = keras.initializers.get(config['w_init_name'])\n", + "if config['w_init_name'].lower() in ['truncatednormal', 'randomnormal']:\n", + " w_init.stddev = config['init_stddev']\n", + "b_init = keras.initializers.get('zeros')\n", + "w_reg = (keras.regularizers.l2(config['l2_penalty']) \n", + " if config['l2_penalty'] \u003e 0 else None)\n", + "\n", + "# Get the linear architecture (DNN with 0 layers)\n", + "dnn_model = build_fcn(int(config['n_layers']),\n", + " int(config['n_hiddens']),\n", + " 1, # number of outputs\n", + " None, # Dropout is not used\n", + " 'relu',\n", + " w_reg, w_init, b_init,\n", + " 'sigmoid') # Last activation\n", + "dnn_model.compile(\n", + " optimizer=optimizer,\n", + " loss='mean_squared_error',\n", + " metrics=['mse', 'mae'])\n", + "\n", + "# Train the model\n", + "dnn_model.fit(\n", + " train_x, train_y,\n", + " batch_size=int(config['batch_size']),\n", + " epochs=300,\n", + " validation_data=(test_x, test_y),\n", + " verbose=1,\n", + " callbacks=[keras.callbacks.EarlyStopping(\n", + " monitor='val_loss',\n", + " min_delta=0,\n", + " patience=10,\n", + " verbose=0,\n", + " mode='auto',\n", + " baseline=None,\n", + " restore_best_weights=False)]\n", + " )\n", + "\n", + "# Evaluate the model\n", + "eval_train = dnn_model.evaluate(train_x, train_y, batch_size=128, verbose=0)\n", + "eval_test = dnn_model.evaluate(test_x, test_y, batch_size=128, verbose=0)\n", + "assert dnn_model.metrics_names[1] == 'mean_squared_error'\n", + "assert dnn_model.metrics_names[2] == 'mean_absolute_error'\n", + "mse = eval_test[1]\n", + "var = np.mean((test_y - np.mean(test_y)) ** 2.)\n", + "r2 = 1. - mse / var\n", + "print('Test MSE = ', mse)\n", + "print('Test MAD = ', eval_test[2])\n", + "print('Test R2 = ', r2)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "i97PjpsxWQWS" + }, + "source": [ + "# 4. Figure 4: Transfer across datasets\n", + "\n", + "Train GBM predictor using statistics of all layers as inputs on all 4 CNN collections. Then evaluate them on each of the 4 CNN collections (without fine-tuning). Store all results." + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "xRFiVulhWeQ9" + }, + "outputs": [], + "source": [ + "transfer_results = {}\n", + "\n", + "for dataset in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:\n", + " print('Training on %s' % dataset)\n", + " transfer_results[dataset] = {}\n", + " \n", + " train_x = weights_train[dataset]\n", + " test_x = weights_test[dataset]\n", + " train_y = outputs_train[dataset][:, 0]\n", + " test_y = outputs_test[dataset][:, 0]\n", + "\n", + " # Pre-process the weights by taking the statistics across layers\n", + " train_x = np.apply_along_axis(\n", + " extract_per_layer_features, 1, \n", + " train_x, qts=None, layers=(0, 1, 2, 3))\n", + " test_x = np.apply_along_axis(\n", + " extract_per_layer_features, 1,\n", + " test_x, qts=None, layers=(0, 1, 2, 3))\n", + "\n", + " # Take the best config we found\n", + " config = [el[-1] for el in best_configs if\n", + " el[0] == dataset and\n", + " el[1] == 'lgbm' and\n", + " el[2] == 'wstats-perlayer'][0]\n", + "\n", + " lgbm_model = lgb.LGBMRegressor(\n", + " num_leaves=config['num_leaves'],\n", + " max_depth=config['max_depth'], \n", + " learning_rate=config['learning_rate'], \n", + " max_bin=int(config['max_bin']),\n", + " min_child_weight=config['min_child_weight'],\n", + " reg_lambda=config['reg_lambda'],\n", + " reg_alpha=config['reg_alpha'],\n", + " subsample=config['subsample'],\n", + " subsample_freq=1, # Always subsample\n", + " colsample_bytree=config['colsample_bytree'],\n", + " n_estimators=4000,\n", + " first_metric_only=True,\n", + " )\n", + " \n", + " # Train the GBM model\n", + " lgbm_model.fit(\n", + " train_x,\n", + " train_y,\n", + " verbose=100,\n", + " # verbose=False,\n", + " early_stopping_rounds=500,\n", + " eval_metric=['rmse', 'l1'],\n", + " eval_set=[(test_x, test_y)],\n", + " eval_names=['test'])\n", + " \n", + " # Evaluate on all 4 CNN collections\n", + " for transfer_to in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:\n", + " print('Evaluating on %s' % transfer_to)\n", + " # Take the test split of the dataset\n", + " transfer_x = weights_test[transfer_to]\n", + " transfer_x = np.apply_along_axis(\n", + " extract_per_layer_features, 1,\n", + " transfer_x, qts=None, layers=(0, 1, 2, 3))\n", + " y_hat = lgbm_model.predict(transfer_x)\n", + " transfer_results[dataset][transfer_to] = y_hat" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "VvkJS4CKYDj_" + }, + "source": [ + "And plot everything" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "U9J8nA4BYF4P" + }, + "outputs": [], + "source": [ + "plt.figure(figsize = (15, 15))\n", + "pic_id = 0\n", + "for dataset in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:\n", + " for transfer_to in ['mnist', 'fashion_mnist', 'cifar10', 'svhn_cropped']:\n", + " pic_id += 1\n", + " sp = plt.subplot(4, 4, pic_id)\n", + " # Take true labels\n", + " y_true = outputs_test[transfer_to][:, 0]\n", + " # Take the predictions of the model\n", + " y_hat = transfer_results[dataset][transfer_to]\n", + " plt.plot([0.01, .99], [0.01, .99], 'r--', linewidth=2)\n", + " sns.scatterplot(y_true, y_hat)\n", + " # Compute the Kendall's tau coefficient\n", + " tau = stats.kendalltau(y_true, y_hat)[0]\n", + " plt.text(0.05, 0.9, r\"$\\tau=%.3f$\" % tau, fontsize=25)\n", + " plt.xlim((0.0, 1.0))\n", + " plt.ylim((0.0, 1.0))\n", + "\n", + " if pic_id % 4 != 1:\n", + " sp.axes.get_yaxis().set_ticklabels([])\n", + " else:\n", + " plt.ylabel('Predictions', fontsize=22)\n", + " sp.axes.tick_params(axis='both', labelsize=15)\n", + "\n", + " if pic_id \u003c 13:\n", + " sp.axes.get_xaxis().set_ticklabels([])\n", + " else:\n", + " plt.xlabel('Test accuracy', fontsize=22)\n", + " sp.axes.tick_params(axis='both', labelsize=15)\n", + "\n", + " if pic_id == 1:\n", + " plt.title('MNIST', fontsize=22)\n", + " if pic_id == 2:\n", + " plt.title('Fashion-MNIST', fontsize=22)\n", + " if pic_id == 3:\n", + " plt.title('CIFAR10-GS', fontsize=22)\n", + " if pic_id == 4:\n", + " plt.title('SVHN-GS', fontsize=22)\n", + "\n", + "plt.tight_layout()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "colab_type": "text", + "id": "Iahn92bHY8kQ" + }, + "source": [ + "# 5. Figure 3: various 2d plots based on subsets of weights statistics\n", + "\n", + "Take weight statistics for the CIFAR10 CNN collection. Plot various 2d plots" + ] + }, + { + "cell_type": "code", + "execution_count": 0, + "metadata": { + "colab": {}, + "colab_type": "code", + "id": "nBXxv0-2ZfZA" + }, + "outputs": [], + "source": [ + "# Take the per-layer weights stats for the train split of CIFAR10-GS collection\n", + "per_layer_stats = np.apply_along_axis(\n", + " extract_per_layer_features, 1,\n", + " weights_train['cifar10'])\n", + "train_test_accuracy = outputs_train['cifar10'][:, 0]\n", + "# Positions of various stats\n", + "b0min = 0 # min of the first layer\n", + "b0max = 4 # max of the first layer\n", + "bnmin = 6*7 + 0 # min of the last layer\n", + "bnmax = 6*7 + 4 # max of the last layer\n", + "x = per_layer_stats[:,b0max] - per_layer_stats[:,b0min]\n", + "y = per_layer_stats[:,bnmax] - per_layer_stats[:,bnmin]\n", + "\n", + "plt.figure(figsize=(10,8))\n", + "plt.scatter(x, y, s=15,\n", + " c=train_test_accuracy,\n", + " cmap=\"jet\",\n", + " vmin=0.1,\n", + " vmax=0.54,\n", + " linewidths=0)\n", + "plt.yscale(\"log\")\n", + "plt.xscale(\"log\")\n", + "plt.ylim(0.1, 10)\n", + "plt.xlim(0.1, 10)\n", + "plt.xlabel(\"Bias range, first layer\", fontsize=22)\n", + "plt.ylabel(\"Bias range, final layer\", fontsize=22)\n", + "cbar = plt.colorbar()\n", + "cbar.ax.tick_params(labelsize=18) \n", + "plt.tight_layout()" + ] + } + ], + "metadata": { + "colab": { + "collapsed_sections": [], + "last_runtime": { + "build_target": "//learning/brain/python/client:colab_notebook", + "kind": "private" + }, + "name": "dnn-predict-accuracy.ipynb", + "private_outputs": true, + "provenance": [ + { + "file_id": "1AV92_u26P4KyTmopFOKgROWg3GostHAA", + "timestamp": 1581610553676 + } + ], + "toc_visible": true + }, + "kernelspec": { + "display_name": "Python 2", + "name": "python2" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/dnn_predict_accuracy/requirements.txt b/dnn_predict_accuracy/requirements.txt new file mode 100644 index 000000000000..d2b0fa434c70 --- /dev/null +++ b/dnn_predict_accuracy/requirements.txt @@ -0,0 +1,4 @@ +tensorflow>=2.0 +tensorflow_datasets>=2.0 +lightgbm>=2.3 +numpy>=1.15.2 diff --git a/dnn_predict_accuracy/train_network.py b/dnn_predict_accuracy/train_network.py new file mode 100644 index 000000000000..0f86c701cf72 --- /dev/null +++ b/dnn_predict_accuracy/train_network.py @@ -0,0 +1,421 @@ +# coding=utf-8 +# Copyright 2020 The Google Research Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Train DNN of a specified architecture on a specified data set.""" + +from __future__ import absolute_import +from __future__ import division + +from __future__ import print_function + +import json +import os +import sys +import time + +from absl import app +from absl import flags +from absl import logging + +import numpy as np +import tensorflow.compat.v2 as tf +from tensorflow.io import gfile +import tensorflow_datasets as tfds + +FLAGS = flags.FLAGS +CNN_KERNEL_SIZE = 3 + +flags.DEFINE_integer('num_layers', 3, 'Number of layers in the network.') +flags.DEFINE_integer('num_units', 16, 'Number of units in a dense layer.') +flags.DEFINE_integer('batchsize', 512, 'Size of the mini-batch.') +flags.DEFINE_float( + 'train_fraction', 1.0, 'How much of the dataset to use for' + 'training [as fraction]: eg. 0.15, 0.5, 1.0') +flags.DEFINE_integer('epochs', 18, 'How many epochs to train for') +flags.DEFINE_integer('epochs_between_checkpoints', 6, + 'How many epochs to train between creating checkpoints') +flags.DEFINE_integer('random_seed', 42, 'Random seed.') +flags.DEFINE_integer('cnn_stride', 2, 'Stride of the CNN') +flags.DEFINE_float('dropout', 0.0, 'Dropout Rate') +flags.DEFINE_float('l2reg', 0.0, 'L2 regularization strength') +flags.DEFINE_float('init_std', 0.05, 'Standard deviation of the initializer.') +flags.DEFINE_float('learning_rate', 0.01, 'Learning rate.') +flags.DEFINE_string('optimizer', 'sgd', + 'Optimizer algorithm: sgd / adam / momentum.') +flags.DEFINE_string('activation', 'relu', + 'Nonlinear activation: relu / tanh / sigmoind / selu.') +flags.DEFINE_string( + 'w_init', 'he_normal', 'Initialization for weights. ' + 'see tf.keras.initializers for options') +flags.DEFINE_string( + 'b_init', 'zero', 'Initialization for biases.' + 'see tf.keras.initializers for options') +flags.DEFINE_boolean('grayscale', True, 'Convert input images to grayscale.') +flags.DEFINE_boolean('augment_traindata', False, 'Augmenting Training data.') +flags.DEFINE_boolean('reduce_learningrate', False, + 'Reduce LR towards end of training.') +flags.DEFINE_string('dataset', 'mnist', 'Name of the dataset compatible ' + 'with TFDS.') +flags.DEFINE_string('dnn_architecture', 'cnn', + 'Architecture of the DNN [fc, cnn, cnnbn]') +flags.DEFINE_string( + 'workdir', '/tmp/dnn_science_workdir', 'Base working directory for storing' + 'checkpoints, summaries, etc.') +flags.DEFINE_integer('verbose', 0, 'Verbosity') +flags.DEFINE_bool('use_tpu', False, 'Whether running on TPU or not.') +flags.DEFINE_string('master', 'local', + 'Name of the TensorFlow master to use. "local" for GPU.') +flags.DEFINE_string( + 'tpu_job_name', 'tpu_worker', + 'Name of the TPU worker job. This is required when having multiple TPU ' + 'worker jobs.') + + +def _get_workunit_params(): + """Get command line parameters of the current process as dict.""" + main_flags = FLAGS.get_key_flags_for_module(sys.argv[0]) + params = {'config.' + k.name: k.value for k in main_flags} + return params + + +def store_results(info_dict, filepath): + """Save results in the json file.""" + with gfile.GFile(filepath, 'w') as json_fp: + json.dump(info_dict, json_fp) + + +def restore_results(filepath): + """Retrieve results in the json file.""" + with gfile.GFile(filepath, 'r') as json_fp: + info = json.load(json_fp) + return info + + +def _preprocess_batch(batch, + normalize, + to_grayscale, + augment=False): + """Preprocessing function for each batch of data.""" + min_out = -1.0 + max_out = 1.0 + image = tf.cast(batch['image'], tf.float32) + image /= 255.0 + + if augment: + shape = image.shape + image = tf.image.resize_with_crop_or_pad(image, shape[1] + 2, shape[2] + 2) + image = tf.image.random_crop(image, size=shape) + + image = tf.image.random_flip_left_right(image) + image = tf.image.random_hue(image, 0.08) + image = tf.image.random_saturation(image, 0.6, 1.6) + image = tf.image.random_brightness(image, 0.05) + image = tf.image.random_contrast(image, 0.7, 1.3) + + if normalize: + image = min_out + image * (max_out - min_out) + if to_grayscale: + image = tf.math.reduce_mean(image, axis=-1, keepdims=True) + return image, batch['label'] + + +def get_dataset(dataset, + batchsize, + to_grayscale=True, + train_fraction=1.0, + shuffle_buffer=1024, + random_seed=None, + normalize=True, + augment=False): + """Load and preprocess the dataset. + + Args: + dataset: The dataset name. Either 'toy' or a TFDS dataset + batchsize: the desired batch size + to_grayscale: if True, all images will be converted into grayscale + train_fraction: what fraction of the overall training set should we use + shuffle_buffer: size of the shuffle.buffer for tf.data.Dataset.shuffle + random_seed: random seed for shuffling operations + normalize: whether to normalize the data into [-1, 1] + augment: use data augmentation on the training set. + + Returns: + tuple (training_dataset, test_dataset, info), where info is a dictionary + with some relevant information about the dataset. + """ + data_tr, ds_info = tfds.load(dataset, split='train', with_info=True) + effective_train_size = ds_info.splits['train'].num_examples + + if train_fraction < 1.0: + effective_train_size = int(effective_train_size * train_fraction) + data_tr = data_tr.shuffle(shuffle_buffer, seed=random_seed) + data_tr = data_tr.take(effective_train_size) + + fn_tr = lambda b: _preprocess_batch(b, normalize, to_grayscale, augment) + data_tr = data_tr.shuffle(shuffle_buffer, seed=random_seed) + data_tr = data_tr.batch(batchsize, drop_remainder=True) + data_tr = data_tr.map(fn_tr, tf.data.experimental.AUTOTUNE) + data_tr = data_tr.prefetch(tf.data.experimental.AUTOTUNE) + + fn_te = lambda b: _preprocess_batch(b, normalize, to_grayscale, False) + data_te = tfds.load(dataset, split='test') + data_te = data_te.batch(batchsize) + data_te = data_te.map(fn_te, tf.data.experimental.AUTOTUNE) + data_te = data_te.prefetch(tf.data.experimental.AUTOTUNE) + + dataset_info = { + 'num_classes': ds_info.features['label'].num_classes, + 'data_shape': ds_info.features['image'].shape, + 'train_num_examples': effective_train_size + } + return data_tr, data_te, dataset_info + + +def build_cnn(n_layers, n_hidden, n_outputs, dropout_rate, activation, stride, + w_regularizer, w_init, b_init, use_batchnorm): + """Convolutional deep neural network.""" + model = tf.keras.Sequential() + for _ in range(n_layers): + model.add( + tf.keras.layers.Conv2D( + n_hidden, + kernel_size=CNN_KERNEL_SIZE, + strides=stride, + activation=activation, + kernel_regularizer=w_regularizer, + kernel_initializer=w_init, + bias_initializer=b_init)) + if dropout_rate > 0.0: + model.add(tf.keras.layers.Dropout(dropout_rate)) + if use_batchnorm: + model.add(tf.keras.layers.BatchNormalization()) + model.add(tf.keras.layers.GlobalAveragePooling2D()) + model.add( + tf.keras.layers.Dense( + n_outputs, + kernel_regularizer=w_regularizer, + kernel_initializer=w_init, + bias_initializer=b_init)) + return model + + + + +def eval_model(model, data_tr, data_te, info, logger, cur_epoch, workdir): + """Runs Model Evaluation.""" + # get training set metrics in eval-mode (no dropout etc.) + metrics_te = model.evaluate(data_te, verbose=0) + res_te = dict(zip(model.metrics_names, metrics_te)) + metrics_tr = model.evaluate(data_tr, verbose=0) + res_tr = dict(zip(model.metrics_names, metrics_tr)) + metrics = { + 'train_accuracy': res_tr['accuracy'], + 'train_loss': res_tr['loss'], + 'test_accuracy': res_te['accuracy'], + 'test_loss': res_te['loss'], + } + for k in metrics: + info[k][cur_epoch] = float(metrics[k]) + metrics['epoch'] = cur_epoch # so it's included in the logging output + print(metrics) + savepath = os.path.join(workdir, 'permanent_ckpt-%d' % cur_epoch) + model.save(savepath) + + +def run(workdir, + data, + strategy, + architecture, + n_layers, + n_hiddens, + activation, + dropout_rate, + l2_penalty, + w_init_name, + b_init_name, + optimizer_name, + learning_rate, + n_epochs, + epochs_between_checkpoints, + init_stddev, + cnn_stride, + reduce_learningrate=False, + verbosity=0): + """Runs the whole training procedure.""" + data_tr, data_te, dataset_info = data + n_outputs = dataset_info['num_classes'] + + with strategy.scope(): + optimizer = tf.keras.optimizers.get(optimizer_name) + optimizer.learning_rate = learning_rate + w_init = tf.keras.initializers.get(w_init_name) + if w_init_name.lower() in ['truncatednormal', 'randomnormal']: + w_init.stddev = init_stddev + b_init = tf.keras.initializers.get(b_init_name) + if b_init_name.lower() in ['truncatednormal', 'randomnormal']: + b_init.stddev = init_stddev + w_reg = tf.keras.regularizers.l2(l2_penalty) if l2_penalty > 0 else None + + if architecture == 'cnn' or architecture == 'cnnbn': + model = build_cnn(n_layers, n_hiddens, n_outputs, dropout_rate, + activation, cnn_stride, w_reg, w_init, b_init, + architecture == 'cnnbn') + elif architecture == 'nin': + model = build_nin(n_hiddens, n_outputs, dropout_rate, activation, w_reg, + w_init, b_init) + else: + assert False, 'Unknown architecture: ' % architecture + + model.compile( + optimizer=optimizer, + loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), + metrics=['accuracy', 'mse', 'sparse_categorical_crossentropy']) + + # force the model to set input shapes and init weights + for x, _ in data_tr: + model.predict(x) + if verbosity: + model.summary() + break + + ckpt = tf.train.Checkpoint( + step=optimizer.iterations, optimizer=optimizer, model=model) + ckpt_dir = os.path.join(workdir, 'temporary-ckpt') + ckpt_manager = tf.train.CheckpointManager(ckpt, ckpt_dir, max_to_keep=3) + if ckpt_manager.latest_checkpoint: + logging.info('restoring checkpoint: %s', ckpt_manager.latest_checkpoint) + print('restoring from %s' % ckpt_manager.latest_checkpoint) + with strategy.scope(): + ckpt.restore(ckpt_manager.latest_checkpoint) + info = restore_results(os.path.join(workdir, '.intermediate-results.json')) + print(info, flush=True) + else: + info = { + 'steps': 0, + 'start_time': time.time(), + 'train_loss': dict(), + 'train_accuracy': dict(), + 'test_loss': dict(), + 'test_accuracy': dict(), + } + info.update(_get_workunit_params()) # Add command line parameters. + + logger = None + starting_epoch = len(info['train_loss']) + cur_epoch = starting_epoch + for cur_epoch in range(starting_epoch, n_epochs): + if reduce_learningrate and cur_epoch == n_epochs - (n_epochs // 10): + optimizer.learning_rate = learning_rate / 10 + elif reduce_learningrate and cur_epoch == n_epochs - 2: + optimizer.learning_rate = learning_rate / 100 + + # Train until we reach the criterion or get NaNs + try: + # always keep checkpoints for the first few epochs + # we evaluate first and train afterwards so we have the at-init data + if cur_epoch < 4 or (cur_epoch % epochs_between_checkpoints) == 0: + eval_model(model, data_tr, data_te, info, logger, cur_epoch, workdir) + + model.fit(data_tr, epochs=1, verbose=verbosity) + ckpt_manager.save() + store_results(info, os.path.join(workdir, '.intermediate-results.json')) + + dt = time.time() - info['start_time'] + logging.info('epoch %d (%3.2fs)', cur_epoch, dt) + + except tf.errors.InvalidArgumentError as e: + # We got NaN in the loss, most likely gradients resulted in NaNs + logging.info(str(e)) + info['status'] = 'NaN' + logging.info('Stop training because NaNs encountered') + break + + eval_model(model, data_tr, data_te, info, logger, cur_epoch+1, workdir) + store_results(info, os.path.join(workdir, 'results.json')) + + # we don't need the temporary checkpoints anymore + gfile.rmtree(os.path.join(workdir, 'temporary-ckpt')) + gfile.remove(os.path.join(workdir, '.intermediate-results.json')) + + +def main(unused_argv): + workdir = FLAGS.workdir + + + if not gfile.isdir(workdir): + gfile.makedirs(workdir) + + tf.random.set_seed(FLAGS.random_seed) + np.random.seed(FLAGS.random_seed) + data = get_dataset( + FLAGS.dataset, + FLAGS.batchsize, + to_grayscale=FLAGS.grayscale, + train_fraction=FLAGS.train_fraction, + random_seed=FLAGS.random_seed, + augment=FLAGS.augment_traindata) + + # Figure out TPU related stuff and create distribution strategy + use_remote_eager = FLAGS.master and FLAGS.master != 'local' + if FLAGS.use_tpu: + logging.info("Use TPU at %s with job name '%s'.", FLAGS.master, + FLAGS.tpu_job_name) + resolver = tf.distribute.cluster_resolver.TPUClusterResolver( + tpu=FLAGS.master, job_name=FLAGS.tpu_job_name) + if use_remote_eager: + tf.config.experimental_connect_to_cluster(resolver) + logging.warning('Remote eager configured. Remote eager can be slow.') + tf.tpu.experimental.initialize_tpu_system(resolver) + strategy = tf.distribute.experimental.TPUStrategy(resolver) + else: + if use_remote_eager: + tf.config.experimental_connect_to_host( + FLAGS.master, job_name='gpu_worker') + logging.warning('Remote eager configured. Remote eager can be slow.') + gpus = tf.config.experimental.list_logical_devices(device_type='GPU') + if gpus: + logging.info('Found GPUs: %s', gpus) + strategy = tf.distribute.MirroredStrategy() + else: + logging.info('Devices: %s', tf.config.list_logical_devices()) + strategy = tf.distribute.OneDeviceStrategy('CPU') + logging.info('Devices: %s', tf.config.list_logical_devices()) + logging.info('Distribution strategy: %s', strategy) + logging.info('Model directory: %s', workdir) + + run(workdir, + data, + strategy, + architecture=FLAGS.dnn_architecture, + n_layers=FLAGS.num_layers, + n_hiddens=FLAGS.num_units, + activation=FLAGS.activation, + dropout_rate=FLAGS.dropout, + l2_penalty=FLAGS.l2reg, + w_init_name=FLAGS.w_init, + b_init_name=FLAGS.b_init, + optimizer_name=FLAGS.optimizer, + learning_rate=FLAGS.learning_rate, + n_epochs=FLAGS.epochs, + epochs_between_checkpoints=FLAGS.epochs_between_checkpoints, + init_stddev=FLAGS.init_std, + cnn_stride=FLAGS.cnn_stride, + reduce_learningrate=FLAGS.reduce_learningrate, + verbosity=FLAGS.verbose) + + +if __name__ == '__main__': + tf.enable_v2_behavior() + app.run(main)