1
1
"""Landsat vessel prediction pipeline."""
2
2
3
3
import json
4
+ import tempfile
5
+ import time
4
6
from datetime import datetime , timedelta
5
7
6
8
import numpy as np
14
16
from rslearn .dataset import Dataset , Window
15
17
from rslearn .utils import Projection , STGeometry
16
18
from rslearn .utils .get_utm_ups_crs import get_utm_ups_projection
19
+ from typing_extensions import TypedDict
17
20
from upath import UPath
18
21
19
22
from rslp .utils .rslearn import materialize_dataset , run_model_predict
@@ -54,6 +57,16 @@ def __init__(
54
57
self .crop_window_dir = crop_window_dir
55
58
56
59
60
+ class FormattedPrediction (TypedDict ):
61
+ """Formatted prediction for a single vessel detection."""
62
+
63
+ latitude : float
64
+ longitude : float
65
+ score : float
66
+ rgb_fname : str
67
+ b8_fname : str
68
+
69
+
57
70
def get_vessel_detections (
58
71
ds_path : UPath ,
59
72
projection : Projection ,
@@ -180,12 +193,12 @@ def run_classifier(
180
193
181
194
182
195
def predict_pipeline (
183
- scratch_path : str ,
184
- json_path : str ,
185
- crop_path : str ,
196
+ crop_path : str | None = None ,
197
+ scratch_path : str | None = None ,
198
+ json_path : str | None = None ,
186
199
image_files : dict [str , str ] | None = None ,
187
200
scene_id : str | None = None ,
188
- ) -> None :
201
+ ) -> list [ FormattedPrediction ] :
189
202
"""Run the Landsat vessel prediction pipeline.
190
203
191
204
This inputs a Landsat scene (consisting of per-band GeoTIFFs) and produces the
@@ -201,6 +214,15 @@ def predict_pipeline(
201
214
scene_id: Landsat scene ID. Exactly one of image_files or scene_id should be
202
215
specified.
203
216
"""
217
+ start_time = time .time () # Start the timer
218
+ time_profile = {}
219
+
220
+ if scratch_path is None :
221
+ tmp_dir = tempfile .TemporaryDirectory ()
222
+ scratch_path = tmp_dir .name
223
+ else :
224
+ tmp_dir = None
225
+
204
226
ds_path = UPath (scratch_path )
205
227
ds_path .mkdir (parents = True , exist_ok = True )
206
228
@@ -259,18 +281,29 @@ def predict_pipeline(
259
281
dst_geom .time_range [1 ] + timedelta (minutes = 30 ),
260
282
)
261
283
284
+ time_profile ["setup" ] = time .time () - start_time
285
+
262
286
# Run pipeline.
287
+ step_start_time = time .time ()
288
+ print ("run detector" )
263
289
detections = get_vessel_detections (
264
290
ds_path ,
265
291
projection ,
266
292
scene_bounds , # type: ignore
267
293
time_range = time_range ,
268
294
)
295
+ time_profile ["get_vessel_detections" ] = time .time () - step_start_time
296
+
297
+ step_start_time = time .time ()
298
+ print ("run classifier" )
269
299
detections = run_classifier (ds_path , detections , time_range = time_range )
300
+ time_profile ["run_classifier" ] = time .time () - step_start_time
270
301
271
302
# Write JSON and crops.
272
- json_upath = UPath (json_path )
273
- crop_upath = UPath (crop_path )
303
+ step_start_time = time .time ()
304
+ if crop_path :
305
+ crop_upath = UPath (crop_path )
306
+ crop_upath .mkdir (parents = True , exist_ok = True )
274
307
275
308
json_data = []
276
309
for idx , detection in enumerate (detections ):
@@ -304,13 +337,17 @@ def predict_pipeline(
304
337
[images ["B4_sharp" ], images ["B3_sharp" ], images ["B2_sharp" ]], axis = 2
305
338
)
306
339
307
- rgb_fname = crop_upath / f"{ idx } _rgb.png"
308
- with rgb_fname .open ("wb" ) as f :
309
- Image .fromarray (rgb ).save (f , format = "PNG" )
340
+ if crop_path :
341
+ rgb_fname = crop_upath / f"{ idx } _rgb.png"
342
+ with rgb_fname .open ("wb" ) as f :
343
+ Image .fromarray (rgb ).save (f , format = "PNG" )
310
344
311
- b8_fname = crop_upath / f"{ idx } _b8.png"
312
- with b8_fname .open ("wb" ) as f :
313
- Image .fromarray (images ["B8" ]).save (f , format = "PNG" )
345
+ b8_fname = crop_upath / f"{ idx } _b8.png"
346
+ with b8_fname .open ("wb" ) as f :
347
+ Image .fromarray (images ["B8" ]).save (f , format = "PNG" )
348
+ else :
349
+ rgb_fname = ""
350
+ b8_fname = ""
314
351
315
352
# Get longitude/latitude.
316
353
src_geom = STGeometry (
@@ -321,14 +358,31 @@ def predict_pipeline(
321
358
lat = dst_geom .shp .y
322
359
323
360
json_data .append (
324
- dict (
361
+ FormattedPrediction (
325
362
longitude = lon ,
326
363
latitude = lat ,
327
364
score = detection .score ,
328
- rgb_fname = str ( rgb_fname ) ,
329
- b8_fname = str ( b8_fname ) ,
330
- )
365
+ rgb_fname = rgb_fname ,
366
+ b8_fname = b8_fname ,
367
+ ),
331
368
)
332
369
333
- with json_upath .open ("w" ) as f :
334
- json .dump (json_data , f )
370
+ time_profile ["write_json_and_crops" ] = time .time () - step_start_time
371
+
372
+ elapsed_time = time .time () - start_time # Calculate elapsed time
373
+ time_profile ["total" ] = elapsed_time
374
+
375
+ # Clean up any temporary directories.
376
+ if tmp_dir :
377
+ tmp_dir .cleanup ()
378
+
379
+ if json_path :
380
+ json_upath = UPath (json_path )
381
+ with json_upath .open ("w" ) as f :
382
+ json .dump (json_data , f )
383
+
384
+ print (f"Prediction pipeline completed in { elapsed_time :.2f} seconds" )
385
+ for step , duration in time_profile .items ():
386
+ print (f"{ step } took { duration :.2f} seconds" )
387
+
388
+ return json_data
0 commit comments