-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathevaluate.py
240 lines (190 loc) · 7.38 KB
/
evaluate.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import argparse
import logging
import os
from typing import Optional, Union
import fiona
import matplotlib.pyplot as plt
import rasterio
import rasterio.windows
import shapely
import torch
from rasterio.enums import Resampling
from shapely.geometry import box
from torchmetrics import (
Accuracy,
Dice,
F1Score,
JaccardIndex,
MetricCollection,
Precision,
Recall,
)
from tqdm import tqdm
from tcd_pipeline.data.tiling import generate_tiles
logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
def get_overlap(ras1, ras2):
x_scale = ras2.transform.a / ras1.transform.a
y_scale = ras2.transform.e / ras1.transform.e
# scale image transform
transform = ras2.transform * ras2.transform.scale(
(ras2.width / ras2.shape[-1]), (ras2.height / ras2.shape[-2])
)
ext1 = box(*ras1.bounds)
ext2 = box(*ras2.bounds)
intersection = ext1.intersection(ext2)
window = rasterio.windows.from_bounds(*intersection.bounds, ras1.transform)
return window, transform, x_scale, y_scale
def sample_raster(src, bounds, x_scale, y_scale, transform):
win2 = rasterio.windows.from_bounds(*bounds, transform)
data = src.read(
window=win2,
out_shape=(src.count, int(win2.height * y_scale), int(win2.width * x_scale)),
resampling=Resampling.bilinear,
)
return data
def maybe_warp_geometry(
image,
shape: Union[dict, shapely.geometry.Polygon],
crs: Optional[rasterio.crs.CRS] = None,
):
"""Filter by geometry, should be a simple Polygon i.e. a
convex hull that defines the region of interest for analysis.
Args:
shape (dict): shape to filter the data
crs: Coordinate reference system for the region, usually
assumed to be the CRS of the image
"""
if crs is not None and crs != image.crs:
logger.warning("Geometry CRS is not the same as the image CRS, warping")
shape = rasterio.warp.transform_geom(crs, image.crs, shape)
if not isinstance(shape, shapely.geometry.Polygon):
shape = shapely.geometry.shape(shape)
if isinstance(shape, shapely.geometry.MultiPolygon):
shape = shape.geoms[0]
if not isinstance(shape, shapely.geometry.Polygon):
logger.warning("Input shape is not a polygon, not applying filter")
return
return shapely.geometry.polygon.orient(shape)
def load_shape_from_geofile(geometry_path):
geometries = []
features = []
with fiona.open(
geometry_path, "r", enabled_drivers=["GeoJSON", "ESRI Shapefile"]
) as geo:
for feature in geo:
geometry = feature["geometry"]
geometries.append(geometry)
features.append(feature)
geom_crs = geo.crs
return geometries[0], geom_crs
def evaluate_semantic(args, threshold=1):
"""
Evaluate a semantic segmentation prediction against a raster ground truth.
"""
metrics = MetricCollection(
{
"accuracy": Accuracy(task="binary"),
"iou": JaccardIndex(task="binary"),
"f1": F1Score(task="binary"),
"precision": Precision(task="binary"),
"recall": Recall(task="binary"),
"dice": Dice(),
}
)
with rasterio.open(args.prediction) as pred:
geometry, geometry_crs = load_shape_from_geofile(args.geometry)
valid_region = maybe_warp_geometry(pred, geometry, geometry_crs)
tile_size = args.tile_size
with rasterio.open(args.ground_truth) as gt:
# Get extent of bounding region
overlap_window, transform, x_scale, y_scale = get_overlap(pred, gt)
# Generate tiles within the region
tiles = generate_tiles(
overlap_window.height, overlap_window.width, tile_size
)
pbar = tqdm(enumerate(tiles), total=len(tiles))
for idx, tile in pbar:
# Generate a window for the current tile
minx, miny, maxx, maxy = tile.bounds
window = rasterio.windows.Window(
minx + overlap_window.col_off,
miny + overlap_window.row_off,
width=int(maxx - minx),
height=int(maxy - miny),
)
pred_data = pred.read(window=window)[0] / 255.0
tile_bounds = rasterio.windows.bounds(window, pred.transform)
# Sample the tile from the other raster, and scale if necessary
gt_data = sample_raster(gt, tile_bounds, x_scale, y_scale, transform)[0]
# Get intersection between the current tile and the
# mask geometry
intersection = valid_region.intersection(box(*tile_bounds))
if shapely.is_empty(intersection):
continue
# Generate a mask for the current tile, invert so that we can
# get a mask of valid pixels. NB shape is a numpy shape, so
# (height, width) order
mask = rasterio.features.geometry_mask(
[intersection],
out_shape=(window.height, window.width),
transform=rasterio.windows.transform(window, pred.transform),
invert=True,
)
# print(pred_data.shape, gt_data.shape, mask.shape)
res = metrics(
torch.from_numpy(pred_data[mask]),
torch.from_numpy(gt_data[mask] > threshold),
)
pbar.set_postfix_str(res)
pbar.update()
# Debug
"""
fig = plt.figure()
plt.subplot(131)
plt.imshow(pred_data * mask > 0.5)
plt.subplot(132)
plt.imshow(gt_data > 1)
plt.subplot(133)
plt.imshow(mask)
plt.title(res)
plt.savefig(os.path.join("temp", f"out_tiles_{idx}.jpg"), bbox_inches='tight')
plt.close(fig)
"""
# Summarise
with open(args.result, "w") as fp:
import json
json.dump({k: float(v) for k, v in metrics.compute().items()}, fp)
def evaluate_instance(args):
from pycocotools import cocoeval
from pycocotools.coco import COCO
gt = COCO(args.ground_truth)
pred = gt.loadRes(args.prediction)
eval = cocoeval.COCOeval(gt, pred, iouType="segm")
eval.evaluate()
eval.accumulate()
eval.summarize()
with open(args.result, "w") as fp:
fp.write(eval)
def main(args):
if args.task == "semantic":
evaluate_semantic(args)
elif args.task == "instance":
evaluate_instance(args)
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("task", help="Root folder", choices=["semantic", "instance"])
parser.add_argument("prediction", help="Prediction result (GeoTIFF or COCO JSON)")
parser.add_argument("ground_truth", help="Ground truth (GeoTIFF or COCO JSON)")
parser.add_argument("result", help="Output metric file")
parser.add_argument(
"-g", "--geometry", help="Region of interest to perform analysis on"
)
parser.add_argument(
"-s",
"--tile_size",
default=10000,
help="Tile size used for metric evaluation - RAM limited",
)
args = parser.parse_args()
main(args)