Skip to content

Commit

Permalink
Merge commit for internal changes
Browse files Browse the repository at this point in the history
  • Loading branch information
Jonathan Hseu committed Nov 14, 2017
2 parents 7497fca + 2e57e3f commit c4e416f
Show file tree
Hide file tree
Showing 249 changed files with 6,558 additions and 17,741 deletions.
5 changes: 5 additions & 0 deletions configure.py
Original file line number Diff line number Diff line change
Expand Up @@ -998,6 +998,10 @@ def create_android_bazelrc_configs():
write_to_bazelrc('build:android_arm64 --cpu=arm64-v8a')


def set_grpc_build_flags():
write_to_bazelrc('build --define grpc_no_ares=true')


def main():
# Make a copy of os.environ to be clear when functions and getting and setting
# environment variables.
Expand Down Expand Up @@ -1071,6 +1075,7 @@ def main():
set_mpi_home(environ_cp)
set_other_mpi_vars(environ_cp)

set_grpc_build_flags()
set_cc_opt_flags(environ_cp)
set_mkl()
set_monolithic()
Expand Down
27 changes: 27 additions & 0 deletions tensorflow/c/python_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,33 @@ void SetRequestedDevice(TF_Graph* graph, TF_Operation* op, const char* device) {
void UpdateEdge(TF_Graph* graph, TF_Output new_src, TF_Input dst,
TF_Status* status) {
mutex_lock l(graph->mu);
tensorflow::shape_inference::InferenceContext* ic =
graph->refiner.GetContext(&new_src.oper->node);

if (ic->num_outputs() <= new_src.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Output index [", new_src.index,
"] is greater than the number of total outputs [", ic->num_outputs(),
"].");
return;
}
tensorflow::shape_inference::ShapeHandle shape = ic->output(new_src.index);

tensorflow::shape_inference::InferenceContext* ic_dst =
graph->refiner.GetContext(&dst.oper->node);
if (ic_dst->num_inputs() <= dst.index) {
status->status = tensorflow::errors::OutOfRange(
"Cannot update edge. Input index [", dst.index,
"] is greater than the number of total inputs [", ic_dst->num_inputs(),
"].");
return;
}
if (!ic_dst->MergeInput(dst.index, shape)) {
status->status = tensorflow::errors::InvalidArgument(
"Cannot update edge, incompatible shapes: ", ic_dst->DebugString(shape),
" and ", ic_dst->DebugString(ic_dst->input(dst.index)), ".");
return;
}
status->status = graph->graph.UpdateEdge(&new_src.oper->node, new_src.index,
&dst.oper->node, dst.index);
}
Expand Down
50 changes: 44 additions & 6 deletions tensorflow/compiler/tests/binary_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,52 @@ def testComplexOps(self):

self._testBinary(
gen_math_ops._real_div,
np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j, 44 + 3j], dtype=dtype),
np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j, 0], dtype=dtype),
np.array([3, 3j, -1.5j, -8, 2 + 3j, 2 + 4j], dtype=dtype),
np.array([2, -2, 7j, -4j, 4 - 6j, 1 + 2j], dtype=dtype),
expected=np.array(
[1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2],
dtype=dtype))

# Test inf/nan scenarios.
self._testBinary(
gen_math_ops._real_div,
np.array([4 + 3j, 4, 3j, -4, -4j, 2 - 3j], dtype=dtype),
np.array([0, 0, 0, 0, 0, 0], dtype=dtype),
expected=np.array(
[
1.5, -1.5j, -0.2142857, -2j, (2 + 3j) / (4 - 6j), 2,
float("inf")
dtype(1 + 1j) / 0,
dtype(1) / 0,
dtype(1j) / 0,
dtype(-1) / 0,
dtype(-1j) / 0,
dtype(1 - 1j) / 0
],
dtype=dtype))

