Skip to content

Commit 9ee2b7f

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent c4a4951 commit 9ee2b7f

File tree

5 files changed

+44
-106
lines changed

5 files changed

+44
-106
lines changed

deployment/fastapi_inference/app/inference.py

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -32,19 +32,17 @@ class InferenceEngine:
3232

3333
def __init__(self):
3434
"""Initialize the inference engine with preprocessing transforms."""
35-
self.preprocess = Compose([
36-
LoadImage(image_only=True),
37-
EnsureChannelFirst(),
38-
Spacing(pixdim=(1.5, 1.5, 2.0)),
39-
ScaleIntensity(),
40-
EnsureType(dtype=torch.float32),
41-
])
42-
43-
async def process_image(
44-
self,
45-
image_bytes: bytes,
46-
filename: str
47-
) -> Dict:
35+
self.preprocess = Compose(
36+
[
37+
LoadImage(image_only=True),
38+
EnsureChannelFirst(),
39+
Spacing(pixdim=(1.5, 1.5, 2.0)),
40+
ScaleIntensity(),
41+
EnsureType(dtype=torch.float32),
42+
]
43+
)
44+
45+
async def process_image(self, image_bytes: bytes, filename: str) -> Dict:
4846
"""
4947
Process an uploaded image and return predictions.
5048
@@ -80,13 +78,15 @@ async def process_image(
8078
result = {
8179
"success": True,
8280
"prediction": self._format_prediction(prediction),
83-
"segmentation_shape": list(prediction.shape) if isinstance(prediction, (np.ndarray, torch.Tensor)) else None,
81+
"segmentation_shape": (
82+
list(prediction.shape) if isinstance(prediction, (np.ndarray, torch.Tensor)) else None
83+
),
8484
"metadata": {
8585
"image_shape": list(image_tensor.shape),
8686
"processing_time": round(processing_time, 3),
8787
"device": str(model_loader.device),
8888
},
89-
"message": f"Inference completed successfully in {processing_time:.3f}s"
89+
"message": f"Inference completed successfully in {processing_time:.3f}s",
9090
}
9191

9292
logger.info(f"Inference completed in {processing_time:.3f}s")
@@ -112,15 +112,12 @@ def _load_image(self, image_buffer: BytesIO, filename: str) -> np.ndarray:
112112
"""
113113
try:
114114
# Support NIfTI format (.nii, .nii.gz)
115-
if filename.endswith(('.nii', '.nii.gz')):
115+
if filename.endswith((".nii", ".nii.gz")):
116116
image_buffer.seek(0)
117117
img = nib.load(image_buffer)
118118
return img.get_fdata()
119119
else:
120-
raise ValueError(
121-
f"Unsupported image format: {filename}. "
122-
"Supported formats: .nii, .nii.gz"
123-
)
120+
raise ValueError(f"Unsupported image format: {filename}. " "Supported formats: .nii, .nii.gz")
124121
except Exception as e:
125122
raise ValueError(f"Failed to load image: {str(e)}")
126123

@@ -170,7 +167,7 @@ async def _run_inference(self, image_tensor: torch.Tensor) -> torch.Tensor:
170167

171168
# Run inference with no gradient computation
172169
with torch.no_grad():
173-
if hasattr(model, '__call__'):
170+
if hasattr(model, "__call__"):
174171
prediction = model(image_tensor)
175172
else:
176173
raise RuntimeError("Model is not callable")

deployment/fastapi_inference/app/main.py

Lines changed: 15 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,7 @@
1717
from .schemas import HealthResponse, PredictionResponse, ErrorResponse
1818

1919
# Configure logging
20-
logging.basicConfig(
21-
level=logging.INFO,
22-
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
23-
)
20+
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
2421
logger = logging.getLogger(__name__)
2522

2623

