Skip to content

Commit

Permalink
Jb/fixes (#97)
Browse files Browse the repository at this point in the history
* added to doc strings and updated clean_crowns function

* change reset index in clean crowns

* readme tutorial edits

* clean_crowns function can filter by confidence score
  • Loading branch information
PatBall1 committed May 11, 2023
1 parent 1c114cd commit 9c0b816
Show file tree
Hide file tree
Showing 6 changed files with 128 additions and 33 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# detectree2
# :robot: detectree2

[![License: MIT](https://img.shields.io/badge/License-MIT-blue.svg)](https://opensource.org/licenses/MIT) [![Detectree CI](https://github.com/patball1/detectree2/actions/workflows/python-ci.yml/badge.svg)](https://github.com/patball1/detectree2/actions/workflows/python-ci.yml) [![PEP8](https://img.shields.io/badge/code%20style-pep8-orange.svg)](https://www.python.org/dev/peps/pep-0008/) [![DOI](https://zenodo.org/badge/470698486.svg)](https://zenodo.org/badge/latestdoi/470698486)

Expand Down
70 changes: 63 additions & 7 deletions detectree2/models/outputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,14 @@ def project_to_geojson(tiles_path, pred_fold=None, output_fold=None): # noqa:N8
Takes a json and changes it to a geojson so it can overlay with orthomosaic. Another copy is produced to overlay
with PNGs.
Args:
tiles_path (str): Path to the tiles folder.
pred_fold (str): Path to the predictions folder.
output_fold (str): Path to the output folder.
Returns:
None
"""

Path(output_fold).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -232,14 +240,33 @@ def filename_geoinfo(filename):


def box_filter(filename, shift: int = 0):
"""Create a bounding box from a file name to filter edge crowns."""
"""Create a bounding box from a file name to filter edge crowns.
Args:
filename: Name of the file.
shift: Number of meters to shift the size of the bounding box in by. This is to avoid edge crowns.
Returns:
gpd.GeoDataFrame: A GeoDataFrame containing the bounding box."""
minx, miny, width, buffer, crs = filename_geoinfo(filename)
bounding_box = box_make(minx, miny, width, buffer, crs, shift)
return bounding_box


def box_make(minx: int, miny: int, width: int, buffer: int, crs, shift: int = 0):
"""Generate bounding box from geographic specifications."""
"""Generate bounding box from geographic specifications.
Args:
minx: Minimum x coordinate.
miny: Minimum y coordinate.
width: Width of the tile.
buffer: Buffer around the tile.
crs: Coordinate reference system.
shift: Number of meters to shift the size of the bounding box in by. This is to avoid edge crowns.
Returns:
gpd.GeoDataFrame: A GeoDataFrame containing the bounding box.
"""
bbox = box(
minx - buffer + shift,
miny - buffer + shift,
Expand All @@ -251,7 +278,15 @@ def box_make(minx: int, miny: int, width: int, buffer: int, crs, shift: int = 0)


def stitch_crowns(folder: str, shift: int = 1):
"""Stitch together predicted crowns."""
"""Stitch together predicted crowns.
Args:
folder: Path to folder containing geojson files.
shift: Number of meters to shift the size of the bounding box in by. This is to avoid edge crowns.
Returns:
gpd.GeoDataFrame: A GeoDataFrame containing all the crowns.
"""
crowns_path = Path(folder)
files = crowns_path.glob("*geojson")
_, _, _, _, crs = filename_geoinfo(list(files)[0])
Expand Down Expand Up @@ -285,13 +320,28 @@ def calc_iou(shape1, shape2):
return iou


def clean_crowns(crowns: gpd.GeoDataFrame, iou_threshold=0.7):
def clean_crowns(crowns: gpd.GeoDataFrame, iou_threshold=0.7, confidence=0.2):
"""Clean overlapping crowns.
Outputs can contain highly overlapping crowns including in the buffer region.
This function removes crowns with a high degree of overlap with others but a
lower Confidence Score.
Args:
crowns (gpd.GeoDataFrame): Crowns to be cleaned.
iou_threshold (float, optional): IoU threshold that determines whether crowns are overlapping.
confidence (float, optional): Minimum confidence score for crowns to be retained. Defaults to 0.2.
Returns:
gpd.GeoDataFrame: Cleaned crowns.
"""
# Filter any rows with empty geometry
crowns = crowns[crowns.is_empty == False]
# Filter any rows with invalid geometry
crowns = crowns[crowns.is_valid]
# Reset the index
crowns = crowns.reset_index(drop=True)
# Create an object to store the cleaned crowns
crowns_out = gpd.GeoDataFrame()
for index, row in crowns.iterrows(): # iterate over each crown
if index % 1000 == 0:
Expand All @@ -302,7 +352,7 @@ def clean_crowns(crowns: gpd.GeoDataFrame, iou_threshold=0.7):
else:
# Find those crowns that intersect with it
intersecting = crowns.loc[crowns.intersects(shape(row.geometry))]
intersecting = intersecting.reset_index().drop("index", axis=1)
intersecting = intersecting.reset_index(drop=True)
iou = []
for (
index1,
Expand All @@ -313,15 +363,21 @@ def clean_crowns(crowns: gpd.GeoDataFrame, iou_threshold=0.7):
# print(iou)
intersecting["iou"] = iou
matches = intersecting[intersecting["iou"] > iou_threshold] # Remove those crowns with a poor match
matches = matches.sort_values("Confidence_score", ascending=False).reset_index().drop("index", axis=1)
matches = matches.sort_values("Confidence_score", ascending=False).reset_index(drop=True)
match = matches.loc[[0]] # Of the remaining crowns select the crown with the highest confidence
if match["iou"][0] < 1: # If the most confident is not the initial crown
continue
else:
match = match.drop("iou", axis=1)
# print(index)
crowns_out = crowns_out.append(match)
return crowns_out.reset_index()
# Convert pandas into back geopandas if it is not already
if not isinstance(crowns_out, gpd.GeoDataFrame):
crowns_out = gpd.GeoDataFrame(crowns_out)
# Filter remaining crowns based on confidence score
if confidence != 0:
crowns_out = crowns_out[crowns_out["Confidence_score"] > confidence]
return crowns_out.reset_index(drop=True)


def clean_predictions(directory, iou_threshold=0.7):
Expand Down
7 changes: 5 additions & 2 deletions detectree2/models/predict.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
"""Generate predictions."""
"""Generate predictions.
This module contains the code to generate predictions on tiled data.
"""
import json
import os
import random
Expand Down Expand Up @@ -63,4 +66,4 @@ def predict_on_data(


if __name__ == "__main__":
print("something")
predict_on_data()
42 changes: 25 additions & 17 deletions detectree2/models/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ class LossEvalHook(HookBase):
See https://gist.github.com/ortegatron/c0dad15e49c2b74de8bb09a5615d9f6b
Attributes:
model:
period:
data_loader:
model: model to train
period: number of iterations between evaluations
data_loader: data loader to use for evaluation
patience: number of evaluation periods to wait for improvement
"""

Expand Down Expand Up @@ -295,7 +295,8 @@ def get_tree_dicts(directory: str, classes: List[str] = None, classes_at: str =
Args:
directory: Path to directory
classes: Signifies which column (if any) corresponds to the class labels
classes: List of classes to include
classes_at: Signifies which column (if any) corresponds to the class labels
Returns:
List of dictionaries corresponding to segmentations of trees. Each dictionary includes
Expand Down Expand Up @@ -471,18 +472,25 @@ def read_data(out_dir):


def remove_registered_data(name="tree"):
"""Remove registered data from catalog.
Args:
name: string of named registered data
"""
for d in ["train", "val"]:
DatasetCatalog.remove(name + "_" + d)
MetadataCatalog.remove(name + "_" + d)


def register_test_data(test_location, name="tree"):
"""Register data for testing."""
d = "test"
DatasetCatalog.register(name + "_" + d, lambda d=d: get_tree_dicts(test_location))
MetadataCatalog.get(name + "_" + d).set(thing_classes=["tree"])


def load_json_arr(json_path):
"""Load json array."""
lines = []
with open(json_path, "r") as f:
for line in f:
Expand Down Expand Up @@ -517,19 +525,19 @@ def setup_cfg(
trains: names of registered data to use for training
tests: names of registered data to use for evaluating models
update_model: updated pre-trained model from detectree2 model_garden
workers:
ims_per_batch:
gamma:
backbone_freeze:
warm_iter:
momentum:
batch_size_per_im:
base_lr:
weight_decay
max_iter:
num_classes:
eval_period:
out_dir:
workers: number of workers for dataloader
ims_per_batch: number of images per batch
gamma: gamma for learning rate scheduler
backbone_freeze: backbone layer to freeze
warm_iter: number of iterations for warmup
momentum: momentum for optimizer
batch_size_per_im: batch size per image
base_lr: base learning rate
weight_decay: weight decay for optimizer
max_iter: maximum number of iterations
num_classes: number of classes
eval_period: number of iterations between evaluations
out_dir: directory to save outputs
"""
cfg = get_cfg()
cfg.merge_from_file(model_zoo.get_config_file(base_model))
Expand Down
19 changes: 14 additions & 5 deletions detectree2/preprocessing/tiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,7 +423,16 @@ def is_overlapping_box(test_boxes_array, train_box):
def record_data(crowns,
out_dir,
column='status'):
"""Function that will record a list of classes into a file that can be readed during training."""
"""Function that will record a list of classes into a file that can be readed during training.
Args:
crowns: gpd dataframe with the crowns
out_dir: directory to save the file
column: column name to get the classes from
Returns:
None
"""

list_of_classes = crowns[column].unique().tolist()

Expand All @@ -447,10 +456,10 @@ def to_traintest_folders(tiles_folder: str = "./",
"""Send tiles to training (+validation) and test dir and automatically make sure no overlap.
Args:
tiles_folder:
out_folder:
test_frac:
folds:
tiles_folder: folder with tiles
out_folder: folder to save train and test folders
test_frac: fraction of tiles to be used for testing
folds: number of folds to split the data into
Returns:
None
Expand Down
21 changes: 20 additions & 1 deletion docs/source/tutorial.rst
Original file line number Diff line number Diff line change
Expand Up @@ -301,10 +301,29 @@ predictions with lower confidence).

.. code-block:: python
crowns = stitch_crowns(tiles_path + "predictions_geo/", 1)
crowns = stitch_crowns(tiles_path + "predictions_geo/", 1, confidence=0)
crowns = crowns[crowns.is_valid]
crowns = clean_crowns(crowns, 0.6)
By default the ``clean_crowns`` function will remove crowns with a condidence of less than 20%. The above 'clean' crowns
includes crowns of all confidence scores (0%-100%) as ``condidence=0``. It is likely that crowns with very low
confidence will be poor quality so it is usually preferable to filter these out. A suitable threshold can be determined
by eye in QGIS or implemented as single line in Python. ``Confidence_score`` is a column in the ``crowns`` GeoDataFrame
and is considered a tunable parameter.

.. code-block:: python
crowns = crowns[crowns["Confidence_score"] > 0.5]
The outputted crown polygons will have many vertices because they are generated from a mask which is pixelwise. If you
will need to edit the crowns in QGIS it is best to simplify them to a reasonable number of vertices. This can be done
with ``simplify`` method. The ``tolerance`` will determine the coarseness of the simplification it has the same units as
the coordinate reference system of the GeoSeries (meters when working with UTM).

.. code-block:: python
clean = clean.set_geometry(crowns.simplify(0.3))
Once we're happy with the crown map, save the crowns to file.

.. code-block:: python
Expand Down

0 comments on commit 9c0b816

Please sign in to comment.