# TODO(b/65408531): support+test pow for cplx
atan2_supported = self.device == "XLA_GPU"
if atan2_supported:
self._testBinary(
math_ops.pow,
dtype(3 + 2j),
dtype(4 - 5j),
expected=np.power(dtype(3 + 2j), dtype(4 - 5j)))
self._testBinary( # empty rhs
math_ops.pow,
np.array([1 + 2j, 2 - 3j], dtype=dtype),
np.zeros(shape=[0, 2], dtype=dtype),
expected=np.zeros(shape=[0, 2], dtype=dtype))
self._testBinary( # to zero power
math_ops.pow,
np.array([1 + 2j, 2 - 3j], dtype=dtype),
np.zeros(shape=[1, 2], dtype=dtype),
expected=np.ones(shape=[1, 2], dtype=dtype))
lhs = np.array([1 - 2j, 4 + 3j, 2 - 3j, 3, 2j, 1, 4], dtype=dtype)
rhs = np.array([2, 3j, 3 + 4j, 2 + 3j, 3 - 2j, 2, 3 + 3j], dtype=dtype)
scalar = dtype(2 + 2j)
self._testBinary(math_ops.pow, lhs, rhs, expected=np.power(lhs, rhs))
self._testBinary(
math_ops.pow, scalar, rhs, expected=np.power(scalar, rhs))
self._testBinary(math_ops.pow, lhs, scalar, np.power(lhs, scalar))

lhs = np.array([4 + 2j, -3 - 1j, 2j, 1], dtype=dtype)
rhs = np.array([5, -6j, 7 - 3j, -8j], dtype=dtype)
Expand All @@ -385,7 +421,9 @@ def testComplexOps(self):
self._testBinary(
gen_math_ops._sigmoid_grad, lhs, rhs, expected=rhs * lhs * (1 - lhs))

# TODO(b/65408531): support+test _rsqrt_grad for cplx (needs pow)
if atan2_supported:
self._testBinary(
gen_math_ops._rsqrt_grad, lhs, rhs, expected=lhs**3 * rhs / -2)

self._testBinary(
gen_math_ops._sqrt_grad, lhs, rhs, expected=rhs / (2 * lhs))
Expand Down
30 changes: 23 additions & 7 deletions tensorflow/compiler/tests/reduce_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,25 +67,37 @@ def _testReduction(self, tf_reduce_fn, np_reduce_fn, dtype, test_inputs,
np.arange(-10, -4).reshape(2, 3),
np.arange(-4, 2).reshape(2, 3),
]
NONEMPTY_FLOAT_DATA = [
np.arange(1, 7).reshape(2, 3),
np.arange(-10, -4).reshape(2, 3),
np.arange(-4, 2).reshape(2, 3),
COMPLEX_DATA = [
np.zeros(shape=(2, 0)).astype(np.complex64),
np.zeros(shape=(0, 30)).astype(np.complex64),
np.arange(1, 13, dtype=np.float32).view(np.complex64).reshape(2, 3),
np.arange(-14, -2, dtype=np.float32).view(np.complex64).reshape(2, 3),
np.arange(-4, 8, dtype=np.float32).view(np.complex64).reshape(2, 3),
]
NONEMPTY_FLOAT_DATA = [x for x in FLOAT_DATA if np.size(x) > 0]
NONEMPTY_COMPLEX_DATA = [x for x in COMPLEX_DATA if np.size(x) > 0]
BOOL_DATA = [
np.array([], dtype=np.bool).reshape(2, 0),
np.array([], dtype=np.bool).reshape(0, 3),
np.array([[False, True, False], [True, True, False]]),
]

def testReduceSum(self):
def testReduceSumF32(self):
self._testReduction(math_ops.reduce_sum, np.sum, np.float32,
self.FLOAT_DATA)

def testReduceProd(self):
def testReduceSumC64(self):
self._testReduction(math_ops.reduce_sum, np.sum, np.complex64,
self.COMPLEX_DATA)

def testReduceProdF32(self):
self._testReduction(math_ops.reduce_prod, np.prod, np.float32,
self.FLOAT_DATA)

def testReduceProdC64(self):
self._testReduction(math_ops.reduce_prod, np.prod, np.complex64,
self.COMPLEX_DATA)

def testReduceMin(self):

def reference_min(inp, axis):
Expand All @@ -108,12 +120,16 @@ def reference_max(inp, axis):
self._testReduction(math_ops.reduce_max, reference_max, np.float32,
self.FLOAT_DATA)

def testReduceMean(self):
def testReduceMeanF32(self):
# TODO(phawkins): mean on XLA currently returns 0 instead of NaN when
# reducing across zero inputs.
self._testReduction(math_ops.reduce_mean, np.mean, np.float32,
self.NONEMPTY_FLOAT_DATA)

def testReduceMeanC64(self):
self._testReduction(math_ops.reduce_mean, np.mean, np.complex64,
self.NONEMPTY_COMPLEX_DATA)

def testReduceAll(self):
self._testReduction(math_ops.reduce_all, np.all, np.bool, self.BOOL_DATA)

Expand Down
31 changes: 24 additions & 7 deletions tensorflow/compiler/tests/unary_ops_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,12 +330,22 @@ def testFloatOps(self):

def testComplexOps(self):
for dtype in self.complex_types:
# TODO(b/65408531): math_ops.acosh (needs pow)
# TODO(b/65408531): math_ops.asinh (needs pow)

# TODO(b/65408531): Wider support for log (needs atan2).
atan2_supported = self.device == "XLA_GPU"
if atan2_supported:
self._assertOpOutputMatchesExpected(
math_ops.acosh,
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
expected=np.arccosh(
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))

self._assertOpOutputMatchesExpected(
math_ops.asinh,
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
expected=np.arcsinh(
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype)))

