Skip to content

Commit

Permalink
Merge pull request #223 from didi/transfer
Browse files Browse the repository at this point in the history
Transfer learning
  • Loading branch information
zh794390558 authored Jul 16, 2020
2 parents 0a06c4a + 3b5e8d9 commit 58a52cc
Show file tree
Hide file tree
Showing 19 changed files with 152 additions and 73 deletions.
5 changes: 3 additions & 2 deletions delta/utils/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,9 +63,10 @@ def len_to_padding(length, maxlen=None, dtype=tf.bool):

def log_vars(prefix, variables):
''' logging TF varables metadata '''
logging.info(f"{prefix}:")
for var in variables:
logging.info("{}: name: {} shape: {} device: {}".format(
prefix, var.name, var.shape, var.device))
logging.info(
f"\tname = {var.name}, shape = {var.shape}, device = {var.device}")


#pylint: disable=bad-continuation
Expand Down
7 changes: 0 additions & 7 deletions delta/utils/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,13 +30,6 @@ def print_ops(graph, prefix=''):
logging.info('{} : op name: {}'.format(prefix, operator.name))


def log_vars(prefix, variables):
"""Print tensorflow variables."""
for var in variables:
logging.info("{}: name: {} shape: {} device: {}".format(
prefix, var.name, var.shape, var.device))


