Skip to content

Commit

Permalink
Merge commit for internal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
yifeif committed Nov 23, 2017
2 parents 79422ab + c0b8a07 commit 0d3a49a
Show file tree
Hide file tree
Showing 143 changed files with 3,091 additions and 1,542 deletions.
22 changes: 22 additions & 0 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -905,6 +905,28 @@ def set_trisycl_include_dir(environ_cp):
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR',
trisycl_include_dir)

def set_trisycl_include_dir(environ_cp):
"""Set TRISYCL_INCLUDE_DIR."""
ask_trisycl_include_dir = ('Please specify the location of the triSYCL '
'include directory. (Use --config=sycl_trisycl '
'when building with Bazel) '
'[Default is %s]: ') % (
_DEFAULT_TRISYCL_INCLUDE_DIR)
while True:
trisycl_include_dir = get_from_env_or_user_or_default(
environ_cp, 'TRISYCL_INCLUDE_DIR', ask_trisycl_include_dir,
_DEFAULT_TRISYCL_INCLUDE_DIR)
if os.path.exists(trisycl_include_dir):
break

print('Invalid triSYCL include directory, %s cannot be found' %
(trisycl_include_dir))

# Set TRISYCL_INCLUDE_DIR
environ_cp['TRISYCL_INCLUDE_DIR'] = trisycl_include_dir
write_action_env_to_bazelrc('TRISYCL_INCLUDE_DIR', trisycl_include_dir)


def set_mpi_home(environ_cp):
"""Set MPI_HOME."""
default_mpi_home = which('mpirun') or which('mpiexec') or ''
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/compiler/aot/tfcompile.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ def tf_library(name, graph, config,
" --cpp_class=" + cpp_class +
" --target_triple=" + target_llvm_triple() +
" --out_session_module=$(@D)/" + session_module_pb +
flags),
" " + flags),
tools=[tfcompile_tool],
visibility=visibility,
testonly=testonly,
Expand Down
6 changes: 4 additions & 2 deletions tensorflow/compiler/tests/fused_batchnorm_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def testInference(self):
# To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
offset = array_ops.placeholder(
np.float32, shape=scale_shape, name="offset")
epsilon = 0.001
y_ref, mean_ref, var_ref = self._reference_training(
x_val, scale_val, offset_val, epsilon, data_format)
Expand Down Expand Up @@ -112,7 +113,8 @@ def _testLearning(self, use_gradient_checker):
# To avoid constant folding
t_val = array_ops.placeholder(np.float32, shape=x_shape, name="x")
scale = array_ops.placeholder(np.float32, shape=scale_shape, name="scale")
offset = array_ops.placeholder(np.float32, shape=scale_shape, name="offset")
offset = array_ops.placeholder(
np.float32, shape=scale_shape, name="offset")
epsilon = 0.001
y, mean, var = nn.fused_batch_norm(
t_val,
Expand Down
11 changes: 10 additions & 1 deletion tensorflow/compiler/xla/client/client.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,15 @@ class Client {
std::vector<GlobalData*> arguments;
ExecutionOptions execution_options;
ExecutionProfile* execution_profile;

ComputationInstance(const Computation& computation,
std::vector<GlobalData*> arguments,
ExecutionOptions execution_options,
ExecutionProfile* execution_profile)
: computation(computation),
arguments(std::move(arguments)),
execution_options(execution_options),
execution_profile(execution_profile) {}
};

// Executes a list ComputationInstances and returns global data produced from
Expand Down Expand Up @@ -133,7 +142,7 @@ class Client {

// Returns a vector of global data handles that point to the tuple elements.
StatusOr<std::vector<std::unique_ptr<GlobalData>>> DeconstructTuple(
const GlobalData& computation);
const GlobalData& data);

// Retrieves the statistics of the given computation.
StatusOr<ComputationStats> GetComputationStats(
Expand Down
63 changes: 44 additions & 19 deletions tensorflow/compiler/xla/service/batchnorm_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -85,9 +85,9 @@ class BatchNormRewriterVisitor : public DfsHloVisitorWithDefault {
HloOpcode opcode) {
HloComputation::Builder b("scalar_computation");
auto scalar_lhs = b.AddInstruction(HloInstruction::CreateParameter(
0, ShapeUtil::MakeShape(F32, {}), "scalar_lhs"));
0, ShapeUtil::MakeShape(primitive_type, {}), "scalar_lhs"));
auto scalar_rhs = b.AddInstruction(HloInstruction::CreateParameter(
1, ShapeUtil::MakeShape(F32, {}), "scalar_rhs"));
1, ShapeUtil::MakeShape(primitive_type, {}), "scalar_rhs"));
auto scalar_op = b.AddInstruction(
HloInstruction::CreateBinary(ShapeUtil::MakeShape(primitive_type, {}),
opcode, scalar_lhs, scalar_rhs));
Expand Down Expand Up @@ -152,22 +152,30 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
// Expand batch norm training into smaller HLO ops.
HloInstruction* operand = batch_norm->mutable_operand(0);
const Shape operand_shape = operand->shape();
PrimitiveType ptype = operand_shape.element_type();
int64 feature_index = batch_norm->feature_index();
const int64 feature_count = operand_shape.dimensions(feature_index);
const int64 size_in_elements = ShapeUtil::ElementsIn(operand_shape);
auto elements_per_feature =
computation_->AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR0<float>(size_in_elements / feature_count)));
auto elements_per_feature_literal =
Literal::CreateR0<float>(size_in_elements / feature_count);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));

