From b0a0817d7a1eab1e471815b97614c8cc6dd5abc0 Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 16 Jul 2020 15:53:13 +0800 Subject: [PATCH 1/2] support tranfser learning for estimator --- delta/utils/misc.py | 4 +- delta/utils/model.py | 7 -- delta/utils/solver/estimator_solver.py | 81 ++++++++++++++++++++++- egs/voxceleb/spk/v1/conf/tdnn_arcface.yml | 6 ++ egs/voxceleb/spk/v1/conf/tdnn_softmax.yml | 6 ++ 5 files changed, 93 insertions(+), 11 deletions(-) diff --git a/delta/utils/misc.py b/delta/utils/misc.py index 2e8a5d18..146d134e 100644 --- a/delta/utils/misc.py +++ b/delta/utils/misc.py @@ -63,9 +63,9 @@ def len_to_padding(length, maxlen=None, dtype=tf.bool): def log_vars(prefix, variables): ''' logging TF varables metadata ''' + logging.info(f"{prefix}:") for var in variables: - logging.info("{}: name: {} shape: {} device: {}".format( - prefix, var.name, var.shape, var.device)) + logging.info(f"\tname = {var.name}, shape = {var.shape}, device = {var.device}") #pylint: disable=bad-continuation diff --git a/delta/utils/model.py b/delta/utils/model.py index 6e90c31a..a0e37ab3 100644 --- a/delta/utils/model.py +++ b/delta/utils/model.py @@ -30,13 +30,6 @@ def print_ops(graph, prefix=''): logging.info('{} : op name: {}'.format(prefix, operator.name)) -def log_vars(prefix, variables): - """Print tensorflow variables.""" - for var in variables: - logging.info("{}: name: {} shape: {} device: {}".format( - prefix, var.name, var.shape, var.device)) - - def model_size(variables): """Get model size.""" total_params = sum( diff --git a/delta/utils/solver/estimator_solver.py b/delta/utils/solver/estimator_solver.py index 84d360e2..3a39f0d8 100644 --- a/delta/utils/solver/estimator_solver.py +++ b/delta/utils/solver/estimator_solver.py @@ -15,7 +15,9 @@ # ============================================================================== ''' Estimator base class for classfication ''' import os +import re import functools +import collections from absl import logging import delta.compat as tf from tensorflow.python import debug as tf_debug #pylint: disable=no-name-in-module @@ -72,6 +74,80 @@ def l2_loss(self, tvars=None): summary_lib.scalar('l2_loss', _l2_loss) return _l2_loss + def get_assignment_map_from_checkpoint(self, tvars, init_checkpoint): + """Compute the union of the current variables and checkpoint variables.""" + assignment_map = {} + initialized_variable_names = {} + + name_to_variable = collections.OrderedDict() + for var in tvars: + name = var.name + m = re.match("^(.*):\\d+$", name) + if m is not None: + name = m.group(1) + name_to_variable[name] = var + + init_vars = tf.train.list_variables(init_checkpoint) + + assignment_map = collections.OrderedDict() + for x in init_vars: + (name, var) = (x[0], x[1]) + if name not in name_to_variable: + continue + assignment_map[name] = name + initialized_variable_names[name] = 1 + initialized_variable_names[name + ":0"] = 1 + + return (assignment_map, initialized_variable_names) + + def init_from_checkpoint(self): + ''' do transfer learning by init sub vars from other checkpoint. ''' + if 'transfer' not in self.config['solver']: + return + transfer_cfg = self.config['solver']['transfer'] + enable = transfer_cfg['enable'] + if not enable: + return + init_checkpoint = transfer_cfg['ckpt_path'] + exclude = transfer_cfg['exclude_reg'] + include = transfer_cfg['include_reg'] + logging.info(f"Transfer from checkpoint: {init_checkpoint}") + logging.info(f"Transfer exclude: {exclude}") + logging.info(f"Transfer include: {include}") + + tvars = tf.trainable_variables() + initialized_variable_names = {} + if init_checkpoint: + def _filter_by_reg(tvars, include, exclude): + include = include if include else [] + exclude = exclude if exclude else [] + outs = [] + for var in tvars: + name = var.name + for reg_str in include: + logging.debug(f"var:{name}, reg: {reg_str}") + m = re.match(reg_str, name) + if m is not None: + outs.append(var) + for reg_str in exclude: + logging.debug(f"var:{name}, reg: {reg_str}") + m = re.match(reg_str, name) + if m is None: + outs.append(var) + return outs + tvars = _filter_by_reg(tvars, include, exclude) + assignment_map, initialized_variable_names = \ + self.get_assignment_map_from_checkpoint(tvars, init_checkpoint) + tf.train.init_from_checkpoint(init_checkpoint, assignment_map) + + logging.info("**** Trainable Variables ****") + for var in tvars: + init_string = "" + if var.name in initialized_variable_names: + init_string = ", *INIT_FROM_CKPT*" + logging.info(" name = %s, shape = %s%s", var.name, var.shape, + init_string) + def model_fn(self): ''' return model_fn ''' model_class = super().model_fn() @@ -144,10 +220,11 @@ def _model_fn(features, labels, mode, params): # L2 loss loss_all += self.l2_loss() + utils.log_vars('****** Global Vars *****', tf.global_variables()) + self.init_from_checkpoint() train_op = self.get_train_op(loss_all) train_hooks = self.get_train_hooks(labels, logits, alpha=alignment) - utils.log_vars('Global Vars', tf.global_variables()) return tf.estimator.EstimatorSpec( #pylint: disable=no-member mode=mode, loss=loss_all, @@ -179,7 +256,7 @@ def create_estimator(self): # multi-gpus devices, num_gpu = utils.gpu_device_names() distribution = utils.get_distribution_strategy(num_gpu) - logging.info('Device: {}/{}'.format(num_gpu, devices)) + logging.info('Device: num = {}, list = {}'.format(num_gpu, devices)) # run config tfconf = self.config['solver']['run_config'] diff --git a/egs/voxceleb/spk/v1/conf/tdnn_arcface.yml b/egs/voxceleb/spk/v1/conf/tdnn_arcface.yml index 02174afd..b574a17f 100644 --- a/egs/voxceleb/spk/v1/conf/tdnn_arcface.yml +++ b/egs/voxceleb/spk/v1/conf/tdnn_arcface.yml @@ -207,6 +207,12 @@ solver: eval_on_dev_every_secs: 1 print_every: 50 resume_model_path: "" + transfer: + enable: false # transfer from checkpoint + ckpt_path: null + exclude_reg: + - '^.*/logits/weights:\d+$' + include_reg: null run_config: debug: false # use tfdbug tf_random_seed: null # 0-2**32; null is None, try to read data from /dev/urandom if available or seed from the clock otherwise diff --git a/egs/voxceleb/spk/v1/conf/tdnn_softmax.yml b/egs/voxceleb/spk/v1/conf/tdnn_softmax.yml index 0687e9df..476acd21 100644 --- a/egs/voxceleb/spk/v1/conf/tdnn_softmax.yml +++ b/egs/voxceleb/spk/v1/conf/tdnn_softmax.yml @@ -207,6 +207,12 @@ solver: eval_on_dev_every_secs: 1 print_every: 50 resume_model_path: "" + transfer: + enable: false # transfer from checkpoint + ckpt_path: null + exclude_reg: + - '^.*/logits/weights:\d+$' + include_reg: null run_config: debug: false # use tfdbug tf_random_seed: null # 0-2**32; null is None, try to read data from /dev/urandom if available or seed from the clock otherwise From 3b5e8d92f9e5add1262cb82fbce3a82bac6337dc Mon Sep 17 00:00:00 2001 From: Hui Zhang Date: Thu, 16 Jul 2020 16:07:16 +0800 Subject: [PATCH 2/2] format --- delta/utils/misc.py | 3 +- delta/utils/solver/estimator_solver.py | 8 +-- deltann/api/c_api.cc | 8 +-- deltann/core/buffer.h | 8 +-- deltann/core/io.h | 4 +- deltann/core/runtime.cc | 5 +- deltann/core/runtime.h | 2 +- deltann/core/tfmodel.cc | 52 +++++++++---------- deltann/examples/speaker/test.cc | 3 +- .../infer/delta_infer/core/scatter_search.cc | 3 +- .../platform/CUDA/transformer_functor_cu.cc | 2 +- .../custom_ops/transformer_cell.cc | 5 +- .../custom_ops/transformer_cell_bert.cc | 5 +- .../custom_ops/transformer_cell_nlp.cc | 9 ++-- .../delta_infer/cpp/delta_cpp_export_py.cc | 10 ++-- deltann/infer/python/delta_infer/optimizer.py | 2 +- 16 files changed, 63 insertions(+), 66 deletions(-) diff --git a/delta/utils/misc.py b/delta/utils/misc.py index 146d134e..6059f820 100644 --- a/delta/utils/misc.py +++ b/delta/utils/misc.py @@ -65,7 +65,8 @@ def log_vars(prefix, variables): ''' logging TF varables metadata ''' logging.info(f"{prefix}:") for var in variables: - logging.info(f"\tname = {var.name}, shape = {var.shape}, device = {var.device}") + logging.info( + f"\tname = {var.name}, shape = {var.shape}, device = {var.device}") #pylint: disable=bad-continuation diff --git a/delta/utils/solver/estimator_solver.py b/delta/utils/solver/estimator_solver.py index 3a39f0d8..1b8403c2 100644 --- a/delta/utils/solver/estimator_solver.py +++ b/delta/utils/solver/estimator_solver.py @@ -103,7 +103,7 @@ def get_assignment_map_from_checkpoint(self, tvars, init_checkpoint): def init_from_checkpoint(self): ''' do transfer learning by init sub vars from other checkpoint. ''' if 'transfer' not in self.config['solver']: - return + return transfer_cfg = self.config['solver']['transfer'] enable = transfer_cfg['enable'] if not enable: @@ -118,6 +118,7 @@ def init_from_checkpoint(self): tvars = tf.trainable_variables() initialized_variable_names = {} if init_checkpoint: + def _filter_by_reg(tvars, include, exclude): include = include if include else [] exclude = exclude if exclude else [] @@ -135,7 +136,8 @@ def _filter_by_reg(tvars, include, exclude): if m is None: outs.append(var) return outs - tvars = _filter_by_reg(tvars, include, exclude) + + tvars = _filter_by_reg(tvars, include, exclude) assignment_map, initialized_variable_names = \ self.get_assignment_map_from_checkpoint(tvars, init_checkpoint) tf.train.init_from_checkpoint(init_checkpoint, assignment_map) @@ -146,7 +148,7 @@ def _filter_by_reg(tvars, include, exclude): if var.name in initialized_variable_names: init_string = ", *INIT_FROM_CKPT*" logging.info(" name = %s, shape = %s%s", var.name, var.shape, - init_string) + init_string) def model_fn(self): ''' return model_fn ''' diff --git a/deltann/api/c_api.cc b/deltann/api/c_api.cc index 03bac7e1..b4d80b4b 100644 --- a/deltann/api/c_api.cc +++ b/deltann/api/c_api.cc @@ -14,9 +14,9 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include #include #include -#include #include "api/c_api.h" #include "core/config.h" @@ -47,10 +47,10 @@ DeltaStatus DeltaSetInputs(InferHandel inf, Input* inputs, int num) { Runtime* rt = static_cast(inf); std::vector ins; for (int i = 0; i < num; ++i) { - //std::cout << "set inputs name : " << inputs[i].input_name << "\n"; - //std::cout << "set inputs nelms: " << inputs[i].nelms << "\n"; + // std::cout << "set inputs name : " << inputs[i].input_name << "\n"; + // std::cout << "set inputs nelms: " << inputs[i].nelms << "\n"; - const int *data = static_cast(inputs[i].ptr); + const int* data = static_cast(inputs[i].ptr); if (inputs[i].shape == NULL) { ins.push_back(In(inputs[i].graph_name, inputs[i].input_name, inputs[i].ptr, inputs[i].nelms)); diff --git a/deltann/core/buffer.h b/deltann/core/buffer.h index 69cd4ba8..f55cf97a 100644 --- a/deltann/core/buffer.h +++ b/deltann/core/buffer.h @@ -85,16 +85,16 @@ class Buffer { void copy_from(const void* src, const std::size_t size) { DELTA_CHECK(_ptr); DELTA_CHECK(src); - DELTA_CHECK(size <= _size) - << "expect size: " << size << " real size:" << _size; + DELTA_CHECK(size <= _size) << "expect size: " << size + << " real size:" << _size; std::memcpy(_ptr, src, size); } void copy_to(void* dst, const std::size_t size) { DELTA_CHECK(_ptr); DELTA_CHECK(dst); - DELTA_CHECK(size <= _size) - << "expect size: " << size << " real size:" << _size; + DELTA_CHECK(size <= _size) << "expect size: " << size + << " real size:" << _size; std::memcpy(dst, _ptr, size); } diff --git a/deltann/core/io.h b/deltann/core/io.h index 0cfe60fa..69dde7e0 100644 --- a/deltann/core/io.h +++ b/deltann/core/io.h @@ -185,7 +185,7 @@ class BaseInOutData { std::size_t bytes = nelms * delta_dtype_size(dtype); this->resize(bytes); _data->copy_from(src, bytes); -} + } void copy_from(const float* src) { copy_from(src, this->nelms()); } @@ -203,7 +203,7 @@ class BaseInOutData { #ifdef USE_TF tensorflow::TensorShape tensor_shape() const { - //tensorflow::Status::Status status; + // tensorflow::Status::Status status; const Shape& shape = this->shape(); tensorflow::TensorShape ts; auto s = shape.vec(); diff --git a/deltann/core/runtime.cc b/deltann/core/runtime.cc index 0440bc3e..1552cac1 100644 --- a/deltann/core/runtime.cc +++ b/deltann/core/runtime.cc @@ -151,9 +151,8 @@ DeltaStatus Runtime::set_inputs(const std::vector& ins) { << in._shape; input.set_shape(in._shape); } - DELTA_CHECK_EQ(in._nelms, input.nelms()) - << in._nelms << ":" - << input.nelms(); + DELTA_CHECK_EQ(in._nelms, input.nelms()) << in._nelms << ":" + << input.nelms(); InputData input_data(input); input_data.copy_from(in._ptr, in._nelms); diff --git a/deltann/core/runtime.h b/deltann/core/runtime.h index ba7c5dda..ed4fe421 100644 --- a/deltann/core/runtime.h +++ b/deltann/core/runtime.h @@ -61,7 +61,7 @@ struct In { std::string _graph_name; std::string _input_name; const void* _ptr; - std::size_t _nelms; // elements + std::size_t _nelms; // elements Shape _shape; }; diff --git a/deltann/core/tfmodel.cc b/deltann/core/tfmodel.cc index 58621b18..3223a36f 100644 --- a/deltann/core/tfmodel.cc +++ b/deltann/core/tfmodel.cc @@ -74,28 +74,28 @@ TFModel::TFModel(ModelMeta model_meta, int num_threads) void TFModel::feed_tensor(Tensor* tensor, const InputData& input) { std::int64_t num_elements = tensor->NumElements(); switch (input.dtype()) { - case DataType::DELTA_FLOAT32:{ - std::cout << "input: " << num_elements << " " << tensor->TotalBytes() << std::endl; + case DataType::DELTA_FLOAT32: { + std::cout << "input: " << num_elements << " " << tensor->TotalBytes() + << std::endl; auto ptr = tensor->flat().data(); std::fill_n(ptr, num_elements, 0.0); - std::copy_n(static_cast(input.ptr()), num_elements, - ptr); + std::copy_n(static_cast(input.ptr()), num_elements, ptr); break; } - case DataType::DELTA_INT32:{ + case DataType::DELTA_INT32: { std::copy_n(static_cast(input.ptr()), num_elements, tensor->flat().data()); break; - } + } case DataType::DELTA_CHAR: { char* cstr = static_cast(input.ptr()); std::string str = std::string(cstr); tensor->scalar()() = str; break; } - default:{ + default: { LOG_FATAL << "Not support dtype:" << delta_dtype_str(input.dtype()); - } + } } } @@ -107,7 +107,7 @@ void TFModel::fetch_tensor(const Tensor& tensor, OutputData* output) { // copy data std::size_t num_elements = tensor.NumElements(); std::size_t total_bytes = tensor.TotalBytes(); - std::cout << "output: " << num_elements << " " << total_bytes << "\n"; + std::cout << "output: " << num_elements << " " << total_bytes << "\n"; DELTA_CHECK(num_elements == output->nelms()) << "expect " << num_elements << "elems, but given " << output->nelms(); @@ -186,33 +186,32 @@ int TFModel::run(const std::vector& inputs, set_feeds(&feeds, inputs); set_fetches(&fetches, *output); - //std::cout << "input xxxxxxxxxxxxxxxxx"<< "\n"; - //auto ti = feeds[0].second; - //for (auto i = 0; i < ti.NumElements(); i++){ + // std::cout << "input xxxxxxxxxxxxxxxxx"<< "\n"; + // auto ti = feeds[0].second; + // for (auto i = 0; i < ti.NumElements(); i++){ // std::cout << std::showpoint << ti.flat()(i) << " "; // if (i % 40 == 1){std::cout << "\n";} //} - //std::cout << "\n"; - //std::cout << "input -------------------"<< "\n"; - + // std::cout << "\n"; + // std::cout << "input -------------------"<< "\n"; // Session run RunOptions run_options; RunMetadata run_meta; - tensorflow::Status s = _bundle.GetSession()->Run(run_options, feeds, fetches, {}, - &output_tensors, &run_meta); + tensorflow::Status s = _bundle.GetSession()->Run( + run_options, feeds, fetches, {}, &output_tensors, &run_meta); if (!s.ok()) { LOG_FATAL << "Error, TF Model run failed: " << s; exit(-1); } - //std::cout << "output xxxxxxxxxxxxxxxxx"<< "\n"; - //auto t = output_tensors[0]; - //for (auto i = 0; i < t.NumElements(); i++){ + // std::cout << "output xxxxxxxxxxxxxxxxx"<< "\n"; + // auto t = output_tensors[0]; + // for (auto i = 0; i < t.NumElements(); i++){ // std::cout << std::showpoint << t.flat()(i) << " "; //} - //std::cout << "\n"; - //std::cout << "output -------------------"<< "\n"; + // std::cout << "\n"; + // std::cout << "output -------------------"<< "\n"; get_featches(output_tensors, output); @@ -287,13 +286,14 @@ DeltaStatus TFModel::load_from_saved_model() { LOG_INFO << "load saved model from path: " << path; if (!MaybeSavedModelDirectory(path)) { LOG_FATAL << "SaveModel not in :" << path; - return DeltaStatus::STATUS_ERROR; + return DeltaStatus::STATUS_ERROR; } - tensorflow::Status s = LoadSavedModel(options, run_options, path, - {tensorflow::kSavedModelTagServe}, &_bundle); + tensorflow::Status s = LoadSavedModel( + options, run_options, path, {tensorflow::kSavedModelTagServe}, &_bundle); if (!s.ok()) { - LOG_FATAL << "Failed Load model from saved_model.pb : " << s.error_message(); + LOG_FATAL << "Failed Load model from saved_model.pb : " + << s.error_message(); } return DeltaStatus::STATUS_OK; diff --git a/deltann/examples/speaker/test.cc b/deltann/examples/speaker/test.cc index 2b494e81..25ed0805 100644 --- a/deltann/examples/speaker/test.cc +++ b/deltann/examples/speaker/test.cc @@ -74,7 +74,8 @@ struct DeltaModel { } DeltaStatus SetInputs(T* buf, const std::vector shape) { - return this->SetInputs(buf, this->NumElems(shape), shape.data(), shape.size()); + return this->SetInputs(buf, this->NumElems(shape), shape.data(), + shape.size()); } DeltaStatus SetInputs(T* buf, int nelms, const int* shape, const int ndims) { diff --git a/deltann/infer/delta_infer/core/scatter_search.cc b/deltann/infer/delta_infer/core/scatter_search.cc index 5c94c9b6..6eeaa2a7 100644 --- a/deltann/infer/delta_infer/core/scatter_search.cc +++ b/deltann/infer/delta_infer/core/scatter_search.cc @@ -11,8 +11,7 @@ struct Compare { Compare() { std::vector> op_maps = { { - "BatchMatMulV2", - "BatchMatMul", + "BatchMatMulV2", "BatchMatMul", }, {"Const", "Shape"}, }; diff --git a/deltann/infer/delta_infer/custom_ops/platform/CUDA/transformer_functor_cu.cc b/deltann/infer/delta_infer/custom_ops/platform/CUDA/transformer_functor_cu.cc index 82bb9ddd..fd37ece1 100644 --- a/deltann/infer/delta_infer/custom_ops/platform/CUDA/transformer_functor_cu.cc +++ b/deltann/infer/delta_infer/custom_ops/platform/CUDA/transformer_functor_cu.cc @@ -211,7 +211,7 @@ void TransformerCellFunctor::operator()( sum); }; - // TODO +// TODO #endif } diff --git a/deltann/infer/delta_infer/custom_ops/transformer_cell.cc b/deltann/infer/delta_infer/custom_ops/transformer_cell.cc index 87fc470f..ead3ab05 100644 --- a/deltann/infer/delta_infer/custom_ops/transformer_cell.cc +++ b/deltann/infer/delta_infer/custom_ops/transformer_cell.cc @@ -230,9 +230,8 @@ class TansformerCellOp : public OpKernel { errors::InvalidArgument("output_layernorm_gamma is null")); Tensor *output = nullptr; OP_REQUIRES_OK(context, context->allocate_output( - 0, - {_param.batch_size * _param.from_seq_len, - _param.head_num * _param.size_per_head}, + 0, {_param.batch_size * _param.from_seq_len, + _param.head_num * _param.size_per_head}, &output)); _param.transformer_out = reinterpret_cast(output->flat().data()); diff --git a/deltann/infer/delta_infer/custom_ops/transformer_cell_bert.cc b/deltann/infer/delta_infer/custom_ops/transformer_cell_bert.cc index 54c57757..5353c53a 100644 --- a/deltann/infer/delta_infer/custom_ops/transformer_cell_bert.cc +++ b/deltann/infer/delta_infer/custom_ops/transformer_cell_bert.cc @@ -230,9 +230,8 @@ class TansformerCellBertOp : public OpKernel { errors::InvalidArgument("output_layernorm_gamma is null")); Tensor *output = nullptr; OP_REQUIRES_OK(context, context->allocate_output( - 0, - {_param.batch_size * _param.from_seq_len, - _param.head_num * _param.size_per_head}, + 0, {_param.batch_size * _param.from_seq_len, + _param.head_num * _param.size_per_head}, &output)); _param.transformer_out = reinterpret_cast(output->flat().data()); diff --git a/deltann/infer/delta_infer/custom_ops/transformer_cell_nlp.cc b/deltann/infer/delta_infer/custom_ops/transformer_cell_nlp.cc index 378de38e..805f7c57 100644 --- a/deltann/infer/delta_infer/custom_ops/transformer_cell_nlp.cc +++ b/deltann/infer/delta_infer/custom_ops/transformer_cell_nlp.cc @@ -434,11 +434,10 @@ class TansformerCellNLPOp : public OpKernel { errors::InvalidArgument("ff_layer_2_1_bias is null")); Tensor *output = nullptr; - OP_REQUIRES_OK( - context, - context->allocate_output( - 0, {w_shape.dim_size(0), w_shape.dim_size(1), w_shape.dim_size(2)}, - &output)); + OP_REQUIRES_OK(context, context->allocate_output( + 0, {w_shape.dim_size(0), w_shape.dim_size(1), + w_shape.dim_size(2)}, + &output)); _param.transformer_out = reinterpret_cast(output->flat().data()); // initial diff --git a/deltann/infer/python/delta_infer/cpp/delta_cpp_export_py.cc b/deltann/infer/python/delta_infer/cpp/delta_cpp_export_py.cc index e565380e..8d006a82 100644 --- a/deltann/infer/python/delta_infer/cpp/delta_cpp_export_py.cc +++ b/deltann/infer/python/delta_infer/cpp/delta_cpp_export_py.cc @@ -24,9 +24,8 @@ PYBIND11_MODULE(export_py, module) { .def("LoadModel", delta_overload_cast()(&grappler::Pattern::LoadModel), "Load tf model from pb files.") - .def("LoadModelCT", - delta_overload_cast()( - &grappler::Pattern::LoadModelCT), + .def("LoadModelCT", delta_overload_cast()( + &grappler::Pattern::LoadModelCT), "Load tf model from GraphDef string buffer.") .def("get_hint_node", &grappler::Pattern::get_hint_node) .def("set_hint_node_type", &grappler::Pattern::set_hint_node_type) @@ -40,9 +39,8 @@ PYBIND11_MODULE(export_py, module) { .def("node", delta_overload_cast()(&grappler::Pattern::node, py::const_), "Get Node by index.") - .def("node", - delta_overload_cast()(&grappler::Pattern::node, - py::const_), + .def("node", delta_overload_cast()( + &grappler::Pattern::node, py::const_), "Get Node by name.") .def_property_readonly("is_input", &grappler::Pattern::is_input); //.def(py::self == py::self) diff --git a/deltann/infer/python/delta_infer/optimizer.py b/deltann/infer/python/delta_infer/optimizer.py index dc6421ba..8d7cec8b 100644 --- a/deltann/infer/python/delta_infer/optimizer.py +++ b/deltann/infer/python/delta_infer/optimizer.py @@ -32,7 +32,7 @@ def __run(self): def __hint_op_type(self, pattern_name): assert (pattern_name in self.__hint_map), \ "Pattern name({}) with hint op must be registered by \ - function register_hint_op." .format(pattern_name) + function register_hint_op." .format(pattern_name) return self.__hint_map[pattern_name] def register_hint_op(self, pattern_name, hint_op_type):