@@ -33,10 +30,7 @@ async def lifespan(app: FastAPI):
3330
# Startup: Load the MONAI model
3431
logger.info("Starting up: Loading MONAI model...")
3532
try:
36-
model_loader.load_model(
37-
model_name="spleen_ct_segmentation",
38-
bundle_dir="./models"
39-
)
33+
model_loader.load_model(model_name="spleen_ct_segmentation", bundle_dir="./models")
4034
logger.info("Model loaded successfully!")
4135
except Exception as e:
4236
logger.error(f"Failed to load model: {e}")
@@ -77,8 +71,8 @@ async def global_exception_handler(request, exc):
7771
content={
7872
"error": "InternalServerError",
7973
"detail": "An unexpected error occurred. Please try again.",
80-
"status_code": 500
81-
}
74+
"status_code": 500,
75+
},
8276
)
8377

8478

@@ -97,7 +91,7 @@ async def root():
9791
"health": "/health",
9892
"predict": "/predict",
9993
"docs": "/docs",
100-
}
94+
},
10195
}
10296

10397

@@ -132,14 +126,9 @@ async def health_check():
132126
200: {"description": "Successful prediction"},
133127
400: {"model": ErrorResponse, "description": "Bad request"},
134128
500: {"model": ErrorResponse, "description": "Internal server error"},
135-
}
129+
},
136130
)
137-
async def predict(
138-
file: UploadFile = File(
139-
...,
140-
description="Medical image file (NIfTI format: .nii or .nii.gz)"
141-
)
142-
):
131+
async def predict(file: UploadFile = File(..., description="Medical image file (NIfTI format: .nii or .nii.gz)")):
143132
"""
144133
Run inference on uploaded medical image.
145134
@@ -153,64 +142,46 @@ async def predict(
153142
HTTPException: If file format is invalid or inference fails
154143
"""
155144
# Validate file format
156-
if not file.filename.endswith(('.nii', '.nii.gz')):
145+
if not file.filename.endswith((".nii", ".nii.gz")):
157146
raise HTTPException(
158-
status_code=status.HTTP_400_BAD_REQUEST,
159-
detail="Invalid file format. Supported formats: .nii, .nii.gz"
147+
status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid file format. Supported formats: .nii, .nii.gz"
160148
)
161149

162150
# Check if model is loaded
163151
if not model_loader.is_loaded():
164152
raise HTTPException(
165-
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
166-
detail="Model not loaded. Please try again later."
153+
status_code=status.HTTP_503_SERVICE_UNAVAILABLE, detail="Model not loaded. Please try again later."
167154
)
168155

169156
try:
170157
# Read file content
171158
contents = await file.read()
172159

173160
# Run inference
174-
result = await inference_engine.process_image(
175-
image_bytes=contents,
176-
filename=file.filename
177-
)
161+
result = await inference_engine.process_image(image_bytes=contents, filename=file.filename)
178162

179163
return PredictionResponse(**result)
180164

181165
except ValueError as e:
182166
# Client error (bad input)
183167
logger.warning(f"Bad request: {str(e)}")
184-
raise HTTPException(
185-
status_code=status.HTTP_400_BAD_REQUEST,
186-
detail=str(e)
187-
)
168+
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(e))
188169

189170
except RuntimeError as e:
190171
# Server error (inference failed)
191172
logger.error(f"Inference error: {str(e)}")
192-
raise HTTPException(
193-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
194-
detail=f"Inference failed: {str(e)}"
195-
)
173+
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Inference failed: {str(e)}")
196174

197175
except Exception as e:
198176
# Unexpected error
199177
logger.error(f"Unexpected error during prediction: {str(e)}")
200178
raise HTTPException(
201-
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
202-
detail="An unexpected error occurred during prediction"
179+
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred during prediction"
203180
)
204181

205182

206183
if __name__ == "__main__":
207184
import uvicorn
208185

