11
11
import shapely
12
12
from PIL import Image
13
13
from rslearn .const import WGS84_PROJECTION
14
- from rslearn .data_sources import data_source_from_config
14
+ from rslearn .data_sources import Item , data_source_from_config
15
15
from rslearn .data_sources .aws_landsat import LandsatOliTirs
16
- from rslearn .dataset import Dataset , Window
16
+ from rslearn .dataset import Dataset , Window , WindowLayerData
17
17
from rslearn .utils import Projection , STGeometry
18
18
from rslearn .utils .get_utm_ups_crs import get_utm_ups_projection
19
19
from typing_extensions import TypedDict
20
20
from upath import UPath
21
21
22
22
from rslp .utils .rslearn import materialize_dataset , run_model_predict
23
23
24
+ LANDSAT_LAYER_NAME = "landsat"
24
25
LOCAL_FILES_DATASET_CONFIG = "data/landsat_vessels/predict_dataset_config.json"
25
26
AWS_DATASET_CONFIG = "data/landsat_vessels/predict_dataset_config_aws.json"
26
27
DETECT_MODEL_CONFIG = "data/landsat_vessels/config.yaml"
@@ -72,6 +73,7 @@ def get_vessel_detections(
72
73
projection : Projection ,
73
74
bounds : tuple [int , int , int , int ],
74
75
time_range : tuple [datetime , datetime ] | None = None ,
76
+ item : Item | None = None ,
75
77
) -> list [VesselDetection ]:
76
78
"""Apply the vessel detector.
77
79
@@ -85,22 +87,30 @@ def get_vessel_detections(
85
87
bounds: the bounds to apply the detector in.
86
88
time_range: optional time range to apply the detector in (in case the data
87
89
source needs an actual time range).
90
+ item: only ingest this item. This is set if we are getting the scene directly
91
+ from a Landsat data source, not local file.
88
92
"""
89
93
# Create a window for applying detector.
90
94
group = "default"
91
95
window_path = ds_path / "windows" / group / "default"
92
- Window (
96
+ window = Window (
93
97
path = window_path ,
94
98
group = group ,
95
99
name = "default" ,
96
100
projection = projection ,
97
101
bounds = bounds ,
98
102
time_range = time_range ,
99
- ).save ()
103
+ )
104
+ window .save ()
105
+
106
+ # Restrict to the item if set.
107
+ if item :
108
+ layer_data = WindowLayerData (LANDSAT_LAYER_NAME , [[item .serialize ()]])
109
+ window .save_layer_datas (dict (LANDSAT_LAYER_NAME = layer_data ))
100
110
101
111
print ("materialize dataset" )
102
112
materialize_dataset (ds_path , group = group )
103
- assert (window_path / "layers" / "landsat" / "B8" / "geotiff.tif" ).exists ()
113
+ assert (window_path / "layers" / LANDSAT_LAYER_NAME / "B8" / "geotiff.tif" ).exists ()
104
114
105
115
# Run object detector.
106
116
run_model_predict (DETECT_MODEL_CONFIG , ds_path )
@@ -131,6 +141,7 @@ def run_classifier(
131
141
ds_path : UPath ,
132
142
detections : list [VesselDetection ],
133
143
time_range : tuple [datetime , datetime ] | None = None ,
144
+ item : Item | None = None ,
134
145
) -> list [VesselDetection ]:
135
146
"""Run the classifier to try to prune false positive detections.
136
147
@@ -140,6 +151,7 @@ def run_classifier(
140
151
detections: the detections from the detector.
141
152
time_range: optional time range to apply the detector in (in case the data
142
153
source needs an actual time range).
154
+ item: only ingest this item.
143
155
144
156
Returns:
145
157
the subset of detections that pass the classifier.
@@ -161,20 +173,27 @@ def run_classifier(
161
173
detection .col + CLASSIFY_WINDOW_SIZE // 2 ,
162
174
detection .row + CLASSIFY_WINDOW_SIZE // 2 ,
163
175
]
164
- Window (
176
+ window = Window (
165
177
path = window_path ,
166
178
group = group ,
167
179
name = window_name ,
168
180
projection = detection .projection ,
169
181
bounds = bounds ,
170
182
time_range = time_range ,
171
- ).save ()
183
+ )
184
+ window .save ()
172
185
window_paths .append (window_path )
173
186
187
+ if item :
188
+ layer_data = WindowLayerData (LANDSAT_LAYER_NAME , [[item .serialize ()]])
189
+ window .save_layer_datas (dict (LANDSAT_LAYER_NAME = layer_data ))
190
+
174
191
print ("materialize dataset" )
175
192
materialize_dataset (ds_path , group = group )
176
193
for window_path in window_paths :
177
- assert (window_path / "layers" / "landsat" / "B8" / "geotiff.tif" ).exists ()
194
+ assert (
195
+ window_path / "layers" / LANDSAT_LAYER_NAME / "B8" / "geotiff.tif"
196
+ ).exists ()
178
197
179
198
# Run classification model.
180
199
run_model_predict (CLASSIFY_MODEL_CONFIG , ds_path )
@@ -225,6 +244,7 @@ def predict_pipeline(
225
244
226
245
ds_path = UPath (scratch_path )
227
246
ds_path .mkdir (parents = True , exist_ok = True )
247
+ item = None
228
248
229
249
if image_files :
230
250
# Setup the dataset configuration file with the provided image files.
@@ -238,7 +258,7 @@ def predict_pipeline(
238
258
cfg ["src_dir" ] = str (UPath (image_path ).parent )
239
259
item_spec ["fnames" ].append (image_path )
240
260
item_spec ["bands" ].append ([band ])
241
- cfg ["layers" ]["landsat" ]["data_source" ]["item_specs" ] = [item_spec ]
261
+ cfg ["layers" ][LANDSAT_LAYER_NAME ]["data_source" ]["item_specs" ] = [item_spec ]
242
262
243
263
with (ds_path / "config.json" ).open ("w" ) as f :
244
264
json .dump (cfg , f )
@@ -251,7 +271,12 @@ def predict_pipeline(
251
271
)
252
272
left = int (raster .transform .c / projection .x_resolution )
253
273
top = int (raster .transform .f / projection .y_resolution )
254
- scene_bounds = [left , top , left + raster .width , top + raster .height ]
274
+ scene_bounds = (
275
+ left ,
276
+ top ,
277
+ left + int (raster .width ),
278
+ top + int (raster .height ),
279
+ )
255
280
256
281
time_range = None
257
282
@@ -264,7 +289,7 @@ def predict_pipeline(
264
289
# Get the projection and scene bounds using the Landsat data source.
265
290
dataset = Dataset (ds_path )
266
291
data_source : LandsatOliTirs = data_source_from_config (
267
- dataset .layers ["landsat" ], dataset .path
292
+ dataset .layers [LANDSAT_LAYER_NAME ], dataset .path
268
293
)
269
294
item = data_source .get_item_by_name (scene_id )
270
295
wgs84_geom = item .geometry .to_projection (WGS84_PROJECTION )
@@ -275,7 +300,12 @@ def predict_pipeline(
275
300
- LANDSAT_RESOLUTION ,
276
301
)
277
302
dst_geom = item .geometry .to_projection (projection )
278
- scene_bounds = [int (value ) for value in dst_geom .shp .bounds ]
303
+ scene_bounds = (
304
+ int (dst_geom .shp .bounds [0 ]),
305
+ int (dst_geom .shp .bounds [1 ]),
306
+ int (dst_geom .shp .bounds [2 ]),
307
+ int (dst_geom .shp .bounds [3 ]),
308
+ )
279
309
time_range = (
280
310
dst_geom .time_range [0 ] - timedelta (minutes = 30 ),
281
311
dst_geom .time_range [1 ] + timedelta (minutes = 30 ),
@@ -289,14 +319,15 @@ def predict_pipeline(
289
319
detections = get_vessel_detections (
290
320
ds_path ,
291
321
projection ,
292
- scene_bounds , # type: ignore
322
+ scene_bounds ,
293
323
time_range = time_range ,
324
+ item = item ,
294
325
)
295
326
time_profile ["get_vessel_detections" ] = time .time () - step_start_time
296
327
297
328
step_start_time = time .time ()
298
329
print ("run classifier" )
299
- detections = run_classifier (ds_path , detections , time_range = time_range )
330
+ detections = run_classifier (ds_path , detections , time_range = time_range , item = item )
300
331
time_profile ["run_classifier" ] = time .time () - step_start_time
301
332
302
333
# Write JSON and crops.
@@ -313,7 +344,11 @@ def predict_pipeline(
313
344
raise ValueError ("Crop window directory is None" )
314
345
for band in ["B2" , "B3" , "B4" , "B8" ]:
315
346
image_fname = (
316
- detection .crop_window_dir / "layers" / "landsat" / band / "geotiff.tif"
347
+ detection .crop_window_dir
348
+ / "layers"
349
+ / LANDSAT_LAYER_NAME
350
+ / band
351
+ / "geotiff.tif"
317
352
)
318
353
with image_fname .open ("rb" ) as f :
319
354
with rasterio .open (f ) as src :
0 commit comments