Skip to content

Commit bc424fd

Browse files
authored
Merge pull request #34 from allenai/pbeukema/fastapi
FastAPI for Landsat vessel detection
2 parents e509899 + b1cea89 commit bc424fd

File tree

10 files changed

+254
-25
lines changed

10 files changed

+254
-25
lines changed

.github/workflows/build_test.yaml

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,10 +91,20 @@ jobs:
9191
run: |
9292
COMPOSE_DOCKER_CLI_BUILD=1 DOCKER_BUILDKIT=1 docker compose -f docker-compose.yaml build
9393
94+
- name: Authenticate into gcp
95+
uses: "google-github-actions/auth@v2"
96+
with:
97+
credentials_json: ${{ secrets.GOOGLE_CREDENTIALS }}
9498

9599
- name: Run tests with Docker Compose
96100
run: |
97-
docker compose -f docker-compose.yaml run test pytest tests/
101+
docker compose -f docker-compose.yaml run \
102+
-e AWS_ACCESS_KEY_ID=${{ secrets.AWS_ACCESS_KEY_ID }} \
103+
-e AWS_SECRET_ACCESS_KEY=${{ secrets.AWS_SECRET_ACCESS_KEY }} \
104+
-v ${{env.GOOGLE_GHA_CREDS_PATH}}:/tmp/gcp-credentials.json:ro \
105+
-e GOOGLE_APPLICATION_CREDENTIALS=/tmp/gcp-credentials.json \
106+
-e RSLP_BUCKET=rslearn-eai \
107+
test pytest tests/ --ignore tests/integration_slow/
98108
99109
- name: Clean up
100110
if: always()

Dockerfile

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,8 @@
11
FROM pytorch/pytorch:2.4.0-cuda11.8-cudnn9-runtime@sha256:58a28ab734f23561aa146fbaf777fb319a953ca1e188832863ed57d510c9f197
22

3-
# TEMPORARY Until RSLEARN Is Public
4-
ARG GIT_USERNAME
5-
ARG GIT_TOKEN
6-
73
RUN apt update
84
RUN apt install -y libpq-dev ffmpeg libsm6 libxext6 git
9-
RUN git clone https://${GIT_USERNAME}:${GIT_TOKEN}@github.com/allenai/rslearn.git /opt/rslearn_projects/rslearn
5+
RUN git clone https://github.com/allenai/rslearn.git /opt/rslearn_projects/rslearn
106
RUN pip install -r /opt/rslearn_projects/rslearn/requirements.txt
117
RUN pip install -r /opt/rslearn_projects/rslearn/extra_requirements.txt
128
COPY requirements.txt /opt/rslearn_projects/requirements.txt

landsat/recheck_landsat_labels/phase123_config.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,12 @@ data:
5252
allow_invalid: true
5353
skip_unknown_categories: true
5454
prob_property: "prob"
55+
positive_class: "correct"
56+
positive_class_threshold: 0.85
5557
input_mapping:
5658
class:
5759
label: "targets"
58-
batch_size: 64
60+
batch_size: 32
5961
num_workers: 32
6062
default_config:
6163
transforms:

requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
11
beaker-py
2+
fastapi
23
interrogate
4+
pydantic
35
pytest
46
python-dotenv
7+
ruff
8+
typing-extensions
9+
uvicorn

rslp/landsat_vessels/Dockerfile

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
# Base image
2+
FROM base-image:latest
3+
4+
# Environment variables
5+
ENV PYTHONPATH="/opt/rslearn_projects:${PYTHONPATH}"
6+
ENV LANDSAT_PORT=5555
7+
8+
# Make port 5555 available to the world outside this container
9+
EXPOSE $LANDSAT_PORT
10+
11+
# Run app.py when the container launches
12+
CMD ["python3", "rslp/landsat_vessels/api_main.py"]

