Skip to content

Commit

Permalink
Merge pull request #545 from microsoft/PreRelease
Browse files Browse the repository at this point in the history
Added AI4GAmazonRainforestv2 model
  • Loading branch information
aa-hernandez authored Dec 6, 2024
2 parents ca21678 + 011381e commit e07f35f
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 2 deletions.
3 changes: 2 additions & 1 deletion PytorchWildlife/models/classification/resnet/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,5 @@
from .opossum import *
from .amazon import *
from .serengeti import *
from .custom_weights import *
from .custom_weights import *
from .amazon_v2 import *
108 changes: 108 additions & 0 deletions PytorchWildlife/models/classification/resnet/amazon_v2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.

import torch
from .base_classifier import PlainResNetInference

__all__ = [
"AI4GAmazonRainforest_v2"
]


class AI4GAmazonRainforest_v2(PlainResNetInference):
"""
Amazon Ranforest Animal Classifier that inherits from PlainResNetInference.
This classifier is specialized for recognizing 36 different animals in the Amazon Rainforest.
"""

# Image size for the Opossum classifier
IMAGE_SIZE = 224

# Class names for prediction
CLASS_NAMES = {
0: 'Dasyprocta',
1: 'Bos',
2: 'Pecari',
3: 'Mazama',
4: 'Cuniculus',
5: 'Leptotila',
6: 'Human',
7: 'Aramides',
8: 'Tinamus',
9: 'Eira',
10: 'Crax',
11: 'Procyon',
12: 'Capra',
13: 'Dasypus',
14: 'Sciurus',
15: 'Crypturellus',
16: 'Tamandua',
17: 'Proechimys',
18: 'Leopardus',
19: 'Equus',
20: 'Columbina',
21: 'Nyctidromus',
22: 'Ortalis',
23: 'Emballonura',
24: 'Odontophorus',
25: 'Geotrygon',
26: 'Metachirus',
27: 'Catharus',
28: 'Cerdocyon',
29: 'Momotus',
30: 'Tapirus',
31: 'Canis',
32: 'Furnarius',
33: 'Didelphis',
34: 'Sylvilagus',
35: 'Unknown'
}

def __init__(self, weights=None, device="cpu", pretrained=True):
"""
Initialize the Amazon animal Classifier.
Args:
weights (str, optional): Path to the model weights. Defaults to None.
device (str, optional): Device for model inference. Defaults to "cpu".
pretrained (bool, optional): Whether to use pretrained weights. Defaults to True.
"""

# If pretrained, use the provided URL to fetch the weights
if pretrained:
url = "https://zenodo.org/records/14252214/files/AI4GAmazonDeforestationv2?download=1"
else:
url = None

super(AI4GAmazonRainforest_v2, self).__init__(weights=weights, device=device,
num_cls=36, num_layers=50, url=url)

def results_generation(self, logits, img_ids, id_strip=None):
"""
Generate results for classification.
Args:
logits (torch.Tensor): Output tensor from the model.
img_id (str): Image identifier.
id_strip (str): stiping string for better image id saving.
Returns:
dict: Dictionary containing image ID, prediction, and confidence score.
"""

probs = torch.softmax(logits, dim=1)
preds = probs.argmax(dim=1)
confs = probs.max(dim=1)[0]
confidences = probs[0].tolist()
result = [[self.CLASS_NAMES[i], confidence] for i, confidence in enumerate(confidences)]

results = []
for pred, img_id, conf in zip(preds, img_ids, confs):
r = {"img_id": str(img_id).strip(id_strip)}
r["prediction"] = self.CLASS_NAMES[pred.item()]
r["class_id"] = pred.item()
r["confidence"] = conf.item()
r["all_confidences"] = result
results.append(r)

return results
2 changes: 1 addition & 1 deletion demo/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def update_ui_elements(det_model):
if "HerdNet" in det_model: # Disable all the classification model dropdown because HerdNet does not require a classification model apart
return gr.Dropdown(choices=["None"], interactive=True, label="Classification model", value="None")
else:
return gr.Dropdown(choices=["None", "AI4GOpossum", "AI4GAmazonRainforest", "AI4GSnapshotSerengeti", "CustomWeights"], interactive=True, label="Classification model", value="None")
return gr.Dropdown(choices=["None", "AI4GOpossum", "AI4GAmazonRainforest", "AI4GAmazonRainforest_v2", "AI4GSnapshotSerengeti", "CustomWeights"], interactive=True, label="Classification model", value="None")

det_drop.change(update_ui_elements, det_drop, [clf_drop])

Expand Down

0 comments on commit e07f35f

Please sign in to comment.