-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.py
60 lines (44 loc) · 1.26 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import uvicorn
import numpy as np
from io import BytesIO
from PIL import Image
import tensorflow as tf
from resizeimage import resizeimage
import matplotlib.pyplot as plt
app = FastAPI()
origins = [
"http://localhost",
"http://localhost:3000",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
CLASS_NAMES = ["nevus", "not found"]
@app.get("/ping")
async def ping():
return "Hello, I am alive"
def read_file_as_image(data) -> np.ndarray:
image = np.array(Image.open(BytesIO(data)))
return image
@app.post("/predict")
async def predict(
file: UploadFile = File(...)
):
MODEL = tf.keras.models.load_model("./Train_Model/2")
image = read_file_as_image(await file.read())
img_batch = np.expand_dims(image, 0)
predictions = MODEL.predict(img_batch)
predicted_class = CLASS_NAMES[np.argmax(predictions[0])]
confidence = (np.max(predictions[0]) )* 100
return {
'class': predicted_class,
'confidence': float(confidence )
}
if __name__ == "__main__":
uvicorn.run(app, host='localhost', port=8000)