def model_size(variables):
"""Get model size."""
total_params = sum(
Expand Down
83 changes: 81 additions & 2 deletions delta/utils/solver/estimator_solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# ==============================================================================
''' Estimator base class for classfication '''
import os
import re
import functools
import collections
from absl import logging
import delta.compat as tf
from tensorflow.python import debug as tf_debug #pylint: disable=no-name-in-module
Expand Down Expand Up @@ -72,6 +74,82 @@ def l2_loss(self, tvars=None):
summary_lib.scalar('l2_loss', _l2_loss)
return _l2_loss

def get_assignment_map_from_checkpoint(self, tvars, init_checkpoint):
"""Compute the union of the current variables and checkpoint variables."""
assignment_map = {}
initialized_variable_names = {}

name_to_variable = collections.OrderedDict()
for var in tvars:
name = var.name
m = re.match("^(.*):\\d+$", name)
if m is not None:
name = m.group(1)
name_to_variable[name] = var

init_vars = tf.train.list_variables(init_checkpoint)

assignment_map = collections.OrderedDict()
for x in init_vars:
(name, var) = (x[0], x[1])
if name not in name_to_variable:
continue
assignment_map[name] = name
initialized_variable_names[name] = 1
initialized_variable_names[name + ":0"] = 1

return (assignment_map, initialized_variable_names)

def init_from_checkpoint(self):
''' do transfer learning by init sub vars from other checkpoint. '''
if 'transfer' not in self.config['solver']:
return
transfer_cfg = self.config['solver']['transfer']
enable = transfer_cfg['enable']
if not enable:
return
init_checkpoint = transfer_cfg['ckpt_path']
exclude = transfer_cfg['exclude_reg']
include = transfer_cfg['include_reg']
logging.info(f"Transfer from checkpoint: {init_checkpoint}")
logging.info(f"Transfer exclude: {exclude}")
logging.info(f"Transfer include: {include}")

tvars = tf.trainable_variables()
initialized_variable_names = {}
if init_checkpoint:

def _filter_by_reg(tvars, include, exclude):
include = include if include else []
exclude = exclude if exclude else []
outs = []
for var in tvars:
name = var.name
for reg_str in include:
logging.debug(f"var:{name}, reg: {reg_str}")
m = re.match(reg_str, name)
if m is not None:
outs.append(var)
for reg_str in exclude:
logging.debug(f"var:{name}, reg: {reg_str}")
m = re.match(reg_str, name)
if m is None:
outs.append(var)
return outs

tvars = _filter_by_reg(tvars, include, exclude)
assignment_map, initialized_variable_names = \
self.get_assignment_map_from_checkpoint(tvars, init_checkpoint)
tf.train.init_from_checkpoint(init_checkpoint, assignment_map)

logging.info("**** Trainable Variables ****")
for var in tvars:
init_string = ""
if var.name in initialized_variable_names:
init_string = ", *INIT_FROM_CKPT*"
logging.info(" name = %s, shape = %s%s", var.name, var.shape,
init_string)

def model_fn(self):
''' return model_fn '''
model_class = super().model_fn()
Expand Down Expand Up @@ -144,10 +222,11 @@ def _model_fn(features, labels, mode, params):
# L2 loss
loss_all += self.l2_loss()

utils.log_vars('****** Global Vars *****', tf.global_variables())
self.init_from_checkpoint()
train_op = self.get_train_op(loss_all)
train_hooks = self.get_train_hooks(labels, logits, alpha=alignment)

utils.log_vars('Global Vars', tf.global_variables())
return tf.estimator.EstimatorSpec( #pylint: disable=no-member
mode=mode,
loss=loss_all,
Expand Down Expand Up @@ -179,7 +258,7 @@ def create_estimator(self):
# multi-gpus
devices, num_gpu = utils.gpu_device_names()
distribution = utils.get_distribution_strategy(num_gpu)
logging.info('Device: {}/{}'.format(num_gpu, devices))
logging.info('Device: num = {}, list = {}'.format(num_gpu, devices))

# run config
tfconf = self.config['solver']['run_config']
Expand Down
8 changes: 4 additions & 4 deletions deltann/api/c_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <iostream>
#include <string>
#include <vector>
#include <iostream>

#include "api/c_api.h"
#include "core/config.h"
Expand Down Expand Up @@ -47,10 +47,10 @@ DeltaStatus DeltaSetInputs(InferHandel inf, Input* inputs, int num) {
Runtime* rt = static_cast<Runtime*>(inf);
std::vector<In> ins;
for (int i = 0; i < num; ++i) {
//std::cout << "set inputs name : " << inputs[i].input_name << "\n";
//std::cout << "set inputs nelms: " << inputs[i].nelms << "\n";
// std::cout << "set inputs name : " << inputs[i].input_name << "\n";
// std::cout << "set inputs nelms: " << inputs[i].nelms << "\n";

const int *data = static_cast<const int*>(inputs[i].ptr);
const int* data = static_cast<const int*>(inputs[i].ptr);
if (inputs[i].shape == NULL) {
ins.push_back(In(inputs[i].graph_name, inputs[i].input_name,
inputs[i].ptr, inputs[i].nelms));
Expand Down
8 changes: 4 additions & 4 deletions deltann/core/buffer.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,16 @@ class Buffer {
void copy_from(const void* src, const std::size_t size) {
DELTA_CHECK(_ptr);
DELTA_CHECK(src);
DELTA_CHECK(size <= _size)
<< "expect size: " << size << " real size:" << _size;
DELTA_CHECK(size <= _size) << "expect size: " << size
<< " real size:" << _size;
std::memcpy(_ptr, src, size);
}

void copy_to(void* dst, const std::size_t size) {
DELTA_CHECK(_ptr);
DELTA_CHECK(dst);
DELTA_CHECK(size <= _size)
<< "expect size: " << size << " real size:" << _size;
DELTA_CHECK(size <= _size) << "expect size: " << size
<< " real size:" << _size;
std::memcpy(dst, _ptr, size);
}

Expand Down
4 changes: 2 additions & 2 deletions deltann/core/io.h
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class BaseInOutData {
std::size_t bytes = nelms * delta_dtype_size(dtype);
this->resize(bytes);
_data->copy_from(src, bytes);
}
}

void copy_from(const float* src) { copy_from(src, this->nelms()); }

Expand All @@ -203,7 +203,7 @@ class BaseInOutData {

#ifdef USE_TF
tensorflow::TensorShape tensor_shape() const {
//tensorflow::Status::Status status;
// tensorflow::Status::Status status;
const Shape& shape = this->shape();
tensorflow::TensorShape ts;
auto s = shape.vec();
Expand Down
5 changes: 2 additions & 3 deletions deltann/core/runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,8 @@ DeltaStatus Runtime::set_inputs(const std::vector<In>& ins) {
<< in._shape;
input.set_shape(in._shape);
}
DELTA_CHECK_EQ(in._nelms, input.nelms())
<< in._nelms << ":"
<< input.nelms();
DELTA_CHECK_EQ(in._nelms, input.nelms()) << in._nelms << ":"
<< input.nelms();

InputData input_data(input);
input_data.copy_from(in._ptr, in._nelms);
Expand Down
2 changes: 1 addition & 1 deletion deltann/core/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct In {
std::string _graph_name;
std::string _input_name;
const void* _ptr;
std::size_t _nelms; // elements
std::size_t _nelms; // elements
Shape _shape;
};

Expand Down
52 changes: 26 additions & 26 deletions deltann/core/tfmodel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,28 +74,28 @@ TFModel::TFModel(ModelMeta model_meta, int num_threads)
void TFModel::feed_tensor(Tensor* tensor, const InputData& input) {
std::int64_t num_elements = tensor->NumElements();
switch (input.dtype()) {
case DataType::DELTA_FLOAT32:{
std::cout << "input: " << num_elements << " " << tensor->TotalBytes() << std::endl;
case DataType::DELTA_FLOAT32: {
std::cout << "input: " << num_elements << " " << tensor->TotalBytes()
<< std::endl;
auto ptr = tensor->flat<float>().data();
std::fill_n(ptr, num_elements, 0.0);
std::copy_n(static_cast<float*>(input.ptr()), num_elements,
ptr);
std::copy_n(static_cast<float*>(input.ptr()), num_elements, ptr);
break;
}
case DataType::DELTA_INT32:{
case DataType::DELTA_INT32: {
std::copy_n(static_cast<int*>(input.ptr()), num_elements,
tensor->flat<int>().data());
break;
}
}
case DataType::DELTA_CHAR: {
char* cstr = static_cast<char*>(input.ptr());
std::string str = std::string(cstr);
tensor->scalar<tensorflow::tstring>()() = str;
break;
}
default:{
default: {
LOG_FATAL << "Not support dtype:" << delta_dtype_str(input.dtype());
}
}
}
}

Expand All @@ -107,7 +107,7 @@ void TFModel::fetch_tensor(const Tensor& tensor, OutputData* output) {
// copy data
std::size_t num_elements = tensor.NumElements();
std::size_t total_bytes = tensor.TotalBytes();
std::cout << "output: " << num_elements << " " << total_bytes << "\n";
std::cout << "output: " << num_elements << " " << total_bytes << "\n";
DELTA_CHECK(num_elements == output->nelms())
<< "expect " << num_elements << "elems, but given " << output->nelms();

Expand Down Expand Up @@ -186,33 +186,32 @@ int TFModel::run(const std::vector<InputData>& inputs,
set_feeds(&feeds, inputs);
set_fetches(&fetches, *output);

//std::cout << "input xxxxxxxxxxxxxxxxx"<< "\n";
//auto ti = feeds[0].second;
//for (auto i = 0; i < ti.NumElements(); i++){
// std::cout << "input xxxxxxxxxxxxxxxxx"<< "\n";
// auto ti = feeds[0].second;
// for (auto i = 0; i < ti.NumElements(); i++){
// std::cout << std::showpoint << ti.flat<float>()(i) << " ";
// if (i % 40 == 1){std::cout << "\n";}
//}
//std::cout << "\n";
//std::cout << "input -------------------"<< "\n";

// std::cout << "\n";
// std::cout << "input -------------------"<< "\n";

// Session run
RunOptions run_options;
RunMetadata run_meta;
tensorflow::Status s = _bundle.GetSession()->Run(run_options, feeds, fetches, {},
&output_tensors, &run_meta);
tensorflow::Status s = _bundle.GetSession()->Run(
run_options, feeds, fetches, {}, &output_tensors, &run_meta);
if (!s.ok()) {
LOG_FATAL << "Error, TF Model run failed: " << s;
exit(-1);
}

//std::cout << "output xxxxxxxxxxxxxxxxx"<< "\n";
//auto t = output_tensors[0];
//for (auto i = 0; i < t.NumElements(); i++){
// std::cout << "output xxxxxxxxxxxxxxxxx"<< "\n";
// auto t = output_tensors[0];
// for (auto i = 0; i < t.NumElements(); i++){
// std::cout << std::showpoint << t.flat<float>()(i) << " ";
//}
//std::cout << "\n";
//std::cout << "output -------------------"<< "\n";
// std::cout << "\n";
// std::cout << "output -------------------"<< "\n";

get_featches(output_tensors, output);

Expand Down Expand Up @@ -287,13 +286,14 @@ DeltaStatus TFModel::load_from_saved_model() {
LOG_INFO << "load saved model from path: " << path;
if (!MaybeSavedModelDirectory(path)) {
LOG_FATAL << "SaveModel not in :" << path;
return DeltaStatus::STATUS_ERROR;
return DeltaStatus::STATUS_ERROR;
}

tensorflow::Status s = LoadSavedModel(options, run_options, path,
{tensorflow::kSavedModelTagServe}, &_bundle);
tensorflow::Status s = LoadSavedModel(
options, run_options, path, {tensorflow::kSavedModelTagServe}, &_bundle);
if (!s.ok()) {
LOG_FATAL << "Failed Load model from saved_model.pb : " << s.error_message();
LOG_FATAL << "Failed Load model from saved_model.pb : "
<< s.error_message();
}

return DeltaStatus::STATUS_OK;
Expand Down
3 changes: 2 additions & 1 deletion deltann/examples/speaker/test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,8 @@ struct DeltaModel {
}

DeltaStatus SetInputs(T* buf, const std::vector<int> shape) {
return this->SetInputs(buf, this->NumElems(shape), shape.data(), shape.size());
return this->SetInputs(buf, this->NumElems(shape), shape.data(),
shape.size());
}

DeltaStatus SetInputs(T* buf, int nelms, const int* shape, const int ndims) {
Expand Down
3 changes: 1 addition & 2 deletions deltann/infer/delta_infer/core/scatter_search.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@ struct Compare {
Compare() {
std::vector<absl::flat_hash_set<std::string>> op_maps = {
{
"BatchMatMulV2",
"BatchMatMul",
"BatchMatMulV2", "BatchMatMul",
},
{"Const", "Shape"},
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -211,7 +211,7 @@ void TransformerCellFunctor<GPUDevice, float>::operator()(
sum);
};

// TODO
// TODO
#endif
}

Expand Down
5 changes: 2 additions & 3 deletions deltann/infer/delta_infer/custom_ops/transformer_cell.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,8 @@ class TansformerCellOp : public OpKernel {
errors::InvalidArgument("output_layernorm_gamma is null"));
Tensor *output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
0,
{_param.batch_size * _param.from_seq_len,
_param.head_num * _param.size_per_head},
0, {_param.batch_size * _param.from_seq_len,
_param.head_num * _param.size_per_head},
&output));
_param.transformer_out =
reinterpret_cast<typename traits::DataType *>(output->flat<T>().data());
Expand Down
5 changes: 2 additions & 3 deletions deltann/infer/delta_infer/custom_ops/transformer_cell_bert.cc
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,8 @@ class TansformerCellBertOp : public OpKernel {
errors::InvalidArgument("output_layernorm_gamma is null"));
Tensor *output = nullptr;
OP_REQUIRES_OK(context, context->allocate_output(
0,
{_param.batch_size * _param.from_seq_len,
_param.head_num * _param.size_per_head},
0, {_param.batch_size * _param.from_seq_len,
_param.head_num * _param.size_per_head},
&output));
_param.transformer_out =
reinterpret_cast<typename traits::DataType *>(output->flat<T>().data());
Expand Down
Loading

0 comments on commit 58a52cc

Please sign in to comment.