HloInstruction* scale = batch_norm->mutable_operand(1);
HloInstruction* offset = batch_norm->mutable_operand(2);
const Shape feature_shape = scale->shape();

auto zero_literal = Literal::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
HloInstruction::CreateConstant(std::move(zero_literal)));

auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
HloInstruction::CreateConstant(std::move(epsilon_literal)));

std::vector<int64> dimensions_without_feature;

Expand All @@ -184,7 +192,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
HloInstruction::CreateBroadcast(operand_shape, offset, {feature_index}));

HloComputation* add_reduce_computation =
GetScalarBinaryComputation(F32, HloOpcode::kAdd);
GetScalarBinaryComputation(ptype, HloOpcode::kAdd);

// X^2.
auto operand_squared =
Expand Down Expand Up @@ -243,8 +251,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormTraining(
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));

auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
HloInstruction::CreateConstant(std::move(neg_half_literal)));

// 1 / Sqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon =
Expand Down Expand Up @@ -286,15 +296,18 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
HloInstruction* operand = batch_norm->mutable_operand(0);
const Shape operand_shape = operand->shape();
int64 feature_index = batch_norm->feature_index();
PrimitiveType ptype = operand_shape.element_type();

HloInstruction* scale = batch_norm->mutable_operand(1);
HloInstruction* offset = batch_norm->mutable_operand(2);
HloInstruction* mean = batch_norm->mutable_operand(3);
HloInstruction* var = batch_norm->mutable_operand(4);
const Shape feature_shape = scale->shape();

auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
HloInstruction::CreateConstant(std::move(epsilon_literal)));

std::vector<int64> dimensions_without_feature;

Expand All @@ -321,8 +334,10 @@ Status BatchNormRewriterVisitor::HandleBatchNormInference(
computation_->AddInstruction(HloInstruction::CreateBinary(
operand_shape, HloOpcode::kAdd, var_broadcasted, epsilon));

auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
HloInstruction::CreateConstant(std::move(neg_half_literal)));

// 1 / Sqrt[Var[X] + epsilon].
auto rsqrt_var_add_epsilon =
Expand Down Expand Up @@ -373,6 +388,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(

HloInstruction* activation = batch_norm->mutable_operand(0);
const Shape activation_shape = activation->shape();
PrimitiveType ptype = activation_shape.element_type();
HloInstruction* scale = batch_norm->mutable_operand(1);
const Shape feature_shape = scale->shape();
HloInstruction* mean = batch_norm->mutable_operand(2);
Expand All @@ -383,18 +399,27 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(

const int64 size_in_elements = ShapeUtil::ElementsIn(activation_shape);
const int64 feature_count = activation_shape.dimensions(feature_index);
auto elements_per_feature =
computation_->AddInstruction(HloInstruction::CreateConstant(
Literal::CreateR0<float>(size_in_elements / feature_count)));

auto elements_per_feature_literal =
Literal::CreateR0<float>(size_in_elements / feature_count);
TF_ASSIGN_OR_RETURN(elements_per_feature_literal,
elements_per_feature_literal->Convert(ptype));
auto elements_per_feature = computation_->AddInstruction(
HloInstruction::CreateConstant(std::move(elements_per_feature_literal)));

auto zero_literal = Literal::CreateR0(0.0f);
TF_ASSIGN_OR_RETURN(zero_literal, zero_literal->Convert(ptype));
auto zero = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(0.0f)));
HloInstruction::CreateConstant(std::move(zero_literal)));

