-
Notifications
You must be signed in to change notification settings - Fork 52
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[App] ResNet Compiled App (2/2) - Pipeline (#165)
Adds ResNet and image classifier pipeline functionality. Includes changes from #428 See huggingface implementation for original API inspiration. Resolves CentML/hidet#60
- Loading branch information
1 parent
b75e5d8
commit 742a6b6
Showing
24 changed files
with
702 additions
and
15 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
## Hidet Image Classification Compiled App | ||
|
||
### Quickstart | ||
|
||
``` | ||
from hidet.apps.image_classification.pipeline.pipeline import ImageClassificationPipeline | ||
from hidet.apps.image_classification.processing.image_processor import ChannelDimension | ||
from datasets import load_dataset | ||
dataset = load_dataset("huggingface/cats-image", split="test", trust_remote_code=True) | ||
pipeline = ImageClassificationPipeline("microsoft/resnet-50", batch_size=1, kernel_search_space=0) | ||
res = pipeline(dataset["image"], input_data_format=ChannelDimension.CHANNEL_LAST, top_k=3) | ||
``` | ||
|
||
An image classifier app currently only supports ResNet50 from Huggingface. Currently supports PIL + torch/hidet tensors as image input. | ||
|
||
Load sample datasets using the datasets library, and change label ids back to string labels using the Huggingface config. Returns the top k candidates with the highest score. | ||
|
||
If the weights used are not public, be sure to modify `hidet.toml` so that option `auth_tokens.for_huggingface` is set to your Huggingface account credential. | ||
|
||
### Model Details | ||
|
||
A `PretrainedModelForImageClassification` is a `PretrainedModel` that allows us to `create_pretrained_model` from a Huggingface identifier. `PretrainedModelForImageClassification` defines a forward function that accepts Hidet tensors as input and returns logits as output. | ||
|
||
Interact with a `PretrainedModelForImageClassification` using `ImageClassificationPipeline`. The pipeline instantiates a pre-processor that adapts the image type for Hidet and performs transformations on the image before calling the pretrained model graph. Specify batch size and model name using the pipeline constructor. | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,12 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from .modeling import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
python/hidet/apps/image_classification/modeling/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,12 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from .resnet import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
11 changes: 11 additions & 0 deletions
11
python/hidet/apps/image_classification/modeling/resnet/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,12 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from .modeling import ResNetForImageClassification, ResNetModel |
12 changes: 12 additions & 0 deletions
12
python/hidet/apps/image_classification/modeling/resnet/modeling.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
12 changes: 12 additions & 0 deletions
12
python/hidet/apps/image_classification/pipeline/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from .pipeline import ImageClassificationPipeline |
75 changes: 75 additions & 0 deletions
75
python/hidet/apps/image_classification/pipeline/pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from typing import Any, Iterable, Optional, Sequence | ||
from hidet.apps import hf | ||
from hidet.apps.image_classification.builder import create_image_classifier, create_image_processor | ||
from hidet.apps.image_classification.processing import BaseImageProcessor, ChannelDimension, ImageInput | ||
from hidet.graph.tensor import Tensor | ||
|
||
|
||
class ImageClassificationPipeline: | ||
def __init__( | ||
self, | ||
name: str, | ||
revision: Optional[str] = None, | ||
batch_size: int = 1, | ||
pre_processor: Optional[BaseImageProcessor] = None, | ||
dtype: str = "float32", | ||
device: str = "cuda", | ||
kernel_search_space: int = 2, | ||
): | ||
if pre_processor is None: | ||
self.pre_processor = create_image_processor(name, revision) | ||
else: | ||
self.pre_processor = pre_processor | ||
|
||
self.model = create_image_classifier(name, revision, dtype, device, batch_size, kernel_search_space) | ||
self.config = hf.load_pretrained_config(name, revision) | ||
|
||
def __call__(self, model_inputs: Any, **kwargs): | ||
""" | ||
Run through image classification pipeline end to end. | ||
images: ImageInput | ||
List or single instance of numpy array, PIL image, or torch tensor | ||
input_data_format: ChannelDimension | ||
Input data is channel first or last | ||
batch_size: int (default 1) | ||
Batch size to feed model inputs | ||
top_k: int (default 5) | ||
Return scores for top k results | ||
""" | ||
if not isinstance(model_inputs, Iterable): | ||
model_inputs = [model_inputs] | ||
if not isinstance(model_inputs, Sequence): | ||
model_inputs = list(model_inputs) | ||
|
||
assert isinstance(model_inputs, Sequence) | ||
|
||
processed_inputs = self.preprocess(model_inputs, **kwargs) | ||
model_outputs = self.forward(processed_inputs, **kwargs) | ||
outputs = self.postprocess(model_outputs, **kwargs) | ||
|
||
return outputs | ||
|
||
def preprocess(self, images: ImageInput, input_data_format: ChannelDimension, **kwargs): | ||
# TODO accept inputs other than ImageInput type, e.g. url or dataset | ||
return self.pre_processor(images, input_data_format=input_data_format, **kwargs) | ||
|
||
def postprocess(self, model_outputs: Tensor, top_k: int = 5, **kwargs): | ||
top_k = min(top_k, self.config.num_labels) | ||
torch_outputs = model_outputs.torch() | ||
values, indices = torch_outputs.topk(top_k, sorted=False) | ||
labels = [[self.config.id2label[int(x.item())] for x in t] for t in indices] | ||
return [[{"label": label, "score": value.item()} for label, value in zip(a, b)] for a, b in zip(labels, values)] | ||
|
||
def forward(self, model_inputs: Tensor, **kwargs) -> Tensor: | ||
return self.model.classify([model_inputs])[0] |
13 changes: 13 additions & 0 deletions
13
python/hidet/apps/image_classification/processing/__init__.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from .image_processor import ChannelDimension, BaseImageProcessor, ImageInput | ||
from .resnet import * |
91 changes: 91 additions & 0 deletions
91
python/hidet/apps/image_classification/processing/image_processor.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
from enum import Enum, auto, unique | ||
from typing import Dict, List, Optional, Sequence, Type, Union | ||
import torch | ||
import numpy as np | ||
from hidet.graph.tensor import Tensor, from_torch | ||
|
||
|
||
@unique | ||
class ChannelDimension(Enum): | ||
CHANNEL_FIRST = auto() | ||
CHANNEL_LAST = auto() | ||
CHANNEL_SINGLE = auto() | ||
|
||
|
||
ImageInput = Union[ | ||
"PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"] | ||
] # noqa | ||
|
||
|
||
class BaseImageProcessor: | ||
|
||
processor_registry: Dict[str, Type["BaseImageProcessor"]] = {} | ||
|
||
def __init__(self, dtype: Optional[str] = None, device: str = "cuda"): | ||
super().__init__() | ||
|
||
self.dtype = dtype | ||
self.device = device | ||
|
||
@classmethod | ||
def register(cls, arch: str, processor_class: Type["BaseImageProcessor"]): | ||
cls.processor_registry[arch] = processor_class | ||
|
||
@classmethod | ||
def load_module(cls, arch: str): | ||
return cls.processor_registry[arch] | ||
|
||
def __call__(self, images: ImageInput, **kwargs) -> Tensor: | ||
return self.preprocess(images, **kwargs) | ||
|
||
def preprocess(self, images: ImageInput, **kwargs) -> Tensor: | ||
raise NotImplementedError("Image processors should implement their own preprocess step.") | ||
|
||
def rescale(self, image: Tensor, scale: float) -> Tensor: | ||
return image * scale | ||
|
||
def normalize( | ||
self, image: Tensor, mean: Union[float, Sequence[float]], std: Union[float, Sequence[float]] | ||
) -> Tensor: | ||
""" | ||
Normalize image on per channel basis as | ||
(mean - pixel) / std | ||
mean and std are broadcast across channels if scalar value provided. | ||
""" | ||
num_channels: int = image.shape[-3] | ||
|
||
if isinstance(mean, Sequence): | ||
if len(mean) != num_channels: | ||
raise ValueError(f"means need {num_channels} values, one for each channel, got {len(mean)}.") | ||
else: | ||
mean = [mean] * num_channels | ||
channel_means = from_torch(torch.Tensor(mean).view(num_channels, 1, 1)).to(self.dtype, self.device) | ||
|
||
if isinstance(std, Sequence): | ||
if len(std) != num_channels: | ||
raise ValueError(f"stds need {num_channels} values, one for each channel, got {len(std)}.") | ||
else: | ||
std = [std] * num_channels | ||
channel_stds = from_torch(torch.Tensor(std).view(num_channels, 1, 1)).to(self.dtype, self.device) | ||
|
||
return (image - channel_means) / channel_stds | ||
|
||
def center_square_crop(self, image: Tensor, size: int): | ||
assert image.shape[-2:] >= (size, size) | ||
|
||
pad_width = image.shape[-2] - size | ||
start = (pad_width // 2) + (pad_width % 2) | ||
end = image.shape[-2] - (pad_width // 2) | ||
|
||
return image[:, :, start:end, start:end] |
Oops, something went wrong.