Skip to content

Commit c00b1ae

Browse files
authored
Merge pull request #35 from allenai/favyen/predict-only-ingest-given-scene
Update prediction pipelines to only ingest the given scene
2 parents 7938c2f + 954894a commit c00b1ae

File tree

17 files changed

+187
-35
lines changed

17 files changed

+187
-35
lines changed

rslp/landsat_vessels/predict_pipeline.py

Lines changed: 50 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,17 @@
1111
import shapely
1212
from PIL import Image
1313
from rslearn.const import WGS84_PROJECTION
14-
from rslearn.data_sources import data_source_from_config
14+
from rslearn.data_sources import Item, data_source_from_config
1515
from rslearn.data_sources.aws_landsat import LandsatOliTirs
16-
from rslearn.dataset import Dataset, Window
16+
from rslearn.dataset import Dataset, Window, WindowLayerData
1717
from rslearn.utils import Projection, STGeometry
1818
from rslearn.utils.get_utm_ups_crs import get_utm_ups_projection
1919
from typing_extensions import TypedDict
2020
from upath import UPath
2121

2222
from rslp.utils.rslearn import materialize_dataset, run_model_predict
2323

24+
LANDSAT_LAYER_NAME = "landsat"
2425
LOCAL_FILES_DATASET_CONFIG = "data/landsat_vessels/predict_dataset_config.json"
2526
AWS_DATASET_CONFIG = "data/landsat_vessels/predict_dataset_config_aws.json"
2627
DETECT_MODEL_CONFIG = "data/landsat_vessels/config.yaml"
@@ -72,6 +73,7 @@ def get_vessel_detections(
7273
projection: Projection,
7374
bounds: tuple[int, int, int, int],
7475
time_range: tuple[datetime, datetime] | None = None,
76+
item: Item | None = None,
7577
) -> list[VesselDetection]:
7678
"""Apply the vessel detector.
7779
@@ -85,22 +87,30 @@ def get_vessel_detections(
8587
bounds: the bounds to apply the detector in.
8688
time_range: optional time range to apply the detector in (in case the data
8789
source needs an actual time range).
90+
item: only ingest this item. This is set if we are getting the scene directly
91+
from a Landsat data source, not local file.
8892
"""
8993
# Create a window for applying detector.
9094
group = "default"
9195
window_path = ds_path / "windows" / group / "default"
92-
Window(
96+
window = Window(
9397
path=window_path,
9498
group=group,
9599
name="default",
96100
projection=projection,
97101
bounds=bounds,
98102
time_range=time_range,
99-
).save()
103+
)
104+
window.save()
105+
106+
# Restrict to the item if set.
107+
if item:
108+
layer_data = WindowLayerData(LANDSAT_LAYER_NAME, [[item.serialize()]])
109+
window.save_layer_datas(dict(LANDSAT_LAYER_NAME=layer_data))
100110

101111
print("materialize dataset")
102112
materialize_dataset(ds_path, group=group)
103-
assert (window_path / "layers" / "landsat" / "B8" / "geotiff.tif").exists()
113+
assert (window_path / "layers" / LANDSAT_LAYER_NAME / "B8" / "geotiff.tif").exists()
104114

