-
Notifications
You must be signed in to change notification settings - Fork 2
FastAPI for Landsat vessel detection #34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 17 commits
5a163ca
7877e59
c400985
8f05ca3
7b53fdb
2bfc38e
19421b5
02ab349
8c2136b
f7081e4
b64ab97
cbdeec9
9df535d
a1b6d1e
8691f53
1d7d4d6
164e895
0babbbe
12d6a5b
67f29b5
3f546bc
2fd1847
5a1149c
3a9cad1
ee16650
dd674b9
ec987af
d505f18
c0396a3
e2790b7
eb0b5a5
850429c
b1cea89
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,8 @@ | ||
beaker-py | ||
python-dotenv | ||
pytest | ||
uvicorn | ||
fastapi | ||
pydantic | ||
typing-extensions | ||
ruff |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Base image | ||
FROM base-image:latest | ||
|
||
# Environment variables | ||
ENV PYTHONPATH="/opt/rslearn_projects:${PYTHONPATH}" | ||
ENV LANDSAT_PORT=5555 | ||
|
||
# Make port 5555 available to the world outside this container | ||
EXPOSE $LANDSAT_PORT | ||
|
||
# Run app.py when the container launches | ||
CMD ["python3", "rslp/landsat_vessels/api_main.py"] |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,99 @@ | ||
"""Landsat Vessel Detection Service.""" | ||
|
||
from __future__ import annotations | ||
|
||
import logging | ||
import multiprocessing | ||
import os | ||
|
||
import uvicorn | ||
from fastapi import FastAPI, Response | ||
from pydantic import BaseModel | ||
from typing_extensions import TypedDict | ||
|
||
from rslp.landsat_vessels import predict_pipeline | ||
|
||
app = FastAPI() | ||
|
||
# Set up the logger | ||
logging.basicConfig(level=logging.INFO) | ||
logger = logging.getLogger(__name__) | ||
|
||
LANDSAT_HOST = "0.0.0.0" | ||
LANDSAT_PORT = 5555 | ||
|
||
|
||
class FormattedPrediction(TypedDict): | ||
"""Formatted prediction for a single vessel detection.""" | ||
|
||
latitude: float | ||
longitude: float | ||
score: float | ||
rgb_fname: str | ||
b8_fname: str | ||
|
||
|
||
class LandsatResponse(BaseModel): | ||
"""Response object for vessel detections.""" | ||
|
||
status: list[str] | ||
predictions: list[FormattedPrediction] | ||
|
||
|
||
class LandsatRequest(BaseModel): | ||
"""Request object for vessel detections.""" | ||
|
||
scene_id: str | None = None | ||
image_files: dict[str, str] | None = None | ||
crop_path: str | ||
scratch_path: str | ||
|
||
json_path: str | ||
|
||
|
||
|
||
@app.on_event("startup") | ||
async def rslp_init() -> None: | ||
"""Landsat Vessel Service Initialization.""" | ||
logger.info("Initializing") | ||
multiprocessing.set_start_method("forkserver", force=True) | ||
multiprocessing.set_forkserver_preload( | ||
[ | ||
"rslp.utils.rslearn.materialize_dataset", | ||
"rslp.utils.rslearn.run_model_predict", | ||
] | ||
) | ||
|
||
|
||
@app.get("/") | ||
async def home() -> dict: | ||
"""Returns a simple message to indicate the service is running.""" | ||
return {"message": "Landsat Detections App"} | ||
|
||
|
||
@app.post("/detections", response_model=LandsatResponse) | ||
async def get_detections(info: LandsatRequest, response: Response) -> LandsatResponse: | ||
"""Returns vessel detections Response object for a given Request object.""" | ||
try: | ||
logger.info(f"Received request with scene_id: {info.scene_id}") | ||
json_data = predict_pipeline( | ||
crop_path=info.crop_path, | ||
scene_id=info.scene_id, | ||
image_files=info.image_files, | ||
scratch_path=info.scratch_path, | ||
json_path=info.json_path, | ||
) | ||
return LandsatResponse(status=["success"], predictions=json_data) | ||
except ValueError as e: | ||
logger.error(f"Value error during prediction pipeline: {e}") | ||
return LandsatResponse(status=["error"], predictions=[]) | ||
except Exception as e: | ||
logger.error(f"Unexpected error during prediction pipeline: {e}") | ||
return LandsatResponse(status=["error"], predictions=[]) | ||
|
||
|
||
if __name__ == "__main__": | ||
uvicorn.run( | ||
"api_main:app", | ||
host=os.getenv("LANDSAT_HOST", default="0.0.0.0"), | ||
port=int(os.getenv("LANDSAT_PORT", default=5555)), | ||
proxy_headers=True, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
version: "3.9" | ||
|
||
services: | ||
# Define the base image | ||
base-image: | ||
build: | ||
context: ../.. | ||
dockerfile: Dockerfile | ||
image: base-image:latest # Tag it as "base-image" | ||
|
||
# Define the landsat-vessels service | ||
landsat-vessels: | ||
build: | ||
context: . | ||
dockerfile: Dockerfile | ||
shm_size: '10G' # This adds the shared memory size | ||
depends_on: | ||
- base-image | ||
ports: | ||
- "5555:5555" | ||
environment: | ||
- RSLP_BUCKET | ||
- S3_ACCESS_KEY_ID | ||
- S3_SECRET_ACCESS_KEY | ||
- AWS_ACCESS_KEY_ID | ||
- AWS_SECRET_ACCESS_KEY | ||
- NVIDIA_VISIBLE_DEVICES=all # Make all GPUs visible | ||
deploy: | ||
resources: | ||
reservations: | ||
devices: | ||
- capabilities: [gpu] # Ensure this service can access GPUs | ||
runtime: nvidia # Use the NVIDIA runtime |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import tempfile | ||
|
||
from fastapi.testclient import TestClient | ||
|
||
from rslp.landsat_vessels.api_main import app | ||
|
||
client = TestClient(app) | ||
|
||
|
||
def test_singapore_dense_scene(): | ||
# LC08_L1TP_125059_20240913_20240920_02_T1 is a scene that includes southeast coast | ||
# of Singapore where there are hundreds of vessels. | ||
with tempfile.TemporaryDirectory() as tmp_dir: | ||
response = client.post( | ||
"/detections", | ||
json={ | ||
"scene_id": "LC08_L1TP_125059_20240913_20240920_02_T1", | ||
"scratch_path": tmp_dir, | ||
"json_path": "", | ||
"crop_path": "", | ||
}, | ||
) | ||
assert response.status_code == 200 | ||
predictions = response.json()["predictions"] | ||
# There are many correct vessels in this scene. | ||
assert len(predictions) >= 100 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be nice for crop_path to be optional as well. If you don't make it optional then in the test I added, crop_path should be changed to be a subdirectory of the scratch_path, otherwise it writes the crops in the current directory.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed. Also leave it as None in the test, so that we can query API just with a scene_id.