diff --git a/.aitk/configs/checks.json b/.aitk/configs/checks.json index fa1bea44..76179618 100644 --- a/.aitk/configs/checks.json +++ b/.aitk/configs/checks.json @@ -1,16 +1,16 @@ { - "configCheck": 127, + "configCheck": 128, "copyCheck": 182, "extensionCheck": 1, - "gitignoreCheck": 37, + "gitignoreCheck": 38, "inferenceModelCheck": 25, - "ipynbCheck": 37, - "licenseCheck": 36, - "modelProjectCheck": 38, + "ipynbCheck": 38, + "licenseCheck": 37, + "modelProjectCheck": 39, "oliveCheck": 36, - "oliveJsonCheck": 127, - "pathCheck": 992, + "oliveJsonCheck": 128, + "pathCheck": 1005, "requirementsCheck": 37, "templateCheck": 1, - "venvRequirementsCheck": 10 + "venvRequirementsCheck": 11 } diff --git a/.aitk/configs/model_list.json b/.aitk/configs/model_list.json index 3bca5199..5091df7b 100644 --- a/.aitk/configs/model_list.json +++ b/.aitk/configs/model_list.json @@ -464,6 +464,19 @@ "version": 1, "p0": false }, + { + "displayName": "openai/whisper-large-v3-turbo", + "icon": "OpenAI", + "modelLink": "https://huggingface.co/openai/whisper-large-v3-turbo", + "id": "huggingface/openai/whisper-large-v3-turbo", + "runtimes": [ + "QNN" + ], + "architecture": "Transformer", + "status": "Hide", + "relativePath": "openai-whisper-large-v3-turbo/aitk", + "version": 1 + }, { "displayName": "Qwen/Qwen2.5-0.5B", "icon": "qwen", @@ -666,7 +679,8 @@ "timm/mini-imagenet": "https://huggingface.co/datasets/timm/mini-imagenet", "wikipedia": "https://huggingface.co/datasets/wikimedia/wikipedia", "google-research-datasets/conceptual_captions": "https://huggingface.co/datasets/google-research-datasets/conceptual_captions", - "AIMClab-RUC/COCO-CN": "https://huggingface.co/datasets/AIMClab-RUC/COCO-CN" + "AIMClab-RUC/COCO-CN": "https://huggingface.co/datasets/AIMClab-RUC/COCO-CN", + "librispeech_asr": "https://huggingface.co/datasets/openslr/librispeech_asr" }, "LoginRequiredDatasets": [ "imagenet-1k" diff --git a/.aitk/requirements/requirements-WCR-QAI.txt b/.aitk/requirements/requirements-WCR-QAI.txt new file mode 100644 index 00000000..27da1a1c --- /dev/null +++ b/.aitk/requirements/requirements-WCR-QAI.txt @@ -0,0 +1,9 @@ +gdown==5.2.0 +gitpython==3.1.46 +librosa==0.11.0 +qai_hub==0.42.0 +ruamel-yaml==0.19.1 +schema==0.7.8 +sounddevice==0.5.2 +# need to install without deps because it depends on onnxruntime +# uv pip:install qai_hub_models==0.39.1 --no-deps;post diff --git a/.aitk/scripts/sanitize/constants.py b/.aitk/scripts/sanitize/constants.py index a66770b0..031355a2 100644 --- a/.aitk/scripts/sanitize/constants.py +++ b/.aitk/scripts/sanitize/constants.py @@ -32,7 +32,6 @@ class ArchitectureEnum(Enum): class ModelStatusEnum(Enum): Ready = "Ready" - Coming = "Coming" Hide = "Hide" @@ -107,6 +106,7 @@ class OliveDeviceTypes(Enum): # Pass name is case insensitive, so we use lower case for all pass names # Should sort by value class OlivePassNames: + AitkPython = "aitkpython" ModelBuilder = "modelbuilder" NVModelOptQuantization = "nvmodeloptquantization" OnnxFloatToFloat16 = "onnxfloattofloat16" @@ -160,6 +160,7 @@ class OlivePropertyNames: TargetDevice = "target_device" Type = "type" UserConfig = "user_config" + UserScript = "user_script" WeightFormat = "weight_format" diff --git a/.aitk/scripts/sanitize/file_validation.py b/.aitk/scripts/sanitize/file_validation.py index d9e0a2d0..0cbf538f 100644 --- a/.aitk/scripts/sanitize/file_validation.py +++ b/.aitk/scripts/sanitize/file_validation.py @@ -172,8 +172,11 @@ def readCheckIpynb(ipynbFile: str, modelItems: dict[str, ModelParameter]): importStr = importOnnxgenairuntime elif modelParameter.runtime.values and modelParameter.isIntel: testPath = outputModelIntelNPURelativePath + elif modelParameter.aitkPython: + testPath = None + importStr = None for item in [testPath, importStr]: - if not re.search(item, ipynbContent): + if item and not re.search(item, ipynbContent): printError(f"{ipynbFile} does not have '{item}' for {name}, please use it as input") if modelParameter.evalRuntime: runtime = GlobalVars.RuntimeToEPName[modelParameter.evalRuntime] diff --git a/.aitk/scripts/sanitize/generator_common.py b/.aitk/scripts/sanitize/generator_common.py index e4a5c4a5..1353af2a 100644 --- a/.aitk/scripts/sanitize/generator_common.py +++ b/.aitk/scripts/sanitize/generator_common.py @@ -34,6 +34,25 @@ def create_model_parameter(aitk, name: str, configFile: Path): return parameter +def add_optimization_wa(optimizationPaths: list[OptimizationPath], k: str, v: dict) -> bool: + if OlivePropertyNames.Precision in v: + optimizationPaths.append( + OptimizationPath( + name="WeightType", + path=f"{OlivePropertyNames.Passes}.{k}.{OlivePropertyNames.Precision}", + ) + ) + # We require both weight and activation type for quantization + optimizationPaths.append( + OptimizationPath( + name="ActivationType", + path=f"{OlivePropertyNames.Passes}.{k}.{OlivePropertyNames.ActivationType}", + ) + ) + return True + return False + + def set_optimization_path(parameter: ModelParameter, configFile: str): parameter.optimizationPaths = [] with open_ex(configFile, "r") as f: @@ -64,18 +83,7 @@ def set_optimization_path(parameter: ModelParameter, configFile: str): OlivePassNames.OnnxStaticQuantization, OlivePassNames.OnnxDynamicQuantization, ]: - parameter.optimizationPaths.append( - OptimizationPath( - name="WeightType", - path=f"{OlivePropertyNames.Passes}.{k}.{OlivePropertyNames.Precision}", - ) - ) - parameter.optimizationPaths.append( - OptimizationPath( - name="ActivationType", - path=f"{OlivePropertyNames.Passes}.{k}.{OlivePropertyNames.ActivationType}", - ) - ) + add_optimization_wa(parameter.optimizationPaths, k, v) return elif vType == OlivePassNames.OnnxFloatToFloat16: parameter.optimizationPaths.append( @@ -85,3 +93,20 @@ def set_optimization_path(parameter: ModelParameter, configFile: str): ) ) return + elif vType == OlivePassNames.AitkPython: + # Check AitkPython specific properties + if k != OlivePassNames.AitkPython: + raise Exception(f"AitkPython pass key must be '{OlivePassNames.AitkPython}' in {configFile}") + if OlivePropertyNames.UserScript in v: + parameter.aitkPython = v[OlivePropertyNames.UserScript] + python_script = Path(configFile).parent / str(parameter.aitkPython) + if not python_script.exists(): + raise Exception(f"UserScript file {python_script} does not exist for AitkPython pass in {configFile}") + else: + raise Exception(f"UserScript is required for AitkPython pass in {configFile}") + wa_added = add_optimization_wa(parameter.optimizationPaths, k, v) + if wa_added: + return + else: + # TODO handle other optimization types if needed + return diff --git a/.aitk/scripts/sanitize/generator_dml.py b/.aitk/scripts/sanitize/generator_dml.py index cf6531eb..7666e51b 100644 --- a/.aitk/scripts/sanitize/generator_dml.py +++ b/.aitk/scripts/sanitize/generator_dml.py @@ -2,7 +2,7 @@ from pathlib import Path from typing import Optional -from .constants import OlivePassNames, OlivePropertyNames, PhaseTypeEnum, ParameterTypeEnum +from .constants import OlivePassNames, OlivePropertyNames, ParameterTypeEnum, PhaseTypeEnum from .generator_common import create_model_parameter, set_optimization_path from .model_info import ModelList from .model_parameter import ModelParameter, OptimizationPath, Section @@ -40,15 +40,14 @@ def generate_quantization_config(configFile: Path, parameter: ModelParameter) -> phase=PhaseTypeEnum.Quantization, parameters=parameters, disableToggleGeneration=True, - toggle = Parameter( - autoGenerated=True, - name="Optimize model", - type=ParameterTypeEnum.Bool, - path=optimize_path, - readOnly=True, - actions=[[], []], - ) - + toggle=Parameter( + autoGenerated=True, + name="Optimize model", + type=ParameterTypeEnum.Bool, + path=optimize_path, + readOnly=True, + actions=[[], []], + ), ) return None diff --git a/.aitk/scripts/sanitize/model_parameter.py b/.aitk/scripts/sanitize/model_parameter.py index 6f8548b0..9c4c130e 100644 --- a/.aitk/scripts/sanitize/model_parameter.py +++ b/.aitk/scripts/sanitize/model_parameter.py @@ -113,6 +113,7 @@ def Check( printError(f"{_file} section {sectionId} parameter {i} has error") # TODO move tag check into Parameter + # TODO guess for possible tags if parameter.path and Section.datasetPathPattern(parameter.path): if self.phase == PhaseTypeEnum.Quantization: if not parameter.tags or ParameterTagEnum.QuantizationDataset not in parameter.tags: @@ -120,10 +121,6 @@ def Check( elif self.phase == PhaseTypeEnum.Evaluation: if not parameter.tags or ParameterTagEnum.EvaluationDataset not in parameter.tags: printError(f"{_file} section {sectionId} parameter {i} should have EvaluationDataset tag") - if parameter.values: - missing_keys = [key for key in parameter.values if key not in modelList.HFDatasets] - if missing_keys: - printError(f"datasets are not in HFDatasets: {', '.join(str(key) for key in missing_keys)}") elif parameter.path and parameter.path.endswith("activation_type"): if not parameter.tags or ParameterTagEnum.ActivationType not in parameter.tags: printError(f"{_file} section {sectionId} parameter {i} should have ActivationType tag") @@ -258,6 +255,7 @@ class ModelParameter(BaseModelClass): runtimeInConversion: Optional[Parameter] = None optimizationPaths: Optional[List[OptimizationPath]] = None optimizationDefault: Optional[str] = None + aitkPython: Optional[str] = None sections: List[Section] = [] @staticmethod @@ -379,7 +377,8 @@ def Check(self, templates: Dict[str, Parameter], oliveJson: Any, modelList: Mode conversion = [ k for k, v in oliveJson[OlivePropertyNames.Passes].items() - if v[OlivePropertyNames.Type].lower() == OlivePassNames.OnnxConversion + if v[OlivePropertyNames.Type].lower() + in [OlivePassNames.OnnxConversion, OlivePassNames.AitkPython] ][0] conversionPath = f"{OlivePropertyNames.Passes}.{conversion}" section.toggle = Parameter( @@ -601,7 +600,7 @@ def checkPhase(self, oliveJson: Any): if ( PhaseTypeEnum.Evaluation in allPhases and PhaseTypeEnum.Quantization in allPhases - and len(oliveJson[OlivePropertyNames.DataConfigs]) != 2 + and (OlivePropertyNames.DataConfigs not in oliveJson or len(oliveJson[OlivePropertyNames.DataConfigs]) != 2) ): printWarning(f"{self._file}'s olive json should have two data configs for evaluation") diff --git a/.aitk/scripts/sanitize/parameters.py b/.aitk/scripts/sanitize/parameters.py index e6ea99b5..d1b9f5e3 100644 --- a/.aitk/scripts/sanitize/parameters.py +++ b/.aitk/scripts/sanitize/parameters.py @@ -207,16 +207,23 @@ def Check( if value != self.values[0]: printError(f"Value {value} not the first in values for {self.path}") return False - for i in range(len(self.values) - 1): - value_in_list = self.values[i + 1] - if modelList and value_in_list not in modelList.DatasetSplit: - printError(f"Value {value_in_list} not in DatasetSplit for {self.path}") - return False - if modelList and value_in_list not in modelList.DatasetSubset: - # No error for this, just warning - printWarning( - f"Value {value_in_list} not in DatasetSubset for {self.path}. Could be acceptable if it doesn't have subset" - ) + if modelList: + for i in range(len(self.values)): + value_in_list = self.values[i] + if value_in_list not in modelList.HFDatasets: + printError(f"Value {value_in_list} not in HFDatasets for {self.path}") + return False + if i == 0: + # The first one doesn't need to be in DatasetSplit or DatasetSubset + continue + if value_in_list not in modelList.DatasetSplit: + printError(f"Value {value_in_list} not in DatasetSplit for {self.path}") + return False + if value_in_list not in modelList.DatasetSubset: + # No error for this, just warning + printWarning( + f"Value {value_in_list} not in DatasetSubset for {self.path}. Could be acceptable if it doesn't have subset" + ) elif value and value not in self.values: printError(f"Value {value} not in values for {self.path}") return False diff --git a/openai-whisper-large-v3-turbo/aitk/.gitignore b/openai-whisper-large-v3-turbo/aitk/.gitignore new file mode 100644 index 00000000..eccf7b1a --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/.gitignore @@ -0,0 +1,6 @@ +__pycache__ +/cache +/history/*/* +!/history/*/history.config +!/history/*/olive_config.json +/data diff --git a/openai-whisper-large-v3-turbo/aitk/README.md b/openai-whisper-large-v3-turbo/aitk/README.md new file mode 100644 index 00000000..fc62acdf --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/README.md @@ -0,0 +1,31 @@ +## Whisper-large-v3-turbo Optimization with ONNX Runtime QNN EP + +This folder outlines the process for optimizing the Whisper-large-v3-turbo model using ONNX Runtime with the QNN Execution Provider. It includes steps for exporting FP32 models, generating representative data for static quantization, creating QDQ models, model evaluation and performing audio transcription using the optimized models. + +### Generate data for static quantization + +To get better results, we need to generate real data from original FP32 model instead of using random data for static quantization. Here we use 100 samples of librispeech dataset to generate the required real data which requires around 164 GB of disk space. + +First generate FP32 onnx models: + +1. Encoder FP32 model + + `olive run --config whisper_large_v3_turbo_encoder_fp32.json` +1. Decoder FP32 model + + `olive run --config whisper_large_v3_turbo_decoder_fp32.json` + +Then download and generate data: + +1. `python .\qnn_run.py --audio-path .\data\librispeech_asr_clean_test --encoder "models\whisper_encoder_fp32\model\model.onnx" --decoder "models\whisper_decoder_fp32\model.onnx" --model_id "openai/whisper-large-v3-turbo" --save_data .\data\quantization_data --num_data 100` + +### Generate QDQ models + +1. `olive run --config whisper_large_v3_turbo_encoder_qdq.json` +2. `olive run --config whisper_large_v3_turbo_decoder_qdq.json` + +(Optional) Use whisper_large_v3_turbo_encoder_qdq_ctx.json and whisper_large_v3_turbo_decoder_qdq_ctx.json to create onnx models with QNN context binaries embedded in them. + +### To transcribe a single sample: + +`python .\qnn_run.py --audio-path .\data\librispeech_asr_clean_test\1320-122617-0000.npy --encoder "models\whisper_encoder_qdq\model.onnx" --decoder "models\whisper_decoder_qdq\model.onnx" --model_id "openai/whisper-large-v3-turbo" --execution_provider QNNExecutionProvider` diff --git a/openai-whisper-large-v3-turbo/aitk/inference_sample.ipynb b/openai-whisper-large-v3-turbo/aitk/inference_sample.ipynb new file mode 100644 index 00000000..5d18ede0 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/inference_sample.ipynb @@ -0,0 +1,26 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": null, + "id": "eed9c231", + "metadata": { + "vscode": { + "languageId": "plaintext" + } + }, + "outputs": [], + "source": [ + "# TODO\n", + "ExecutionProvider=\"QNNExecutionProvider\"" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/openai-whisper-large-v3-turbo/aitk/info.yml b/openai-whisper-large-v3-turbo/aitk/info.yml new file mode 100644 index 00000000..223f0abd --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/info.yml @@ -0,0 +1,12 @@ +keywords: + aitk +arch: whisper +recipes: + - file: "qnn_workflow.json" + device: npu + ep: QNNExecutionProvider +aitk: + modelInfo: + id: "huggingface/openai/whisper-large-v3-turbo" + version: 1 + status: Hide diff --git a/openai-whisper-large-v3-turbo/aitk/model_project.config b/openai-whisper-large-v3-turbo/aitk/model_project.config new file mode 100644 index 00000000..bb3a40f1 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/model_project.config @@ -0,0 +1,12 @@ +{ + "workflows": [ + { + "file": "qnn_workflow.json", + "templateName": "qnn_workflow" + } + ], + "modelInfo": { + "id": "huggingface/openai/whisper-large-v3-turbo", + "version": 1 + } +} diff --git a/openai-whisper-large-v3-turbo/aitk/qnn_app.py b/openai-whisper-large-v3-turbo/aitk/qnn_app.py new file mode 100644 index 00000000..9bdb9f80 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/qnn_app.py @@ -0,0 +1,214 @@ +# --------------------------------------------------------------------- +# Copyright (c) 2025 Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# --------------------------------------------------------------------- + +from __future__ import annotations + +import os + +import numpy as np +import onnxruntime as ort +import torch +from qai_hub_models.models._shared.hf_whisper.app import HfWhisperApp, chunk_and_resample_audio +from qai_hub_models.models._shared.hf_whisper.model import ( + CHUNK_LENGTH, + SAMPLE_RATE, +) +from transformers.models.whisper import WhisperProcessor + + +def get_audio_name(audio_file): + return os.path.splitext(os.path.basename(audio_file))[0] + + +def infer_audio(app, model_id, audio_file, save_data: bool, audio_name: str | None = None): + audio_dict = np.load(audio_file, allow_pickle=True).item() + + audio = audio_dict["audio"]["array"] + sample_rate = audio_dict["audio"]["sampling_rate"] + audio_name = get_audio_name(audio_file) if not audio_name else audio_name + + processor = WhisperProcessor.from_pretrained(model_id) + reference = processor.tokenizer._normalize(audio_dict["text"]) + print("Reference: ", reference) + + # Perform transcription + transcription = app.transcribe(audio, sample_rate, audio_name, save_data) + print("done transcription") + prediction = processor.tokenizer._normalize(transcription) + print("Prediction:", prediction) + + +def add_ep_for_device(session_options, ep_name, device_type, ep_options=None): + ep_devices = ort.get_ep_devices() + for ep_device in ep_devices: + if ep_device.ep_name == ep_name and ep_device.device.type == device_type: + print(f"Adding {ep_name} for {device_type}") + session_options.add_provider_for_devices([ep_device], {} if ep_options is None else ep_options) + break + + +class HfWhisperAppWithSave(HfWhisperApp): + def __init__( + self, + encoder, + decoder, + hf_model_id: str, + execution_provider: str = "CPUExecutionProvider", + device_str: str = "cpu", + sample_rate: int = SAMPLE_RATE, + max_audio_seconds: int = CHUNK_LENGTH, + ): + super().__init__(None, None, hf_model_id, sample_rate, max_audio_seconds) + options = ort.SessionOptions() + device_type = ort.OrtHardwareDeviceType.CPU + if device_str.lower() == "gpu": + device_type = ort.OrtHardwareDeviceType.GPU + elif device_str.lower() == "npu": + device_type = ort.OrtHardwareDeviceType.NPU + add_ep_for_device(options, execution_provider, device_type) + + self.encoder = ort.InferenceSession( + encoder, sess_options=options, + ) + + self.decoder = ort.InferenceSession( + decoder, sess_options=options, + ) + + def transcribe_tokens(self, audio, sample_rate, audio_name, save_data=False) -> list[int]: + out_chunked_tokens = [] + for ind, x in enumerate(chunk_and_resample_audio(audio, sample_rate)): + out_chunked_tokens.append(self._transcribe_single_chunk(x, audio_name, ind, save_data)) + + out_tokens: list[int] = [] + for chunk_tokens in out_chunked_tokens: + out_tokens.extend(chunk_tokens) + return out_tokens + + def transcribe(self, audio, sample_rate, audio_name, save_data=False) -> str: + tokens = self.transcribe_tokens(audio, sample_rate, audio_name, save_data) + return self.tokenizer.decode(tokens, skip_special_tokens=True).strip() + + def _transcribe_single_chunk( + self, audio: np.ndarray, audio_name=None, chunk_number=None, save_data=False + ) -> list[int]: + # feature + input_features = self.feature_extractor(audio, sampling_rate=self.sample_rate, return_tensors="np")[ + "input_features" + ] + + # encoder + output_names_encoder = [output.name for output in self.encoder.get_outputs()] + # kv_cache_cross = self.encoder(input_features) + input_features_feed = {"input_features": input_features} + + if save_data: + input_features_save_path = os.path.join(save_data, audio_name, f"{chunk_number}_input_features.npy") + os.makedirs(os.path.dirname(input_features_save_path), exist_ok=True) + np.save(input_features_save_path, input_features_feed) + + kv_cache_cross_numpy = self.encoder.run(output_names_encoder, input_features_feed) + kv_cache_cross = [torch.from_numpy(arr) for arr in kv_cache_cross_numpy] + if not isinstance(kv_cache_cross, tuple): + kv_cache_cross = (kv_cache_cross,) + if not isinstance(kv_cache_cross[0], (tuple, list)): + kv_cache_cross = (kv_cache_cross,) + + sot = self.config.decoder_start_token_id + num_decoder_blocks = self.config.decoder_layers + attention_dim = self.config.d_model + num_decoder_heads = self.config.decoder_attention_heads + mask_neg = self.config.mask_neg + eot = self.config.eos_token_id + + # decoder + output_ids = torch.tensor([[sot]]) # Start of transcript + output_logits = [] + output_length = output_ids.shape[1] + + position_ids = torch.tensor([0], dtype=torch.int32) + attention_mask = torch.full( + (1, 1, 1, self.mean_decode_len), + mask_neg, + dtype=torch.float32, + ) + + # init kv_cache_self + k_cache_self = torch.zeros( + ( + num_decoder_heads, + 1, + attention_dim // num_decoder_heads, + self.mean_decode_len - 1, + ), + dtype=torch.float32, + ) + v_cache_self = torch.zeros( + ( + num_decoder_heads, + 1, + self.mean_decode_len - 1, + attention_dim // num_decoder_heads, + ), + dtype=torch.float32, + ) + kv_cache_self = tuple((k_cache_self, v_cache_self) for _ in range(num_decoder_blocks)) + + for n in range(self.mean_decode_len - 1): + # get current token + input_ids = output_ids[:, n : n + 1].to(torch.int32) + + # update attention_mask + attention_mask[:, :, :, self.mean_decode_len - n - 1] = 0.0 + + # flattened kv caches input + flattened_kv_cache_self = tuple(item for sublist in kv_cache_self for item in sublist) + flattened_kv_cache_cross = tuple(item for sublist in kv_cache_cross for item in sublist) + + # decode and update kv_cache_self + decoder_input = ( + (input_ids, attention_mask) + flattened_kv_cache_self + flattened_kv_cache_cross + (position_ids,) + ) + + # print("decoder_input: ", decoder_input) + input_names_decoder = [input.name for input in self.decoder.get_inputs()] + output_names_decoder = [output.name for output in self.decoder.get_outputs()] + + # decoder_input_feed = dict(zip(input_names_decoder, decoder_input)) + decoder_input_feed = { + name: tensor.numpy() if isinstance(tensor, torch.Tensor) else tensor + for name, tensor in zip(input_names_decoder, decoder_input) + } + + if save_data: + decoder_input_save_path = os.path.join(save_data, audio_name, f"{chunk_number}_{n}_decoder_input.npy") + os.makedirs(os.path.dirname(decoder_input_save_path), exist_ok=True) + np.save(decoder_input_save_path, decoder_input_feed) + + decoder_output_numpy = self.decoder.run(output_names_decoder, decoder_input_feed) + decoder_output = [torch.from_numpy(arr) for arr in decoder_output_numpy] + # decoder_output = self.decoder(*decoder_input) + if isinstance(decoder_output, tuple) and len(decoder_output) == 2: + logits, kv_cache_self = decoder_output + else: + logits = decoder_output[0] + kv_cache_self = tuple(decoder_output[i : i + 2] for i in range(1, len(decoder_output), 2)) + + # update output_logits + output_logits.append(logits.detach().clone()) + + # update output_ids + output_id = torch.argmax(logits, 1).squeeze(0) + # end of transcript + if len(output_logits) == (self.mean_decode_len - 1) or output_id == eot: + output_ids = torch.cat((output_ids, output_id), -1) + break + if n >= output_length - 1: + output_ids = torch.cat((output_ids, output_id), -1) + + # update position_ids + position_ids += 1 + + return output_ids[0].tolist() diff --git a/openai-whisper-large-v3-turbo/aitk/qnn_run.py b/openai-whisper-large-v3-turbo/aitk/qnn_run.py new file mode 100644 index 00000000..ccb59c48 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/qnn_run.py @@ -0,0 +1,125 @@ +# --------------------------------------------------------------------- +# Copyright (c) 2025 Qualcomm Technologies, Inc. and/or its subsidiaries. +# SPDX-License-Identifier: BSD-3-Clause +# --------------------------------------------------------------------- + +import argparse +import os +from qnn_app import HfWhisperAppWithSave, infer_audio, get_audio_name +import logging + +logger = logging.getLogger(os.path.basename(__file__)) +logging.basicConfig(level=logging.INFO) + +def register_execution_providers(): + import json + import subprocess + import sys + import onnxruntime as ort + + worker_script = os.path.abspath('winml.py') + result = subprocess.check_output([sys.executable, worker_script], text=True) + paths = json.loads(result) + for item in paths.items(): + try: + ort.register_execution_provider_library(item[0], item[1]) + except Exception as e: + logger.warning(f"Failed to register execution provider {item[0]}: {e}") + +def main(): + parser = argparse.ArgumentParser(description="Demo") + parser.add_argument( + "--audio-path", + type=str, + help="Path to folder containing audio files or a single audio file path", + ) + parser.add_argument( + "--encoder", + type=str, + help="Path to encoder onnx file", + ) + parser.add_argument( + "--decoder", + type=str, + help="Path to decoder onnx file", + ) + parser.add_argument( + "--model_id", + type=str, + default="openai/whisper-large-v3-turbo", + help="HuggingFace Whisper model id", + ) + parser.add_argument( + "--execution_provider", + type=str, + default="CPUExecutionProvider", + help="ORT Execution provider", + ) + parser.add_argument( + "--device_type", + type=str, + default="cpu", + ) + parser.add_argument( + "--save_data", + type=str, + default=None, + help="(Optional) Path to save quantization data", + ) + parser.add_argument( + "--dataset_name", + type=str, + default="librispeech_asr", + help="(Optional) dataset to download", + ) + parser.add_argument( + "--dataset_split", + type=str, + default="test", + help="(Optional) dataset split to download", + ) + parser.add_argument( + "--num_data", + type=int, + default=100, + help="Number of data samples to use for quantization. Only applicable if --save_data is enabled", + ) + + args = parser.parse_args() + + encoder_path = args.encoder + decoder_path = args.decoder + + register_execution_providers() + app = HfWhisperAppWithSave(encoder_path, decoder_path, args.model_id, args.execution_provider, args.device_type) + + if not os.path.exists(args.audio_path) or os.path.isdir(args.audio_path): + from datasets import load_dataset + import numpy as np + + os.makedirs(args.audio_path, exist_ok=True) + streamed_dataset = load_dataset(args.dataset_name, "clean", split=args.dataset_split, streaming=True) + i = 0 + for batch in streamed_dataset: + i += 1 + file_path = os.path.join(args.audio_path, f"{batch['id']}.npy") + if not os.path.exists(file_path): + np.save(file_path, batch) + + audio_name = get_audio_name(file_path) + if args.save_data and os.path.exists(os.path.join(args.save_data, audio_name)): + #print(f"Skipping {file_path} as data already exists.") + pass + else: + logger.info(f"Processing data {i} in {file_path} ...") + infer_audio(app, args.model_id, file_path, args.save_data, audio_name) + + if i >= args.num_data: + break + + else: + infer_audio(app, args.model_id, args.audio_path, args.save_data) + + +if __name__ == "__main__": + main() diff --git a/openai-whisper-large-v3-turbo/aitk/qnn_workflow.json b/openai-whisper-large-v3-turbo/aitk/qnn_workflow.json new file mode 100644 index 00000000..cee9e307 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/qnn_workflow.json @@ -0,0 +1,39 @@ +{ + "input_model": { + "model_path": "openai/whisper-large-v3-turbo" + }, + "systems": { + "target_system": { + "type": "LocalSystem", + "accelerators": [ + { + "device": "npu", + "execution_providers": [ + "QNNExecutionProvider" + ] + } + ] + } + }, + "evaluators": { + "common_evaluator": { + } + }, + "passes": { + "aitkpython": { + "type": "AitkPython", + "activation_type": "uint16", + "precision": "uint8", + "dataset_name": "librispeech_asr", + "split": "test", + "length": 100, + "user_script": "qnn_workflow.py" + } + }, + "evaluator": "common_evaluator", + "evaluate_input_model": false, + "target": "target_system", + "clean_cache": false, + "output_dir": "model/whisper_qnn", + "cache_dir": "cache" +} diff --git a/openai-whisper-large-v3-turbo/aitk/qnn_workflow.json.config b/openai-whisper-large-v3-turbo/aitk/qnn_workflow.json.config new file mode 100644 index 00000000..c00f74a3 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/qnn_workflow.json.config @@ -0,0 +1,201 @@ +{ + "name": "Convert to Qualcomm NPU", + "executeRuntimeFeatures": [ + "QAI" + ], + "evaluationRuntimeFeatures": [ + "QAI" + ], + "addCpu": true, + "runtime": { + "autoGenerated": true, + "name": "Evaluate on", + "type": "enum", + "displayNames": [ + "Qualcomm NPU", + "CPU" + ], + "path": "systems.target_system.accelerators.0.execution_providers.0", + "values": [ + "QNNExecutionProvider", + "CPUExecutionProvider" + ], + "readOnly": false + }, + "optimizationPaths": [ + { + "path": "passes.aitkpython.precision", + "name": "WeightType" + }, + { + "path": "passes.aitkpython.activation_type", + "name": "ActivationType" + } + ], + "optimizationDefault": "w8a16", + "aitkPython": "qnn_workflow.py", + "sections": [ + { + "autoGenerated": true, + "name": "Convert", + "phase": "Conversion", + "parameters": [], + "toggle": { + "autoGenerated": true, + "name": "Convert to ONNX format", + "type": "bool", + "path": "passes.aitkpython", + "actions": [ + [], + [] + ], + "readOnly": true + } + }, + { + "name": "Quantize", + "phase": "Quantization", + "parameters": [ + { + "name": "Activation Type", + "tags": [ + "ActivationType" + ], + "description": "Quantization data type of activation. ‘Int8’ for signed 8-bit integer, ‘UInt8’ for unsigned 8-bit integer etc.", + "descriptionLink": "https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html", + "type": "enum", + "displayNames": [ + "Int8", + "UInt8", + "Int16", + "UInt16" + ], + "displayType": "RadioGroup", + "path": "passes.aitkpython.activation_type", + "values": [ + "int8", + "uint8", + "int16", + "uint16" + ], + "template": { + "path": "passes.aitkpython.activation_type", + "template": "ActivationType" + } + }, + { + "name": "Weight Type", + "tags": [ + "WeightType" + ], + "description": "Data type for quantizing weights. ‘Int8’ for signed 8-bit integer, ‘UInt8’ for unsigned 8-bit integer etc.", + "descriptionLink": "https://onnxruntime.ai/docs/performance/model-optimizations/quantization.html", + "type": "enum", + "displayNames": [ + "Int8", + "UInt8", + "Int16", + "UInt16" + ], + "displayType": "RadioGroup", + "path": "passes.aitkpython.precision", + "values": [ + "int8", + "uint8", + "int16", + "uint16" + ], + "template": { + "path": "passes.aitkpython.precision", + "template": "WeightType" + } + }, + { + "name": "Quantization Dataset", + "tags": [ + "QuantizationDataset" + ], + "type": "enum", + "path": "passes.aitkpython.dataset_name", + "values": [ + "librispeech_asr" + ], + "template": { + "path": "passes.aitkpython.dataset_name", + "values": [ + "librispeech_asr" + ], + "template": "QuantizationDataset" + } + }, + { + "name": "Quantization Dataset Split", + "tags": [ + "QuantizationDatasetSplit", + "DependsOnDataset" + ], + "type": "enum", + "path": "passes.aitkpython.split", + "values": [ + "test", + "validation", + "train.100", + "train.360" + ], + "template": { + "path": "passes.aitkpython.split", + "values": [ + "test", + "validation", + "train.100", + "train.360" + ], + "template": "QuantizationDatasetSplit" + } + }, + { + "name": "Quantization Dataset Size", + "description": "ATTENTION! Using 100 samples of librispeech dataset to generate the required real data requires around 164 GB of disk space.", + "type": "int", + "path": "passes.aitkpython.length", + "template": { + "description": "ATTENTION! Using 100 samples of librispeech dataset to generate the required real data requires around 164 GB of disk space.", + "path": "passes.aitkpython.length", + "template": "QuantizationDatasetSize" + } + } + ], + "disableToggleGeneration": true, + "toggle": { + "name": "Quantize model", + "type": "bool", + "path": "passes.aitkpython", + "actions": [ + [], + [] + ], + "readOnly": true + } + }, + { + "name": "Evaluate", + "phase": "Evaluation", + "parameters": [], + "toggle": { + "autoGenerated": true, + "name": "Evaluate model performance", + "type": "bool", + "path": "evaluator", + "actions": [ + [], + [ + { + "type": "delete", + "path": "evaluator" + } + ] + ] + } + } + ] +} diff --git a/openai-whisper-large-v3-turbo/aitk/qnn_workflow.py b/openai-whisper-large-v3-turbo/aitk/qnn_workflow.py new file mode 100644 index 00000000..d9b22f86 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/qnn_workflow.py @@ -0,0 +1,117 @@ +import argparse +import json +import os +import olive.workflows +import subprocess +import sys +import logging + +logger = logging.getLogger(os.path.basename(__file__)) +logging.basicConfig(level=logging.INFO) + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", required=True, help="path to input config file") + parser.add_argument("--model_config", help="path to input model config file") + parser.add_argument("--runtime", required=True, help="runtime") + return parser.parse_args() + +def load_update_config( + config_path: str, + cache_dir: str, + output_dir: str, + activation_type: str | None = None, + precision: str | None = None, + data_path: str | None = None, + num_data: int | None = None) -> dict: + with open(config_path, 'r', encoding='utf-8') as file: + oliveJson = json.load(file) + + oliveJson["cache_dir"] = cache_dir + oliveJson["output_dir"] = output_dir + + if activation_type is not None: + oliveJson["passes"]["quantization"]["activation_type"] = activation_type + if precision is not None: + oliveJson["passes"]["quantization"]["precision"] = precision + if data_path is not None: + oliveJson["data_configs"][0]["dataloader_config"]["data_path"] = data_path + if num_data is not None: + oliveJson["data_configs"][0]["dataloader_config"]["num_data"] = num_data + + return oliveJson + + +def generate_model( + config_path: str, + cache_dir: str, + output_dir: str, + skip_existing: bool = True, + activation_type: str | None = None, + precision: str | None = None, + data_path: str | None = None, + num_data: int | None = None): + if skip_existing and os.path.exists(os.path.join(output_dir, "model.onnx")): + logger.info(f"Output model {output_dir} already exists, skipping {config_path}.") + return + logger.info(f"Generating model from {config_path}...") + oliveJson = load_update_config(config_path, cache_dir, output_dir, activation_type, precision, data_path, num_data) + output = olive.workflows.run(oliveJson) + if output is None or not output.has_output_model(): + error = f"Model file is not generated" + raise Exception(error) + + +def main(): + args = parse_arguments() + # When we have model_config, we are in evaluation + if args.model_config: + # TODO add evaluation + metrics = { + "latency-avg": 5.26205 + } + output_file = os.path.join(os.path.dirname(args.config), "metrics.json") + resultStr = json.dumps(metrics, indent=4) + with open(output_file, 'w') as file: + file.write(resultStr) + logger.info("Model lab succeeded for evaluation.\n%s", resultStr) + return + + # Get arguments + with open(args.config, 'r', encoding='utf-8') as file: + oliveJson = json.load(file) + output_dir: str = oliveJson["output_dir"] + cache_dir: str = oliveJson["cache_dir"] + config_pass = oliveJson["passes"]["aitkpython"] + activation_type: str = config_pass["activation_type"] + precision: str = config_pass["precision"] + dataset_name: str = config_pass["dataset_name"] + dataset_split: str = config_pass["split"] + num_data: int = config_pass["length"] + audio_path: str = os.path.join("data", dataset_name.replace("/", "_"), dataset_split) + save_data_path: str = os.path.join("data", "_data_" + dataset_name.replace("/", "_"), dataset_split) + # Generate original model + original_encoder = os.path.join("data", "_encoder_fp32") + generate_model("whisper_large_v3_turbo_encoder_fp32.json", cache_dir, original_encoder) + original_decoder = os.path.join("data", "_decoder_fp32") + generate_model("whisper_large_v3_turbo_decoder_fp32.json", cache_dir, original_decoder) + # Generate dataset + subprocess.run([sys.executable, "qnn_run.py", + "--audio-path", audio_path, + "--encoder", os.path.join(original_encoder, "model.onnx"), + "--decoder", os.path.join(original_decoder, "model.onnx"), + "--save_data", save_data_path, + "--dataset_name", dataset_name, + "--dataset_split", dataset_split, + "--num_data", str(num_data)], + check=True) + # Generate quantized model + generate_model("whisper_large_v3_turbo_encoder_qdq.json", cache_dir, os.path.join(output_dir, "encoder"), + False, activation_type, precision, save_data_path, num_data) + # decoder has more data for 1 sample, to keep variants, multiply num_data by 10 + generate_model("whisper_large_v3_turbo_decoder_qdq.json", cache_dir, os.path.join(output_dir, "decoder"), + False, activation_type, precision, save_data_path, num_data * 10) + + +if __name__ == "__main__": + main() diff --git a/openai-whisper-large-v3-turbo/aitk/whisper_decoder_load.py b/openai-whisper-large-v3-turbo/aitk/whisper_decoder_load.py new file mode 100644 index 00000000..3665ddd2 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/whisper_decoder_load.py @@ -0,0 +1,44 @@ +import glob +import os + +import numpy as np +from qai_hub_models.utils.input_spec import make_torch_inputs + +from olive.data.registry import Registry + + +def model_loader(model_name): + if model_name == "openai/whisper-large-v3-turbo": + from qai_hub_models.models.whisper_large_v3_turbo import Model + + model = Model.from_pretrained() + component = model.components["HfWhisperDecoder"] + return component + else: + raise ValueError(f"Invalid model id provided: {model_name}") + + +def generate_dummy_inputs(model=None): + from qai_hub_models.models.whisper_large_v3_turbo import Model + + model = Model.from_pretrained() + component = model.components["HfWhisperDecoder"] + input_spec = component.get_input_spec() + return tuple(make_torch_inputs(input_spec)) + + +class DecoderBaseDataLoader: + def __init__(self, data_path, num_data): + self.data_files = sorted(glob.glob(os.path.join(data_path, "**", "*_decoder_input.npy"), recursive=True))[:num_data] + print(f"Decoder data loader loaded {len(self.data_files)} samples from {data_path}") + + def __len__(self): + return len(self.data_files) + + def __getitem__(self, idx): + return np.load(self.data_files[idx], allow_pickle=True).item() + + +@Registry.register_dataloader() +def decoder_data_loader(dataset, data_path, num_data): + return DecoderBaseDataLoader(data_path, num_data) diff --git a/openai-whisper-large-v3-turbo/aitk/whisper_encoder_load.py b/openai-whisper-large-v3-turbo/aitk/whisper_encoder_load.py new file mode 100644 index 00000000..e6f2b6c7 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/whisper_encoder_load.py @@ -0,0 +1,47 @@ +import glob +import os + +import numpy as np +from qai_hub_models.utils.input_spec import make_torch_inputs + +from olive.data.registry import Registry + + +def model_loader(model_name): + if model_name == "openai/whisper-large-v3-turbo": + from qai_hub_models.models.whisper_large_v3_turbo import Model + + model = Model.from_pretrained() + component = model.components["HfWhisperEncoder"] + return component + else: + raise ValueError(f"Invalid model id provided: {model_name}") + + +def generate_dummy_inputs(model=None): + from qai_hub_models.models.whisper_large_v3_turbo import Model + + model = Model.from_pretrained() + component = model.components["HfWhisperEncoder"] + input_spec = component.get_input_spec() + return tuple(make_torch_inputs(input_spec)) + + +class EncoderBaseDataLoader: + def __init__(self, data_path, num_data): + # Ensure deterministic ordering of input feature files + self.data_files = sorted( + glob.glob(os.path.join(data_path, "**", "*_input_features.npy"), recursive=True) + )[:num_data] + print(f"Encoder data loader loaded {len(self.data_files)} samples from {data_path}") + + def __len__(self): + return len(self.data_files) + + def __getitem__(self, idx): + return np.load(self.data_files[idx], allow_pickle=True).item() + + +@Registry.register_dataloader() +def encoder_data_loader(dataset, data_path, num_data): + return EncoderBaseDataLoader(data_path, num_data) diff --git a/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_decoder_fp32.json b/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_decoder_fp32.json new file mode 100644 index 00000000..5c0e8908 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_decoder_fp32.json @@ -0,0 +1,53 @@ +{ + "input_model": { + "type": "PyTorchModel", + "model_path": "openai/whisper-large-v3-turbo", + "model_loader": "model_loader", + "model_script": "whisper_decoder_load.py", + "io_config": { + "input_names": ["input_ids", + "attention_mask", + "k_cache_self_0_in", + "v_cache_self_0_in", + "k_cache_self_1_in", + "v_cache_self_1_in", + "k_cache_self_2_in", + "v_cache_self_2_in", + "k_cache_self_3_in", + "v_cache_self_3_in", + "k_cache_cross_0", + "v_cache_cross_0", + "k_cache_cross_1", + "v_cache_cross_1", + "k_cache_cross_2", + "v_cache_cross_2", + "k_cache_cross_3", + "v_cache_cross_3", + "position_ids" ], + "output_names": ["logits", + "k_cache_self_0_out", + "v_cache_self_0_out", + "k_cache_self_1_out", + "v_cache_self_1_out", + "k_cache_self_2_out", + "v_cache_self_2_out", + "k_cache_self_3_out", + "v_cache_self_3_out"] + }, + "dummy_inputs_func": "generate_dummy_inputs" + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "accelerators": [ { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } ] + } + }, + "passes": { + "convert": { "type": "OnnxConversion", "target_opset": 20 } + }, + "log_severity_level": 0, + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_dir": "models/whisper_decoder_fp32" +} diff --git a/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_decoder_qdq.json b/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_decoder_qdq.json new file mode 100644 index 00000000..8a67130a --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_decoder_qdq.json @@ -0,0 +1,71 @@ +{ + "input_model": { + "type": "PyTorchModel", + "model_path": "openai/whisper-large-v3-turbo", + "model_loader": "model_loader", + "model_script": "whisper_decoder_load.py", + "io_config": { + "input_names": ["input_ids", + "attention_mask", + "k_cache_self_0_in", + "v_cache_self_0_in", + "k_cache_self_1_in", + "v_cache_self_1_in", + "k_cache_self_2_in", + "v_cache_self_2_in", + "k_cache_self_3_in", + "v_cache_self_3_in", + "k_cache_cross_0", + "v_cache_cross_0", + "k_cache_cross_1", + "v_cache_cross_1", + "k_cache_cross_2", + "v_cache_cross_2", + "k_cache_cross_3", + "v_cache_cross_3", + "position_ids" ], + "output_names": ["logits", + "k_cache_self_0_out", + "v_cache_self_0_out", + "k_cache_self_1_out", + "v_cache_self_1_out", + "k_cache_self_2_out", + "v_cache_self_2_out", + "k_cache_self_3_out", + "v_cache_self_3_out"] + }, + "dummy_inputs_func": "generate_dummy_inputs" + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "accelerators": [ { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } ] + } + }, + "data_configs": [ + { + "name": "quantize_data_config", + "user_script": "whisper_decoder_load.py", + "load_dataset_config": { "type": "local_dataset" }, + "dataloader_config": { "type": "decoder_data_loader", + "data_path": ".\\data\\quantization_data", + "num_data": 100 } + } + ], + "passes": { + "convert": { "type": "OnnxConversion", "target_opset": 20 }, + "quantization": { + "type": "OnnxStaticQuantization", + "data_config": "quantize_data_config", + "activation_type": "uint16", + "precision": "uint8", + "calibrate_method": "MinMax", + "quant_preprocess": true + } + }, + "log_severity_level": 0, + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_dir": "models/whisper_decoder_qdq" +} diff --git a/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_encoder_fp32.json b/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_encoder_fp32.json new file mode 100644 index 00000000..7978e1a3 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_encoder_fp32.json @@ -0,0 +1,34 @@ +{ + "input_model": { + "type": "PyTorchModel", + "model_path": "openai/whisper-large-v3-turbo", + "model_loader": "model_loader", + "model_script": "whisper_encoder_load.py", + "io_config": { + "input_names": [ "input_features" ], + "output_names": [ "k_cache_cross_0", + "v_cache_cross_0", + "k_cache_cross_1", + "v_cache_cross_1", + "k_cache_cross_2", + "v_cache_cross_2", + "k_cache_cross_3", + "v_cache_cross_3" ] + }, + "dummy_inputs_func": "generate_dummy_inputs" + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "accelerators": [ { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } ] + } + }, + "passes": { + "convert": { "type": "OnnxConversion", "target_opset": 20 } + }, + "log_severity_level": 0, + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_dir": "models/whisper_encoder_fp32" +} diff --git a/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_encoder_qdq.json b/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_encoder_qdq.json new file mode 100644 index 00000000..9512e82d --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/whisper_large_v3_turbo_encoder_qdq.json @@ -0,0 +1,52 @@ +{ + "input_model": { + "type": "PyTorchModel", + "model_path": "openai/whisper-large-v3-turbo", + "model_loader": "model_loader", + "model_script": "whisper_encoder_load.py", + "io_config": { + "input_names": [ "input_features" ], + "output_names": [ "k_cache_cross_0", + "v_cache_cross_0", + "k_cache_cross_1", + "v_cache_cross_1", + "k_cache_cross_2", + "v_cache_cross_2", + "k_cache_cross_3", + "v_cache_cross_3" ] + }, + "dummy_inputs_func": "generate_dummy_inputs" + }, + "systems": { + "local_system": { + "type": "LocalSystem", + "accelerators": [ { "device": "cpu", "execution_providers": [ "CPUExecutionProvider" ] } ] + } + }, + "data_configs": [ + { + "name": "quantize_data_config", + "user_script": "whisper_encoder_load.py", + "load_dataset_config": { "type": "local_dataset" }, + "dataloader_config": { "type": "encoder_data_loader", + "data_path": ".\\data\\quantization_data", + "num_data": 100 } + } + ], + "passes": { + "convert": { "type": "OnnxConversion", "target_opset": 20 }, + "quantization": { + "type": "OnnxStaticQuantization", + "data_config": "quantize_data_config", + "activation_type": "uint16", + "precision": "uint8", + "calibrate_method": "MinMax", + "quant_preprocess": true + } + }, + "log_severity_level": 0, + "host": "local_system", + "target": "local_system", + "cache_dir": "cache", + "output_dir": "models/whisper_encoder_qdq" +} diff --git a/openai-whisper-large-v3-turbo/aitk/winml.py b/openai-whisper-large-v3-turbo/aitk/winml.py new file mode 100644 index 00000000..74a12c53 --- /dev/null +++ b/openai-whisper-large-v3-turbo/aitk/winml.py @@ -0,0 +1,21 @@ +import json + +def _get_ep_paths() -> dict[str, str]: + from winui3.microsoft.windows.applicationmodel.dynamicdependency.bootstrap import ( + InitializeOptions, + initialize + ) + import winui3.microsoft.windows.ai.machinelearning as winml + eps = {} + with initialize(options = InitializeOptions.ON_NO_MATCH_SHOW_UI): + catalog = winml.ExecutionProviderCatalog.get_default() + providers = catalog.find_all_providers() + for provider in providers: + provider.ensure_ready_async().get() + eps[provider.name] = provider.library_path + # DO NOT call provider.try_register in python. That will register to the native env. + return eps + +if __name__ == "__main__": + eps = _get_ep_paths() + print(json.dumps(eps))