self._assertOpOutputMatchesExpected(
math_ops.atanh,
np.array([0.1, 0.2j, 0.3 - 0.1j, 0.4 + 0.5j], dtype=dtype),
Expand Down Expand Up @@ -392,19 +402,26 @@ def testComplexOps(self):
expected=np.log1p(
np.array([[1e-14, 1e-15j, 0.6 - 0.3j]], dtype=dtype)))

# TODO(b/34703906): math_ops.rsqrt (needs pow)
val = np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)
self._assertOpOutputMatchesExpected(
math_ops.rsqrt, val, expected=1 / np.sqrt(val))

# TODO(b/34703906): math_ops.sigmoid (needs tanh)
self._assertOpOutputMatchesExpected(
math_ops.sigmoid, val, expected=1 / (1 + np.exp(-val)))

# TODO(b/34703906): math_ops.sqrt (needs pow)
self._assertOpOutputMatchesExpected(
math_ops.sqrt, val, expected=np.sqrt(val))

self._assertOpOutputMatchesExpected(
math_ops.tanh,
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
expected=np.tanh(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))

self._assertOpOutputMatchesExpected(
math_ops.tan,
np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype),
expected=np.tan(np.array([1, 2j, 2 - 3j, 4 + 5j], dtype=dtype)))

# TODO(b/34703906): math_ops.tanh (as itself)

