diff --git a/README.md b/README.md index e8953cd..bcd8867 100644 --- a/README.md +++ b/README.md @@ -1 +1,59 @@ -# Oxford5k-Paris6k-ObjectDetection \ No newline at end of file +# Oxford5k-Paris6k-ObjectDetection + +This project aims to create an object detection model for monument recognition using the Oxford5k and Paris6k datasets. The model is built using MediaPipe Model Maker for transfer learning, starting from a pre-trained model. + +## Project Overview + +The main objective of this project is to adapt the Oxford5k and Paris6k datasets, originally designed for image retrieval, for object detection tasks. This involved significant work in converting the annotations from their original format (stored in .pkl files) to standard object detection formats such as Pascal VOC and COCO. + +## Key Features + +- Adaptation of Oxford5k and Paris6k datasets for object detection +- Custom scripts for data preprocessing and annotation conversion +- Transfer learning using MediaPipe Model Maker +- Support for both Pascal VOC and COCO annotation formats + +## Getting Started + +1. Clone the repository +2. Install the required dependencies +3. Run the data preparation scripts in the `scripts/` directory +4. Use the Jupyter notebooks in the `training/` directory for model training + +## Data Preparation + +The `scripts/` directory contains various Python scripts for data preparation: + +- `get_data.py`: downloads the original datasets +- `create_annotations.py`: converts original annotations to Pascal VOC and COCO formats +- `prepare_dataset.py`: prepares the dataset for training +- `check_annotations.py`: verifies the correctness of the converted annotations + +## Training + +The `training/` directory contains Jupyter notebooks for model training: + +- `mediapipe_object_detector_model_customization_template.ipynb`: template for MediaPipe Model Maker +- `mp_training_paris6k.ipynb`: specific training notebook for the Paris6k dataset + +## Inference + +Use the scripts in the `inference/` directory to run object detection on new images. + +## License + +This project is licensed under the [LICENSE NAME] - see the [LICENSE.txt](LICENSE.txt) file for details. + +## Acknowledgments + +- Original Oxford5k and Paris6k dataset creators +- MediaPipe team for their Model Maker tool + +## References + +- [Oxford5k Dataset](http://www.robots.ox.ac.uk/~vgg/data/oxbuildings/) +- [Paris6k Dataset](http://www.robots.ox.ac.uk/~vgg/data/parisbuildings/) + +## Authors + +- [Elia Innocenti](https://github.com/eliainnocenti) diff --git a/inference/test.py b/inference/test.py index 5bdebaa..c87d28a 100644 --- a/inference/test.py +++ b/inference/test.py @@ -11,7 +11,7 @@ base_path = "../../../Data/" # Load the TFLite model -interpreter = tf.lite.Interpreter(model_path='../models/model2.tflite') +interpreter = tf.lite.Interpreter(model_path='../models/model.tflite') interpreter.allocate_tensors() # Get input and output details @@ -145,7 +145,7 @@ def main(): :return: """ - #train_images() + train_images() #validation_images() #test_images() diff --git a/scripts/README.md b/scripts/README.md index 8ebb0ae..61bcd4c 100644 --- a/scripts/README.md +++ b/scripts/README.md @@ -15,7 +15,7 @@ python get_data.py python prepare_dataset.py ``` - + ```mermaid graph TD; diff --git a/scripts/create_annotations.py b/scripts/create_annotations.py index 6bb9e40..5f77964 100644 --- a/scripts/create_annotations.py +++ b/scripts/create_annotations.py @@ -36,6 +36,7 @@ Creates a list of classes from the dataset. 10. get_id_by_name(categories, name): + Gets the ID of a category by its name. 11. process_data(folder_name, data, image_folder, output_folder, monuments_list, type='xml', levels=2): Processes the dataset to create annotations in XML or JSON format. @@ -503,7 +504,7 @@ def _process_data_json(data, image_folder, output_folder, monuments_list, levels _objects[monument] = [] _objects[monument].append(Object(f"{monument}", "Unspecified", "0", str(level), bbox)) - # merge bbox for the same monument # FIXME: choose another way to merge + # merge bbox for the same monument # FIXME: choose another way to merge (?) for monument in _objects.keys(): xmin_avg, ymin_avg, xmax_avg, ymax_avg = 0, 0, 0, 0 difficulty = 0 @@ -585,7 +586,7 @@ def process_data(folder_name, data, image_folder, output_folder, monuments_list, print("Annotations created successfully") -def main(datasets=None, type='xml', levels=2): +def main(datasets=None, type='xml', levels=1): """ Main function to create annotations for the specified datasets. diff --git a/scripts/prepare_dataset.py b/scripts/prepare_dataset.py index edaedfe..b9ba1b3 100644 --- a/scripts/prepare_dataset.py +++ b/scripts/prepare_dataset.py @@ -90,7 +90,6 @@ def split_train_val_test(dataset_name, train_percent=0.7, val_percent=0.2, test_ with open(labels_file, 'r') as file: labels_json = json.load(file) # insert in images only the images that have at least one annotation - # TODO: check images = [image['file_name'] for image in labels_json['images']] for image in labels_json['images']: if image['id'] not in [annotation['image_id'] for annotation in labels_json['annotations']]: @@ -190,7 +189,7 @@ def split_annotations(dataset_name, type='json'): return -def prepare_dataset(dataset_name, type='xml', levels=3): +def prepare_dataset(dataset_name, type='xml', levels=1): """ Prepares the dataset by creating annotations and splitting it into training, validation, and test sets. @@ -239,7 +238,7 @@ def main(): :return: None """ datasets = [ - #'roxford5k', + #'roxford5k', # TODO: uncomment 'rparis6k' ] diff --git a/training/mp_training_paris6k.ipynb b/training/mp_training_paris6k.ipynb index 77d35cd..82ed752 100644 --- a/training/mp_training_paris6k.ipynb +++ b/training/mp_training_paris6k.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","source":["# Object detection model customization"],"metadata":{"id":"rQIYsEciHsMN"}},{"cell_type":"markdown","source":["## Setup"],"metadata":{"id":"NReBxyCQH1Rx"}},{"cell_type":"markdown","source":["To install the libraries for customizing a model, run the following commands:"],"metadata":{"id":"xbuuejtPH97m"}},{"cell_type":"code","source":["!python --version\n","!pip install --upgrade pip\n","!pip install mediapipe-model-maker"],"metadata":{"id":"ceFO9GPSHyRv","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Use the following code to import the required Python classes:"],"metadata":{"id":"2beHFLLIIIv0"}},{"cell_type":"code","source":["from google.colab import files\n","import os\n","import json\n","import tensorflow as tf\n","assert tf.__version__.startswith('2')\n","\n","from mediapipe_model_maker import object_detector"],"metadata":{"id":"0OpjgSuBIJJC","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Prepare data"],"metadata":{"id":"Yq6UT0j9ILD_"}},{"cell_type":"code","source":["from google.colab import drive\n","import os\n","import shutil\n","\n","# Mount Google Drive\n","drive.mount('/content/drive')"],"metadata":{"id":"-Lrh4ORlDGgu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Define paths\n","base_path = '/content/drive/MyDrive/'\n","source_path = base_path + 'Datasets/revisitop/rparis6k/data/'\n","dest_base_path = base_path + 'MyProject/rparis6k/'\n","\n","train_dataset_path = dest_base_path + 'train/'\n","validation_dataset_path = dest_base_path + 'validation/'\n","test_dataset_path = dest_base_path + 'test/'"],"metadata":{"id":"BlfRX-r_3YUc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: is it possible to add a progress bar?"],"metadata":{"id":"-noKxHj9ZVzE"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Function to copy images\n","def copy_images(file_list, dest_folder):\n"," with open(file_list, 'r') as f:\n"," for line in f:\n"," img_name = line.strip()\n"," src = os.path.join(source_path, img_name)\n"," dst = os.path.join(dest_folder, img_name)\n"," os.makedirs(os.path.dirname(dst), exist_ok=True)\n"," shutil.copy2(src, dst)\n","\n","# Copy images for each set\n","if len(os.listdir(train_dataset_path + 'images/')) == 0 and len(os.listdir(validation_dataset_path + 'images/')) == 0 and len(os.listdir(test_dataset_path + 'images/')) == 0:\n"," copy_images(dest_base_path + 'train.txt', train_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'val.txt', validation_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'test.txt', test_dataset_path + 'images/')\n"," print(\"Dataset division completed!\\n\")\n","else:\n"," print(\"One or more directories are not empty. Copy operation aborted.\\n\")\n","\n","print(f\"Number of images in train set: {len(os.listdir(train_dataset_path + 'images/'))}\")\n","print(f\"Number of images in validation set: {len(os.listdir(validation_dataset_path + 'images/'))}\")\n","print(f\"Number of images in test set: {len(os.listdir(test_dataset_path + 'images/'))}\")"],"metadata":{"id":"-A3ZtrofqptN"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Review dataset"],"metadata":{"id":"IUcrpj6aIzqt"}},{"cell_type":"markdown","source":["Verify the dataset content by printing the categories from the `labels.json` file. There should be 13 total categories. Index 0 is always set to be the `background` class which may be unused in the dataset."],"metadata":{"id":"xYXzxYC3I0h5"}},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","for category_item in labels_json[\"categories\"]:\n"," print(f\"{category_item['id']}: {category_item['name']}\")"],"metadata":{"id":"v6rYvXVwIwKp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Visualize the training dataset\n","import matplotlib.pyplot as plt\n","from matplotlib import patches, text, patheffects\n","from collections import defaultdict\n","import math\n","\n","def draw_outline(obj):\n"," obj.set_path_effects([patheffects.Stroke(linewidth=4, foreground='black'), patheffects.Normal()])\n","def draw_box(ax, bb):\n"," patch = ax.add_patch(patches.Rectangle((bb[0],bb[1]), bb[2], bb[3], fill=False, edgecolor='red', lw=2))\n"," draw_outline(patch)\n","def draw_text(ax, bb, txt, disp):\n"," text = ax.text(bb[0],(bb[1]-disp),txt,verticalalignment='top'\n"," ,color='white',fontsize=10,weight='bold')\n"," draw_outline(text)\n","def draw_bbox(ax, annotations_list, id_to_label, image_shape):\n"," for annotation in annotations_list:\n"," cat_id = annotation[\"category_id\"]\n"," bbox = annotation[\"bbox\"]\n"," draw_box(ax, bbox)\n"," draw_text(ax, bbox, id_to_label[cat_id], image_shape[0] * 0.05)\n","def visualize(dataset_folder, max_examples=None):\n"," with open(os.path.join(dataset_folder, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n"," images = labels_json[\"images\"]\n"," cat_id_to_label = {item[\"id\"]:item[\"name\"] for item in labels_json[\"categories\"]}\n"," image_annots = defaultdict(list)\n"," for annotation_obj in labels_json[\"annotations\"]:\n"," image_id = annotation_obj[\"image_id\"]\n"," image_annots[image_id].append(annotation_obj)\n","\n"," if max_examples is None:\n"," max_examples = len(image_annots.items())\n"," n_rows = math.ceil(max_examples / 3)\n"," fig, axs = plt.subplots(n_rows, 3, figsize=(24, n_rows*8)) # 3 columns(2nd index), 8x8 for each image\n"," for ind, (image_id, annotations_list) in enumerate(list(image_annots.items())[:max_examples]):\n"," ax = axs[ind//3, ind%3]\n"," img = plt.imread(os.path.join(dataset_folder, \"images\", images[image_id][\"file_name\"]))\n"," ax.imshow(img)\n"," draw_bbox(ax, annotations_list, cat_id_to_label, img.shape)\n"," plt.show()\n","\n","visualize(train_dataset_path, 9)"],"metadata":{"id":"8D17VhVAI33W","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Create dataset"],"metadata":{"id":"fPyRkkKYJEOB"}},{"cell_type":"code","source":["import shutil\n","shutil.rmtree(\"/tmp/od_data/train\") # TODO: do I need this instruction ?"],"metadata":{"id":"6MICrUKmFEsS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: is it possible to add a progress bar?"],"metadata":{"id":"7f4Uke50b1X9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir=\"/tmp/od_data/validation\")\n","print(\"train_data size: \", train_data.size)\n","print(\"validation_data size: \", validation_data.size)"],"metadata":{"id":"cooLOJrmJDbo"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Augmentation"],"metadata":{"id":"JGLtQeX3UG_s"}},{"cell_type":"code","source":["# TODO: remove all the images that have been augmentated (if it's necessary)"],"metadata":{"id":"dcpPbAIhPzOZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import albumentations as A\n","import numpy as np\n","import cv2\n","import json\n","import os\n","from tqdm import tqdm"],"metadata":{"id":"7R9rK4PXUKsa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: check the paths and the json files name\n","# TODO: check label_fields\n","# FIXME: manage the case in which augmentation has already been done"],"metadata":{"id":"F9PAySGhXvgG"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def get_dynamic_transform(image_height, image_width):\n"," crop_height = min(224, image_height)\n"," crop_width = min(224, image_width)\n","\n"," return A.Compose([\n"," A.HorizontalFlip(p=0.5),\n"," A.RandomBrightnessContrast(p=0.2),\n"," A.Perspective(p=0.5),\n"," A.Rotate(limit=30, p=0.5),\n"," A.RandomCrop(height=crop_height, width=crop_width, p=0.5),\n"," A.Cutout(num_holes=8, max_h_size=8, max_w_size=8, fill_value=0, p=0.5),\n"," ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))"],"metadata":{"id":"Kvs_HcPNUOFW"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def clip_bboxes(bboxes):\n"," return [[max(0, min(1, coord)) for coord in bbox] for bbox in bboxes]\n","\n","def resize_if_needed(image, min_size=224): # TODO: check if it's necessary\n"," height, width = image.shape[:2]\n"," if height < min_size or width < min_size:\n"," scale = min_size / min(height, width)\n"," new_height = int(height * scale)\n"," new_width = int(width * scale)\n"," image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR)\n"," return image\n","\n","def apply_augmentations(image, bboxes, categories):\n"," height, width = image.shape[:2]\n"," transform = get_dynamic_transform(height, width)\n","\n"," # Clip bboxes before applying augmentations\n"," clipped_bboxes = clip_bboxes(bboxes)\n","\n"," augmented = transform(image=image, bboxes=clipped_bboxes, category_ids=categories)\n","\n"," # Clip bboxes after augmentations as well\n"," augmented['bboxes'] = clip_bboxes(augmented['bboxes'])\n","\n"," return augmented['image'], augmented['bboxes'], augmented['category_ids']\n","\n","def update_coco_annotations(annotations, new_image_id, new_bboxes, new_categories):\n"," new_annotations = []\n"," for i, (bbox, category) in enumerate(zip(new_bboxes, new_categories)):\n"," new_annotations.append({\n"," \"id\": len(annotations) + i,\n"," \"image_id\": new_image_id,\n"," \"category_id\": category,\n"," \"bbox\": [round(coord, 1) for coord in bbox]\n"," })\n"," return new_annotations"],"metadata":{"id":"rFEifv__UMhZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Load COCO annotations\n","with open(os.path.join(train_dataset_path, 'labels.json'), 'r') as f:\n"," coco_data = json.load(f)\n","\n","augmented_images = []\n","augmented_annotations = []\n","new_image_id = len(coco_data['images'])\n","\n","for image_info in tqdm(coco_data['images']):\n"," # Load image\n"," image_path = os.path.join(train_dataset_path, 'images', image_info['file_name'])\n","\n"," if not os.path.exists(image_path):\n"," print(f\"Image not found: {image_path}\")\n"," continue\n","\n"," image = cv2.imread(image_path)\n","\n"," if image is None:\n"," print(f\"Failed to load image: {image_path}\")\n"," continue\n","\n"," image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n","\n"," # Get annotations for this image\n"," image_annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] == image_info['id']]\n"," bboxes = [ann['bbox'] for ann in image_annotations]\n"," categories = [ann['category_id'] for ann in image_annotations]\n","\n"," # Clip bboxes before applying augmentations\n"," bboxes = clip_bboxes(bboxes)\n","\n"," #image = resize_if_needed(image) # TODO: check if it's necessary\n","\n"," # Apply augmentations\n"," aug_image, aug_bboxes, aug_categories = apply_augmentations(image, bboxes, categories)\n","\n"," # Save augmented image\n"," aug_image_name = f\"aug_{image_info['file_name']}\"\n"," cv2.imwrite(os.path.join(train_dataset_path, 'images', aug_image_name), cv2.cvtColor(aug_image, cv2.COLOR_RGB2BGR)) # FIXME: why do I do COLOR_RGB2BGR?\n","\n"," # Update COCO annotations\n"," augmented_images.append({\n"," \"id\": new_image_id,\n"," \"file_name\": aug_image_name\n"," })\n"," augmented_annotations.extend(update_coco_annotations(coco_data['annotations'], new_image_id, aug_bboxes, aug_categories))\n"," new_image_id += 1\n","\n","# Update COCO data\n","coco_data['images'].extend(augmented_images)\n","coco_data['annotations'].extend(augmented_annotations)\n","\n","# Save updated COCO annotations\n","with open(os.path.join(train_dataset_path, 'labels.json'), 'w') as f:\n"," json.dump(coco_data, f)\n","\n","print(f\"Added {len(augmented_images)} augmented images to the dataset.\")"],"metadata":{"id":"cOtUZ5tnUPwJ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["import shutil\n","shutil.rmtree(\"/tmp/od_data/train\") # TODO: do I need this instruction ?"],"metadata":{"id":"pUpntlCgYgUh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","print(\"Updated train_data size: \", train_data.size)"],"metadata":{"id":"2A_QSKYdUs2U"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Retrain model"],"metadata":{"id":"IzhODm2CJaMB"}},{"cell_type":"markdown","source":["### Set retraining options"],"metadata":{"id":"jF7sZHYyJcl7"}},{"cell_type":"code","source":["spec = object_detector.SupportedModels.MOBILENET_V2\n","hparams = object_detector.HParams(\n"," learning_rate=0.3,\n"," batch_size=8,\n"," epochs=50,\n"," #cosine_decay_epochs=50,\n"," #cosine_decay_alpha=0.2,\n"," export_dir='exported_model'\n",")\n","#model_options = object_detector.ModelOptions(\n","# l2_weight_decay=3e-4\n","#)\n","options = object_detector.ObjectDetectorOptions(\n"," supported_model=spec,\n"," hparams=hparams,\n"," #model_options=model_options\n",")"],"metadata":{"id":"iC6cVSpVJWFw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Run retraining"],"metadata":{"id":"bfQoIMPOJhK5"}},{"cell_type":"code","source":["model = object_detector.ObjectDetector.create(\n"," train_data=train_data,\n"," validation_data=validation_data,\n"," options=options)"],"metadata":{"id":"BxT1UHQWJfAX","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Evaluate the model performance\n","\n","After training the model, evaluate it on validation dataset and print the loss and coco_metrics. The most important metric for evaluating the model performance is typically the \"AP\" coco metric for Average Precision."],"metadata":{"id":"y3h_cOytJjXk"}},{"cell_type":"code","source":["# TODO: is it possible to add a progress bar?"],"metadata":{"id":"WaUndk7ZZgbk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["loss, coco_metrics = model.evaluate(validation_data, batch_size=4)\n","print(f\"Validation loss: {loss}\")\n","print(f\"Validation coco metrics: {coco_metrics}\")"],"metadata":{"id":"kJabrlEjJkmH","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Export model\n","\n","After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label map."],"metadata":{"id":"Vkg5K-xWJmj6"}},{"cell_type":"code","source":["model.export_model()\n","!ls exported_model\n","files.download('exported_model/model.tflite')"],"metadata":{"id":"ADUK65tBJn7x","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Model quantization"],"metadata":{"id":"6g6QkA6YJpf8"}},{"cell_type":"markdown","source":["Model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy.\n","\n","This section of the guide explains how to apply quantization to your model. Model Maker supports two forms of quantization for object detector:\n","1. Quantization Aware Training: 8 bit integer precision for CPU usage\n","2. Post-Training Quantization: 16 bit floating point precision for GPU usage"],"metadata":{"id":"Jl1Dc9EhJtRo"}},{"cell_type":"markdown","source":["### Quantization aware training (int8 quantization)\n","Quantization aware training (QAT) is a fine-tuning step which happens after fully training your model. This technique further tunes a model which emulates inference time quantization in order to account for the lower precision of 8 bit integer quantization. For on-device applications with a standard CPU, use Int8 precision. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/training) documentation.\n","\n","To apply quantization aware training and export to an int8 model, create a `QATHParams` configuration and run the `quantization_aware_training` method. See the **Hyperparameters** section below on detailed usage of `QATHParams`."],"metadata":{"id":"D88f-mqgJvwi"}},{"cell_type":"code","source":["qat_hparams = object_detector.QATHParams(learning_rate=0.3, batch_size=4, epochs=10, decay_steps=6, decay_rate=0.96)\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"NbTDJtEJJr7Y"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["The QAT step often requires multiple runs to tune the parameters of training. To avoid having to rerun model training using the `create` method, use the `restore_float_ckpt` method to restore the model state back to the fully trained float model(After running the `create` method) in order to run QAT again."],"metadata":{"id":"5KEPghm9JyUi"}},{"cell_type":"code","source":["new_qat_hparams = object_detector.QATHParams(learning_rate=0.9, batch_size=4, epochs=15, decay_steps=5, decay_rate=0.96)\n","model.restore_float_ckpt()\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=new_qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"vALj8IqaJ1B4"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Finally, us the `export_model` to export to an int8 quantized model. The `export_model` function will automatically export to either float32 or int8 model depending on whether `quantization_aware_training` was run."],"metadata":{"id":"HdZAaTJKJ4ky"}},{"cell_type":"code","source":["model.export_model('model_int8_qat.tflite')\n","!ls -lh exported_model\n","files.download('exported_model/model_int8_qat.tflite')"],"metadata":{"id":"AePYkJHwJ2OG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Post-training quantization (fp16 quantization)\n","\n","Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 16-bit floats. Float16 quantization is reccomended for GPU usage. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/post_training) documentation.\n","\n","First, import the MediaPipe Model Maker quantization module:"],"metadata":{"id":"mm4gCymvJ7wK"}},{"cell_type":"code","source":["from mediapipe_model_maker import quantization"],"metadata":{"id":"mt8zeY52J8Gk"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Define a QuantizationConfig object using the `for_float16()` class method. This configuration modifies a trained model to use 16-bit floating point numbers instead of 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class."],"metadata":{"id":"r6eGgib4J-Ac"}},{"cell_type":"code","source":["quantization_config = quantization.QuantizationConfig.for_float16()"],"metadata":{"id":"IjAn-GBKJ_ab"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Export the model using the additional quantization_config object to apply post-training quantization. Note that if you previously ran `quantization_aware_training`, you must first convert the model back to a float model by using `restore_float_ckpt`."],"metadata":{"id":"ChqPtssVKCA8"}},{"cell_type":"code","source":["model.restore_float_ckpt()\n","model.export_model(model_name=\"model_fp16.tflite\", quantization_config=quantization_config)\n","!ls -lh exported_model\n","files.download('exported_model/model_fp16.tflite')"],"metadata":{"id":"f90WxGNNKCSW"},"execution_count":null,"outputs":[]}],"metadata":{"colab":{"last_runtime":{"build_target":"//learning/grp/tools/ml_python:ml_notebook","kind":"private"},"private_outputs":true,"provenance":[{"file_id":"https://github.com/googlesamples/mediapipe/blob/main/examples/customization/object_detector.ipynb","timestamp":1720270505355},{"file_id":"11PG1YgsQWWLJ8jpqJ6QY7hjYWzxVwoCb","timestamp":1677706798050}],"gpuType":"T4","collapsed_sections":["6g6QkA6YJpf8"]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0} \ No newline at end of file +{"cells":[{"cell_type":"markdown","source":["# Object detection model customization"],"metadata":{"id":"rQIYsEciHsMN"}},{"cell_type":"markdown","source":["## Setup"],"metadata":{"id":"NReBxyCQH1Rx"}},{"cell_type":"markdown","source":["To install the libraries for customizing a model, run the following commands:"],"metadata":{"id":"xbuuejtPH97m"}},{"cell_type":"code","source":["!python --version\n","!pip install --upgrade pip\n","!pip install mediapipe-model-maker"],"metadata":{"id":"ceFO9GPSHyRv","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Use the following code to import the required Python classes:"],"metadata":{"id":"2beHFLLIIIv0"}},{"cell_type":"code","source":["from google.colab import files\n","import os\n","import json\n","from tqdm import tqdm\n","import tensorflow as tf\n","assert tf.__version__.startswith('2')\n","\n","from mediapipe_model_maker import object_detector"],"metadata":{"id":"0OpjgSuBIJJC","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Prepare data"],"metadata":{"id":"Yq6UT0j9ILD_"}},{"cell_type":"code","source":["from google.colab import drive\n","import shutil"],"metadata":{"id":"CdowIoy4KgAw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Set paths"],"metadata":{"id":"ThZpHZPKKK1O"}},{"cell_type":"code","source":["# Mount Google Drive\n","drive.mount('/content/drive')"],"metadata":{"id":"-Lrh4ORlDGgu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Define paths\n","base_path = '/content/drive/MyDrive/'\n","source_path = base_path + 'Datasets/revisitop/rparis6k/data/'\n","dest_base_path = base_path + 'MyProject/rparis6k/'\n","\n","train_dataset_path = dest_base_path + 'train/'\n","validation_dataset_path = dest_base_path + 'validation/'\n","test_dataset_path = dest_base_path + 'test/'"],"metadata":{"id":"BlfRX-r_3YUc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["os.makedirs(dest_base_path, exist_ok=True)\n","\n","os.makedirs(train_dataset_path, exist_ok=True)\n","os.makedirs(validation_dataset_path, exist_ok=True)\n","os.makedirs(test_dataset_path, exist_ok=True)\n","\n","os.makedirs(os.path.join(train_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(validation_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(test_dataset_path, 'images'), exist_ok=True)"],"metadata":{"id":"4auASQMqjYRt"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Copy images"],"metadata":{"id":"Ai5-oF_zjkxz"}},{"cell_type":"code","source":["# Function to copy images\n","def copy_images(file_list, dest_folder):\n"," with open(file_list, 'r') as f:\n"," lines = f.readlines()\n"," for line in tqdm(lines, desc=f\"Copying images to {dest_folder}\"):\n"," img_name = line.strip()\n"," src = os.path.join(source_path, img_name)\n"," dst = os.path.join(dest_folder, img_name)\n"," os.makedirs(os.path.dirname(dst), exist_ok=True)\n"," shutil.copy2(src, dst)\n","\n","# Copy images for each set\n","if (len(os.listdir(train_dataset_path + 'images/')) == 0 and\n"," len(os.listdir(validation_dataset_path + 'images/')) == 0 and\n"," len(os.listdir(test_dataset_path + 'images/')) == 0):\n"," copy_images(dest_base_path + 'train.txt', train_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'val.txt', validation_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'test.txt', test_dataset_path + 'images/')\n"," print(\"Dataset division completed!\\n\")\n","else:\n"," print(\"One or more directories are not empty. Copy operation aborted.\\n\")\n","\n","print(f\"Number of images in train set: {len(os.listdir(train_dataset_path + 'images/'))}\")\n","print(f\"Number of images in validation set: {len(os.listdir(validation_dataset_path + 'images/'))}\")\n","print(f\"Number of images in test set: {len(os.listdir(test_dataset_path + 'images/'))}\")"],"metadata":{"id":"-A3ZtrofqptN"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Review dataset"],"metadata":{"id":"IUcrpj6aIzqt"}},{"cell_type":"markdown","source":["Verify the dataset content by printing the categories from the `labels.json` file. There should be 13 total categories. Index 0 is always set to be the `background` class which may be unused in the dataset."],"metadata":{"id":"xYXzxYC3I0h5"}},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","for category_item in labels_json[\"categories\"]:\n"," print(f\"{category_item['id']}: {category_item['name']}\")"],"metadata":{"id":"v6rYvXVwIwKp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Visualize the training dataset\n","import matplotlib.pyplot as plt\n","from matplotlib import patches, text, patheffects\n","from collections import defaultdict\n","import math\n","\n","def draw_outline(obj):\n"," obj.set_path_effects([patheffects.Stroke(linewidth=4, foreground='black'), patheffects.Normal()])\n","def draw_box(ax, bb):\n"," patch = ax.add_patch(patches.Rectangle((bb[0],bb[1]), bb[2], bb[3], fill=False, edgecolor='red', lw=2))\n"," draw_outline(patch)\n","def draw_text(ax, bb, txt, disp):\n"," text = ax.text(bb[0],(bb[1]-disp),txt,verticalalignment='top'\n"," ,color='white',fontsize=10,weight='bold')\n"," draw_outline(text)\n","def draw_bbox(ax, annotations_list, id_to_label, image_shape):\n"," for annotation in annotations_list:\n"," cat_id = annotation[\"category_id\"]\n"," bbox = annotation[\"bbox\"]\n"," draw_box(ax, bbox)\n"," draw_text(ax, bbox, id_to_label[cat_id], image_shape[0] * 0.05)\n","def visualize(dataset_folder, max_examples=None):\n"," with open(os.path.join(dataset_folder, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n"," images = labels_json[\"images\"]\n"," cat_id_to_label = {item[\"id\"]:item[\"name\"] for item in labels_json[\"categories\"]}\n"," image_annots = defaultdict(list)\n"," for annotation_obj in labels_json[\"annotations\"]:\n"," image_id = annotation_obj[\"image_id\"]\n"," image_annots[image_id].append(annotation_obj)\n","\n"," if max_examples is None:\n"," max_examples = len(image_annots.items())\n"," n_rows = math.ceil(max_examples / 3)\n"," fig, axs = plt.subplots(n_rows, 3, figsize=(24, n_rows*8)) # 3 columns(2nd index), 8x8 for each image\n"," for ind, (image_id, annotations_list) in enumerate(list(image_annots.items())[:max_examples]):\n"," ax = axs[ind//3, ind%3]\n"," img = plt.imread(os.path.join(dataset_folder, \"images\", images[image_id][\"file_name\"]))\n"," ax.imshow(img)\n"," draw_bbox(ax, annotations_list, cat_id_to_label, img.shape)\n"," plt.show()\n","\n","visualize(train_dataset_path, 9)"],"metadata":{"id":"8D17VhVAI33W","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Create dataset"],"metadata":{"id":"fPyRkkKYJEOB"}},{"cell_type":"code","source":["if os.path.exists(\"/tmp/od_data/train\"):\n"," shutil.rmtree(\"/tmp/od_data/train\") # TODO: do I need this instruction ?"],"metadata":{"id":"6MICrUKmFEsS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: is it possible to add a progress bar?"],"metadata":{"id":"7f4Uke50b1X9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir=\"/tmp/od_data/validation\")\n","print(\"train_data size: \", train_data.size)\n","print(\"validation_data size: \", validation_data.size)"],"metadata":{"id":"cooLOJrmJDbo"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Augmentation"],"metadata":{"id":"JGLtQeX3UG_s"}},{"cell_type":"code","source":["import albumentations as A\n","import numpy as np\n","import cv2"],"metadata":{"id":"7R9rK4PXUKsa"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#### Search for augmented data"],"metadata":{"id":"J5WnmuRFJ1-6"}},{"cell_type":"code","source":["# FIXME"],"metadata":{"id":"Bi-iMOfRW74c"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def check_and_delete_augmented_images(folder_path):\n"," augmented_images = [f for f in os.listdir(folder_path) if 'aug' in f]\n","\n"," if len(augmented_images) > 0:\n"," print(f\"Found {len(augmented_images)} augmented images in {folder_path}.\")\n"," user_input = input(\"Do you want to delete these images? (yes/no): \").strip().lower()\n","\n"," if user_input == 'yes':\n"," for img in tqdm(augmented_images, desc=\"Deleting augmented images\"):\n"," img_path = os.path.join(folder_path, img)\n"," os.remove(img_path)\n"," print(\"Augmented images deleted successfully.\")\n"," else:\n"," print(\"Deletion aborted by user.\")\n"," else:\n"," print(\"No augmented images found.\")\n","\n","check_and_delete_augmented_images(train_dataset_path + 'images/')"],"metadata":{"id":"KK83p2AQJ5BG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#### Augment data"],"metadata":{"id":"77QQeKsFKCBx"}},{"cell_type":"code","source":["# TODO: check label_fields\n","# FIXME: fix augmentation"],"metadata":{"id":"F9PAySGhXvgG"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def get_dynamic_transform(image_height, image_width):\n"," crop_height = min(224, image_height)\n"," crop_width = min(224, image_width)\n","\n"," return A.Compose([ # FIXME\n"," A.HorizontalFlip(p=0.5),\n"," #A.RandomRotate90(p=0.5),\n"," A.RandomBrightnessContrast(p=0.2),\n"," A.Perspective(p=0.5), # TODO: check perspective transformations\n"," A.RandomGamma(p=0.2),\n"," A.GaussianBlur(blur_limit=(3, 7), p=0.1),\n"," A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.2),\n"," #A.RandomCrop(height=crop_height, width=crop_width, p=0.5),\n"," #A.Cutout(num_holes=8, max_h_size=8, max_w_size=8, fill_value=0, p=0.5),\n"," ], bbox_params=A.BboxParams(format='coco', label_fields=['category_ids']))"],"metadata":{"id":"Kvs_HcPNUOFW"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def clip_bboxes(bboxes):\n"," return [[max(0, min(1, coord)) for coord in bbox] for bbox in bboxes]\n","\n","def resize_if_needed(image, min_size=224): # TODO: check if it's necessary\n"," height, width = image.shape[:2]\n"," if height < min_size or width < min_size:\n"," scale = min_size / min(height, width)\n"," new_height = int(height * scale)\n"," new_width = int(width * scale)\n"," image = cv2.resize(image, (new_width, new_height), interpolation=cv2.INTER_LINEAR)\n"," return image\n","\n","def apply_augmentations(image, bboxes, categories):\n"," height, width = image.shape[:2]\n"," transform = get_dynamic_transform(height, width)\n","\n"," # Clip bboxes before applying augmentations\n"," clipped_bboxes = clip_bboxes(bboxes)\n","\n"," augmented = transform(image=image, bboxes=clipped_bboxes, category_ids=categories)\n","\n"," # Clip bboxes after augmentations as well\n"," augmented['bboxes'] = clip_bboxes(augmented['bboxes'])\n","\n"," return augmented['image'], augmented['bboxes'], augmented['category_ids']\n","\n","def update_coco_annotations(annotations, new_image_id, new_bboxes, new_categories):\n"," new_annotations = []\n"," for i, (bbox, category) in enumerate(zip(new_bboxes, new_categories)):\n"," new_annotations.append({\n"," \"id\": len(annotations) + i,\n"," \"image_id\": new_image_id,\n"," \"category_id\": category,\n"," \"bbox\": [round(coord, 1) for coord in bbox]\n"," })\n"," return new_annotations"],"metadata":{"id":"rFEifv__UMhZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Load COCO annotations\n","with open(os.path.join(train_dataset_path, 'labels.json'), 'r') as f:\n"," coco_data = json.load(f)\n","\n","augmented_images = []\n","augmented_annotations = []\n","new_image_id = len(coco_data['images'])\n","\n","for image_info in tqdm(coco_data['images']):\n"," # Load image\n"," image_path = os.path.join(train_dataset_path, 'images', image_info['file_name'])\n","\n"," if not os.path.exists(image_path):\n"," print(f\"Image not found: {image_path}\")\n"," continue\n","\n"," image = cv2.imread(image_path)\n","\n"," if image is None:\n"," print(f\"Failed to load image: {image_path}\")\n"," continue\n","\n"," image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n","\n"," # Get annotations for this image\n"," image_annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] == image_info['id']]\n"," bboxes = [ann['bbox'] for ann in image_annotations]\n"," categories = [ann['category_id'] for ann in image_annotations]\n","\n"," # Clip bboxes before applying augmentations\n"," bboxes = clip_bboxes(bboxes)\n","\n"," #image = resize_if_needed(image) # TODO: check if it's necessary\n","\n"," # Apply augmentations\n"," aug_image, aug_bboxes, aug_categories = apply_augmentations(image, bboxes, categories)\n","\n"," # Save augmented image\n"," aug_image_name = f\"aug_{image_info['file_name']}\"\n"," cv2.imwrite(os.path.join(train_dataset_path, 'images', aug_image_name), cv2.cvtColor(aug_image, cv2.COLOR_RGB2BGR)) # FIXME: why do I do COLOR_RGB2BGR?\n","\n"," # Update COCO annotations\n"," augmented_images.append({\n"," \"id\": new_image_id,\n"," \"file_name\": aug_image_name\n"," })\n"," augmented_annotations.extend(update_coco_annotations(coco_data['annotations'], new_image_id, aug_bboxes, aug_categories))\n"," new_image_id += 1\n","\n","# Update COCO data\n","coco_data['images'].extend(augmented_images)\n","coco_data['annotations'].extend(augmented_annotations)\n","\n","# Save updated COCO annotations\n","with open(os.path.join(train_dataset_path, 'labels.json'), 'w') as f:\n"," json.dump(coco_data, f)\n","\n","print(f\"Added {len(augmented_images)} augmented images to the dataset.\")"],"metadata":{"id":"cOtUZ5tnUPwJ"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["#### Rewrite train data"],"metadata":{"id":"zbBGHtdpJ98E"}},{"cell_type":"code","source":["shutil.rmtree(\"/tmp/od_data/train\") # TODO: do I need this instruction ?"],"metadata":{"id":"pUpntlCgYgUh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","print(\"Updated train_data size: \", train_data.size)"],"metadata":{"id":"2A_QSKYdUs2U"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Retrain model"],"metadata":{"id":"IzhODm2CJaMB"}},{"cell_type":"markdown","source":["### Set retraining options"],"metadata":{"id":"jF7sZHYyJcl7"}},{"cell_type":"code","source":["spec = object_detector.SupportedModels.MOBILENET_MULTI_AVG_I384\n","\n","hparams = object_detector.HParams(\n"," learning_rate=0.1,\n"," batch_size=16,\n"," epochs=100,\n"," cosine_decay_epochs=100,\n"," cosine_decay_alpha=0.2,\n"," export_dir='exported_model'\n",")\n","\n","model_options = object_detector.ModelOptions(\n"," l2_weight_decay=3e-4\n",")\n","\n","options = object_detector.ObjectDetectorOptions(\n"," supported_model=spec,\n"," hparams=hparams,\n"," model_options=model_options\n",")"],"metadata":{"id":"iC6cVSpVJWFw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Run retraining"],"metadata":{"id":"bfQoIMPOJhK5"}},{"cell_type":"code","source":["model = object_detector.ObjectDetector.create(\n"," train_data=train_data,\n"," validation_data=validation_data,\n"," options=options\n",")"],"metadata":{"id":"BxT1UHQWJfAX","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Evaluate the model performance\n","\n","After training the model, evaluate it on validation dataset and print the loss and coco_metrics. The most important metric for evaluating the model performance is typically the \"AP\" coco metric for Average Precision."],"metadata":{"id":"y3h_cOytJjXk"}},{"cell_type":"code","source":["# TODO: is it possible to add a progress bar?"],"metadata":{"id":"WaUndk7ZZgbk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["loss, coco_metrics = model.evaluate(validation_data, batch_size=4)\n","print(f\"Validation loss: {loss}\")\n","print(f\"Validation coco metrics: {coco_metrics}\")"],"metadata":{"id":"kJabrlEjJkmH","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Export model\n","\n","After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label map."],"metadata":{"id":"Vkg5K-xWJmj6"}},{"cell_type":"code","source":["model.export_model()\n","!ls exported_model\n","files.download('exported_model/model.tflite')"],"metadata":{"id":"ADUK65tBJn7x","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Model quantization"],"metadata":{"id":"6g6QkA6YJpf8"}},{"cell_type":"markdown","source":["Model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy.\n","\n","This section of the guide explains how to apply quantization to your model. Model Maker supports two forms of quantization for object detector:\n","1. Quantization Aware Training: 8 bit integer precision for CPU usage\n","2. Post-Training Quantization: 16 bit floating point precision for GPU usage"],"metadata":{"id":"Jl1Dc9EhJtRo"}},{"cell_type":"markdown","source":["### Quantization aware training (int8 quantization)\n","Quantization aware training (QAT) is a fine-tuning step which happens after fully training your model. This technique further tunes a model which emulates inference time quantization in order to account for the lower precision of 8 bit integer quantization. For on-device applications with a standard CPU, use Int8 precision. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/training) documentation.\n","\n","To apply quantization aware training and export to an int8 model, create a `QATHParams` configuration and run the `quantization_aware_training` method. See the **Hyperparameters** section below on detailed usage of `QATHParams`."],"metadata":{"id":"D88f-mqgJvwi"}},{"cell_type":"code","source":["qat_hparams = object_detector.QATHParams(learning_rate=0.3, batch_size=4, epochs=10, decay_steps=6, decay_rate=0.96)\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"NbTDJtEJJr7Y"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["The QAT step often requires multiple runs to tune the parameters of training. To avoid having to rerun model training using the `create` method, use the `restore_float_ckpt` method to restore the model state back to the fully trained float model(After running the `create` method) in order to run QAT again."],"metadata":{"id":"5KEPghm9JyUi"}},{"cell_type":"code","source":["new_qat_hparams = object_detector.QATHParams(learning_rate=0.9, batch_size=4, epochs=15, decay_steps=5, decay_rate=0.96)\n","model.restore_float_ckpt()\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=new_qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"vALj8IqaJ1B4"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Finally, us the `export_model` to export to an int8 quantized model. The `export_model` function will automatically export to either float32 or int8 model depending on whether `quantization_aware_training` was run."],"metadata":{"id":"HdZAaTJKJ4ky"}},{"cell_type":"code","source":["model.export_model('model_int8_qat.tflite')\n","!ls -lh exported_model\n","files.download('exported_model/model_int8_qat.tflite')"],"metadata":{"id":"AePYkJHwJ2OG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Post-training quantization (fp16 quantization)\n","\n","Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 16-bit floats. Float16 quantization is reccomended for GPU usage. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/post_training) documentation.\n","\n","First, import the MediaPipe Model Maker quantization module:"],"metadata":{"id":"mm4gCymvJ7wK"}},{"cell_type":"code","source":["from mediapipe_model_maker import quantization"],"metadata":{"id":"mt8zeY52J8Gk"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Define a QuantizationConfig object using the `for_float16()` class method. This configuration modifies a trained model to use 16-bit floating point numbers instead of 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class."],"metadata":{"id":"r6eGgib4J-Ac"}},{"cell_type":"code","source":["quantization_config = quantization.QuantizationConfig.for_float16()"],"metadata":{"id":"IjAn-GBKJ_ab"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Export the model using the additional quantization_config object to apply post-training quantization. Note that if you previously ran `quantization_aware_training`, you must first convert the model back to a float model by using `restore_float_ckpt`."],"metadata":{"id":"ChqPtssVKCA8"}},{"cell_type":"code","source":["model.restore_float_ckpt()\n","model.export_model(model_name=\"model_fp16.tflite\", quantization_config=quantization_config)\n","!ls -lh exported_model\n","files.download('exported_model/model_fp16.tflite')"],"metadata":{"id":"f90WxGNNKCSW"},"execution_count":null,"outputs":[]}],"metadata":{"colab":{"last_runtime":{"build_target":"//learning/grp/tools/ml_python:ml_notebook","kind":"private"},"private_outputs":true,"provenance":[{"file_id":"https://github.com/googlesamples/mediapipe/blob/main/examples/customization/object_detector.ipynb","timestamp":1720270505355},{"file_id":"11PG1YgsQWWLJ8jpqJ6QY7hjYWzxVwoCb","timestamp":1677706798050}],"gpuType":"T4","collapsed_sections":["6g6QkA6YJpf8"]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0} \ No newline at end of file