From fe29f48baa64692e52e32b13a0e2658d70a5d045 Mon Sep 17 00:00:00 2001 From: J-shang <33053116+J-shang@users.noreply.github.com> Date: Fri, 12 May 2023 10:52:51 +0800 Subject: [PATCH 1/4] fix example config list (#5554) --- examples/compression/fusion/pd_fuse.py | 1 - examples/compression/fusion/pqd_fuse.py | 1 - examples/compression/pruning/norm_pruning.py | 1 - examples/compression/pruning/scheduled_pruning.py | 1 - examples/compression/pruning/slim_pruning.py | 1 - examples/compression/pruning/taylor_pruning.py | 1 - 6 files changed, 6 deletions(-) diff --git a/examples/compression/fusion/pd_fuse.py b/examples/compression/fusion/pd_fuse.py index 546f65b38c..fe4ed2186c 100644 --- a/examples/compression/fusion/pd_fuse.py +++ b/examples/compression/fusion/pd_fuse.py @@ -72,7 +72,6 @@ def teacher_predict(batch, teacher_model): config_list = [{ 'op_types': ['Conv2d'], - 'op_names_re': ['features.*'], 'lambda': 0.1, 'apply_method': 'mse', }] diff --git a/examples/compression/fusion/pqd_fuse.py b/examples/compression/fusion/pqd_fuse.py index 07905b703f..763e70f75e 100644 --- a/examples/compression/fusion/pqd_fuse.py +++ b/examples/compression/fusion/pqd_fuse.py @@ -94,7 +94,6 @@ def teacher_predict(batch, teacher_model): d_config_list = [{ 'op_types': ['Conv2d'], - 'op_names_re': ['features.*'], 'lambda': 0.1, 'apply_method': 'mse', }] diff --git a/examples/compression/pruning/norm_pruning.py b/examples/compression/pruning/norm_pruning.py index ec0a1858ef..fa54af6c35 100644 --- a/examples/compression/pruning/norm_pruning.py +++ b/examples/compression/pruning/norm_pruning.py @@ -35,7 +35,6 @@ config_list = [{ 'op_types': ['Conv2d'], - 'op_names_re': ['features.*'], 'sparse_ratio': 0.5 }] dummy_input = torch.rand(8, 3, 224, 224).to(device) diff --git a/examples/compression/pruning/scheduled_pruning.py b/examples/compression/pruning/scheduled_pruning.py index c16f4153fc..30392e40e4 100644 --- a/examples/compression/pruning/scheduled_pruning.py +++ b/examples/compression/pruning/scheduled_pruning.py @@ -32,7 +32,6 @@ config_list = [{ 'op_types': ['Conv2d'], - 'op_names_re': ['features.*'], 'sparse_ratio': 0.5 }] dummy_input = torch.rand(8, 3, 224, 224).to(device) diff --git a/examples/compression/pruning/slim_pruning.py b/examples/compression/pruning/slim_pruning.py index 03397c33aa..5e975dad6c 100644 --- a/examples/compression/pruning/slim_pruning.py +++ b/examples/compression/pruning/slim_pruning.py @@ -30,7 +30,6 @@ config_list = [{ 'op_types': ['Conv2d'], - 'op_names_re': ['features.*'], 'sparse_ratio': 0.5 }] dummy_input = torch.rand(8, 3, 224, 224).to(device) diff --git a/examples/compression/pruning/taylor_pruning.py b/examples/compression/pruning/taylor_pruning.py index 971c06a994..7d0e6d8291 100644 --- a/examples/compression/pruning/taylor_pruning.py +++ b/examples/compression/pruning/taylor_pruning.py @@ -30,7 +30,6 @@ config_list = [{ 'op_types': ['Conv2d'], - 'op_names_re': ['features.*'], 'sparse_ratio': 0.5 }] dummy_input = torch.rand(8, 3, 224, 224).to(device) From 89fed7b624508267964f01aaa961d05babc7d590 Mon Sep 17 00:00:00 2001 From: v-hongyiyao Date: Wed, 28 Jun 2023 11:02:37 +0000 Subject: [PATCH 2/4] dynamic_input_output for torch2onnx & onnx2trt --- examples/tutorials/quantization_bert_glue.py | 113 +++++++++++++++- .../quantization_speedup/frontend_to_onnx.py | 77 +++++++---- .../integrated_tensorrt.py | 125 +++++++++++------- .../quantization_speedup/trt_pycuda.py | 14 +- 4 files changed, 247 insertions(+), 82 deletions(-) diff --git a/examples/tutorials/quantization_bert_glue.py b/examples/tutorials/quantization_bert_glue.py index 6a58e0670d..cc5295d3db 100644 --- a/examples/tutorials/quantization_bert_glue.py +++ b/examples/tutorials/quantization_bert_glue.py @@ -54,18 +54,18 @@ from transformers.training_args import TrainingArguments -task_name = 'qnli' +task_name = 'rte' #'qnli''rte' finetune_lr = 4e-5 quant_lr = 1e-5 -quant_method = 'lsq' -dev_mode = True +quant_method = 'lsq'# 'lsq' 'ptq' +dev_mode = False if dev_mode: quant_max_epochs = 1 finetune_max_epochs = 1 else: - quant_max_epochs = 10 - finetune_max_epochs = 10 + quant_max_epochs = 10 #10 + finetune_max_epochs = 10 #10 # %% @@ -212,13 +212,42 @@ def build_finetuning_model(state_dict_path: str, is_quant=False): from nni.contrib.compression.quantization import QATQuantizer, LsqQuantizer, PtqQuantizer from nni.contrib.compression.utils import TransformersEvaluator +# dummy_input is used for torch2onnx and onnx2trt + +# transfer dummy_input type into dict +def transfer_dummy_input(dummy_input,input_names): + dict_dummy_input = {} + if isinstance(dummy_input,dict): + for input_name,input_tensor in dummy_input.items(): + if torch.is_tensor(input_tensor): + continue + else: + dummy_input[input_name] = torch.tensor(input_tensor) + dict_dummy_input = dummy_input + elif isinstance(dummy_input,tuple): + for i in range(len(dummy_input)): + if torch.is_tensor(dummy_input[i]): + continue + else: + temp_dummy_input = torch.tensor(dummy_input[i]) + dict_dummy_input[input_names[i]] = temp_dummy_input + elif torch.is_tensor(dummy_input): + dict_dummy_input[input_names[0]] = dummy_input + else : + print('the dummy_input type is not allowed !') + return dict_dummy_input + +dummy_input = ([[101, 11271, 20726, 1010, 1996, 7794, 1997, 1996, 3364, 5696, 20726, 1010, 2038, 2351, 1997, 11192, 4456, 2012, 2287, 4008, 1010, 2429, 2000, 1996, 5696, 20726, 3192, 1012, 102, 5696, 20726, 2018, 2019, 4926, 1012, 102]],[[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]],[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]) +input_names=['input_ids','token_type_ids','attention_mask'] +dummy_input = transfer_dummy_input(dummy_input,input_names) + def fake_quantize(): config_list = [{ 'op_types': ['Linear'], 'op_names_re': ['bert.encoder.layer.{}'.format(i) for i in range(12)], - 'target_names': ['weight', '_output_'], + 'target_names': ['weight', '_input_','_output_'], 'quant_dtype': 'int8', - 'quant_scheme': 'affine', + 'quant_scheme': 'symmetric',#'affine''symmetric' 'granularity': 'default', }] @@ -243,6 +272,75 @@ def fake_quantize(): quantizer.evaluator.bind_model(model, quantizer._get_param_names_map()) print(quantizer.evaluator.evaluate()) + model.eval() + model.to('cpu') + print('quantized torch-model output: ', model(**dummy_input)) + model.to('cuda') + quantizer.unwrap_model() + evaluate() + + # Speed up the model with TensorRT + from nni.compression.pytorch.quantization_speedup import ModelSpeedupTensorRT + engine = ModelSpeedupTensorRT(model, dummy_input=dummy_input, config=calibration_config, onnx_path='bert_rte.onnx',input_names=['input_ids','token_type_ids','attention_mask'],output_names=['output'], + dynamic_axes={'input_ids' : {1 : 'seq_len'}, + 'token_type_ids' : {1 : 'seq_len'}, + 'attention_mask' : {1 : 'seq_len'}}, + dynamic_shape_setting ={'min_shape' : (1,18), + 'opt_shape' : (1,72), + 'max_shape' : (1,360)}) + engine.compress() + import time + start_time = time.time() + output, time_span = engine.inference(dummy_input) + infer_time = time.time() - start_time + print('test dummy_input inference output: ', output) + print('test dummy_input inference time: ', time_span, infer_time) + test_Accuracy(engine) + +def to_numpy(tensor): + return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy() + +def test_Accuracy(engine): + tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') + _, validation_datasets = prepare_datasets(task_name, tokenizer, '') + merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()]) # type: ignore + true_cnt = 0 + total_time = 0 + for input_data in merged_validation_dataset: + for input_name,input_tensor in input_data.items(): + if 'labels' != input_name: + input_data[input_name] = torch.tensor([input_tensor]) + test_data = {key: input_data[key] for key in list(input_data.keys())[:-1]} + output, time_span = engine.inference(test_data,reset_context=True) + total_time += time_span + prediction = torch.argmax(output,-1) + if input_data['labels'] == prediction: + true_cnt +=1 + Accuracy = true_cnt/len(merged_validation_dataset) + print('inference time: ', total_time /len(merged_validation_dataset)) + print('Accuracy of mode #1: ', Accuracy) + +def test_onnx_Accuracy(onnx_model): + import onnxruntime + ort_session = onnxruntime.InferenceSession(onnx_model) + tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased') + _, validation_datasets = prepare_datasets(task_name, tokenizer, '') + merged_validation_dataset = ConcatDataset([d for d in validation_datasets.values()]) # type: ignore + true_cnt = 0 + for input_data in merged_validation_dataset: + for input_name,input_tensor in input_data.items(): + if 'labels' != input_name: + input_data[input_name] = to_numpy(torch.tensor([input_tensor])) + test_data = {key: input_data[key] for key in list(input_data.keys())[:-1]} + output = ort_session.run(None, test_data) + prediction = np.argmax(output,-1) + if input_data['labels'] == prediction: + true_cnt +=1 + Accuracy = true_cnt/len(merged_validation_dataset) + print('Accuracy of mode #1: ', Accuracy) + + + def evaluate(): model = build_finetuning_model(f'./output/bert_finetuned/{task_name}.bin', is_quant=False) trainer = prepare_traced_trainer(model, is_quant=False) @@ -251,6 +349,7 @@ def evaluate(): fake_quantize() +test_onnx_Accuracy('bert_rte.onnx') evaluate() diff --git a/nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py b/nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py index 2e043c8a43..7fa2c52097 100644 --- a/nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py +++ b/nni/compression/pytorch/quantization_speedup/frontend_to_onnx.py @@ -54,30 +54,52 @@ def unwrapper(model_onnx, index2name, config): dict The configuration of onnx model layers and calibration parameters """ - # Support Gemm, Conv, Relu, Clip(Relu6) and Maxpool - support_op = ['Gemm', 'Conv', 'Relu', 'Clip', 'MaxP'] + # Support Gemm, Conv, Relu, Clip(Relu6) and Maxpool + MatMul + support_op = ['Gemm', 'Conv', 'Relu', 'Clip', 'MaxP', 'MatMul'] idx = 0 onnx_config = {} - while idx < len(model_onnx.graph.node): - nd = model_onnx.graph.node[idx] - if nd.name[0:4] in support_op and idx > 1: - # Grad constant node and multiply node - const_nd = model_onnx.graph.node[idx-2] - mul_nd = model_onnx.graph.node[idx-1] - # Get index number which is transferred by constant node - index = int(onnx.numpy_helper.to_array(const_nd.attribute[0].t)) - if index != -1: - name = index2name[index] - onnx_config[nd.name] = config[name] - nd.input[0] = mul_nd.input[0] - # Remove constant node and multiply node - model_onnx.graph.node.remove(const_nd) - model_onnx.graph.node.remove(mul_nd) - idx = idx-2 - idx = idx+1 + mul_name_list =[] + const_name_list = [] + const_list = [] + mul_list = [] + #find mul node output name + for node in model_onnx.graph.node: + for op in support_op: + if op in node.name: + for node_input_name in node.input: + if 'Mul_output' in node_input_name: + mul_name_list.append(node_input_name) + #find const node output name by mul node output name + for node in model_onnx.graph.node: + if node.output[0] in mul_name_list: + for node_input_name in node.input: + if 'Constant_output' in node_input_name: + const_name_list.append(node_input_name) + # find mul node and const node + for node in model_onnx.graph.node: + for nd_name in mul_name_list: + if node.output[0] == nd_name: + mul_list.append(node) + for nd_name in const_name_list: + if node.output[0] == nd_name: + const_list.append(node) + for node in model_onnx.graph.node: + for mul_node in mul_list: + if mul_node.output[0] in node.input: + # import pdb;pdb.set_trace() + for const_node in const_list: + if const_node.output[0] in mul_node.input: + # import pdb;pdb.set_trace() + index = int(onnx.numpy_helper.to_array(const_node.attribute[0].t)) + if index != -1: + name = index2name[index] + onnx_config[node.name] = config[name] + node.input[0] = mul_node.input[0] + model_onnx.graph.node.remove(const_node) + model_onnx.graph.node.remove(mul_node) return model_onnx, onnx_config -def torch_to_onnx(model, config, input_shape, model_path, input_names, output_names): +def torch_to_onnx(model, config, dummy_input, model_path, input_names, output_names,dynamic_axes=None): """ Convert torch model to onnx model and get layer bits config of onnx model. @@ -103,6 +125,8 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na dict The configuration of onnx model layers and calibration parameters """ + device = torch.device('cpu') + model.to(device) # Support Gemm, Conv, Relu, Clip(Relu6) and MaxPool support_op = [torch.nn.Conv2d, torch.nn.Linear, torch.nn.ReLU, torch.nn.ReLU6, torch.nn.MaxPool2d] # Transfer bits number to onnx layer by using wrapper @@ -124,14 +148,15 @@ def torch_to_onnx(model, config, input_shape, model_path, input_names, output_na set_nested_attr(model, name, wrapper_module) # Convert torch model to onnx model and save it in model_path device = torch.device('cpu') - dummy_input = torch.randn(input_shape) - dummy_input = dummy_input.to(device) - model.to(device) - torch.onnx.export(model, dummy_input, model_path, verbose=False, input_names=input_names, output_names=output_names, export_params=True) + if(dynamic_axes == None): + dynamic_axes = {'input' : {2 : 'image_height',3:'image_wdith'}, #for image + 'output' : {2 : 'image_height',3:'image_wdith'}} + # dummy_input = dummy_input.to(device) + # model.to(device) + torch.onnx.export(model, dummy_input, model_path, verbose=False, input_names=input_names, output_names=output_names, export_params=True,opset_version=11,dynamic_axes=dynamic_axes) # Load onnx model model_onnx = onnx.load(model_path) model_onnx, onnx_config = unwrapper(model_onnx, index2name, config) onnx.save(model_onnx, model_path) - onnx.checker.check_model(model_onnx) - return model_onnx, onnx_config \ No newline at end of file + return model_onnx, onnx_config diff --git a/nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py b/nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py index 7b7d54520b..3bee77cf50 100644 --- a/nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py +++ b/nni/compression/pytorch/quantization_speedup/integrated_tensorrt.py @@ -54,22 +54,22 @@ def _handle_gemm(layer, config, out2layer, in2layer): LayerType.Constant (bias) -> Shuffle ->| assume quantize input, output, and weight """ - w_bits = config['weight_bits'] + w_bits = config['weight']['quant_bits'] layer.precision = Precision_Dict[w_bits] # handle the input tensor in_tensor = layer.get_input(0) - in_tensor.dynamic_range = (config['tracked_min_input'], config['tracked_max_input']) + in_tensor.dynamic_range = (config['_input_0']['tracked_min'], config['_input_0']['tracked_max']) # handle the output tensor out_tensor = layer.get_output(0) - out_tensor.dynamic_range = (config['tracked_min_output'], config['tracked_max_output']) + out_tensor.dynamic_range = (config['_output_0']['tracked_min'], config['_output_0']['tracked_max']) # handle weight w_in_tensor = layer.get_input(1) weight_layer = out2layer[w_in_tensor.name] - assert weight_layer.type == trt.LayerType.CONSTANT + # assert weight_layer.type == trt.LayerType.CONSTANT weight_layer.precision = Precision_Dict[w_bits] weight_layer.set_output_type(0, Precision_Dict[w_bits]) w_out_tensor = weight_layer.get_output(0) - w_out_tensor.dynamic_range = (config['min_weight'], config['max_weight']) + w_out_tensor.dynamic_range = (config['weight']['tracked_min'], config['weight']['tracked_max']) print('special gemm: ', w_out_tensor.dynamic_range) # TODO: handle sum & bias # NOTE: a feasible way is setting bias to 0 in quantization algorithm size @@ -108,6 +108,7 @@ def propagate_from_low_bit_predecessor(layer, out2layer, default_precision=trt.f dynamic range of current layer's output tensor """ dynamic_range = None + tensor = layer.get_input(0) if tensor is not None: predecessor = out2layer[tensor.name] @@ -123,7 +124,7 @@ def propagate_from_low_bit_predecessor(layer, out2layer, default_precision=trt.f return trt.int32, None else: logger.warning(f'set op {layer.name} to default precision {default_precision}') - return default_precision, None + return layer.get_output_type(0),None def config_network_precision(network, config): """ @@ -134,42 +135,54 @@ def config_network_precision(network, config): # build two auxiliary indices out2layer = {} in2layer = {} + for layer_idx in range(network.num_layers): - layer = network.get_layer(layer_idx) + layer = network.get_layer(layer_idx) for i in range(layer.num_outputs): output = layer.get_output(i) - out2layer[output.name] = layer + out2layer[output.name] = layer for i in range(layer.num_inputs): - _input = layer.get_input(i) - if _input.name in in2layer: - in2layer[_input.name].append(layer) - else: - in2layer[_input.name] = [layer] - - net_input = network.get_input(0) - assert net_input.name in in2layer + if layer.get_input(i) != None: + _input = layer.get_input(i) + if _input.name in in2layer: + in2layer[_input.name].append(layer) + else: + in2layer[_input.name] = [layer] + + # fliter out the input_layer + input_list = [] + for i in range(network.num_inputs): + net_input = network.get_input(i) + assert net_input.name in in2layer + input_list.append(net_input.name) # traverse the network/graph and specify precision and dynamic range for layer_idx in range(network.num_layers): # assume the traverse order is topological layer = network.get_layer(layer_idx) - if layer.name in config: - if layer.name[0:4] == 'Gemm': - _handle_gemm(layer, config[layer.name], out2layer, in2layer) + for i in range(layer.num_inputs): + if layer.get_input(i) != None: + _input = layer.get_input(i) + if _input.name in input_list: + break + else: + if layer.name in config: + if ('Gemm' in layer.name) or ('MatMul' in layer.name): + _handle_gemm(layer, config[layer.name], out2layer, in2layer) + else: + apply_precision_to_layer(layer, config[layer.name]) else: - apply_precision_to_layer(layer, config[layer.name]) - else: - precision, dynamic_range = propagate_from_low_bit_predecessor(layer, out2layer) - if precision: - layer.precision = precision - layer.set_output_type(0, precision) - if dynamic_range: - out_tensor = layer.get_output(0) - out_tensor.dynamic_range = dynamic_range + precision, dynamic_range = propagate_from_low_bit_predecessor(layer, out2layer) + if precision: + layer.precision = precision + layer.set_output_type(0, precision) + if dynamic_range: + out_tensor = layer.get_output(0) + out_tensor.dynamic_range = dynamic_range print_layer_precisions(network) -def build_engine_without_calib(onnx_model_file, config): +def build_engine_without_calib(onnx_model_file, config, dummy_input,dynamic_shape_setting): """ This function builds an engine from an onnx model following the precisions and dynamic range in config without calibrator. @@ -207,6 +220,15 @@ def build_engine_without_calib(onnx_model_file, config): logger.error(parser.get_error(error)) raise ValueError('Failed to parse the ONNX file.') + profile = builder.create_optimization_profile() + # #input is a dict + for input_name, input_tensor in dummy_input.items(): + profile.set_shape(input_name, min = dynamic_shape_setting['min_shape'], + opt = dynamic_shape_setting['opt_shape'], + max = dynamic_shape_setting['max_shape']) + + trt_config.add_optimization_profile(profile) + config_network_precision(network, config) # Build engine and do int8 calibration. @@ -270,16 +292,18 @@ class ModelSpeedupTensorRT(BaseModelSpeedup): The path user want to store onnx model which is converted from pytorch model. """ - def __init__(self, model, input_shape, config=None, onnx_path="default_model.onnx"): + def __init__(self, model, dummy_input, config=None, onnx_path="default_model.onnx", input_names = ["actual_input_1"],output_names = ["output1"],dynamic_axes=None,dynamic_shape_setting=None): super().__init__(model, config) self.model = model - self.input_shape = input_shape self.config = config self.onnx_path = onnx_path + self.dummy_input = dummy_input # Input name of onnx model providing for torch.onnx.export to generate onnx model # Output name of onnx model providing for torch.onnx.export to generate onnx model - self.input_names = ["actual_input_1"] - self.output_names = ["output1"] + self.input_names = input_names + self.output_names = output_names + self.dynamic_axes = dynamic_axes + self.dynamic_shape_setting = dynamic_shape_setting self.engine = None self.context = None @@ -301,10 +325,10 @@ def compress(self): """ assert self.config is not None # Convert pytorch model to onnx model and save onnx model in onnx_path - _, onnx_config = fonnx.torch_to_onnx(self.model, self.config, input_shape=self.input_shape, - model_path=self.onnx_path, input_names=self.input_names, output_names=self.output_names) + _, onnx_config = fonnx.torch_to_onnx(self.model, self.config, dummy_input=self.dummy_input, + model_path=self.onnx_path, input_names=self.input_names, output_names=self.output_names,dynamic_axes=self.dynamic_axes) valid_config(onnx_config) - self.engine = build_engine_without_calib(self.onnx_path, onnx_config) + self.engine = build_engine_without_calib(self.onnx_path, onnx_config,dummy_input=self.dummy_input,dynamic_shape_setting=self.dynamic_shape_setting) def compress_with_calibrator(self, calib): """ @@ -340,23 +364,30 @@ def inference(self, test_data, reset_context=False): """ if self.context is None or reset_context: self.context = self.engine.create_execution_context() - self.inputs, self.outputs, self.bindings, self.stream = common.allocate_buffers(self.engine) + self.context.active_optimization_profile = 0 # for dynamic shape + self.inputs, self.outputs, self.bindings, self.stream = common.allocate_buffers(self.engine,test_data) self.context.set_optimization_profile_async(0, self.stream.handle) + + # test_data is a dict + for input_name,input_tensor in test_data.items(): + index = self.engine.get_binding_index(input_name) + engine_input_shape = self.engine.get_binding_shape(index) + for div in range(len(input_tensor.shape)): + if (engine_input_shape[div]==-1): + engine_input_shape[div]=input_tensor.shape[div] + self.context.set_binding_shape(index,(engine_input_shape)) + if input_tensor.device != torch.device('cpu'): + logger.warning('test_data should be placed on CPU.') + input_tensor = input_tensor.to(torch.device('cpu')) + + input_tensor = input_tensor.numpy() + np.copyto(self.inputs[index].host, input_tensor.ravel()) - engine_input_shape = self.engine.get_binding_shape(0) - assert engine_input_shape[0] == test_data.size()[0] - if test_data.device != torch.device('cpu'): - logger.warning('test_data should be placed on CPU.') - test_data = test_data.to(torch.device('cpu')) - test_data = test_data.numpy() - assert test_data.dtype == np.float32 - - np.copyto(self.inputs[0].host, test_data.ravel()) start_time = time.time() trt_outputs = common.do_inference_v2(self.context, bindings=self.bindings, inputs=self.inputs, outputs=self.outputs, stream=self.stream) time_span = time.time() - start_time - return torch.as_tensor(trt_outputs[0]), time_span + return torch.as_tensor(trt_outputs), time_span def export_quantized_model(self, path): """ diff --git a/nni/compression/pytorch/quantization_speedup/trt_pycuda.py b/nni/compression/pytorch/quantization_speedup/trt_pycuda.py index f0c93b5351..e601895c4b 100644 --- a/nni/compression/pytorch/quantization_speedup/trt_pycuda.py +++ b/nni/compression/pytorch/quantization_speedup/trt_pycuda.py @@ -48,7 +48,7 @@ def __str__(self): def __repr__(self): return self.__str__() -def allocate_buffers(engine): +def allocate_buffers(engine,test_data): """ Allocates all buffers required for an engine, i.e. host/device inputs/outputs. NOTE: currently this function only supports NetworkDefinitionCreationFlag::kEXPLICIT_BATCH flag. @@ -73,8 +73,17 @@ def allocate_buffers(engine): outputs = [] bindings = [] stream = cuda.Stream() + # find real shape + binding_index = 0 + input_list = list(test_data.keys()) + #test_data is a dict,which is for setting real input shape while model shape is dynamic for binding in engine: - size = trt.volume(engine.get_binding_shape(binding)) # * engine.max_batch_size, batch size already in + if binding_index < len(test_data):#dynamic input shape + size = trt.volume(test_data[input_list[binding_index]].shape) + else: + #static input shape or static output shape + size = trt.volume(engine.get_binding_shape(binding)) # * engine.max_batch_size, batch size already in + dtype = trt.nptype(engine.get_binding_dtype(binding)) # Allocate host and device buffers host_mem = cuda.pagelocked_empty(size, dtype) @@ -86,6 +95,7 @@ def allocate_buffers(engine): inputs.append(HostDeviceMem(host_mem, device_mem)) else: outputs.append(HostDeviceMem(host_mem, device_mem)) + binding_index += 1 return inputs, outputs, bindings, stream # This function is generalized for multiple inputs/outputs for full dimension networks. From 5655bd5ba69f43469befa8605d2ecfc82be67112 Mon Sep 17 00:00:00 2001 From: v-hongyiyao Date: Wed, 28 Jun 2023 14:26:41 +0000 Subject: [PATCH 3/4] delete something unimportant --- examples/tutorials/quantization_bert_glue.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/examples/tutorials/quantization_bert_glue.py b/examples/tutorials/quantization_bert_glue.py index cc5295d3db..fa78379a1a 100644 --- a/examples/tutorials/quantization_bert_glue.py +++ b/examples/tutorials/quantization_bert_glue.py @@ -54,18 +54,18 @@ from transformers.training_args import TrainingArguments -task_name = 'rte' #'qnli''rte' +task_name = 'rte' finetune_lr = 4e-5 quant_lr = 1e-5 -quant_method = 'lsq'# 'lsq' 'ptq' +quant_method = 'ptq' dev_mode = False if dev_mode: quant_max_epochs = 1 finetune_max_epochs = 1 else: - quant_max_epochs = 10 #10 - finetune_max_epochs = 10 #10 + quant_max_epochs = 10 + finetune_max_epochs = 10 # %% From f479ec1d5d8899d0e36187af5e9eccb0b826ad27 Mon Sep 17 00:00:00 2001 From: BillAmihom Date: Fri, 30 Jun 2023 10:13:18 +0000 Subject: [PATCH 4/4] docs for dynamic-shape deployment --- docs/source/compression/dynamic_shape.rst | 84 +++++++++++++++++++++++ 1 file changed, 84 insertions(+) create mode 100644 docs/source/compression/dynamic_shape.rst diff --git a/docs/source/compression/dynamic_shape.rst b/docs/source/compression/dynamic_shape.rst new file mode 100644 index 0000000000..b1b53c979f --- /dev/null +++ b/docs/source/compression/dynamic_shape.rst @@ -0,0 +1,84 @@ +Compression for Models with Dynamic-shape Input +================== +Compression for models with dynamic-shape input is a novel experimental feature incorporated into NNI 3.0. +This feature makes deployment more convenient. For example, when we feed multiple images to the neural network, we don't have to create multiple models decided by their height and width. And when we feed a piece of text into a neural network, the length of the text is no longer limited by a fixed format. +This feature is mainly achieved through two steps. First, create a dynamic ONNX model, and then create a dynamic TensorRT engine through the dynamic ONNX model. +.. Note:: + + NNI strives to ensure maximum compatibility among different compressors in dynamic-shape compression. + Nevertheless, it is impossible to avoid mutual interference in model modification between different compression algorithms in some individual scenarios. + We encourage users to integrate algorithms after acquiring a comprehensive understanding of the fundamental principles of compression methods. + If you encounter any problems or doubts that cannot be resolved while using dynamic-shape compression, you are welcome to raise an issue for discussion. + +Main API +-------- + +To explain how dynamic-shape compression worked, we should know that each module in the model has a corresponding wrapper in the compressor. +The wrapper stores the necessary data required for compression. +``ModelSpeedupTensorRT`` append ``dummy_input`` as a parameter instead of ``input_shape``. +``dummy_input`` is an input that satisfies the torch model you want to deploy. It is used to create a onnx-model. +In addition, you should provide two parameters, ``dynamic_axes`` and ``dynamic_shape_setting``. +``dynamic_axes`` determine which dimensions of the model's input (or output) you set as dynamic. +``dynamic_shape_setting`` is to determine the specific range of the dynamic shape you set. + +Example +------- +Quantize Bert and Deploy Model into ONNX&TensorRT with Dynamic-shape input + +The full example can be found `here `__. + +The following code is a common pipeline with quantization first and then deployment. + +.. code-block:: python + ... + task_name = 'rte' + finetune_lr = 4e-5 + quant_lr = 1e-5 + quant_method = 'ptq' + ... + config_list = [{ + 'op_types': ['Linear'], + 'op_names_re': ['bert.encoder.layer.{}'.format(i) for i in range(12)], + 'target_names': ['weight', '_input_','_output_'], + 'quant_dtype': 'int8', + 'quant_scheme': 'symmetric',#'affine''symmetric' + 'granularity': 'default', + }] + +The same steps as the normal quantization by nni, first set the hyperparameters of the quantizer configuration. +When the fake-quantize finished, save the parameters of the quantization node as ``calibration_config``, +and then remove the quantization node in the model by ``quantizer.unwrap_model()``. +Prepare the ``dummy_input`` required by the input model. +In order to more accurately meet the model input requirements, it is recommended to extract ``dummy_input`` directly from the training-dataset or val-dataset of the task. +Modify the ``dummy_input`` to the ``dict`` data type through the function ``transfer_dummy_input``. + +.. code-block:: python + ... + input_names=['input_ids','token_type_ids','attention_mask'] + dummy_input = transfer_dummy_input(dummy_input,input_names) + + +``dynamic_axes`` is a dict. The dict keys are names of input and output whose shape is dynamic, + the dict values are dicts which specify dimensions are dynamic. +``dynamic_shape_setting`` requires you to provide three parameters, which are the smallest shape of your input, the commonly used shape, and the largest shape. + These three parameters facilitate TensorRT to allocate memory space to the model. + +.. code-block:: python + ... + dynamic_axes={'input_ids' : {1 : 'seq_len'}, + 'token_type_ids' : {1 : 'seq_len'}, + 'attention_mask' : {1 : 'seq_len'}} + dynamic_shape_setting ={'min_shape' : (1,18), + 'opt_shape' : (1,72), + 'max_shape' : (1,360)} + ... +.. code-block:: python + ... + engine = ModelSpeedupTensorRT(model, dummy_input=dummy_input, config=calibration_config, onnx_path='bert_rte.onnx',input_names=['input_ids','token_type_ids','attention_mask'],output_names=['output'], + dynamic_axes = dynamic_axes, + dynamic_shape_setting = dynamic_shape_setting) + engine.compress() + +After ``engine.compress()``,you get a TensorRT engine of original model. +You can test model's output and inference time by ``output, time_span = engine.inference(dummy_input)`` +You can test model's accuracy by ``test_Accuracy(engine)`` \ No newline at end of file