auto neg_half_literal = Literal::CreateR0(-0.5f);
TF_ASSIGN_OR_RETURN(neg_half_literal, neg_half_literal->Convert(ptype));
auto neg_half = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(-0.5f)));
HloInstruction::CreateConstant(std::move(neg_half_literal)));

auto epsilon_literal = Literal::CreateR0(batch_norm->epsilon());
TF_ASSIGN_OR_RETURN(epsilon_literal, epsilon_literal->Convert(ptype));
auto epsilon = computation_->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0(batch_norm->epsilon())));
HloInstruction::CreateConstant(std::move(epsilon_literal)));

std::vector<int64> dimensions_without_feature;

Expand Down Expand Up @@ -442,7 +467,7 @@ Status BatchNormRewriterVisitor::HandleBatchNormGrad(
grad_output, activation_minus_mean));

HloComputation* add_reduce_computation =
GetScalarBinaryComputation(F32, HloOpcode::kAdd);
GetScalarBinaryComputation(ptype, HloOpcode::kAdd);

// sum(Grad[Y] * (X - E[X])).
auto sum_grad_output_times_activiation_minus_mean =
Expand Down
75 changes: 60 additions & 15 deletions tensorflow/compiler/xla/service/cpu/cpu_compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -197,28 +197,35 @@ void InitializeLLVMCommandLineOptions(const HloModuleConfig& config) {
class CollectProfileCandidates : public DfsHloVisitorWithDefault {
public:
static StatusOr<std::unordered_map<const HloInstruction*, size_t>>
GetCandidatesForComputation(HloComputation* computation) {
GetCandidatesForComputation(
HloComputation* computation,
const std::unordered_map<const HloInstruction*, int64>&
assigned_indices) {
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
CollectProfileCandidates profile_candidates_for_computation(
&hlo_to_profile_idx);
&hlo_to_profile_idx, assigned_indices);
TF_RETURN_IF_ERROR(
computation->Accept(&profile_candidates_for_computation));
return hlo_to_profile_idx;
}

private:
explicit CollectProfileCandidates(
std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx)
: hlo_to_profile_idx_(hlo_to_profile_idx) {}
CollectProfileCandidates(
std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx,
const std::unordered_map<const HloInstruction*, int64>& assigned_indices)
: hlo_to_profile_idx_(hlo_to_profile_idx),
assigned_indices_(assigned_indices) {}

Status DefaultAction(HloInstruction* hlo_instruction) override {
hlo_to_profile_idx_->insert({hlo_instruction, hlo_to_profile_idx_->size()});
hlo_to_profile_idx_->insert(
{hlo_instruction, FindOrDie(assigned_indices_, hlo_instruction)});
return Status::OK();
}

Status HandleCall(HloInstruction* call) override {
TF_RETURN_IF_ERROR(DefaultAction(call));
CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_);
CollectProfileCandidates candidates_for_call(hlo_to_profile_idx_,
assigned_indices_);
TF_RETURN_IF_ERROR(call->to_apply()->Accept(&candidates_for_call));
return Status::OK();
}
Expand All @@ -232,17 +239,20 @@ class CollectProfileCandidates : public DfsHloVisitorWithDefault {
Status HandleWhile(HloInstruction* xla_while) override {
TF_RETURN_IF_ERROR(DefaultAction(xla_while));

CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_);
CollectProfileCandidates candidates_for_condition(hlo_to_profile_idx_,
assigned_indices_);
TF_RETURN_IF_ERROR(
xla_while->while_condition()->Accept(&candidates_for_condition));

CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_);
CollectProfileCandidates candidates_for_body(hlo_to_profile_idx_,
assigned_indices_);
TF_RETURN_IF_ERROR(xla_while->while_body()->Accept(&candidates_for_body));

return Status::OK();
}