ctypes = {np.complex64: np.float32}
self._assertOpOutputMatchesExpected(
math_ops.abs,
Expand Down
16 changes: 14 additions & 2 deletions tensorflow/compiler/xla/legacy_flags/debug_options_flags.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@ std::vector<tensorflow::Flag>* flag_objects;
std::once_flag flags_init;

void SetDebugOptionsDefaults(DebugOptions* flags) {
flags->set_xla_hlo_graph_path("/tmp/");
flags->set_xla_enable_fast_math(true);
flags->set_xla_llvm_enable_alias_scope_metadata(true);
flags->set_xla_llvm_enable_noalias_metadata(true);
Expand Down Expand Up @@ -117,9 +116,22 @@ void AllocateFlags() {
bool_setter_for(&DebugOptions::set_xla_hlo_dump_as_graphdef),
flag_values->xla_hlo_dump_as_graphdef(),
"Dump HLO graphs as TensorFlow GraphDefs."),
tensorflow::Flag(
"xla_hlo_graph_sharding_color",
bool_setter_for(&DebugOptions::set_xla_hlo_graph_sharding_color),
flag_values->xla_hlo_graph_sharding_color(),
"Assign colors based on sharding assignments when generating the "
"HLO graphs."),
tensorflow::Flag(
"xla_hlo_tfgraph_device_scopes",
bool_setter_for(&DebugOptions::set_xla_hlo_tfgraph_device_scopes),
flag_values->xla_hlo_tfgraph_device_scopes(),
"When generating TensorFlow HLO graphs, if the HLO instructions "
"are assigned to a specific device, prefix the name scope with "
"\"devX\" with X being the device ordinal."),
tensorflow::Flag(
"xla_log_hlo_text", flag_values->mutable_xla_log_hlo_text(),
"HLO modules matching this regex will be dumped to LOG(INFO). "),
"HLO modules matching this regex will be dumped to LOG(INFO)."),
tensorflow::Flag(
"xla_generate_hlo_text_to",
flag_values->mutable_xla_generate_hlo_text_to(),
Expand Down
4 changes: 2 additions & 2 deletions tensorflow/compiler/xla/literal_util_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,10 @@ TEST_F(LiteralUtilTest, LiteralScalarToString) {
auto bf16_lit = Literal::CreateR0<bfloat16>(static_cast<bfloat16>(0.5f));
ASSERT_EQ("0.5", bf16_lit->ToString());

// 3.14 will be rounded to 3.125 in bfloat16 format (Round to nearest even).
// 3.14 will be truncated to 3.125 in bfloat16 format.
auto bf16_lit_truncated =
Literal::CreateR0<bfloat16>(static_cast<bfloat16>(3.14f));
ASSERT_EQ("3.140625", bf16_lit_truncated->ToString());
ASSERT_EQ("3.125", bf16_lit_truncated->ToString());

auto bf16_lit_truncated2 =
Literal::CreateR0<bfloat16>(static_cast<bfloat16>(9.001f));
Expand Down
25 changes: 25 additions & 0 deletions tensorflow/compiler/xla/service/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -630,6 +630,7 @@ cc_library(

cc_library(
name = "llvm_compiler",
srcs = ["llvm_compiler.cc"],
hdrs = ["llvm_compiler.h"],
deps = [
":compiler",
Expand Down Expand Up @@ -1358,6 +1359,7 @@ cc_library(
deps = [
":hlo",
":hlo_cost_analysis",
":hlo_profile_printer",
":human_readable_profile_builder",
"//tensorflow/compiler/xla:types",
"//tensorflow/compiler/xla:util",
Expand All @@ -1366,6 +1368,18 @@ cc_library(
],
)

tf_cc_test(
name = "hlo_execution_profile_test",
srcs = ["hlo_execution_profile_test.cc"],
deps = [
":cpu_plugin",
":hlo_cost_analysis",
":hlo_execution_profile",
"//tensorflow/compiler/xla/tests:hlo_test_base",
"//tensorflow/compiler/xla/tests:xla_internal_test_main",
],
)

tf_cc_test(
name = "hlo_computation_test",
srcs = ["hlo_computation_test.cc"],
Expand Down Expand Up @@ -1983,6 +1997,7 @@ cc_library(
":hlo",
"//tensorflow/compiler/xla:literal_util",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_proto",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
Expand Down Expand Up @@ -2156,6 +2171,16 @@ cc_library(
],
)

cc_library(
name = "hlo_profile_printer",
srcs = ["hlo_profile_printer.cc"],
hdrs = ["hlo_profile_printer.h"],
deps = [
":human_readable_profile_builder",
"//tensorflow/compiler/xla:types",
],
)

# -----------------------------------------------------------------------------

filegroup(
Expand Down
18 changes: 15 additions & 3 deletions tensorflow/compiler/xla/service/algebraic_simplifier.cc
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,6 @@ limitations under the License.
namespace xla {
namespace {

using tensorflow::gtl::nullopt;
using tensorflow::gtl::optional;

// Returns whether operand is a literal with the given value.
bool IsLiteralWithValue(const HloInstruction* operand, int8 value) {
return operand->opcode() == HloOpcode::kConstant &&
Expand Down Expand Up @@ -135,7 +132,10 @@ class AlgebraicSimplifierVisitor : public DfsHloVisitorWithDefault {

Status HandleConvert(HloInstruction* convert) override;

Status HandleComplex(HloInstruction* complex) override;

Status HandleReal(HloInstruction* real) override;

Status HandleImag(HloInstruction* imag) override;

Status HandleConvolution(HloInstruction* convolution) override;
Expand Down Expand Up @@ -947,6 +947,18 @@ Status AlgebraicSimplifierVisitor::HandleConvert(HloInstruction* convert) {
return Status::OK();
}

// Complex(Real(c), Imag(c)) -> c
Status AlgebraicSimplifierVisitor::HandleComplex(HloInstruction* complex) {
auto real = complex->mutable_operand(0);
auto imag = complex->mutable_operand(1);
if (real->opcode() == HloOpcode::kReal &&
imag->opcode() == HloOpcode::kImag &&
real->operand(0) == imag->operand(0)) {
return ReplaceInstruction(complex, real->mutable_operand(0));
}
return Status::OK();
}

// Real(Complex(r, i)) -> r
Status AlgebraicSimplifierVisitor::HandleReal(HloInstruction* real) {
auto operand = real->mutable_operand(0);
Expand Down
Loading

0 comments on commit c4e416f

Please sign in to comment.