diff --git a/.gitignore b/.gitignore index 2a65be8..15236e8 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,9 @@ __pycache__/ **/__pycache__/ +# models +models/ + # git # TODO: unignore when ready .gitattributes diff --git a/inference/test.py b/inference/test.py index e5d3b23..5bdebaa 100644 --- a/inference/test.py +++ b/inference/test.py @@ -11,7 +11,7 @@ base_path = "../../../Data/" # Load the TFLite model -interpreter = tf.lite.Interpreter(model_path='../models/model.tflite') +interpreter = tf.lite.Interpreter(model_path='../models/model2.tflite') interpreter.allocate_tensors() # Get input and output details @@ -76,13 +76,54 @@ def visualize_detections(image_path, boxes, classes_scores, threshold=0.5): plt.show() -def main(): +def train_images(): + """ + + :return: + """ + train_path = '../data/rparis6k/sets/train/train.txt' + + if not os.path.exists(train_path): + print(f"Error: Train file not found: {train_path}") + return + + with open(train_path, 'r') as f: + train_images = [line.strip() for line in f] + + for image_name in train_images[:5]: + image_path = os.path.join(base_path, 'datasets', 'rparis6k', 'images', image_name) + image_np = load_image_into_numpy_array(image_path) + boxes, classes_scores = run_inference(image_np) + visualize_detections(image_path, boxes, classes_scores) + + +def validation_images(): """ :return: """ + validation_path = '../data/rparis6k/sets/validation/val.txt' + + if not os.path.exists(validation_path): + print(f"Error: Validation file not found: {validation_path}") + return + + with open(validation_path, 'r') as f: + validation_images = [line.strip() for line in f] + + for image_name in validation_images[:5]: + image_path = os.path.join(base_path, 'datasets', 'rparis6k', 'images', image_name) + image_np = load_image_into_numpy_array(image_path) + boxes, classes_scores = run_inference(image_np) + visualize_detections(image_path, boxes, classes_scores) + + +def test_images(): + """ - test_path = '../data/rparis6k/sets/test.txt' + :return: + """ + test_path = '../data/rparis6k/sets/test/test.txt' if not os.path.exists(test_path): print(f"Error: Test file not found: {test_path}") @@ -98,5 +139,16 @@ def main(): visualize_detections(image_path, boxes, classes_scores) +def main(): + """ + + :return: + """ + + #train_images() + #validation_images() + #test_images() + + if __name__ == '__main__': main() diff --git a/models/model.tflite b/models/model.tflite deleted file mode 100644 index 9323452..0000000 Binary files a/models/model.tflite and /dev/null differ diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000..04f59b2 --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,26 @@ +# Data scripts + +This directory contains scripts that are used to get, process and analyze +the data used in the project. + +## Get the data + +```bash +python get_data.py +``` + +## Process the data + +```bash +python prepare_dataset.py +``` + +```mermaid +graph TD; + A-->B; + A-->C; + B-->D; + C-->D; +``` + + diff --git a/training/mediapipe_object_detector_model_customization.ipynb b/training/mediapipe_object_detector_model_customization.ipynb deleted file mode 100644 index 651fb71..0000000 --- a/training/mediapipe_object_detector_model_customization.ipynb +++ /dev/null @@ -1 +0,0 @@ -{"cells":[{"cell_type":"markdown","source":["# Object detection model customization"],"metadata":{"id":"rQIYsEciHsMN"}},{"cell_type":"markdown","source":["## Setup"],"metadata":{"id":"NReBxyCQH1Rx"}},{"cell_type":"markdown","source":["To install the libraries for customizing a model, run the following commands:"],"metadata":{"id":"xbuuejtPH97m"}},{"cell_type":"code","source":["!python --version\n","!pip install --upgrade pip\n","!pip install mediapipe-model-maker"],"metadata":{"id":"ceFO9GPSHyRv","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Use the following code to import the required Python classes:"],"metadata":{"id":"2beHFLLIIIv0"}},{"cell_type":"code","source":["from google.colab import files\n","import os\n","import json\n","import tensorflow as tf\n","assert tf.__version__.startswith('2')\n","\n","from mediapipe_model_maker import object_detector"],"metadata":{"id":"0OpjgSuBIJJC","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Prepare data"],"metadata":{"id":"Yq6UT0j9ILD_"}},{"cell_type":"markdown","source":["1. Mount your Google Drive in Colab\n","2. Upload the files train.txt, val.txt and test.txt to your Drive\n","3. Use Python to read these files and copy the images to the correct locations"],"metadata":{"id":"Q9HH1m0IqsA6"}},{"cell_type":"code","source":["from google.colab import drive\n","import os\n","import shutil\n","\n","# Mount Google Drive\n","drive.mount('/content/drive')"],"metadata":{"id":"-Lrh4ORlDGgu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Define paths\n","base_path = '/content/drive/MyDrive/'\n","source_path = base_path + 'Datasets/revisitop/rparis6k/data/'\n","dest_base_path = base_path + 'MyProject/rparis6k/'\n","\n","train_dataset_path = dest_base_path + 'train/'\n","validation_dataset_path = dest_base_path + 'validation/'\n","test_dataset_path = dest_base_path + 'test/'"],"metadata":{"id":"BlfRX-r_3YUc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Function to copy images\n","def copy_images(file_list, dest_folder):\n"," with open(file_list, 'r') as f:\n"," for line in f:\n"," img_name = line.strip()\n"," src = os.path.join(source_path, img_name)\n"," dst = os.path.join(dest_folder, img_name)\n"," os.makedirs(os.path.dirname(dst), exist_ok=True)\n"," shutil.copy2(src, dst)\n","\n","# Copy images for each set\n","if not os.listdir(train_dataset_path) and not os.listdir(validation_dataset_path) and not os.listdir(test_dataset_path):\n"," copy_images(dest_base_path + 'train.txt', train_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'val.txt', validation_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'test.txt', test_dataset_path + 'images/')\n"," print(\"Dataset division completed!\\n\")\n","else:\n"," print(\"One or more directories are not empty. Copy operation aborted.\\n\")\n","\n","print(f\"Number of images in train set: {len(os.listdir(train_dataset_path + 'images/'))}\")\n","print(f\"Number of images in validation set: {len(os.listdir(validation_dataset_path + 'images/'))}\")\n","print(f\"Number of images in test set: {len(os.listdir(test_dataset_path + 'images/'))}\")"],"metadata":{"id":"-A3ZtrofqptN"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Review dataset"],"metadata":{"id":"IUcrpj6aIzqt"}},{"cell_type":"markdown","source":["Verify the dataset content by printing the categories from the `labels.json` file. There should be 13 total categories. Index 0 is always set to be the `background` class which may be unused in the dataset."],"metadata":{"id":"xYXzxYC3I0h5"}},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","for category_item in labels_json[\"categories\"]:\n"," print(f\"{category_item['id']}: {category_item['name']}\")"],"metadata":{"id":"v6rYvXVwIwKp"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["To better understand the dataset, plot a couple of example images along with their bounding boxes."],"metadata":{"id":"-M33tF2KI8bU"}},{"cell_type":"code","source":["#@title Visualize the training dataset\n","import matplotlib.pyplot as plt\n","from matplotlib import patches, text, patheffects\n","from collections import defaultdict\n","import math\n","\n","def draw_outline(obj):\n"," obj.set_path_effects([patheffects.Stroke(linewidth=4, foreground='black'), patheffects.Normal()])\n","def draw_box(ax, bb):\n"," patch = ax.add_patch(patches.Rectangle((bb[0],bb[1]), bb[2], bb[3], fill=False, edgecolor='red', lw=2))\n"," draw_outline(patch)\n","def draw_text(ax, bb, txt, disp):\n"," text = ax.text(bb[0],(bb[1]-disp),txt,verticalalignment='top'\n"," ,color='white',fontsize=10,weight='bold')\n"," draw_outline(text)\n","def draw_bbox(ax, annotations_list, id_to_label, image_shape):\n"," for annotation in annotations_list:\n"," cat_id = annotation[\"category_id\"]\n"," bbox = annotation[\"bbox\"]\n"," draw_box(ax, bbox)\n"," draw_text(ax, bbox, id_to_label[cat_id], image_shape[0] * 0.05)\n","def visualize(dataset_folder, max_examples=None):\n"," with open(os.path.join(dataset_folder, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n"," images = labels_json[\"images\"]\n"," cat_id_to_label = {item[\"id\"]:item[\"name\"] for item in labels_json[\"categories\"]}\n"," image_annots = defaultdict(list)\n"," for annotation_obj in labels_json[\"annotations\"]:\n"," image_id = annotation_obj[\"image_id\"]\n"," image_annots[image_id].append(annotation_obj)\n","\n"," if max_examples is None:\n"," max_examples = len(image_annots.items())\n"," n_rows = math.ceil(max_examples / 3)\n"," fig, axs = plt.subplots(n_rows, 3, figsize=(24, n_rows*8)) # 3 columns(2nd index), 8x8 for each image\n"," for ind, (image_id, annotations_list) in enumerate(list(image_annots.items())[:max_examples]):\n"," ax = axs[ind//3, ind%3]\n"," img = plt.imread(os.path.join(dataset_folder, \"images\", images[image_id][\"file_name\"]))\n"," ax.imshow(img)\n"," draw_bbox(ax, annotations_list, cat_id_to_label, img.shape)\n"," plt.show()\n","\n","visualize(train_dataset_path, 9)"],"metadata":{"id":"8D17VhVAI33W","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Augmentation"],"metadata":{"id":"_Igtupjrvpph"}},{"cell_type":"code","source":["import albumentations as A\n","\n","transform = A.Compose([\n"," A.HorizontalFlip(p=0.5),\n"," A.RandomBrightnessContrast(p=0.2),\n"," A.Perspective(p=0.5),\n"," A.Rotate(limit=30, p=0.5),\n"," # add transformations\n","])\n","\n","def augment_dataset(dataset):\n"," augmented_images = []\n"," augmented_annotations = []\n","\n"," for image, annotation in dataset:\n"," augmented = transform(image=image, bboxes=annotation['bboxes'], category_ids=annotation['category_ids'])\n"," augmented_images.append(augmented['image'])\n"," augmented_annotations.append({\n"," 'bboxes': augmented['bboxes'],\n"," 'category_ids': augmented['category_ids']\n"," })\n","\n"," return augmented_images, augmented_annotations\n","\n","train_data_augmented = augment_dataset(train_data)"],"metadata":{"id":"AfDaOm8cvs52"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Create dataset\n","\n","The Dataset class has two methods for loading in COCO or PASCAL VOC datasets:\n","* `Dataset.from_coco_folder`\n","* `Dataset.from_pascal_voc_folder`\n","\n","Since the android_figurines dataset is in the COCO dataset format, use the `from_coco_folder` method to load the dataset located at `train_dataset_path` and `validation_dataset_path`. When loading the dataset, the data will be parsed from the provided path and converted into a standardized [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) format which is cached for later use. You should create a `cache_dir` location and reuse it for all your training to avoid saving multiple caches of the same dataset."],"metadata":{"id":"fPyRkkKYJEOB"}},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir=\"/tmp/od_data/validation\")\n","print(\"train_data size: \", train_data.size)\n","print(\"validation_data size: \", validation_data.size)"],"metadata":{"id":"cooLOJrmJDbo"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Retrain model\n","\n","Once you have completed preparing your data, you can begin retraining a model to recognize the new objects, or classes, defined by your training data. The instructions below use the data prepared in the previous section to retrain an image classification model to recognize the two types of android figurines."],"metadata":{"id":"IzhODm2CJaMB"}},{"cell_type":"markdown","source":["### Set retraining options\n","\n","There are a few required settings to run retraining aside from your training dataset: output directory for the model, and the model architecture. Use `HParams` to specify the `export_dir` parameter for the output directory. Use the `SupportedModels` class to specify the model architecture. The object detector solution supports the following model architectures:\n","* `MobileNet-V2`\n","* `MobileNet-MultiHW-AVG`\n","\n","For more advanced customization of training parameters, see the [Hyperparameters](#hyperparameters) section below.\n","\n","To set the required parameters, use the following code:"],"metadata":{"id":"jF7sZHYyJcl7"}},{"cell_type":"code","source":["spec = object_detector.SupportedModels.MOBILENET_MULTI_AVG_I384\n","hparams = object_detector.HParams(\n"," learning_rate=0.1,\n"," batch_size=16,\n"," epochs=50,\n"," cosine_decay_epochs=45,\n"," cosine_decay_alpha=0.1,\n"," export_dir='exported_model' # TODO: check\n",")\n","model_options = object_detector.ModelOptions(\n"," l2_weight_decay=1e-4\n",")\n","options = object_detector.ObjectDetectorOptions(\n"," supported_model=spec,\n"," hparams=hparams,\n"," model_options=model_options\n",")"],"metadata":{"id":"iC6cVSpVJWFw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Run retraining\n","With your training dataset and retraining options prepared, you are ready to start the retraining process. This process is resource intensive and can take a few minutes to a few hours depending on your available compute resources. Using a Google Colab environment with standard GPU runtimes, the example retraining below takes about 2~4 minutes.\n","\n","To begin the retraining process, use the `create()` method with dataset and options you previously defined:"],"metadata":{"id":"bfQoIMPOJhK5"}},{"cell_type":"code","source":["model = object_detector.ObjectDetector.create(\n"," train_data=train_data,\n"," validation_data=validation_data,\n"," options=options)"],"metadata":{"id":"BxT1UHQWJfAX","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Evaluate the model performance\n","\n","After training the model, evaluate it on validation dataset and print the loss and coco_metrics. The most important metric for evaluating the model performance is typically the \"AP\" coco metric for Average Precision."],"metadata":{"id":"y3h_cOytJjXk"}},{"cell_type":"code","source":["loss, coco_metrics = model.evaluate(validation_data, batch_size=4)\n","print(f\"Validation loss: {loss}\")\n","print(f\"Validation coco metrics: {coco_metrics}\")"],"metadata":{"id":"kJabrlEjJkmH","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Export model\n","\n","After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label map."],"metadata":{"id":"Vkg5K-xWJmj6"}},{"cell_type":"code","source":["model.export_model()\n","!ls exported_model\n","files.download('exported_model/model.tflite')"],"metadata":{"id":"ADUK65tBJn7x","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Model quantization"],"metadata":{"id":"6g6QkA6YJpf8"}},{"cell_type":"markdown","source":["Model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy.\n","\n","This section of the guide explains how to apply quantization to your model. Model Maker supports two forms of quantization for object detector:\n","1. Quantization Aware Training: 8 bit integer precision for CPU usage\n","2. Post-Training Quantization: 16 bit floating point precision for GPU usage"],"metadata":{"id":"Jl1Dc9EhJtRo"}},{"cell_type":"markdown","source":["### Quantization aware training (int8 quantization)\n","Quantization aware training (QAT) is a fine-tuning step which happens after fully training your model. This technique further tunes a model which emulates inference time quantization in order to account for the lower precision of 8 bit integer quantization. For on-device applications with a standard CPU, use Int8 precision. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/training) documentation.\n","\n","To apply quantization aware training and export to an int8 model, create a `QATHParams` configuration and run the `quantization_aware_training` method. See the **Hyperparameters** section below on detailed usage of `QATHParams`."],"metadata":{"id":"D88f-mqgJvwi"}},{"cell_type":"code","source":["qat_hparams = object_detector.QATHParams(learning_rate=0.3, batch_size=4, epochs=10, decay_steps=6, decay_rate=0.96)\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"NbTDJtEJJr7Y"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["The QAT step often requires multiple runs to tune the parameters of training. To avoid having to rerun model training using the `create` method, use the `restore_float_ckpt` method to restore the model state back to the fully trained float model(After running the `create` method) in order to run QAT again."],"metadata":{"id":"5KEPghm9JyUi"}},{"cell_type":"code","source":["new_qat_hparams = object_detector.QATHParams(learning_rate=0.9, batch_size=4, epochs=15, decay_steps=5, decay_rate=0.96)\n","model.restore_float_ckpt()\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=new_qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"vALj8IqaJ1B4"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Finally, us the `export_model` to export to an int8 quantized model. The `export_model` function will automatically export to either float32 or int8 model depending on whether `quantization_aware_training` was run."],"metadata":{"id":"HdZAaTJKJ4ky"}},{"cell_type":"code","source":["model.export_model('model_int8_qat.tflite')\n","!ls -lh exported_model\n","files.download('exported_model/model_int8_qat.tflite')"],"metadata":{"id":"AePYkJHwJ2OG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Post-training quantization (fp16 quantization)\n","\n","Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 16-bit floats. Float16 quantization is reccomended for GPU usage. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/post_training) documentation.\n","\n","First, import the MediaPipe Model Maker quantization module:"],"metadata":{"id":"mm4gCymvJ7wK"}},{"cell_type":"code","source":["from mediapipe_model_maker import quantization"],"metadata":{"id":"mt8zeY52J8Gk"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Define a QuantizationConfig object using the `for_float16()` class method. This configuration modifies a trained model to use 16-bit floating point numbers instead of 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class."],"metadata":{"id":"r6eGgib4J-Ac"}},{"cell_type":"code","source":["quantization_config = quantization.QuantizationConfig.for_float16()"],"metadata":{"id":"IjAn-GBKJ_ab"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Export the model using the additional quantization_config object to apply post-training quantization. Note that if you previously ran `quantization_aware_training`, you must first convert the model back to a float model by using `restore_float_ckpt`."],"metadata":{"id":"ChqPtssVKCA8"}},{"cell_type":"code","source":["model.restore_float_ckpt()\n","model.export_model(model_name=\"model_fp16.tflite\", quantization_config=quantization_config)\n","!ls -lh exported_model\n","files.download('exported_model/model_fp16.tflite')"],"metadata":{"id":"f90WxGNNKCSW"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Hyperparameters\n","You can further customize the model using the ObjectDetectorOptions class, which has three parameters for `SupportedModels`, `ModelOptions`, and `HParams`.\n","\n","Use the `SupportedModels` enum class to specify the model architecture to use for training. The following model architectures are supported:\n","* MOBILENET_V2\n","* MOBILENET_V2_I320\n","* MOBILENET_MULTI_AVG\n","* MOBILENET_MULTI_AVG_I384\n","\n","Use the `HParams` class to customize other parameters related to training and saving the model:\n","* `learning_rate`: Learning rate to use for gradient descent training. Defaults to 0.3.\n","* `batch_size`: Batch size for training. Defaults to 8.\n","* `epochs`: Number of training iterations over the dataset. Defaults to 30.\n","* `cosine_decay_epochs`: The number of epochs for cosine decay learning rate. See [tf.keras.optimizers.schedules.CosineDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay) for more info. Defaults to None, which is equivalent to setting it to `epochs`.\n","* `cosine_decay_alpha`: The alpha value for cosine decay learning rate. See [tf.keras.optimizers.schedules.CosineDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay) for more info. Defaults to 1.0, which means no cosine decay.\n","\n","Use the `ModelOptions` class to customize parameters related to the model itself:\n","* `l2_weight_decay`: L2 regularization penalty used in [tf.keras.regularizers.L2](https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/L2). Defaults to 3e-5.\n","\n","Uset the `QATHParams` class to customize training parameters for Quantization Aware Training:\n","* `learning_rate`: Learning rate to use for gradient descent QAT. Defaults to 0.3.\n","* `batch_size`: Batch size for QAT. Defaults to 8\n","* `epochs`: Number of training iterations over the dataset. Defaults to 15.\n","* `decay_steps`: Learning rate decay steps for Exponential Decay. See [tf.keras.optimizers.schedules.ExponentialDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay) for more information. Defaults to 8\n","* `decay_rate`: Learning rate decay rate for Exponential Decay. See [tf.keras.optimizers.schedules.ExponentialDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay) for more information. Defaults to 0.96."],"metadata":{"id":"XMT2xM7jKD--"}},{"cell_type":"markdown","source":["## Benchmarking\n","Below is a summary of our benchmarking results for the supported model architectures. These models were trained and evaluated on the same android figurines dataset as this notebook. When considering the model benchmarking results, there are a few important caveats to keep in mind:\n","* The android figurines dataset is a small and simple dataset with 62 training examples and 10 validation examples. Since the dataset is quite small, metrics may vary drastically due to variances in the training process. This dataset was provided for demo purposes and it is recommended to collect more data samples for better performing models.\n","* The float32 models were trained with the default HParams, and the QAT step for the int8 models was run with `QATHParams(learning_rate=0.1, batch_size=4, epochs=30, decay_rate=1)`.\n","* For your own dataset, you will likely need to tune values for both HParams and QATHParams in order to achieve the best results. See the [Hyperparameters](#hyperparameters) section above for more information on configuring training parameters.\n","* All latency numbers are benchmarked on the Pixel 6.\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","
Model architectureInput Image SizeTest APCPU LatencyModel Size
float32QAT int8float32QAT int8float32QAT int8
MobileNetV2256x25688.4%73.5%48ms16ms11MB3.2MB
MobileNetV2 I320320x32089.1%75.5%75ms33.38ms10MB3.3MB
MobileNet MultiHW AVG256x25688.5%70.0%56ms19ms13MB3.6MB
MobileNet MultiHW AVG I384384x38492.7%73.4%238ms41ms13MB3.6MB
\n","\n"],"metadata":{"id":"3l1ReIyUKIVW"}}],"metadata":{"colab":{"last_runtime":{"build_target":"//learning/grp/tools/ml_python:ml_notebook","kind":"private"},"private_outputs":true,"provenance":[{"file_id":"https://github.com/googlesamples/mediapipe/blob/main/examples/customization/object_detector.ipynb","timestamp":1720270505355},{"file_id":"11PG1YgsQWWLJ8jpqJ6QY7hjYWzxVwoCb","timestamp":1677706798050}],"gpuType":"T4"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0} \ No newline at end of file diff --git a/training/mediapipe_object_detector_model_customization_template.ipynb b/training/mediapipe_object_detector_model_customization_template.ipynb new file mode 100644 index 0000000..55919e3 --- /dev/null +++ b/training/mediapipe_object_detector_model_customization_template.ipynb @@ -0,0 +1,790 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "eOlNZgOafs5b" + }, + "source": [ + "Project: /mediapipe/_project.yaml\n", + "Book: /mediapipe/_book.yaml\n", + "\n", + "\n", + "\n", + "# Object detection model customization guide\n", + "\n", + "\u003ctable align=\"left\" class=\"buttons\"\u003e\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://colab.research.google.com/github/googlesamples/mediapipe/blob/main/examples/customization/object_detector.ipynb\" target=\"_blank\"\u003e\n", + " \u003cimg src=\"https://developers.google.com/static/mediapipe/solutions/customization/colab-logo-32px_1920.png\" alt=\"Colab logo\"\u003e Run in Colab\n", + " \u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\n", + " \u003ctd\u003e\n", + " \u003ca href=\"https://github.com/googlesamples/mediapipe/blob/main/examples/customization/object_detector.ipynb\" target=\"_blank\"\u003e\n", + " \u003cimg src=\"https://developers.google.com/static/mediapipe/solutions/customization/github-logo-32px_1920.png\" alt=\"GitHub logo\"\u003e\n", + " View on GitHub\n", + " \u003c/a\u003e\n", + " \u003c/td\u003e\n", + "\u003c/table\u003e" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "wGYgvC7P7faD" + }, + "outputs": [], + "source": [ + "#@title License information\n", + "# Copyright 2023 The MediaPipe Authors.\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "#\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "#\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "#\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "WzM-FLsJKwij" + }, + "source": [ + "The MediaPipe object detection solution provides several models you can use immediately for machine learning (ML) in your application. However, if you need to detect objects not covered by the provided models, you can customize any of the provided models with your own data and MediaPipe Model Maker. This model modification tool rebuilds the model using data you provide. This method is faster than training a new model and can produce a model that is more useful for your specific application.\n", + "\n", + "The following sections show you how to use Model Maker to retrain a pre-built model for object detection with your own data, which you can then use with the MediaPipe [Object Detector](https://developers.google.com/mediapipe/solutions/vision/object_detector). The example retrains a general purpose object detection model to detect android figurines in images." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "1_9q7fppLbPA" + }, + "source": [ + "## Setup\n", + "\n", + "This section describes key steps for setting up your development environment to retrain a model. These instructions describe how to update a model using [Google Colab](https://colab.research.google.com/), and you can also use Python in your own development environment. For general information on setting up your development environment for using MediaPipe, including platform version requirements, see the [Setup guide for Python](https://developers.google.com/mediapipe/solutions/setup_python).\n", + "\n", + "**Attention:** This MediaPipe Solutions Preview is an early release. [Learn more](https://developers.google.com/mediapipe/solutions/about).\n", + "\n", + "To install the libraries for customizing a model, run the following commands:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "GMSCcHh1LScM" + }, + "outputs": [], + "source": [ + "!python --version\n", + "!pip install --upgrade pip\n", + "!pip install mediapipe-model-maker" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "7NXvZgLPLh6n" + }, + "source": [ + "Use the following code to import the required Python classes:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "oazmbPzKHYFq" + }, + "outputs": [], + "source": [ + "from google.colab import files\n", + "import os\n", + "import json\n", + "import tensorflow as tf\n", + "assert tf.__version__.startswith('2')\n", + "\n", + "from mediapipe_model_maker import object_detector" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "_zP1AkaRL72Z" + }, + "source": [ + "## Prepare data" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "EVnSV0ZQMA5-" + }, + "source": [ + "Retraining a model for object detection requires a dataset that includes the items, or classes, that you want the completed model to be able to identify. You can do this by trimming down a public dataset to only the classes that are relevant to your usecase, compiling your own dataset, or some combination of both, The dataset can be significantly smaller than what would be required to train a new model. For example, the [COCO](https://cocodataset.org/) dataset used to train many reference models contains hundreds of thousands of images with 91 classes of objects. Transfer learning with Model Maker can retrain an existing model with a smaller dataset and still perform well, depending on your inference accuracy goals. These instructions use a smaller dataset containing 2 types of android figurines, or 2 classes, with 62 total training images.\n", + "\n", + "To download the example dataset, use the following code:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Mz3-eHe07FEX" + }, + "outputs": [], + "source": [ + "!wget https://storage.googleapis.com/mediapipe-tasks/object_detector/android_figurine.zip\n", + "!unzip android_figurine.zip\n", + "train_dataset_path = \"android_figurine/train\"\n", + "validation_dataset_path = \"android_figurine/validation\"" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "oEVcacUj7L1l" + }, + "source": [ + "This code stores the dataset at the directory location `android_figurine`. The directory contains two subdirectories for the training and validation datasets, located in `android_figurine/train` and `android_figurine/validation` respectively. Each of the train and validation datasets follow the COCO Dataset format described below." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "vbIAPjPUMHiK" + }, + "source": [ + "### Supported dataset formats\n", + "Model Maker Object Detection API supports reading the following dataset formats:\n", + "\n", + "#### COCO format\n", + "\n", + "The COCO dataset format has a `data` directory which stores all of the images and a single `labels.json` file which contains the object annotations for all images.\n", + "```\n", + "\u003cdataset_dir\u003e/\n", + " data/\n", + " \u003cimg0\u003e.\u003cjpg/jpeg\u003e\n", + " \u003cimg1\u003e.\u003cjpg/jpeg\u003e\n", + " ...\n", + " labels.json\n", + "```\n", + "where `labels.json` is formatted as:\n", + "```\n", + "{\n", + " \"categories\":[\n", + " {\"id\":1, \"name\":\u003ccat1_name\u003e},\n", + " ...\n", + " ],\n", + " \"images\":[\n", + " {\"id\":0, \"file_name\":\"\u003cimg0\u003e.\u003cjpg/jpeg\u003e\"},\n", + " ...\n", + " ],\n", + " \"annotations\":[\n", + " {\"id\":0, \"image_id\":0, \"category_id\":1, \"bbox\":[x-top left, y-top left, width, height]},\n", + " ...\n", + " ]\n", + "}\n", + "```\n", + "\n", + "#### PASCAL VOC format\n", + "\n", + "The PASCAL VOC dataset format also has a `data` directory which stores all of the images, however the annotations are split up per image into corresponding xml files in the `Annotations` directory.\n", + "```\n", + "\u003cdataset_dir\u003e/\n", + " data/\n", + " \u003cfile0\u003e.\u003cjpg/jpeg\u003e\n", + " ...\n", + " Annotations/\n", + " \u003cfile0\u003e.xml\n", + " ...\n", + "```\n", + "where the xml files are formatted as:\n", + "```\n", + "\u003cannotation\u003e\n", + " \u003cfilename\u003efile0.jpg\u003c/filename\u003e\n", + " \u003cobject\u003e\n", + " \u003cname\u003ekangaroo\u003c/name\u003e\n", + " \u003cbndbox\u003e\n", + " \u003cxmin\u003e233\u003c/xmin\u003e\n", + " \u003cymin\u003e89\u003c/ymin\u003e\n", + " \u003cxmax\u003e386\u003c/xmax\u003e\n", + " \u003cymax\u003e262\u003c/ymax\u003e\n", + " \u003c/bndbox\u003e\n", + " \u003c/object\u003e\n", + " \u003cobject\u003e\n", + " ...\n", + " \u003c/object\u003e\n", + " ...\n", + "\u003c/annotation\u003e\n", + "```" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "G7TPn8Mb_aJb" + }, + "source": [ + "### Review dataset" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "L64U7mgPNKec" + }, + "source": [ + "Verify the dataset content by printing the categories from the `labels.json` file. There should be 3 total categories. Index 0 is always set to be the `background` class which may be unused in the dataset. There should be two non-background categories of `android` and `pig_android`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2f_Z-TAwNK3n" + }, + "outputs": [], + "source": [ + "with open(os.path.join(train_dataset_path, \"labels.json\"), \"r\") as f:\n", + " labels_json = json.load(f)\n", + "for category_item in labels_json[\"categories\"]:\n", + " print(f\"{category_item['id']}: {category_item['name']}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "xr-7QJ05PmyS" + }, + "source": [ + "To better understand the dataset, plot a couple of example images along with their bounding boxes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "6kTw3uodPl7-" + }, + "outputs": [], + "source": [ + "#@title Visualize the training dataset\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib import patches, text, patheffects\n", + "from collections import defaultdict\n", + "import math\n", + "\n", + "def draw_outline(obj):\n", + " obj.set_path_effects([patheffects.Stroke(linewidth=4, foreground='black'), patheffects.Normal()])\n", + "def draw_box(ax, bb):\n", + " patch = ax.add_patch(patches.Rectangle((bb[0],bb[1]), bb[2], bb[3], fill=False, edgecolor='red', lw=2))\n", + " draw_outline(patch)\n", + "def draw_text(ax, bb, txt, disp):\n", + " text = ax.text(bb[0],(bb[1]-disp),txt,verticalalignment='top'\n", + " ,color='white',fontsize=10,weight='bold')\n", + " draw_outline(text)\n", + "def draw_bbox(ax, annotations_list, id_to_label, image_shape):\n", + " for annotation in annotations_list:\n", + " cat_id = annotation[\"category_id\"]\n", + " bbox = annotation[\"bbox\"]\n", + " draw_box(ax, bbox)\n", + " draw_text(ax, bbox, id_to_label[cat_id], image_shape[0] * 0.05)\n", + "def visualize(dataset_folder, max_examples=None):\n", + " with open(os.path.join(dataset_folder, \"labels.json\"), \"r\") as f:\n", + " labels_json = json.load(f)\n", + " images = labels_json[\"images\"]\n", + " cat_id_to_label = {item[\"id\"]:item[\"name\"] for item in labels_json[\"categories\"]}\n", + " image_annots = defaultdict(list)\n", + " for annotation_obj in labels_json[\"annotations\"]:\n", + " image_id = annotation_obj[\"image_id\"]\n", + " image_annots[image_id].append(annotation_obj)\n", + "\n", + " if max_examples is None:\n", + " max_examples = len(image_annots.items())\n", + " n_rows = math.ceil(max_examples / 3)\n", + " fig, axs = plt.subplots(n_rows, 3, figsize=(24, n_rows*8)) # 3 columns(2nd index), 8x8 for each image\n", + " for ind, (image_id, annotations_list) in enumerate(list(image_annots.items())[:max_examples]):\n", + " ax = axs[ind//3, ind%3]\n", + " img = plt.imread(os.path.join(dataset_folder, \"images\", images[image_id][\"file_name\"]))\n", + " ax.imshow(img)\n", + " draw_bbox(ax, annotations_list, cat_id_to_label, img.shape)\n", + " plt.show()\n", + "\n", + "visualize(train_dataset_path, 9)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ANqfl-ghQM74" + }, + "source": [ + "### Create dataset\n", + "\n", + "The Dataset class has two methods for loading in COCO or PASCAL VOC datasets:\n", + "* `Dataset.from_coco_folder`\n", + "* `Dataset.from_pascal_voc_folder`\n", + "\n", + "Since the android_figurines dataset is in the COCO dataset format, use the `from_coco_folder` method to load the dataset located at `train_dataset_path` and `validation_dataset_path`. When loading the dataset, the data will be parsed from the provided path and converted into a standardized [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) format which is cached for later use. You should create a `cache_dir` location and reuse it for all your training to avoid saving multiple caches of the same dataset." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "EOdyImqyI6s-" + }, + "outputs": [], + "source": [ + "train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n", + "validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir=\"/tmp/od_data/validation\")\n", + "print(\"train_data size: \", train_data.size)\n", + "print(\"validation_data size: \", validation_data.size)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "det3wYoWRRDs" + }, + "source": [ + "## Retrain model\n", + "\n", + "Once you have completed preparing your data, you can begin retraining a model to recognize the new objects, or classes, defined by your training data. The instructions below use the data prepared in the previous section to retrain an image classification model to recognize the two types of android figurines." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "f9AzF_CGA7mj" + }, + "source": [ + "### Set retraining options\n", + "\n", + "There are a few required settings to run retraining aside from your training dataset: output directory for the model, and the model architecture. Use `HParams` to specify the `export_dir` parameter for the output directory. Use the `SupportedModels` class to specify the model architecture. The object detector solution supports the following model architectures:\n", + "* `MobileNet-V2`\n", + "* `MobileNet-MultiHW-AVG`\n", + "\n", + "For more advanced customization of training parameters, see the [Hyperparameters](#hyperparameters) section below.\n", + "\n", + "To set the required parameters, use the following code:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "4ZHjWHM1JyiN" + }, + "outputs": [], + "source": [ + "spec = object_detector.SupportedModels.MOBILENET_MULTI_AVG\n", + "hparams = object_detector.HParams(export_dir='exported_model')\n", + "options = object_detector.ObjectDetectorOptions(\n", + " supported_model=spec,\n", + " hparams=hparams\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "t3Kto1RlCcPj" + }, + "source": [ + "### Run retraining\n", + "With your training dataset and retraining options prepared, you are ready to start the retraining process. This process is resource intensive and can take a few minutes to a few hours depending on your available compute resources. Using a Google Colab environment with standard GPU runtimes, the example retraining below takes about 2~4 minutes.\n", + "\n", + "To begin the retraining process, use the `create()` method with dataset and options you previously defined:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "V5bIsWBZCb8d" + }, + "outputs": [], + "source": [ + "model = object_detector.ObjectDetector.create(\n", + " train_data=train_data,\n", + " validation_data=validation_data,\n", + " options=options)" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "RuRapoFiRp34" + }, + "source": [ + "### Evaluate the model performance\n", + "\n", + "After training the model, evaluate it on validation dataset and print the loss and coco_metrics. The most important metric for evaluating the model performance is typically the \"AP\" coco metric for Average Precision." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xJvB_nf7RwzJ" + }, + "outputs": [], + "source": [ + "loss, coco_metrics = model.evaluate(validation_data, batch_size=4)\n", + "print(f\"Validation loss: {loss}\")\n", + "print(f\"Validation coco metrics: {coco_metrics}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "ax8TkYA9VJUv" + }, + "source": [ + "## Export model\n", + "\n", + "After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label map." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "PWZqnGEKVP13" + }, + "outputs": [], + "source": [ + "model.export_model()\n", + "!ls exported_model\n", + "files.download('exported_model/model.tflite')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UcYu5ENbT4T6" + }, + "source": [ + "## Model quantization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "nIeJjCfWTnBj" + }, + "source": [ + "Model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy.\n", + "\n", + "This section of the guide explains how to apply quantization to your model. Model Maker supports two forms of quantization for object detector:\n", + "1. Quantization Aware Training: 8 bit integer precision for CPU usage\n", + "2. Post-Training Quantization: 16 bit floating point precision for GPU usage" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "UcB3DRfHWDfs" + }, + "source": [ + "### Quantization aware training (int8 quantization)\n", + "Quantization aware training (QAT) is a fine-tuning step which happens after fully training your model. This technique further tunes a model which emulates inference time quantization in order to account for the lower precision of 8 bit integer quantization. For on-device applications with a standard CPU, use Int8 precision. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/training) documentation.\n", + "\n", + "To apply quantization aware training and export to an int8 model, create a `QATHParams` configuration and run the `quantization_aware_training` method. See the **Hyperparameters** section below on detailed usage of `QATHParams`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "_7nRSQT9WCS-" + }, + "outputs": [], + "source": [ + "qat_hparams = object_detector.QATHParams(learning_rate=0.3, batch_size=4, epochs=10, decay_steps=6, decay_rate=0.96)\n", + "model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)\n", + "qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n", + "print(f\"QAT validation loss: {qat_loss}\")\n", + "print(f\"QAT validation coco metrics: {qat_coco_metrics}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "rT7grgHOW048" + }, + "source": [ + "The QAT step often requires multiple runs to tune the parameters of training. To avoid having to rerun model training using the `create` method, use the `restore_float_ckpt` method to restore the model state back to the fully trained float model(After running the `create` method) in order to run QAT again." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "xHDPWbIaXjR4" + }, + "outputs": [], + "source": [ + "new_qat_hparams = object_detector.QATHParams(learning_rate=0.9, batch_size=4, epochs=15, decay_steps=5, decay_rate=0.96)\n", + "model.restore_float_ckpt()\n", + "model.quantization_aware_training(train_data, validation_data, qat_hparams=new_qat_hparams)\n", + "qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n", + "print(f\"QAT validation loss: {qat_loss}\")\n", + "print(f\"QAT validation coco metrics: {qat_coco_metrics}\")" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "AfWo6TVpWJfr" + }, + "source": [ + "Finally, us the `export_model` to export to an int8 quantized model. The `export_model` function will automatically export to either float32 or int8 model depending on whether `quantization_aware_training` was run." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "rsixePxCWJDp" + }, + "outputs": [], + "source": [ + "model.export_model('model_int8_qat.tflite')\n", + "!ls -lh exported_model\n", + "files.download('exported_model/model_int8_qat.tflite')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "eO8nZR3Cgx8_" + }, + "source": [ + "### Post-training quantization (fp16 quantization)\n", + "\n", + "Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 16-bit floats. Float16 quantization is reccomended for GPU usage. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/post_training) documentation.\n", + "\n", + "First, import the MediaPipe Model Maker quantization module:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "Yo7cQ_N-ZE8A" + }, + "outputs": [], + "source": [ + "from mediapipe_model_maker import quantization" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "OE8j5cloZTo-" + }, + "source": [ + "Define a QuantizationConfig object using the `for_float16()` class method. This configuration modifies a trained model to use 16-bit floating point numbers instead of 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "kbwc3g_Fa3dv" + }, + "outputs": [], + "source": [ + "quantization_config = quantization.QuantizationConfig.for_float16()" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "qTrkDXi8bM_L" + }, + "source": [ + "Export the model using the additional quantization_config object to apply post-training quantization. Note that if you previously ran `quantization_aware_training`, you must first convert the model back to a float model by using `restore_float_ckpt`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "mmzEu_AjbMPI" + }, + "outputs": [], + "source": [ + "model.restore_float_ckpt()\n", + "model.export_model(model_name=\"model_fp16.tflite\", quantization_config=quantization_config)\n", + "!ls -lh exported_model\n", + "files.download('exported_model/model_fp16.tflite')" + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "npaRBUB3ZevY" + }, + "source": [ + "## Hyperparameters\n", + "You can further customize the model using the ObjectDetectorOptions class, which has three parameters for `SupportedModels`, `ModelOptions`, and `HParams`.\n", + "\n", + "Use the `SupportedModels` enum class to specify the model architecture to use for training. The following model architectures are supported:\n", + "* MOBILENET_V2\n", + "* MOBILENET_V2_I320\n", + "* MOBILENET_MULTI_AVG\n", + "* MOBILENET_MULTI_AVG_I384\n", + "\n", + "Use the `HParams` class to customize other parameters related to training and saving the model:\n", + "* `learning_rate`: Learning rate to use for gradient descent training. Defaults to 0.3.\n", + "* `batch_size`: Batch size for training. Defaults to 8.\n", + "* `epochs`: Number of training iterations over the dataset. Defaults to 30.\n", + "* `cosine_decay_epochs`: The number of epochs for cosine decay learning rate. See [tf.keras.optimizers.schedules.CosineDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay) for more info. Defaults to None, which is equivalent to setting it to `epochs`.\n", + "* `cosine_decay_alpha`: The alpha value for cosine decay learning rate. See [tf.keras.optimizers.schedules.CosineDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay) for more info. Defaults to 1.0, which means no cosine decay.\n", + "\n", + "Use the `ModelOptions` class to customize parameters related to the model itself:\n", + "* `l2_weight_decay`: L2 regularization penalty used in [tf.keras.regularizers.L2](https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/L2). Defaults to 3e-5.\n", + "\n", + "Uset the `QATHParams` class to customize training parameters for Quantization Aware Training:\n", + "* `learning_rate`: Learning rate to use for gradient descent QAT. Defaults to 0.3.\n", + "* `batch_size`: Batch size for QAT. Defaults to 8\n", + "* `epochs`: Number of training iterations over the dataset. Defaults to 15.\n", + "* `decay_steps`: Learning rate decay steps for Exponential Decay. See [tf.keras.optimizers.schedules.ExponentialDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay) for more information. Defaults to 8\n", + "* `decay_rate`: Learning rate decay rate for Exponential Decay. See [tf.keras.optimizers.schedules.ExponentialDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay) for more information. Defaults to 0.96." + ] + }, + { + "cell_type": "markdown", + "metadata": { + "id": "9HCrUl8z6liX" + }, + "source": [ + "## Benchmarking\n", + "Below is a summary of our benchmarking results for the supported model architectures. These models were trained and evaluated on the same android figurines dataset as this notebook. When considering the model benchmarking results, there are a few important caveats to keep in mind:\n", + "* The android figurines dataset is a small and simple dataset with 62 training examples and 10 validation examples. Since the dataset is quite small, metrics may vary drastically due to variances in the training process. This dataset was provided for demo purposes and it is recommended to collect more data samples for better performing models.\n", + "* The float32 models were trained with the default HParams, and the QAT step for the int8 models was run with `QATHParams(learning_rate=0.1, batch_size=4, epochs=30, decay_rate=1)`.\n", + "* For your own dataset, you will likely need to tune values for both HParams and QATHParams in order to achieve the best results. See the [Hyperparameters](#hyperparameters) section above for more information on configuring training parameters.\n", + "* All latency numbers are benchmarked on the Pixel 6.\n", + "\n", + "\n", + "\u003ctable\u003e\n", + "\u003cthead\u003e\n", + "\u003ccol\u003e\n", + "\u003ccol\u003e\n", + "\u003ccolgroup span=\"2\"\u003e\u003c/colgroup\u003e\n", + "\u003ccolgroup span=\"2\"\u003e\u003c/colgroup\u003e\n", + "\u003ccolgroup span=\"2\"\u003e\u003c/colgroup\u003e\n", + "\u003ctr\u003e\n", + "\u003cth rowspan=\"2\"\u003eModel architecture\u003c/th\u003e\n", + "\u003cth rowspan=\"2\"\u003eInput Image Size\u003c/th\u003e\n", + "\u003cth colspan=\"2\" scope=\"colgroup\"\u003eTest AP\u003c/th\u003e\n", + "\u003cth colspan=\"2\" scope=\"colgroup\"\u003eCPU Latency\u003c/th\u003e\n", + "\u003cth colspan=\"2\" scope=\"colgroup\"\u003eModel Size\u003c/th\u003e\n", + "\u003c/tr\u003e\n", + "\u003ctr\u003e\n", + "\u003cth\u003efloat32\u003c/th\u003e\n", + "\u003cth\u003eQAT int8\u003c/th\u003e\n", + "\u003cth\u003efloat32\u003c/th\u003e\n", + "\u003cth\u003eQAT int8\u003c/th\u003e\n", + "\u003cth\u003efloat32\u003c/th\u003e\n", + "\u003cth\u003eQAT int8\u003c/th\u003e\n", + "\u003c/tr\u003e\n", + "\u003c/thead\u003e\n", + "\u003ctbody\u003e\n", + "\u003ctr\u003e\n", + "\u003ctd\u003eMobileNetV2\u003c/td\u003e\n", + "\u003ctd\u003e256x256\u003c/td\u003e\n", + "\u003ctd\u003e88.4%\u003c/td\u003e\n", + "\u003ctd\u003e73.5%\u003c/td\u003e\n", + "\u003ctd\u003e48ms\u003c/td\u003e\n", + "\u003ctd\u003e16ms\u003c/td\u003e\n", + "\u003ctd\u003e11MB\u003c/td\u003e\n", + "\u003ctd\u003e3.2MB\u003c/td\u003e\n", + "\u003c/tr\u003e\n", + "\u003ctr\u003e\n", + "\u003ctd\u003eMobileNetV2 I320\u003c/td\u003e\n", + "\u003ctd\u003e320x320\u003c/td\u003e\n", + "\u003ctd\u003e89.1%\u003c/td\u003e\n", + "\u003ctd\u003e75.5%\u003c/td\u003e\n", + "\u003ctd\u003e75ms\u003c/td\u003e\n", + "\u003ctd\u003e33.38ms\u003c/td\u003e\n", + "\u003ctd\u003e10MB\u003c/td\u003e\n", + "\u003ctd\u003e3.3MB\u003c/td\u003e\n", + "\u003c/tr\u003e\n", + "\u003ctr\u003e\n", + "\u003ctd\u003eMobileNet MultiHW AVG\u003c/td\u003e\n", + "\u003ctd\u003e256x256\u003c/td\u003e\n", + "\u003ctd\u003e88.5%\u003c/td\u003e\n", + "\u003ctd\u003e70.0%\u003c/td\u003e\n", + "\u003ctd\u003e56ms\u003c/td\u003e\n", + "\u003ctd\u003e19ms\u003c/td\u003e\n", + "\u003ctd\u003e13MB\u003c/td\u003e\n", + "\u003ctd\u003e3.6MB\u003c/td\u003e\n", + "\u003c/tr\u003e\n", + "\u003ctr\u003e\n", + "\u003ctd\u003eMobileNet MultiHW AVG I384\u003c/td\u003e\n", + "\u003ctd\u003e384x384\u003c/td\u003e\n", + "\u003ctd\u003e92.7%\u003c/td\u003e\n", + "\u003ctd\u003e73.4%\u003c/td\u003e\n", + "\u003ctd\u003e238ms\u003c/td\u003e\n", + "\u003ctd\u003e41ms\u003c/td\u003e\n", + "\u003ctd\u003e13MB\u003c/td\u003e\n", + "\u003ctd\u003e3.6MB\u003c/td\u003e\n", + "\u003c/tr\u003e\n", + "\n", + "\u003c/tbody\u003e\n", + "\u003c/table\u003e\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "TOCPzKohXxy6" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "last_runtime": { + "build_target": "//learning/grp/tools/ml_python:ml_notebook", + "kind": "private" + }, + "private_outputs": true, + "provenance": [ + { + "file_id": "11PG1YgsQWWLJ8jpqJ6QY7hjYWzxVwoCb", + "timestamp": 1677706798050 + } + ] + }, + "kernelspec": { + "display_name": "Python 3", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/training/mp_training_paris6k.ipynb b/training/mp_training_paris6k.ipynb new file mode 100644 index 0000000..a3a6dd3 --- /dev/null +++ b/training/mp_training_paris6k.ipynb @@ -0,0 +1 @@ +{"cells":[{"cell_type":"markdown","source":["# Object detection model customization"],"metadata":{"id":"rQIYsEciHsMN"}},{"cell_type":"markdown","source":["## Setup"],"metadata":{"id":"NReBxyCQH1Rx"}},{"cell_type":"markdown","source":["To install the libraries for customizing a model, run the following commands:"],"metadata":{"id":"xbuuejtPH97m"}},{"cell_type":"code","source":["!python --version\n","!pip install --upgrade pip\n","!pip install mediapipe-model-maker"],"metadata":{"id":"ceFO9GPSHyRv","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Use the following code to import the required Python classes:"],"metadata":{"id":"2beHFLLIIIv0"}},{"cell_type":"code","source":["from google.colab import files\n","import os\n","import json\n","import tensorflow as tf\n","assert tf.__version__.startswith('2')\n","\n","from mediapipe_model_maker import object_detector"],"metadata":{"id":"0OpjgSuBIJJC","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Prepare data"],"metadata":{"id":"Yq6UT0j9ILD_"}},{"cell_type":"markdown","source":["1. Mount your Google Drive in Colab\n","2. Upload the files train.txt, val.txt and test.txt to your Drive\n","3. Use Python to read these files and copy the images to the correct locations"],"metadata":{"id":"Q9HH1m0IqsA6"}},{"cell_type":"code","source":["from google.colab import drive\n","import os\n","import shutil\n","\n","# Mount Google Drive\n","drive.mount('/content/drive')"],"metadata":{"id":"-Lrh4ORlDGgu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Define paths\n","base_path = '/content/drive/MyDrive/'\n","source_path = base_path + 'Datasets/revisitop/rparis6k/data/'\n","dest_base_path = base_path + 'MyProject/rparis6k/'\n","\n","train_dataset_path = dest_base_path + 'train/'\n","validation_dataset_path = dest_base_path + 'validation/'\n","test_dataset_path = dest_base_path + 'test/'"],"metadata":{"id":"BlfRX-r_3YUc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Function to copy images\n","def copy_images(file_list, dest_folder):\n"," with open(file_list, 'r') as f:\n"," for line in f:\n"," img_name = line.strip()\n"," src = os.path.join(source_path, img_name)\n"," dst = os.path.join(dest_folder, img_name)\n"," os.makedirs(os.path.dirname(dst), exist_ok=True)\n"," shutil.copy2(src, dst)\n","\n","# Copy images for each set\n","if not os.listdir(train_dataset_path) and not os.listdir(validation_dataset_path) and not os.listdir(test_dataset_path):\n"," copy_images(dest_base_path + 'train.txt', train_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'val.txt', validation_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'test.txt', test_dataset_path + 'images/')\n"," print(\"Dataset division completed!\\n\")\n","else:\n"," print(\"One or more directories are not empty. Copy operation aborted.\\n\")\n","\n","print(f\"Number of images in train set: {len(os.listdir(train_dataset_path + 'images/'))}\")\n","print(f\"Number of images in validation set: {len(os.listdir(validation_dataset_path + 'images/'))}\")\n","print(f\"Number of images in test set: {len(os.listdir(test_dataset_path + 'images/'))}\")"],"metadata":{"id":"-A3ZtrofqptN"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Review dataset"],"metadata":{"id":"IUcrpj6aIzqt"}},{"cell_type":"markdown","source":["Verify the dataset content by printing the categories from the `labels.json` file. There should be 13 total categories. Index 0 is always set to be the `background` class which may be unused in the dataset."],"metadata":{"id":"xYXzxYC3I0h5"}},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","for category_item in labels_json[\"categories\"]:\n"," print(f\"{category_item['id']}: {category_item['name']}\")"],"metadata":{"id":"v6rYvXVwIwKp"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["To better understand the dataset, plot a couple of example images along with their bounding boxes."],"metadata":{"id":"-M33tF2KI8bU"}},{"cell_type":"code","source":["#@title Visualize the training dataset\n","import matplotlib.pyplot as plt\n","from matplotlib import patches, text, patheffects\n","from collections import defaultdict\n","import math\n","\n","def draw_outline(obj):\n"," obj.set_path_effects([patheffects.Stroke(linewidth=4, foreground='black'), patheffects.Normal()])\n","def draw_box(ax, bb):\n"," patch = ax.add_patch(patches.Rectangle((bb[0],bb[1]), bb[2], bb[3], fill=False, edgecolor='red', lw=2))\n"," draw_outline(patch)\n","def draw_text(ax, bb, txt, disp):\n"," text = ax.text(bb[0],(bb[1]-disp),txt,verticalalignment='top'\n"," ,color='white',fontsize=10,weight='bold')\n"," draw_outline(text)\n","def draw_bbox(ax, annotations_list, id_to_label, image_shape):\n"," for annotation in annotations_list:\n"," cat_id = annotation[\"category_id\"]\n"," bbox = annotation[\"bbox\"]\n"," draw_box(ax, bbox)\n"," draw_text(ax, bbox, id_to_label[cat_id], image_shape[0] * 0.05)\n","def visualize(dataset_folder, max_examples=None):\n"," with open(os.path.join(dataset_folder, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n"," images = labels_json[\"images\"]\n"," cat_id_to_label = {item[\"id\"]:item[\"name\"] for item in labels_json[\"categories\"]}\n"," image_annots = defaultdict(list)\n"," for annotation_obj in labels_json[\"annotations\"]:\n"," image_id = annotation_obj[\"image_id\"]\n"," image_annots[image_id].append(annotation_obj)\n","\n"," if max_examples is None:\n"," max_examples = len(image_annots.items())\n"," n_rows = math.ceil(max_examples / 3)\n"," fig, axs = plt.subplots(n_rows, 3, figsize=(24, n_rows*8)) # 3 columns(2nd index), 8x8 for each image\n"," for ind, (image_id, annotations_list) in enumerate(list(image_annots.items())[:max_examples]):\n"," ax = axs[ind//3, ind%3]\n"," img = plt.imread(os.path.join(dataset_folder, \"images\", images[image_id][\"file_name\"]))\n"," ax.imshow(img)\n"," draw_bbox(ax, annotations_list, cat_id_to_label, img.shape)\n"," plt.show()\n","\n","visualize(train_dataset_path, 9)"],"metadata":{"id":"8D17VhVAI33W","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Create dataset\n","\n","The Dataset class has two methods for loading in COCO or PASCAL VOC datasets:\n","* `Dataset.from_coco_folder`\n","* `Dataset.from_pascal_voc_folder`\n","\n","Since the android_figurines dataset is in the COCO dataset format, use the `from_coco_folder` method to load the dataset located at `train_dataset_path` and `validation_dataset_path`. When loading the dataset, the data will be parsed from the provided path and converted into a standardized [TFRecord](https://www.tensorflow.org/tutorials/load_data/tfrecord) format which is cached for later use. You should create a `cache_dir` location and reuse it for all your training to avoid saving multiple caches of the same dataset."],"metadata":{"id":"fPyRkkKYJEOB"}},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir=\"/tmp/od_data/validation\")\n","print(\"train_data size: \", train_data.size)\n","print(\"validation_data size: \", validation_data.size)"],"metadata":{"id":"cooLOJrmJDbo"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Augmentation"],"metadata":{"id":"JGLtQeX3UG_s"}},{"cell_type":"code","source":["# TODO: remove all the images that have been augmentated (if it's necessary)"],"metadata":{"id":"dcpPbAIhPzOZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import albumentations as A\n","import numpy as np\n","import cv2\n","import json\n","import os\n","from tqdm import tqdm"],"metadata":{"id":"7R9rK4PXUKsa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: check the paths and the json files name\n","# TODO: check label_fields\n","# FIXME: manage the case in which augmentation has already been done"],"metadata":{"id":"F9PAySGhXvgG"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def get_dynamic_transform(image_height, image_width):\n"," crop_height = min(224, image_height)\n"," crop_width = min(224, image_width)\n","\n"," return A.Compose([\n"," A.HorizontalFlip(p=0.5),\n"," A.RandomBrightnessContrast(p=0.2),\n"," A.Perspective(p=0.5),\n"," A.Rotate(limit=30, p=0.5),\n"," A.RandomCrop(height=crop_height, width=crop_width, p=0.5),\n"," A.Cutout(num_holes=8, max_h_size=8, max_w_size=8, fill_value=0, p=0.5),\n"," ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))"],"metadata":{"id":"Kvs_HcPNUOFW"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def clip_bboxes(bboxes):\n"," return [[max(0, min(1, coord)) for coord in bbox] for bbox in bboxes]\n","\n","def resize_if_needed(image, min_size=224): # TODO: check if it's necessary\n"," height, width = image.shape[:2]\n"," if height < min_size or width < min_size:\n"," scale = min_size / min(height, width)\n"," new_height = int(height * scale)\n"," new_width = int(width * scale)\n"," image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR)\n"," return image\n","\n","def apply_augmentations(image, bboxes, categories):\n"," height, width = image.shape[:2]\n"," transform = get_dynamic_transform(height, width)\n","\n"," # Clip bboxes before applying augmentations\n"," clipped_bboxes = clip_bboxes(bboxes)\n","\n"," augmented = transform(image=image, bboxes=clipped_bboxes, category_ids=categories)\n","\n"," # Clip bboxes after augmentations as well\n"," augmented['bboxes'] = clip_bboxes(augmented['bboxes'])\n","\n"," return augmented['image'], augmented['bboxes'], augmented['category_ids']\n","\n","def update_coco_annotations(annotations, new_image_id, new_bboxes, new_categories):\n"," new_annotations = []\n"," for i, (bbox, category) in enumerate(zip(new_bboxes, new_categories)):\n"," new_annotations.append({\n"," \"id\": len(annotations) + i,\n"," \"image_id\": new_image_id,\n"," \"category_id\": category,\n"," \"bbox\": [round(coord, 1) for coord in bbox]\n"," })\n"," return new_annotations"],"metadata":{"id":"rFEifv__UMhZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Load COCO annotations\n","with open(os.path.join(train_dataset_path, 'labels.json'), 'r') as f:\n"," coco_data = json.load(f)\n","\n","augmented_images = []\n","augmented_annotations = []\n","new_image_id = len(coco_data['images'])\n","\n","for image_info in tqdm(coco_data['images']):\n"," # Load image\n"," image_path = os.path.join(train_dataset_path, 'images', image_info['file_name'])\n","\n"," if not os.path.exists(image_path):\n"," print(f\"Image not found: {image_path}\")\n"," continue\n","\n"," image = cv2.imread(image_path)\n","\n"," if image is None:\n"," print(f\"Failed to load image: {image_path}\")\n"," continue\n","\n"," image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n","\n"," # Get annotations for this image\n"," image_annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] == image_info['id']]\n"," bboxes = [ann['bbox'] for ann in image_annotations]\n"," categories = [ann['category_id'] for ann in image_annotations]\n","\n"," # Clip bboxes before applying augmentations\n"," bboxes = clip_bboxes(bboxes)\n","\n"," #image = resize_if_needed(image) # TODO: check if it's necessary\n","\n"," # Apply augmentations\n"," aug_image, aug_bboxes, aug_categories = apply_augmentations(image, bboxes, categories)\n","\n"," # Save augmented image\n"," aug_image_name = f\"aug_{image_info['file_name']}\"\n"," cv2.imwrite(os.path.join(train_dataset_path, 'images', aug_image_name), cv2.cvtColor(aug_image, cv2.COLOR_RGB2BGR)) # FIXME: why do I do COLOR_RGB2BGR?\n","\n"," # Update COCO annotations\n"," augmented_images.append({\n"," \"id\": new_image_id,\n"," \"file_name\": aug_image_name\n"," })\n"," augmented_annotations.extend(update_coco_annotations(coco_data['annotations'], new_image_id, aug_bboxes, aug_categories))\n"," new_image_id += 1\n","\n","# Update COCO data\n","coco_data['images'].extend(augmented_images)\n","coco_data['annotations'].extend(augmented_annotations)\n","\n","# Save updated COCO annotations\n","with open(os.path.join(train_dataset_path, 'labels.json'), 'w') as f:\n"," json.dump(coco_data, f)\n","\n","print(f\"Added {len(augmented_images)} augmented images to the dataset.\")"],"metadata":{"id":"cOtUZ5tnUPwJ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import shutil\n","shutil.rmtree(\"/tmp/od_data/train\") # TODO: do I need this instruction ?"],"metadata":{"id":"pUpntlCgYgUh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","print(\"Updated train_data size: \", train_data.size)"],"metadata":{"id":"2A_QSKYdUs2U"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Retrain model\n","\n","Once you have completed preparing your data, you can begin retraining a model to recognize the new objects, or classes, defined by your training data. The instructions below use the data prepared in the previous section to retrain an image classification model to recognize the two types of android figurines."],"metadata":{"id":"IzhODm2CJaMB"}},{"cell_type":"markdown","source":["### Set retraining options\n","\n","There are a few required settings to run retraining aside from your training dataset: output directory for the model, and the model architecture. Use `HParams` to specify the `export_dir` parameter for the output directory. Use the `SupportedModels` class to specify the model architecture. The object detector solution supports the following model architectures:\n","* `MobileNet-V2`\n","* `MobileNet-MultiHW-AVG`\n","\n","For more advanced customization of training parameters, see the [Hyperparameters](#hyperparameters) section below.\n","\n","To set the required parameters, use the following code:"],"metadata":{"id":"jF7sZHYyJcl7"}},{"cell_type":"code","source":["spec = object_detector.SupportedModels.MOBILENET_MULTI_AVG_I384\n","hparams = object_detector.HParams(\n"," learning_rate=0.1,\n"," batch_size=16,\n"," epochs=50,\n"," cosine_decay_epochs=45,\n"," cosine_decay_alpha=0.1,\n"," export_dir='exported_model' # TODO: check\n",")\n","model_options = object_detector.ModelOptions(\n"," l2_weight_decay=1e-4\n",")\n","options = object_detector.ObjectDetectorOptions(\n"," supported_model=spec,\n"," hparams=hparams,\n"," model_options=model_options\n",")"],"metadata":{"id":"iC6cVSpVJWFw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Run retraining\n","With your training dataset and retraining options prepared, you are ready to start the retraining process. This process is resource intensive and can take a few minutes to a few hours depending on your available compute resources. Using a Google Colab environment with standard GPU runtimes, the example retraining below takes about 2~4 minutes.\n","\n","To begin the retraining process, use the `create()` method with dataset and options you previously defined:"],"metadata":{"id":"bfQoIMPOJhK5"}},{"cell_type":"code","source":["model = object_detector.ObjectDetector.create(\n"," train_data=train_data,\n"," validation_data=validation_data,\n"," options=options)"],"metadata":{"id":"BxT1UHQWJfAX","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Evaluate the model performance\n","\n","After training the model, evaluate it on validation dataset and print the loss and coco_metrics. The most important metric for evaluating the model performance is typically the \"AP\" coco metric for Average Precision."],"metadata":{"id":"y3h_cOytJjXk"}},{"cell_type":"code","source":["loss, coco_metrics = model.evaluate(validation_data, batch_size=4)\n","print(f\"Validation loss: {loss}\")\n","print(f\"Validation coco metrics: {coco_metrics}\")"],"metadata":{"id":"kJabrlEjJkmH","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Export model\n","\n","After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label map."],"metadata":{"id":"Vkg5K-xWJmj6"}},{"cell_type":"code","source":["model.export_model()\n","!ls exported_model\n","files.download('exported_model/model.tflite')"],"metadata":{"id":"ADUK65tBJn7x","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Model quantization"],"metadata":{"id":"6g6QkA6YJpf8"}},{"cell_type":"markdown","source":["Model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy.\n","\n","This section of the guide explains how to apply quantization to your model. Model Maker supports two forms of quantization for object detector:\n","1. Quantization Aware Training: 8 bit integer precision for CPU usage\n","2. Post-Training Quantization: 16 bit floating point precision for GPU usage"],"metadata":{"id":"Jl1Dc9EhJtRo"}},{"cell_type":"markdown","source":["### Quantization aware training (int8 quantization)\n","Quantization aware training (QAT) is a fine-tuning step which happens after fully training your model. This technique further tunes a model which emulates inference time quantization in order to account for the lower precision of 8 bit integer quantization. For on-device applications with a standard CPU, use Int8 precision. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/training) documentation.\n","\n","To apply quantization aware training and export to an int8 model, create a `QATHParams` configuration and run the `quantization_aware_training` method. See the **Hyperparameters** section below on detailed usage of `QATHParams`."],"metadata":{"id":"D88f-mqgJvwi"}},{"cell_type":"code","source":["qat_hparams = object_detector.QATHParams(learning_rate=0.3, batch_size=4, epochs=10, decay_steps=6, decay_rate=0.96)\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"NbTDJtEJJr7Y"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["The QAT step often requires multiple runs to tune the parameters of training. To avoid having to rerun model training using the `create` method, use the `restore_float_ckpt` method to restore the model state back to the fully trained float model(After running the `create` method) in order to run QAT again."],"metadata":{"id":"5KEPghm9JyUi"}},{"cell_type":"code","source":["new_qat_hparams = object_detector.QATHParams(learning_rate=0.9, batch_size=4, epochs=15, decay_steps=5, decay_rate=0.96)\n","model.restore_float_ckpt()\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=new_qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"vALj8IqaJ1B4"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Finally, us the `export_model` to export to an int8 quantized model. The `export_model` function will automatically export to either float32 or int8 model depending on whether `quantization_aware_training` was run."],"metadata":{"id":"HdZAaTJKJ4ky"}},{"cell_type":"code","source":["model.export_model('model_int8_qat.tflite')\n","!ls -lh exported_model\n","files.download('exported_model/model_int8_qat.tflite')"],"metadata":{"id":"AePYkJHwJ2OG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Post-training quantization (fp16 quantization)\n","\n","Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 16-bit floats. Float16 quantization is reccomended for GPU usage. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/post_training) documentation.\n","\n","First, import the MediaPipe Model Maker quantization module:"],"metadata":{"id":"mm4gCymvJ7wK"}},{"cell_type":"code","source":["from mediapipe_model_maker import quantization"],"metadata":{"id":"mt8zeY52J8Gk"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Define a QuantizationConfig object using the `for_float16()` class method. This configuration modifies a trained model to use 16-bit floating point numbers instead of 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class."],"metadata":{"id":"r6eGgib4J-Ac"}},{"cell_type":"code","source":["quantization_config = quantization.QuantizationConfig.for_float16()"],"metadata":{"id":"IjAn-GBKJ_ab"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Export the model using the additional quantization_config object to apply post-training quantization. Note that if you previously ran `quantization_aware_training`, you must first convert the model back to a float model by using `restore_float_ckpt`."],"metadata":{"id":"ChqPtssVKCA8"}},{"cell_type":"code","source":["model.restore_float_ckpt()\n","model.export_model(model_name=\"model_fp16.tflite\", quantization_config=quantization_config)\n","!ls -lh exported_model\n","files.download('exported_model/model_fp16.tflite')"],"metadata":{"id":"f90WxGNNKCSW"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Hyperparameters\n","You can further customize the model using the ObjectDetectorOptions class, which has three parameters for `SupportedModels`, `ModelOptions`, and `HParams`.\n","\n","Use the `SupportedModels` enum class to specify the model architecture to use for training. The following model architectures are supported:\n","* MOBILENET_V2\n","* MOBILENET_V2_I320\n","* MOBILENET_MULTI_AVG\n","* MOBILENET_MULTI_AVG_I384\n","\n","Use the `HParams` class to customize other parameters related to training and saving the model:\n","* `learning_rate`: Learning rate to use for gradient descent training. Defaults to 0.3.\n","* `batch_size`: Batch size for training. Defaults to 8.\n","* `epochs`: Number of training iterations over the dataset. Defaults to 30.\n","* `cosine_decay_epochs`: The number of epochs for cosine decay learning rate. See [tf.keras.optimizers.schedules.CosineDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay) for more info. Defaults to None, which is equivalent to setting it to `epochs`.\n","* `cosine_decay_alpha`: The alpha value for cosine decay learning rate. See [tf.keras.optimizers.schedules.CosineDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/CosineDecay) for more info. Defaults to 1.0, which means no cosine decay.\n","\n","Use the `ModelOptions` class to customize parameters related to the model itself:\n","* `l2_weight_decay`: L2 regularization penalty used in [tf.keras.regularizers.L2](https://www.tensorflow.org/api_docs/python/tf/keras/regularizers/L2). Defaults to 3e-5.\n","\n","Uset the `QATHParams` class to customize training parameters for Quantization Aware Training:\n","* `learning_rate`: Learning rate to use for gradient descent QAT. Defaults to 0.3.\n","* `batch_size`: Batch size for QAT. Defaults to 8\n","* `epochs`: Number of training iterations over the dataset. Defaults to 15.\n","* `decay_steps`: Learning rate decay steps for Exponential Decay. See [tf.keras.optimizers.schedules.ExponentialDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay) for more information. Defaults to 8\n","* `decay_rate`: Learning rate decay rate for Exponential Decay. See [tf.keras.optimizers.schedules.ExponentialDecay](https://www.tensorflow.org/api_docs/python/tf/keras/optimizers/schedules/ExponentialDecay) for more information. Defaults to 0.96."],"metadata":{"id":"XMT2xM7jKD--"}},{"cell_type":"markdown","source":["## Benchmarking\n","Below is a summary of our benchmarking results for the supported model architectures. These models were trained and evaluated on the same android figurines dataset as this notebook. When considering the model benchmarking results, there are a few important caveats to keep in mind:\n","* The android figurines dataset is a small and simple dataset with 62 training examples and 10 validation examples. Since the dataset is quite small, metrics may vary drastically due to variances in the training process. This dataset was provided for demo purposes and it is recommended to collect more data samples for better performing models.\n","* The float32 models were trained with the default HParams, and the QAT step for the int8 models was run with `QATHParams(learning_rate=0.1, batch_size=4, epochs=30, decay_rate=1)`.\n","* For your own dataset, you will likely need to tune values for both HParams and QATHParams in order to achieve the best results. See the [Hyperparameters](#hyperparameters) section above for more information on configuring training parameters.\n","* All latency numbers are benchmarked on the Pixel 6.\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","\n","
Model architectureInput Image SizeTest APCPU LatencyModel Size
float32QAT int8float32QAT int8float32QAT int8
MobileNetV2256x25688.4%73.5%48ms16ms11MB3.2MB
MobileNetV2 I320320x32089.1%75.5%75ms33.38ms10MB3.3MB
MobileNet MultiHW AVG256x25688.5%70.0%56ms19ms13MB3.6MB
MobileNet MultiHW AVG I384384x38492.7%73.4%238ms41ms13MB3.6MB
\n","\n"],"metadata":{"id":"3l1ReIyUKIVW"}}],"metadata":{"colab":{"last_runtime":{"build_target":"//learning/grp/tools/ml_python:ml_notebook","kind":"private"},"private_outputs":true,"provenance":[{"file_id":"https://github.com/googlesamples/mediapipe/blob/main/examples/customization/object_detector.ipynb","timestamp":1720270505355},{"file_id":"11PG1YgsQWWLJ8jpqJ6QY7hjYWzxVwoCb","timestamp":1677706798050}],"gpuType":"T4","collapsed_sections":["JGLtQeX3UG_s","6g6QkA6YJpf8"]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0} \ No newline at end of file diff --git a/training/tf_training_paris6k.ipynb b/training/tf_training_paris6k.ipynb new file mode 100644 index 0000000..088f9de --- /dev/null +++ b/training/tf_training_paris6k.ipynb @@ -0,0 +1 @@ +{"nbformat":4,"nbformat_minor":0,"metadata":{"colab":{"provenance":[],"collapsed_sections":["Z_jqsWcdF4lb","PULQlqJTMRrL"],"authorship_tag":"ABX9TyNWwrWmcKvj6r4SMhEk+QZl"},"kernelspec":{"name":"python3","display_name":"Python 3"},"language_info":{"name":"python"}},"cells":[{"cell_type":"markdown","source":["# Monument Classification with Transfer Learning and TensorFlow Lite"],"metadata":{"id":"trNnIFVNIBpX"}},{"cell_type":"markdown","source":["1. Starting with a pre-trained model (in this case, MobileNetV2)\n","2. Fine-tuning it on paris6k\n","3. Converting the fine-tuned model to TensorFlow Lite format"],"metadata":{"id":"QS-vm_qFLvCC"}},{"cell_type":"markdown","source":["## Setup"],"metadata":{"id":"_IcSP6xC--0g"}},{"cell_type":"markdown","source":["### Install required packages"],"metadata":{"id":"BM8YC5nkIE33"}},{"cell_type":"code","source":["!python --version\n","!pip install --upgrade pip\n","!pip install tensorflow\n","!pip install albumentations\n","!pip install pycocotools"],"metadata":{"id":"6hTPNnZp_DiG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Import necessary libraries"],"metadata":{"id":"Mo-UEKdOIQXJ"}},{"cell_type":"code","source":["from google.colab import drive\n","import os\n","import json\n","import tensorflow as tf\n","import albumentations as A\n","import numpy as np\n","import matplotlib.pyplot as plt\n","from pycocotools.coco import COCO\n","from tensorflow.keras.applications import MobileNetV2\n","from tensorflow.keras.models import Model\n","from tensorflow.keras.layers import Dense, GlobalAveragePooling2D\n","\n","#from tensorflow.keras.preprocessing.image import ImageDataGenerator\n","#import cv2\n","\n","assert tf.__version__.startswith('2')"],"metadata":{"id":"I1rxFMi6_A5S"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Mount Google Drive and Set Paths"],"metadata":{"id":"Vut4cGcH-3m-"}},{"cell_type":"markdown","source":["### Mount Google Drive"],"metadata":{"id":"Uwiszy6EIeEH"}},{"cell_type":"code","source":["drive.mount('/content/drive')"],"metadata":{"id":"T_gfxE-_CFp_"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Define paths"],"metadata":{"id":"dqivQXQDIfYd"}},{"cell_type":"code","source":["base_path = '/content/drive/MyDrive/'\n","source_path = base_path + 'Datasets/revisitop/rparis6k/data/'\n","dest_base_path = base_path + 'MyProject/rparis6k/'\n","\n","train_dataset_path = dest_base_path + 'train/'\n","validation_dataset_path = dest_base_path + 'validation/'\n","test_dataset_path = dest_base_path + 'test/'"],"metadata":{"id":"0SmkSGAvB-5h"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Dataset Preparation"],"metadata":{"id":"jPMWHe1BIjE6"}},{"cell_type":"markdown","source":["### Copy images"],"metadata":{"id":"F41txtrnFxe_"}},{"cell_type":"code","source":["# Function to copy images\n","def copy_images(file_list, dest_folder):\n"," with open(file_list, 'r') as f:\n"," for line in f:\n"," img_name = line.strip()\n"," src = os.path.join(source_path, img_name)\n"," dst = os.path.join(dest_folder, img_name)\n"," os.makedirs(os.path.dirname(dst), exist_ok=True)\n"," shutil.copy2(src, dst)\n","\n","# Copy images for each set\n","if not os.listdir(train_dataset_path) and not os.listdir(validation_dataset_path) and not os.listdir(test_dataset_path):\n"," copy_images(dest_base_path + 'train.txt', train_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'val.txt', validation_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'test.txt', test_dataset_path + 'images/')\n"," print(\"Dataset division completed!\\n\")\n","else:\n"," print(\"One or more directories are not empty. Copy operation aborted.\\n\")\n","\n","print(f\"Number of images in train set: {len(os.listdir(train_dataset_path + 'images/'))}\")\n","print(f\"Number of images in validation set: {len(os.listdir(validation_dataset_path + 'images/'))}\")\n","print(f\"Number of images in test set: {len(os.listdir(test_dataset_path + 'images/'))}\")"],"metadata":{"id":"Q0BXXZ71Fz5q"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Review dataset"],"metadata":{"id":"Z_jqsWcdF4lb"}},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","for category_item in labels_json[\"categories\"]:\n"," print(f\"{category_item['id']}: {category_item['name']}\")"],"metadata":{"id":"k0_pG4aaF7qS"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Data Augmentation"],"metadata":{"id":"i2JaY6MqBvly"}},{"cell_type":"code","source":["# Define Albumentations transformations\n","train_transform = A.Compose([\n"," A.RandomRotate90(),\n"," A.Flip(),\n"," A.Transpose(),\n"," A.OneOf([\n"," A.IAAAdditiveGaussianNoise(),\n"," A.GaussNoise(),\n"," ], p=0.2),\n"," A.OneOf([\n"," A.MotionBlur(p=0.2),\n"," A.MedianBlur(blur_limit=3, p=0.1),\n"," A.Blur(blur_limit=3, p=0.1),\n"," ], p=0.2),\n"," A.ShiftScaleRotate(shift_limit=0.0625, scale_limit=0.5, rotate_limit=45, p=0.2),\n"," A.OneOf([\n"," A.OpticalDistortion(p=0.3),\n"," A.GridDistortion(p=0.1),\n"," A.IAAPiecewiseAffine(p=0.3),\n"," ], p=0.2),\n"," A.OneOf([\n"," A.CLAHE(clip_limit=2),\n"," A.IAASharpen(),\n"," A.IAAEmboss(),\n"," A.RandomBrightnessContrast(),\n"," ], p=0.3),\n"," A.HueSaturationValue(p=0.3),\n","])\n","\n","def augment_image(image, transform=train_transform):\n"," image = np.array(image)\n"," augmented = transform(image=image)\n"," return augmented['image']\n"],"metadata":{"id":"BRlaPD3-BxlI"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Dataset Creation"],"metadata":{"id":"idjelgoRB0U6"}},{"cell_type":"code","source":["def load_and_preprocess_image(path):\n"," image = tf.io.read_file(path)\n"," image = tf.image.decode_jpeg(image, channels=3)\n"," image = tf.image.resize(image, [224, 224])\n"," image = tf.cast(image, tf.float32) / 255.0 # Normalize to [0,1]\n"," return image\n","\n","def preprocess_and_augment(image, label):\n"," image = tf.numpy_function(augment_image, [image], tf.float32)\n"," image.set_shape([224, 224, 3])\n"," return image, label\n","\n","def get_dataset(image_paths, labels, batch_size, is_training=False):\n"," dataset = tf.data.Dataset.from_tensor_slices((image_paths, labels))\n"," dataset = dataset.map(lambda x, y: (load_and_preprocess_image(x), y), num_parallel_calls=tf.data.AUTOTUNE)\n"," if is_training:\n"," dataset = dataset.map(preprocess_and_augment, num_parallel_calls=tf.data.AUTOTUNE)\n"," dataset = dataset.shuffle(buffer_size=1000)\n"," dataset = dataset.batch(batch_size)\n"," dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)\n"," return dataset"],"metadata":{"id":"dyzcJs04E1Fe"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Load the dataset"],"metadata":{"id":"irw8zVStJpwg"}},{"cell_type":"code","source":["image_paths, labels = load_your_dataset() # TODO: implement this function"],"metadata":{"id":"EeedjGy7JpMA"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# This function should return two lists: one containing the paths to the images,\n","# and another containing the corresponding labels (as one-hot encoded vectors)."],"metadata":{"id":"qQLiCyCPMgwP"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Split the dataset"],"metadata":{"id":"uKY0g7GEJx5W"}},{"cell_type":"code","source":["# Split the dataset\n","split_index = int(len(image_paths) * 0.8)\n","image_paths_train = image_paths[:split_index]\n","labels_train = labels[:split_index]\n","image_paths_val = image_paths[split_index:]\n","labels_val = labels[split_index:]\n","\n","train_dataset = get_dataset(image_paths_train, labels_train, batch_size=32, is_training=True)\n","val_dataset = get_dataset(image_paths_val, labels_val, batch_size=32, is_training=False)"],"metadata":{"id":"fSrzwTK9J1x-"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Model"],"metadata":{"id":"mHbrAGItKJEW"}},{"cell_type":"markdown","source":["### Model creations"],"metadata":{"id":"1q-vKKrm_OoU"}},{"cell_type":"code","source":["# Load pre-trained MobileNetV2 model\n","base_model = MobileNetV2(weights='imagenet', include_top=False, input_shape=(224, 224, 3))\n","\n","# Freeze base model layers\n","for layer in base_model.layers:\n"," layer.trainable = False\n","\n","# Add custom layers\n","x = base_model.output\n","x = GlobalAveragePooling2D()(x)\n","x = Dense(1024, activation='relu')(x)\n","predictions = Dense(13, activation='softmax')(x) # 13 classes (12 monuments + background)\n","\n","model = Model(inputs=base_model.input, outputs=predictions)"],"metadata":{"id":"ogWWkKa2_QAu"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Model compilation"],"metadata":{"id":"O5bGdIfAKESs"}},{"cell_type":"code","source":["model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])"],"metadata":{"colab":{"base_uri":"https://localhost:8080/","height":144},"collapsed":true,"id":"MTDMWS1m_RU8","executionInfo":{"status":"error","timestamp":1720637611502,"user_tz":-120,"elapsed":381,"user":{"displayName":"Elia Innocenti","userId":"07866908462894922277"}},"outputId":"2af75516-14fd-4d40-87a1-d35bc36a34f1"},"execution_count":1,"outputs":[{"output_type":"error","ename":"NameError","evalue":"name 'model' is not defined","traceback":["\u001b[0;31m---------------------------------------------------------------------------\u001b[0m","\u001b[0;31mNameError\u001b[0m Traceback (most recent call last)","\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcompile\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'adam'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m'categorical_crossentropy'\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmetrics\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'accuracy'\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m","\u001b[0;31mNameError\u001b[0m: name 'model' is not defined"]}]},{"cell_type":"markdown","source":["### Fine-tuning"],"metadata":{"id":"G1WUgy_5M7j0"}},{"cell_type":"code","source":["# Unfreeze the top layers of the base model\n","for layer in base_model.layers[-20:]:\n"," layer.trainable = True\n","\n","# Recompile the model with a lower learning rate\n","model.compile(optimizer=tf.keras.optimizers.Adam(1e-5),\n"," loss='categorical_crossentropy',\n"," metrics=['accuracy'])\n","\n","# Continue training\n","model.fit(...)"],"metadata":{"id":"xQkqRvtnM9fU"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Training"],"metadata":{"id":"ZbPYOlUDKYKb"}},{"cell_type":"markdown","source":["- Start by training for a few epochs and monitor the validation loss and accuracy.\n","- If the model is underfitting (high training and validation loss):\n"," 1. Unfreeze more layers of the base model\n"," 2. Train for more epochs\n"," 3. Increase model capacity (add more dense layers)\n","- If the model is overfitting (low training loss, high validation loss):\n"," 1. Add regularization (e.g., dropout layers)\n"," 2. Use data augmentation (already implemented)\n"," 3. Reduce model capacity"],"metadata":{"id":"NqvcuM_AL85p"}},{"cell_type":"markdown","source":["- Start with a small number of epochs (e.g., 10) and monitor the training and validation metrics.\n","- Gradually increase the number of epochs if needed.\n","- Use the ModelCheckpoint callback to save the best model based on validation accuracy.\n","- Use the EarlyStopping callback to prevent overfitting by stopping training when the validation loss stops improving."],"metadata":{"id":"1cspPobzMmuU"}},{"cell_type":"markdown","source":["### Resume training"],"metadata":{"id":"PULQlqJTMRrL"}},{"cell_type":"code","source":["# Load the saved model\n","model = tf.keras.models.load_model('best_model.h5')\n","\n","# Continue training\n","model.fit(\n"," train_dataset,\n"," validation_data=val_dataset,\n"," epochs=10,\n"," initial_epoch=history.epoch[-1], # Start from the last epoch\n"," callbacks=[\n"," tf.keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_accuracy'),\n"," tf.keras.callbacks.EarlyStopping(patience=3, monitor='val_loss')\n"," ]\n",")"],"metadata":{"id":"jFmlaG9lMZ4q"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Train the model"],"metadata":{"id":"0HAl3H3SKgGF"}},{"cell_type":"code","source":["history = model.fit(\n"," train_dataset,\n"," validation_data=val_dataset,\n"," epochs=10,\n"," callbacks=[\n"," tf.keras.callbacks.ModelCheckpoint('best_model.h5', save_best_only=True, monitor='val_accuracy'),\n"," tf.keras.callbacks.EarlyStopping(patience=3, monitor='val_loss')\n"," ]\n",")"],"metadata":{"id":"L0Jije_c_TWh"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Evaluate"],"metadata":{"id":"dWwHp7oj_a_k"}},{"cell_type":"markdown","source":["### Evaluate the model"],"metadata":{"id":"GIicAZ6vKodk"}},{"cell_type":"code","source":["loss, accuracy = model.evaluate(val_dataset)\n","print(f\"Validation loss: {loss}\")\n","print(f\"Validation accuracy: {accuracy}\")"],"metadata":{"id":"lwjCx0Ir_aqt"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Export to TensorFlow Lite"],"metadata":{"id":"4tArnfxu_ceT"}},{"cell_type":"markdown","source":["### Convert to TFLite"],"metadata":{"id":"U2agEoARKxJ6"}},{"cell_type":"code","source":["model.save('monumenti_model.h5') # TODO: check line\n","\n","converter = tf.lite.TFLiteConverter.from_keras_model(model)\n","tflite_model = converter.convert()"],"metadata":{"id":"1RwFPHZU_e3c"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Save the TFLite model"],"metadata":{"id":"jCwfjbWkKycO"}},{"cell_type":"code","source":["with open('monuments_model.tflite', 'wb') as f:\n"," f.write(tflite_model)"],"metadata":{"id":"KoNhYZ5YK1Pn"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["# Quantization"],"metadata":{"id":"ERb5D4Uc_fIb"}},{"cell_type":"markdown","source":["### Quantize the model"],"metadata":{"id":"XBQPKRTlLEUU"}},{"cell_type":"code","source":["converter.optimizations = [tf.lite.Optimize.DEFAULT]\n","tflite_model_quantized = converter.convert()"],"metadata":{"id":"sMDtUnD7_gx2"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Save the quantized TFLite model"],"metadata":{"id":"JNh-G1ZVLF1n"}},{"cell_type":"code","source":["with open('monuments_model_quantized.tflite', 'wb') as f:\n"," f.write(tflite_model_quantized)"],"metadata":{"id":"lYl3k1x1LIyM"},"execution_count":null,"outputs":[]}]} \ No newline at end of file