rslp/landsat_vessels/api_main.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
"""Landsat Vessel Detection Service."""
2+
3+
from __future__ import annotations
4+
5+
import logging
6+
import multiprocessing
7+
import os
8+
9+
import uvicorn
10+
from fastapi import FastAPI, Response
11+
from pydantic import BaseModel
12+
13+
from rslp.landsat_vessels.predict_pipeline import FormattedPrediction, predict_pipeline
14+
15+
app = FastAPI()
16+
17+
# Set up the logger
18+
logging.basicConfig(level=logging.INFO)
19+
logger = logging.getLogger(__name__)
20+
21+
LANDSAT_HOST = "0.0.0.0"
22+
LANDSAT_PORT = 5555
23+
24+
25+
class LandsatResponse(BaseModel):
26+
"""Response object for vessel detections."""
27+
28+
status: list[str]
29+
predictions: list[FormattedPrediction]
30+
31+
32+
class LandsatRequest(BaseModel):
33+
"""Request object for vessel detections."""
34+
35+
scene_id: str | None = None
36+
image_files: dict[str, str] | None = None
37+
crop_path: str | None = None
38+
scratch_path: str | None = None
39+
json_path: str | None = None
40+
41+
42+
@app.on_event("startup")
43+
async def rslp_init() -> None:
44+
"""Landsat Vessel Service Initialization."""
45+
logger.info("Initializing")
46+
multiprocessing.set_start_method("forkserver", force=True)
47+
multiprocessing.set_forkserver_preload(
48+
[
49+
"rslp.utils.rslearn.materialize_dataset",
50+
"rslp.utils.rslearn.run_model_predict",
51+
]
52+
)
53+
54+
55+
@app.get("/")
56+
async def home() -> dict:
57+
"""Returns a simple message to indicate the service is running."""
58+
return {"message": "Landsat Detections App"}
59+
60+
61+
@app.post("/detections", response_model=LandsatResponse)
62+
async def get_detections(info: LandsatRequest, response: Response) -> LandsatResponse:
63+
"""Returns vessel detections Response object for a given Request object."""
64+
# Ensure that either scene_id or image_files is specified.
65+
if info.scene_id is None and info.image_files is None:
66+
raise ValueError("Either scene_id or image_files must be specified.")
67+
68+
try:
69+
if info.scene_id is not None:
70+
logger.info(f"Received request with scene_id: {info.scene_id}")
71+
elif info.image_files is not None:
72+
logger.info("Received request with image_files")
73+
json_data = predict_pipeline(
74+
crop_path=info.crop_path,
75+
scene_id=info.scene_id,
76+
image_files=info.image_files,
77+
scratch_path=info.scratch_path,
78+
json_path=info.json_path,
79+
)
80+
return LandsatResponse(
81+
status=["success"],
82+
predictions=[pred for pred in json_data],
83+
)
84+
except ValueError as e:
85+
logger.error(f"Value error during prediction pipeline: {e}")
86+
return LandsatResponse(status=["error"], predictions=[])
87+
except Exception as e:
88+
logger.error(f"Unexpected error during prediction pipeline: {e}")
89+
return LandsatResponse(status=["error"], predictions=[])
90+
91+
92+
if __name__ == "__main__":
93+
uvicorn.run(
94+
"api_main:app",
95+
host=os.getenv("LANDSAT_HOST", default="0.0.0.0"),
96+
port=int(os.getenv("LANDSAT_PORT", default=5555)),
97+
proxy_headers=True,
98+
)
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
version: "3.9"
2+
3+
services:
4+
# Define the base image
5+
base-image:
6+
build:
7+
context: ../..
8+
dockerfile: Dockerfile
9+
image: base-image:latest # Tag it as "base-image"
10+
11+
# Define the landsat-vessels service
12+
landsat-vessels:
13+
build:
14+
context: .
15+
dockerfile: Dockerfile
16+
shm_size: '10G' # This adds the shared memory size
17+
depends_on:
18+
- base-image
19+
ports:
20+
- "5555:5555"
21+
environment:
22+
- RSLP_BUCKET
23+
- S3_ACCESS_KEY_ID
24+
- S3_SECRET_ACCESS_KEY
25+
- AWS_ACCESS_KEY_ID
26+
- AWS_SECRET_ACCESS_KEY
27+
- NVIDIA_VISIBLE_DEVICES=all # Make all GPUs visible
28+
deploy:
29+
resources:
30+
reservations:
31+
devices:
32+
- capabilities: [gpu] # Ensure this service can access GPUs
33+
runtime: nvidia # Use the NVIDIA runtime

rslp/landsat_vessels/predict_pipeline.py

