-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
177 lines (147 loc) · 4.86 KB
/
app.py
File metadata and controls
177 lines (147 loc) · 4.86 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import os
import io
import numpy as np
from PIL import Image
import tensorflow as tf
from fastapi import FastAPI, File, UploadFile, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
# Configure GPU memory (if available)
gpus = tf.config.list_physical_devices('GPU')
if gpus:
try:
for gpu in gpus:
tf.config.experimental.set_memory_growth(gpu, True)
# Limit GPU memory to 3.5GB for GTX 1650
tf.config.set_logical_device_configuration(
gpus[0],
[tf.config.LogicalDeviceConfiguration(memory_limit=3500)]
)
except RuntimeError as e:
print(f"GPU configuration error: {e}")
# Initialize FastAPI app
app = FastAPI(
title="Deepfake Detection API",
description="CNN-based deepfake image detection system",
version="1.0.0"
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # In production, specify exact origins
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variable for model
model = None
# Model path - check multiple locations
MODEL_PATHS = [
"deepfake_cnn_gpu.h5",
"src/models/basic_cnn_best.h5",
"src/models/basic_cnn_final.h5"
]
def load_model():
"""Load the trained CNN model"""
global model
for model_path in MODEL_PATHS:
if os.path.exists(model_path):
print(f"Loading model from: {model_path}")
try:
model = tf.keras.models.load_model(model_path)
print("Model loaded successfully!")
return
except Exception as e:
print(f"Error loading model from {model_path}: {e}")
continue
raise RuntimeError("No trained model found! Please train the CNN model first.")
def preprocess_image(image: Image.Image) -> np.ndarray:
"""
Preprocess image for CNN model
- Resize to 224x224
- Convert to RGB
- Normalize to [0, 1]
"""
# Convert to RGB if needed
if image.mode != 'RGB':
image = image.convert('RGB')
# Resize to model input size
image = image.resize((224, 224))
# Convert to numpy array and normalize
img_array = np.array(image, dtype=np.float32) / 255.0
# Add batch dimension
img_array = np.expand_dims(img_array, axis=0)
return img_array
@app.on_event("startup")
async def startup_event():
"""Load model on startup"""
load_model()
@app.get("/")
async def root():
"""Root endpoint"""
return {
"message": "Deepfake Detection API",
"status": "running",
"model_loaded": model is not None
}
@app.get("/health")
async def health_check():
"""Health check endpoint"""
return {
"status": "healthy",
"model_loaded": model is not None
}
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Predict if an uploaded image is real or fake
Args:
file: Uploaded image file
Returns:
JSON with prediction result and confidence
"""
if model is None:
raise HTTPException(status_code=503, detail="Model not loaded")
# Validate file type
if not file.content_type.startswith('image/'):
raise HTTPException(
status_code=400,
detail="Invalid file type. Please upload an image file."
)
try:
# Read image file
contents = await file.read()
image = Image.open(io.BytesIO(contents))
# Preprocess image
processed_image = preprocess_image(image)
# Make prediction
prediction = model.predict(processed_image, verbose=0)[0][0]
# Convert to percentage
confidence = float(prediction) * 100
# Determine label (sigmoid output: 0 = Fake, 1 = Real)
# prediction > 0.5 means closer to 1 (Real)
is_real = prediction > 0.5
label = "Real" if is_real else "Fake"
# Adjust confidence for display
# If Real (p > 0.5), confidence is p * 100
# If Fake (p <= 0.5), confidence is (1 - p) * 100
display_confidence = confidence if is_real else (100 - confidence)
return JSONResponse(content={
"success": True,
"prediction": label,
"confidence": round(display_confidence, 2),
"raw_score": round(float(prediction), 4),
"details": {
"is_fake": not is_real,
"fake_probability": round(100 - confidence, 2),
"real_probability": round(confidence, 2)
}
})
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Error processing image: {str(e)}"
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)