From a319c88b8535c72864f9e77c120817cd58a76483 Mon Sep 17 00:00:00 2001 From: Rania Sridi Date: Tue, 17 Jun 2025 10:49:05 +0200 Subject: [PATCH] add arabic vocabs and some modification for the detection model so the errors are more clear --- doctr/datasets/vocabs.py | 5 +- .../differentiable_binarization/base.py | 30 +- .../differentiable_binarization/pytorch.py | 27 +- references/detection/train_pytorch.py | 286 +++++++++++++----- 4 files changed, 261 insertions(+), 87 deletions(-) diff --git a/doctr/datasets/vocabs.py b/doctr/datasets/vocabs.py index cf973f2235..d39b713975 100644 --- a/doctr/datasets/vocabs.py +++ b/doctr/datasets/vocabs.py @@ -13,7 +13,7 @@ # Arabic & Persian "arabic_diacritics": "ًٌٍَُِّْ", "arabic_digits": "٠١٢٣٤٥٦٧٨٩", - "arabic_letters": "ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي", + "arabic_letters": "- ء آ أ ؤ إ ئ ا ٪ ب ت ث ج ح خ د ذ ر ز س ش ص ض ط ظ ع غ ف ق ك ٰیٕ٪ ل م ن ه ة و ي پ چ ڢ ڤ گ ﻻ ﻷ ﻹ ﻵ ﺀ ﺁ ﺃ ﺅ ﺇ ﺉ ﺍ ﺏ ﺕ ﺙ ﺝ ﺡ ﺥ ﺩ ﺫ ﺭ ﺯ ﺱ ﺵ ﺹ ﺽ ﻁ ﻅ ﻉ ﻍ ﻑ ﻕ ﻙ ﻝ ﻡ ﻥ ﻩ ﻩ ﻭ ﻱ ﺑ ﺗ ﺛ ﺟ ﺣ ﺧ ﺳ ﺷ ﺻ ﺿ ﻃ ﻇ ﻋ ﻏ ﻓ ﻗ ﻛ ﻟ ﻣ ﻧ ﻫ ﻳ ﺒ ﺘ ﺜ ﺠ ﺤ ﺨ ﺴ ﺸ ﺼ ﺾ ﻄ ﻈ ﻌ ﻐ ﻔ ﻘ ﻜ ﻠ ﻤ ﻨ ﻬ ﻴ ﺎ ﺐ ﺖ ﺚ ﺞ ﺢ ﺦ ﺪ ﺬ ﺮ ﺰ ﺲ ﺶ ﺺ ﺾ ﻂ ﻆ ﻊ ﻎ ﻒ ﻖ ﻚ ﻞ ﻢ ﻦ ﻪ ﺔ ﺓﺋ ﺓﺋ ى ﻼوفرّٕ ﺊ ﻯ ﻀ ﻯ ﻼ ﺋ ﺊﺓى ﻀال ص ح x ـ ـوx ﻰ ﻮ ﻲ ً ٌ ؟ ؛ « » — ! # $ % & ' ( ) * + , - . / : ; < = > ? @ [ ] ^ _ { | } ~", "arabic_punctuation": "؟؛«»—", "persian_letters": "پچڢڤگ", # Bangla @@ -786,7 +786,8 @@ VOCABS["multilingual"] = "".join( dict.fromkeys( # latin_based - VOCABS["english"] + VOCABS["arabic"] + +VOCABS["english"] + VOCABS["albanian"] + VOCABS["afrikaans"] + VOCABS["azerbaijani"] diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py index 8eb9a5f4c8..c70c524d8b 100644 --- a/doctr/models/detection/differentiable_binarization/base.py +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -28,6 +28,17 @@ class DBPostProcessor(DetectionPostProcessor): bin_thresh: threshold used to binzarized p_map at inference time """ + class InvalidCoordinatesError(Exception): + def __init__(self, image_name, class_name, min_val, max_val): + self.image_name = image_name + self.class_name = class_name + self.min_val = min_val + self.max_val = max_val + message = ( + f"Invalid box coordinates in {image_name}, class '{class_name}': " + f"values should be between 0 & 1, but found range [{min_val:.4f}, {max_val:.4f}]." + ) + super().__init__(message) def __init__( self, @@ -270,11 +281,24 @@ def build_target( target: list[dict[str, np.ndarray]], output_shape: tuple[int, int, int], channels_last: bool = True, + image_names: list[str] = None, # Add optional parameter for image names ) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: if any(t.dtype != np.float32 for tgt in target for t in tgt.values()): raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.") - if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for tgt in target for t in tgt.values()): - raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.") + + # Enhanced error checking with image identification + for idx, tgt in enumerate(target): + for class_name, t in tgt.items(): + if np.any((t[:, :4] > 1) | (t[:, :4] < 0)): + image_id = f"image #{idx}" if image_names is None else image_names[idx] + # Find the actual values that are out of range for better debugging + min_val = t[:, :4].min() + max_val = t[:, :4].max() + raise ValueError( + f"Invalid box coordinates in {image_id}, class '{class_name}': " + f"values should be between 0 & 1, but found range [{min_val:.4f}, {max_val:.4f}]. " + f"Please normalize your coordinates by dividing x by image width and y by image height." + ) input_dtype = next(iter(target[0].values())).dtype if len(target) > 0 else np.float32 @@ -362,4 +386,4 @@ def build_target( thresh_target = thresh_target.astype(input_dtype) thresh_mask = thresh_mask.astype(bool) - return seg_target, seg_mask, thresh_target, thresh_mask + return seg_target, seg_mask, thresh_target, thresh_mask \ No newline at end of file diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py index 5b82b47149..58d259b5c4 100644 --- a/doctr/models/detection/differentiable_binarization/pytorch.py +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -4,8 +4,7 @@ # See LICENSE or go to for full license details. from collections.abc import Callable -from typing import Any - +from typing import Any, List, Optional import numpy as np import torch from torch import nn @@ -185,6 +184,7 @@ def forward( target: list[np.ndarray] | None = None, return_model_output: bool = False, return_preds: bool = False, + image_names: Optional[List[str]] = None, # Added parameter for image names ) -> dict[str, torch.Tensor]: # Extract feature maps at different stages feats = self.feat_extractor(x) @@ -218,7 +218,7 @@ def _postprocess(prob_map: torch.Tensor) -> list[dict[str, Any]]: if target is not None: thresh_map = self.thresh_head(feat_concat) - loss = self.compute_loss(logits, thresh_map, target) + loss = self.compute_loss(logits, thresh_map, target, image_names=image_names) # Pass image_names to compute_loss out["loss"] = loss return out @@ -231,6 +231,7 @@ def compute_loss( gamma: float = 2.0, alpha: float = 0.5, eps: float = 1e-8, + image_names: Optional[List[str]] = None, # Paramètre pour les noms d'images ) -> torch.Tensor: """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes and a list of masks for each image. From there it computes the loss with the model output @@ -242,6 +243,7 @@ def compute_loss( gamma: modulating factor in the focal loss formula alpha: balancing factor in the focal loss formula eps: epsilon factor in dice loss + image_names: list of image filenames for error reporting Returns: A loss tensor @@ -252,13 +254,26 @@ def compute_loss( prob_map = torch.sigmoid(out_map) thresh_map = torch.sigmoid(thresh_map) - targets = self.build_target(target, out_map.shape[1:], False) # type: ignore[arg-type] + try: + targets = self.build_target(target, out_map.shape[1:], False, image_names=image_names) + except ValueError as e: + # Re-raise with more context about which images caused the problem + if "Invalid box coordinates" in str(e) and image_names: + batch_info = ", ".join(image_names) + raise ValueError(f"{str(e)} Images in batch: {batch_info}") + else: + raise seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3]) thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device) + # Initialize all loss components + focal_loss = torch.tensor(0.0, device=out_map.device) + dice_loss = torch.tensor(0.0, device=out_map.device) + l1_loss = torch.tensor(0.0, device=out_map.device) + if torch.any(seg_mask): # Focal loss focal_scale = 10.0 @@ -269,7 +284,7 @@ def compute_loss( # Unreduced version focal_loss = alpha_t * (1 - p_t) ** gamma * bce_loss # Class reduced - focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / seg_mask.sum((0, 1, 2, 3)) + focal_loss = (seg_mask * focal_loss).sum((0, 1, 2, 3)) / (seg_mask.sum((0, 1, 2, 3)) + eps) # Compute dice loss for each class or for approx binary_map if len(self.class_names) > 1: @@ -429,4 +444,4 @@ def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet: "thresh_head.6.bias", ], **kwargs, - ) + ) \ No newline at end of file diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py index 33ab6ebebd..e9b4066da2 100644 --- a/references/detection/train_pytorch.py +++ b/references/detection/train_pytorch.py @@ -12,6 +12,7 @@ import logging import multiprocessing import time +import json from pathlib import Path import numpy as np @@ -36,6 +37,38 @@ from doctr.utils.metrics import LocalizationConfusion from utils import EarlyStopper, plot_recorder, plot_samples +def extract_image_names(dataset, train_path=None): + """Extract image filenames from dataset if possible""" + # If we have the path to the labels.json file, use it directly + if train_path: + labels_path = os.path.join(train_path, "labels.json") + if os.path.exists(labels_path): + try: + with open(labels_path, 'r') as f: + data = json.load(f) + return list(data.keys()) + except Exception: + pass + + # Fallback to existing methods + if hasattr(dataset, 'data') and isinstance(dataset.data, dict): + return list(dataset.data.keys()) + elif hasattr(dataset, 'img_paths') and isinstance(dataset.img_paths, list): + return [Path(img_path).name for img_path in dataset.img_paths] + else: + # Default to empty list if we can't determine image names + return [] + +def get_image_names_from_json(json_path): + """Extract image names directly from JSON file""" + try: + with open(json_path, 'r') as f: + data = json.load(f) + # Return the list of keys (file names) + return list(data.keys()) + except Exception as e: + print(f"Error loading JSON file: {e}") + return [] def record_lr( model: torch.nn.Module, @@ -68,129 +101,230 @@ def record_lr( if amp: scaler = torch.cuda.amp.GradScaler() + # Try to extract image names from the dataset + image_names = extract_image_names(train_loader.dataset) + for batch_idx, (images, targets) in enumerate(train_loader): if torch.cuda.is_available(): images = images.cuda() images = batch_transforms(images) + # Get the batch filenames if available + batch_filenames = [] + if image_names: + batch_start_idx = batch_idx * train_loader.batch_size + batch_end_idx = min(batch_start_idx + len(images), len(image_names)) + batch_filenames = image_names[batch_start_idx:batch_end_idx] + # Forward, Backward & update optimizer.zero_grad() - if amp: - with torch.cuda.amp.autocast(): - train_loss = model(images, targets)["loss"] - scaler.scale(train_loss).backward() - # Gradient clipping - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - # Update the params - scaler.step(optimizer) - scaler.update() - else: - train_loss = model(images, targets)["loss"] - train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - optimizer.step() + try: + if amp: + with torch.cuda.amp.autocast(): + train_loss = model(images, targets, image_names=batch_filenames)["loss"] + scaler.scale(train_loss).backward() + # Gradient clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + # Update the params + scaler.step(optimizer) + scaler.update() + else: + train_loss = model(images, targets, image_names=batch_filenames)["loss"] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + optimizer.step() + + # Record + if not torch.isfinite(train_loss): + if batch_idx == 0: + raise ValueError("loss value is NaN or inf.") + else: + break + loss_recorder.append(train_loss.item()) + except ValueError as e: + # Skip this batch if there's an error with coordinates + if "Invalid box coordinates" in str(e): + print(f"Skipping batch with invalid coordinates: {e}") + continue + else: + # Re-raise if it's a different ValueError + raise + # Update LR scheduler.step() - # Record - if not torch.isfinite(train_loss): - if batch_idx == 0: - raise ValueError("loss value is NaN or inf.") - else: - break - loss_recorder.append(train_loss.item()) # Stop after the number of iterations if batch_idx + 1 == num_it: break return lr_recorder[: len(loss_recorder)], loss_recorder - def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, amp=False, log=None, rank=0): if amp: scaler = torch.cuda.amp.GradScaler() + # Load image names from JSON file + image_names = get_image_names_from_json(os.path.join(args.train_path, "labels.json")) + model.train() # Iterate over the batches of the dataset epoch_train_loss, batch_cnt = 0, 0 + skipped_batches = 0 pbar = tqdm(train_loader, dynamic_ncols=True, disable=(rank != 0)) - for images, targets in pbar: + for batch_idx, (images, targets) in enumerate(pbar): + # Calculate indices of images in this batch + start_idx = batch_idx * train_loader.batch_size + end_idx = start_idx + len(images) + + # Get image names for this batch + batch_filenames = [] + if start_idx < len(image_names): + batch_filenames = image_names[start_idx:end_idx] + if batch_idx % 10 == 0: # Show every 10 batches to avoid excessive output + print(f"Batch {batch_idx} contains: {batch_filenames}") + else: + batch_filenames = [f"unknown_{i}" for i in range(len(images))] + if torch.cuda.is_available(): images = images.cuda() images = batch_transforms(images) optimizer.zero_grad() - if amp: - with torch.cuda.amp.autocast(): - train_loss = model(images, targets)["loss"] - scaler.scale(train_loss).backward() - # Gradient clipping - scaler.unscale_(optimizer) - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - # Update the params - scaler.step(optimizer) - scaler.update() - else: - train_loss = model(images, targets)["loss"] - train_loss.backward() - torch.nn.utils.clip_grad_norm_(model.parameters(), 5) - optimizer.step() - - scheduler.step() - last_lr = scheduler.get_last_lr()[0] - - pbar.set_description(f"Training loss: {train_loss.item():.6} | LR: {last_lr:.6}") - if log: - log(train_loss=train_loss.item(), lr=last_lr) - - epoch_train_loss += train_loss.item() - batch_cnt += 1 + + try: + if amp: + with torch.cuda.amp.autocast(): + # Pass file names to the model + train_loss = model(images, targets, image_names=batch_filenames)["loss"] + scaler.scale(train_loss).backward() + # Gradient clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + # Update the params + scaler.step(optimizer) + scaler.update() + else: + # Pass file names to the model + train_loss = model(images, targets, image_names=batch_filenames)["loss"] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + optimizer.step() + + scheduler.step() + last_lr = scheduler.get_last_lr()[0] + + pbar.set_description(f"Training loss: {train_loss.item():.6} | LR: {last_lr:.6}") + if log: + log(train_loss=train_loss.item(), lr=last_lr) + + epoch_train_loss += train_loss.item() + batch_cnt += 1 + + except ValueError as e: + # Skip this batch if there's an error with coordinates + if "Invalid box coordinates" in str(e): + skipped_batches += 1 + print(f"Skipping batch {batch_idx} with invalid coordinates: {e}") + # Still step the scheduler to maintain the learning rate schedule + scheduler.step() + last_lr = scheduler.get_last_lr()[0] + + # Update progress bar to show we're skipping + pbar.set_description(f"Skipped batch ({skipped_batches} total) | LR: {last_lr:.6}") + continue + else: + # Re-raise if it's a different ValueError + raise - epoch_train_loss /= batch_cnt + if batch_cnt > 0: + epoch_train_loss /= batch_cnt + else: + epoch_train_loss = float('nan') + + if skipped_batches > 0: + print(f"Skipped {skipped_batches} batches with invalid coordinates in this epoch") + return epoch_train_loss, last_lr - @torch.no_grad() def evaluate(model, val_loader, batch_transforms, val_metric, args, amp=False, log=None): # Model in eval mode model.eval() # Reset val metric val_metric.reset() + + # Load image names from JSON file + image_names = get_image_names_from_json(os.path.join(args.val_path, "labels.json")) + # Validation loop val_loss, batch_cnt = 0, 0 + skipped_batches = 0 pbar = tqdm(val_loader, dynamic_ncols=True) - for images, targets in pbar: + for batch_idx, (images, targets) in enumerate(pbar): + # Calculate indices of images in this batch + start_idx = batch_idx * val_loader.batch_size + end_idx = start_idx + len(images) + + # Get image names for this batch + batch_filenames = [] + if start_idx < len(image_names): + batch_filenames = image_names[start_idx:end_idx] + else: + batch_filenames = [f"unknown_{i}" for i in range(len(images))] + if torch.cuda.is_available(): images = images.cuda() images = batch_transforms(images) - if amp: - with torch.cuda.amp.autocast(): - out = model(images, targets, return_preds=True) - else: - out = model(images, targets, return_preds=True) - # Compute metric - loc_preds = out["preds"] - for target, loc_pred in zip(targets, loc_preds): - for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()): - if args.rotation and args.eval_straight: - # Convert pred to boxes [xmin, ymin, xmax, ymax] N, 5, 2 (with scores) --> N, 4 - boxes_pred = np.concatenate((boxes_pred[:, :4].min(axis=1), boxes_pred[:, :4].max(axis=1)), axis=-1) - val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) - - pbar.set_description(f"Validation loss: {out['loss'].item():.6}") - if log: - log(val_loss=out["loss"].item()) - - val_loss += out["loss"].item() - batch_cnt += 1 - - val_loss /= batch_cnt + + try: + if amp: + with torch.cuda.amp.autocast(): + out = model(images, targets, return_preds=True, image_names=batch_filenames) + else: + out = model(images, targets, return_preds=True, image_names=batch_filenames) + + # Compute metric + loc_preds = out["preds"] + for target, loc_pred in zip(targets, loc_preds): + for boxes_gt, boxes_pred in zip(target.values(), loc_pred.values()): + if args.rotation and args.eval_straight: + # Convert pred to boxes [xmin, ymin, xmax, ymax] N, 5, 2 (with scores) --> N, 4 + boxes_pred = np.concatenate((boxes_pred[:, :4].min(axis=1), boxes_pred[:, :4].max(axis=1)), axis=-1) + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) + + pbar.set_description(f"Validation loss: {out['loss'].item():.6}") + if log: + log(val_loss=out["loss"].item()) + + val_loss += out["loss"].item() + batch_cnt += 1 + + except ValueError as e: + # Skip this batch if there's an error with coordinates + if "Invalid box coordinates" in str(e): + skipped_batches += 1 + print(f"Skipping validation batch {batch_idx} with invalid coordinates: {e}") + pbar.set_description(f"Skipped batch ({skipped_batches} total)") + continue + else: + # Re-raise if it's a different ValueError + raise + + if batch_cnt > 0: + val_loss /= batch_cnt + else: + val_loss = float('nan') + + if skipped_batches > 0: + print(f"Skipped {skipped_batches} validation batches with invalid coordinates") + recall, precision, mean_iou = val_metric.summary() return val_loss, recall, precision, mean_iou - +# The rest of your code remains the same... +# main function, parse_args, etc. def main(args): # Detect distributed setup # variable is set by torchrun @@ -651,4 +785,4 @@ def parse_args(): if __name__ == "__main__": args = parse_args() - main(args) + main(args) \ No newline at end of file