105115
# Run object detector.
106116
run_model_predict(DETECT_MODEL_CONFIG, ds_path)
@@ -131,6 +141,7 @@ def run_classifier(
131141
ds_path: UPath,
132142
detections: list[VesselDetection],
133143
time_range: tuple[datetime, datetime] | None = None,
144+
item: Item | None = None,
134145
) -> list[VesselDetection]:
135146
"""Run the classifier to try to prune false positive detections.
136147
@@ -140,6 +151,7 @@ def run_classifier(
140151
detections: the detections from the detector.
141152
time_range: optional time range to apply the detector in (in case the data
142153
source needs an actual time range).
154+
item: only ingest this item.
143155
144156
Returns:
145157
the subset of detections that pass the classifier.
@@ -161,20 +173,27 @@ def run_classifier(
161173
detection.col + CLASSIFY_WINDOW_SIZE // 2,
162174
detection.row + CLASSIFY_WINDOW_SIZE // 2,
163175
]
164-
Window(
176+
window = Window(
165177
path=window_path,
166178
group=group,
167179
name=window_name,
168180
projection=detection.projection,
169181
bounds=bounds,
170182
time_range=time_range,
171-
).save()
183+
)
184+
window.save()
172185
window_paths.append(window_path)
173186

187+
if item:
188+
layer_data = WindowLayerData(LANDSAT_LAYER_NAME, [[item.serialize()]])
189+
window.save_layer_datas(dict(LANDSAT_LAYER_NAME=layer_data))
190+
174191
print("materialize dataset")
175192
materialize_dataset(ds_path, group=group)
176193
for window_path in window_paths:
177-
assert (window_path / "layers" / "landsat" / "B8" / "geotiff.tif").exists()
194+
assert (
195+
window_path / "layers" / LANDSAT_LAYER_NAME / "B8" / "geotiff.tif"
196+
).exists()
178197

179198
# Run classification model.
180199
run_model_predict(CLASSIFY_MODEL_CONFIG, ds_path)
@@ -225,6 +244,7 @@ def predict_pipeline(
225244

226245
ds_path = UPath(scratch_path)
227246
ds_path.mkdir(parents=True, exist_ok=True)
247+
item = None
228248

229249
if image_files:
230250
# Setup the dataset configuration file with the provided image files.
@@ -238,7 +258,7 @@ def predict_pipeline(
238258
cfg["src_dir"] = str(UPath(image_path).parent)
239259
item_spec["fnames"].append(image_path)
240260
item_spec["bands"].append([band])
241-
cfg["layers"]["landsat"]["data_source"]["item_specs"] = [item_spec]
261+
cfg["layers"][LANDSAT_LAYER_NAME]["data_source"]["item_specs"] = [item_spec]
242262

243263
with (ds_path / "config.json").open("w") as f:
244264
json.dump(cfg, f)
@@ -251,7 +271,12 @@ def predict_pipeline(
251271
)
252272
left = int(raster.transform.c / projection.x_resolution)
253273
top = int(raster.transform.f / projection.y_resolution)
254-
scene_bounds = [left, top, left + raster.width, top + raster.height]
274+
scene_bounds = (
275+
left,
276+
top,
277+
left + int(raster.width),
278+
top + int(raster.height),
279+
)
255280

256281
time_range = None
257282

@@ -264,7 +289,7 @@ def predict_pipeline(
264289
# Get the projection and scene bounds using the Landsat data source.
265290
dataset = Dataset(ds_path)
266291
data_source: LandsatOliTirs = data_source_from_config(
267-
dataset.layers["landsat"], dataset.path
292+
dataset.layers[LANDSAT_LAYER_NAME], dataset.path
268293
)
269294
item = data_source.get_item_by_name(scene_id)
270295
wgs84_geom = item.geometry.to_projection(WGS84_PROJECTION)
@@ -275,7 +300,12 @@ def predict_pipeline(
275300
-LANDSAT_RESOLUTION,
276301
)
277302
dst_geom = item.geometry.to_projection(projection)
278-
scene_bounds = [int(value) for value in dst_geom.shp.bounds]
303+
scene_bounds = (
304+
int(dst_geom.shp.bounds[0]),
305+
int(dst_geom.shp.bounds[1]),
306+
int(dst_geom.shp.bounds[2]),
307+
int(dst_geom.shp.bounds[3]),
308+
)
279309
time_range = (
280310
dst_geom.time_range[0] - timedelta(minutes=30),
281311
dst_geom.time_range[1] + timedelta(minutes=30),
@@ -289,14 +319,15 @@ def predict_pipeline(
289319
detections = get_vessel_detections(
290320
ds_path,
291321
projection,
292-
scene_bounds, # type: ignore
322+
scene_bounds,
293323
time_range=time_range,
324+
item=item,
294325
)
295326
time_profile["get_vessel_detections"] = time.time() - step_start_time
296327

297328
step_start_time = time.time()
298329
print("run classifier")
299-
detections = run_classifier(ds_path, detections, time_range=time_range)
330+
detections = run_classifier(ds_path, detections, time_range=time_range, item=item)
300331
time_profile["run_classifier"] = time.time() - step_start_time
301332

302333
# Write JSON and crops.
@@ -313,7 +344,11 @@ def predict_pipeline(
313344
raise ValueError("Crop window directory is None")
314345
for band in ["B2", "B3", "B4", "B8"]:
315346
image_fname = (
316-
detection.crop_window_dir / "layers" / "landsat" / band / "geotiff.tif"
347+
detection.crop_window_dir
348+
/ "layers"
349+
/ LANDSAT_LAYER_NAME
350+
/ band
351+
/ "geotiff.tif"
317352
)
318353
with image_fname.open("rb") as f:
319354
with rasterio.open(f) as src:

rslp/launch_beaker.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,13 +57,17 @@ def launch_job(
5757
value="prior-satlas", # nosec
5858
),
5959
EnvVar(
60-
name="S3_ACCESS_KEY_ID", # nosec
60+
name="WEKA_ACCESS_KEY_ID", # nosec
6161
secret="RSLEARN_WEKA_KEY", # nosec
6262
),
6363
EnvVar(
64-
name="S3_SECRET_ACCESS_KEY", # nosec
64+
name="WEKA_SECRET_ACCESS_KEY", # nosec
6565
secret="RSLEARN_WEKA_SECRET", # nosec
6666
),
67+
EnvVar(
68+
name="WEKA_ENDPOINT_URL", # nosec
69+
value="https://weka-aus.beaker.org:9000", # nosec
70+
),
6771
EnvVar(
6872
name="RSLP_PROJECT", # nosec
6973
value=project_id,

rslp/lightning_cli.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from rslearn.main import RslearnLightningCLI
1111
from upath import UPath
1212

13+
import rslp.utils.fs # noqa: F401 (imported but unused)
1314
from rslp import launcher_lib
1415

1516
CHECKPOINT_DIR = "gs://{rslp_bucket}/projects/{project_id}/{experiment_id}/checkpoints/"

rslp/sentinel2_vessels/predict_pipeline.py

Lines changed: 29 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,16 @@
88
import shapely
99
from PIL import Image
1010
from rslearn.const import WGS84_PROJECTION
11-
from rslearn.data_sources import data_source_from_config
11+
from rslearn.data_sources import Item, data_source_from_config
1212
from rslearn.data_sources.gcp_public_data import Sentinel2
13-
from rslearn.dataset import Dataset, Window
13+
from rslearn.dataset import Dataset, Window, WindowLayerData
1414
from rslearn.utils import Projection, STGeometry
1515
from rslearn.utils.get_utm_ups_crs import get_utm_ups_projection
1616
from upath import UPath
1717

1818
from rslp.utils.rslearn import materialize_dataset, run_model_predict
1919

20+
SENTINEL2_LAYER_NAME = "sentinel2"
2021
DATASET_CONFIG = "data/sentinel2_vessels/config.json"
2122
DETECT_MODEL_CONFIG = "data/sentinel2_vessels/config.yaml"
2223
SENTINEL2_RESOLUTION = 10
@@ -59,6 +60,7 @@ def get_vessel_detections(
5960
projection: Projection,
6061
bounds: tuple[int, int, int, int],
6162
ts: datetime,
63+
item: Item,
6264
) -> list[VesselDetection]:
6365
"""Apply the vessel detector.
6466
@@ -71,22 +73,30 @@ def get_vessel_detections(
7173
projection: the projection to apply the detector in.
7274
bounds: the bounds to apply the detector in.
7375
ts: timestamp to apply the detector on.
76+
item: the item to ingest.
7477
"""
7578
# Create a window for applying detector.
7679
group = "detector_predict"
7780
window_path = ds_path / "windows" / group / "default"
78-
Window(
81+
window = Window(
7982
path=window_path,
8083
group=group,
8184
name="default",
8285
projection=projection,
8386
bounds=bounds,
8487
time_range=(ts - timedelta(minutes=20), ts + timedelta(minutes=20)),
85-
).save()
88+
)
89+
window.save()
90+
91+
if item:
92+
layer_data = WindowLayerData(SENTINEL2_LAYER_NAME, [[item.serialize()]])
93+
window.save_layer_datas(dict(SENTINEL2_LAYER_NAME=layer_data))
8694

8795
print("materialize dataset")
8896
materialize_dataset(ds_path, group=group, workers=1)
89-
assert (window_path / "layers" / "sentinel2" / "R_G_B" / "geotiff.tif").exists()
97+
assert (
98+
window_path / "layers" / SENTINEL2_LAYER_NAME / "R_G_B" / "geotiff.tif"
99+
).exists()
90100

91101
# Run object detector.
92102
run_model_predict(DETECT_MODEL_CONFIG, ds_path)
@@ -141,7 +151,7 @@ def predict_pipeline(
141151
# Determine the bounds and timestamp of this scene using the data source.
142152
dataset = Dataset(ds_path)
143153
data_source: Sentinel2 = data_source_from_config(
144-
dataset.layers["sentinel2"], dataset.path
154+
dataset.layers[SENTINEL2_LAYER_NAME], dataset.path
145155
)
146156
item = data_source.get_item_by_name(scene_id)
147157
wgs84_geom = item.geometry.to_projection(WGS84_PROJECTION)
@@ -152,12 +162,15 @@ def predict_pipeline(
152162
-SENTINEL2_RESOLUTION,
153163
)
154164
dst_geom = item.geometry.to_projection(projection)
155-
bounds = tuple(int(value) for value in dst_geom.shp.bounds)
156-
if len(bounds) != 4:
157-
raise ValueError(f"Expected 4 bounds, got {len(bounds)}")
165+
bounds = (
166+
int(dst_geom.shp.bounds[0]),
167+
int(dst_geom.shp.bounds[1]),
168+
int(dst_geom.shp.bounds[2]),
169+
int(dst_geom.shp.bounds[3]),
170+
)
158171
ts = item.geometry.time_range[0]
159172

160-
detections = get_vessel_detections(ds_path, projection, bounds, ts) # type: ignore
173+
detections = get_vessel_detections(ds_path, projection, bounds, ts, item)
161174

162175
# Create windows just to collect crops for each detection.
163176
group = "crops"
@@ -166,13 +179,11 @@ def predict_pipeline(
166179
window_name = f"{detection.col}_{detection.row}"
167180
window_path = ds_path / "windows" / group / window_name
168181
detection.crop_window_dir = window_path
169-
bounds = tuple(
170-
[
171-
detection.col - CROP_WINDOW_SIZE // 2,
172-
detection.row - CROP_WINDOW_SIZE // 2,
173-
detection.col + CROP_WINDOW_SIZE // 2,
174-
detection.row + CROP_WINDOW_SIZE // 2,
175-
]
182+
bounds = (
183+
detection.col - CROP_WINDOW_SIZE // 2,
184+
detection.row - CROP_WINDOW_SIZE // 2,
185+
detection.col + CROP_WINDOW_SIZE // 2,
186+
detection.row + CROP_WINDOW_SIZE // 2,
176187
)
177188
Window(
178189
path=window_path,
@@ -193,7 +204,7 @@ def predict_pipeline(
193204
for detection, crop_window_path in zip(detections, window_paths):
194205
# Get RGB crop.
195206
image_fname = (
196-
crop_window_path / "layers" / "sentinel2" / "R_G_B" / "geotiff.tif"
207+
crop_window_path / "layers" / SENTINEL2_LAYER_NAME / "R_G_B" / "geotiff.tif"
197208
)
198209
with image_fname.open("rb") as f:
199210
with rasterio.open(f) as src:

rslp/transforms/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""rslp transforms.
2+
3+
These transforms should be ones that are not general enough to include in rslearn, but
4+
still relevant across multiple rslp projects.
5+
6+
If it is project-specific, it should go in rslp/[project_name]/train.py or similar.
7+
"""

rslp/utils/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
"""rslp utilities.
2+
3+
These utilities should be ones that are not general enough to include in rslearn, but
4+
still relevant across multiple rslp projects.
5+
6+
If it is project-specific, it should go in rslp/[project_name]/util.py or similar.
7+
"""

0 commit comments

Comments
 (0)