Lines changed: 72 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Landsat vessel prediction pipeline."""
22

33
import json
4+
import tempfile
5+
import time
46
from datetime import datetime, timedelta
57

68
import numpy as np
@@ -14,6 +16,7 @@
1416
from rslearn.dataset import Dataset, Window
1517
from rslearn.utils import Projection, STGeometry
1618
from rslearn.utils.get_utm_ups_crs import get_utm_ups_projection
19+
from typing_extensions import TypedDict
1720
from upath import UPath
1821

1922
from rslp.utils.rslearn import materialize_dataset, run_model_predict
@@ -54,6 +57,16 @@ def __init__(
5457
self.crop_window_dir = crop_window_dir
5558

5659

60+
class FormattedPrediction(TypedDict):
61+
"""Formatted prediction for a single vessel detection."""
62+
63+
latitude: float
64+
longitude: float
65+
score: float
66+
rgb_fname: str
67+
b8_fname: str
68+
69+
5770
def get_vessel_detections(
5871
ds_path: UPath,
5972
projection: Projection,
@@ -180,12 +193,12 @@ def run_classifier(
180193

181194

182195
def predict_pipeline(
183-
scratch_path: str,
184-
json_path: str,
185-
crop_path: str,
196+
crop_path: str | None = None,
197+
scratch_path: str | None = None,
198+
json_path: str | None = None,
186199
image_files: dict[str, str] | None = None,
187200
scene_id: str | None = None,
188-
) -> None:
201+
) -> list[FormattedPrediction]:
189202
"""Run the Landsat vessel prediction pipeline.
190203
191204
This inputs a Landsat scene (consisting of per-band GeoTIFFs) and produces the
@@ -201,6 +214,15 @@ def predict_pipeline(
201214
scene_id: Landsat scene ID. Exactly one of image_files or scene_id should be
202215
specified.
203216
"""
217+
start_time = time.time() # Start the timer
218+
time_profile = {}
219+
220+
if scratch_path is None:
221+
tmp_dir = tempfile.TemporaryDirectory()
222+
scratch_path = tmp_dir.name
223+
else:
224+
tmp_dir = None
225+
204226
ds_path = UPath(scratch_path)
205227
ds_path.mkdir(parents=True, exist_ok=True)
206228

@@ -259,18 +281,29 @@ def predict_pipeline(
259281
dst_geom.time_range[1] + timedelta(minutes=30),
260282
)
261283

284+
time_profile["setup"] = time.time() - start_time
285+
262286
# Run pipeline.
287+
step_start_time = time.time()
288+
print("run detector")
263289
detections = get_vessel_detections(
264290
ds_path,
265291
projection,
266292
scene_bounds, # type: ignore
267293
time_range=time_range,
268294
)
295+
time_profile["get_vessel_detections"] = time.time() - step_start_time
296+
297+
step_start_time = time.time()
298+
print("run classifier")
269299
detections = run_classifier(ds_path, detections, time_range=time_range)
300+
time_profile["run_classifier"] = time.time() - step_start_time
270301

271302
# Write JSON and crops.
272-
json_upath = UPath(json_path)
273-
crop_upath = UPath(crop_path)
303+
step_start_time = time.time()
304+
if crop_path:
305+
crop_upath = UPath(crop_path)
306+
crop_upath.mkdir(parents=True, exist_ok=True)
274307

275308
json_data = []
276309
for idx, detection in enumerate(detections):
@@ -304,13 +337,17 @@ def predict_pipeline(
304337
[images["B4_sharp"], images["B3_sharp"], images["B2_sharp"]], axis=2
305338
)
306339

307-
rgb_fname = crop_upath / f"{idx}_rgb.png"
308-
with rgb_fname.open("wb") as f:
309-
Image.fromarray(rgb).save(f, format="PNG")
340+
if crop_path:
341+
rgb_fname = crop_upath / f"{idx}_rgb.png"
342+
with rgb_fname.open("wb") as f:
343+
Image.fromarray(rgb).save(f, format="PNG")
310344

311-
b8_fname = crop_upath / f"{idx}_b8.png"
312-
with b8_fname.open("wb") as f:
313-
Image.fromarray(images["B8"]).save(f, format="PNG")
345+
b8_fname = crop_upath / f"{idx}_b8.png"
346+
with b8_fname.open("wb") as f:
347+
Image.fromarray(images["B8"]).save(f, format="PNG")
348+
else:
349+
rgb_fname = ""
350+
b8_fname = ""
314351

315352
# Get longitude/latitude.
316353
src_geom = STGeometry(
@@ -321,14 +358,31 @@ def predict_pipeline(
321358
lat = dst_geom.shp.y
322359

323360
json_data.append(
324-
dict(
361+
FormattedPrediction(
325362
longitude=lon,
326363
latitude=lat,
327364
score=detection.score,
328-
rgb_fname=str(rgb_fname),
329-
b8_fname=str(b8_fname),
330-
)
365+
rgb_fname=rgb_fname,
366+
b8_fname=b8_fname,
367+
),
331368
)
332369

333-
with json_upath.open("w") as f:
334-
json.dump(json_data, f)
370+
time_profile["write_json_and_crops"] = time.time() - step_start_time
371+
372+
elapsed_time = time.time() - start_time # Calculate elapsed time
373+
time_profile["total"] = elapsed_time
374+
375+
# Clean up any temporary directories.
376+
if tmp_dir:
377+
tmp_dir.cleanup()
378+
379+
if json_path:
380+
json_upath = UPath(json_path)
381+
with json_upath.open("w") as f:
382+
json.dump(json_data, f)
383+
384+
print(f"Prediction pipeline completed in {elapsed_time:.2f} seconds")
385+
for step, duration in time_profile.items():
386+
print(f"{step} took {duration:.2f} seconds")
387+
388+
return json_data

rslp/utils/rslearn.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,14 @@ def materialize_dataset(
3636
dataset,
3737
workers=workers,
3838
group=group,
39+
use_initial_job=False,
3940
)
4041
apply_on_windows(
4142
MaterializeHandler(),
4243
dataset,
4344
workers=workers,
4445
group=group,
46+
use_initial_job=False,
4547
)
4648

4749

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
from fastapi.testclient import TestClient
2+
3+
from rslp.landsat_vessels.api_main import app
4+
5+
client = TestClient(app)
6+
7+
8+
def test_singapore_dense_scene() -> None:
9+
# LC08_L1TP_125059_20240913_20240920_02_T1 is a scene that includes southeast coast
10+
# of Singapore where there are hundreds of vessels.
11+
response = client.post(
12+
"/detections", json={"scene_id": "LC08_L1TP_125059_20240913_20240920_02_T1"}
13+
)
14+
assert response.status_code == 200
15+
predictions = response.json()["predictions"]
16+
# There are many correct vessels in this scene.
17+
assert len(predictions) >= 100

0 commit comments

Comments
 (0)