-
Notifications
You must be signed in to change notification settings - Fork 1
Extract tf.function decorator keyword arguments #144
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 64 commits
39c7fb5
f21e423
747d5c5
8f8c0b1
4ffa27a
21a8004
89e4dd0
b6f4e9c
b9c8107
d6b110c
26fd8f1
9a4cb62
d0d0875
015965a
6102351
a62af55
f7a9373
c9ee308
0ab5892
db5147f
b9bd9be
9db5396
ec84857
74b4bc4
e12fb16
7aef65b
9f49fea
279089b
ef5d5bb
64bd2b6
f53e997
1df7219
1b2ac89
c4184d4
d2ae4c9
b48960d
1eb7274
9195db7
eba9736
dd82611
df97664
54d78be
10cb16b
303a117
fc90f50
ecffbc5
3577a39
0c72e7d
94733e6
41bfe8d
9003639
6fd42b2
1682088
02ef49b
ece2522
93f4e9c
af61fe5
247e1d6
9b764d9
bc3ff06
cb23e2f
79d0ff3
2c93c3f
e690e4f
bc534a4
9163368
09d52cb
d845f0a
6a13a5e
4c79a5a
720bba7
5d63094
0c345f0
44b8b7b
c5b576b
f2238cf
45bcfd2
a8ddc42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,82 @@ | ||
| package edu.cuny.hunter.hybridize.core.analysis; | ||
|
|
||
| import java.util.Objects; | ||
|
|
||
| /** | ||
| * A representation of a tf.Tensorspec which describes a tf.Tensor | ||
| */ | ||
| public class TensorSpec { | ||
|
|
||
| /** | ||
| * Shape of the tensor being described by {@link TensorSpec}. | ||
| */ | ||
| private String shape; | ||
|
|
||
| /** | ||
| * Type of the tensor being described by {@link TensorSpec}. | ||
| */ | ||
| private String dtype; | ||
|
|
||
| public TensorSpec() { | ||
| this.shape = ""; | ||
| this.dtype = ""; | ||
|
||
| } | ||
|
|
||
| public TensorSpec(String s, String d) { | ||
| this.shape = s; | ||
| this.dtype = d; | ||
| } | ||
|
|
||
| /** | ||
| * Shape of {@link TensorSpec}. | ||
| * | ||
| * @return String of this {@link TensorSpec} shape. | ||
| */ | ||
| public String getShape() { | ||
| return this.shape; | ||
| } | ||
|
|
||
| /** | ||
| * Dtype of {@link TensorSpec}. | ||
| * | ||
| * @return String of this {@link TensorSpec} dtype. | ||
| */ | ||
| public String getDType() { | ||
| return this.dtype; | ||
| } | ||
|
|
||
| /** | ||
| * Set shape of {@link TensorSpec}. | ||
| */ | ||
| public void setShape(String s) { | ||
| this.shape = s; | ||
| } | ||
|
|
||
| /** | ||
| * Set dtype of {@link TensorSpec}. | ||
| */ | ||
| public void setDType(String d) { | ||
| this.dtype = d; | ||
| } | ||
|
|
||
| @Override | ||
| public int hashCode() { | ||
| return Objects.hash(shape, dtype); | ||
| } | ||
|
|
||
| @Override | ||
| public boolean equals(Object tensorObject) { | ||
|
|
||
| if (tensorObject == this) { | ||
| return true; | ||
| } | ||
|
|
||
| if (!(tensorObject instanceof TensorSpec)) { | ||
| return false; | ||
| } | ||
|
|
||
| TensorSpec tensor = (TensorSpec) tensorObject; | ||
|
|
||
| return shape.equals(tensor.shape) && dtype.equals(tensor.dtype); | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It is overridden because, in our tests, I check am checking that the list of tensorspec that is generated is the same one as the one provided. |
||
| } | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(input_signature=None) | ||
| def func(x): | ||
| return x | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| number = tf.constant([1.0, 1.0]) | ||
| func(number) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(jit_compile=None) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(reduce_retracing=True) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(reduce_retracing=False) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(experimental_implements="google.matmul_low_rank_matrix") | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(experimental_implements="embedded_matmul") | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(experimental_implements=None) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(experimental_autograph_options=tf.autograph.experimental.Feature.EQUALITY_OPERATORS) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(experimental_autograph_options=tf.autograph.experimental.Feature.ALL) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(experimental_autograph_options=(tf.autograph.experimental.Feature.EQUALITY_OPERATORS, tf.autograph.experimental.Feature.BUILTIN_FUNCTIONS)) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(experimental_autograph_options=None) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),)) | ||
| def func(x): | ||
| return x | ||
|
|
||
| if __name__ == '__main__': | ||
| number = tf.constant([1.0, 1.0]) | ||
| func(number) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(experimental_follow_type_hints=True) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(experimental_follow_type_hints=False) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(experimental_follow_type_hints=None) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(input_signature=(tf.TensorSpec(shape=[None], dtype=tf.float32),), autograph=False) | ||
| def func(x): | ||
| print('Tracing with', x) | ||
| return x | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| number = tf.constant([1.0, 1.0]) | ||
| func(number) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| import custom | ||
|
|
||
|
|
||
| @custom.decorator(input_signature=None) | ||
| def func(x): | ||
| print('Tracing with', x) | ||
| return x | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func(1) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| def decorator(input_signature=None): | ||
|
|
||
| def decorated(inner_function): | ||
|
|
||
| def wrapper(*args, **kwargs): | ||
| result = function(*args, **kwargs) | ||
| return result | ||
|
|
||
| return decorated | ||
|
|
||
| return decorator |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
|
|
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,12 @@ | ||
| import custom | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @custom.decorator(input_signature=None) | ||
| @tf.function(autograph=False) | ||
| def func(): | ||
| pass | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func() |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| def decorator(input_signature=None): | ||
|
|
||
| def decorated(inner_function): | ||
|
|
||
| def wrapper(*args, **kwargs): | ||
| result = function(*args, **kwargs) | ||
| return result | ||
|
|
||
| return decorated | ||
|
|
||
| return decorator |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,11 @@ | ||
| import tensorflow as tf | ||
|
|
||
|
|
||
| @tf.function(autograph=False) | ||
| @tf.function(jit_compile=True) | ||
| def func(x): | ||
| return x | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| func(tf.constant(1)) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,13 @@ | ||
| import tensorflow as tf | ||
|
|
||
| var = (tf.TensorSpec(shape=[None], dtype=tf.float32),) | ||
|
|
||
|
|
||
| @tf.function(input_signature=var) | ||
| def func(x): | ||
| return x | ||
|
|
||
|
|
||
| if __name__ == '__main__': | ||
| number = tf.constant([1.0, 1.0]) | ||
| func(number) |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| tensorflow==2.9.3 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there more useful representation for a shape than a
String?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a
List