209186
# For development only - use proper ASGI server in production
210-
uvicorn.run(
211-
"main:app",
212-
host="0.0.0.0",
213-
port=8000,
214-
reload=True,
215-
log_level="info"
216-
)
187+
uvicorn.run("main:app", host="0.0.0.0", port=8000, reload=True, log_level="info")

deployment/fastapi_inference/app/model_loader.py

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,11 +46,7 @@ def _setup_device(self):
4646
self._device = torch.device("cpu")
4747
logger.info("Using CPU for inference")
4848

49-
def load_model(
50-
self,
51-
model_name: str = "spleen_ct_segmentation",
52-
bundle_dir: str = "./models"
53-
) -> None:
49+
def load_model(self, model_name: str = "spleen_ct_segmentation", bundle_dir: str = "./models") -> None:
5450
"""
5551
Load a MONAI model bundle.
5652
@@ -74,18 +70,14 @@ def load_model(
7470

7571
# Load the model
7672
logger.info("Loading model into memory...")
77-
self._model = load(
78-
name=model_name,
79-
bundle_dir=bundle_dir,
80-
source="monaihosting"
81-
)
73+
self._model = load(name=model_name, bundle_dir=bundle_dir, source="monaihosting")
8274

8375
# Move model to device
84-
if hasattr(self._model, 'to'):
76+
if hasattr(self._model, "to"):
8577
self._model = self._model.to(self._device)
8678

8779
# Set model to evaluation mode
88-
if hasattr(self._model, 'eval'):
80+
if hasattr(self._model, "eval"):
8981
self._model.eval()
9082

9183
logger.info("Model loaded successfully")
@@ -98,9 +90,7 @@ def load_model(
9890
def model(self):
9991
"""Get the loaded model instance."""
10092
if self._model is None:
101-
raise RuntimeError(
102-
"Model not loaded. Call load_model() first."
103-
)
93+
raise RuntimeError("Model not loaded. Call load_model() first.")
10494
return self._model
10595

10696
@property

deployment/fastapi_inference/app/schemas.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,8 @@ class PredictionResponse(BaseModel):
2929
"""Response model for inference predictions."""
3030

3131
success: bool = Field(..., description="Whether prediction was successful")
32-
prediction: Optional[Dict] = Field(
33-
None,
34-
description="Prediction results (format depends on model output)"
35-
)
36-
segmentation_shape: Optional[List[int]] = Field(
37-
None,
38-
description="Shape of segmentation mask if applicable"
39-
)
32+
prediction: Optional[Dict] = Field(None, description="Prediction results (format depends on model output)")
33+
segmentation_shape: Optional[List[int]] = Field(None, description="Shape of segmentation mask if applicable")
4034
metadata: PredictionMetadata = Field(..., description="Prediction metadata")
4135
message: Optional[str] = Field(None, description="Additional information or error message")
4236

deployment/fastapi_inference/examples/client.py

Lines changed: 4 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -63,24 +63,10 @@ def predict(self, image_path: str) -> dict:
6363

6464
def main():
6565
"""Main function for command-line usage."""
66-
parser = argparse.ArgumentParser(
67-
description="MONAI FastAPI Inference Client"
68-
)
69-
parser.add_argument(
70-
"--url",
71-
default="http://localhost:8000",
72-
help="API base URL (default: http://localhost:8000)"
73-
)
74-
parser.add_argument(
75-
"--health",
76-
action="store_true",
77-
help="Check API health status"
78-
)
79-
parser.add_argument(
80-
"--image",
81-
type=str,
82-
help="Path to medical image file for prediction"
83-
)
66+
parser = argparse.ArgumentParser(description="MONAI FastAPI Inference Client")
67+
parser.add_argument("--url", default="http://localhost:8000", help="API base URL (default: http://localhost:8000)")
68+
parser.add_argument("--health", action="store_true", help="Check API health status")
69+
parser.add_argument("--image", type=str, help="Path to medical image file for prediction")
8470

8571
args = parser.parse_args()
8672

0 commit comments

Comments
 (0)