Skip to content

Consider using TensorFlow operations to avoid side-effects #406

@khatchad

Description

@khatchad

Consider the example found here:

class Model(tf.Module):
  def __init__(self):
    self.v = tf.Variable(0)
    self.counter = 0

  def __call__(self):
    if self.counter == 0:
      # A python side-effect
      self.counter += 1
      self.v.assign_add(1)

    return self.v

m = Model()
for n in range(3):
  print(m().numpy()) # prints 1, 1, 1 but if @tf.function is used, it would print 1, 2, 3.

Because __call__() writes to self.counter (the "Python" side-effect), it wouldn't be refactored to graph execution to preserve semantics. However, it may be interesting to explore automatically rewriting the code using TensorFlow operations like tf.if() (in this case) and tf.while() to force the computation to become a graph node and thus preserve semantics during the refactoring to graph execution.

Suggested by @ansariahmad.

Metadata

Metadata

Assignees

No one assigned

    Labels

    enhancementNew feature or requestjavaPull requests that update Java codepreconditionsRefactoring preconditions need to be added.questionFurther information is requestedside-effects

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions