From 742a6b64c7acb204e8040ada5de42cc7b9527791 Mon Sep 17 00:00:00 2001 From: Kevin Tong Date: Sat, 27 Apr 2024 05:58:03 -0400 Subject: [PATCH] [App] ResNet Compiled App (2/2) - Pipeline (#165) Adds ResNet and image classifier pipeline functionality. Includes changes from https://github.com/hidet-org/hidet/pull/428 See huggingface implementation for original API inspiration. Resolves https://github.com/CentML/hidet/issues/60 --- python/hidet/apps/hf.py | 11 ++ .../hidet/apps/image_classification/README.md | 29 +++ .../apps/image_classification/__init__.py | 11 ++ python/hidet/apps/image_classification/app.py | 11 ++ .../apps/image_classification/builder.py | 29 ++- .../image_classification/modeling/__init__.py | 11 ++ .../modeling/pretrained.py | 16 +- .../modeling/resnet/__init__.py | 11 ++ .../modeling/resnet/modeling.py | 12 ++ .../image_classification/pipeline/__init__.py | 12 ++ .../image_classification/pipeline/pipeline.py | 75 ++++++++ .../processing/__init__.py | 13 ++ .../processing/image_processor.py | 91 +++++++++ .../processing/resnet/__init__.py | 12 ++ .../processing/resnet/processing.py | 174 ++++++++++++++++++ python/hidet/apps/modeling_outputs.py | 11 ++ python/hidet/apps/pretrained.py | 11 ++ python/hidet/apps/registry.py | 11 ++ .../processing/resnet/test_processing.py | 56 ++++++ .../apps/image_classification/test_builder.py | 42 +++++ .../test_image_classifier_builder.py | 14 +- .../image_classification/test_pipeline.py | 30 +++ tests/apps/test_pretrained.py | 13 +- tests/apps/test_registry.py | 11 ++ 24 files changed, 702 insertions(+), 15 deletions(-) create mode 100644 python/hidet/apps/image_classification/README.md create mode 100644 python/hidet/apps/image_classification/pipeline/__init__.py create mode 100644 python/hidet/apps/image_classification/pipeline/pipeline.py create mode 100644 python/hidet/apps/image_classification/processing/__init__.py create mode 100644 python/hidet/apps/image_classification/processing/image_processor.py create mode 100644 python/hidet/apps/image_classification/processing/resnet/__init__.py create mode 100644 python/hidet/apps/image_classification/processing/resnet/processing.py create mode 100644 tests/apps/image_classification/processing/resnet/test_processing.py create mode 100644 tests/apps/image_classification/test_builder.py create mode 100644 tests/apps/image_classification/test_pipeline.py diff --git a/python/hidet/apps/hf.py b/python/hidet/apps/hf.py index 8d3d7cd8a..b72b54c2c 100644 --- a/python/hidet/apps/hf.py +++ b/python/hidet/apps/hf.py @@ -1,3 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Optional import torch diff --git a/python/hidet/apps/image_classification/README.md b/python/hidet/apps/image_classification/README.md new file mode 100644 index 000000000..252116464 --- /dev/null +++ b/python/hidet/apps/image_classification/README.md @@ -0,0 +1,29 @@ +## Hidet Image Classification Compiled App + +### Quickstart + +``` +from hidet.apps.image_classification.pipeline.pipeline import ImageClassificationPipeline +from hidet.apps.image_classification.processing.image_processor import ChannelDimension +from datasets import load_dataset + + +dataset = load_dataset("huggingface/cats-image", split="test", trust_remote_code=True) + +pipeline = ImageClassificationPipeline("microsoft/resnet-50", batch_size=1, kernel_search_space=0) + +res = pipeline(dataset["image"], input_data_format=ChannelDimension.CHANNEL_LAST, top_k=3) +``` + +An image classifier app currently only supports ResNet50 from Huggingface. Currently supports PIL + torch/hidet tensors as image input. + +Load sample datasets using the datasets library, and change label ids back to string labels using the Huggingface config. Returns the top k candidates with the highest score. + +If the weights used are not public, be sure to modify `hidet.toml` so that option `auth_tokens.for_huggingface` is set to your Huggingface account credential. + +### Model Details + +A `PretrainedModelForImageClassification` is a `PretrainedModel` that allows us to `create_pretrained_model` from a Huggingface identifier. `PretrainedModelForImageClassification` defines a forward function that accepts Hidet tensors as input and returns logits as output. + +Interact with a `PretrainedModelForImageClassification` using `ImageClassificationPipeline`. The pipeline instantiates a pre-processor that adapts the image type for Hidet and performs transformations on the image before calling the pretrained model graph. Specify batch size and model name using the pipeline constructor. + diff --git a/python/hidet/apps/image_classification/__init__.py b/python/hidet/apps/image_classification/__init__.py index 617af58f3..5ae610ab4 100644 --- a/python/hidet/apps/image_classification/__init__.py +++ b/python/hidet/apps/image_classification/__init__.py @@ -1 +1,12 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .modeling import * diff --git a/python/hidet/apps/image_classification/app.py b/python/hidet/apps/image_classification/app.py index fe152cdf9..15e6d0dc8 100644 --- a/python/hidet/apps/image_classification/app.py +++ b/python/hidet/apps/image_classification/app.py @@ -1,3 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Sequence from hidet.graph.tensor import Tensor diff --git a/python/hidet/apps/image_classification/builder.py b/python/hidet/apps/image_classification/builder.py index b0b1d87c7..a83e6eca9 100644 --- a/python/hidet/apps/image_classification/builder.py +++ b/python/hidet/apps/image_classification/builder.py @@ -1,8 +1,20 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Optional from transformers import PretrainedConfig from hidet.apps import hf +from hidet.apps.image_classification.processing.image_processor import BaseImageProcessor from hidet.apps.image_classification.app import ImageClassificationApp from hidet.apps.image_classification.modeling.pretrained import PretrainedModelForImageClassification from hidet.apps.modeling_outputs import ImageClassifierOutput @@ -19,6 +31,7 @@ def create_image_classifier( revision: Optional[str] = None, dtype: str = "float32", device: str = "cuda", + batch_size: int = 1, kernel_search_space: int = 2, ): # load the huggingface config according to (model, revision) pair @@ -28,7 +41,7 @@ def create_image_classifier( model = PretrainedModelForImageClassification.create_pretrained_model( config, revision=revision, dtype=dtype, device=device ) - inputs: Tensor = symbol(["bs", 3, 224, 224], dtype=dtype, device=device) + inputs: Tensor = symbol([batch_size, 3, 224, 224], dtype=dtype, device=device) outputs: ImageClassifierOutput = model.forward(inputs) graph: FlowGraph = trace_from(outputs.logits, inputs) @@ -43,14 +56,10 @@ def create_image_classifier( ) -# def create_image_processor( -# name: str, -# revision: Optional[str] = None, -# **kwargs -# ) -> BaseProcessor: -# # load the huggingface config according to (model, revision) pair -# config: PretrainedConfig = hf.load_pretrained_config(name, revision=revision) +def create_image_processor(name: str, revision: Optional[str] = None, **kwargs): + # load the huggingface config according to (model, revision) pair + config: PretrainedConfig = hf.load_pretrained_config(name, revision=revision) -# processor = BaseImageProcessor.load_module(config, module_type=ModuleType.PROCESSING) + processor = BaseImageProcessor.load_module(config.architectures[0]) -# return processor(**kwargs) + return processor(**kwargs) diff --git a/python/hidet/apps/image_classification/modeling/__init__.py b/python/hidet/apps/image_classification/modeling/__init__.py index b792ca6ec..a67e86695 100644 --- a/python/hidet/apps/image_classification/modeling/__init__.py +++ b/python/hidet/apps/image_classification/modeling/__init__.py @@ -1 +1,12 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .resnet import * diff --git a/python/hidet/apps/image_classification/modeling/pretrained.py b/python/hidet/apps/image_classification/modeling/pretrained.py index a9d139335..75d5b3c32 100644 --- a/python/hidet/apps/image_classification/modeling/pretrained.py +++ b/python/hidet/apps/image_classification/modeling/pretrained.py @@ -1,3 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Optional import torch @@ -9,7 +20,7 @@ import hidet -class PretrainedModelForImageClassification(PretrainedModel[ImageClassifierOutput]): +class PretrainedModelForImageClassification(PretrainedModel): @classmethod def create_pretrained_model( cls, config: PretrainedConfig, revision: Optional[str] = None, dtype: Optional[str] = None, device: str = "cuda" @@ -37,3 +48,6 @@ def create_pretrained_model( cls.copy_weights(torch_model, hidet_model) return hidet_model + + def forward(self, *args, **kwargs) -> ImageClassifierOutput: + raise NotImplementedError() diff --git a/python/hidet/apps/image_classification/modeling/resnet/__init__.py b/python/hidet/apps/image_classification/modeling/resnet/__init__.py index a74368312..91f1a791f 100644 --- a/python/hidet/apps/image_classification/modeling/resnet/__init__.py +++ b/python/hidet/apps/image_classification/modeling/resnet/__init__.py @@ -1 +1,12 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from .modeling import ResNetForImageClassification, ResNetModel diff --git a/python/hidet/apps/image_classification/modeling/resnet/modeling.py b/python/hidet/apps/image_classification/modeling/resnet/modeling.py index 081e4e482..471c41775 100644 --- a/python/hidet/apps/image_classification/modeling/resnet/modeling.py +++ b/python/hidet/apps/image_classification/modeling/resnet/modeling.py @@ -1,3 +1,15 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from dataclasses import asdict from typing import Sequence from transformers import ResNetConfig diff --git a/python/hidet/apps/image_classification/pipeline/__init__.py b/python/hidet/apps/image_classification/pipeline/__init__.py new file mode 100644 index 000000000..602686e79 --- /dev/null +++ b/python/hidet/apps/image_classification/pipeline/__init__.py @@ -0,0 +1,12 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .pipeline import ImageClassificationPipeline diff --git a/python/hidet/apps/image_classification/pipeline/pipeline.py b/python/hidet/apps/image_classification/pipeline/pipeline.py new file mode 100644 index 000000000..9ad2dcac1 --- /dev/null +++ b/python/hidet/apps/image_classification/pipeline/pipeline.py @@ -0,0 +1,75 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Any, Iterable, Optional, Sequence +from hidet.apps import hf +from hidet.apps.image_classification.builder import create_image_classifier, create_image_processor +from hidet.apps.image_classification.processing import BaseImageProcessor, ChannelDimension, ImageInput +from hidet.graph.tensor import Tensor + + +class ImageClassificationPipeline: + def __init__( + self, + name: str, + revision: Optional[str] = None, + batch_size: int = 1, + pre_processor: Optional[BaseImageProcessor] = None, + dtype: str = "float32", + device: str = "cuda", + kernel_search_space: int = 2, + ): + if pre_processor is None: + self.pre_processor = create_image_processor(name, revision) + else: + self.pre_processor = pre_processor + + self.model = create_image_classifier(name, revision, dtype, device, batch_size, kernel_search_space) + self.config = hf.load_pretrained_config(name, revision) + + def __call__(self, model_inputs: Any, **kwargs): + """ + Run through image classification pipeline end to end. + images: ImageInput + List or single instance of numpy array, PIL image, or torch tensor + input_data_format: ChannelDimension + Input data is channel first or last + batch_size: int (default 1) + Batch size to feed model inputs + top_k: int (default 5) + Return scores for top k results + """ + if not isinstance(model_inputs, Iterable): + model_inputs = [model_inputs] + if not isinstance(model_inputs, Sequence): + model_inputs = list(model_inputs) + + assert isinstance(model_inputs, Sequence) + + processed_inputs = self.preprocess(model_inputs, **kwargs) + model_outputs = self.forward(processed_inputs, **kwargs) + outputs = self.postprocess(model_outputs, **kwargs) + + return outputs + + def preprocess(self, images: ImageInput, input_data_format: ChannelDimension, **kwargs): + # TODO accept inputs other than ImageInput type, e.g. url or dataset + return self.pre_processor(images, input_data_format=input_data_format, **kwargs) + + def postprocess(self, model_outputs: Tensor, top_k: int = 5, **kwargs): + top_k = min(top_k, self.config.num_labels) + torch_outputs = model_outputs.torch() + values, indices = torch_outputs.topk(top_k, sorted=False) + labels = [[self.config.id2label[int(x.item())] for x in t] for t in indices] + return [[{"label": label, "score": value.item()} for label, value in zip(a, b)] for a, b in zip(labels, values)] + + def forward(self, model_inputs: Tensor, **kwargs) -> Tensor: + return self.model.classify([model_inputs])[0] diff --git a/python/hidet/apps/image_classification/processing/__init__.py b/python/hidet/apps/image_classification/processing/__init__.py new file mode 100644 index 000000000..60061e413 --- /dev/null +++ b/python/hidet/apps/image_classification/processing/__init__.py @@ -0,0 +1,13 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .image_processor import ChannelDimension, BaseImageProcessor, ImageInput +from .resnet import * diff --git a/python/hidet/apps/image_classification/processing/image_processor.py b/python/hidet/apps/image_classification/processing/image_processor.py new file mode 100644 index 000000000..fbb5da7f7 --- /dev/null +++ b/python/hidet/apps/image_classification/processing/image_processor.py @@ -0,0 +1,91 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from enum import Enum, auto, unique +from typing import Dict, List, Optional, Sequence, Type, Union +import torch +import numpy as np +from hidet.graph.tensor import Tensor, from_torch + + +@unique +class ChannelDimension(Enum): + CHANNEL_FIRST = auto() + CHANNEL_LAST = auto() + CHANNEL_SINGLE = auto() + + +ImageInput = Union[ + "PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"] +] # noqa + + +class BaseImageProcessor: + + processor_registry: Dict[str, Type["BaseImageProcessor"]] = {} + + def __init__(self, dtype: Optional[str] = None, device: str = "cuda"): + super().__init__() + + self.dtype = dtype + self.device = device + + @classmethod + def register(cls, arch: str, processor_class: Type["BaseImageProcessor"]): + cls.processor_registry[arch] = processor_class + + @classmethod + def load_module(cls, arch: str): + return cls.processor_registry[arch] + + def __call__(self, images: ImageInput, **kwargs) -> Tensor: + return self.preprocess(images, **kwargs) + + def preprocess(self, images: ImageInput, **kwargs) -> Tensor: + raise NotImplementedError("Image processors should implement their own preprocess step.") + + def rescale(self, image: Tensor, scale: float) -> Tensor: + return image * scale + + def normalize( + self, image: Tensor, mean: Union[float, Sequence[float]], std: Union[float, Sequence[float]] + ) -> Tensor: + """ + Normalize image on per channel basis as + (mean - pixel) / std + mean and std are broadcast across channels if scalar value provided. + """ + num_channels: int = image.shape[-3] + + if isinstance(mean, Sequence): + if len(mean) != num_channels: + raise ValueError(f"means need {num_channels} values, one for each channel, got {len(mean)}.") + else: + mean = [mean] * num_channels + channel_means = from_torch(torch.Tensor(mean).view(num_channels, 1, 1)).to(self.dtype, self.device) + + if isinstance(std, Sequence): + if len(std) != num_channels: + raise ValueError(f"stds need {num_channels} values, one for each channel, got {len(std)}.") + else: + std = [std] * num_channels + channel_stds = from_torch(torch.Tensor(std).view(num_channels, 1, 1)).to(self.dtype, self.device) + + return (image - channel_means) / channel_stds + + def center_square_crop(self, image: Tensor, size: int): + assert image.shape[-2:] >= (size, size) + + pad_width = image.shape[-2] - size + start = (pad_width // 2) + (pad_width % 2) + end = image.shape[-2] - (pad_width // 2) + + return image[:, :, start:end, start:end] diff --git a/python/hidet/apps/image_classification/processing/resnet/__init__.py b/python/hidet/apps/image_classification/processing/resnet/__init__.py new file mode 100644 index 000000000..7727fe0ec --- /dev/null +++ b/python/hidet/apps/image_classification/processing/resnet/__init__.py @@ -0,0 +1,12 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .processing import ResNetImageProcessor diff --git a/python/hidet/apps/image_classification/processing/resnet/processing.py b/python/hidet/apps/image_classification/processing/resnet/processing.py new file mode 100644 index 000000000..b6618b6f3 --- /dev/null +++ b/python/hidet/apps/image_classification/processing/resnet/processing.py @@ -0,0 +1,174 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import List, Optional, Union +import numpy as np +import torch +import PIL + +from hidet.graph.flow_graph import FlowGraph, trace_from +from hidet.graph.ops.image import resize2d +from hidet.graph.tensor import Tensor, from_torch, symbol + +import hidet +from ..image_processor import BaseImageProcessor, ChannelDimension, ImageInput + + +class ResNetImageProcessor(BaseImageProcessor): + def __init__( + self, + do_resize: bool = True, + size: int = 224, + crop_pct: float = 0.875, + resample: str = "cubic", + do_rescale: bool = True, + rescale_factor: Union[int, float] = 1 / 255, + do_normalize: bool = True, + image_mean: Optional[Union[float, List[float]]] = None, + image_std: Optional[Union[float, List[float]]] = None, + dtype: str = "float32", + device: str = "cuda", + **kwargs, + ) -> None: + """ + Pre-process images before ResNet model. Produces square images of (size, size) and transforms + input images to channel first. + Assumes inputs are uint8 RGB images on CPU memory. + Default values taken from `AutoImageProcessor.from_pretrained("microsoft/resnet-50")`, not ImageNet + standards. + See transformers library ConvNextImageProcessor for reference implementation. + """ + super().__init__(dtype, device) + + assert 0 < crop_pct < 1 + + self.do_resize = do_resize + self.size = size + self.crop_pct = crop_pct + self.resample = resample + self.do_rescale = do_rescale + self.rescale_factor = rescale_factor + self.do_normalize = do_normalize + self.image_mean = image_mean or [0.485, 0.456, 0.406] + self.image_std = image_std or [0.229, 0.224, 0.225] + + resize_inputs: Tensor = symbol([1, 3, "h", "w"], dtype="uint8", device="cuda") + resize_outputs = self.resize(resize_inputs.to(self.dtype)) + resize_graph: FlowGraph = trace_from(resize_outputs, resize_inputs) + + scaling_inputs: Tensor = symbol(["n", 3, "h", "w"], dtype="int32" if self.do_resize else "uint8", device="cuda") + scaling_outputs = scaling_inputs.to(self.dtype) + if do_rescale: + scaling_outputs = self.rescale(scaling_outputs, scale=rescale_factor) + if do_normalize: + scaling_outputs = self.normalize(scaling_outputs, self.image_mean, self.image_std) + rescale_graph: FlowGraph = trace_from(scaling_outputs, scaling_inputs) + + resize_graph = hidet.graph.optimize(resize_graph) + rescale_graph = hidet.graph.optimize(rescale_graph) + + self.resize_graph = resize_graph.build(space=2) + self.rescale_graph = rescale_graph.build(space=2) + + def preprocess( + self, images: ImageInput, input_data_format: ChannelDimension, _do_resize: Optional[bool] = None, **kwargs + ) -> Tensor: + assert isinstance(images, (PIL.Image.Image, np.ndarray, torch.Tensor, list)) + + do_resize = _do_resize if _do_resize is not None else self.do_resize + + def _preprocess_one(images): + if isinstance(images, PIL.Image.Image): + images = np.asarray(images) + # fall through + if isinstance(images, np.ndarray): + images = torch.from_numpy(images) + + assert isinstance(images, torch.Tensor) + + if len(images.shape) == 2: + # broadcast grayscale to 3 channels + images = images.expand(3, -1, -1) + if len(images.shape) == 3: + # batch size 1 + images = images.reshape(1, *images.shape) + + if input_data_format == ChannelDimension.CHANNEL_LAST: + # change to channel first + images = images.permute(0, 3, 1, 2) + + hidet_images: Tensor = from_torch(images.contiguous()).to(device=self.device) + + if do_resize: + hidet_images = self.resize_graph.run_async([hidet_images])[0] + + if self.do_rescale or self.do_normalize: + hidet_images = self.rescale_graph.run_async([hidet_images])[0] + return hidet_images + + def _preprocess_many(images): + common_type = type(images[0]) + assert all(isinstance(image, common_type) for image in images) + + if not do_resize: + # batching images requires same size + common_size = images[0].shape + assert all(image.shape == common_size for image in images) + + # change to torch + if issubclass(common_type, PIL.Image.Image): + images = [torch.from_numpy(np.asarray(image).copy()) for image in images] + elif common_type is np.ndarray: + images = [torch.from_numpy(image) for image in images] + + if input_data_format == ChannelDimension.CHANNEL_FIRST: + images = [image.expand(3, -1, -1) if len(image.shape) < 3 else image for image in images] + else: + images = [image.unsqueeze(-1).repeat(1, 1, 3) if len(image.shape) < 3 else image for image in images] + + if input_data_format == ChannelDimension.CHANNEL_LAST: + images = [image.permute(2, 0, 1).contiguous() for image in images] + + if do_resize: + resized_images = [] + for image in images: + image = from_torch(image.reshape(1, *image.shape)).to(device=self.device) + image = self.resize_graph.run_async([image])[0] + image = image.torch() + resized_images.append(image) + + images = resized_images + + # combine to single tensor, recurse + images = torch.stack(images).squeeze() + return self.preprocess(images, ChannelDimension.CHANNEL_FIRST, _do_resize=False) + + return _preprocess_many(images) if isinstance(images, list) else _preprocess_one(images) + + def resize(self, image: Tensor): + """ + If size is <384, resize to size / crop_pct and then apply center crop (to preserve image quality). + Mirrors ConvNextImageProcessor resize operation. Assumes input image with shape (bs, 3, h, w). + """ + assert len(image.shape) == 4 + + if self.size < 384: + resize_shortest_edge = int(self.size / self.crop_pct) + image = resize2d(image, size=(resize_shortest_edge, resize_shortest_edge), method='cubic').to(dtype="int32") + + x = self.center_square_crop(image, self.size) + return x + + else: + return resize2d(image, size=(self.size, self.size), method=self.resample) + + +BaseImageProcessor.register(arch="ResNetForImageClassification", processor_class=ResNetImageProcessor) diff --git a/python/hidet/apps/modeling_outputs.py b/python/hidet/apps/modeling_outputs.py index f8cfb1a09..6f5257f8e 100644 --- a/python/hidet/apps/modeling_outputs.py +++ b/python/hidet/apps/modeling_outputs.py @@ -1,3 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from collections import OrderedDict from dataclasses import dataclass, fields, is_dataclass from typing import Any, List, Tuple diff --git a/python/hidet/apps/pretrained.py b/python/hidet/apps/pretrained.py index 52f43f9cf..7630b81b5 100644 --- a/python/hidet/apps/pretrained.py +++ b/python/hidet/apps/pretrained.py @@ -1,3 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. from typing import Generic, List, Set, Union import logging diff --git a/python/hidet/apps/registry.py b/python/hidet/apps/registry.py index 3953aa679..287c570b4 100644 --- a/python/hidet/apps/registry.py +++ b/python/hidet/apps/registry.py @@ -1,3 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import importlib from dataclasses import astuple, dataclass from typing import Dict diff --git a/tests/apps/image_classification/processing/resnet/test_processing.py b/tests/apps/image_classification/processing/resnet/test_processing.py new file mode 100644 index 000000000..0468cd505 --- /dev/null +++ b/tests/apps/image_classification/processing/resnet/test_processing.py @@ -0,0 +1,56 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.apps.image_classification.processing.image_processor import ChannelDimension +from hidet.apps.image_classification.processing.resnet.processing import ResNetImageProcessor +import pytest +import torch + + +def test_resnet_processor_resize(): + # Channel first + image = torch.zeros((3, 10, 15), dtype=torch.uint8) + image += torch.arange(1, 16) + + processor = ResNetImageProcessor(size=4, do_rescale=False, do_normalize=False) + res = processor(image, input_data_format=ChannelDimension.CHANNEL_FIRST) + assert res.shape == (1, 3, 4, 4) + assert ((0 < res.torch()) & (res.torch() < 15)).all() + + # Channel last + image = torch.zeros((10, 15, 3), dtype=torch.uint8) + image += torch.arange(1, 16).view(1, 15, 1) + + processor = ResNetImageProcessor(size=4, do_rescale=False, do_normalize=False) + res = processor(image, input_data_format=ChannelDimension.CHANNEL_LAST) + assert res.shape == (1, 3, 4, 4) + assert ((0 < res.torch()) & (res.torch() < 15)).all() + + # Batch resize + images = [] + + import random + + random.seed(0) + + for _ in range(10): + rows = random.randint(10, 20) + cols = random.randint(10, 20) + tensor = torch.randint(1, 9, (rows, cols, 3), dtype=torch.uint8) + images.append(tensor) + + res = processor(images, input_data_format=ChannelDimension.CHANNEL_LAST) + assert res.shape == (10, 3, 4, 4) + assert ((0 <= res.torch()) & (res.torch() < 10)).all() + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/apps/image_classification/test_builder.py b/tests/apps/image_classification/test_builder.py new file mode 100644 index 000000000..c04364fd7 --- /dev/null +++ b/tests/apps/image_classification/test_builder.py @@ -0,0 +1,42 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.apps.image_classification.processing.image_processor import ChannelDimension +import pytest +import torch +from datasets import load_dataset +from hidet.apps.image_classification.builder import create_image_classifier, create_image_processor +from hidet.graph.tensor import from_torch +from transformers import AutoImageProcessor + + +def test_create_image_classifier(): + dataset = load_dataset("huggingface/cats-image", split="test", trust_remote_code=True) + + # using huggingface pre-processor + image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") + images = image_processor(dataset["image"], return_tensors="pt") + images = images["pixel_values"] + images = from_torch(images).cuda() + + resnet = create_image_classifier("microsoft/resnet-50", kernel_search_space=0) + assert "image_classifier" in resnet.compiled_app.meta.graphs + assert resnet.compiled_app.meta.name == "microsoft/resnet-50" + + res = resnet.compiled_app.graphs["image_classifier"].run_async([images]) + res = res[0].torch() + res = torch.argmax(res, dim=1) + + assert res[0].item() == 282 # tiger cat label + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/apps/image_classification/test_image_classifier_builder.py b/tests/apps/image_classification/test_image_classifier_builder.py index 148a7503d..865e4f9ca 100644 --- a/tests/apps/image_classification/test_image_classifier_builder.py +++ b/tests/apps/image_classification/test_image_classifier_builder.py @@ -1,3 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest import torch from datasets import load_dataset @@ -6,7 +17,6 @@ from transformers import AutoImageProcessor -@pytest.mark.slow def test_create_image_classifier(): dataset = load_dataset("huggingface/cats-image", split="test", trust_remote_code=True) @@ -15,7 +25,7 @@ def test_create_image_classifier(): image = image_processor(dataset[0]["image"], return_tensors="pt")["pixel_values"] image = from_torch(image).cuda() - resnet = create_image_classifier("microsoft/resnet-50") + resnet = create_image_classifier("microsoft/resnet-50", kernel_search_space=0) assert "image_classifier" in resnet.compiled_app.meta.graphs assert resnet.compiled_app.meta.name == "microsoft/resnet-50" diff --git a/tests/apps/image_classification/test_pipeline.py b/tests/apps/image_classification/test_pipeline.py new file mode 100644 index 000000000..3fea13c0c --- /dev/null +++ b/tests/apps/image_classification/test_pipeline.py @@ -0,0 +1,30 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from hidet.apps.image_classification.pipeline.pipeline import ImageClassificationPipeline +from hidet.apps.image_classification.processing.image_processor import ChannelDimension +import pytest +from datasets import load_dataset + + +def test_image_classifier_pipeline(): + dataset = load_dataset("huggingface/cats-image", split="test", trust_remote_code=True) + + pipeline = ImageClassificationPipeline("microsoft/resnet-50", batch_size=1, kernel_search_space=0) + + res = pipeline(dataset["image"], input_data_format=ChannelDimension.CHANNEL_LAST, top_k=3) + + assert len(res) == 1 + assert all([len(x) == 3 for x in res]) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/apps/test_pretrained.py b/tests/apps/test_pretrained.py index c06cc0799..6aa2b56dc 100644 --- a/tests/apps/test_pretrained.py +++ b/tests/apps/test_pretrained.py @@ -1,3 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest import torch from hidet.apps import PretrainedModel, hf @@ -6,7 +17,6 @@ from transformers import AutoModelForImageClassification, PretrainedConfig, ResNetConfig -@pytest.mark.slow @pytest.mark.parametrize( "model_name, dtype", [ @@ -19,7 +29,6 @@ def test_parse_dtype(model_name: str, dtype: str): assert PretrainedModel.parse_dtype(config) == dtype -@pytest.mark.slow def test_copy_weights(): with torch.device("cuda"): diff --git a/tests/apps/test_registry.py b/tests/apps/test_registry.py index b32820bd0..6f416579e 100644 --- a/tests/apps/test_registry.py +++ b/tests/apps/test_registry.py @@ -1,3 +1,14 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. import pytest from hidet.apps import Registry, hf from hidet.apps.image_classification.modeling.resnet.modeling import ResNetForImageClassification