diff --git a/README.md b/README.md index bcd8867..20f8611 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# Oxford5k-Paris6k-ObjectDetection +# Oxford5k-Paris6k-ObjectDetection [not finished yet] This project aims to create an object detection model for monument recognition using the Oxford5k and Paris6k datasets. The model is built using MediaPipe Model Maker for transfer learning, starting from a pre-trained model. diff --git a/scripts/create_annotations.py b/scripts/create_annotations.py index ef4bfac..3754cec 100644 --- a/scripts/create_annotations.py +++ b/scripts/create_annotations.py @@ -295,6 +295,39 @@ def convert_bbox(bbox): return [xmin, ymin, width, height] +def validate_bbox(bbox, image_width, image_height): + """ + Validates the bounding box to ensure it fits within the image dimensions. + + :param bbox: Bounding box in (xmin, ymin, width, height) format. + :param image_width: Width of the image. + :param image_height: Height of the image. + :return: True if the bounding box is valid, False otherwise. + """ + x, y, w, h = bbox + + return 0 <= x < image_width and 0 <= y < image_height and x + w <= image_width and y + h <= image_height + + +def clip_bbox(bbox, image_width, image_height): + """ + Clips the bounding box to fit within the image dimensions. + + :param bbox: Bounding box in (xmin, ymin, width, height) format. + :param image_width: Width of the image. + :param image_height: Height of the image. + :return: Clipped bounding box. + """ + x_min, y_min, width, height = bbox + + x_min = max(0, min(x_min, image_width - 1)) # TODO: check -1 + y_min = max(0, min(y_min, image_height - 1)) # TODO: check -1 + width = min(width, image_width - x_min) + height = min(height, image_height - y_min) + + return [x_min, y_min, width, height] + + def _process_data_xml(folder_name, data, image_folder, output_folder, monuments_list, levels=2): """ Processes the dataset to create annotations in XML format. @@ -462,6 +495,7 @@ def _process_data_json(data, image_folder, output_folder, monuments_list, levels width, height = img.size _bbox = gnd[idx]['bbx'] _bbox = convert_bbox(_bbox) # TODO: check + _bbox = clip_bbox(_bbox, width, height) # TODO: check bbox = BoundingBox(_bbox[0], _bbox[1], _bbox[2], _bbox[3]) monument = find_monument_by_query_number(idx, monuments_dict) query_images_objects[idx] = [] @@ -538,11 +572,12 @@ def _process_data_json(data, image_folder, output_folder, monuments_list, levels ymax_avg = round(ymax_avg / len(_objects[monument]), 1) _bbox = [xmin_avg, ymin_avg, xmax_avg, ymax_avg] _bbox = convert_bbox(_bbox) # TODO: check + _bbox = clip_bbox(_bbox, width, height) # TODO: check bbox = BoundingBox(_bbox[0], _bbox[1], _bbox[2], _bbox[3]) objects.append(Object(f"{monument}", "Unspecified", "0", str(difficulty), bbox)) other_images_objects[idx] = objects - i = len(annotations) # TODO: check + i = len(annotations) offset = len(qimlist) for idx in other_images_objects.keys(): for obj in other_images_objects[idx]: diff --git a/scripts/prepare_dataset.py b/scripts/prepare_dataset.py index b9ba1b3..b8ef57b 100644 --- a/scripts/prepare_dataset.py +++ b/scripts/prepare_dataset.py @@ -206,6 +206,8 @@ def prepare_dataset(dataset_name, type='xml', levels=1): print("Error: Invalid type of annotation") return + # FIXME: handle the two cases separately + # Check if annotations dir are already created if not os.path.exists(annotations_dir) or \ (os.path.exists(annotations_dir) and type == 'xml' and len(os.listdir(annotations_dir)) == 0) or \ diff --git a/training/mp_training_paris6k.ipynb b/training/mp_training_paris6k.ipynb index 4fd4cc2..b95ef82 100644 --- a/training/mp_training_paris6k.ipynb +++ b/training/mp_training_paris6k.ipynb @@ -1 +1 @@ -{"cells":[{"cell_type":"markdown","source":["# Object detection model customization"],"metadata":{"id":"rQIYsEciHsMN"}},{"cell_type":"markdown","source":["## Setup"],"metadata":{"id":"NReBxyCQH1Rx"}},{"cell_type":"markdown","source":["To install the libraries for customizing a model, run the following commands:"],"metadata":{"id":"xbuuejtPH97m"}},{"cell_type":"code","source":["!python --version\n","!pip install --upgrade pip\n","!pip install mediapipe-model-maker"],"metadata":{"id":"ceFO9GPSHyRv","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Use the following code to import the required Python classes:"],"metadata":{"id":"2beHFLLIIIv0"}},{"cell_type":"code","source":["from google.colab import files\n","import os\n","import json\n","from tqdm import tqdm\n","import tensorflow as tf\n","assert tf.__version__.startswith('2')\n","\n","from mediapipe_model_maker import object_detector\n","\n","from google.colab import drive\n","import shutil\n","\n","%matplotlib inline"],"metadata":{"id":"0OpjgSuBIJJC","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Set paths"],"metadata":{"id":"ThZpHZPKKK1O"}},{"cell_type":"code","source":["# Mount Google Drive\n","drive.mount('/content/drive')"],"metadata":{"id":"-Lrh4ORlDGgu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["base_path = '/content/drive/MyDrive/'\n","source_path = base_path + 'Datasets/revisitop/rparis6k/data/'\n","\n","dest_base_path = base_path + 'MyProject/rparis6k/'\n","\n","train_dataset_path = dest_base_path + 'train/'\n","validation_dataset_path = dest_base_path + 'validation/'\n","test_dataset_path = dest_base_path + 'test/'"],"metadata":{"id":"BlfRX-r_3YUc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["os.makedirs(dest_base_path, exist_ok=True)\n","\n","os.makedirs(train_dataset_path, exist_ok=True)\n","os.makedirs(validation_dataset_path, exist_ok=True)\n","os.makedirs(test_dataset_path, exist_ok=True)\n","\n","os.makedirs(os.path.join(train_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(validation_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(test_dataset_path, 'images'), exist_ok=True)"],"metadata":{"id":"4auASQMqjYRt"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Prepare data"],"metadata":{"id":"Yq6UT0j9ILD_"}},{"cell_type":"markdown","source":["### Copy images"],"metadata":{"id":"Ai5-oF_zjkxz"}},{"cell_type":"code","source":["# FIXME: fix check_and_delete_augmented_images function"],"metadata":{"id":"VOf82qC3Sr3k"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Function to search for augmented data\n","def check_and_delete_augmented_images(folder_path):\n"," augmented_images = [f for f in os.listdir(folder_path) if 'aug' in f]\n","\n"," if len(augmented_images) > 0:\n"," print(f\"Found {len(augmented_images)} augmented images in {folder_path}.\")\n"," user_input = input(\"Do you want to delete these images? (yes/no): \").strip().lower()\n","\n"," if user_input == 'yes':\n"," # delete all the elements in the image folder\n"," # delete the labels.json file\n"," for img in tqdm(augmented_images, desc=\"Deleting augmented images\"):\n"," img_path = os.path.join(folder_path, img)\n"," os.remove(img_path)\n"," print(\"Augmented images deleted successfully.\")\n"," else:\n"," print(\"Deletion aborted by user.\")\n"," else:\n"," print(\"No augmented images found.\")\n","\n","# folder_path = augmented_train_dataset_path"],"metadata":{"id":"tttYb0xMSlir"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: mention that the train set contain augmented images if it's the case\n","# TODO: use also a boolean value to memorize the aug data deletion"],"metadata":{"id":"Jsgo4eA9R0xv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Function to copy images\n","def copy_images(file_list, dest_folder):\n"," with open(file_list, 'r') as f:\n"," lines = f.readlines()\n"," for line in tqdm(lines, desc=f\"Copying images to {dest_folder}\"):\n"," img_name = line.strip()\n"," src = os.path.join(source_path, img_name)\n"," dst = os.path.join(dest_folder, img_name)\n"," os.makedirs(os.path.dirname(dst), exist_ok=True)\n"," shutil.copy2(src, dst)\n","\n","# Copy images for each set\n","if (len(os.listdir(train_dataset_path + 'images/')) == 0 and\n"," len(os.listdir(validation_dataset_path + 'images/')) == 0 and\n"," len(os.listdir(test_dataset_path + 'images/')) == 0):\n"," copy_images(dest_base_path + 'train.txt', train_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'val.txt', validation_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'test.txt', test_dataset_path + 'images/')\n"," print(\"Dataset division completed!\\n\")\n","else:\n"," print(\"One or more directories are not empty. Copy operation aborted.\\n\")\n","\n","print(f\"Number of images in train set: {len(os.listdir(train_dataset_path + 'images/'))}\")\n","print(f\"Number of images in validation set: {len(os.listdir(validation_dataset_path + 'images/'))}\")\n","print(f\"Number of images in test set: {len(os.listdir(test_dataset_path + 'images/'))}\")"],"metadata":{"id":"-A3ZtrofqptN"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Review data"],"metadata":{"id":"IUcrpj6aIzqt"}},{"cell_type":"markdown","source":["Verify the dataset content by printing the categories from the `labels.json` file. There should be 13 total categories. Index 0 is always set to be the `background` class which may be unused in the dataset."],"metadata":{"id":"xYXzxYC3I0h5"}},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","for category_item in labels_json[\"categories\"]:\n"," print(f\"{category_item['id']}: {category_item['name']}\")"],"metadata":{"id":"v6rYvXVwIwKp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Visualize the training data\n","import matplotlib.pyplot as plt\n","from matplotlib import patches, text, patheffects\n","from collections import defaultdict\n","import math\n","\n","# TODO: it may be interesting if it shows two random images per category\n","\n","def draw_outline(obj):\n"," obj.set_path_effects([patheffects.Stroke(linewidth=4, foreground='black'), patheffects.Normal()])\n","\n","def draw_box(ax, bb):\n"," patch = ax.add_patch(patches.Rectangle((bb[0],bb[1]), bb[2], bb[3], fill=False, edgecolor='red', lw=2))\n"," draw_outline(patch)\n","\n","def draw_text(ax, bb, txt, disp):\n"," text = ax.text(bb[0], (bb[1]-disp), txt, verticalalignment='top', color='white', fontsize=10, weight='bold')\n"," draw_outline(text)\n","\n","def draw_bbox(ax, annotations_list, id_to_label, image_shape):\n"," for annotation in annotations_list:\n"," cat_id = annotation[\"category_id\"]\n"," bbox = annotation[\"bbox\"]\n"," draw_box(ax, bbox)\n"," draw_text(ax, bbox, id_to_label[cat_id], image_shape[0] * 0.05)\n","\n","def visualize(dataset_folder, max_examples=None):\n"," with open(os.path.join(dataset_folder, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","\n"," images = labels_json[\"images\"]\n"," cat_id_to_label = {item[\"id\"]:item[\"name\"] for item in labels_json[\"categories\"]}\n"," image_annots = defaultdict(list)\n","\n"," for annotation_obj in labels_json[\"annotations\"]:\n"," image_id = annotation_obj[\"image_id\"]\n"," image_annots[image_id].append(annotation_obj)\n","\n"," if max_examples is None:\n"," max_examples = len(image_annots.items())\n","\n"," n_rows = math.ceil(max_examples / 3)\n"," fig, axs = plt.subplots(n_rows, 3, figsize=(24, n_rows*8)) # 3 columns(2nd index), 8x8 for each image\n","\n"," for ind, (image_id, annotations_list) in enumerate(list(image_annots.items())[:max_examples]):\n"," ax = axs[ind//3, ind%3]\n"," img = plt.imread(os.path.join(dataset_folder, \"images\", images[image_id][\"file_name\"]))\n"," ax.imshow(img)\n"," draw_bbox(ax, annotations_list, cat_id_to_label, img.shape)\n","\n"," plt.show()\n","\n","visualize(train_dataset_path, 9)"],"metadata":{"id":"8D17VhVAI33W","collapsed":true,"cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Create dataset"],"metadata":{"id":"fPyRkkKYJEOB"}},{"cell_type":"code","source":["if os.path.exists(\"/tmp/od_data/train\"):\n"," shutil.rmtree(\"/tmp/od_data/train\") # TODO: do I need this instruction ?"],"metadata":{"id":"6MICrUKmFEsS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: is it possible to add a progress bar?"],"metadata":{"id":"7f4Uke50b1X9"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir=\"/tmp/od_data/validation\")\n","print(\"train_data size: \", train_data.size)\n","print(\"validation_data size: \", validation_data.size)"],"metadata":{"id":"cooLOJrmJDbo"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Augmentation"],"metadata":{"id":"JGLtQeX3UG_s"}},{"cell_type":"markdown","source":["### Augment data"],"metadata":{"id":"77QQeKsFKCBx"}},{"cell_type":"code","source":["import albumentations as A\n","import numpy as np\n","import cv2"],"metadata":{"id":"7R9rK4PXUKsa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: check the aug data deletion boolean value\n","# if true -> ask to execute augmentation\n","# if false -> ask to delete aug data before and then to execute augmentation (update train_data_path)"],"metadata":{"id":"aOLGUAwkTGq2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Augmentation pipeline\n","bbox_params=A.BboxParams(format='coco', label_fields=['category_ids'])\n","\n","transform = A.Compose([\n"," A.RandomResizedCrop(height=384, width=384, scale=(0.5, 1.0), ratio=(0.8, 1.2), p=1.0), # scale=(0.8, 1.0), ratio=(0.9, 1.1)\n"," A.HorizontalFlip(p=0.5), #\n"," A.RandomBrightnessContrast(p=0.3), # p=0.2\n"," A.RGBShift(r_shift_limit=30, g_shift_limit=30, b_shift_limit=20, p=0.3), # r_shift_limit=20, g_shift_limit=20\n"," A.ShiftScaleRotate(shift_limit=0.2, scale_limit=0.2, rotate_limit=30, p=0.5), # shift_limit=0.1, scale_limit=0.1, rotate_limit=15\n"," A.RandomShadow(p=0.2), #\n"," A.CLAHE(p=0.3), #\n","], bbox_params=bbox_params) # min_area=1024 min_area=256, min_visibility=0.1"],"metadata":{"id":"3ppVnY8EhUBm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def normalize_bbox(bbox, image_width, image_height):\n"," \"\"\"\n"," Normalize bbox coordinates to [0, 1] range.\n"," \"\"\"\n"," x_min, y_min, w, h = bbox\n"," x_max = x_min + w\n"," y_max = y_min + h\n","\n"," x_min_norm = max(0, min(1, x_min / image_width))\n"," y_min_norm = max(0, min(1, y_min / image_height))\n"," x_max_norm = max(0, min(1, x_max / image_width))\n"," y_max_norm = max(0, min(1, y_max / image_height))\n","\n"," return [x_min_norm, y_min_norm, x_max_norm, y_max_norm]"],"metadata":{"id":"lQRFiNuZbEl8"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def denormalize_bbox(bbox_norm, image_width, image_height):\n"," \"\"\"\n"," Convert normalized bbox coordinates back to pixel coordinates.\n"," \"\"\"\n","\n"," x_min_norm, y_min_norm, x_max_norm, y_max_norm = bbox_norm\n","\n"," x_min = round(x_min_norm * image_width, 1)\n"," y_min = round(y_min_norm * image_height, 1)\n"," x_max = round(x_max_norm * image_width, 1)\n"," y_max = round(y_max_norm * image_height, 1)\n","\n"," w = round(x_max - x_min, 1)\n"," h = round(y_max - y_min, 1)\n","\n"," return [x_min, y_min, w, h]"],"metadata":{"id":"V38bRoNSueDk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def apply_augmentation(image_path, bboxes, category_ids, output_path, output_filename, transform):\n"," \"\"\"\n"," Apply augmentation to an image and save the results.\n","\n"," Parameters:\n"," - image_path: Path to the original image.\n"," - bboxes: Bounding boxes in the format [(x_min, y_min, x_max, y_max), ...].\n"," - category_ids: List of category IDs corresponding to each bounding box.\n"," - output_path: Directory where the augmented image will be saved.\n"," - output_filename: Filename for the augmented image.\n"," - transform: Albumentations transform to be applied.\n","\n"," Returns:\n"," - A tuple containing the transformed bounding boxes and category IDs.\n"," \"\"\"\n"," # Read the image\n"," image = cv2.imread(image_path)\n"," image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n"," image_height, image_width = image.shape[:2]\n","\n"," # Normalize bounding boxes\n"," normalized_bboxes = [normalize_bbox(bbox, image_width, image_height) for bbox in bboxes]\n","\n"," # Apply the augmentation\n"," try:\n"," transformed = transform(image=image, bboxes=normalized_bboxes, category_ids=category_ids) # FIXME: empty returning values\n"," except Exception as e:\n"," print(f\"Error during transformation: {e}\")\n"," return [], []\n","\n"," # Save the augmented image\n"," augmented_image_path = os.path.join(output_path, output_filename)\n"," #cv2.imwrite(augmented_image_path, cv2.cvtColor(transformed['image'], cv2.COLOR_RGB2BGR)) # TODO: uncomment\n","\n"," print(transformed.keys())\n","\n"," return transformed['bboxes'], transformed['category_ids']"],"metadata":{"id":"eZWBxtDrhonL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, 'labels.json'), 'r') as f:\n"," train_json = json.load(f)\n","\n","with open(os.path.join(validation_dataset_path, 'labels.json'), 'r') as f:\n"," val_json = json.load(f)\n","\n","with open(os.path.join(test_dataset_path, 'labels.json'), 'r') as f:\n"," test_json = json.load(f)\n","\n","n_images = max(train_json['images'][-1]['id'], val_json['images'][-1]['id'], test_json['images'][-1]['id'])\n","n_annotations = max(train_json['annotations'][-1]['id'], val_json['annotations'][-1]['id'], test_json['annotations'][-1]['id'])\n","\n","print(f\"Number of images: {n_images}\")\n","print(f\"Number of annotations: {n_annotations}\")"],"metadata":{"id":"4l6m63pOLFhV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def augment_dataset(input_path, output_path, transform, n_images, n_annotations, num_augmentations=5):\n"," \"\"\"\n"," Augment a dataset based on COCO format and save the augmented images and annotations.\n","\n"," Parameters:\n"," - input_path: Path to the directory containing the original dataset and 'labels.json'.\n"," - output_path: Path to the directory where the augmented dataset will be saved.\n"," - num_augmentations: Number of augmentations to apply per image.\n","\n"," The function reads the original COCO JSON file, applies specified augmentations to each image,\n"," and saves the augmented images and updated annotations in a new COCO JSON file.\n"," \"\"\"\n"," # Load the original COCO JSON file\n"," with open(os.path.join(input_path, 'labels.json'), 'r') as f:\n"," coco_data = json.load(f)\n","\n"," new_images = []\n"," new_annotations = []\n","\n"," # Copy original images and annotations\n"," for img in tqdm(coco_data['images'], desc=\"Copying original images\"):\n","\n"," # Copy original image\n"," src_path = os.path.join(input_path, 'images', img['file_name'])\n"," dst_path = os.path.join(output_path, 'images', img['file_name'])\n"," #shutil.copy2(src_path, dst_path) # TODO: uncomment\n","\n"," new_images.append(img)\n"," img_anns = [ann for ann in coco_data['annotations'] if ann['image_id'] == img['id']]\n"," new_annotations.extend(img_anns)\n","\n"," # Apply augmentations\n"," for img in tqdm(coco_data['images'], desc=\"Augmenting images\"):\n"," image_path = os.path.join(input_path, 'images', img['file_name'])\n","\n"," # Find annotations for this image\n"," annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] == img['id']]\n","\n"," # Get original image dimensions\n"," image = cv2.imread(image_path)\n"," image_height, image_width = image.shape[:2]\n","\n"," for i in range(num_augmentations):\n"," # Prepare data for augmentation\n"," bboxes = [ann['bbox'] for ann in annotations]\n"," category_ids = [ann['category_id'] for ann in annotations]\n","\n"," # Generate a new filename\n"," new_filename = f\"{os.path.splitext(img['file_name'])[0]}_aug_{i}.jpg\"\n","\n"," # Apply augmentation\n"," new_bboxes, new_category_ids = apply_augmentation( # FIXME: new_bboxes and new_category_ids are always empty\n"," image_path, bboxes, category_ids,\n"," os.path.join(output_path, 'images'), new_filename, transform\n"," )\n","\n"," new_bboxes = [denormalize_bbox(bbox, image_width, image_height) for bbox in new_bboxes]\n","\n"," # Create a new image entry\n"," new_img_id = n_images + 1\n"," new_images.append({\n"," 'id': new_img_id,\n"," 'file_name': new_filename\n"," })\n","\n"," n_images = n_images + 1\n","\n"," # Create new annotations\n"," for bbox, cat_id in zip(new_bboxes, new_category_ids):\n"," new_annotations.append({\n"," 'id': n_annotations + 1,\n"," 'image_id': new_img_id,\n"," 'category_id': cat_id,\n"," 'bbox': bbox\n"," })\n","\n"," n_annotations = n_annotations + 1\n","\n"," # Create the new COCO JSON file\n"," new_coco_data = {\n"," 'categories': coco_data['categories'],\n"," 'images': new_images,\n"," 'annotations': new_annotations\n"," }\n","\n"," ''' # TODO: uncomment\n"," # Save the new COCO JSON file\n"," with open(os.path.join(output_path, 'labels.json'), 'w') as f:\n"," json.dump(new_coco_data, f, indent=4)\n"," '''\n","\n","\n","augmented_train_dataset_path = dest_base_path + 'train_augmented/'\n","\n","os.makedirs(os.path.join(augmented_train_dataset_path, 'images'), exist_ok=True)\n","\n","if (len(os.listdir(augmented_train_dataset_path + 'images/'))) == 0:\n"," augment_dataset(train_dataset_path, augmented_train_dataset_path, transform, n_images, n_annotations, num_augmentations=5)\n","else:\n"," print(\"Augmentation has already been made.\")"],"metadata":{"id":"qzU7uPm9iuny","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["extensions=['.jpg', '.jpeg']\n","count = 0\n","count = sum(1 for filename in os.listdir(os.path.join(augmented_train_dataset_path, 'images')) if any(filename.lower().endswith(ext) for ext in extensions))\n","\n","print(f\"Number of images in the train_augmented folder: {count}\")"],"metadata":{"id":"ER3exrc3ciP5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Visualize the augmented training data\n","\n","visualize(augmented_train_dataset_path, 12) # FIXME"],"metadata":{"collapsed":true,"id":"lhULs9Lxjgyt"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Rewrite train dataset"],"metadata":{"id":"zbBGHtdpJ98E"}},{"cell_type":"code","source":["# TODO: add if condition (if augmentation has been executed)"],"metadata":{"id":"HBithwlITduj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["shutil.rmtree(\"/tmp/od_data/train\") # TODO: do I need this instruction ?"],"metadata":{"id":"pUpntlCgYgUh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(augmented_train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","print(\"Updated train_data size: \", train_data.size)"],"metadata":{"id":"2A_QSKYdUs2U"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Retrain model"],"metadata":{"id":"IzhODm2CJaMB"}},{"cell_type":"markdown","source":["### Set retraining options"],"metadata":{"id":"jF7sZHYyJcl7"}},{"cell_type":"code","source":["spec = object_detector.SupportedModels.MOBILENET_MULTI_AVG_I384\n","\n","hparams = object_detector.HParams(\n"," learning_rate=0.015,\n"," batch_size=64,\n"," epochs=100,\n"," cosine_decay_epochs=100,\n"," cosine_decay_alpha=0.1,\n"," export_dir='exported_model'\n",")\n","\n","model_options = object_detector.ModelOptions(\n"," l2_weight_decay=1e-4\n",")\n","\n","options = object_detector.ObjectDetectorOptions(\n"," supported_model=spec,\n"," hparams=hparams,\n"," model_options=model_options\n",")"],"metadata":{"id":"iC6cVSpVJWFw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Run retraining"],"metadata":{"id":"bfQoIMPOJhK5"}},{"cell_type":"code","source":["model = object_detector.ObjectDetector.create(\n"," train_data=train_data,\n"," validation_data=validation_data,\n"," options=options\n",")"],"metadata":{"id":"BxT1UHQWJfAX","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Evaluate the model performance\n","\n","After training the model, evaluate it on validation dataset and print the loss and coco_metrics. The most important metric for evaluating the model performance is typically the \"AP\" coco metric for Average Precision."],"metadata":{"id":"y3h_cOytJjXk"}},{"cell_type":"code","source":["# TODO: is it possible to add a progress bar?"],"metadata":{"id":"WaUndk7ZZgbk"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["loss, coco_metrics = model.evaluate(validation_data, batch_size=4)\n","print(f\"Validation loss: {loss}\")\n","print(f\"Validation coco metrics: {coco_metrics}\")"],"metadata":{"id":"kJabrlEjJkmH","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Export model\n","\n","After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label map."],"metadata":{"id":"Vkg5K-xWJmj6"}},{"cell_type":"code","source":["# TODO: do I need to remove the existing model first?"],"metadata":{"id":"eQ3BCV__LXaB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model.export_model()\n","!ls exported_model\n","files.download('exported_model/model.tflite')"],"metadata":{"id":"ADUK65tBJn7x","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Model quantization"],"metadata":{"id":"6g6QkA6YJpf8"}},{"cell_type":"markdown","source":["Model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy.\n","\n","This section of the guide explains how to apply quantization to your model. Model Maker supports two forms of quantization for object detector:\n","1. Quantization Aware Training: 8 bit integer precision for CPU usage\n","2. Post-Training Quantization: 16 bit floating point precision for GPU usage"],"metadata":{"id":"Jl1Dc9EhJtRo"}},{"cell_type":"markdown","source":["### Quantization aware training (int8 quantization)\n","Quantization aware training (QAT) is a fine-tuning step which happens after fully training your model. This technique further tunes a model which emulates inference time quantization in order to account for the lower precision of 8 bit integer quantization. For on-device applications with a standard CPU, use Int8 precision. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/training) documentation.\n","\n","To apply quantization aware training and export to an int8 model, create a `QATHParams` configuration and run the `quantization_aware_training` method. See the **Hyperparameters** section below on detailed usage of `QATHParams`."],"metadata":{"id":"D88f-mqgJvwi"}},{"cell_type":"code","source":["qat_hparams = object_detector.QATHParams(learning_rate=0.3, batch_size=4, epochs=10, decay_steps=6, decay_rate=0.96)\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"NbTDJtEJJr7Y"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["The QAT step often requires multiple runs to tune the parameters of training. To avoid having to rerun model training using the `create` method, use the `restore_float_ckpt` method to restore the model state back to the fully trained float model(After running the `create` method) in order to run QAT again."],"metadata":{"id":"5KEPghm9JyUi"}},{"cell_type":"code","source":["new_qat_hparams = object_detector.QATHParams(learning_rate=0.9, batch_size=4, epochs=15, decay_steps=5, decay_rate=0.96)\n","model.restore_float_ckpt()\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=new_qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"vALj8IqaJ1B4"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Finally, us the `export_model` to export to an int8 quantized model. The `export_model` function will automatically export to either float32 or int8 model depending on whether `quantization_aware_training` was run."],"metadata":{"id":"HdZAaTJKJ4ky"}},{"cell_type":"code","source":["model.export_model('model_int8_qat.tflite')\n","!ls -lh exported_model\n","files.download('exported_model/model_int8_qat.tflite')"],"metadata":{"id":"AePYkJHwJ2OG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Post-training quantization (fp16 quantization)\n","\n","Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 16-bit floats. Float16 quantization is reccomended for GPU usage. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/post_training) documentation.\n","\n","First, import the MediaPipe Model Maker quantization module:"],"metadata":{"id":"mm4gCymvJ7wK"}},{"cell_type":"code","source":["from mediapipe_model_maker import quantization"],"metadata":{"id":"mt8zeY52J8Gk"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Define a QuantizationConfig object using the `for_float16()` class method. This configuration modifies a trained model to use 16-bit floating point numbers instead of 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class."],"metadata":{"id":"r6eGgib4J-Ac"}},{"cell_type":"code","source":["quantization_config = quantization.QuantizationConfig.for_float16()"],"metadata":{"id":"IjAn-GBKJ_ab"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Export the model using the additional quantization_config object to apply post-training quantization. Note that if you previously ran `quantization_aware_training`, you must first convert the model back to a float model by using `restore_float_ckpt`."],"metadata":{"id":"ChqPtssVKCA8"}},{"cell_type":"code","source":["model.restore_float_ckpt()\n","model.export_model(model_name=\"model_fp16.tflite\", quantization_config=quantization_config)\n","!ls -lh exported_model\n","files.download('exported_model/model_fp16.tflite')"],"metadata":{"id":"f90WxGNNKCSW"},"execution_count":null,"outputs":[]}],"metadata":{"colab":{"last_runtime":{"build_target":"//learning/grp/tools/ml_python:ml_notebook","kind":"private"},"private_outputs":true,"provenance":[{"file_id":"https://github.com/googlesamples/mediapipe/blob/main/examples/customization/object_detector.ipynb","timestamp":1720270505355},{"file_id":"11PG1YgsQWWLJ8jpqJ6QY7hjYWzxVwoCb","timestamp":1677706798050}],"gpuType":"T4","collapsed_sections":["zbBGHtdpJ98E","bfQoIMPOJhK5","y3h_cOytJjXk","Vkg5K-xWJmj6","6g6QkA6YJpf8"]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0} \ No newline at end of file +{"cells":[{"cell_type":"markdown","source":["# Object detection model customization"],"metadata":{"id":"rQIYsEciHsMN"}},{"cell_type":"markdown","source":["## Setup"],"metadata":{"id":"NReBxyCQH1Rx"}},{"cell_type":"markdown","source":["To install the libraries for customizing a model, run the following commands:"],"metadata":{"id":"xbuuejtPH97m"}},{"cell_type":"code","source":["!python --version\n","!pip install --upgrade pip\n","!pip install mediapipe-model-maker"],"metadata":{"id":"ceFO9GPSHyRv","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Use the following code to import the required Python classes:"],"metadata":{"id":"2beHFLLIIIv0"}},{"cell_type":"code","source":["from google.colab import files\n","import os\n","import json\n","from tqdm import tqdm\n","import tensorflow as tf\n","assert tf.__version__.startswith('2')\n","\n","from mediapipe_model_maker import object_detector\n","\n","from google.colab import drive\n","import shutil\n","\n","%matplotlib inline"],"metadata":{"id":"0OpjgSuBIJJC","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Colab Pro"],"metadata":{"id":"GYUY-Sw33nJD"}},{"cell_type":"code","source":["gpu_info = !nvidia-smi\n","gpu_info = '\\n'.join(gpu_info)\n","if gpu_info.find('failed') >= 0:\n"," print('Not connected to a GPU')\n","else:\n"," print(gpu_info)"],"metadata":{"id":"9bHLSYKW3pEt"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["from psutil import virtual_memory\n","ram_gb = virtual_memory().total / 1e9\n","print('Your runtime has {:.1f} gigabytes of available RAM\\n'.format(ram_gb))\n","\n","if ram_gb < 20:\n"," print('Not using a high-RAM runtime')\n","else:\n"," print('You are using a high-RAM runtime!')"],"metadata":{"id":"VvbzIdMd3s5N"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Set paths"],"metadata":{"id":"ThZpHZPKKK1O"}},{"cell_type":"code","source":["# Mount Google Drive\n","drive.mount('/content/drive')"],"metadata":{"id":"-Lrh4ORlDGgu"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["base_path = '/content/drive/MyDrive/'\n","source_path = base_path + 'Datasets/revisitop/rparis6k/data/'\n","\n","dest_base_path = base_path + 'MyProject/rparis6k/'\n","\n","train_dataset_path = dest_base_path + 'train/'\n","validation_dataset_path = dest_base_path + 'validation/'\n","test_dataset_path = dest_base_path + 'test/'"],"metadata":{"id":"BlfRX-r_3YUc"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Create directories\n","\n","os.makedirs(dest_base_path, exist_ok=True)\n","\n","os.makedirs(train_dataset_path, exist_ok=True)\n","os.makedirs(validation_dataset_path, exist_ok=True)\n","os.makedirs(test_dataset_path, exist_ok=True)\n","\n","os.makedirs(os.path.join(train_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(validation_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(test_dataset_path, 'images'), exist_ok=True)"],"metadata":{"id":"4auASQMqjYRt","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Prepare data"],"metadata":{"id":"Yq6UT0j9ILD_"}},{"cell_type":"markdown","source":["### Copy images"],"metadata":{"id":"Ai5-oF_zjkxz"}},{"cell_type":"code","source":["# FIXME:\n","def check_and_delete_augmented_images(folder_path):\n"," augmented_images = [f for f in os.listdir(folder_path) if 'aug' in f]\n","\n"," if len(augmented_images) > 0:\n"," print(f\"Found {len(augmented_images)} augmented images in {folder_path}.\")\n"," user_input = input(\"Do you want to delete these images? (yes/no): \").strip().lower()\n","\n"," if user_input == 'yes':\n"," # delete all the elements in the image folder\n"," # delete the labels.json file\n"," for img in tqdm(augmented_images, desc=\"Deleting augmented images\"):\n"," img_path = os.path.join(folder_path, img)\n"," os.remove(img_path)\n"," print(\"Augmented images deleted successfully.\")\n"," else:\n"," print(\"Deletion aborted by user.\")\n"," else:\n"," print(\"No augmented images found.\")"],"metadata":{"id":"tttYb0xMSlir"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: mention that the train set contain augmented images if it's the case\n","# TODO: use also a boolean value to memorize the aug data deletion"],"metadata":{"id":"Jsgo4eA9R0xv"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# Function to copy images\n","def copy_images(file_list, dest_folder):\n"," with open(file_list, 'r') as f:\n"," lines = f.readlines()\n"," for line in tqdm(lines, desc=f\"Copying images to {dest_folder}\"):\n"," img_name = line.strip()\n"," src = os.path.join(source_path, img_name)\n"," dst = os.path.join(dest_folder, img_name)\n"," os.makedirs(os.path.dirname(dst), exist_ok=True)\n"," shutil.copy2(src, dst)\n","\n","# Copy images for each set\n","if os.listdir(train_dataset_path + 'images/') == [] and\n"," os.listdir(validation_dataset_path + 'images/') == [] and\n"," os.listdir(test_dataset_path + 'images/') == []:\n"," copy_images(dest_base_path + 'train.txt', train_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'val.txt', validation_dataset_path + 'images/')\n"," copy_images(dest_base_path + 'test.txt', test_dataset_path + 'images/')\n"," print(\"Dataset division completed!\\n\")\n","else:\n"," print(\"One or more directories are not empty. Copy operation aborted.\\n\")\n","\n","print(f\"Number of images in train set: {len(os.listdir(train_dataset_path + 'images/'))}\")\n","print(f\"Number of images in validation set: {len(os.listdir(validation_dataset_path + 'images/'))}\")\n","print(f\"Number of images in test set: {len(os.listdir(test_dataset_path + 'images/'))}\")"],"metadata":{"id":"-A3ZtrofqptN"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Review data"],"metadata":{"id":"IUcrpj6aIzqt"}},{"cell_type":"markdown","source":["Verify the dataset content by printing the categories from the `labels.json` file. There should be 13 total categories. Index 0 is always set to be the `background` class which may be unused in the dataset."],"metadata":{"id":"xYXzxYC3I0h5"}},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","for category_item in labels_json[\"categories\"]:\n"," print(f\"{category_item['id']}: {category_item['name']}\")"],"metadata":{"id":"v6rYvXVwIwKp"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Visualize the training data [FIXME]\n","import matplotlib.pyplot as plt\n","from matplotlib import patches, text, patheffects\n","from collections import defaultdict\n","import math\n","\n","# TODO: it may be interesting if it shows two (or n) random images per category\n","\n","def draw_outline(obj):\n"," obj.set_path_effects([patheffects.Stroke(linewidth=4, foreground='black'), patheffects.Normal()])\n","\n","def draw_box(ax, bb):\n"," patch = ax.add_patch(patches.Rectangle((bb[0],bb[1]), bb[2], bb[3], fill=False, edgecolor='red', lw=2))\n"," draw_outline(patch)\n","\n","def draw_text(ax, bb, txt, disp):\n"," text = ax.text(bb[0], (bb[1]-disp), txt, verticalalignment='top', color='white', fontsize=10, weight='bold')\n"," draw_outline(text)\n","\n","def draw_bbox(ax, annotations_list, id_to_label, image_shape):\n"," for annotation in annotations_list:\n"," cat_id = annotation[\"category_id\"]\n"," bbox = annotation[\"bbox\"]\n"," draw_box(ax, bbox)\n"," draw_text(ax, bbox, id_to_label[cat_id], image_shape[0] * 0.05)\n","\n","def visualize(dataset_folder, max_examples=None):\n"," with open(os.path.join(dataset_folder, \"labels.json\"), \"r\") as f:\n"," labels_json = json.load(f)\n","\n"," images = labels_json[\"images\"]\n"," cat_id_to_label = {item[\"id\"]:item[\"name\"] for item in labels_json[\"categories\"]}\n"," image_annots = defaultdict(list)\n","\n"," for annotation_obj in labels_json[\"annotations\"]:\n"," image_id = annotation_obj[\"image_id\"]\n"," image_annots[image_id].append(annotation_obj)\n","\n"," if max_examples is None:\n"," max_examples = len(image_annots.items())\n","\n"," n_rows = math.ceil(max_examples / 3)\n"," fig, axs = plt.subplots(n_rows, 3, figsize=(24, n_rows*8)) # 3 columns(2nd index), 8x8 for each image\n","\n"," for ind, (image_id, annotations_list) in enumerate(list(image_annots.items())[:max_examples]):\n"," ax = axs[ind//3, ind%3]\n"," img = plt.imread(os.path.join(dataset_folder, \"images\", images[image_id][\"file_name\"]))\n"," ax.imshow(img)\n"," draw_bbox(ax, annotations_list, cat_id_to_label, img.shape)\n","\n"," plt.show()\n","\n","visualize(train_dataset_path, 9)"],"metadata":{"id":"8D17VhVAI33W","collapsed":true,"cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Create dataset"],"metadata":{"id":"fPyRkkKYJEOB"}},{"cell_type":"code","source":["# TODO: do I need this instruction ?\n","\n","cache_dirs = [\"/tmp/od_data/train\", \"/tmp/od_data/validation\"]\n","\n","for cache_dir in cache_dirs:\n"," if os.path.exists(cache_dir):\n"," shutil.rmtree(cache_dir)"],"metadata":{"id":"6MICrUKmFEsS"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir=\"/tmp/od_data/validation\")\n","\n","print(f\"{'Training Dataset Size:':<25} {train_data.size:>4}\")\n","print(f\"{'Validation Dataset Size:':<25} {validation_data.size:>4}\")"],"metadata":{"id":"cooLOJrmJDbo"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Augmentation"],"metadata":{"id":"JGLtQeX3UG_s"}},{"cell_type":"markdown","source":["### Augment data"],"metadata":{"id":"77QQeKsFKCBx"}},{"cell_type":"code","source":["import albumentations as A\n","import numpy as np\n","import cv2"],"metadata":{"id":"7R9rK4PXUKsa"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["# TODO: check the aug data deletion boolean value\n","# if true -> ask to execute augmentation\n","# if false -> ask to delete aug data before and then to execute augmentation (update train_data_path)"],"metadata":{"id":"aOLGUAwkTGq2"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def get_transform(set='train'):\n"," bboxes_params = A.BboxParams(format='coco', min_visibility=0.3, label_fields=['class_labels']) # TODO: check min_visibility\n","\n"," if set == 'train':\n"," transform = A.Compose([ # TODO: update pipeline (?)\n"," #A.RandomResizedCrop(height=640, width=640, scale=(0.8, 1.0), ratio=(0.9, 1.1), p=1.0), # TODO: check h,w\n"," A.HorizontalFlip(p=0.5),\n"," A.RandomBrightnessContrast(brightness_limit=0.2, contrast_limit=0.2, p=0.5),\n"," A.HueSaturationValue(hue_shift_limit=20, sat_shift_limit=30, val_shift_limit=20, p=0.5),\n"," A.GaussNoise(var_limit=(10.0, 50.0), p=0.5),\n"," A.RandomShadow(num_shadows_lower=1, num_shadows_upper=3, shadow_dimension=5, shadow_roi=(0, 0.5, 1, 1), p=0.3),\n"," A.CLAHE(clip_limit=4.0, tile_grid_size=(8, 8), p=0.5),\n"," A.OneOf([\n"," A.MotionBlur(blur_limit=7, p=0.5),\n"," A.MedianBlur(blur_limit=7, p=0.5),\n"," A.GaussianBlur(blur_limit=7, p=0.5),\n"," ], p=0.3),\n"," A.ShiftScaleRotate(shift_limit=0.1, scale_limit=0.2, rotate_limit=15, border_mode=0, p=0.5),\n"," A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n"," ], bbox_params=bboxes_params)\n","\n"," elif set == 'validation':\n"," transform = A.Compose([ # TODO: update pipeline\n"," #A.Resize(height=640, width=640),\n"," #A.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),\n"," ], bbox_params=bboxes_params)\n","\n"," return transform"],"metadata":{"id":"3ppVnY8EhUBm"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["with open(os.path.join(train_dataset_path, 'labels.json'), 'r') as f:\n"," train_json = json.load(f)\n","\n","with open(os.path.join(validation_dataset_path, 'labels.json'), 'r') as f:\n"," val_json = json.load(f)\n","\n","with open(os.path.join(test_dataset_path, 'labels.json'), 'r') as f:\n"," test_json = json.load(f)\n","\n","n_images = max(train_json['images'][-1]['id'], val_json['images'][-1]['id'], test_json['images'][-1]['id'])\n","n_annotations = max(train_json['annotations'][-1]['id'], val_json['annotations'][-1]['id'], test_json['annotations'][-1]['id'])"],"metadata":{"id":"4l6m63pOLFhV"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def clip_bbox(bbox, image_width, image_height):\n"," x_min, y_min, width, height = bbox\n","\n"," x_min = max(0, min(x_min, image_width - 1)) # TODO: check -1\n"," y_min = max(0, min(y_min, image_height - 1)) # TODO: check -1\n"," width = min(width, image_width - x_min)\n"," height = min(height, image_height - y_min)\n","\n"," return [x_min, y_min, width, height]"],"metadata":{"id":"7BO5fOS6PnEB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def validate_bbox(bbox, image_width, image_height):\n"," x, y, w, h = bbox\n","\n"," return 0 <= x < image_width and 0 <= y < image_height and x + w <= image_width and y + h <= image_height"],"metadata":{"id":"KVDrkEmpXGD6"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def apply_augmentation(image_path, bboxes, class_labels, output_path, output_filename, transform):\n"," # Read the image\n"," image = cv2.imread(image_path)\n"," image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)\n"," image_height, image_width = image.shape[:2]\n","\n"," # Apply the augmentation\n"," try:\n"," transformed = transform(image=image, bboxes=bboxes, class_labels=class_labels)\n"," except Exception as e:\n"," print(f\"Error during transformation: {e}\")\n"," return [], []\n","\n"," # Save the augmented image\n"," augmented_image_path = os.path.join(output_path, output_filename)\n"," cv2.imwrite(augmented_image_path, cv2.cvtColor(transformed['image'], cv2.COLOR_RGB2BGR))\n","\n"," return transformed['bboxes'], transformed['class_labels']"],"metadata":{"id":"eZWBxtDrhonL"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["def augment_dataset(input_path, output_path, transform, n_images, n_annotations, num_augmentations=5):\n"," # Load the original COCO JSON file\n"," with open(os.path.join(input_path, 'labels.json'), 'r') as f:\n"," coco_data = json.load(f)\n","\n"," new_images = []\n"," new_annotations = []\n","\n"," # Copy original images and annotations\n"," for img in tqdm(coco_data['images'], desc=\"Copying original images\"):\n","\n"," src_path = os.path.join(input_path, 'images', img['file_name'])\n"," dst_path = os.path.join(output_path, 'images', img['file_name'])\n"," shutil.copy2(src_path, dst_path)\n","\n"," new_images.append(img)\n"," img_anns = [ann for ann in coco_data['annotations'] if ann['image_id'] == img['id']]\n"," new_annotations.extend(img_anns)\n","\n"," '''debug'''\n"," print(\"Before augmentation:\")\n"," print(f\"Number of images: {len(new_images)}\")\n"," print(f\"Number of annotations: {len(new_annotations)}\")\n","\n"," # Apply augmentations\n"," for img in tqdm(coco_data['images'], desc=\"Augmenting images\"):\n"," image_path = os.path.join(input_path, 'images', img['file_name'])\n","\n"," annotations = [ann for ann in coco_data['annotations'] if ann['image_id'] == img['id']]\n","\n"," image = cv2.imread(image_path)\n"," image_height, image_width = image.shape[:2]\n","\n"," for i in range(num_augmentations):\n"," bboxes = [ann['bbox'] for ann in annotations]\n"," class_labels = [ann['category_id'] for ann in annotations]\n","\n"," # TODO: should I call the function clip_bbox() regardless of the function validate_bbox()?\n"," for bbox in bboxes:\n"," if not validate_bbox(bbox, image_width, image_height):\n"," bboxes = [clip_bbox(bbox, image_width, image_height) for bbox in bboxes]\n","\n"," new_filename = f\"{os.path.splitext(img['file_name'])[0]}_aug_{i}.jpg\"\n","\n"," new_bboxes, new_class_labels = apply_augmentation(\n"," image_path, bboxes, class_labels,\n"," os.path.join(output_path, 'images'), new_filename, transform\n"," )\n","\n"," new_img_id = n_images + 1\n"," new_images.append({\n"," 'id': new_img_id,\n"," 'file_name': new_filename\n"," })\n","\n"," n_images = n_images + 1\n","\n"," for bbox, cat_id in zip(new_bboxes, new_class_labels):\n"," new_annotations.append({\n"," 'id': n_annotations + 1,\n"," 'image_id': new_img_id,\n"," 'category_id': cat_id,\n"," 'bbox': bbox\n"," })\n","\n"," n_annotations = n_annotations + 1\n","\n"," '''debug'''\n"," print(\"After augmentation:\")\n"," print(f\"Number of images: {len(new_images)}\")\n"," print(f\"Number of annotations: {len(new_annotations)}\")\n","\n"," # Create the new COCO JSON file\n"," new_coco_data = {\n"," 'categories': coco_data['categories'],\n"," 'images': new_images,\n"," 'annotations': new_annotations\n"," }\n","\n"," # Save the new COCO JSON file\n"," with open(os.path.join(output_path, 'labels.json'), 'w') as f:\n"," json.dump(new_coco_data, f, indent=4)"],"metadata":{"id":"qzU7uPm9iuny","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["augmented_train_dataset_path = dest_base_path + 'train_augmented/'\n","augmented_validation_dataset_path = dest_base_path + 'validation_augmented/'\n","\n","os.makedirs(os.path.join(augmented_train_dataset_path, 'images'), exist_ok=True)\n","os.makedirs(os.path.join(augmented_validation_dataset_path, 'images'), exist_ok=True)\n","\n","if os.listdir(augmented_train_dataset_path + 'images/') == []:\n"," augment_dataset(train_dataset_path, augmented_train_dataset_path, get_transform('train'), n_images, n_annotations, num_augmentations=5)\n","else:\n"," print(\"Augmentation on the training set has already been made.\")\n","\n","if os.listdir(augmented_validation_dataset_path + 'images/') == []:\n"," augment_dataset(validation_dataset_path, augmented_validation_dataset_path, get_transform('validation'), n_images, n_annotations, num_augmentations=5)\n","else:\n"," print(\"Augmentation on the validation set has already been made.\")"],"metadata":{"id":"PEdCirIuvIbZ"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["count1 = sum(1 for filename in os.listdir(os.path.join(augmented_train_dataset_path, 'images')) if any(filename.lower().endswith(ext) for ext in ['.jpg', '.jpeg']))\n","count2 = sum(1 for filename in os.listdir(os.path.join(augmented_validation_dataset_path, 'images')) if any(filename.lower().endswith(ext) for ext in ['.jpg', '.jpeg']))\n","\n","print(f\"Number of images in the train_augmented folder: {count1}\") # TODO: make prettier\n","print(f\"Number of images in the validation_augmented folder: {count2}\") # TODO: make prettier"],"metadata":{"id":"ER3exrc3ciP5"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["#@title Visualize the augmented training data [FIXME]\n","\n","visualize(augmented_train_dataset_path, 12)"],"metadata":{"collapsed":true,"id":"lhULs9Lxjgyt","cellView":"form"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Rewrite train dataset"],"metadata":{"id":"zbBGHtdpJ98E"}},{"cell_type":"code","source":["# TODO: add if condition (if augmentation has been executed)"],"metadata":{"id":"HBithwlITduj"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["shutil.rmtree(\"/tmp/od_data/train\") # TODO: do I need this instruction ? do I need to use other cache dirs?"],"metadata":{"id":"pUpntlCgYgUh"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["train_data = object_detector.Dataset.from_coco_folder(augmented_train_dataset_path, cache_dir=\"/tmp/od_data/train\")\n","validation_data = object_detector.Dataset.from_coco_folder(validation_dataset_path, cache_dir=\"/tmp/od_data/validation\")\n","\n","print(f\"{'New Training Dataset Size:':<25} {train_data.size:>6} images\")\n","print(f\"{'New Validation Dataset Size:':<25} {validation_data.size:>4} images\")"],"metadata":{"id":"2A_QSKYdUs2U"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Retrain model"],"metadata":{"id":"IzhODm2CJaMB"}},{"cell_type":"markdown","source":["### Set retraining options"],"metadata":{"id":"jF7sZHYyJcl7"}},{"cell_type":"code","source":["spec = object_detector.SupportedModels.MOBILENET_MULTI_AVG_I384\n","\n","hparams = object_detector.HParams(\n"," learning_rate=0.015, # reduce to 0.01 (is it possible to implement a scheduler?)\n"," batch_size=64, # try 128, 256\n"," epochs=100,\n"," cosine_decay_epochs=100,\n"," cosine_decay_alpha=0.1,\n"," #shuffle=True, # TODO: check\n"," #repeat=True, # TODO: check\n"," export_dir='exported_model'\n",")\n","\n","model_options = object_detector.ModelOptions(\n"," l2_weight_decay=1e-4 # 3e-5\n",")\n","\n","options = object_detector.ObjectDetectorOptions(\n"," supported_model=spec,\n"," hparams=hparams,\n"," model_options=model_options\n",")"],"metadata":{"id":"iC6cVSpVJWFw"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Run retraining"],"metadata":{"id":"bfQoIMPOJhK5"}},{"cell_type":"code","source":["model = object_detector.ObjectDetector.create(\n"," train_data=train_data,\n"," validation_data=validation_data,\n"," options=options\n",")"],"metadata":{"id":"BxT1UHQWJfAX","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Evaluate the model performance\n","\n","After training the model, evaluate it on validation dataset and print the loss and coco_metrics. The most important metric for evaluating the model performance is typically the \"AP\" coco metric for Average Precision."],"metadata":{"id":"y3h_cOytJjXk"}},{"cell_type":"code","source":["loss, coco_metrics = model.evaluate(validation_data, batch_size=4) # TODO: update batch_size (?)\n","print(f\"Validation loss: {loss}\")\n","print(f\"Validation coco metrics: {coco_metrics}\")"],"metadata":{"id":"kJabrlEjJkmH","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Export model\n","\n","After creating the model, convert and export it to a Tensorflow Lite model format for later use on an on-device application. The export also includes model metadata, which includes the label map."],"metadata":{"id":"Vkg5K-xWJmj6"}},{"cell_type":"code","source":["# TODO: do I need to remove the existing model first?"],"metadata":{"id":"eQ3BCV__LXaB"},"execution_count":null,"outputs":[]},{"cell_type":"code","source":["model.export_model()\n","!ls exported_model\n","files.download('exported_model/model.tflite')"],"metadata":{"id":"ADUK65tBJn7x","collapsed":true},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["## Model quantization"],"metadata":{"id":"6g6QkA6YJpf8"}},{"cell_type":"markdown","source":["Model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy.\n","\n","This section of the guide explains how to apply quantization to your model. Model Maker supports two forms of quantization for object detector:\n","1. Quantization Aware Training: 8 bit integer precision for CPU usage\n","2. Post-Training Quantization: 16 bit floating point precision for GPU usage"],"metadata":{"id":"Jl1Dc9EhJtRo"}},{"cell_type":"markdown","source":["### Quantization aware training (int8 quantization)\n","Quantization aware training (QAT) is a fine-tuning step which happens after fully training your model. This technique further tunes a model which emulates inference time quantization in order to account for the lower precision of 8 bit integer quantization. For on-device applications with a standard CPU, use Int8 precision. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/training) documentation.\n","\n","To apply quantization aware training and export to an int8 model, create a `QATHParams` configuration and run the `quantization_aware_training` method. See the **Hyperparameters** section below on detailed usage of `QATHParams`."],"metadata":{"id":"D88f-mqgJvwi"}},{"cell_type":"code","source":["qat_hparams = object_detector.QATHParams(learning_rate=0.3, batch_size=4, epochs=10, decay_steps=6, decay_rate=0.96)\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"NbTDJtEJJr7Y"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["The QAT step often requires multiple runs to tune the parameters of training. To avoid having to rerun model training using the `create` method, use the `restore_float_ckpt` method to restore the model state back to the fully trained float model(After running the `create` method) in order to run QAT again."],"metadata":{"id":"5KEPghm9JyUi"}},{"cell_type":"code","source":["new_qat_hparams = object_detector.QATHParams(learning_rate=0.9, batch_size=4, epochs=15, decay_steps=5, decay_rate=0.96)\n","model.restore_float_ckpt()\n","model.quantization_aware_training(train_data, validation_data, qat_hparams=new_qat_hparams)\n","qat_loss, qat_coco_metrics = model.evaluate(validation_data)\n","print(f\"QAT validation loss: {qat_loss}\")\n","print(f\"QAT validation coco metrics: {qat_coco_metrics}\")"],"metadata":{"id":"vALj8IqaJ1B4"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Finally, us the `export_model` to export to an int8 quantized model. The `export_model` function will automatically export to either float32 or int8 model depending on whether `quantization_aware_training` was run."],"metadata":{"id":"HdZAaTJKJ4ky"}},{"cell_type":"code","source":["model.export_model('model_int8_qat.tflite')\n","!ls -lh exported_model\n","files.download('exported_model/model_int8_qat.tflite')"],"metadata":{"id":"AePYkJHwJ2OG"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["### Post-training quantization (fp16 quantization)\n","\n","Post-training model quantization is a model modification technique that can reduce the model size and improve the speed of predictions with only a relatively minor decrease in accuracy. This approach reduces the size of the data processed by the model, for example by transforming 32-bit floating point numbers to 16-bit floats. Float16 quantization is reccomended for GPU usage. For more information, see the [TensorFlow Lite](https://www.tensorflow.org/model_optimization/guide/quantization/post_training) documentation.\n","\n","First, import the MediaPipe Model Maker quantization module:"],"metadata":{"id":"mm4gCymvJ7wK"}},{"cell_type":"code","source":["from mediapipe_model_maker import quantization"],"metadata":{"id":"mt8zeY52J8Gk"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Define a QuantizationConfig object using the `for_float16()` class method. This configuration modifies a trained model to use 16-bit floating point numbers instead of 32-bit floating point numbers. You can further customize the quantization process by setting additional parameters for the QuantizationConfig class."],"metadata":{"id":"r6eGgib4J-Ac"}},{"cell_type":"code","source":["quantization_config = quantization.QuantizationConfig.for_float16()"],"metadata":{"id":"IjAn-GBKJ_ab"},"execution_count":null,"outputs":[]},{"cell_type":"markdown","source":["Export the model using the additional quantization_config object to apply post-training quantization. Note that if you previously ran `quantization_aware_training`, you must first convert the model back to a float model by using `restore_float_ckpt`."],"metadata":{"id":"ChqPtssVKCA8"}},{"cell_type":"code","source":["model.restore_float_ckpt()\n","model.export_model(model_name=\"model_fp16.tflite\", quantization_config=quantization_config)\n","!ls -lh exported_model\n","files.download('exported_model/model_fp16.tflite')"],"metadata":{"id":"f90WxGNNKCSW"},"execution_count":null,"outputs":[]}],"metadata":{"colab":{"last_runtime":{"build_target":"//learning/grp/tools/ml_python:ml_notebook","kind":"private"},"private_outputs":true,"provenance":[{"file_id":"https://github.com/googlesamples/mediapipe/blob/main/examples/customization/object_detector.ipynb","timestamp":1720270505355},{"file_id":"11PG1YgsQWWLJ8jpqJ6QY7hjYWzxVwoCb","timestamp":1677706798050}],"gpuType":"L4","collapsed_sections":["GYUY-Sw33nJD","ThZpHZPKKK1O","Vkg5K-xWJmj6","6g6QkA6YJpf8"]},"kernelspec":{"display_name":"Python 3","name":"python3"},"language_info":{"name":"python"},"accelerator":"GPU"},"nbformat":4,"nbformat_minor":0} \ No newline at end of file