From c0033eb1e55a7af9ec45e1a8408d3d0b840d189e Mon Sep 17 00:00:00 2001 From: Pranshu Pant <32600304+pranshupant@users.noreply.github.com> Date: Mon, 10 Jul 2023 15:32:50 -0400 Subject: [PATCH] Implement GELU as function op (#5277) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit ### Description These changes have been made to support the GELU operator as a function op. ### Motivation and Context Support for [GELU: Gaussian Error Linear Unit](https://paperswithcode.com/method/gelu) activation function, which was requested in #4933. #4423 also mentions this under the new ops section of `Contributions Welcome`. As per the discussion in #4933, I have added GELU as a context-dependent function-op, that uses the attribute `approximate` to return one of the two possible function-body definitions. The first function definition is the regular GELU: `GELU(x)=x∗Φ(x) = 0.5 * x * (1 + erf(x / sqrt(2)))` The second is the fast approximation based on `tanh`: `GELU(x)=0.5 ∗ x ∗ (1+Tanh( sqrt(2/π) ∗ (x + 0.044715 ∗ x^3)))` This implementation uses the [PyTorch docs for GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html?highlight=gelu#torch.nn.GELU) as a reference. PS: I also refactored `onnx/defs/math/defs.cc` to bring the operator implementation of `mish` right next to its doc string. --------- Signed-off-by: pranshupant Co-authored-by: G. Ramalingam --- docs/Changelog.md | 42 +++++++ docs/Operators.md | 96 ++++++++++++++++ docs/TestCoverage.md | 52 ++++++++- onnx/backend/test/case/node/gelu.py | 51 +++++++++ .../data/node/test_gelu_default_1/model.onnx | Bin 0 -> 93 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 21 bytes .../test_gelu_default_1_expanded/model.onnx | Bin 0 -> 1429 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 21 bytes .../data/node/test_gelu_default_2/model.onnx | Bin 0 -> 109 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | 3 + .../test_gelu_default_2_expanded/model.onnx | Bin 0 -> 1445 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | 3 + .../data/node/test_gelu_tanh_1/model.onnx | Bin 0 -> 114 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 21 bytes .../node/test_gelu_tanh_1_expanded/model.onnx | Bin 0 -> 2239 bytes .../test_data_set_0/input_0.pb | Bin 0 -> 21 bytes .../test_data_set_0/output_0.pb | Bin 0 -> 21 bytes .../data/node/test_gelu_tanh_2/model.onnx | Bin 0 -> 130 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | Bin 0 -> 254 bytes .../node/test_gelu_tanh_2_expanded/model.onnx | Bin 0 -> 2255 bytes .../test_data_set_0/input_0.pb | 1 + .../test_data_set_0/output_0.pb | Bin 0 -> 254 bytes onnx/defs/math/defs.cc | 108 ++++++++++++++++-- onnx/defs/operator_sets.h | 2 + onnx/test/automatic_upgrade_test.py | 6 + onnx/test/test_backend_onnxruntime.py | 1 + 32 files changed, 355 insertions(+), 13 deletions(-) create mode 100644 onnx/backend/test/case/node/gelu.py create mode 100644 onnx/backend/test/data/node/test_gelu_default_1/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_1/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_1_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_1_expanded/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/output_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2_expanded/model.onnx create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb create mode 100644 onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/output_0.pb diff --git a/docs/Changelog.md b/docs/Changelog.md index afeb3331fd2..e7c9a05dee4 100644 --- a/docs/Changelog.md +++ b/docs/Changelog.md @@ -23919,6 +23919,48 @@ This version of the operator has been available since version 20 of the default
Constrain output types to be numerics.
+### **Gelu-20** + + Gelu takes one input data (Tensor) and produces one + output data (Tensor) where the gaussian error linear units function, + $y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. + If the attribute "approximate" is set to "tanh", the function estimation, + $y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied + to the tensor elementwise. + + +#### Version + +This version of the operator has been available since version 20 of the default ONNX operator set. + +#### Attributes + +
+
approximate : string (default is none)
+
Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.
+
+ +#### Inputs + +
+
X (differentiable) : T
+
Input tensor
+
+ +#### Outputs + +
+
Y (differentiable) : T
+
Output tensor
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
+
Constrain input and output types to float tensors.
+
+ ### **GridSample-20** Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`. diff --git a/docs/Operators.md b/docs/Operators.md index 1e2eb08671c..0efd01da237 100644 --- a/docs/Operators.md +++ b/docs/Operators.md @@ -170,6 +170,7 @@ For an operator input/output's differentiability, it can be differentiable, |Clip|13, 12, 11, 6, 1|13| |DynamicQuantizeLinear|11|11| |Elu|6, 1|18| +|Gelu|20|20| |GreaterOrEqual|16, 12|16| |GroupNormalization|18|18| |HammingWindow|17|17| @@ -9410,6 +9411,101 @@ expect( +### **Gelu** + + Gelu takes one input data (Tensor) and produces one + output data (Tensor) where the gaussian error linear units function, + $y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. + If the attribute "approximate" is set to "tanh", the function estimation, + $y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied + to the tensor elementwise. + + +#### Version + +This version of the operator has been available since version 20 of the default ONNX operator set. + +#### Attributes + +
+
approximate : string (default is none)
+
Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.
+
+ +#### Inputs + +
+
X (differentiable) : T
+
Input tensor
+
+ +#### Outputs + +
+
Y (differentiable) : T
+
Output tensor
+
+ +#### Type Constraints + +
+
T : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)
+
Constrain input and output types to float tensors.
+
+ + +#### Examples + +
+gelu_default + +```python +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.15865526, 0., 0.84134474] +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.99595031, 3.99987331, 4.99999857] +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") +``` + +
+ + +
+gelu_tanh + +```python +node = onnx.helper.make_node( + "Gelu", inputs=["x"], outputs=["y"], approximate="tanh" +) + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.158808, 0., 0.841192] +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.9963627, 3.99993, 4.9999995] +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") +``` + +
+ + ### **Gemm** General Matrix multiplication: diff --git a/docs/TestCoverage.md b/docs/TestCoverage.md index f59159e2a4c..7b8d0cd2d9e 100644 --- a/docs/TestCoverage.md +++ b/docs/TestCoverage.md @@ -6,7 +6,7 @@ * [Overall Test Coverage](#overall-test-coverage) # Node Test Coverage ## Summary -Node tests have covered 173/186 (93.01%, 5 generators excluded) common operators. +Node tests have covered 174/187 (93.05%, 5 generators excluded) common operators. Node tests have covered 0/0 (N/A) experimental operators. @@ -6241,6 +6241,56 @@ expect( +### Gelu +There are 2 test cases, listed as following: +
+gelu_default + +```python +node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.15865526, 0., 0.84134474] +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.99595031, 3.99987331, 4.99999857] +y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) +expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") +``` + +
+
+gelu_tanh + +```python +node = onnx.helper.make_node( + "Gelu", inputs=["x"], outputs=["y"], approximate="tanh" +) + +x = np.array([-1, 0, 1]).astype(np.float32) +# expected output [-0.158808, 0., 0.841192] +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") + +x = np.random.randn(3, 4, 5).astype(np.float32) +# expected output [2.9963627, 3.99993, 4.9999995] +y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) +).astype(np.float32) +expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") +``` + +
+ + ### Gemm There are 11 test cases, listed as following:
diff --git a/onnx/backend/test/case/node/gelu.py b/onnx/backend/test/case/node/gelu.py new file mode 100644 index 00000000000..cc93a4f5471 --- /dev/null +++ b/onnx/backend/test/case/node/gelu.py @@ -0,0 +1,51 @@ +# Copyright (c) ONNX Project Contributors +# +# SPDX-License-Identifier: Apache-2.0 + +import math + +import numpy as np + +import onnx +from onnx.backend.test.case.base import Base +from onnx.backend.test.case.node import expect + + +class Gelu(Base): + @staticmethod + def export_gelu_tanh() -> None: + node = onnx.helper.make_node( + "Gelu", inputs=["x"], outputs=["y"], approximate="tanh" + ) + + x = np.array([-1, 0, 1]).astype(np.float32) + # expected output [-0.158808, 0., 0.841192] + y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) + ).astype(np.float32) + expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1") + + x = np.random.randn(3, 4, 5).astype(np.float32) + # expected output [2.9963627, 3.99993, 4.9999995] + y = ( + 0.5 + * x + * (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) + ).astype(np.float32) + expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2") + + @staticmethod + def export_gelu_default() -> None: + node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"]) + + x = np.array([-1, 0, 1]).astype(np.float32) + # expected output [-0.15865526, 0., 0.84134474] + y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) + expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1") + + x = np.random.randn(3, 4, 5).astype(np.float32) + # expected output [2.99595031, 3.99987331, 4.99999857] + y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32) + expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2") diff --git a/onnx/backend/test/data/node/test_gelu_default_1/model.onnx b/onnx/backend/test/data/node/test_gelu_default_1/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..ada8f652bed5fdba31049c4a5363d16c902a76fe GIT binary patch literal 93 zcmdKLDtPodJP$0G`FktuY3w$9ih*Em0vM2|?_t``v z?LZt8ARdiwF!T9irUD|W_i|1X@=q&+kD|c}yV}%+W(85HRxsp`l@ch)s1ZmB>`!Ob zry%qP?1K_zlSpd-aZ0G#&KLDtPoz?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..c55aea167f7 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_default_2/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +ByJ?K>Q?A @t? V$I?WuB>d=?VQ?V=q>~W>Q?xG>d+8^_$>o2?.ޓ@3ٽxD<+7?п?24=Iz>jL*'A=D?D? +| +BKR?h@:UZ?T;%;)>faJ>s=?>*S  \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx b/onnx/backend/test/data/node/test_gelu_default_2_expanded/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..1988c1b6297e78111e1b62265dd3b2f27301d617 GIT binary patch literal 1445 zcmbtUO-sW-5Z#y*;@Cr2L@%ClEcija^`f*EJtzvTP`oVLW}~5TlQz4xz4-(D5&fyo zmeh+<7(8q;%QElHo3C)6H^*`^RXRB}sxqCg19;}=w<@>7-Nmc35|v7u8_^bOSxTL# zM5IfqlmiXCyYaS~el#>`+qqfV*lG`}jl3oK%I4D`~b-9rso>8$a2#G+&P0 b=x~Tak2hg2^zj6E8e@D|{q@idfJKMjk9hki literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..bae0ffd6324 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +BxJx?h>z?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb new file mode 100644 index 00000000000..c55aea167f7 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_default_2_expanded/test_data_set_0/output_0.pb @@ -0,0 +1,3 @@ +ByJ?K>Q?A @t? V$I?WuB>d=?VQ?V=q>~W>Q?xG>d+8^_$>o2?.ޓ@3ٽxD<+7?п?24=Iz>jL*'A=D?D? +| +BKR?h@:UZ?T;%;)>faJ>s=?>*S  \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_tanh_1/model.onnx b/onnx/backend/test/data/node/test_gelu_tanh_1/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..cbb06f2b052747f6aa0a281ec0e25ef28a8fb20c GIT binary patch literal 114 zcmd|(fh^YCdYdwDo6|l&{8)xtvUN`V$dTX9!}!f*O=r)c z!>-Fb>N$;w)p;vm&V}nYEYl8HpSz6b)?j|%`Vq}RR<{gI(~eKIH{CeYAONMDpbtr2 z1%&PPxd0-bSFJV`1#ieEHonN{9~_<%4joWQDJkqKKu9sTBYW8J%41L*(6H2+$>aIc zN_dPwJ+Wl7uM8)}sYG(vBzAr#nU#%$qu5w&P{=4*?BD^U;1A$`8sMcX_d@>NO+$VP zGndIW7Imc=(!cCnD8$P3KrN{h4{sa^Bc)$QiS^2Y;FNP!-4aWuXGRXZPxQI0yI;Jc1D?=KhP8!>ov#<&t$H z9_zyJ<&s!@xwEiw((Q)9&@yDSkmXee)hUffn3TgKEHztGgqca-+IW}9+gP4Cc~a45 zTiB6TWA%OfOc=-ErAT-YeKf5`)@S}<+%HKE0H&!E1Ln+3O)QLlE}R;mWmmrdx0M{x literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/input_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..8a9445744b63f66e76c3ef4fce746606ffc6f47e GIT binary patch literal 21 Ycmd;J7GQK@tnlJtU})IS00s^A03TBWEC2ui literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_1_expanded/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..0f554cc42e247392ae38456c23643f18032a9088 GIT binary patch literal 21 bcmd;J7GQK@tn}iUFi&Y80}#YSgxdoED_sPK literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2/model.onnx b/onnx/backend/test/data/node/test_gelu_tanh_2/model.onnx new file mode 100644 index 0000000000000000000000000000000000000000..887e5c52023cb40139da847694442f1a46320594 GIT binary patch literal 130 zcmdz?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_2/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..d9ba5e0c6fbe5ba8123405c2bcd15ff1fa98a415 GIT binary patch literal 254 zcmVKZr_(KK^!5KOdwGKme%cKl^(mzPy%6KMPoh zy?+Bjz5JPqKIdelJ(G&8Kb%Y#KTN5JJ^YlIKCyjQK3Yn$KOXO|y|J!GJ_-CQzUk*z zyj{EaKGkS4KP4G1z7v@bK(wmay>~>yJYYAYz1iQlKX|yVKN&Q#J+P~KJ`-vxzF^Wo zz0Rruz6JiXJ=#2sKMg&IKemSozSjioy^6M=tP*@k&GvIsCvF3GIuN-O;1L%_!y8aD@}U@&pvwz>sBrWCav{87QB63^D*v)v z(_X+rT`0U<6bmnR;kV9vJ>MJanuzAIys1cbO2ZLG<=_a-c4v+-G3%QHcZv8Ki!&!q zD)?-3JF;4+z6+lT(>S;k5ihJx=GE}_!at11Ex`e!OiIORw315jOXF9v&KD+sEtE8* HWH!D5t*RZU literal 0 HcmV?d00001 diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb new file mode 100644 index 00000000000..bae0ffd6324 --- /dev/null +++ b/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/input_0.pb @@ -0,0 +1 @@ +BxJx?h>z?j@$ ?.z8s?bhdӽ9>(>%?^B?0= B>]ת>=?RiJ>Z/d#S'?K]?=C@(Hm;= ?2??>>Ec! >*z??Oƾmǚ6&õgڿ?xFKྙ[ G?4οYL=e> kQN>.:=ݚ>b"6 \ No newline at end of file diff --git a/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/output_0.pb b/onnx/backend/test/data/node/test_gelu_tanh_2_expanded/test_data_set_0/output_0.pb new file mode 100644 index 0000000000000000000000000000000000000000..d9ba5e0c6fbe5ba8123405c2bcd15ff1fa98a415 GIT binary patch literal 254 zcmVKZr_(KK^!5KOdwGKme%cKl^(mzPy%6KMPoh zy?+Bjz5JPqKIdelJ(G&8Kb%Y#KTN5JJ^YlIKCyjQK3Yn$KOXO|y|J!GJ_-CQzUk*z zyj{EaKGkS4KP4G1z7v@bK(wmay>~>yJYYAYz1iQlKX|yVKN&Q#J+P~KJ`-vxzF^Wo zz0Rruz6JiXJ=#2sKMg&IKemSozSjioy) and produces one +output data (Tensor) where the gaussian error linear units function, +$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise. +If the attribute "approximate" is set to "tanh", the function estimation, +$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied +to the tensor elementwise. + +)DOC"; + +static std::string gelu_default_approx = "none"; + +bool BuildContextDependentFunctionBodyGelu( + const FunctionBodyBuildContext& ctx, + const OpSchema& schema, + FunctionProto& functionProto) { + auto approx_attr_proto = ctx.getAttribute("approximate"); + std::string approximate = + approx_attr_proto != nullptr && approx_attr_proto->has_s() ? approx_attr_proto->s() : gelu_default_approx; + FunctionBuilder builder(functionProto); + + if (approximate == "tanh") { + builder.Add(R"( + Half = Constant () + HalfCast = CastLike (Half, X) + One = Constant () + OneCast = CastLike (One, X) + TwoOverPi = Constant () + TwoOverPiCast = CastLike (TwoOverPi, X) + C0 = Constant () + C0Cast = CastLike (C0, X) + SqrtTwoOverPi = Sqrt (TwoOverPiCast) + Three = Constant () + ThreeCast = CastLike (Three, X) + XCubed = Pow (X, ThreeCast) + XCubedC0 = Mul (C0Cast, XCubed) + XC0XCubed = Sum (X, XCubedC0) + TanhInput = Mul (SqrtTwoOverPi, XC0XCubed) + ErfApprox = Tanh (TanhInput) + PhiApprox = Sum (OneCast, ErfApprox) + MultX = Mul (HalfCast, X) + Y = Mul (MultX, PhiApprox) + )"); + } else { + builder.Add(R"( + Half = Constant () + HalfCast = CastLike (Half, X) + One = Constant () + OneCast = CastLike (One, X) + Two = Constant () + TwoCast = CastLike (Two, X) + SqrtTwo = Sqrt (TwoCast) + XSqrt = Div (X, SqrtTwo) + ErfXSqrt = Erf(XSqrt) + Phi = Sum (OneCast, ErfXSqrt) + MultX = Mul (HalfCast, X) + Y = Mul (MultX, Phi) + )"); + } + schema.BuildFunction(functionProto); + return true; +} + ONNX_OPERATOR_SET_SCHEMA( - Mish, - 18, + Gelu, + 20, OpSchema() - .SetDoc(mish_ver18_doc) + .SetDoc(gelu_ver20_doc) .Input(0, "X", "Input tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) .Output(0, "Y", "Output tensor", "T", OpSchema::Single, true, 1, OpSchema::Differentiable) + .Attr( + "approximate", + "Gelu approximation algorithm: `\"tanh\"`, `\"none\"`(default)." + "`\"none\"`: do not use approximation." + "`\"tanh\"`: use tanh approximation.", + AttributeProto::STRING, + gelu_default_approx) .TypeConstraint( "T", - {"tensor(float16)", "tensor(float)", "tensor(double)"}, - "Constrain input X and output types to float tensors.") - .FunctionBody(R"ONNX( - { - Softplus_X = Softplus (X) - TanHSoftplusX = Tanh (Softplus_X) - Y = Mul (X, TanHSoftplusX) - } - )ONNX") + {"tensor(float16)", "tensor(float)", "tensor(double)", "tensor(bfloat16)"}, + "Constrain input and output types to float tensors.") + .SetContextDependentFunctionBodyBuilder(BuildContextDependentFunctionBodyGelu) .TypeAndShapeInferenceFunction(propagateShapeAndTypeFromFirstInput)); static const char* Exp_ver13_doc = R"DOC( diff --git a/onnx/defs/operator_sets.h b/onnx/defs/operator_sets.h index 76473d2f9e0..a83adfd194f 100644 --- a/onnx/defs/operator_sets.h +++ b/onnx/defs/operator_sets.h @@ -1103,6 +1103,7 @@ class OpSet_Onnx_ver19 { // Forward declarations for ai.onnx version 20 class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, GridSample); +class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, Gelu); class ONNX_OPERATOR_SET_SCHEMA_CLASS_NAME(Onnx, 20, ConstantOfShape); // Iterate over schema from ai.onnx version 20 @@ -1110,6 +1111,7 @@ class OpSet_Onnx_ver20 { public: static void ForEachSchema(std::function fn) { fn(GetOpSchema()); + fn(GetOpSchema()); fn(GetOpSchema()); } }; diff --git a/onnx/test/automatic_upgrade_test.py b/onnx/test/automatic_upgrade_test.py index 248f338ce82..0277e068a79 100644 --- a/onnx/test/automatic_upgrade_test.py +++ b/onnx/test/automatic_upgrade_test.py @@ -464,6 +464,12 @@ def test_GatherElements(self) -> None: def test_GatherND(self) -> None: self._test_op_upgrade("GatherND", 11, [[1, 2, 3], [1, 2, 3]], [[1, 2]]) + def test_Gelu_approximate_tanh(self) -> None: + self._test_op_upgrade("Gelu", 20, attrs={"approximate": "tanh"}) + + def test_Gelu(self) -> None: + self._test_op_upgrade("Gelu", 20) + def test_Gemm(self) -> None: self._test_op_upgrade("Gemm", 1, [[5, 4], [4, 3], [3]], [[5, 3]]) diff --git a/onnx/test/test_backend_onnxruntime.py b/onnx/test/test_backend_onnxruntime.py index 06811d7c1d7..9a87309c15a 100644 --- a/onnx/test/test_backend_onnxruntime.py +++ b/onnx/test/test_backend_onnxruntime.py @@ -249,6 +249,7 @@ def run_node(cls, node, inputs, device=None, outputs_info=None, **kwargs): "|equal" "|identity" "|reshape" + "|gelu" ")" )