diff --git a/.github/requirements-ci.txt b/.github/requirements-ci.txt index 7318dc814..1692c318f 100644 --- a/.github/requirements-ci.txt +++ b/.github/requirements-ci.txt @@ -1,4 +1,4 @@ mysql-connector-python -transformers +transformers==4.37 accelerate sentencepiece \ No newline at end of file diff --git a/.github/scripts/bench/bench_transformer.py b/.github/scripts/bench/bench_transformer.py deleted file mode 100644 index 24cd3cd03..000000000 --- a/.github/scripts/bench/bench_transformer.py +++ /dev/null @@ -1,59 +0,0 @@ -import sys -import os -import argparse -import numpy as np -import torch -import hidet -from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM, logging -from bench_utils import enable_compile_server, setup_hidet_flags, bench_torch_model -os.environ["TOKENIZERS_PARALLELISM"] = "false" -logging.set_verbosity_error() - -model_class = { - 'bert-base-uncased': 'AutoModelForMaskedLM', -} - -def bench_hf_transformers(model_name, seqlen, dtype): - setup_hidet_flags(dtype) - enable_compile_server(True) - dtype = getattr(torch, dtype) - tokenizer = AutoTokenizer.from_pretrained(model_name) - AutoModel_cls = eval(model_class[model_name]) - model = AutoModel_cls.from_pretrained(model_name, - max_position_embeddings=8192, ignore_mismatched_sizes=True) - model = model.eval().to(dtype).cuda() - inputs = tokenizer("Dummy sentence", padding='max_length', max_length=seqlen, - return_tensors='pt') - inputs = {'input_ids': inputs['input_ids']} - torch_inputs = tuple(i.clone().cuda() for i in inputs.values()) - with torch.no_grad(), torch.autocast("cuda"): - model = torch.compile(model, backend='hidet') - latency = bench_torch_model(model, torch_inputs) - del model - return latency - -if __name__ == '__main__': - parser = argparse.ArgumentParser(prog='Benchmark Transformers') - parser.add_argument( - 'model', - type=str, - help='Specify model' - ) - parser.add_argument( - '--params', - type=str, - default='seqlen=1024', - help='Specify Input Parameters. E.g., seqlen=1024' - ) - parser.add_argument( - '--dtype', - type=str, - default='float16', - help='Specify precision. E.g., float32' - ) - args = parser.parse_args() - - model, dtype = args.model, args.dtype - seqlen = int(args.params.split('=')[1]) - latency = bench_hf_transformers(model, seqlen, dtype) - print(latency) \ No newline at end of file diff --git a/.github/scripts/bench/bench_utils.py b/.github/scripts/bench/bench_utils.py deleted file mode 100644 index 3921eea7a..000000000 --- a/.github/scripts/bench/bench_utils.py +++ /dev/null @@ -1,44 +0,0 @@ -import os -import hidet - -def setup_hidet_flags(dtype, dynamo=True): - if dynamo: - import torch - use_fp16 = dtype == 'float16' - hidet.torch.dynamo_config.search_space(2) - hidet.torch.dynamo_config.use_fp16(use_fp16) - hidet.torch.dynamo_config.use_fp16_reduction(use_fp16) - hidet.torch.dynamo_config.use_attention(True) - hidet.torch.dynamo_config.use_tensor_core(True) - hidet.torch.dynamo_config.use_cuda_graph(True) - else: - hidet.option.search_space(2) - hidet.option.cache_dir(hidet.option.get_cache_dir() + '/regression') - -def bench_torch_model(model, torch_inputs, bench_iters=100, warmup_iters=10): - import torch - for _ in range(warmup_iters): - torch_out = model(*torch_inputs) - torch.cuda.empty_cache() - - start = torch.cuda.Event(enable_timing=True) - end = torch.cuda.Event(enable_timing=True) - torch.cuda.synchronize() - start.record() - for _ in range(bench_iters): - torch_out = model(*torch_inputs) - end.record() - end.synchronize() - torch.cuda.empty_cache() - - latency = start.elapsed_time(end) / bench_iters - return latency - -def enable_compile_server(enable=True): - if os.environ.get('CI_CS_HOSTNAME'): - hidet.option.compile_server.addr(os.environ.get('CI_CS_HOSTNAME')) - hidet.option.compile_server.port(int(os.environ.get('CI_CS_PORT'))) - hidet.option.compile_server.username(os.environ.get('CI_CS_USERNAME')) - hidet.option.compile_server.password(os.environ.get('CI_CS_PASSWORD')) - hidet.option.compile_server.repo(os.environ.get('REPO_NAME').strip(), os.environ.get('REPO_BRANCH').strip()) - hidet.option.compile_server.enable(flag=enable) \ No newline at end of file diff --git a/.github/scripts/bench/bench_vision.py b/.github/scripts/bench/bench_vision.py deleted file mode 100644 index 7d11da809..000000000 --- a/.github/scripts/bench/bench_vision.py +++ /dev/null @@ -1,52 +0,0 @@ -import sys -import os -import argparse -import numpy as np -import torch -import torchvision -import hidet -from bench_utils import enable_compile_server, setup_hidet_flags, bench_torch_model - -def bench_torchvision(model_name, shape, dtype): - setup_hidet_flags(dtype) - enable_compile_server(True) - dtype = getattr(torch, dtype) - if any(name in model_name for name in ['deeplab', 'fcn', 'lraspp']): - model_cls = getattr(torchvision.models.segmentation, model_name) - else: - model_cls = getattr(torchvision.models, model_name) - model = model_cls(weights=None) - model = model.eval().to(dtype).cuda() - torch_inputs = [torch.randn(shape, device='cuda', dtype=dtype)] - with torch.no_grad(), torch.autocast("cuda"): - model = torch.compile(model, backend='hidet') - latency = bench_torch_model(model, torch_inputs) - del model - return latency - - -if __name__ == '__main__': - parser = argparse.ArgumentParser(prog='Benchmark Vision Models') - parser.add_argument( - 'model', - type=str, - help='Specify model' - ) - parser.add_argument( - '--params', - type=str, - default='1x3x224x224', - help='Specify Input Size. E.g., 1x3x224x224' - ) - parser.add_argument( - '--dtype', - type=str, - default='float16', - help='Specify precision. E.g., float32' - ) - args = parser.parse_args() - - model, dtype = args.model, args.dtype - shape = [int(d) for d in args.params.split('x')] - latency = bench_torchvision(model, shape, dtype) - print(latency) \ No newline at end of file diff --git a/.github/workflows/lint.yaml b/.github/workflows/lint.yaml index 7c31653aa..22dbf81fd 100644 --- a/.github/workflows/lint.yaml +++ b/.github/workflows/lint.yaml @@ -7,7 +7,7 @@ on: jobs: format-and-lint: - if: github.repository == 'hidet-org/hidet' + if: github.repository == 'hidet-org/hidet' || github.repository == 'CentML/hidet' concurrency: group: ${{ github.workflow }}-${{ github.head_ref || github.run_id }} cancel-in-progress: true diff --git a/.github/workflows/regression.yaml b/.github/workflows/regression.yaml index a202251b3..7deefc270 100644 --- a/.github/workflows/regression.yaml +++ b/.github/workflows/regression.yaml @@ -47,12 +47,12 @@ jobs: AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} AWS_DEFAULT_REGION: us-east-1 - - name: Upload run configs - uses: actions/upload-artifact@v3 - with: - name: run_configs - path: run_configs.json - retention-days: 1 + #- name: Upload run configs + # uses: actions/upload-artifact@v3 + # with: + # name: run_configs + # path: run_configs.json + # retention-days: 1 run_tests: needs: start_instances @@ -85,12 +85,16 @@ jobs: path: models ref: ci - - name: Download run configs - uses: actions/download-artifact@v3 - with: - name: run_configs - path: ./mount - + #- name: Download run configs + # uses: actions/download-artifact@v3 + # with: + # name: run_configs + # path: ./mount + + # Put run_configs.json in shared folder. Intup and output of tests is saved in it. + - name: Copy run_config.json + run: cp hidet/tests/benchmarks/run_configs.json ./mount + # Build the image - name: Build docker image from base image run: docker build -t hidet-ci -f hidet/.github/Dockerfile . @@ -105,7 +109,7 @@ jobs: -e HW_CONFIG -e REPO_NAME -e REPO_BRANCH -e CI_CS_HOSTNAME -e CI_CS_PORT -e CI_CS_USERNAME -e CI_CS_PASSWORD -e HF_TOKEN -v ./mount:/workspace/mount - hidet-ci python hidet/.github/scripts/run_tests.py --configs /workspace/mount/run_configs.json' + hidet-ci python hidet/tests/benchmarks/run_tests.py --configs /workspace/mount/run_configs.json' env: HW_CONFIG: ${{ matrix.hw_configs }} REPO_NAME: ${{ inputs.source_repo == 'this' && github.repository || inputs.source_repo }} @@ -162,7 +166,7 @@ jobs: CI_DB_PASSWORD: ${{ secrets.CI_DB_PASSWORD }} stop_instances: - if: inputs.shutdown_instances + if: ${{ always() && inputs.shutdown_instances }} runs-on: ubuntu-latest needs: [start_instances, run_tests] steps: diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 45fa5ca8c..c09d69a8f 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -60,6 +60,7 @@ jobs: - name: Run tests run: | + rm -rf ~/.config/hidet python -m pytest -v --durations=20 --clear-cache ./tests # Build the docs diff --git a/.gitignore b/.gitignore index 36e20a4d7..90196d6d0 100644 --- a/.gitignore +++ b/.gitignore @@ -184,6 +184,9 @@ cython_debug/ # option (not recommended) you can uncomment the following to ignore the entire idea folder. .idea/ +# Ignore VSCode files +.vscode/* + # pycharm line profiler result **/*.pclprof @@ -213,3 +216,6 @@ build-release # experiments folder /experiments + +# vscode +.vscode \ No newline at end of file diff --git a/apps/compile_server/resources/compile_worker.py b/apps/compile_server/resources/compile_worker.py index 40ab456d9..b7fc795e6 100644 --- a/apps/compile_server/resources/compile_worker.py +++ b/apps/compile_server/resources/compile_worker.py @@ -1,4 +1,4 @@ -from typing import Dict, Any, List, Tuple +from typing import Dict, Any, List, Tuple, Sequence, Union import os import traceback import argparse @@ -51,7 +51,7 @@ def compile_job(job_id: str): # load the workload workload: Dict[str, Any] = pickle.loads(job['workload']) - ir_module: hidet.ir.IRModule = workload['ir_module'] + ir_module: Union[hidet.ir.IRModule, Sequence[hidet.ir.IRModule]] = workload['ir_module'] target: str = workload['target'] output_kind: str = workload['output_kind'] diff --git a/python/hidet/apps/__init__.py b/python/hidet/apps/__init__.py index 4d9a92490..c0fac1823 100644 --- a/python/hidet/apps/__init__.py +++ b/python/hidet/apps/__init__.py @@ -9,3 +9,10 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .pretrained import PretrainedModel + +# from .pipeline import Pipeline +from .registry import Registry, RegistryEntry + +# from .processing import BaseProcessor diff --git a/python/hidet/apps/compile_server/compilation.py b/python/hidet/apps/compile_server/compilation.py index 78caa786d..7e23c3a78 100644 --- a/python/hidet/apps/compile_server/compilation.py +++ b/python/hidet/apps/compile_server/compilation.py @@ -9,6 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from typing import Sequence, Union import zipfile import shutil import tempfile @@ -21,7 +22,9 @@ from .core import api_url, access_token -def remote_build(ir_module: IRModule, output_dir: str, *, target: str, output_kind: str = '.so'): +def remote_build( + ir_module: Union[IRModule, Sequence[IRModule]], output_dir: str, *, target: str, output_kind: str = '.so' +): # upload the IRModule if 'cuda' in target: if 'arch' not in target: diff --git a/python/hidet/apps/diffusion/modeling/stable_diffusion/downsample.py b/python/hidet/apps/diffusion/modeling/stable_diffusion/downsample.py new file mode 100644 index 000000000..e0c570c4b --- /dev/null +++ b/python/hidet/apps/diffusion/modeling/stable_diffusion/downsample.py @@ -0,0 +1,19 @@ +from hidet.graph import nn +from hidet.graph.tensor import Tensor + + +class Downsample2D(nn.Module): + def __init__(self, channels: int, **kwargs): + super().__init__() + + self.channels = channels + self.out_channels = kwargs.get("output_channels", None) or channels + + self.conv = nn.Conv2d( + self.channels, self.out_channels, kernel_size=3, stride=2, padding=kwargs["downsample_padding"], bias=True + ) + + def forward(self, hidden_states: Tensor) -> Tensor: + assert hidden_states.shape[1] == self.channels + + return self.conv(hidden_states) diff --git a/python/hidet/apps/diffusion/modeling/stable_diffusion/resnet_blocks.py b/python/hidet/apps/diffusion/modeling/stable_diffusion/resnet_blocks.py new file mode 100644 index 000000000..ac2586bb6 --- /dev/null +++ b/python/hidet/apps/diffusion/modeling/stable_diffusion/resnet_blocks.py @@ -0,0 +1,79 @@ +from hidet.graph import nn +from hidet.graph.tensor import Tensor +from hidet.graph.ops import split + + +class ResnetBlock2D(nn.Module): + def __init__(self, **kwargs): + super().__init__() + + input_channels = kwargs["input_channels"] + output_channels = kwargs["output_channels"] or input_channels + groups_out = kwargs["resnet_groups"] + + self.norm1 = nn.GroupNorm( + num_groups=kwargs["resnet_groups"], num_channels=kwargs["input_channels"], eps=kwargs["resnet_eps"] + ) + + self.conv1 = nn.Conv2d( + in_channels=input_channels, out_channels=output_channels, kernel_size=3, padding=1, bias=True + ) + + temb_channels = kwargs["temb_channels"] + self.time_embedding_norm = kwargs["resnet_time_scale_shift"] + + self.time_emb_proj = None + if temb_channels is not None: + if self.time_embedding_norm == "default": + self.time_emb_proj = nn.Linear(temb_channels, output_channels) + elif self.time_embedding_norm == "scale_shift": + self.time_emb_proj = nn.Linear(temb_channels, 2 * output_channels) + else: + raise ValueError(f"Unknown time_embedding_norm: {self.time_embedding_norm}") + + self.norm2 = nn.GroupNorm(num_groups=groups_out, num_channels=output_channels, eps=kwargs["resnet_eps"]) + + if kwargs["dropout"] != 0.0: + raise NotImplementedError("No dropout should be used for inference") + + self.conv2 = nn.Conv2d( + in_channels=output_channels, out_channels=output_channels, kernel_size=3, padding=1, bias=True + ) + + self.nonlinearity = kwargs["resnet_act_fn"] + + self.use_in_shortcut = input_channels != output_channels + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv2d( + in_channels=input_channels, out_channels=output_channels, kernel_size=1, bias=True + ) + + def forward(self, x: Tensor, temb: Tensor): + input_tensor = x + x = self.norm1(x) + x = self.nonlinearity(x) + x = self.conv1(x) + + if self.time_emb_proj is not None: + temb = self.nonlinearity(temb) + temb = self.time_emb_proj(temb)[:, :, None, None] + + if self.time_embedding_norm == "default": + x = x + temb + x = self.norm2(x) + elif self.time_embedding_norm == "scale_shift": + time_scale, time_shift = split(temb, 2, axis=1)[:2] + x = self.norm2(x) + x = x * (1 + time_scale) + time_shift + else: + x = self.norm2(x) + + x = self.nonlinearity(x) + x = self.conv2(x) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + x + return output_tensor diff --git a/python/hidet/apps/diffusion/modeling/stable_diffusion/upsample.py b/python/hidet/apps/diffusion/modeling/stable_diffusion/upsample.py new file mode 100644 index 000000000..619adc50f --- /dev/null +++ b/python/hidet/apps/diffusion/modeling/stable_diffusion/upsample.py @@ -0,0 +1,24 @@ +from typing import Optional +from hidet.graph import nn +from hidet.graph.tensor import Tensor +from hidet.graph.ops import resize2d + + +class Upsample2D(nn.Module): + def __init__(self, channels: int, **kwargs): + super().__init__() + + self.channels = channels + self.out_channels = kwargs.get("output_channels", None) or channels + + self.conv = nn.Conv2d(self.channels, self.out_channels, kernel_size=3, padding=1, bias=True) + + def forward(self, hidden_states: Tensor, output_size: Optional[int] = None): + assert hidden_states.shape[1] == self.channels + + if output_size is None: + hidden_states = resize2d(hidden_states, scale_factor=2.0, method="nearest") + else: + hidden_states = resize2d(hidden_states, size=output_size, method="nearest") + + return self.conv(hidden_states) diff --git a/python/hidet/apps/hf.py b/python/hidet/apps/hf.py new file mode 100644 index 000000000..381afdd4d --- /dev/null +++ b/python/hidet/apps/hf.py @@ -0,0 +1,10 @@ +from typing import Optional + +from transformers import AutoConfig, PretrainedConfig + +import hidet + + +def load_pretrained_config(model: str, revision: Optional[str] = None) -> PretrainedConfig: + huggingface_token = hidet.option.get_option('auth_tokens.for_huggingface') + return AutoConfig.from_pretrained(model, revision=revision, token=huggingface_token) diff --git a/python/hidet/apps/image_classification/__init__.py b/python/hidet/apps/image_classification/__init__.py new file mode 100644 index 000000000..617af58f3 --- /dev/null +++ b/python/hidet/apps/image_classification/__init__.py @@ -0,0 +1 @@ +from .modeling import * diff --git a/python/hidet/apps/image_classification/app.py b/python/hidet/apps/image_classification/app.py new file mode 100644 index 000000000..fe152cdf9 --- /dev/null +++ b/python/hidet/apps/image_classification/app.py @@ -0,0 +1,13 @@ +from typing import Sequence + +from hidet.graph.tensor import Tensor +from hidet.runtime.compiled_app import CompiledApp + + +class ImageClassificationApp: + def __init__(self, compiled_app: CompiledApp): + super().__init__() + self.compiled_app: CompiledApp = compiled_app + + def classify(self, input_images: Sequence[Tensor]): + return self.compiled_app.graphs["image_classifier"].run_async(input_images) diff --git a/python/hidet/apps/image_classification/builder.py b/python/hidet/apps/image_classification/builder.py new file mode 100644 index 000000000..e0d17e1e7 --- /dev/null +++ b/python/hidet/apps/image_classification/builder.py @@ -0,0 +1,54 @@ +from typing import Optional + +from transformers import PretrainedConfig + +from hidet.apps import hf +from hidet.apps.image_classification.app import ImageClassificationApp +from hidet.apps.image_classification.modeling.pretrained import PretrainedModelForImageClassification +from hidet.apps.modeling_outputs import ImageClassifierOutput +from hidet.graph import trace_from +from hidet.graph.flow_graph import FlowGraph +from hidet.graph.tensor import Tensor, symbol +from hidet.runtime.compiled_app import create_compiled_app + +import hidet + + +def create_image_classifier( + name: str, + revision: Optional[str] = None, + dtype: str = "float32", + device: str = "cuda", + kernel_search_space: int = 2, +): + # load the huggingface config according to (model, revision) pair + config: PretrainedConfig = hf.load_pretrained_config(name, revision=revision) + + # load model instance by architecture, assume only 1 architecture for now + model = PretrainedModelForImageClassification.create_pretrained_model( + config, revision=revision, dtype=dtype, device=device + ) + inputs: Tensor = symbol(["bs", 3, 224, 224], dtype=dtype, device=device) + outputs: ImageClassifierOutput = model.forward(inputs) + graph: FlowGraph = trace_from(outputs.logits, inputs) + + graph = hidet.graph.optimize(graph) + + compiled_graph = graph.build(space=kernel_search_space) + + return ImageClassificationApp( + compiled_app=create_compiled_app(graphs={"image_classifier": compiled_graph}, name=name) + ) + + +# def create_image_processor( +# name: str, +# revision: Optional[str] = None, +# **kwargs +# ) -> BaseProcessor: +# # load the huggingface config according to (model, revision) pair +# config: PretrainedConfig = hf.load_pretrained_config(name, revision=revision) + +# processor = BaseImageProcessor.load_module(config, module_type=ModuleType.PROCESSING) + +# return processor(**kwargs) diff --git a/python/hidet/apps/image_classification/modeling/__init__.py b/python/hidet/apps/image_classification/modeling/__init__.py new file mode 100644 index 000000000..b792ca6ec --- /dev/null +++ b/python/hidet/apps/image_classification/modeling/__init__.py @@ -0,0 +1 @@ +from .resnet import * diff --git a/python/hidet/apps/image_classification/modeling/pretrained.py b/python/hidet/apps/image_classification/modeling/pretrained.py new file mode 100644 index 000000000..b4f8f1f3f --- /dev/null +++ b/python/hidet/apps/image_classification/modeling/pretrained.py @@ -0,0 +1,39 @@ +from typing import Optional + +import torch +from transformers import AutoModelForImageClassification, PretrainedConfig +from transformers import PreTrainedModel as TransformersPretrainedModel +from hidet.apps.modeling_outputs import ImageClassifierOutput +from hidet.apps.pretrained import PretrainedModel + +import hidet + + +class PretrainedModelForImageClassification(PretrainedModel[ImageClassifierOutput]): + @classmethod + def create_pretrained_model( + cls, config: PretrainedConfig, revision: Optional[str] = None, dtype: Optional[str] = None, device: str = "cuda" + ): + # dynamically load model subclass + pretrained_model_class = cls.load_module(config) + + # load the pretrained huggingface model into cpu + with torch.device("cuda"): # reduce the time to load the model + huggingface_token = hidet.option.get_option("auth_tokens.for_huggingface") + torch_model: TransformersPretrainedModel = AutoModelForImageClassification.from_pretrained( + pretrained_model_name_or_path=config.name_or_path, + torch_dtype=torch.float32, + revision=revision, + token=huggingface_token, + ) + + torch_model = torch_model.cpu() + torch.cuda.empty_cache() + + dtype = cls.parse_dtype(config) + hidet_model = pretrained_model_class(config) + hidet_model.to(dtype=dtype, device=device) + + cls.copy_weights(torch_model, hidet_model) + + return hidet_model diff --git a/python/hidet/apps/image_classification/modeling/resnet/__init__.py b/python/hidet/apps/image_classification/modeling/resnet/__init__.py new file mode 100644 index 000000000..a74368312 --- /dev/null +++ b/python/hidet/apps/image_classification/modeling/resnet/__init__.py @@ -0,0 +1 @@ +from .modeling import ResNetForImageClassification, ResNetModel diff --git a/python/hidet/apps/image_classification/modeling/resnet/modeling.py b/python/hidet/apps/image_classification/modeling/resnet/modeling.py new file mode 100644 index 000000000..081e4e482 --- /dev/null +++ b/python/hidet/apps/image_classification/modeling/resnet/modeling.py @@ -0,0 +1,222 @@ +from dataclasses import asdict +from typing import Sequence +from transformers import ResNetConfig + +from hidet.apps import PretrainedModel +from hidet.apps.image_classification.modeling.pretrained import PretrainedModelForImageClassification +from hidet.apps.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling, ImageClassifierOutput +from hidet.apps.registry import RegistryEntry +from hidet.graph import nn +from hidet.graph.tensor import Tensor + +PretrainedModel.register( + arch="ResNetForImageClassification", + entry=RegistryEntry( + model_category="image_classification", module_name="resnet", klass="ResNetForImageClassification" + ), +) + +# Contents below reflects transformers.models.resnet.modeling_resnet.py +# with minor API changes + + +class ResNetConvLayer(nn.Module[Tensor]): + def __init__( + self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: bool = True + ): + super().__init__() + self.convolution = nn.Conv2d( + in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2 + ) + self.normalization = nn.BatchNorm2d(out_channels) + self.apply_activation = activation + self.activation = nn.Relu() + + def forward(self, x: Tensor) -> Tensor: + hidden_state = self.convolution(x) + hidden_state = self.normalization(hidden_state) + if self.apply_activation: + hidden_state = self.activation(hidden_state) + return hidden_state + + +class ResNetEmbeddings(nn.Module[Tensor]): + """ + ResNet Embeddings (stem) composed of a single aggressive convolution. + """ + + def __init__(self, config: ResNetConfig): + super().__init__() + self.embedder = ResNetConvLayer(config.num_channels, config.embedding_size, kernel_size=7, stride=2) + self.pooler = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + self.num_channels = config.num_channels + + def forward(self, pixel_values: Tensor) -> Tensor: + num_channels = pixel_values.shape[1] + if num_channels != self.num_channels: + raise ValueError( + "Make sure that the channel dimension of the pixel values match with the one set in the configuration." + ) + embedding = self.embedder(pixel_values) + embedding = self.pooler(embedding) + return embedding + + +class ResNetShortCut(nn.Module[Tensor]): + """ + ResNet shortcut, used to project the residual features to the correct size. If needed, it is also used to + downsample the input using `stride=2`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 2): + super().__init__() + self.convolution = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride) + self.normalization = nn.BatchNorm2d(out_channels) + + def forward(self, x: Tensor) -> Tensor: + hidden_state = self.convolution(x) + hidden_state = self.normalization(hidden_state) + return hidden_state + + +class ResNetBottleNeckLayer(nn.Module[Tensor]): + """ + A classic ResNet's bottleneck layer composed by three `3x3` convolutions. + + The first `1x1` convolution reduces the input by a factor of `reduction` in order to make the second `3x3` + convolution faster. The last `1x1` convolution remaps the reduced features to `out_channels`. + """ + + def __init__(self, in_channels: int, out_channels: int, stride: int = 1, reduction: int = 4): + super().__init__() + self.should_apply_shortcut = in_channels != out_channels or stride != 1 + if self.should_apply_shortcut: + self.shortcut = ResNetShortCut(in_channels, out_channels, stride=stride) + + reduces_channels = out_channels // reduction + layer = [ + ResNetConvLayer(in_channels, reduces_channels, kernel_size=1), + ResNetConvLayer(reduces_channels, reduces_channels, stride=stride), + ResNetConvLayer(reduces_channels, out_channels, kernel_size=1, activation=False), + ] + + self.layer = nn.Sequential(layer) + self.activation = nn.Relu() + + def forward(self, hidden_state: Tensor) -> Tensor: + residual = hidden_state + hidden_state = self.layer(hidden_state) + if self.should_apply_shortcut: + residual = self.shortcut(residual) + hidden_state += residual + hidden_state = self.activation(hidden_state) + return hidden_state + + +class ResNetStage(nn.Module[Tensor]): + """ + A ResNet stage composed by stacked layers. + """ + + def __init__(self, config: ResNetConfig, in_channels: int, out_channels: int, stride: int = 2, depth: int = 2): + super().__init__() + + if config.layer_type != "bottleneck": + raise NotImplementedError( + "Only ResNet bottleneck layers supported. See ResNetBasicLayer in transformers source." + ) + + if config.hidden_act != "relu": + raise NotImplementedError("Only ReLU supported for ResNet activation.") + + self.layers = nn.Sequential( + # downsampling is done in the first layer with stride of 2 + ResNetBottleNeckLayer(in_channels, out_channels, stride=stride), + *[ResNetBottleNeckLayer(out_channels, out_channels) for _ in range(depth - 1)], + ) + + def forward(self, x: Tensor) -> Tensor: + return self.layers.forward(x) + + +class ResNetEncoder(nn.Module[BaseModelOutput]): + def __init__(self, config: ResNetConfig): + super().__init__() + stages = [ + ResNetStage( + config, + config.embedding_size, + config.hidden_sizes[0], + stride=2 if config.downsample_in_first_stage else 1, + depth=config.depths[0], + ) + ] + + # based on `downsample_in_first_stage` the first layer of the first stage may or may not downsample the input + in_out_channels = zip(config.hidden_sizes, config.hidden_sizes[1:]) + for (in_channels, out_channels), depth in zip(in_out_channels, config.depths[1:]): + stages.append(ResNetStage(config, in_channels, out_channels, depth=depth)) + + self.stages: nn.ModuleList = nn.ModuleList(stages) + + def forward(self, hidden_state: Tensor) -> BaseModelOutput: + hidden_states = [hidden_state] + + for stage_module in self.stages: + if stage_module is not None: + hidden_state = stage_module(hidden_state) + hidden_states.append(hidden_state) + + return BaseModelOutput(last_hidden_state=hidden_state, hidden_states=hidden_states) + + +class ResNetClassifier(nn.Sequential): + class Flatten(nn.Module): + def __init__(self, dims: Sequence[int]): + super().__init__() + self.dims = dims + + def forward(self, x: Tensor) -> Tensor: + return x.squeeze(self.dims) + + def __init__(self, config: ResNetConfig): + super().__init__() + assert config.num_labels > 0 + + layers = [self.Flatten((2, 3)), nn.Linear(config.hidden_sizes[-1], config.num_labels)] + for idx, module in enumerate(layers): + self.__setattr__(str(idx), module) + + +class ResNetModel(nn.Module[BaseModelOutputWithPooling]): + def __init__(self, config: ResNetConfig): + super().__init__() + self.config = config + self.embedder = ResNetEmbeddings(config) + self.encoder = ResNetEncoder(config) + self.pooler = nn.AdaptiveAvgPool2d((1, 1)) + + def forward(self, input_images: Tensor) -> BaseModelOutputWithPooling: + embedding_output = self.embedder(input_images) + encoder_outputs: BaseModelOutput = self.encoder(embedding_output) + + pooled_output = self.pooler(encoder_outputs.last_hidden_state) + + return BaseModelOutputWithPooling(**asdict(encoder_outputs), pooler_output=pooled_output) + + +class ResNetForImageClassification(PretrainedModelForImageClassification): + def __init__(self, config: ResNetConfig): + assert isinstance(config, ResNetConfig) + super().__init__(config) + self.num_labels = config.num_labels + self.resnet = ResNetModel(config) + # classification head + self.classifier = ResNetClassifier(config) + + def forward(self, input_images: Tensor) -> ImageClassifierOutput: + outputs: BaseModelOutputWithPooling = self.resnet(input_images) + + logits = self.classifier(outputs.pooler_output) + + return ImageClassifierOutput(**asdict(outputs), logits=logits) diff --git a/python/hidet/apps/modeling_outputs.py b/python/hidet/apps/modeling_outputs.py new file mode 100644 index 000000000..f8cfb1a09 --- /dev/null +++ b/python/hidet/apps/modeling_outputs.py @@ -0,0 +1,64 @@ +from collections import OrderedDict +from dataclasses import dataclass, fields, is_dataclass +from typing import Any, List, Tuple + +from hidet.graph import Tensor + + +class ModelOutput(OrderedDict): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + if self.__class__ != ModelOutput and not is_dataclass(self): + raise TypeError(f"{self.__module__}.{self.__class__} must be a dataclass to inherit from ModelOutput.") + + def __post_init__(self): + """ + Called by dataclasses after initialization of dataclass values. + + Here to enable dict-like access + """ + + class_fields = fields(self) # type: ignore + if len(class_fields) == 0: + raise ValueError(f"{self.__class__.__name__} has no fields.") + + for field in class_fields: + v = getattr(self, field.name) + if v is not None: + self[field.name] = v + + def __getitem__(self, k): + if isinstance(k, str): + inner_dict = dict(self.items()) + return inner_dict[k] + else: + return self.to_tuple()[k] + + def __setattr__(self, name, value): + if name in self.keys() and value is not None: + super().__setitem__(name, value) + super().__setattr__(name, value) + + def __setitem__(self, key, value): + super().__setitem__(key, value) + super().__setattr__(key, value) + + def to_tuple(self) -> Tuple[Any, ...]: + return tuple(self[k] for k in self.keys()) + + +@dataclass +class BaseModelOutput(ModelOutput): + last_hidden_state: Tensor + hidden_states: List[Tensor] + + +@dataclass +class BaseModelOutputWithPooling(BaseModelOutput): + pooler_output: Tensor + + +@dataclass +class ImageClassifierOutput(BaseModelOutputWithPooling): + logits: Tensor diff --git a/python/hidet/apps/pretrained.py b/python/hidet/apps/pretrained.py new file mode 100644 index 000000000..5a0734188 --- /dev/null +++ b/python/hidet/apps/pretrained.py @@ -0,0 +1,52 @@ +from typing import Generic, List, Set + +import torch +from transformers import PretrainedConfig + +from hidet.apps.registry import Registry +from hidet.graph import Tensor, nn +from hidet.graph.nn.module import R +from hidet.graph.tensor import from_torch + + +class PretrainedModel(nn.Module[R], Registry, Generic[R]): + def __init__(self, config: PretrainedConfig): + super().__init__() + self.config = config + + def forward(self, *args, **kwargs): + raise NotImplementedError() + + @classmethod + def copy_weights(cls, torch_model: torch.nn.Module, hidet_model: nn.Module): + found_tensors: List[Tensor] = [] + for name, tensor in torch_model.state_dict().items(): + member = hidet_model + for m_name in name.split("."): + member = getattr(member, m_name) + + if not isinstance(member, Tensor): + raise ValueError( + 'PyTorch model "{}" defined a parameter "{}" that is not in the hidet model'.format( + torch_model.__class__.__name__, name + ) + ) + + src = from_torch(tensor).to(member.dtype, member.device) + if src.shape != member.shape: + raise ValueError(f"Parameter {name} shape mismatch, hidet: {member.shape}, torch: {src.shape}") + found_tensors.append(member) + member.copy_(src) + + buffer_names: Set[str] = set(name for name, _ in torch_model.named_buffers()) + for name, tensor in hidet_model.named_parameters(): + if tensor not in found_tensors and name not in buffer_names: + raise ValueError(f"Parameter {name} in hidet model does not find equivalent in PyTorch model.") + + @classmethod + def parse_dtype(cls, config: PretrainedConfig, default: str = "float16"): + if config.torch_dtype: + assert isinstance(config.torch_dtype, torch.dtype) + return str(config.torch_dtype).rsplit(".", maxsplit=1)[-1] + else: + return default diff --git a/python/hidet/apps/registry.py b/python/hidet/apps/registry.py new file mode 100644 index 000000000..57d5ad8d2 --- /dev/null +++ b/python/hidet/apps/registry.py @@ -0,0 +1,73 @@ +import importlib +from dataclasses import astuple, dataclass +from typing import Dict + +from transformers import PretrainedConfig + + +@dataclass +class RegistryEntry: + """ + Configuration for dynamic loading of classes. + + We expect app directories to follow this file structure: + + apps/ + ├──/ + │ ├── modeling/ + │ │ ├── / + │ │ │ ├── __init__.py + │ │ └── ... + │ ├── processing/ + │ │ ├── / + │ │ │ ├── __init__.py + │ │ └── ... + ├──/ + └── ... + + For example, model_category could be "image_classification", under which "resnet" + is a model_name. The "resnet" module could contain class ResNetImageProcessor + representing a callable for processing images. + + Use this to dynamically load pre-processors under a general naming scheme. + """ + + model_category: str + module_name: str + klass: str + + def __init__(self, model_category: str, module_name: str, klass: str): + self.model_category = model_category + self.module_name = module_name + self.klass = klass + + +class Registry: + module_registry: Dict[str, RegistryEntry] = {} + + @classmethod + def load_module(cls, config: PretrainedConfig): + architectures = getattr(config, "architectures") + if not architectures: + raise ValueError(f"Config {config.name_or_path} has no architecture.") + + # assume only 1 architecture available for now + architecture = architectures[0] + if architecture not in cls.module_registry: + raise KeyError( + f"No model class with architecture {architecture} found." + f"Registered architectures: {', '.join(cls.module_registry.keys())}." + ) + + model_category, module_name, klass = astuple(cls.module_registry[architecture]) + + module = importlib.import_module(f"hidet.apps.{model_category}.modeling.{module_name}") + + if klass not in dir(module): + raise KeyError(f"No processor class named {klass} found in module {module}.") + + return getattr(module, klass) + + @classmethod + def register(cls, arch: str, entry: RegistryEntry): + cls.module_registry[arch] = entry diff --git a/python/hidet/backend/codegen.py b/python/hidet/backend/codegen.py index 5c141c6a8..bdaadfd40 100644 --- a/python/hidet/backend/codegen.py +++ b/python/hidet/backend/codegen.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional, List, Tuple, Dict, Union +from typing import Sequence, Optional, List, Tuple, Dict, Union import os import numpy as np from hidet.ir.dialects.pattern import PlaceholderExpr @@ -263,7 +263,6 @@ def visit_IRModule(self, module: IRModule) -> Doc: doc += '} // namespace ' + module.namespace + NewLine() doc = self.require_headers() + doc - return doc def visit_Function(self, func: Function) -> Doc: @@ -831,7 +830,7 @@ def visit_Function(self, func: Function) -> Doc: return doc -def codegen(ir_module: IRModule, src_out_path: str, target: Union[str, Target]) -> str: +def codegen(ir_module: Union[IRModule, Sequence[IRModule]], src_out_path: str, target: Union[str, Target]) -> str: if isinstance(target, str): target = Target.from_string(target) @@ -842,8 +841,14 @@ def codegen(ir_module: IRModule, src_out_path: str, target: Union[str, Target]) else: raise ValueError(f'Unknown target: {target}') - doc = gen(ir_module) - code = str(doc) + code = '' + if isinstance(ir_module, Sequence): + for m in ir_module: + doc = gen(m) + code += str(doc) + '\n' + else: + doc = gen(ir_module) + code = str(doc) if src_out_path is not None: dir_path = os.path.dirname(src_out_path) if not os.path.exists(dir_path): diff --git a/python/hidet/drivers/build_module.py b/python/hidet/drivers/build_module.py index 99fa64ec3..d44a4a832 100644 --- a/python/hidet/drivers/build_module.py +++ b/python/hidet/drivers/build_module.py @@ -9,7 +9,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Sequence, Dict +from typing import Sequence, Dict, Union import logging import os import pickle @@ -34,10 +34,20 @@ def can_remote_build(ir_module: IRModule) -> bool: - return not (len(ir_module.object_files) > 0 or len(ir_module.linking_dirs) > 0 or len(ir_module.include_dirs) > 0) + def can_remote_single_build(ir_module: IRModule) -> bool: + return not ( + len(ir_module.object_files) > 0 or len(ir_module.linking_dirs) > 0 or len(ir_module.include_dirs) > 0 + ) + + if isinstance(ir_module, IRModule): + return can_remote_single_build(ir_module) + else: + return all(can_remote_single_build(m) for m in ir_module) -def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output_kind: str = '.so', force=False): +def build_ir_module( + ir_module: Union[IRModule, Sequence[IRModule]], output_dir: str, target: str, output_kind: str = '.so', force=False +): """ Build an IR module to a shared library or object file. @@ -50,8 +60,8 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output Parameters ---------- - ir_module: IRModule - The IR module to be built. + ir_module: Union[IRModule, Sequence[IRModule]] + The IR module to be built. This can be a single IRModule or a sequence of IRModules. output_dir: str The directory to save the generated source code and the compiled library. @@ -118,20 +128,39 @@ def build_ir_module(ir_module: IRModule, output_dir: str, *, target: str, output if target.name == 'cpu' and 'arch' in target.attrs: hidet.option.cpu.arch(target.attrs['arch']) with PassContext(instruments=instruments): - ir_module = lower(ir_module) + if isinstance(ir_module, Sequence): + for i in range(len(ir_module)): + ir_module[i] = lower(ir_module[i]) + else: + ir_module = lower(ir_module) # code generation codegen(ir_module, src_out_path=src_path, target=target) + include_dir = [] + linking_dir = [] + linking_lib = [] + object_file = [] + if isinstance(ir_module, Sequence): + for im in ir_module: + include_dir.extend(im.include_dirs) + linking_dir.extend(im.linking_dirs) + linking_lib.extend(im.linking_libs) + object_file.extend(im.object_files) + else: + include_dir.extend(ir_module.include_dirs) + linking_dir.extend(ir_module.linking_dirs) + linking_lib.extend(ir_module.linking_libs) + object_file.extend(ir_module.object_files) # compile source code compile_source( src_path, output_library_file=lib_path, target=target, - include_dirs=ir_module.include_dirs, - linking_dirs=ir_module.linking_dirs, - linking_libraries=ir_module.linking_libs, - object_files=ir_module.object_files, + include_dirs=include_dir, + linking_dirs=linking_dir, + linking_libraries=linking_lib, + object_files=object_file, ) # write the function types @@ -154,8 +183,8 @@ def build_ir_module_batch( ir_modules: Sequence[IRModule] A sequence of ir modules to build. - output_dirs: Sequence[str] - The output directory to save the compiled library and source code (lib.so and source.cu). + output_dirs: Squence[str] + Directories for compilation artifacts output_kind: str The output kind of the compiled library. Can be '.so' or '.o'. @@ -172,19 +201,40 @@ def build_job(args): ir_module, output_dir = args build_ir_module(ir_module, output_dir, output_kind=output_kind, target=target, force=force) - jobs = [(ir_module, output_dir) for ir_module, output_dir in zip(ir_modules, output_dirs)] + def regroup_modules(modules, size): + if size > 1: + return [modules[i : i + size] for i in range(0, len(modules), size)] + else: + return modules + + # check if regrouped IRModules have unique function names + def check_function_singular(module_list: Union[Sequence[IRModule], Sequence[Sequence[IRModule]]]) -> bool: + if len(module_list) == 0 or isinstance(module_list[0], IRModule): + return True + name_set = set() + for modules in module_list: + for module in modules: + namespace_str = module.namespace + function_name_list = list(module.extern_functions.keys()) + list(module.functions.keys()) + for func_name in function_name_list: + func_str = namespace_str + '::' + func_name + if func_str in name_set: + return False + else: + name_set.add(func_str) + return True # calculate the number of workers cpu_count = os.cpu_count() if hidet.option.compile_server.enabled(): - num_workers = min(len(jobs), 128) + num_workers = min(len(ir_modules), 128) else: max_jobs, mem_for_worker = option.get_parallel_tune() max_jobs = cpu_count if max_jobs == -1 else min(max_jobs, cpu_count) mem_for_worker *= 1024**3 num_workers = min(max(int(psutil.virtual_memory().available // mem_for_worker), 1), max_jobs) - if num_workers > 1 and len(jobs) > 1: + if num_workers > 1 and len(ir_modules) > 1: # Set the affinity of current process. Some package such as numpy will change affinity of current process, # which might limit the parallelism of compilation. from contextlib import suppress @@ -194,9 +244,19 @@ def build_job(args): lazy_initialize_cuda() + per_worker_jobs = 1 if len(ir_modules) < num_workers else len(ir_modules) // num_workers + ir_modules_list = regroup_modules(ir_modules, per_worker_jobs) + assert check_function_singular( + ir_modules_list + ), 'duplicate function names detected after regrouping candidates for batch compilation' + jobs = [ + (ir_modules, output_dir) + for ir_modules, output_dir in zip(ir_modules_list, output_dirs[: len(ir_modules_list)]) + ] + for _ in tqdm(parallel_imap(build_job, jobs, num_workers), desc='Compiling', total=len(jobs), ncols=80): pass + return output_dirs[: len(ir_modules_list)] else: - # sequential build - for job in tqdm(jobs, desc='Compiling', ncols=80, disable=len(jobs) == 1): - build_job(job) + build_ir_module(ir_modules, output_dir=output_dirs[0], output_kind=output_kind, target=target, force=force) + return [output_dirs[0]] diff --git a/python/hidet/drivers/build_task.py b/python/hidet/drivers/build_task.py index 9a51bce95..92e858b1d 100644 --- a/python/hidet/drivers/build_task.py +++ b/python/hidet/drivers/build_task.py @@ -16,7 +16,6 @@ import shutil from hashlib import sha256 from typing import List, Optional, Tuple - import hidet.cuda from hidet import option from hidet.ir.stmt import AssertStmt @@ -111,9 +110,8 @@ def get_output_shape(idx: int32, dims: ~int32): # generate the candidate summary _generate_candidate_summary(candidates, task_dir) - # build each candidate to an object file (.o) - build_ir_module_batch( + objects_path_list = build_ir_module_batch( ir_modules=candidates, output_dirs=[os.path.join(task_dir, 'candidates', str(i)) for i in range(len(candidates))], output_kind='.o', @@ -143,9 +141,7 @@ def launch(arg: meta.types(param_types)): ir_module = script_module.ir_module() ir_module.add_function(get_input_shape.name, get_input_shape) ir_module.add_function(get_output_shape.name, get_output_shape) - ir_module.object_files.extend( - [os.path.join(task_dir, 'candidates', str(i), 'lib.o') for i in range(len(candidates))] - ) + ir_module.object_files.extend([os.path.join(object_path, 'lib.o') for object_path in objects_path_list]) task_ir_module = ir_module # add assertions to the launch function @@ -162,7 +158,6 @@ def launch(arg: meta.types(param_types)): # build task ir module build_ir_module(ir_module=task_ir_module, output_dir=task_dir, output_kind='.so', target=target) - # clear the candidate object files that are no longer needed if not hidet.option.get_option('debug_cache_tuning'): shutil.rmtree(os.path.join(task_dir, 'candidates'), ignore_errors=True) @@ -277,7 +272,6 @@ def build_task(task: Task, target='cuda', load=True) -> Optional[CompiledTask]: # write version with open(version_path, 'w') as f: f.write(hidet.__version__) - # implement task to IRModule, each task may produce multiple IRModules (candidates) # they have the same functionality but different performance candidates = task.implement(target=target, working_dir=task_dir) diff --git a/python/hidet/graph/frontend/torch/dynamo_backends.py b/python/hidet/graph/frontend/torch/dynamo_backends.py index 4605b90c2..5b3c56b98 100644 --- a/python/hidet/graph/frontend/torch/dynamo_backends.py +++ b/python/hidet/graph/frontend/torch/dynamo_backends.py @@ -22,6 +22,7 @@ from hidet.graph.flow_graph import FlowGraph from hidet.graph.transforms import PassContext, optimize from hidet.cuda.graph import CudaGraphCreationError +from hidet.ffi import runtime_api from .dynamo_config import dynamo_config from .interpreter import Interpreter from .utils import serialize_output, deserialize_output, resolve_save_dir_multigraph @@ -108,46 +109,45 @@ def preprocess_inputs(inputs: Sequence[torch.Tensor]) -> List[hidet.Tensor]: return hidet_inputs -def get_wrapper(cgraph: CompiledGraph, inputs, output_format): - use_cuda_graph = dynamo_config['use_cuda_graph'] - if use_cuda_graph: - try: - runner = cgraph.cuda_graph() - except CudaGraphCreationError: - runner = cgraph - else: - runner = cgraph +class HidetCompiledModel: + def __init__(self, cgraph: CompiledGraph, inputs, output_format): + super().__init__() + self.inputs = inputs + self.output_format = output_format + self.cgraph_configured = False + self.cgraph = cgraph - def run(*inputs: torch.Tensor): - hidet_inputs = preprocess_inputs(inputs) - hidet_outputs: List[hidet.Tensor] = runner.run_async(hidet_inputs) - torch_outputs: List[torch.Tensor] = [tensor.torch() for tensor in hidet_outputs] - return torch_outputs + def configure_cgraph(self): + if dynamo_config['use_cuda_graph']: + try: + self.cgraph = self.cgraph.cuda_graph() + except CudaGraphCreationError: + pass # Leave cgraph as is + + def __call__(self, *args): + if not self.cgraph_configured: + self.configure_cgraph() + self.cgraph_configured = True - def wrapper(*args: Tensor): tensor_args = [] - for param, arg in zip(inputs, args): + for param, arg in zip(self.inputs, args): if isinstance(param, Tensor): tensor_args.append(arg) elif isinstance(param, SymbolVar): dtype = param.type assert isinstance(dtype, DataType) if dtype.name == 'int32': - from hidet.ffi import runtime_api - runtime_api.set_symbol_value(param.name, int(arg)) else: raise ValueError(f'hidet_backend: unsupported symbolic dtype {dtype}. We only support int32 now.') else: # ignore constant pass - outputs: Sequence[torch.Tensor] = run(*tensor_args) - ret = deserialize_output(output_format, outputs) - return ret - - logger.info('finish generating the executor') - return wrapper + hidet_inputs = preprocess_inputs(tensor_args) + hidet_outputs: List[hidet.Tensor] = self.cgraph.run_async(hidet_inputs) + outputs: Sequence[torch.Tensor] = [tensor.torch() for tensor in hidet_outputs] + return deserialize_output(self.output_format, outputs) def hidet_backend(graph_module, example_inputs): @@ -178,4 +178,4 @@ def wrapper(*args): cgraph = get_compiled_graph(flow_graph) - return get_wrapper(cgraph, inputs, output_format) + return HidetCompiledModel(cgraph, inputs, output_format) diff --git a/python/hidet/graph/frontend/torch/register_functions.py b/python/hidet/graph/frontend/torch/register_functions.py index 8b66104b3..8c1f98053 100644 --- a/python/hidet/graph/frontend/torch/register_functions.py +++ b/python/hidet/graph/frontend/torch/register_functions.py @@ -30,7 +30,7 @@ @register_function(torch.nn.functional.conv1d) -def conv1d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups): +def conv1d(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride=1, padding=0, dilation=1, groups=1): x = ops.conv_pad(x, padding) y = ops.conv1d(x, weight, stride=stride, dilations=dilation, groups=groups) if bias is not None: @@ -40,7 +40,14 @@ def conv1d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, d @register_function(torch.nn.functional.conv_transpose1d) def conv1d_transpose( - x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + stride=1, + padding=0, + output_padding=0, + groups=1, + dilation=1, ): if dilation != 1 and not same_list(dilation, [1]): raise NotImplementedError("dilation != 1") @@ -51,7 +58,7 @@ def conv1d_transpose( @register_function(torch.nn.functional.conv2d) -def conv2d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups): +def conv2d(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride=1, padding=0, dilation=1, groups=1): y = ops.conv2d(x, weight, stride, dilation, groups, padding=padding) if bias is not None: y = y + ops.unsqueeze(bias, [0, 2, 3]) @@ -60,7 +67,7 @@ def conv2d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, d @register_function(torch.nn.functional.conv_transpose2d) def conv2d_transpose( - x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation + x: Tensor, weight: Tensor, bias: Optional[Tensor], stride=1, padding=0, output_padding=0, groups=1, dilation=1 ): if dilation != 1 and not same_list(dilation, [1, 1]): raise NotImplementedError("dilation != 1") @@ -71,7 +78,7 @@ def conv2d_transpose( @register_function(torch.nn.functional.conv3d) -def conv3d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, dilation, groups): +def conv3d(x: Tensor, weight: Tensor, bias: Optional[Tensor] = None, stride=1, padding=0, dilation=1, groups=1): x = ops.conv_pad(x, padding) y = ops.conv3d(x, weight, stride, dilation, groups) if bias is not None: @@ -81,7 +88,14 @@ def conv3d(x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, d @register_function(torch.nn.functional.conv_transpose3d) def conv3d_transpose( - x: Tensor, weight: Tensor, bias: Optional[Tensor], stride, padding, output_padding, groups, dilation + x: Tensor, + weight: Tensor, + bias: Optional[Tensor] = None, + stride=1, + padding=0, + output_padding=0, + groups=1, + dilation=1, ): if dilation != 1 and not same_list(dilation, [1, 1, 1]): raise NotImplementedError("dilation != 1") @@ -160,6 +174,11 @@ def iadd(x: Tensor, y: Tensor): return x + y +@register_function(operator.imul) +def imul(x: Tensor, y: Tensor): + return x * y + + @register_function(torch.sin) @register_function(torch.ops.aten.sin.default) def sin(x: Tensor): @@ -306,7 +325,7 @@ def mul(x: Tensor, y: Tensor): @register_function(torch.cat) -def cat(tensors: List[Tensor], dim: int): +def cat(tensors: List[Tensor], dim: int = 0): dtype = functools.reduce(promote_type, [t.dtype for t in tensors]) tensors = [ops.cast(t, dtype) for t in tensors] return ops.concat(tensors, dim) @@ -674,6 +693,11 @@ def pow(base: Tensor, exponent: Union[Number, Tensor]): return ops.pow(base, exponent) +@register_function(torch.scalar_tensor) +def scalar_tensor(value): + return ops.full([1], value) + + @register_function(torch.full) def full(size, fill_value, *, out=None, dtype=None, layout=None, device=None, requires_grad=False): if out is not None: @@ -713,7 +737,9 @@ def empty( hidet_dtype: DataType = dtype_from_torch(torch_dtype=dtype) if len(size) == 1 and isinstance(size[0], (tuple, list)): size = size[0] - return ops.full(size, dtype=hidet_dtype, device=hidet_device, value=hidet_dtype.zero) + return ops.full( + size, dtype=hidet_dtype, device=hidet_device, value=hidet_dtype.zero if hidet_dtype is not None else 0 + ) @register_function(torch.bmm) @@ -1003,6 +1029,7 @@ def ge(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor return a >= b +@register_method(torch.Tensor.eq) @register_function(operator.eq) def eq(a: Union[Tensor, Expr, Number], b: Union[Tensor, Expr, Number]) -> Tensor: if isinstance(a, Tensor) or isinstance(b, Tensor): @@ -1146,6 +1173,7 @@ def torch_conj(x: Tensor) -> Tensor: @register_function(torch._C._log_api_usage_once) +@register_function(torch._assert_async) @register_function(torch.cuda.synchronize) def torch_noop(*args, **kwargs): return diff --git a/python/hidet/graph/frontend/torch/register_methods.py b/python/hidet/graph/frontend/torch/register_methods.py index a9b5c9d21..56b963d23 100644 --- a/python/hidet/graph/frontend/torch/register_methods.py +++ b/python/hidet/graph/frontend/torch/register_methods.py @@ -63,6 +63,16 @@ def tensor_type_as(self: Tensor, other: Tensor) -> Tensor: return ops.cast(self, other.dtype) +@register_method(torch.Tensor.index_select) +def index_select(self: Tensor, dim: int, index: Tensor): + return ops.index_select(self, index, dim) + + +@register_method(torch.Tensor.fill_) +def fill_(self: Tensor, value): + return ops.full(self.shape, value, dtype=self.dtype, device=self.device) + + @register_method(torch.Tensor.to) def tensor_to(self: Tensor, *args, **kwargs) -> Tensor: """ diff --git a/python/hidet/graph/nn/__init__.py b/python/hidet/graph/nn/__init__.py index b127c0da8..def2ba30e 100644 --- a/python/hidet/graph/nn/__init__.py +++ b/python/hidet/graph/nn/__init__.py @@ -13,8 +13,10 @@ from . import container from .module import Module +from .identity import Identity from .container import Sequential, ModuleList -from .activations import Relu, Gelu, Tanh +from .attention import CrossAttention +from .activations import Relu, Gelu, Geglu, Tanh from .convolutions import Conv2d from .linear import Linear, LinearTransposed from .norms import BatchNorm2d, LayerNorm diff --git a/python/hidet/graph/nn/activations.py b/python/hidet/graph/nn/activations.py index 4d6cbe60e..e5a32c229 100644 --- a/python/hidet/graph/nn/activations.py +++ b/python/hidet/graph/nn/activations.py @@ -10,6 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from hidet.graph import ops +from hidet.graph.nn.linear import Linear from hidet.graph.nn.module import Module @@ -23,6 +24,17 @@ def forward(self, x): return ops.gelu(x) +class Geglu(Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = Linear(dim_in, dim_out * 2, bias=bias) + + def forward(self, x): + x = self.proj(x) + hidden_states, gate = ops.split(x, 2, axis=2) + return hidden_states * ops.gelu(gate) + + class Tanh(Module): def forward(self, x): return ops.tanh(x) diff --git a/python/hidet/graph/nn/attention.py b/python/hidet/graph/nn/attention.py new file mode 100644 index 000000000..7857ba125 --- /dev/null +++ b/python/hidet/graph/nn/attention.py @@ -0,0 +1,89 @@ +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from typing import Optional +import math + +from hidet.graph import ops +from hidet.graph.nn.container import ModuleList +from hidet.graph.nn.linear import Linear +from hidet.graph.nn.module import Module +from hidet.graph.tensor import Tensor +from hidet.utils.py import prod + + +class CrossAttention(Module): + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + upcast: bool = False, + out_bias: bool = True, + ): + super().__init__() + self.inner_dim = dim_head * heads + self.query_dim = query_dim + self.cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.heads = heads + self.upcast = upcast + self.out_bias = out_bias + + self.to_q = Linear(query_dim, self.inner_dim, bias=False) + self.to_k = Linear(self.cross_attention_dim, self.inner_dim, bias=False) + self.to_v = Linear(self.cross_attention_dim, self.inner_dim, bias=False) + + self.to_out = ModuleList([Linear(self.inner_dim, self.query_dim, bias=out_bias)]) + + def forward( + self, hidden_states: Tensor, encoder_hidden_states: Optional[Tensor] = None, temperature_scaling: float = 1.0 + ) -> Tensor: + ndim = len(hidden_states.shape) + if ndim == 4: + bs, c, h, w = hidden_states.shape + hidden_states = hidden_states.reshape([bs, c, h * w]).transpose(1, 2) + + bs, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape + + q = self.to_q(hidden_states) + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + + k = self.to_k(encoder_hidden_states) + v = self.to_v(encoder_hidden_states) + + inner_dim = k.shape[-1] + head_dim = inner_dim // self.heads + + other_dims = bs * inner_dim + assert (prod(q.shape) % other_dims) == 0 + + q, k, v = tuple(t.reshape((bs, -1, self.heads, head_dim)).transpose(1, 2).to("float16") for t in (q, k, v)) + q = q * (1 / math.sqrt(head_dim)) + k = k.transpose(-1, -2) + + # Use softmax temperature parameter to prevent QK matmul causing float overflow + # due to limited fp16 range. May cause accuracy issues, should only be applied + # for attention layers that have overflow issue. Alternate solution is to + # cast to fp32 and use mm/softmax/mm attention + assert temperature_scaling >= 1.0 + if temperature_scaling != 1.0: + q = q / temperature_scaling + + hidden_states = ops.attention(q, k, v).to(dtype=hidden_states.dtype) + hidden_states = hidden_states.transpose(1, 2).reshape((bs, -1, inner_dim)) + hidden_states = self.to_out[0](hidden_states) + + if ndim == 4: + hidden_states = hidden_states.transpose(1, 2).reshape((bs, c, h, w)) + + return hidden_states diff --git a/python/hidet/graph/nn/container.py b/python/hidet/graph/nn/container.py index a28ad3684..ac3af30bf 100644 --- a/python/hidet/graph/nn/container.py +++ b/python/hidet/graph/nn/container.py @@ -10,7 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -from typing import Iterable +from typing import Iterable, List, Union from collections import OrderedDict from .module import Module @@ -30,12 +30,19 @@ def __init__(self, *args): def forward(self, x): # pylint: disable=arguments-differ for module in self._submodules.values(): - x = module(x) + if module is not None: + x = module(x) return x + def __iter__(self): + return iter(self._submodules.values()) + + def __len__(self): + return len(self._submodules.keys()) + class ModuleList(Module): - def __init__(self, modules: Iterable[Module] = None): + def __init__(self, modules: Iterable[Module]): super().__init__() for idx, module in enumerate(modules): self._submodules[str(idx)] = module @@ -43,5 +50,15 @@ def __init__(self, modules: Iterable[Module] = None): def __iter__(self): return iter(self._submodules.values()) + def __getitem__(self, index: int) -> Union[Module, List[Module]]: + if isinstance(index, slice): + module_list = [self._submodules[str(idx)] for idx in range(len(self._submodules))] + return module_list[index] + else: + return self._submodules[str(index)] + + def __len__(self): + return len(self._submodules) + def forward(self, *args): raise ValueError('Should not forward ModuleList.') diff --git a/python/hidet/graph/nn/convolutions.py b/python/hidet/graph/nn/convolutions.py index 86e500d7f..d7d086954 100644 --- a/python/hidet/graph/nn/convolutions.py +++ b/python/hidet/graph/nn/convolutions.py @@ -16,7 +16,7 @@ class Conv2d(Module): - def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, groups=1): + def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, groups=1, bias=False): super().__init__() self.in_channels = in_channels self.out_channels = out_channels @@ -25,6 +25,8 @@ def __init__(self, in_channels, out_channels, kernel_size, padding=0, stride=1, self.stride = normalize(stride) self.groups = groups self.weight = empty(shape=[out_channels, in_channels, *self.kernel], dtype='float32') + # use shape (oc, 1, 1) for broadcast + self.bias = empty(shape=[out_channels, 1, 1], dtype="float32") if bias else None def extra_str(self) -> str: return 'in_channels={}, out_channels={}, kernel_size={}, stride={}, padding={}'.format( @@ -33,4 +35,7 @@ def extra_str(self) -> str: def forward(self, x): x = ops.pad(x, ops.utils.normalize_padding(self.padding)) - return ops.conv2d(x, self.weight, stride=self.stride, groups=self.groups) + x = ops.conv2d(x, self.weight, stride=self.stride, groups=self.groups) + if self.bias is not None: + x = ops.add(x, self.bias) + return x diff --git a/python/hidet/graph/nn/identity.py b/python/hidet/graph/nn/identity.py new file mode 100644 index 000000000..d7a578333 --- /dev/null +++ b/python/hidet/graph/nn/identity.py @@ -0,0 +1,19 @@ +from hidet.graph.nn.module import Module + + +class Identity(Module): + """ + Identity function. + + Used as a dummy for replacing modules (e.g. remove a layer in module list + but need to keep indices in container to match torch model) + """ + + def __init__(self, *args, **kwargs): + super().__init__() + + self.args = args + self.kwargs = kwargs + + def forward(self, x): + return x diff --git a/python/hidet/graph/nn/module.py b/python/hidet/graph/nn/module.py index b45d6fd4e..9b1230a3c 100644 --- a/python/hidet/graph/nn/module.py +++ b/python/hidet/graph/nn/module.py @@ -10,35 +10,44 @@ # See the License for the specific language governing permissions and # limitations under the License. from __future__ import annotations -from typing import Optional, Sequence, Iterator, Dict, Any +from typing import Optional, Sequence, Iterator, Dict, Any, Generic, TypeVar from collections import OrderedDict from hidet.graph.tensor import symbol_like from hidet.graph.flow_graph import FlowGraph, trace_from from hidet.graph.tensor import Tensor +# forward method return type +R = TypeVar('R') -class Module: + +class Module(Generic[R]): def __init__(self): self.name = None self._parameters: OrderedDict[str, Optional[Tensor]] = OrderedDict() self._submodules: OrderedDict[str, Optional[Module]] = OrderedDict() def __setattr__(self, key, value): - parameters = self.__dict__.get('parameters') - submodules = self.__dict__.get('submodules') + if key in ['name', '_submodules', '_parameters']: + super().__setattr__(key, value) + return + + parameters = self.__dict__.get('_parameters') + submodules = self.__dict__.get('_submodules') + + if key in parameters: + del self._parameters[key] + elif key in submodules: + del self._submodules[key] + elif key in self.__dict__: + del self.__dict__[key] + if isinstance(value, Tensor): - value.name = key - self._parameters[key] = value + parameters[key] = value elif isinstance(value, Module): - value.name = '{}.{}'.format(self.name, key) if self.name else key - self._submodules[key] = value - elif parameters and submodules and value is None and (key in parameters or key in submodules): - if key in self._parameters: - self._parameters[key] = value - if key in self._submodules: - self._submodules[key] = value + submodules[key] = value else: - super().__setattr__(key, value) + self.__dict__[key] = value + cnt = sum(1 for collection in [parameters, submodules, self.__dict__] if collection and key in collection) assert cnt <= 1, 'duplicated definition of {}'.format(key) @@ -70,7 +79,7 @@ def __str__(self): lines = [' ' * indent + line for line in lines] return '{}(\n{}\n)'.format(name, '\n'.join(lines)) - def __call__(self, *args, **kwargs): + def __call__(self, *args, **kwargs) -> R: return self.forward(*args, **kwargs) def state_dict(self) -> Dict[str, Any]: @@ -86,7 +95,7 @@ def load_state_dict(self, state_dict: Dict[str, Any]): def extra_str(self) -> str: return '' - def forward(self, *args, **kwargs): + def forward(self, *args, **kwargs) -> R: raise NotImplementedError() def parameters(self, recursive: bool = True) -> Iterator[Tensor]: diff --git a/python/hidet/graph/nn/norms.py b/python/hidet/graph/nn/norms.py index 3d288c751..0e65539df 100644 --- a/python/hidet/graph/nn/norms.py +++ b/python/hidet/graph/nn/norms.py @@ -23,6 +23,7 @@ def __init__(self, num_features, eps=1e-5, affine=True): self.affine = affine self.running_mean = empty(shape=[num_features]) self.running_var = empty(shape=[num_features]) + self.num_batches_tracked = empty(shape=[]) if affine: self.weight: Tensor = empty(shape=[num_features]) self.bias = empty(shape=[num_features]) @@ -62,3 +63,26 @@ def forward(self, x: Tensor) -> Tensor: if self.bias is not None: x = x + self.bias return x + + +class GroupNorm(Module): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True): + super().__init__() + self.eps = eps + self.affine = affine + self.num_groups = num_groups + self.num_channels = num_channels + if affine: + # add extra dims for broadcast + self.weight: Tensor = empty(shape=[num_channels, 1, 1]) + self.bias: Tensor = empty(shape=[num_channels, 1, 1]) + else: + self.weight = None + self.bias = None + + def forward(self, x: Tensor): + x = ops.group_norm(x, self.num_groups, self.eps) + if self.affine: + x = x * self.weight + self.bias + + return x diff --git a/python/hidet/graph/ops/__init__.py b/python/hidet/graph/ops/__init__.py index 52a710a53..85e583309 100644 --- a/python/hidet/graph/ops/__init__.py +++ b/python/hidet/graph/ops/__init__.py @@ -41,7 +41,7 @@ from .reduce import mean, sum, var, min, max, std, prod, argmin, argmax, all, any from .cumulative import cumsum from .transform import squeeze, unsqueeze, flatten, concat, cast, take, rearrange, strided_slice, reshape -from .transform import transpose, broadcast, pad, tile, split, conv_pad, expand_dims, gather +from .transform import transpose, broadcast, pad, tile, split, conv_pad, expand_dims, gather, index_select from .transform import permute_dims from .fusion import fused_operator from .transfer import transfer diff --git a/python/hidet/graph/ops/attention/attention.py b/python/hidet/graph/ops/attention/attention.py index 278deb49b..df7f0f0be 100644 --- a/python/hidet/graph/ops/attention/attention.py +++ b/python/hidet/graph/ops/attention/attention.py @@ -859,6 +859,9 @@ def __init__(self, q: Tensor, k: Tensor, v: Tensor, is_causal: bool = False): def attention(q: Tensor, k: Tensor, v: Tensor, mask: Optional[Tensor] = None, is_causal: bool = False) -> Tensor: + # Note: does not apply scaling factor (1/sqrt(E)) in softmax, + # requires k transposed relative to q + # (ie. returns softmax(Q @ K) @ V ) if mask is not None and is_causal is True: raise ValueError("mask and is_causal cannot be set at the same time") diff --git a/python/hidet/graph/ops/conv2d/conv2d.py b/python/hidet/graph/ops/conv2d/conv2d.py index 6c3891285..c7f9eb9d5 100644 --- a/python/hidet/graph/ops/conv2d/conv2d.py +++ b/python/hidet/graph/ops/conv2d/conv2d.py @@ -12,7 +12,14 @@ from typing import List, Union, Sequence from hidet import ir from hidet.graph.ops.utils import Task, Operator, Tensor, TensorNode -from hidet.graph.ops.utils import compute, input_like, normalize_stride, normalize_dilations, reduce +from hidet.graph.ops.utils import ( + compute, + input_like, + normalize_stride, + normalize_dilations, + normalize_conv_padding, + reduce, +) from hidet.utils.py import cdiv @@ -147,6 +154,7 @@ def __init__( ): stride = normalize_stride(stride) dilations = normalize_dilations(dilations) + padding = normalize_conv_padding(padding, 2) super().__init__( inputs=[x, w], attributes={'padding': padding, 'stride': stride, 'groups': groups, 'dilations': dilations}, diff --git a/python/hidet/graph/ops/image.py b/python/hidet/graph/ops/image.py index 2100d1be5..42ee99e02 100644 --- a/python/hidet/graph/ops/image.py +++ b/python/hidet/graph/ops/image.py @@ -107,6 +107,11 @@ def resize2d_nchw_compute( extrapolation_value: Optional[float], recompute_scale_factor: Optional[bool], ): # pylint: disable=unused-argument + """ + Resize data to size or by scale, whichever is provided, according to method (one of `nearest`, `linear`, `cubic`). + + `nearest` and `linear` methods preserve dtype, cubic does not. + """ _ = roi # not supported yet image_size = data.shape[2:] diff --git a/python/hidet/graph/ops/reduce/reduce.py b/python/hidet/graph/ops/reduce/reduce.py index 7a0979aac..05cc491e3 100644 --- a/python/hidet/graph/ops/reduce/reduce.py +++ b/python/hidet/graph/ops/reduce/reduce.py @@ -398,7 +398,14 @@ def reduce_fcompute(reduce_index): class ReduceBaseOp(Operator): - def __init__(self, x: Tensor, dims: Optional[Sequence[int]], keep_dim: bool, reduce_type: str): + def __init__( + self, + x: Tensor, + dims: Optional[Sequence[int]], + keep_dim: bool, + reduce_type: str, + accumulate_dtype: str = 'float32', + ): rank = len(x.shape) if dims is None: dims = list(range(rank)) @@ -406,7 +413,7 @@ def __init__(self, x: Tensor, dims: Optional[Sequence[int]], keep_dim: bool, red super().__init__( inputs=[x], attributes={'dims': dims, 'keepdims': keep_dim}, - task=ReduceTask(input_like(x, 'x'), dims, keep_dim, reduce_type), + task=ReduceTask(input_like(x, 'x'), dims, keep_dim, reduce_type, accumulate_dtype=accumulate_dtype), ) @@ -444,12 +451,12 @@ def __init__(self, x: Tensor, dims: Optional[Sequence[int]], keepdims: bool = Fa class ReduceOrOp(ReduceBaseOp): def __init__(self, x: Tensor, dims: Optional[Sequence[int]], keepdims: bool = False): - super().__init__(x, dims, keepdims, ReduceType.Or.value) + super().__init__(x, dims, keepdims, ReduceType.Or.value, 'bool') class ReduceAndOp(ReduceBaseOp): def __init__(self, x: Tensor, dims: Optional[Sequence[int]], keepdims: bool = False): - super().__init__(x, dims, keepdims, ReduceType.And.value) + super().__init__(x, dims, keepdims, ReduceType.And.value, 'bool') class ReduceProdOp(ReduceBaseOp): diff --git a/python/hidet/graph/ops/transform.py b/python/hidet/graph/ops/transform.py index 9f22c6da1..fbc6ab334 100644 --- a/python/hidet/graph/ops/transform.py +++ b/python/hidet/graph/ops/transform.py @@ -238,6 +238,20 @@ def fmap(*output_indices): super().__init__(name='gather', inputs=[data, indices], outputs=[output]) +class IdxSelTask(Task): + def __init__(self, data: TensorInput, index: TensorInput, dim=0): + output_shape = data.shape[:dim] + [index.shape[0]] + data.shape[dim + 1 :] + + def fmap(*output_indices): + index_value = index[output_indices[dim]] + index_value = if_then_else(index_value < 0, index_value + data.shape[dim], index_value) + data_indices = output_indices[:dim] + (index_value,) + output_indices[dim + 1 :] + return data[data_indices] + + output = compute(name='output', shape=output_shape, fcompute=lambda *output_indices: fmap(*output_indices)) + super().__init__(name='idxsel', inputs=[data, index], outputs=[output]) + + class StridedSliceTask(Task): def __init__( self, @@ -426,6 +440,15 @@ def __init__(self, data: Tensor, indices: Tensor, axis: int): ) +class IdxSelOp(Operator): + def __init__(self, data: Tensor, index: Tensor, dim: int): + super().__init__( + inputs=[data, index], + attributes={'dim': dim}, + task=IdxSelTask(input_like(data, 'data'), input_like(index, 'index'), dim=dim), + ) + + class StridedSliceOp(Operator): def __init__( self, @@ -583,6 +606,10 @@ def gather(data: Tensor, indices: Tensor, axis: int = 0) -> Tensor: return GatherOp(data, indices, axis).outputs[0] +def index_select(data: Tensor, index: Tensor, dim: int) -> Tensor: + return IdxSelOp(data, index, dim).outputs[0] + + def strided_slice( data: Tensor, starts: Sequence[Optional[int]], diff --git a/python/hidet/graph/ops/utils/tensor_utils.py b/python/hidet/graph/ops/utils/tensor_utils.py index 93f13a712..398536f89 100644 --- a/python/hidet/graph/ops/utils/tensor_utils.py +++ b/python/hidet/graph/ops/utils/tensor_utils.py @@ -92,6 +92,15 @@ def normalize_padding(padding: Union[Int, Sequence[Int]], dim=2) -> List[Int]: ) +def normalize_conv_padding(padding: Union[Int, Sequence[Int]], dim) -> List[Int]: + if isinstance(padding, int): + return [padding for _ in range(dim)] + elif isinstance(padding, (list, tuple)): + assert len(padding) == dim + return padding + raise ValueError('Incorrect conv padding: {}; dim is {}'.format(padding, dim)) + + def normalize_dim(dim: Optional[Union[Int, Sequence[Int]]], rank: int) -> Union[Int, List[Int]]: """ normalize a dim from [-rank, rank] or None to [0, rank]. diff --git a/python/hidet/ir/dtypes/boolean.py b/python/hidet/ir/dtypes/boolean.py index b613c8314..d1814b676 100644 --- a/python/hidet/ir/dtypes/boolean.py +++ b/python/hidet/ir/dtypes/boolean.py @@ -10,6 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any +from functools import cached_property import warnings from hidet.ir.type import DataType @@ -42,19 +43,19 @@ def constant(self, value: Any): value = bool(value) return constant(value, self) - @property + @cached_property def one(self): return self.constant(True) - @property + @cached_property def zero(self): return self.constant(False) - @property + @cached_property def true(self): return self.constant(True) - @property + @cached_property def false(self): return self.constant(False) diff --git a/python/hidet/ir/dtypes/complex.py b/python/hidet/ir/dtypes/complex.py index 948668086..62135e49c 100644 --- a/python/hidet/ir/dtypes/complex.py +++ b/python/hidet/ir/dtypes/complex.py @@ -10,6 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any +from functools import cached_property from hidet.ir.type import DataType from hidet.ir.dtypes.floats import float32, float64 @@ -47,11 +48,11 @@ def constant(self, value: Any): else: raise RuntimeError("Invalid constant value for complex type: {}".format(value)) - @property + @cached_property def one(self): return self.constant(1.0 + 0.0j) - @property + @cached_property def zero(self): return self.constant(0.0 + 0.0j) diff --git a/python/hidet/ir/dtypes/floats.py b/python/hidet/ir/dtypes/floats.py index 4e3074443..eaede9cd2 100644 --- a/python/hidet/ir/dtypes/floats.py +++ b/python/hidet/ir/dtypes/floats.py @@ -10,6 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any +from functools import cached_property from dataclasses import dataclass import warnings import numpy as np @@ -77,11 +78,11 @@ def constant(self, value: Any): return constant(value, self) - @property + @cached_property def one(self): return self.constant(1.0) - @property + @cached_property def zero(self): return self.constant(0.0) diff --git a/python/hidet/ir/dtypes/integer.py b/python/hidet/ir/dtypes/integer.py index 27268c106..7447a22f1 100644 --- a/python/hidet/ir/dtypes/integer.py +++ b/python/hidet/ir/dtypes/integer.py @@ -10,6 +10,7 @@ # See the License for the specific language governing permissions and # limitations under the License. from typing import Any +from functools import cached_property from dataclasses import dataclass import warnings from hidet.ir.type import DataType @@ -60,19 +61,19 @@ def constant(self, value: Any): def signedness(self): return self._min_value < 0 - @property + @cached_property def one(self): return self.constant(1) - @property + @cached_property def zero(self): return self.constant(0) - @property + @cached_property def min_value(self): return self.constant(self._min_value) - @property + @cached_property def max_value(self): return self.constant(self._max_value) diff --git a/python/hidet/ir/primitives/cuda/smem.py b/python/hidet/ir/primitives/cuda/smem.py index e9b4b9366..bc4fb12d6 100644 --- a/python/hidet/ir/primitives/cuda/smem.py +++ b/python/hidet/ir/primitives/cuda/smem.py @@ -24,7 +24,7 @@ def register_functions(): from hidet.lang import script, attrs, cast - for dtype in ['int8', 'uint8', 'uint32', 'int32', 'float16', 'float32']: + for dtype in ['int8', 'uint8', 'uint32', 'int32', 'float16', 'float32', 'bool']: func_name = f'cuda_dynamic_shared_memory_{dtype}' dtype = data_type(dtype) diff --git a/python/hidet/option.py b/python/hidet/option.py index ad225d514..98c31e0ae 100644 --- a/python/hidet/option.py +++ b/python/hidet/option.py @@ -285,6 +285,12 @@ def register_hidet_options(): default_value=True, description='Whether to enable imperative execution when op arguments allows', ) + register_option( + name='auth_tokens.for_huggingface', + type_hint='str', + default_value='', + description='The auth token to use for accessing private huggingface models.', + ) config_file_path = os.path.join(os.path.expanduser('~'), '.config', 'hidet') if not os.path.exists(config_file_path): diff --git a/python/hidet/runtime/compiled_graph.py b/python/hidet/runtime/compiled_graph.py index a8262f5f1..9c7b3583a 100644 --- a/python/hidet/runtime/compiled_graph.py +++ b/python/hidet/runtime/compiled_graph.py @@ -136,6 +136,21 @@ def __init__( # the weights are already loaded, initialize the graph directly self._init_compiled_graph() + def __getstate__(self): + # Create a temporary file and save the CompiledGraph zip in it + with tempfile.NamedTemporaryFile() as temp_file: + self.save(temp_file.name, save_dispatch_table=True) + with open(temp_file.name, 'rb') as f: + state = f.read() + return state + + def __setstate__(self, state): + # Load the CompiledGraph + with tempfile.NamedTemporaryFile() as temp_file: + with open(temp_file.name, 'wb') as f: + f.write(state) + self.__dict__.update(load_compiled_graph(temp_file.name).__dict__) + def __str__(self): """ Get the basic information of this compiled graph. diff --git a/python/hidet/version.py b/python/hidet/version.py index 42f341b65..23c7332e2 100644 --- a/python/hidet/version.py +++ b/python/hidet/version.py @@ -9,4 +9,4 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.3.1.dev" +__version__ = "0.3.1" diff --git a/requirements-dev.txt b/requirements-dev.txt index b6cb3e29a..afcfabba3 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -14,6 +14,7 @@ pylint==2.13.9 # for models to test torch torchvision +datasets transformers==4.37 sentencepiece sacremoses diff --git a/setup.py b/setup.py index fbcb908bd..abe6a64fc 100644 --- a/setup.py +++ b/setup.py @@ -14,7 +14,7 @@ setup( name="hidet", - version="0.3.1.dev", + version="0.3.1", description="Hidet: a compilation-based DNN inference framework.", long_description=open("README.md").read(), long_description_content_type="text/markdown", diff --git a/tests/apps/image_classification/test_builder.py b/tests/apps/image_classification/test_builder.py new file mode 100644 index 000000000..148a7503d --- /dev/null +++ b/tests/apps/image_classification/test_builder.py @@ -0,0 +1,30 @@ +import pytest +import torch +from datasets import load_dataset +from hidet.apps.image_classification.builder import create_image_classifier +from hidet.graph.tensor import from_torch +from transformers import AutoImageProcessor + + +@pytest.mark.slow +def test_create_image_classifier(): + dataset = load_dataset("huggingface/cats-image", split="test", trust_remote_code=True) + + # using huggingface pre-processor + image_processor = AutoImageProcessor.from_pretrained("microsoft/resnet-50") + image = image_processor(dataset[0]["image"], return_tensors="pt")["pixel_values"] + image = from_torch(image).cuda() + + resnet = create_image_classifier("microsoft/resnet-50") + assert "image_classifier" in resnet.compiled_app.meta.graphs + assert resnet.compiled_app.meta.name == "microsoft/resnet-50" + + res = resnet.compiled_app.graphs["image_classifier"].run_async([image]) + res = res[0].torch() + res = torch.argmax(res, dim=1) + + assert res[0].item() == 282 # tiger cat label + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/apps/test_pretrained.py b/tests/apps/test_pretrained.py new file mode 100644 index 000000000..547f304d1 --- /dev/null +++ b/tests/apps/test_pretrained.py @@ -0,0 +1,58 @@ +import pytest +import torch +from hidet.apps import PretrainedModel, hf +from hidet.apps.image_classification.modeling.resnet.modeling import ResNetForImageClassification +from hidet.graph.tensor import empty +from hidet.option import get_option +from transformers import AutoModelForImageClassification, PretrainedConfig, ResNetConfig + + +@pytest.mark.slow +@pytest.mark.parametrize( + "model_name, dtype", + [ + ("microsoft/codebert-base", "float16"), # resolve to default float16 + ("microsoft/resnet-50", "float32"), # use config float32 + ], +) +def test_parse_dtype(model_name: str, dtype: str): + config: PretrainedConfig = hf.load_pretrained_config(model_name) + assert PretrainedModel.parse_dtype(config) == dtype + + +@pytest.mark.slow +def test_copy_weights(): + + with torch.device("cuda"): + config: ResNetConfig = hf.load_pretrained_config("microsoft/resnet-50") + huggingface_token = get_option("auth_tokens.for_huggingface") + + torch_model = AutoModelForImageClassification.from_pretrained( + pretrained_model_name_or_path=config.name_or_path, torch_dtype=torch.float32, token=huggingface_token + ) + hidet_model = ResNetForImageClassification(config) + hidet_model.to(dtype="float32", device="cuda") + PretrainedModel.copy_weights(torch_model, hidet_model) + + normalization_stage = ( + hidet_model.resnet.encoder.stages._submodules["0"] + .layers._submodules["0"] + .layer._submodules["0"] + .normalization + ) + weight_set = [ + normalization_stage.weight, + normalization_stage.bias, + normalization_stage.running_mean, + normalization_stage.running_var, + hidet_model.classifier._submodules["1"].weight, + hidet_model.resnet.embedder.embedder.convolution.weight, + ] + + for weight in weight_set: + weight = weight.torch() + assert not torch.equal(weight, torch.zeros_like(weight)) + + +if __name__ == "__main__": + pytest.main([__file__]) diff --git a/tests/apps/test_registry.py b/tests/apps/test_registry.py new file mode 100644 index 000000000..be24ddfa9 --- /dev/null +++ b/tests/apps/test_registry.py @@ -0,0 +1,15 @@ +import pytest +from hidet.apps import Registry, hf +from hidet.apps.image_classification.modeling.resnet.modeling import ResNetForImageClassification +from transformers import PretrainedConfig + + +@pytest.mark.slow +@pytest.mark.parametrize('model_name', ["microsoft/resnet-50"]) +def test_load_module(model_name: str): + config: PretrainedConfig = hf.load_pretrained_config(model_name) + assert Registry.load_module(config) is ResNetForImageClassification + + +if __name__ == '__main__': + pytest.main([__file__]) diff --git a/.github/scripts/bench/bench_op.py b/tests/benchmarks/bench_op.py similarity index 89% rename from .github/scripts/bench/bench_op.py rename to tests/benchmarks/bench_op.py index 7bbce06e9..fb14b6aa6 100644 --- a/.github/scripts/bench/bench_op.py +++ b/tests/benchmarks/bench_op.py @@ -3,7 +3,8 @@ import argparse import numpy as np import hidet -from bench_utils import enable_compile_server, setup_hidet_flags, bench_torch_model +from bench_utils import bench_torch_model, Backend + def bench_matmul_f16(params: str, *args, **kwargs) -> float: a_shape, b_shape = params.split(',') @@ -17,6 +18,7 @@ def bench_matmul_f16(params: str, *args, **kwargs) -> float: g = g.cuda_graph() return bench_torch_model(lambda: g.run_async(), []) + def bench_batch_matmul(params: str, *args, **kwargs) -> float: # Default to benchmarking f32 for now, though this op can run other dtypes a_shape, b_shape = params.split(',') @@ -30,6 +32,7 @@ def bench_batch_matmul(params: str, *args, **kwargs) -> float: g = g.cuda_graph() return bench_torch_model(lambda: g.run_async(), []) + def bench_conv2d(params: str, *args, **kwargs) -> float: x_shape, w_shape = params.split(',') x_shape = [int(s) for s in x_shape.split('x')] @@ -42,6 +45,7 @@ def bench_conv2d(params: str, *args, **kwargs) -> float: g = g.cuda_graph() return bench_torch_model(lambda: g.run_async(), []) + def bench_conv2d_gemm_f16(params: str, *args, **kwargs) -> float: x_shape, w_shape = params.split(',') x_shape = [int(s) for s in x_shape.split('x')] @@ -54,6 +58,7 @@ def bench_conv2d_gemm_f16(params: str, *args, **kwargs) -> float: g = g.cuda_graph() return bench_torch_model(lambda: g.run_async(), []) + def bench_attn(params: str, *args, **kwargs) -> float: bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')] q_shape = [bs, nhead, seqlen, hdim] @@ -68,6 +73,7 @@ def bench_attn(params: str, *args, **kwargs) -> float: g = g.cuda_graph() return bench_torch_model(lambda: g.run_async(), []) + def bench_attn_mask_add(params: str, *args, **kwargs) -> float: bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')] q_shape = [bs, nhead, seqlen, hdim] @@ -84,6 +90,7 @@ def bench_attn_mask_add(params: str, *args, **kwargs) -> float: g = g.cuda_graph() return bench_torch_model(lambda: g.run_async(), []) + def bench_reduce(params: str, *args, **kwargs) -> float: x_shape, axis = params.split(',', maxsplit=1) start = axis.find('axis=[') + len('axis=[') @@ -97,6 +104,7 @@ def bench_reduce(params: str, *args, **kwargs) -> float: g = g.cuda_graph() return bench_torch_model(lambda: g.run_async(), []) + bench_func_map = { 'matmul_f16': bench_matmul_f16, 'batch_matmul': bench_batch_matmul, @@ -109,23 +117,14 @@ def bench_reduce(params: str, *args, **kwargs) -> float: if __name__ == '__main__': parser = argparse.ArgumentParser(prog='Benchmark Operators') + parser.add_argument('operator', type=str, help='Specify operator. E.g., matmul_f16') parser.add_argument( - 'operator', - type=str, - help='Specify operator. E.g., matmul_f16' - ) - parser.add_argument( - '--params', - type=str, - help='Specify Input Parameters. Different operators have different formats.' - ) - parser.add_argument( - '--dtype', - type=str, - default='float16', - help='Specify precision. E.g., float32' + '--params', type=str, help='Specify Input Parameters. Different operators have different formats.' ) + parser.add_argument('--dtype', type=str, default='float16', help='Specify precision. E.g., float32') + parser.add_argument('--backend', type=str, default='hidet', help='Only hidet supported in this script') args = parser.parse_args() + assert args.backend == 'hidet' operator, dtype = args.operator, args.dtype params = args.params @@ -134,11 +133,11 @@ def bench_reduce(params: str, *args, **kwargs) -> float: else: raise ValueError(f'Benchmark function for operator {operator} not implemented') - setup_hidet_flags(dtype, dynamo=False) - enable_compile_server(True) + Backend(backend='hidet', dtype=dtype).init_hidet() + with hidet.graph.PassContext() as ctx: ctx.set_reduce_precision(dtype) ctx.set_use_attention(True) ctx.set_mma('mma') latency = bench_func(params, dtype) - print(latency) \ No newline at end of file + print(latency) diff --git a/tests/benchmarks/bench_op_torch_api.py b/tests/benchmarks/bench_op_torch_api.py new file mode 100644 index 000000000..ba5efe5fe --- /dev/null +++ b/tests/benchmarks/bench_op_torch_api.py @@ -0,0 +1,166 @@ +import sys +import argparse +import torch +from bench_utils import bench_torch_model, Backend + + +# MATMUL BENCHMARKS # +class torch_matmul(torch.nn.Module): + def __init__(self): + super(torch_matmul, self).__init__() + + def forward(self, a, b): + return torch.matmul(a, b) + + +def create_model_matmul(params: str, dtype: torch.dtype): + a_shape, b_shape = params.split(',') + a_shape = [int(s) for s in a_shape.split('x')] + b_shape = [int(s) for s in b_shape.split('x')] + a = torch.randn(*a_shape, dtype=dtype, device='cuda') + b = torch.randn(*b_shape, dtype=dtype, device='cuda') + model = torch_matmul() + return model, [a, b] + + +# CONV BENCHMARKS # +class torch_conv2d(torch.nn.Module): + def __init__(self, w_shape, dtype: torch.dtype): + super(torch_conv2d, self).__init__() + self.w = torch.randn(*w_shape, dtype=dtype, device='cuda') + + def forward(self, x): + return torch.nn.functional.conv2d(x, self.w) + + +def create_model_conv2d(params: str, dtype: torch.dtype): + x_shape, w_shape = params.split(',') + x_shape = [int(s) for s in x_shape.split('x')] + w_shape = [int(s) for s in w_shape.split('x')] + x = torch.randn(*x_shape, dtype=dtype, device='cuda') + model = torch_conv2d(w_shape, dtype) + return model, [x] + + +# ATTENTION BENCHMARKS # +class torch_attn(torch.nn.Module): + def __init__(self, mask_shape=None): + super(torch_attn, self).__init__() + if mask_shape: + self.mask = torch.randn(*mask_shape, dtype=torch.float16, device='cuda') + else: + self.mask = None + + def forward(self, q, k, v): + return torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=self.mask) + + +def create_model_attn(params: str, dtype: torch.dtype) -> float: + bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')] + q_shape = [bs, nhead, seqlen, hdim] + q = torch.randn(*q_shape, dtype=dtype, device='cuda') + k = torch.randn(*q_shape, dtype=dtype, device='cuda') + v = torch.randn(*q_shape, dtype=dtype, device='cuda') + model = torch_attn() + return model, [q, k, v] + + +def create_model_attn_mask_add(params: str, dtype: torch.dtype) -> float: + bs, seqlen, nhead, hdim = [int(s) for s in params.split('x')] + q_shape = [bs, nhead, seqlen, hdim] + mask_shape = [bs, nhead, seqlen, seqlen] + q = torch.randn(*q_shape, dtype=dtype, device='cuda') + k = torch.randn(*q_shape, dtype=dtype, device='cuda') + v = torch.randn(*q_shape, dtype=dtype, device='cuda') + model = torch_attn(mask_shape=mask_shape) + return model, [q, k, v] + + +# REDUCE # +class torch_sum(torch.nn.Module): + def __init__(self, axis): + super(torch_sum, self).__init__() + self.axis = axis + + def forward(self, x): + return torch.sum(x, dim=self.axis) + + +def create_model_reduce(params: str, dtype): + x_shape, axis = params.split(',', maxsplit=1) + start = axis.find('axis=[') + len('axis=[') + end = axis.find(']', start) + axis = [int(s) for s in axis[start:end].split(',')] + x_shape = [int(s) for s in x_shape.split('x')] + x = torch.randn(*x_shape, dtype=dtype, device='cuda') + model = torch_sum(axis=axis) + return model, [x] + + +# RESHAPE # +class torch_reshape(torch.nn.Module): + def __init__(self, shape): + super(torch_reshape, self).__init__() + self.new_shape = shape + + def forward(self, x): + return torch.reshape(x, self.new_shape) + + +def create_model_reshape(params: str, dtype): + input_shape, output_shape = params.split(',', maxsplit=1) + input_shape = [int(s) for s in input_shape.split('x')] + output_shape = [int(s) for s in output_shape.split('x')] + x = torch.randn(*input_shape, dtype=dtype, device='cuda') + model = torch_reshape(output_shape) + return model, [x] + + +# TRANSPOSE 2D # +class torch_transpose(torch.nn.Module): + def __init__(self, input_shape, dim0, dim1): + super(torch_transpose, self).__init__() + self.input_shape = input_shape + self.dim0 = int(dim0) + self.dim1 = int(dim1) + + def forward(self, x): + return torch.transpose(x, self.dim0, self.dim1).flatten() + + +def create_model_transpose(params: str, dtype): + input_shape, dim0, dim1 = params.split(',', maxsplit=2) + input_shape = [int(s) for s in input_shape.split('x')] + x = torch.randn(*input_shape, dtype=dtype, device='cuda') + model = torch_transpose(input_shape, dim0, dim1) + return model, [x] + + +# Main benchmark function for ops. +# Calls bench_torch_model +def bench_op(operator, params, dtype, backend): + dtype = getattr(torch, dtype) + comp_backend = Backend(backend, dtype) + model_creator = getattr(sys.modules[__name__], "create_model_" + operator) + model, model_inputs = model_creator(params, dtype) + with torch.no_grad(), torch.autocast("cuda"): + opt_model = comp_backend.compile(model) + latency = bench_torch_model(opt_model, model_inputs) + + return latency + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(prog='Benchmark Operators') + parser.add_argument('operator', type=str, help='Specify operator. E.g., matmul_f16') + parser.add_argument( + '--params', type=str, help='Specify Input Parameters. Different operators have different formats.' + ) + parser.add_argument('--dtype', type=str, default='float16', help='Specify precision. E.g., float32') + parser.add_argument('--backend', type=str, default='hidet', help='torch.compile backend: hidet or max-autotune') + args = parser.parse_args() + + operator, dtype, backend = args.operator, args.dtype, args.backend + params = args.params + latency = bench_op(operator, params, dtype, backend) + print(latency) diff --git a/tests/benchmarks/bench_transformer.py b/tests/benchmarks/bench_transformer.py new file mode 100644 index 000000000..8ce9d9929 --- /dev/null +++ b/tests/benchmarks/bench_transformer.py @@ -0,0 +1,48 @@ +import os +import argparse +import torch +from transformers import AutoTokenizer, AutoModelForMaskedLM, AutoModelForCausalLM, logging +from bench_utils import bench_torch_model, Backend + +os.environ["TOKENIZERS_PARALLELISM"] = "false" +logging.set_verbosity_error() + +model_class = {'bert-base-uncased': 'AutoModelForMaskedLM'} + + +def bench_hf_transformers(model_name, seqlen, dtype, backend): + comp_backend = Backend(backend, dtype) + + dtype = getattr(torch, dtype) + tokenizer = AutoTokenizer.from_pretrained(model_name) + AutoModel_cls = eval(model_class[model_name]) + model = AutoModel_cls.from_pretrained(model_name, max_position_embeddings=8192, ignore_mismatched_sizes=True) + model = model.eval().to(dtype).cuda() + inputs = tokenizer("Dummy sentence", padding='max_length', max_length=seqlen, return_tensors='pt') + inputs = {'input_ids': inputs['input_ids']} + torch_inputs = tuple(i.clone().cuda() for i in inputs.values()) + + with torch.no_grad(), torch.autocast("cuda"): + model = comp_backend.compile(model) + latency = bench_torch_model(model, torch_inputs) + del model + return latency + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(prog='Benchmark Transformers') + parser.add_argument('model', type=str, help='Specify model') + parser.add_argument('--params', type=str, default='seqlen=1024', help='Specify Input Parameters. E.g., seqlen=1024') + parser.add_argument('--dtype', type=str, default='float16', help='Specify precision. E.g., float32') + parser.add_argument( + '--backend', + type=str, + default='hidet', + help='torch.compile backend: hidet or max-autotune or max-autotune-no-cudagraphs', + ) + args = parser.parse_args() + + model, dtype, backend = args.model, args.dtype, args.backend + seqlen = int(args.params.split('=')[1]) + latency = bench_hf_transformers(model, seqlen, dtype, backend) + print(latency) diff --git a/tests/benchmarks/bench_utils.py b/tests/benchmarks/bench_utils.py new file mode 100644 index 000000000..62db27c6a --- /dev/null +++ b/tests/benchmarks/bench_utils.py @@ -0,0 +1,68 @@ +# Class to initialise backend, run compilation +class Backend: + def __init__(self, backend, dtype) -> None: + assert ( + backend == 'hidet' or backend == 'max-autotune' or backend == 'max-autotune-no-cudagraphs' + ), 'backend is hidet or max-autotune or max-autotune-no-cudagraphs supported only' + self.backend = backend + self.dtype = dtype + if self.backend == 'hidet': + self.init_hidet() + + def init_hidet(self): + import hidet, os + + use_fp16 = self.dtype == 'float16' + hidet.torch.dynamo_config.search_space(2) + hidet.torch.dynamo_config.use_fp16(use_fp16) + hidet.torch.dynamo_config.use_fp16_reduction(use_fp16) + hidet.torch.dynamo_config.use_attention(True) + hidet.torch.dynamo_config.use_tensor_core(True) + hidet.torch.dynamo_config.use_cuda_graph(True) + hidet.option.search_space(2) + hidet.option.cache_dir(hidet.option.get_cache_dir() + '/regression') + + # hidet.option.parallel_tune(max_parallel_jobs=1) + # hidet.option.debug_cache_tuning(True) + # hidet.option.save_lower_ir(True) + # hidet.option.debug_show_verbose_flow_graph(True) + + # Initialise compiler server + if os.environ.get('CI_CS_HOSTNAME'): + hidet.option.compile_server.addr(os.environ.get('CI_CS_HOSTNAME')) + hidet.option.compile_server.port(int(os.environ.get('CI_CS_PORT'))) + hidet.option.compile_server.username(os.environ.get('CI_CS_USERNAME')) + hidet.option.compile_server.password(os.environ.get('CI_CS_PASSWORD')) + hidet.option.compile_server.repo(os.environ.get('REPO_NAME').strip(), os.environ.get('REPO_BRANCH').strip()) + hidet.option.compile_server.enable(flag=True) + + def compile(self, model): + import torch + + if self.backend == 'hidet': + model = torch.compile(model, backend=self.backend) + else: + model = torch.compile(model, mode=self.backend) + return model + + +# Make benchmarking of given torch model +def bench_torch_model(model, torch_inputs, bench_iters=100, warmup_iters=10): + import torch + + for _ in range(warmup_iters): + out = model(*torch_inputs) + torch.cuda.empty_cache() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + for _ in range(bench_iters): + out = model(*torch_inputs) + end.record() + end.synchronize() + torch.cuda.empty_cache() + + latency = start.elapsed_time(end) / bench_iters + return latency diff --git a/tests/benchmarks/bench_vision.py b/tests/benchmarks/bench_vision.py new file mode 100644 index 000000000..f43892060 --- /dev/null +++ b/tests/benchmarks/bench_vision.py @@ -0,0 +1,45 @@ +import argparse +import torch +import torchvision +from bench_utils import bench_torch_model, Backend + + +def bench_torchvision(model_name, shape, dtype, backend): + comp_backend = Backend(backend, dtype) + + dtype = getattr(torch, dtype) + if any(name in model_name for name in ['deeplab', 'fcn', 'lraspp']): + model_cls = getattr(torchvision.models.segmentation, model_name) + model = model_cls(weights=None) + model = model.eval().to(dtype).cuda() + elif model_name == 'yolov7': + # TODO: yolov7 don't work right now via pytorch + model = torch.hub.load( + 'WongKinYiu/yolov7', 'custom', '/tmp/yolov7.pt', autoshape=False, force_reload=True, trust_repo=True + ) + else: + model_cls = getattr(torchvision.models, model_name) + model = model_cls(weights=None) + model = model.eval().to(dtype).cuda() + + model_inputs = [torch.randn(shape, device='cuda', dtype=dtype)] + + with torch.no_grad(), torch.autocast("cuda"): + model = comp_backend.compile(model) + latency = bench_torch_model(model, model_inputs) + del model + return latency + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(prog='Benchmark Vision Models') + parser.add_argument('model', type=str, help='Specify model') + parser.add_argument('--params', type=str, default='1x3x224x224', help='Specify Input Size. E.g., 1x3x224x224') + parser.add_argument('--dtype', type=str, default='float16', help='Specify precision. E.g., float32') + parser.add_argument('--backend', type=str, default='hidet', help='torch.compile backend: hidet or max-autotune') + args = parser.parse_args() + + model, dtype, backend = args.model, args.dtype, args.backend + shape = [int(d) for d in args.params.split('x')] + latency = bench_torchvision(model, shape, dtype, backend) + print(latency) diff --git a/tests/benchmarks/run_configs.json b/tests/benchmarks/run_configs.json new file mode 100644 index 000000000..b5e0bf65e --- /dev/null +++ b/tests/benchmarks/run_configs.json @@ -0,0 +1,192 @@ +[ + { + "type": "model", + "id": 1, + "name": "bert-base-uncased", + "runfile": "bench_transformer.py", + "param_id": 1, + "param_name": "seqlen=256", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "model", + "id": 4, + "name": "resnet50", + "runfile": "bench_vision.py", + "param_id": 13, + "param_name": "128x3x224x224", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "model", + "id": 5, + "name": "efficientnet_b0", + "runfile": "bench_vision.py", + "param_id": 13, + "param_name": "128x3x224x224", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "model", + "id": 6, + "name": "densenet121", + "runfile": "bench_vision.py", + "param_id": 13, + "param_name": "128x3x224x224", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "model", + "id": 12, + "name": "mobilenet_v2", + "runfile": "bench_vision.py", + "param_id": 13, + "param_name": "128x3x224x224", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "model", + "id": 14, + "name": "vit_b_16", + "runfile": "bench_vision.py", + "param_id": 13, + "param_name": "128x3x224x224", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 1, + "name": "batch_matmul", + "runfile": "bench_op.py", + "param_id": 5, + "param_name": "1x4096x4096,1x4096x4096", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 1, + "name": "batch_matmul", + "runfile": "bench_op.py", + "param_id": 6, + "param_name": "1x1024x128,1x128x512", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 2, + "name": "matmul_f16", + "runfile": "bench_op.py", + "param_id": 5, + "param_name": "1x4096x4096,1x4096x4096", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 2, + "name": "matmul_f16", + "runfile": "bench_op.py", + "param_id": 6, + "param_name": "1x1024x128,1x128x512", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 3, + "name": "conv2d", + "runfile": "bench_op.py", + "param_id": 7, + "param_name": "1x3x224x224,64x3x3x3", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 3, + "name": "conv2d", + "runfile": "bench_op.py", + "param_id": 8, + "param_name": "1x3x1280x768,32x3x3x3", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 4, + "name": "conv2d_gemm_f16", + "runfile": "bench_op.py", + "param_id": 7, + "param_name": "1x3x224x224,64x3x3x3", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 4, + "name": "conv2d_gemm_f16", + "runfile": "bench_op.py", + "param_id": 8, + "param_name": "1x3x1280x768,32x3x3x3", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 5, + "name": "attn", + "runfile": "bench_op.py", + "param_id": 9, + "param_name": "1x4096x16x64", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 5, + "name": "attn", + "runfile": "bench_op.py", + "param_id": 10, + "param_name": "1x1024x16x128", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 6, + "name": "attn_mask_add", + "runfile": "bench_op.py", + "param_id": 9, + "param_name": "1x4096x16x64", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 6, + "name": "attn_mask_add", + "runfile": "bench_op.py", + "param_id": 10, + "param_name": "1x1024x16x128", + "dtype_id": 1, + "dtype_name": "float16" + }, + { + "type": "operator", + "id": 7, + "name": "reduce", + "runfile": "bench_op.py", + "param_id": 11, + "param_name": "1x224x224x3,axis=[1,2]", + "dtype_id": 1, + "dtype_name": "float16" + } +] \ No newline at end of file diff --git a/.github/scripts/run_tests.py b/tests/benchmarks/run_tests.py similarity index 76% rename from .github/scripts/run_tests.py rename to tests/benchmarks/run_tests.py index 5913a1bd5..2b807c9c7 100644 --- a/.github/scripts/run_tests.py +++ b/tests/benchmarks/run_tests.py @@ -7,6 +7,7 @@ external_models = ['llama-7b', 'gpt2'] + def run_command(cmd): cmd = " ".join(cmd) print("Running command: " + cmd) @@ -20,33 +21,29 @@ def run_command(cmd): raise RuntimeError(f'Command {cmd} failed with return code {ret}.') return stdout -def get_bench_cmd(run_type, run_id, run_name, runfile, run_param_name, dtype): + +def get_bench_cmd(run_name, runfile, run_param_name, dtype, backend): if run_name in external_models: runfile = './models/bench/' + runfile else: - runfile = str(pathlib.Path(__file__).parent.resolve()) + '/bench/' + runfile - cmd = ['python', runfile, run_name, '--params', run_param_name, '--dtype', dtype] + runfile = str(pathlib.Path(__file__).parent.resolve()) + '/' + runfile + cmd = ['python', runfile, run_name, '--params', run_param_name, '--dtype', dtype, '--backend', backend] return cmd + if __name__ == '__main__': parser = argparse.ArgumentParser(prog='Run Benchmarks') + parser.add_argument('--print', action='store_true', default=False, help='Print results') parser.add_argument( - '--print', - action='store_true', - default=False, - help='Print results' - ) - parser.add_argument( - '--configs', - type=str, - default='run_configs.json', - help='Specify configurations file to use for benchmarking' + '--configs', type=str, default='run_configs.json', help='Specify configurations file to use for benchmarking' ) + parser.add_argument('--backend', type=str, default='hidet', help='torch.compile backend: hidet or max-autotune') args = parser.parse_args() configs_file = args.configs fh = open(configs_file) run_configs = json.load(fh) fh.close() + backend = args.backend hw_config = os.environ.get('HW_CONFIG') for run_config in run_configs: # Append hardware_config column @@ -60,7 +57,7 @@ def get_bench_cmd(run_type, run_id, run_name, runfile, run_param_name, dtype): run_param_name = run_config['param_name'] run_dtype_id = run_config['dtype_id'] run_dtype_name = run_config['dtype_name'] - cmd = get_bench_cmd(run_type, run_id, run_name, runfile, run_param_name, run_dtype_name) + cmd = get_bench_cmd(run_name, runfile, run_param_name, run_dtype_name, backend) outputs = run_command(cmd) if outputs: # The second last line of All benchmark scripts' stdout is the latency. (Last line is empty) @@ -72,4 +69,4 @@ def get_bench_cmd(run_type, run_id, run_name, runfile, run_param_name, dtype): json.dump(run_configs, fh) if args.print: - print(tabulate(run_configs, headers="keys")) \ No newline at end of file + print(tabulate(run_configs, headers="keys"))