From 9c0b81659c6544eb8ea30be6b45c465b3446fb96 Mon Sep 17 00:00:00 2001 From: James Ball <37094972+PatBall1@users.noreply.github.com> Date: Thu, 11 May 2023 15:32:47 +0100 Subject: [PATCH] Jb/fixes (#97) * added to doc strings and updated clean_crowns function * change reset index in clean crowns * readme tutorial edits * clean_crowns function can filter by confidence score --- README.md | 2 +- detectree2/models/outputs.py | 70 +++++++++++++++++++++++++++--- detectree2/models/predict.py | 7 ++- detectree2/models/train.py | 42 ++++++++++-------- detectree2/preprocessing/tiling.py | 19 +++++--- docs/source/tutorial.rst | 21 ++++++++- 6 files changed, 128 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index f812aad7..d3ea1440 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -# detectree2 +# :robot: detectree2 [![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![Detectree CI](https://github.com/patball1/detectree2/actions/workflows/python-ci.yml/badge.svg)](https://github.com/patball1/detectree2/actions/workflows/python-ci.yml) [![PEP8](https://img.shields.io/badge/code%20style-pep8-orange.svg)](https://www.python.org/dev/peps/pep-0008/) [![DOI](https://zenodo.org/badge/470698486.svg)](https://zenodo.org/badge/latestdoi/470698486) diff --git a/detectree2/models/outputs.py b/detectree2/models/outputs.py index 881b9931..6f0d49ba 100644 --- a/detectree2/models/outputs.py +++ b/detectree2/models/outputs.py @@ -136,6 +136,14 @@ def project_to_geojson(tiles_path, pred_fold=None, output_fold=None): # noqa:N8 Takes a json and changes it to a geojson so it can overlay with orthomosaic. Another copy is produced to overlay with PNGs. + + Args: + tiles_path (str): Path to the tiles folder. + pred_fold (str): Path to the predictions folder. + output_fold (str): Path to the output folder. + + Returns: + None """ Path(output_fold).mkdir(parents=True, exist_ok=True) @@ -232,14 +240,33 @@ def filename_geoinfo(filename): def box_filter(filename, shift: int = 0): - """Create a bounding box from a file name to filter edge crowns.""" + """Create a bounding box from a file name to filter edge crowns. + + Args: + filename: Name of the file. + shift: Number of meters to shift the size of the bounding box in by. This is to avoid edge crowns. + + Returns: + gpd.GeoDataFrame: A GeoDataFrame containing the bounding box.""" minx, miny, width, buffer, crs = filename_geoinfo(filename) bounding_box = box_make(minx, miny, width, buffer, crs, shift) return bounding_box def box_make(minx: int, miny: int, width: int, buffer: int, crs, shift: int = 0): - """Generate bounding box from geographic specifications.""" + """Generate bounding box from geographic specifications. + + Args: + minx: Minimum x coordinate. + miny: Minimum y coordinate. + width: Width of the tile. + buffer: Buffer around the tile. + crs: Coordinate reference system. + shift: Number of meters to shift the size of the bounding box in by. This is to avoid edge crowns. + + Returns: + gpd.GeoDataFrame: A GeoDataFrame containing the bounding box. + """ bbox = box( minx - buffer + shift, miny - buffer + shift, @@ -251,7 +278,15 @@ def box_make(minx: int, miny: int, width: int, buffer: int, crs, shift: int = 0) def stitch_crowns(folder: str, shift: int = 1): - """Stitch together predicted crowns.""" + """Stitch together predicted crowns. + + Args: + folder: Path to folder containing geojson files. + shift: Number of meters to shift the size of the bounding box in by. This is to avoid edge crowns. + + Returns: + gpd.GeoDataFrame: A GeoDataFrame containing all the crowns. + """ crowns_path = Path(folder) files = crowns_path.glob("*geojson") _, _, _, _, crs = filename_geoinfo(list(files)[0]) @@ -285,13 +320,28 @@ def calc_iou(shape1, shape2): return iou -def clean_crowns(crowns: gpd.GeoDataFrame, iou_threshold=0.7): +def clean_crowns(crowns: gpd.GeoDataFrame, iou_threshold=0.7, confidence=0.2): """Clean overlapping crowns. Outputs can contain highly overlapping crowns including in the buffer region. This function removes crowns with a high degree of overlap with others but a lower Confidence Score. + + Args: + crowns (gpd.GeoDataFrame): Crowns to be cleaned. + iou_threshold (float, optional): IoU threshold that determines whether crowns are overlapping. + confidence (float, optional): Minimum confidence score for crowns to be retained. Defaults to 0.2. + + Returns: + gpd.GeoDataFrame: Cleaned crowns. """ + # Filter any rows with empty geometry + crowns = crowns[crowns.is_empty == False] + # Filter any rows with invalid geometry + crowns = crowns[crowns.is_valid] + # Reset the index + crowns = crowns.reset_index(drop=True) + # Create an object to store the cleaned crowns crowns_out = gpd.GeoDataFrame() for index, row in crowns.iterrows(): # iterate over each crown if index % 1000 == 0: @@ -302,7 +352,7 @@ def clean_crowns(crowns: gpd.GeoDataFrame, iou_threshold=0.7): else: # Find those crowns that intersect with it intersecting = crowns.loc[crowns.intersects(shape(row.geometry))] - intersecting = intersecting.reset_index().drop("index", axis=1) + intersecting = intersecting.reset_index(drop=True) iou = [] for ( index1, @@ -313,7 +363,7 @@ def clean_crowns(crowns: gpd.GeoDataFrame, iou_threshold=0.7): # print(iou) intersecting["iou"] = iou matches = intersecting[intersecting["iou"] > iou_threshold] # Remove those crowns with a poor match - matches = matches.sort_values("Confidence_score", ascending=False).reset_index().drop("index", axis=1) + matches = matches.sort_values("Confidence_score", ascending=False).reset_index(drop=True) match = matches.loc[[0]] # Of the remaining crowns select the crown with the highest confidence if match["iou"][0] < 1: # If the most confident is not the initial crown continue @@ -321,7 +371,13 @@ def clean_crowns(crowns: gpd.GeoDataFrame, iou_threshold=0.7): match = match.drop("iou", axis=1) # print(index) crowns_out = crowns_out.append(match) - return crowns_out.reset_index() + # Convert pandas into back geopandas if it is not already + if not isinstance(crowns_out, gpd.GeoDataFrame): + crowns_out = gpd.GeoDataFrame(crowns_out) + # Filter remaining crowns based on confidence score + if confidence != 0: + crowns_out = crowns_out[crowns_out["Confidence_score"] > confidence] + return crowns_out.reset_index(drop=True) def clean_predictions(directory, iou_threshold=0.7): diff --git a/detectree2/models/predict.py b/detectree2/models/predict.py index 44c63810..4fdc2c39 100644 --- a/detectree2/models/predict.py +++ b/detectree2/models/predict.py @@ -1,4 +1,7 @@ -"""Generate predictions.""" +"""Generate predictions. + +This module contains the code to generate predictions on tiled data. +""" import json import os import random @@ -63,4 +66,4 @@ def predict_on_data( if __name__ == "__main__": - print("something") + predict_on_data() diff --git a/detectree2/models/train.py b/detectree2/models/train.py index ab7ed99e..d7606793 100644 --- a/detectree2/models/train.py +++ b/detectree2/models/train.py @@ -53,9 +53,9 @@ class LossEvalHook(HookBase): See https://gist.github.com/ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b Attributes: - model: - period: - data_loader: + model: model to train + period: number of iterations between evaluations + data_loader: data loader to use for evaluation patience: number of evaluation periods to wait for improvement """ @@ -295,7 +295,8 @@ def get_tree_dicts(directory: str, classes: List[str] = None, classes_at: str = Args: directory: Path to directory - classes: Signifies which column (if any) corresponds to the class labels + classes: List of classes to include + classes_at: Signifies which column (if any) corresponds to the class labels Returns: List of dictionaries corresponding to segmentations of trees. Each dictionary includes @@ -471,18 +472,25 @@ def read_data(out_dir): def remove_registered_data(name="tree"): + """Remove registered data from catalog. + + Args: + name: string of named registered data + """ for d in ["train", "val"]: DatasetCatalog.remove(name + "_" + d) MetadataCatalog.remove(name + "_" + d) def register_test_data(test_location, name="tree"): + """Register data for testing.""" d = "test" DatasetCatalog.register(name + "_" + d, lambda d=d: get_tree_dicts(test_location)) MetadataCatalog.get(name + "_" + d).set(thing_classes=["tree"]) def load_json_arr(json_path): + """Load json array.""" lines = [] with open(json_path, "r") as f: for line in f: @@ -517,19 +525,19 @@ def setup_cfg( trains: names of registered data to use for training tests: names of registered data to use for evaluating models update_model: updated pre-trained model from detectree2 model_garden - workers: - ims_per_batch: - gamma: - backbone_freeze: - warm_iter: - momentum: - batch_size_per_im: - base_lr: - weight_decay - max_iter: - num_classes: - eval_period: - out_dir: + workers: number of workers for dataloader + ims_per_batch: number of images per batch + gamma: gamma for learning rate scheduler + backbone_freeze: backbone layer to freeze + warm_iter: number of iterations for warmup + momentum: momentum for optimizer + batch_size_per_im: batch size per image + base_lr: base learning rate + weight_decay: weight decay for optimizer + max_iter: maximum number of iterations + num_classes: number of classes + eval_period: number of iterations between evaluations + out_dir: directory to save outputs """ cfg = get_cfg() cfg.merge_from_file(model_zoo.get_config_file(base_model)) diff --git a/detectree2/preprocessing/tiling.py b/detectree2/preprocessing/tiling.py index 09d28456..3600ab21 100644 --- a/detectree2/preprocessing/tiling.py +++ b/detectree2/preprocessing/tiling.py @@ -423,7 +423,16 @@ def is_overlapping_box(test_boxes_array, train_box): def record_data(crowns, out_dir, column='status'): - """Function that will record a list of classes into a file that can be readed during training.""" + """Function that will record a list of classes into a file that can be readed during training. + + Args: + crowns: gpd dataframe with the crowns + out_dir: directory to save the file + column: column name to get the classes from + + Returns: + None + """ list_of_classes = crowns[column].unique().tolist() @@ -447,10 +456,10 @@ def to_traintest_folders(tiles_folder: str = "./", """Send tiles to training (+validation) and test dir and automatically make sure no overlap. Args: - tiles_folder: - out_folder: - test_frac: - folds: + tiles_folder: folder with tiles + out_folder: folder to save train and test folders + test_frac: fraction of tiles to be used for testing + folds: number of folds to split the data into Returns: None diff --git a/docs/source/tutorial.rst b/docs/source/tutorial.rst index d8ab9360..7dfef55b 100644 --- a/docs/source/tutorial.rst +++ b/docs/source/tutorial.rst @@ -301,10 +301,29 @@ predictions with lower confidence). .. code-block:: python - crowns = stitch_crowns(tiles_path + "predictions_geo/", 1) + crowns = stitch_crowns(tiles_path + "predictions_geo/", 1, confidence=0) crowns = crowns[crowns.is_valid] crowns = clean_crowns(crowns, 0.6) +By default the ``clean_crowns`` function will remove crowns with a condidence of less than 20%. The above 'clean' crowns +includes crowns of all confidence scores (0%-100%) as ``condidence=0``. It is likely that crowns with very low +confidence will be poor quality so it is usually preferable to filter these out. A suitable threshold can be determined +by eye in QGIS or implemented as single line in Python. ``Confidence_score`` is a column in the ``crowns`` GeoDataFrame +and is considered a tunable parameter. + +.. code-block:: python + + crowns = crowns[crowns["Confidence_score"] > 0.5] + +The outputted crown polygons will have many vertices because they are generated from a mask which is pixelwise. If you +will need to edit the crowns in QGIS it is best to simplify them to a reasonable number of vertices. This can be done +with ``simplify`` method. The ``tolerance`` will determine the coarseness of the simplification it has the same units as +the coordinate reference system of the GeoSeries (meters when working with UTM). + +.. code-block:: python + + clean = clean.set_geometry(crowns.simplify(0.3)) + Once we're happy with the crown map, save the crowns to file. .. code-block:: python