-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
56 lines (46 loc) · 1.38 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# Put the code for your API here.
import os
from pandas import DataFrame
from fastapi import FastAPI
from api.schema import ModelInput
from starter.ml.data import process_data
from starter.ml.model import load_model, load_encoder, load_lb, inference
# Categorical Features
cat_features = [
"workclass",
"education",
"marital_status",
"occupation",
"relationship",
"race",
"sex",
"native_country",
]
# Instantiate the app.
app = FastAPI()
# load file on startup to avoid latency on prediction
@app.on_event("startup")
async def startup_event():
global model, encoder, lb
model = load_model(os.path.join('model', 'model_dtc.pkl'))
encoder = load_encoder(os.path.join('model', 'encoder_dtc.pkl'))
lb = load_lb(os.path.join('model', 'lb_dtc.pkl'))
# Define a GET on the specified endpoint.
@app.get("/")
async def say_hello():
return {'greeting': 'Hello World!'}
@app.post("/predict")
async def predict(input_data: ModelInput):
X_input = DataFrame([input_data.dict()])
# Run: process data
X_infer, _, _, _ = process_data(
X_input,
categorical_features=cat_features,
encoder=encoder,
lb=lb,
training=False,
)
# Run:inference
pred = inference(model=model, X=X_infer)
# Run: inverse of the binarizer to get: "<=50K" or "">50K"
return {"Prediction": lb.inverse_transform(pred)[0]}