diff --git a/.gitignore b/.gitignore index da17793..60f66b8 100644 --- a/.gitignore +++ b/.gitignore @@ -10,7 +10,7 @@ __pycache__/ **/__pycache__/ # models -models/ +#models/ # git # TODO: unignore when ready .gitattributes diff --git a/inference/test.py b/inference/test.py index c87d28a..e65e739 100644 --- a/inference/test.py +++ b/inference/test.py @@ -3,6 +3,8 @@ """ import os.path +import random + import tensorflow as tf import matplotlib.pyplot as plt import matplotlib.patches as patches @@ -65,6 +67,13 @@ def visualize_detections(image_path, boxes, classes_scores, threshold=0.5): height, width, _ = image.shape for box, class_score in zip(boxes[0], classes_scores[0]): score = np.max(class_score) + + '''debug''' + print(f"image_path: {image_path}") + print(f"box: {box}") + print(f"class_score: {class_score}") + print(f"score: {score}") + if score > threshold: ymin, xmin, ymax, xmax = box rect = patches.Rectangle((xmin*width, ymin*height), (xmax-xmin)*width, (ymax-ymin)*height, @@ -90,10 +99,18 @@ def train_images(): with open(train_path, 'r') as f: train_images = [line.strip() for line in f] + random.shuffle(train_images) # TODO: check + for image_name in train_images[:5]: image_path = os.path.join(base_path, 'datasets', 'rparis6k', 'images', image_name) image_np = load_image_into_numpy_array(image_path) boxes, classes_scores = run_inference(image_np) + + '''debug''' + print(f"Image: {image_name}") + print(boxes) + print(classes_scores) + visualize_detections(image_path, boxes, classes_scores) @@ -132,7 +149,7 @@ def test_images(): with open(test_path, 'r') as f: test_images = [line.strip() for line in f] - for image_name in test_images[:5]: + for image_name in test_images[:1]: # TODO: check image_path = os.path.join(base_path, 'datasets', 'rparis6k', 'images', image_name) image_np = load_image_into_numpy_array(image_path) boxes, classes_scores = run_inference(image_np) diff --git a/models/model.tflite b/models/model.tflite new file mode 100644 index 0000000..b4e8916 Binary files /dev/null and b/models/model.tflite differ diff --git a/models/model_fp16.tflite b/models/model_fp16.tflite new file mode 100644 index 0000000..403f928 Binary files /dev/null and b/models/model_fp16.tflite differ diff --git a/scripts/augment_dataset.py b/scripts/augment_dataset.py new file mode 100644 index 0000000..737510a --- /dev/null +++ b/scripts/augment_dataset.py @@ -0,0 +1,209 @@ +import albumentations as A +import cv2 +import json +import os +import shutil +from tqdm import tqdm + + +def get_transform(set='train'): + bboxes_params = A.BboxParams(format='coco', min_visibility=0.3, label_fields=['class_labels']) # TODO: check min_visibility + + if set == 'train': + transform = A.Compose([ # TODO: update pipeline (?) + # TODO: do I need to resize images? + #A.RandomResizedCrop(height=640, width=640, scale=(0.8, 1.0), ratio=(0.9, 1.1), p=1.0), # TODO: check h,w + A.HorizontalFlip(p=0.5), + A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5), + A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5), + A.GaussNoise(var_limit=(10.0, 50.0), p=0.5), + A.RandomShadow(num_shadows_lower=1, num_shadows_upper=3, shadow_dimension=5, shadow_roi=(0, 0.5, 1, 1), p=0.3), + A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.5), + A.OneOf([ + A.MotionBlur(blur_limit=7, p=0.5), + A.MedianBlur(blur_limit=7, p=0.5), + A.GaussianBlur(blur_limit=7, p=0.5), + ], p=0.3), + A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, border_mode=0, p=0.5), + A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # TODO: check + ], bbox_params=bboxes_params) + + elif set == 'validation': + transform = A.Compose([ # TODO: update pipeline + # TODO: do I need to resize images? + A.HorizontalFlip(p=0.5), + A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.5), + A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # TODO: check + ], bbox_params=bboxes_params) + + return transform + + +def clip_bbox(bbox, image_width, image_height): + x_min, y_min, width, height = bbox + + x_min = max(0, min(x_min, image_width - 1)) # TODO: check -1 + y_min = max(0, min(y_min, image_height - 1)) # TODO: check -1 + width = min(width, image_width - x_min) + height = min(height, image_height - y_min) + + return [x_min, y_min, width, height] + + +def validate_bbox(bbox, image_width, image_height): + x, y, w, h = bbox + + return 0 <= x < image_width and 0 <= y < image_height and x + w <= image_width and y + h <= image_height + + +def apply_augmentation(image_path, bboxes, class_labels, output_path, output_filename, transform): + # Read the image + image = cv2.imread(image_path) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + image_height, image_width = image.shape[:2] + + # Apply the augmentation + try: + transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels) + except Exception as e: + print(f"Error during transformation: {e}") + return [], [] + + # Save the augmented image + augmented_image_path = os.path.join(output_path, output_filename) + cv2.imwrite(augmented_image_path, cv2.cvtColor(transformed['image'], cv2.COLOR_RGB2BGR)) + + return transformed['bboxes'], transformed['class_labels'] + + +def augment_dataset(input_path, output_path, transform, n_images, n_annotations, num_augmentations=5): + # Load the original COCO JSON file + with open(os.path.join(input_path, 'labels.json'), 'r') as f: + coco_data = json.load(f) + + new_images = [] + new_annotations = [] + + # Copy original images and annotations + for img in tqdm(coco_data['images'], desc="Copying original images"): + + src_path = os.path.join(input_path, 'images', img['file_name']) + dst_path = os.path.join(output_path, 'images', img['file_name']) + shutil.copy2(src_path, dst_path) + + new_images.append(img) + img_anns = [ann for ann in coco_data['annotations'] if ann['image_id'] == img['id']] + new_annotations.extend(img_anns) + + '''debug''' + print("Before augmentation:") + print(f"Number of images: {len(new_images)}") + print(f"Number of annotations: {len(new_annotations)}") + + # Apply augmentations + for img in tqdm(coco_data['images'], desc="Augmenting images"): + image_path = os.path.join(input_path, 'images', img['file_name']) + + annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] == img['id']] + + image = cv2.imread(image_path) + image_height, image_width = image.shape[:2] + + for i in range(num_augmentations): + bboxes = [ann['bbox'] for ann in annotations] + class_labels = [ann['category_id'] for ann in annotations] + + # TODO: should I call the function clip_bbox() regardless of the function validate_bbox()? + for bbox in bboxes: + if not validate_bbox(bbox, image_width, image_height): + bboxes = [clip_bbox(bbox, image_width, image_height) for bbox in bboxes] + + new_filename = f"{os.path.splitext(img['file_name'])[0]}_aug_{i}.jpg" + + new_bboxes, new_class_labels = apply_augmentation( + image_path, bboxes, class_labels, + os.path.join(output_path, 'images'), new_filename, transform + ) + + new_img_id = n_images + 1 + new_images.append({ + 'id': new_img_id, + 'file_name': new_filename + }) + + n_images = n_images + 1 + + for bbox, cat_id in zip(new_bboxes, new_class_labels): + new_annotations.append({ + 'id': n_annotations + 1, + 'image_id': new_img_id, + 'category_id': cat_id, + 'bbox': bbox + }) + + n_annotations = n_annotations + 1 + + '''debug''' + print("After augmentation:") + print(f"Number of images: {len(new_images)}") + print(f"Number of annotations: {len(new_annotations)}") + + # Create the new COCO JSON file + new_coco_data = { + 'categories': coco_data['categories'], + 'images': new_images, + 'annotations': new_annotations + } + + # Save the new COCO JSON file + with open(os.path.join(output_path, 'labels.json'), 'w') as f: + json.dump(new_coco_data, f, indent=4) + + return n_images, n_annotations + + +def main(): + + train_dataset_path = 'path/to/dataset/train/' + validation_dataset_path = 'path/to/dataset/validation/' + test_dataset_path = 'path/to/dataset/test/' + + augmented_train_dataset_path = 'path/to/dest_base_path' + 'train_augmented/' + augmented_validation_dataset_path = 'path/to/dest_base_path' + 'validation_augmented/' + + os.makedirs(os.path.join(augmented_train_dataset_path, 'images'), exist_ok=True) + os.makedirs(os.path.join(augmented_validation_dataset_path, 'images'), exist_ok=True) + + with open(os.path.join(train_dataset_path, 'labels.json'), 'r') as f: + train_json = json.load(f) + + with open(os.path.join(validation_dataset_path, 'labels.json'), 'r') as f: + val_json = json.load(f) + + with open(os.path.join(test_dataset_path, 'labels.json'), 'r') as f: + test_json = json.load(f) + + n_images = max(train_json['images'][-1]['id'], val_json['images'][-1]['id'], test_json['images'][-1]['id']) + n_annotations = max(train_json['annotations'][-1]['id'], val_json['annotations'][-1]['id'], + test_json['annotations'][-1]['id']) + + if os.listdir(augmented_train_dataset_path + 'images/') == []: + n_images, n_annotations = augment_dataset(train_dataset_path, augmented_train_dataset_path, get_transform('train'), n_images, n_annotations, num_augmentations=5) + else: + print("Augmentation on the training set has already been made.") + + if os.listdir(augmented_validation_dataset_path + 'images/') == []: + augment_dataset(validation_dataset_path, augmented_validation_dataset_path, get_transform('validation'), n_images, n_annotations, num_augmentations=5) + else: + print("Augmentation on the validation set has already been made.") + + + count1 = sum(1 for filename in os.listdir(os.path.join(augmented_train_dataset_path, 'images')) if any(filename.lower().endswith(ext) for ext in ['.jpg', '.jpeg'])) + count2 = sum(1 for filename in os.listdir(os.path.join(augmented_validation_dataset_path, 'images')) if any(filename.lower().endswith(ext) for ext in ['.jpg', '.jpeg'])) + + print(f"Number of images in the train_augmented folder: {count1}") # TODO: make prettier + print(f"Number of images in the validation_augmented folder: {count2}") # TODO: make prettier + + +if __name__ == '__main__': + main() diff --git a/training/mp_training_paris6k.ipynb b/training/mp_training_paris6k.ipynb index b95ef82..239f5bc 100644 --- a/training/mp_training_paris6k.ipynb +++ b/training/mp_training_paris6k.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","source":["# Object detection model customization"],"metadata":{"id":"rQIYsEciHsMN"}},{"cell_type":"markdown","source":["## Setup"],"metadata":{"id":"NReBxyCQH1Rx"}},{"cell_type":"markdown","source":["To install the libraries for customizing a model, run the following commands:"],"metadata":{"id":"xbuuejtPH97m"}},{"cell_type":"code","source":["!python --version\n","!pip install --upgrade pip\n","!pip install mediapipe-model-maker"],"metadata":{"id":"ceFO9GPSHyRv","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Use the following code to import the required Python classes:"],"metadata":{"id":"2beHFLLIIIv0"}},{"cell_type":"code","source":["from google.colab import files\n","import os\n","import json\n","from tqdm import tqdm\n","import tensorflow as tf\n","assert tf.__version__.startswith('2')\n","\n","from mediapipe_model_maker import object_detector\n","\n","from google.colab import drive\n","import shutil\n","\n","%matplotlib inline"],"metadata":{"id":"0OpjgSuBIJJC","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Colab Pro"],"metadata":{"id":"GYUY-Sw33nJD"}},{"cell_type":"code","source":["gpu_info = !nvidia-smi\n","gpu_info = '\\n'.join(gpu_info)\n","if gpu_info.find('failed') >= 0:\n"," print('Not connected to a GPU')\n","else:\n"," print(gpu_info)"],"metadata":{"id":"9bHLSYKW3pEt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from psutil import virtual_memory\n","ram_gb = virtual_memory().total / 1e9\n","print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n","\n","if ram_gb < 20:\n"," print('Not using a high-RAM runtime')\n","else:\n"," print('You are using a high-RAM runtime!')"],"metadata":{"id":"VvbzIdMd3s5N"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Set paths"],"metadata":{"id":"ThZpHZPKKK1O"}},{"cell_type":"code","source":["# Mount Google Drive\n","drive.mount('/content/drive')"],"metadata":{"id":"-Lrh4ORlDGgu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["base_path = '/content/drive/MyDrive/'\n","source_path = base_path + 'Datasets/revisitop/rparis6k/data/'\n","\n","dest_base_path = base_path + 'MyProject/rparis6k/'\n","\n","train_dataset_path = dest_base_path + 'train/'\n","validation_dataset_path = dest_base_path + 'validation/'\n","test_dataset_path = dest_base_path + 'test/'"],"metadata":{"id":"BlfRX-r_3YUc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Create directories\n","\n","os.makedirs(dest_base_path, exist_ok=True)\n","\n","os.makedirs(train_dataset_path, exist_ok=True)\n","os.makedirs(validation_dataset_path, exist_ok=True)\n","os.makedirs(test_dataset_path, exist_ok=True)\n","\n","os.makedirs(os.path.join(train_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(validation_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(test_dataset_path, 'images'), exist_ok=True)"],"metadata":{"id":"4auASQMqjYRt","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Prepare data"],"metadata":{"id":"Yq6UT0j9ILD_"}},{"cell_type":"markdown","source":["### Copy images"],"metadata":{"id":"Ai5-oF_zjkxz"}},{"cell_type":"code","source":["# FIXME:\n","def check_and_delete_augmented_images(folder_path):\n"," augmented_images = [f for f in os.listdir(folder_path) if 'aug' in f]\n","\n"," if len(augmented_images) > 0:\n"," print(f\"Found {len(augmented_images)} augmented images in {folder_path}.\")\n"," user_input = input(\"Do you want to delete these images? (yes/no): \").strip().lower()\n","\n"," if user_input == 'yes':\n"," # delete all the elements in the image folder\n"," # delete the labels.json file\n"," for img in tqdm(augmented_images, desc=\"Deleting augmented images\"):\n"," img_path = os.path.join(folder_path, img)\n"," os.remove(img_path)\n"," print(\"Augmented images deleted successfully.\")\n"," else:\n"," print(\"Deletion aborted by user.\")\n"," else:\n"," print(\"No augmented images found.\")"],"metadata":{"id":"tttYb0xMSlir"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: mention that the train set contain augmented images if it's the case\n","# TODO: use also a boolean value to memorize the aug data deletion"],"metadata":{"id":"Jsgo4eA9R0xv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Function to copy images\n","def copy_images(file_list, dest_folder):\n"," with open(file_list, 'r') as f:\n"," lines = f.readlines()\n"," for line in tqdm(lines, desc=f\"Copying images to {dest_folder}\"):\n"," img_name = line.strip()\n"," src = os.path.join(source_path, img_name)\n"," dst = os.path.join(dest_folder, img_name)\n"," os.makedirs(os.path.dirname(dst), exist_ok=True)\n"," shutil.copy2(src, dst)\n","\n","# Copy images for each set\n","if os.listdir(train_dataset_path + 'images/') == [] and\n"," os.listdir(validation_dataset_path + 'images/') == [] and\n"," os.listdir(test_dataset_path + 'images/') == []:\n"," copy_images(dest_base_path + 'train.txt', train_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'val.txt', validation_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'test.txt', test_dataset_path + 'images/')\n"," print(\"Dataset division completed!\\n\")\n","else:\n"," print(\"One or more directories are not empty. Copy operation aborted.\\n\")\n","\n","print(f\"Number of images in train set: {len(os.listdir(train_dataset_path + 'images/'))}\")\n","print(f\"Number of images in validation set: {len(os.listdir(validation_dataset_path + 'images/'))}\")\n","print(f\"Number of images in test set: {len(os.listdir(test_dataset_path + 'images/'))}\")"],"metadata":{"id":"-A3ZtrofqptN"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Review data"],"metadata":{"id":"IUcrpj6aIzqt"}},{"cell_type":"markdown","source":["Verify the dataset content by printing the categories from the `labels.json` file. There should be 13 total categories. Index 0 is always set to be the `background` class which may be unused in the dataset."],"metadata":{"id":"xYXzxYC3I0h5"}},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","for category_item in labels_json[\"categories\"]:\n"," print(f\"{category_item['id']}: {category_item['name']}\")"],"metadata":{"id":"v6rYvXVwIwKp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Visualize the training data [FIXME]\n","import matplotlib.pyplot as plt\n","from matplotlib import patches, text, patheffects\n","from collections import defaultdict\n","import math\n","\n","# TODO: it may be interesting if it shows two (or n) random images per category\n","\n","def draw_outline(obj):\n"," obj.set_path_effects([patheffects.Stroke(linewidth=4, foreground='black'), patheffects.Normal()])\n","\n","def draw_box(ax, bb):\n"," patch = ax.add_patch(patches.Rectangle((bb[0],bb[1]), bb[2], bb[3], fill=False, edgecolor='red', lw=2))\n"," draw_outline(patch)\n","\n","def draw_text(ax, bb, txt, disp):\n"," text = ax.text(bb[0], (bb[1]-disp), txt, verticalalignment='top', color='white', fontsize=10, weight='bold')\n"," draw_outline(text)\n","\n","def draw_bbox(ax, annotations_list, id_to_label, image_shape):\n"," for annotation in annotations_list:\n"," cat_id = annotation[\"category_id\"]\n"," bbox = annotation[\"bbox\"]\n"," draw_box(ax, bbox)\n"," draw_text(ax, bbox, id_to_label[cat_id], image_shape[0] * 0.05)\n","\n","def visualize(dataset_folder, max_examples=None):\n"," with open(os.path.join(dataset_folder, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","\n"," images = labels_json[\"images\"]\n"," cat_id_to_label = {item[\"id\"]:item[\"name\"] for item in labels_json[\"categories\"]}\n"," image_annots = defaultdict(list)\n","\n"," for annotation_obj in labels_json[\"annotations\"]:\n"," image_id = annotation_obj[\"image_id\"]\n"," image_annots[image_id].append(annotation_obj)\n","\n"," if max_examples is None:\n"," max_examples = len(image_annots.items())\n","\n"," n_rows = math.ceil(max_examples / 3)\n"," fig, axs = plt.subplots(n_rows, 3, figsize=(24, n_rows*8)) # 3 columns(2nd index), 8x8 for each image\n","\n"," for ind, (image_id, annotations_list) in enumerate(list(image_annots.items())[:max_examples]):\n"," ax = axs[ind//3, ind%3]\n"," img = plt.imread(os.path.join(dataset_folder, \"images\", images[image_id][\"file_name\"]))\n"," ax.imshow(img)\n"," draw_bbox(ax, annotations_list, cat_id_to_label, img.shape)\n","\n"," plt.show()\n","\n","visualize(train_dataset_path, 9)"],"metadata":{"id":"8D17VhVAI33W","collapsed":true,"cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Create dataset"],"metadata":{"id":"fPyRkkKYJEOB"}},{"cell_type":"code","source":["# TODO: do I need this instruction ?\n","\n","cache_dirs = [\"/tmp/od_data/train\", \"/tmp/od_data/validation\"]\n","\n","for cache_dir in cache_dirs:\n"," if os.path.exists(cache_dir):\n"," shutil.rmtree(cache_dir)"],"metadata":{"id":"6MICrUKmFEsS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir=\"/tmp/od_data/validation\")\n","\n","print(f\"{'Training Dataset Size:':<25} {train_data.size:>4}\")\n","print(f\"{'Validation Dataset Size:':<25} {validation_data.size:>4}\")"],"metadata":{"id":"cooLOJrmJDbo"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Augmentation"],"metadata":{"id":"JGLtQeX3UG_s"}},{"cell_type":"markdown","source":["### Augment data"],"metadata":{"id":"77QQeKsFKCBx"}},{"cell_type":"code","source":["import albumentations as A\n","import numpy as np\n","import cv2"],"metadata":{"id":"7R9rK4PXUKsa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: check the aug data deletion boolean value\n","# if true -> ask to execute augmentation\n","# if false -> ask to delete aug data before and then to execute augmentation (update train_data_path)"],"metadata":{"id":"aOLGUAwkTGq2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def get_transform(set='train'):\n"," bboxes_params = A.BboxParams(format='coco', min_visibility=0.3, label_fields=['class_labels']) # TODO: check min_visibility\n","\n"," if set == 'train':\n"," transform = A.Compose([ # TODO: update pipeline (?)\n"," #A.RandomResizedCrop(height=640, width=640, scale=(0.8, 1.0), ratio=(0.9, 1.1), p=1.0), # TODO: check h,w\n"," A.HorizontalFlip(p=0.5),\n"," A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),\n"," A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),\n"," A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),\n"," A.RandomShadow(num_shadows_lower=1, num_shadows_upper=3, shadow_dimension=5, shadow_roi=(0, 0.5, 1, 1), p=0.3),\n"," A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.5),\n"," A.OneOf([\n"," A.MotionBlur(blur_limit=7, p=0.5),\n"," A.MedianBlur(blur_limit=7, p=0.5),\n"," A.GaussianBlur(blur_limit=7, p=0.5),\n"," ], p=0.3),\n"," A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, border_mode=0, p=0.5),\n"," A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n"," ], bbox_params=bboxes_params)\n","\n"," elif set == 'validation':\n"," transform = A.Compose([ # TODO: update pipeline\n"," #A.Resize(height=640, width=640),\n"," #A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n"," ], bbox_params=bboxes_params)\n","\n"," return transform"],"metadata":{"id":"3ppVnY8EhUBm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, 'labels.json'), 'r') as f:\n"," train_json = json.load(f)\n","\n","with open(os.path.join(validation_dataset_path, 'labels.json'), 'r') as f:\n"," val_json = json.load(f)\n","\n","with open(os.path.join(test_dataset_path, 'labels.json'), 'r') as f:\n"," test_json = json.load(f)\n","\n","n_images = max(train_json['images'][-1]['id'], val_json['images'][-1]['id'], test_json['images'][-1]['id'])\n","n_annotations = max(train_json['annotations'][-1]['id'], val_json['annotations'][-1]['id'], test_json['annotations'][-1]['id'])"],"metadata":{"id":"4l6m63pOLFhV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def clip_bbox(bbox, image_width, image_height):\n"," x_min, y_min, width, height = bbox\n","\n"," x_min = max(0, min(x_min, image_width - 1)) # TODO: check -1\n"," y_min = max(0, min(y_min, image_height - 1)) # TODO: check -1\n"," width = min(width, image_width - x_min)\n"," height = min(height, image_height - y_min)\n","\n"," return [x_min, y_min, width, height]"],"metadata":{"id":"7BO5fOS6PnEB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def validate_bbox(bbox, image_width, image_height):\n"," x, y, w, h = bbox\n","\n"," return 0 <= x < image_width and 0 <= y < image_height and x + w <= image_width and y + h <= image_height"],"metadata":{"id":"KVDrkEmpXGD6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def apply_augmentation(image_path, bboxes, class_labels, output_path, output_filename, transform):\n"," # Read the image\n"," image = cv2.imread(image_path)\n"," image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n"," image_height, image_width = image.shape[:2]\n","\n"," # Apply the augmentation\n"," try:\n"," transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels)\n"," except Exception as e:\n"," print(f\"Error during transformation: {e}\")\n"," return [], []\n","\n"," # Save the augmented image\n"," augmented_image_path = os.path.join(output_path, output_filename)\n"," cv2.imwrite(augmented_image_path, cv2.cvtColor(transformed['image'], cv2.COLOR_RGB2BGR))\n","\n"," return transformed['bboxes'], transformed['class_labels']"],"metadata":{"id":"eZWBxtDrhonL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def augment_dataset(input_path, output_path, transform, n_images, n_annotations, num_augmentations=5):\n"," # Load the original COCO JSON file\n"," with open(os.path.join(input_path, 'labels.json'), 'r') as f:\n"," coco_data = json.load(f)\n","\n"," new_images = []\n"," new_annotations = []\n","\n"," # Copy original images and annotations\n"," for img in tqdm(coco_data['images'], desc=\"Copying original images\"):\n","\n"," src_path = os.path.join(input_path, 'images', img['file_name'])\n"," dst_path = os.path.join(output_path, 'images', img['file_name'])\n"," shutil.copy2(src_path, dst_path)\n","\n"," new_images.append(img)\n"," img_anns = [ann for ann in coco_data['annotations'] if ann['image_id'] == img['id']]\n"," new_annotations.extend(img_anns)\n","\n"," '''debug'''\n"," print(\"Before augmentation:\")\n"," print(f\"Number of images: {len(new_images)}\")\n"," print(f\"Number of annotations: {len(new_annotations)}\")\n","\n"," # Apply augmentations\n"," for img in tqdm(coco_data['images'], desc=\"Augmenting images\"):\n"," image_path = os.path.join(input_path, 'images', img['file_name'])\n","\n"," annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] == img['id']]\n","\n"," image = cv2.imread(image_path)\n"," image_height, image_width = image.shape[:2]\n","\n"," for i in range(num_augmentations):\n"," bboxes = [ann['bbox'] for ann in annotations]\n"," class_labels = [ann['category_id'] for ann in annotations]\n","\n"," # TODO: should I call the function clip_bbox() regardless of the function validate_bbox()?\n"," for bbox in bboxes:\n"," if not validate_bbox(bbox, image_width, image_height):\n"," bboxes = [clip_bbox(bbox, image_width, image_height) for bbox in bboxes]\n","\n"," new_filename = f\"{os.path.splitext(img['file_name'])[0]}_aug_{i}.jpg\"\n","\n"," new_bboxes, new_class_labels = apply_augmentation(\n"," image_path, bboxes, class_labels,\n"," os.path.join(output_path, 'images'), new_filename, transform\n"," )\n","\n"," new_img_id = n_images + 1\n"," new_images.append({\n"," 'id': new_img_id,\n"," 'file_name': new_filename\n"," })\n","\n"," n_images = n_images + 1\n","\n"," for bbox, cat_id in zip(new_bboxes, new_class_labels):\n"," new_annotations.append({\n"," 'id': n_annotations + 1,\n"," 'image_id': new_img_id,\n"," 'category_id': cat_id,\n"," 'bbox': bbox\n"," })\n","\n"," n_annotations = n_annotations + 1\n","\n"," '''debug'''\n"," print(\"After augmentation:\")\n"," print(f\"Number of images: {len(new_images)}\")\n"," print(f\"Number of annotations: {len(new_annotations)}\")\n","\n"," # Create the new COCO JSON file\n"," new_coco_data = {\n"," 'categories': coco_data['categories'],\n"," 'images': new_images,\n"," 'annotations': new_annotations\n"," }\n","\n"," # Save the new COCO JSON file\n"," with open(os.path.join(output_path, 'labels.json'), 'w') as f:\n"," json.dump(new_coco_data, f, indent=4)"],"metadata":{"id":"qzU7uPm9iuny","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["augmented_train_dataset_path = dest_base_path + 'train_augmented/'\n","augmented_validation_dataset_path = dest_base_path + 'validation_augmented/'\n","\n","os.makedirs(os.path.join(augmented_train_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(augmented_validation_dataset_path, 'images'), exist_ok=True)\n","\n","if os.listdir(augmented_train_dataset_path + 'images/') == []:\n"," augment_dataset(train_dataset_path, augmented_train_dataset_path, get_transform('train'), n_images, n_annotations, num_augmentations=5)\n","else:\n"," print(\"Augmentation on the training set has already been made.\")\n","\n","if os.listdir(augmented_validation_dataset_path + 'images/') == []:\n"," augment_dataset(validation_dataset_path, augmented_validation_dataset_path, get_transform('validation'), n_images, n_annotations, num_augmentations=5)\n","else:\n"," print(\"Augmentation on the validation set has already been made.\")"],"metadata":{"id":"PEdCirIuvIbZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["count1 = sum(1 for filename in os.listdir(os.path.join(augmented_train_dataset_path, 'images')) if any(filename.lower().endswith(ext) for ext in ['.jpg', '.jpeg']))\n","count2 = sum(1 for filename in os.listdir(os.path.join(augmented_validation_dataset_path, 'images')) if any(filename.lower().endswith(ext) for ext in ['.jpg', '.jpeg']))\n","\n","print(f\"Number of images in the train_augmented folder: {count1}\") # TODO: make prettier\n","print(f\"Number of images in the validation_augmented folder: {count2}\") # TODO: make prettier"],"metadata":{"id":"ER3exrc3ciP5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Visualize the augmented training data [FIXME]\n","\n","visualize(augmented_train_dataset_path, 12)"],"metadata":{"collapsed":true,"id":"lhULs9Lxjgyt","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Rewrite train dataset"],"metadata":{"id":"zbBGHtdpJ98E"}},{"cell_type":"code","source":["# TODO: add if condition (if augmentation has been executed)"],"metadata":{"id":"HBithwlITduj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["shutil.rmtree(\"/tmp/od_data/train\") # TODO: do I need this instruction ? do I need to use other cache dirs?"],"metadata":{"id":"pUpntlCgYgUh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(augmented_train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir=\"/tmp/od_data/validation\")\n","\n","print(f\"{'New Training Dataset Size:':<25} {train_data.size:>6} images\")\n","print(f\"{'New Validation Dataset Size:':<25} {validation_data.size:>4} images\")"],"metadata":{"id":"2A_QSKYdUs2U"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Retrain model"],"metadata":{"id":"IzhODm2CJaMB"}},{"cell_type":"markdown","source":["### Set retraining options"],"metadata":{"id":"jF7sZHYyJcl7"}},{"cell_type":"code","source":["spec = object_detector.SupportedModels.MOBILENET_MULTI_AVG_I384\n","\n","hparams = object_detector.HParams(\n"," learning_rate=0.015, # reduce to 0.01 (is it possible to implement a scheduler?)\n"," batch_size=64, # try 128, 256\n"," epochs=100,\n"," cosine_decay_epochs=100,\n"," cosine_decay_alpha=0.1,\n"," #shuffle=True, # TODO: check\n"," #repeat=True, # TODO: check\n"," export_dir='exported_model'\n",")\n","\n","model_options = object_detector.ModelOptions(\n"," l2_weight_decay=1e-4 # 3e-5\n",")\n","\n","options = object_detector.ObjectDetectorOptions(\n"," supported_model=spec,\n"," hparams=hparams,\n"," model_options=model_options\n",")"],"metadata":{"id":"iC6cVSpVJWFw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Run retraining"],"metadata":{"id":"bfQoIMPOJhK5"}},{"cell_type":"code","source":["model = object_detector.ObjectDetector.create(\n"," train_data=train_data,\n"," validation_data=validation_data,\n"," options=options\n",")"],"metadata":{"id":"BxT1UHQWJfAX","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Evaluate the model performance\n","\n","After training the model, evaluate it on validation dataset and print the loss and coco_metrics. The most important metric for evaluating the model performance is typically the \"AP\" coco metric for Average Precision."],"metadata":{"id":"y3h_cOytJjXk"}},{"cell_type":"code","source":["loss, coco_metrics = model.evaluate(validation_data, batch_size=4) # TODO: update batch_size (?)\n","print(f\"Validation loss: {loss}\")\n","print(f\"Validation coco metrics: {coco_metrics}\")"],"metadata":{"id":"kJabrlEjJkmH","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Export model\n","\n","After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label map."],"metadata":{"id":"Vkg5K-xWJmj6"}},{"cell_type":"code","source":["# TODO: do I need to remove the existing model first?"],"metadata":{"id":"eQ3BCV__LXaB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model.export_model()\n","!ls exported_model\n","files.download('exported_model/model.tflite')"],"metadata":{"id":"ADUK65tBJn7x","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Model quantization"],"metadata":{"id":"6g6QkA6YJpf8"}},{"cell_type":"markdown","source":["Model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy.\n","\n","This section of the guide explains how to apply quantization to your model. Model Maker supports two forms of quantization for object detector:\n","1. Quantization Aware Training: 8 bit integer precision for CPU usage\n","2. Post-Training Quantization: 16 bit floating point precision for GPU usage"],"metadata":{"id":"Jl1Dc9EhJtRo"}},{"cell_type":"markdown","source":["### Quantization aware training (int8 quantization)\n","Quantization aware training (QAT) is a fine-tuning step which happens after fully training your model. This technique further tunes a model which emulates inference time quantization in order to account for the lower precision of 8 bit integer quantization. For on-device applications with a standard CPU, use Int8 precision. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/training) documentation.\n","\n","To apply quantization aware training and export to an int8 model, create a `QATHParams` configuration and run the `quantization_aware_training` method. See the **Hyperparameters** section below on detailed usage of `QATHParams`."],"metadata":{"id":"D88f-mqgJvwi"}},{"cell_type":"code","source":["qat_hparams = object_detector.QATHParams(learning_rate=0.3, batch_size=4, epochs=10, decay_steps=6, decay_rate=0.96)\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"NbTDJtEJJr7Y"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["The QAT step often requires multiple runs to tune the parameters of training. To avoid having to rerun model training using the `create` method, use the `restore_float_ckpt` method to restore the model state back to the fully trained float model(After running the `create` method) in order to run QAT again."],"metadata":{"id":"5KEPghm9JyUi"}},{"cell_type":"code","source":["new_qat_hparams = object_detector.QATHParams(learning_rate=0.9, batch_size=4, epochs=15, decay_steps=5, decay_rate=0.96)\n","model.restore_float_ckpt()\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=new_qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"vALj8IqaJ1B4"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Finally, us the `export_model` to export to an int8 quantized model. The `export_model` function will automatically export to either float32 or int8 model depending on whether `quantization_aware_training` was run."],"metadata":{"id":"HdZAaTJKJ4ky"}},{"cell_type":"code","source":["model.export_model('model_int8_qat.tflite')\n","!ls -lh exported_model\n","files.download('exported_model/model_int8_qat.tflite')"],"metadata":{"id":"AePYkJHwJ2OG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Post-training quantization (fp16 quantization)\n","\n","Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 16-bit floats. Float16 quantization is reccomended for GPU usage. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/post_training) documentation.\n","\n","First, import the MediaPipe Model Maker quantization module:"],"metadata":{"id":"mm4gCymvJ7wK"}},{"cell_type":"code","source":["from mediapipe_model_maker import quantization"],"metadata":{"id":"mt8zeY52J8Gk"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Define a QuantizationConfig object using the `for_float16()` class method. This configuration modifies a trained model to use 16-bit floating point numbers instead of 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class."],"metadata":{"id":"r6eGgib4J-Ac"}},{"cell_type":"code","source":["quantization_config = quantization.QuantizationConfig.for_float16()"],"metadata":{"id":"IjAn-GBKJ_ab"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Export the model using the additional quantization_config object to apply post-training quantization. Note that if you previously ran `quantization_aware_training`, you must first convert the model back to a float model by using `restore_float_ckpt`."],"metadata":{"id":"ChqPtssVKCA8"}},{"cell_type":"code","source":["model.restore_float_ckpt()\n","model.export_model(model_name=\"model_fp16.tflite\", quantization_config=quantization_config)\n","!ls -lh exported_model\n","files.download('exported_model/model_fp16.tflite')"],"metadata":{"id":"f90WxGNNKCSW"},"execution_count":null,"outputs":[]}],"metadata":{"colab":{"last_runtime":{"build_target":"//learning/grp/tools/ml_python:ml_notebook","kind":"private"},"private_outputs":true,"provenance":[{"file_id":"https://github.com/googlesamples/mediapipe/blob/main/examples/customization/object_detector.ipynb","timestamp":1720270505355},{"file_id":"11PG1YgsQWWLJ8jpqJ6QY7hjYWzxVwoCb","timestamp":1677706798050}],"gpuType":"L4","collapsed_sections":["GYUY-Sw33nJD","ThZpHZPKKK1O","Vkg5K-xWJmj6","6g6QkA6YJpf8"]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0} \ No newline at end of file +{"cells":[{"cell_type":"markdown","source":["# Object detection model customization"],"metadata":{"id":"rQIYsEciHsMN"}},{"cell_type":"markdown","source":["## Setup"],"metadata":{"id":"NReBxyCQH1Rx"}},{"cell_type":"markdown","source":["To install the libraries for customizing a model, run the following commands:"],"metadata":{"id":"xbuuejtPH97m"}},{"cell_type":"code","source":["!python --version\n","!pip install --upgrade pip\n","!pip install mediapipe-model-maker"],"metadata":{"id":"ceFO9GPSHyRv","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Use the following code to import the required Python classes:"],"metadata":{"id":"2beHFLLIIIv0"}},{"cell_type":"code","source":["from google.colab import files\n","import os\n","import json\n","from tqdm import tqdm\n","import tensorflow as tf\n","assert tf.__version__.startswith('2')\n","\n","from mediapipe_model_maker import object_detector\n","\n","from google.colab import drive\n","import shutil\n","\n","%matplotlib inline"],"metadata":{"id":"0OpjgSuBIJJC","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Colab Pro"],"metadata":{"id":"GYUY-Sw33nJD"}},{"cell_type":"code","source":["gpu_info = !nvidia-smi\n","gpu_info = '\\n'.join(gpu_info)\n","if gpu_info.find('failed') >= 0:\n"," print('Not connected to a GPU')\n","else:\n"," print(gpu_info)"],"metadata":{"id":"9bHLSYKW3pEt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from psutil import virtual_memory\n","ram_gb = virtual_memory().total / 1e9\n","print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n","\n","if ram_gb < 20:\n"," print('Not using a high-RAM runtime')\n","else:\n"," print('You are using a high-RAM runtime!')"],"metadata":{"id":"VvbzIdMd3s5N"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Set paths"],"metadata":{"id":"ThZpHZPKKK1O"}},{"cell_type":"code","source":["# Mount Google Drive\n","drive.mount('/content/drive')"],"metadata":{"id":"-Lrh4ORlDGgu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["base_path = '/content/drive/MyDrive/'\n","source_path = base_path + 'Datasets/revisitop/rparis6k/data/'\n","\n","dest_base_path = base_path + 'MyProject/rparis6k/'\n","\n","train_dataset_path = dest_base_path + 'train/'\n","validation_dataset_path = dest_base_path + 'validation/'\n","test_dataset_path = dest_base_path + 'test/'"],"metadata":{"id":"BlfRX-r_3YUc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Create directories\n","\n","os.makedirs(dest_base_path, exist_ok=True)\n","\n","os.makedirs(train_dataset_path, exist_ok=True)\n","os.makedirs(validation_dataset_path, exist_ok=True)\n","os.makedirs(test_dataset_path, exist_ok=True)\n","\n","os.makedirs(os.path.join(train_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(validation_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(test_dataset_path, 'images'), exist_ok=True)"],"metadata":{"id":"4auASQMqjYRt"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Prepare data"],"metadata":{"id":"Yq6UT0j9ILD_"}},{"cell_type":"markdown","source":["### Copy images"],"metadata":{"id":"Ai5-oF_zjkxz"}},{"cell_type":"code","source":["# FIXME:\n","def check_and_delete_augmented_images(folder_path):\n"," augmented_images = [f for f in os.listdir(folder_path) if 'aug' in f]\n","\n"," if len(augmented_images) > 0:\n"," print(f\"Found {len(augmented_images)} augmented images in {folder_path}.\")\n"," user_input = input(\"Do you want to delete these images? (yes/no): \").strip().lower()\n","\n"," if user_input == 'yes':\n"," # delete all the elements in the image folder\n"," # delete the labels.json file\n"," for img in tqdm(augmented_images, desc=\"Deleting augmented images\"):\n"," img_path = os.path.join(folder_path, img)\n"," os.remove(img_path)\n"," print(\"Augmented images deleted successfully.\")\n"," else:\n"," print(\"Deletion aborted by user.\")\n"," else:\n"," print(\"No augmented images found.\")"],"metadata":{"id":"tttYb0xMSlir"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: mention that the train set contain augmented images if it's the case\n","# TODO: use also a boolean value to memorize the aug data deletion"],"metadata":{"id":"Jsgo4eA9R0xv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Function to copy images\n","def copy_images(file_list, dest_folder):\n"," with open(file_list, 'r') as f:\n"," lines = f.readlines()\n"," for line in tqdm(lines, desc=f\"Copying images to {dest_folder}\"):\n"," img_name = line.strip()\n"," src = os.path.join(source_path, img_name)\n"," dst = os.path.join(dest_folder, img_name)\n"," os.makedirs(os.path.dirname(dst), exist_ok=True)\n"," shutil.copy2(src, dst)\n","\n","# Copy images for each set\n","if os.listdir(train_dataset_path + 'images/') == [] and \\\n"," os.listdir(validation_dataset_path + 'images/') == [] and \\\n"," os.listdir(test_dataset_path + 'images/') == []:\n"," copy_images(dest_base_path + 'train.txt', train_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'val.txt', validation_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'test.txt', test_dataset_path + 'images/')\n"," print(\"Dataset division completed!\\n\")\n","else:\n"," print(\"One or more directories are not empty. Copy operation aborted.\\n\")\n","\n","print(f\"Number of images in train set: {len(os.listdir(train_dataset_path + 'images/'))}\")\n","print(f\"Number of images in validation set: {len(os.listdir(validation_dataset_path + 'images/'))}\")\n","print(f\"Number of images in test set: {len(os.listdir(test_dataset_path + 'images/'))}\")"],"metadata":{"id":"-A3ZtrofqptN"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Review data"],"metadata":{"id":"IUcrpj6aIzqt"}},{"cell_type":"markdown","source":["Verify the dataset content by printing the categories from the `labels.json` file. There should be 13 total categories. Index 0 is always set to be the `background` class which may be unused in the dataset."],"metadata":{"id":"xYXzxYC3I0h5"}},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","for category_item in labels_json[\"categories\"]:\n"," print(f\"{category_item['id']}: {category_item['name']}\")"],"metadata":{"id":"v6rYvXVwIwKp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Visualize the training data [FIXME]\n","import matplotlib.pyplot as plt\n","from matplotlib import patches, text, patheffects\n","from collections import defaultdict\n","import math\n","\n","# TODO: it may be interesting if it shows two (or n) random images per category\n","\n","def draw_outline(obj):\n"," obj.set_path_effects([patheffects.Stroke(linewidth=4, foreground='black'), patheffects.Normal()])\n","\n","def draw_box(ax, bb):\n"," patch = ax.add_patch(patches.Rectangle((bb[0],bb[1]), bb[2], bb[3], fill=False, edgecolor='red', lw=2))\n"," draw_outline(patch)\n","\n","def draw_text(ax, bb, txt, disp):\n"," text = ax.text(bb[0], (bb[1]-disp), txt, verticalalignment='top', color='white', fontsize=10, weight='bold')\n"," draw_outline(text)\n","\n","def draw_bbox(ax, annotations_list, id_to_label, image_shape):\n"," for annotation in annotations_list:\n"," cat_id = annotation[\"category_id\"]\n"," bbox = annotation[\"bbox\"]\n"," draw_box(ax, bbox)\n"," draw_text(ax, bbox, id_to_label[cat_id], image_shape[0] * 0.05)\n","\n","def visualize(dataset_folder, max_examples=None):\n"," with open(os.path.join(dataset_folder, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","\n"," images = labels_json[\"images\"]\n"," cat_id_to_label = {item[\"id\"]:item[\"name\"] for item in labels_json[\"categories\"]}\n"," image_annots = defaultdict(list)\n","\n"," for annotation_obj in labels_json[\"annotations\"]:\n"," image_id = annotation_obj[\"image_id\"]\n"," image_annots[image_id].append(annotation_obj)\n","\n"," if max_examples is None:\n"," max_examples = len(image_annots.items())\n","\n"," n_rows = math.ceil(max_examples / 3)\n"," fig, axs = plt.subplots(n_rows, 3, figsize=(24, n_rows*8)) # 3 columns(2nd index), 8x8 for each image\n","\n"," for ind, (image_id, annotations_list) in enumerate(list(image_annots.items())[:max_examples]):\n"," ax = axs[ind//3, ind%3]\n"," img = plt.imread(os.path.join(dataset_folder, \"images\", images[image_id][\"file_name\"]))\n"," ax.imshow(img)\n"," draw_bbox(ax, annotations_list, cat_id_to_label, img.shape)\n","\n"," plt.show()\n","\n","visualize(train_dataset_path, 9)"],"metadata":{"id":"8D17VhVAI33W","collapsed":true,"cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Create dataset"],"metadata":{"id":"fPyRkkKYJEOB"}},{"cell_type":"code","source":["# TODO: do I need this instruction ?\n","\n","cache_dirs = [\"/tmp/od_data/train\", \"/tmp/od_data/validation\"]\n","\n","for cache_dir in cache_dirs:\n"," if os.path.exists(cache_dir):\n"," shutil.rmtree(cache_dir)"],"metadata":{"id":"6MICrUKmFEsS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir=\"/tmp/od_data/validation\")\n","\n","print(f\"{'Training Dataset Size:':<25} {train_data.size:>4}\")\n","print(f\"{'Validation Dataset Size:':<25} {validation_data.size:>4}\")"],"metadata":{"id":"cooLOJrmJDbo"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Augmentation"],"metadata":{"id":"JGLtQeX3UG_s"}},{"cell_type":"markdown","source":["### Augment data"],"metadata":{"id":"77QQeKsFKCBx"}},{"cell_type":"code","source":["import albumentations as A\n","import numpy as np\n","import cv2"],"metadata":{"id":"7R9rK4PXUKsa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: check the aug data deletion boolean value\n","# if true -> ask to execute augmentation\n","# if false -> ask to delete aug data before and then to execute augmentation (update train_data_path)"],"metadata":{"id":"aOLGUAwkTGq2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def get_transform(set='train'):\n"," bboxes_params = A.BboxParams(format='coco', min_visibility=0.3, label_fields=['class_labels']) # TODO: check min_visibility\n","\n"," if set == 'train':\n"," transform = A.Compose([ # TODO: update pipeline (?)\n"," # TODO: do I need to resize images?\n"," #A.RandomResizedCrop(height=640, width=640, scale=(0.8, 1.0), ratio=(0.9, 1.1), p=1.0), # TODO: check h,w\n"," A.HorizontalFlip(p=0.5),\n"," A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),\n"," A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),\n"," A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),\n"," A.RandomShadow(num_shadows_lower=1, num_shadows_upper=3, shadow_dimension=5, shadow_roi=(0, 0.5, 1, 1), p=0.3),\n"," A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.5),\n"," A.OneOf([\n"," A.MotionBlur(blur_limit=7, p=0.5),\n"," A.MedianBlur(blur_limit=7, p=0.5),\n"," A.GaussianBlur(blur_limit=7, p=0.5),\n"," ], p=0.3),\n"," A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, border_mode=0, p=0.5),\n"," A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # TODO: check\n"," ], bbox_params=bboxes_params)\n","\n"," elif set == 'validation':\n"," transform = A.Compose([ # TODO: update pipeline\n"," # TODO: do I need to resize images?\n"," A.HorizontalFlip(p=0.5),\n"," A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.5),\n"," A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), # TODO: check\n"," ], bbox_params=bboxes_params)\n","\n"," return transform"],"metadata":{"id":"3ppVnY8EhUBm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, 'labels.json'), 'r') as f:\n"," train_json = json.load(f)\n","\n","with open(os.path.join(validation_dataset_path, 'labels.json'), 'r') as f:\n"," val_json = json.load(f)\n","\n","with open(os.path.join(test_dataset_path, 'labels.json'), 'r') as f:\n"," test_json = json.load(f)\n","\n","n_images = max(train_json['images'][-1]['id'], val_json['images'][-1]['id'], test_json['images'][-1]['id'])\n","n_annotations = max(train_json['annotations'][-1]['id'], val_json['annotations'][-1]['id'], test_json['annotations'][-1]['id'])"],"metadata":{"id":"4l6m63pOLFhV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def clip_bbox(bbox, image_width, image_height):\n"," x_min, y_min, width, height = bbox\n","\n"," x_min = max(0, min(x_min, image_width - 1)) # TODO: check -1\n"," y_min = max(0, min(y_min, image_height - 1)) # TODO: check -1\n"," width = min(width, image_width - x_min)\n"," height = min(height, image_height - y_min)\n","\n"," return [x_min, y_min, width, height]"],"metadata":{"id":"7BO5fOS6PnEB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def validate_bbox(bbox, image_width, image_height):\n"," x, y, w, h = bbox\n","\n"," return 0 <= x < image_width and 0 <= y < image_height and x + w <= image_width and y + h <= image_height"],"metadata":{"id":"KVDrkEmpXGD6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def apply_augmentation(image_path, bboxes, class_labels, output_path, output_filename, transform):\n"," # Read the image\n"," image = cv2.imread(image_path)\n"," image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n"," image_height, image_width = image.shape[:2]\n","\n"," # Apply the augmentation\n"," try:\n"," transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels)\n"," except Exception as e:\n"," print(f\"Error during transformation: {e}\")\n"," return [], []\n","\n"," # Save the augmented image\n"," augmented_image_path = os.path.join(output_path, output_filename)\n"," cv2.imwrite(augmented_image_path, cv2.cvtColor(transformed['image'], cv2.COLOR_RGB2BGR))\n","\n"," return transformed['bboxes'], transformed['class_labels']"],"metadata":{"id":"eZWBxtDrhonL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def augment_dataset(input_path, output_path, transform, n_images, n_annotations, num_augmentations=5):\n"," # Load the original COCO JSON file\n"," with open(os.path.join(input_path, 'labels.json'), 'r') as f:\n"," coco_data = json.load(f)\n","\n"," new_images = []\n"," new_annotations = []\n","\n"," # Copy original images and annotations\n"," for img in tqdm(coco_data['images'], desc=\"Copying original images\"):\n","\n"," src_path = os.path.join(input_path, 'images', img['file_name'])\n"," dst_path = os.path.join(output_path, 'images', img['file_name'])\n"," shutil.copy2(src_path, dst_path)\n","\n"," new_images.append(img)\n"," img_anns = [ann for ann in coco_data['annotations'] if ann['image_id'] == img['id']]\n"," new_annotations.extend(img_anns)\n","\n"," '''debug'''\n"," print(\"Before augmentation:\")\n"," print(f\"Number of images: {len(new_images)}\")\n"," print(f\"Number of annotations: {len(new_annotations)}\")\n","\n"," # Apply augmentations\n"," for img in tqdm(coco_data['images'], desc=\"Augmenting images\"):\n"," image_path = os.path.join(input_path, 'images', img['file_name'])\n","\n"," annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] == img['id']]\n","\n"," image = cv2.imread(image_path)\n"," image_height, image_width = image.shape[:2]\n","\n"," for i in range(num_augmentations):\n"," bboxes = [ann['bbox'] for ann in annotations]\n"," class_labels = [ann['category_id'] for ann in annotations]\n","\n"," # TODO: should I call the function clip_bbox() regardless of the function validate_bbox()?\n"," for bbox in bboxes:\n"," if not validate_bbox(bbox, image_width, image_height):\n"," bboxes = [clip_bbox(bbox, image_width, image_height) for bbox in bboxes]\n","\n"," new_filename = f\"{os.path.splitext(img['file_name'])[0]}_aug_{i}.jpg\"\n","\n"," new_bboxes, new_class_labels = apply_augmentation(\n"," image_path, bboxes, class_labels,\n"," os.path.join(output_path, 'images'), new_filename, transform\n"," )\n","\n"," new_img_id = n_images + 1\n"," new_images.append({\n"," 'id': new_img_id,\n"," 'file_name': new_filename\n"," })\n","\n"," n_images = n_images + 1\n","\n"," for bbox, cat_id in zip(new_bboxes, new_class_labels):\n"," new_annotations.append({\n"," 'id': n_annotations + 1,\n"," 'image_id': new_img_id,\n"," 'category_id': cat_id,\n"," 'bbox': bbox\n"," })\n","\n"," n_annotations = n_annotations + 1\n","\n"," '''debug'''\n"," print(\"After augmentation:\")\n"," print(f\"Number of images: {len(new_images)}\")\n"," print(f\"Number of annotations: {len(new_annotations)}\")\n","\n"," # Create the new COCO JSON file\n"," new_coco_data = {\n"," 'categories': coco_data['categories'],\n"," 'images': new_images,\n"," 'annotations': new_annotations\n"," }\n","\n"," # Save the new COCO JSON file\n"," with open(os.path.join(output_path, 'labels.json'), 'w') as f:\n"," json.dump(new_coco_data, f, indent=4)\n","\n"," return n_images, n_annotations"],"metadata":{"id":"qzU7uPm9iuny","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["augmented_train_dataset_path = dest_base_path + 'train_augmented/'\n","augmented_validation_dataset_path = dest_base_path + 'validation_augmented/'\n","\n","os.makedirs(os.path.join(augmented_train_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(augmented_validation_dataset_path, 'images'), exist_ok=True)\n","\n","if os.listdir(augmented_train_dataset_path + 'images/') == []:\n"," n_images, n_annotations = augment_dataset(train_dataset_path, augmented_train_dataset_path, get_transform('train'), n_images, n_annotations, num_augmentations=5)\n","else:\n"," print(\"Augmentation on the training set has already been made.\")\n","\n","if os.listdir(augmented_validation_dataset_path + 'images/') == []:\n"," augment_dataset(validation_dataset_path, augmented_validation_dataset_path, get_transform('validation'), n_images, n_annotations, num_augmentations=5)\n","else:\n"," print(\"Augmentation on the validation set has already been made.\")"],"metadata":{"id":"PEdCirIuvIbZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["count1 = sum(1 for filename in os.listdir(os.path.join(augmented_train_dataset_path, 'images')) if any(filename.lower().endswith(ext) for ext in ['.jpg', '.jpeg']))\n","count2 = sum(1 for filename in os.listdir(os.path.join(augmented_validation_dataset_path, 'images')) if any(filename.lower().endswith(ext) for ext in ['.jpg', '.jpeg']))\n","\n","print(f\"Number of images in the train_augmented folder: {count1}\") # TODO: make prettier\n","print(f\"Number of images in the validation_augmented folder: {count2}\") # TODO: make prettier"],"metadata":{"id":"ER3exrc3ciP5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Visualize the augmented training data [FIXME]\n","\n","visualize(augmented_train_dataset_path, 12)"],"metadata":{"collapsed":true,"id":"lhULs9Lxjgyt","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Rewrite train dataset"],"metadata":{"id":"zbBGHtdpJ98E"}},{"cell_type":"code","source":["# TODO: add if condition (if augmentation has been executed)"],"metadata":{"id":"HBithwlITduj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["shutil.rmtree(\"/tmp/od_data/augmented_train\") # TODO: do I need this instruction ?\n","shutil.rmtree(\"/tmp/od_data/augmented_validation\") # TODO: do I need this instruction ?"],"metadata":{"id":"pUpntlCgYgUh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(augmented_train_dataset_path, cache_dir=\"/tmp/od_data/augmented_train\")\n","validation_data = object_detector.Dataset.from_coco_folder(augmented_validation_dataset_path, cache_dir=\"/tmp/od_data/augmented_validation\")\n","\n","print(f\"{'New Training Dataset Size:':<25} {train_data.size:>6} images\")\n","print(f\"{'New Validation Dataset Size:':<25} {validation_data.size:>4} images\")"],"metadata":{"id":"2A_QSKYdUs2U"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Retrain model"],"metadata":{"id":"IzhODm2CJaMB"}},{"cell_type":"markdown","source":["### Set retraining options"],"metadata":{"id":"jF7sZHYyJcl7"}},{"cell_type":"code","source":["spec = object_detector.SupportedModels.MOBILENET_MULTI_AVG_I384\n","\n","hparams = object_detector.HParams(\n"," learning_rate=0.01, # 0.015 (is it possible to implement a scheduler?)\n"," batch_size=64, # try 128, 256\n"," epochs=100,\n"," cosine_decay_epochs=100,\n"," cosine_decay_alpha=0.1,\n"," shuffle=True, # TODO: check\n"," #repeat=True, # TODO: check\n"," export_dir='exported_model'\n",")\n","\n","model_options = object_detector.ModelOptions(\n"," l2_weight_decay=1e-4 # 3e-5\n",")\n","\n","options = object_detector.ObjectDetectorOptions(\n"," supported_model=spec,\n"," hparams=hparams,\n"," model_options=model_options\n",")"],"metadata":{"id":"iC6cVSpVJWFw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Run retraining"],"metadata":{"id":"bfQoIMPOJhK5"}},{"cell_type":"code","source":["model = object_detector.ObjectDetector.create(\n"," train_data=train_data,\n"," validation_data=validation_data,\n"," options=options\n",")"],"metadata":{"id":"BxT1UHQWJfAX","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Evaluate the model performance\n","\n","After training the model, evaluate it on validation dataset and print the loss and coco_metrics. The most important metric for evaluating the model performance is typically the \"AP\" coco metric for Average Precision."],"metadata":{"id":"y3h_cOytJjXk"}},{"cell_type":"code","source":["loss, coco_metrics = model.evaluate(validation_data, batch_size=32) # TODO: check batch_size\n","print(f\"Validation loss: {loss}\")\n","print(f\"Validation coco metrics: {coco_metrics}\")"],"metadata":{"id":"H9eJns4FsSXv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["loss, coco_metrics = model.evaluate(validation_data, batch_size=64) # TODO: check batch_size\n","print(f\"Validation loss: {loss}\")\n","print(f\"Validation coco metrics: {coco_metrics}\")"],"metadata":{"id":"kJabrlEjJkmH","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Export model\n","\n","After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label map."],"metadata":{"id":"Vkg5K-xWJmj6"}},{"cell_type":"code","source":["# TODO: do I need to remove the existing model first?"],"metadata":{"id":"eQ3BCV__LXaB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model.export_model()\n","!ls exported_model\n","files.download('exported_model/model.tflite')"],"metadata":{"id":"ADUK65tBJn7x","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Model quantization"],"metadata":{"id":"6g6QkA6YJpf8"}},{"cell_type":"markdown","source":["Model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy.\n","\n","This section of the guide explains how to apply quantization to your model. Model Maker supports two forms of quantization for object detector:\n","1. Quantization Aware Training: 8 bit integer precision for CPU usage\n","2. Post-Training Quantization: 16 bit floating point precision for GPU usage"],"metadata":{"id":"Jl1Dc9EhJtRo"}},{"cell_type":"markdown","source":["### Quantization aware training (int8 quantization)\n","Quantization aware training (QAT) is a fine-tuning step which happens after fully training your model. This technique further tunes a model which emulates inference time quantization in order to account for the lower precision of 8 bit integer quantization. For on-device applications with a standard CPU, use Int8 precision. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/training) documentation.\n","\n","To apply quantization aware training and export to an int8 model, create a `QATHParams` configuration and run the `quantization_aware_training` method. See the **Hyperparameters** section below on detailed usage of `QATHParams`."],"metadata":{"id":"D88f-mqgJvwi"}},{"cell_type":"code","source":["qat_hparams = object_detector.QATHParams(learning_rate=0.3, batch_size=4, epochs=10, decay_steps=6, decay_rate=0.96)\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"NbTDJtEJJr7Y"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["The QAT step often requires multiple runs to tune the parameters of training. To avoid having to rerun model training using the `create` method, use the `restore_float_ckpt` method to restore the model state back to the fully trained float model(After running the `create` method) in order to run QAT again."],"metadata":{"id":"5KEPghm9JyUi"}},{"cell_type":"code","source":["new_qat_hparams = object_detector.QATHParams(learning_rate=0.9, batch_size=4, epochs=15, decay_steps=5, decay_rate=0.96)\n","model.restore_float_ckpt()\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=new_qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"vALj8IqaJ1B4"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Finally, us the `export_model` to export to an int8 quantized model. The `export_model` function will automatically export to either float32 or int8 model depending on whether `quantization_aware_training` was run."],"metadata":{"id":"HdZAaTJKJ4ky"}},{"cell_type":"code","source":["model.export_model('model_int8_qat.tflite')\n","!ls -lh exported_model\n","files.download('exported_model/model_int8_qat.tflite')"],"metadata":{"id":"AePYkJHwJ2OG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Post-training quantization (fp16 quantization)\n","\n","Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 16-bit floats. Float16 quantization is reccomended for GPU usage. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/post_training) documentation.\n","\n","First, import the MediaPipe Model Maker quantization module:"],"metadata":{"id":"mm4gCymvJ7wK"}},{"cell_type":"code","source":["from mediapipe_model_maker import quantization"],"metadata":{"id":"mt8zeY52J8Gk"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Define a QuantizationConfig object using the `for_float16()` class method. This configuration modifies a trained model to use 16-bit floating point numbers instead of 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class."],"metadata":{"id":"r6eGgib4J-Ac"}},{"cell_type":"code","source":["quantization_config = quantization.QuantizationConfig.for_float16()"],"metadata":{"id":"IjAn-GBKJ_ab"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Export the model using the additional quantization_config object to apply post-training quantization. Note that if you previously ran `quantization_aware_training`, you must first convert the model back to a float model by using `restore_float_ckpt`."],"metadata":{"id":"ChqPtssVKCA8"}},{"cell_type":"code","source":["model.restore_float_ckpt()\n","model.export_model(model_name=\"model_fp16.tflite\", quantization_config=quantization_config)\n","!ls -lh exported_model\n","files.download('exported_model/model_fp16.tflite')"],"metadata":{"id":"f90WxGNNKCSW"},"execution_count":null,"outputs":[]}],"metadata":{"colab":{"last_runtime":{"build_target":"//learning/grp/tools/ml_python:ml_notebook","kind":"private"},"private_outputs":true,"provenance":[{"file_id":"https://github.com/googlesamples/mediapipe/blob/main/examples/customization/object_detector.ipynb","timestamp":1720270505355},{"file_id":"11PG1YgsQWWLJ8jpqJ6QY7hjYWzxVwoCb","timestamp":1677706798050}],"gpuType":"L4"},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0} \ No newline at end of file