std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx_;
const std::unordered_map<const HloInstruction*, int64>& assigned_indices_;
};
} // namespace

Expand Down Expand Up @@ -475,10 +485,27 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(

HloComputation* computation = module->entry_computation();
std::unordered_map<const HloInstruction*, size_t> hlo_to_profile_idx;
std::unique_ptr<HloProfileIndexMap> hlo_profile_index_map;
std::unique_ptr<HloProfilePrinter> hlo_profile_printer;
if (module->config().hlo_profiling_enabled()) {
hlo_profile_index_map = MakeUnique<HloProfileIndexMap>(*module);

TF_ASSIGN_OR_RETURN(
hlo_to_profile_idx,
CollectProfileCandidates::GetCandidatesForComputation(computation));
CollectProfileCandidates::GetCandidatesForComputation(
computation, hlo_profile_index_map->instruction_to_profile_idx()));

auto shape_size_bytes = [](const Shape& shape) {
// On the cpu, opaques are pointers.
if (ShapeUtil::IsOpaque(shape)) {
return static_cast<int64>(sizeof(void*));
}
return ShapeUtil::ByteSizeOf(shape, sizeof(void*));
};

HloCostAnalysis cost_analysis(shape_size_bytes);
hlo_profile_printer =
CreateHloProfilePrinter(*hlo_profile_index_map, cost_analysis);
}

std::unique_ptr<Executable> cpu_executable;
Expand Down Expand Up @@ -544,8 +571,16 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
parallel_computations.emplace(to_apply, instruction);
}

// We always profile the entire computation as a whole, even if hlo
// profiling is disabled. When hlo profiling is diabled, we pass in a
// profile counter array of just one element, which corresponds to the whole
// computation.
size_t entry_computation_profile_idx =
hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor(
*module->entry_computation())
: 0;
IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
hlo_to_profile_idx, hlo_to_profile_idx.size(),
hlo_to_profile_idx, entry_computation_profile_idx,
jit->target_machine(), jit->external_constant_pool());

std::unique_ptr<HloInstructionMap<string>> function_names(
Expand Down Expand Up @@ -586,8 +621,8 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
jit->AddModule(std::move(llvm_module));
cpu_executable.reset(new ParallelCpuExecutable(
std::move(jit), std::move(assignment), std::move(module),
std::move(function_names), std::move(hlo_to_profile_idx),
std::move(aligned_constants)));
std::move(function_names), std::move(aligned_constants),
std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));

if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable)
Expand Down Expand Up @@ -620,12 +655,22 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
TF_RETURN_IF_ERROR(protobuf_util::DumpProtoToDirectory(
proto, xla_dump_hlo_proto_to, module->name()));
}
// We always profile the entire computation as a whole, even if hlo
// profiling is disabled. When hlo profiling is diabled, we pass in a
// profile counter array of just one element, which corresponds to the whole
// computation.
size_t entry_computation_profile_idx =
hlo_profile_index_map ? hlo_profile_index_map->GetProfileIndexFor(
*module->entry_computation())
: 0;

// Each computation is a single function. Emit all embedded computations
// before the entry computation. The order of computations returned from
// GetEmbeddedComputations guarantees that a called computation occurs
// before a caller computation.

IrEmitter ir_emitter(*module, *assignment, llvm_module.get(),
hlo_to_profile_idx, hlo_to_profile_idx.size(),
hlo_to_profile_idx, entry_computation_profile_idx,
jit->target_machine(), jit->external_constant_pool());

for (auto embedded_computation :
Expand Down Expand Up @@ -659,7 +704,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
jit->AddModule(std::move(llvm_module));
cpu_executable.reset(new CpuExecutable(
std::move(jit), std::move(assignment), std::move(module), function_name,
std::move(hlo_to_profile_idx)));
std::move(hlo_profile_printer), std::move(hlo_profile_index_map)));

if (embed_ir_in_executable) {
static_cast<CpuExecutable&>(*cpu_executable)
Expand Down
Loading

0 comments on commit 0d3a49a

Please sign in to comment.