Skip to content

Commit

Permalink
disease detection
Browse files Browse the repository at this point in the history
  • Loading branch information
Piexie3 committed Sep 25, 2023
1 parent f3aa447 commit eaeac8b
Show file tree
Hide file tree
Showing 2 changed files with 138 additions and 0 deletions.
98 changes: 98 additions & 0 deletions disease_detection/disease.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
from markupsafe import Markup
import requests
from disease_detection.disease_dic import disease_dic
from disease_detection.model import ResNet9
from torchvision import transforms
from PIL import Image

import io
import torch


disease_classes = ['Apple___Apple_scab',
'Apple___Black_rot',
'Apple___Cedar_apple_rust',
'Apple___healthy',
'Blueberry___healthy',
'Cherry_(including_sour)___Powdery_mildew',
'Cherry_(including_sour)___healthy',
'Corn_(maize)___Cercospora_leaf_spot Gray_leaf_spot',
'Corn_(maize)___Common_rust_',
'Corn_(maize)___Northern_Leaf_Blight',
'Corn_(maize)___healthy',
'Grape___Black_rot',
'Grape___Esca_(Black_Measles)',
'Grape___Leaf_blight_(Isariopsis_Leaf_Spot)',
'Grape___healthy',
'Orange___Haunglongbing_(Citrus_greening)',
'Peach___Bacterial_spot',
'Peach___healthy',
'Pepper,_bell___Bacterial_spot',
'Pepper,_bell___healthy',
'Potato___Early_blight',
'Potato___Late_blight',
'Potato___healthy',
'Raspberry___healthy',
'Soybean___healthy',
'Squash___Powdery_mildew',
'Strawberry___Leaf_scorch',
'Strawberry___healthy',
'Tomato___Bacterial_spot',
'Tomato___Early_blight',
'Tomato___Late_blight',
'Tomato___Leaf_Mold',
'Tomato___Septoria_leaf_spot',
'Tomato___Spider_mites Two-spotted_spider_mite',
'Tomato___Target_Spot',
'Tomato___Tomato_Yellow_Leaf_Curl_Virus',
'Tomato___Tomato_mosaic_virus',
'Tomato___healthy']
disease_model_path = 'Trained_Model/plant_disease_model.pth'
disease_model = ResNet9(3, len(disease_classes))
disease_model.load_state_dict(torch.load(
disease_model_path, map_location=torch.device('cpu')))
disease_model.eval()



def disease(msg_received):
try:
file = msg_received['file']
if not file:
return {"code": -1,"Error":"We are not able to find the image file"}

prediction = predict_image(file)

prediction = Markup(str(disease_dic[prediction]))

return prediction
except Exception as e:
return {
"code": -1,
"Error": str(e),
}

def predict_image(img, model=disease_model):
"""
Transforms image to tensor and predicts disease label
:params: image
:return: prediction (string)
"""
transform = transforms.Compose([
transforms.Resize(256),
transforms.ToTensor(),
])
# Sending an HTTP GET request to the URL
response = requests.get(img)
# Reading the image data from the response content and openig the image using Pillow (PIL)
image = Image.open(io.BytesIO(response.content ))
img_t = transform(image)
img_u = torch.unsqueeze(img_t, 0)

# Get predictions from model
yb = model(img_u)
# Pick index with highest probability
_, preds = torch.max(yb, dim=1)
prediction = disease_classes[preds[0].item()]
# Retrieve the class label
return prediction
40 changes: 40 additions & 0 deletions disease_detection/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import torch
import torch.nn as nn
import torch.nn.functional as F


def ConvBlock(in_channels, out_channels, pool=False):
layers = [nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
nn.BatchNorm2d(out_channels),
nn.ReLU(inplace=True)]
if pool:
layers.append(nn.MaxPool2d(4))
return nn.Sequential(*layers)


# Model Architecture
class ResNet9(nn.Module):
def __init__(self, in_channels, num_diseases):
super().__init__()

self.conv1 = ConvBlock(in_channels, 64)
self.conv2 = ConvBlock(64, 128, pool=True) # out_dim : 128 x 64 x 64
self.res1 = nn.Sequential(ConvBlock(128, 128), ConvBlock(128, 128))

self.conv3 = ConvBlock(128, 256, pool=True) # out_dim : 256 x 16 x 16
self.conv4 = ConvBlock(256, 512, pool=True) # out_dim : 512 x 4 x 44
self.res2 = nn.Sequential(ConvBlock(512, 512), ConvBlock(512, 512))

self.classifier = nn.Sequential(nn.MaxPool2d(4),
nn.Flatten(),
nn.Linear(512, num_diseases))

def forward(self, xb): # xb is the loaded batch
out = self.conv1(xb)
out = self.conv2(out)
out = self.res1(out) + out
out = self.conv3(out)
out = self.conv4(out)
out = self.res2(out) + out
out = self.classifier(out)
return out

0 comments on commit eaeac8b

Please sign in to comment.