-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
32 lines (26 loc) · 860 Bytes
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import numpy as np
import pandas as pd
import uvicorn
from fastapi import FastAPI
from pydantic import BaseModel
from simple_model.importing_data import import_data
from simple_model.model_training import train_model
from simple_model.model_inference import prediction
# Training the model
dataset = import_data()
model = train_model(dataset)
app = FastAPI()
# We define the input data format
class Flower(BaseModel):
sepal_length: float
sepal_width: float
petal_length: float
petal_width: float
@app.post("/predict")
def predict(data: Flower):
input_array = np.array([[data.sepal_length, data.sepal_width, data.petal_length, data.petal_width]])
pred = prediction(model, input_array)
return {"prediction": pred[0]}
# Run the FastAPI app on localhost
if __name__ == "__main__":
uvicorn.run(app, host="localhost", port=8000)