-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
138 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
from markupsafe import Markup | ||
import requests | ||
from disease_detection.disease_dic import disease_dic | ||
from disease_detection.model import ResNet9 | ||
from torchvision import transforms | ||
from PIL import Image | ||
|
||
import io | ||
import torch | ||
|
||
|
||
disease_classes = ['Apple___Apple_scab', | ||
'Apple___Black_rot', | ||
'Apple___Cedar_apple_rust', | ||
'Apple___healthy', | ||
'Blueberry___healthy', | ||
'Cherry_(including_sour)___Powdery_mildew', | ||
'Cherry_(including_sour)___healthy', | ||
'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot', | ||
'Corn_(maize)___Common_rust_', | ||
'Corn_(maize)___Northern_Leaf_Blight', | ||
'Corn_(maize)___healthy', | ||
'Grape___Black_rot', | ||
'Grape___Esca_(Black_Measles)', | ||
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)', | ||
'Grape___healthy', | ||
'Orange___Haunglongbing_(Citrus_greening)', | ||
'Peach___Bacterial_spot', | ||
'Peach___healthy', | ||
'Pepper,_bell___Bacterial_spot', | ||
'Pepper,_bell___healthy', | ||
'Potato___Early_blight', | ||
'Potato___Late_blight', | ||
'Potato___healthy', | ||
'Raspberry___healthy', | ||
'Soybean___healthy', | ||
'Squash___Powdery_mildew', | ||
'Strawberry___Leaf_scorch', | ||
'Strawberry___healthy', | ||
'Tomato___Bacterial_spot', | ||
'Tomato___Early_blight', | ||
'Tomato___Late_blight', | ||
'Tomato___Leaf_Mold', | ||
'Tomato___Septoria_leaf_spot', | ||
'Tomato___Spider_mites Two-spotted_spider_mite', | ||
'Tomato___Target_Spot', | ||
'Tomato___Tomato_Yellow_Leaf_Curl_Virus', | ||
'Tomato___Tomato_mosaic_virus', | ||
'Tomato___healthy'] | ||
disease_model_path = 'Trained_Model/plant_disease_model.pth' | ||
disease_model = ResNet9(3, len(disease_classes)) | ||
disease_model.load_state_dict(torch.load( | ||
disease_model_path, map_location=torch.device('cpu'))) | ||
disease_model.eval() | ||
|
||
|
||
|
||
def disease(msg_received): | ||
try: | ||
file = msg_received['file'] | ||
if not file: | ||
return {"code": -1,"Error":"We are not able to find the image file"} | ||
|
||
prediction = predict_image(file) | ||
|
||
prediction = Markup(str(disease_dic[prediction])) | ||
|
||
return prediction | ||
except Exception as e: | ||
return { | ||
"code": -1, | ||
"Error": str(e), | ||
} | ||
|
||
def predict_image(img, model=disease_model): | ||
""" | ||
Transforms image to tensor and predicts disease label | ||
:params: image | ||
:return: prediction (string) | ||
""" | ||
transform = transforms.Compose([ | ||
transforms.Resize(256), | ||
transforms.ToTensor(), | ||
]) | ||
# Sending an HTTP GET request to the URL | ||
response = requests.get(img) | ||
# Reading the image data from the response content and openig the image using Pillow (PIL) | ||
image = Image.open(io.BytesIO(response.content )) | ||
img_t = transform(image) | ||
img_u = torch.unsqueeze(img_t, 0) | ||
|
||
# Get predictions from model | ||
yb = model(img_u) | ||
# Pick index with highest probability | ||
_, preds = torch.max(yb, dim=1) | ||
prediction = disease_classes[preds[0].item()] | ||
# Retrieve the class label | ||
return prediction |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
|
||
|
||
def ConvBlock(in_channels, out_channels, pool=False): | ||
layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), | ||
nn.BatchNorm2d(out_channels), | ||
nn.ReLU(inplace=True)] | ||
if pool: | ||
layers.append(nn.MaxPool2d(4)) | ||
return nn.Sequential(*layers) | ||
|
||
|
||
# Model Architecture | ||
class ResNet9(nn.Module): | ||
def __init__(self, in_channels, num_diseases): | ||
super().__init__() | ||
|
||
self.conv1 = ConvBlock(in_channels, 64) | ||
self.conv2 = ConvBlock(64, 128, pool=True) # out_dim : 128 x 64 x 64 | ||
self.res1 = nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128)) | ||
|
||
self.conv3 = ConvBlock(128, 256, pool=True) # out_dim : 256 x 16 x 16 | ||
self.conv4 = ConvBlock(256, 512, pool=True) # out_dim : 512 x 4 x 44 | ||
self.res2 = nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512)) | ||
|
||
self.classifier = nn.Sequential(nn.MaxPool2d(4), | ||
nn.Flatten(), | ||
nn.Linear(512, num_diseases)) | ||
|
||
def forward(self, xb): # xb is the loaded batch | ||
out = self.conv1(xb) | ||
out = self.conv2(out) | ||
out = self.res1(out) + out | ||
out = self.conv3(out) | ||
out = self.conv4(out) | ||
out = self.res2(out) + out | ||
out = self.classifier(out) | ||
return out |