Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[tensorflow] Introduce non-strict NodeDef names #94

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 53 additions & 53 deletions symbolic_pymc/tensorflow/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,12 @@
tf_metatize_cache = Cache(50)


class DefaultTensorName(str):
"""A type used to indicate a default tensor name."""

pass


class MetaOpDefLibrary(object):
"""A singleton-like object that holds correspondences between TF Python API functions and the `OpDef`s they construct.

Expand Down Expand Up @@ -366,10 +372,16 @@ def _protobuf_convert(cls, k, v):
raise TypeError(f"Could not convert {k}")

def __init__(self, op, name, attr, obj=None):
"""Create a TF meta NodeDef.

XXX: Meta NodeDefs with `name == None` have a special meaning;
their names are uniquely generated. We still consider them equal
(when every other property is equal, of course).
"""
super().__init__(obj=obj)
self.op = metatize(op)
assert name is not None
self.name = name if isvar(name) else str(name)
self.name = name if isvar(name) else name

if not isvar(attr):
opdef_sig, _ = op_def_lib.get_op_info(self.op)
Expand Down Expand Up @@ -600,6 +612,11 @@ def reify(self):
# An operation with this name might already exist in the graph
#
try:
# FIXME: Lame hack
if isinstance(self.name, DefaultTensorName):
# Use a unique version of the default name.
raise KeyError()

existing_op = ops.get_default_graph().get_operation_by_name(self.name)
except KeyError:
#
Expand All @@ -613,7 +630,15 @@ def reify(self):
# An `Operation` with this name exists, let's make sure it's
# equivalent to this meta `Operation`
#
if self != mt(existing_op):
existing_op_mt = mt(existing_op)

# # Since we can't exactly reproduce all NodeDef.attr information
# # (e.g. dtypes), we need to remove any unnecessary NodeDef.attr
# # fields from comparisons with same-named nodes in the graph.
# if op_attrs.keys() != node_attr.keys():
# existing_op_mt.node_def.attr = node_attr

if self != existing_op_mt:
raise MetaReificationError(
f"An Operation with the name {self.name}"
" already exists in the graph and is not"
Expand Down Expand Up @@ -725,40 +750,40 @@ def reify(self):

def __truediv__(self, y):
# TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
return mt.realdiv(self, y, name="truediv")
return mt.realdiv(self, y, name=DefaultTensorName("truediv"))

def __rtruediv__(self, x):
# TODO: TF performs some dtype logic (using `dtype.base_dtype`) and casting here.
return mt.realdiv(x, self, name="truediv")
return mt.realdiv(x, self, name=DefaultTensorName("truediv"))

def __add__(self, y):
# TODO: If `self.dtype == tf.dtypes.string`, use `mt.add`
return mt.addv2(self, y, name="add")
return mt.addv2(self, y, name=DefaultTensorName("add"))

def __radd__(self, x):
# TODO: If `x.dtype == tf.dtypes.string`, use `mt.add`
return mt.addv2(x, self, name="add")
return mt.addv2(x, self, name=DefaultTensorName("add"))

def __sub__(self, y):
return mt.sub(self, y, name="sub")
return mt.sub(self, y, name=DefaultTensorName("sub"))

def __rsub__(self, x):
return mt.sub(x, self, name="sub")
return mt.sub(x, self, name=DefaultTensorName("sub"))

def __mul__(self, y):
return mt.mul(self, y, name="mul")
return mt.mul(self, y, name=DefaultTensorName("mul"))

def __rmul__(self, x):
return mt.mul(x, self, name="mul")
return mt.mul(x, self, name=DefaultTensorName("mul"))

def __abs__(self):
return mt.abs(self, name="Abs")
return mt.abs(self, name=DefaultTensorName("Abs"))

def __pow__(self, y):
return mt.pow(self, y, name="pow")
return mt.pow(self, y, name=DefaultTensorName("pow"))

def __neg__(self):
return mt.neg(self, name="Neg")
return mt.neg(self, name=DefaultTensorName("Neg"))


class TFlowMetaTensorShape(TFlowMetaSymbol):
Expand Down Expand Up @@ -987,48 +1012,22 @@ def __api_call__(self, *args, **kwargs):

if not op_args_unreified:

res_var = None
# name = op_args.get("name", None)
#
# if name is not None:
# #
# # An operation with this name might already exist in the graph
# #
#
# from tensorflow.python.framework import ops
# We create the `Operation` in the graph
#
# try:
# this_op = ops.get_default_graph().get_operation_by_name(name)
# except KeyError:
# pass
# else:
# # TODO: Make sure the existing `Operation` matches our arguments
# assert this_op.type == self.op_def.obj.name
#
# this_op = mt(this_op)
# op_inputs, op_node_def = self.op_args_to_operation_inputs(op_args)
# assert op_inputs == this_op.inputs
# assert op_node_def == this_op.node_def
# res_var = this_op.default_output

if res_var is None:
#
# We create the `Operation` in the graph
#

tf_out = self._apply_func(**op_args)

# Ensure that the original meta objects will be available
# for use in the `metatize` that follows
tf_metatize_cache.update(
{
k: v
for k, v in zip(op_args.values(), apply_arguments.values())
if isinstance(k, tf.Tensor)
}
)
tf_out = self._apply_func(**op_args)

# Ensure that the original meta objects will be available
# for use in the `metatize` that follows
tf_metatize_cache.update(
{
k: v
for k, v in zip(op_args.values(), apply_arguments.values())
if isinstance(k, tf.Tensor)
}
)

res_var = metatize(tf_out)
res_var = metatize(tf_out)

if "names" in meta._lvar_defaults_enabled:
# This should also reset the NodeDef's `obj`
Expand Down Expand Up @@ -1073,7 +1072,8 @@ def op_args_to_operation_inputs(self, apply_arguments):
node_attr = var()

if "names" not in meta._lvar_defaults_enabled:
op_name = apply_arguments.get("name", op_def_tf.name) or op_def_tf.name
default_name = DefaultTensorName(op_def_tf.name)
op_name = apply_arguments.get("name", default_name) or default_name
else:
op_name = var()

Expand Down
19 changes: 17 additions & 2 deletions tests/tensorflow/test_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
TFlowMetaOperator,
MetaOpDefLibrary,
MetaReificationError,
DefaultTensorName,
mt)

from tests.tensorflow import run_in_graph_mode
Expand Down Expand Up @@ -636,7 +637,7 @@ def test_global_options():
with tf.Graph().as_default(), disable_auto_reification():
y_mt = mt.Placeholder('float')
assert y_mt.obj is None
assert y_mt.name == 'Placeholder:0'
assert isinstance(y_mt.op.name, DefaultTensorName)
assert isinstance(y_mt.op.node_def.attr, dict)

with tf.Graph().as_default(), enable_lvar_defaults('names', 'node_attrs'):
Expand Down Expand Up @@ -706,7 +707,7 @@ def test_meta_const():
@run_in_graph_mode
def test_meta_existing_names():

with tf.Graph().as_default():
with tf.Graph().as_default() as test_graph:
one_mt = mt(1)
assert one_mt.op.name == 'Const'

Expand All @@ -723,6 +724,7 @@ def test_meta_existing_names():
# Make sure it's the first base variable we created
assert orig_one_tf is one_tf

# FYI: This implicitly creates 'Const_1'
two_mt = mt(2)
two_mt.op.node_def.name = 'Const'

Expand All @@ -736,3 +738,16 @@ def test_meta_existing_names():

with pytest.raises(MetaReificationError):
two_mt.reify()

another_one_mt = TFlowMetaOperator('Const', None)(3, var())
# The following is something that would happen as a result of
# reification (of the lvar in the meta object, not the meta object
# itself).
another_one_mt.op.node_def.attr['dtype'] = tf.int32

assert another_one_mt.op.name == 'Const'
assert isinstance(another_one_mt.op.name, DefaultTensorName)
# We need to make sure that the reified meta object actually uses a
# unique name.
assert isinstance(another_one_mt.reify(), tf.Tensor)
assert another_one_mt.reify().op.name == 'Const_2'
33 changes: 30 additions & 3 deletions tests/tensorflow/test_unify.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def test_etuple_term():
# TODO FIXME: Because of the above two, this errs
# add_lvar_et = etuplize(add_lvar_mt)


@run_in_graph_mode
def test_basic_unify_reify():
# Test reification with manually constructed replacements
Expand All @@ -127,8 +128,11 @@ def test_basic_unify_reify():

test_expr = mt.add(tf.constant(1, dtype=tf.float64),
mt.mul(tf.constant(2, dtype=tf.float64),
x_l))
test_reify_res = reify(test_expr, {x_l: a})
x_l, name=var('mul_name')),
name=var('add_name'))
test_reify_res = reify(test_expr, {x_l: a,
var('add_name'): 'Add_10',
var('mul_name'): 'Mul_10'})
test_base_res = test_reify_res.reify()
assert isinstance(test_base_res, tf.Tensor)

Expand All @@ -141,7 +145,7 @@ def test_basic_unify_reify():
# Simply make sure that unification succeeds
meta_expected_res = mt(expected_res)
s_test = unify(test_expr, meta_expected_res, {})
assert len(s_test) == 3
assert len(s_test) == 5

assert reify(test_expr, s_test) == meta_expected_res

Expand Down Expand Up @@ -199,3 +203,26 @@ def test_sexp_unify_reify():
# Now, the second, `A . y`
assert z_dist_tf.op.inputs[1].op.inputs[0] == A
assert z_dist_tf.op.inputs[1].op.inputs[1] == y


@run_in_graph_mode
@pytest.mark.xfail(strict=True)
def test_unique_names():

first_div_mt = mt(1) / mt(2)

assert first_div_mt.op.name == 'truediv'
assert first_div_mt.reify().op.name

div_lv = mt.realdiv(var('b'), var('c'), name=var('name'))
# Unify with the TF graph, then reify
s = unify(first_div_mt.reify(), div_lv)

s[var('b')] = 1
s[var('b')] = 3

div_mt = reify(div_lv, s)

assert div_mt.op.name == 'truediv'
assert isinstance(div_mt.reify(), tf.Tensor)
assert first_div_mt.reify() != div_mt.reify()