Skip to content

Commit

Permalink
Implement GELU as function op (onnx#5277)
Browse files Browse the repository at this point in the history
### Description
These changes have been made to support the GELU operator as a function
op.

### Motivation and Context
Support for [GELU: Gaussian Error Linear
Unit](https://paperswithcode.com/method/gelu) activation function, which
was requested in onnx#4933.
onnx#4423 also mentions this under the new ops section of `Contributions
Welcome`.

As per the discussion in onnx#4933, I have added GELU as a context-dependent
function-op, that uses the attribute `approximate` to return one of the
two possible function-body definitions.

The first function definition is the regular GELU:
`GELU(x)=x∗Φ(x) = 0.5 * x * (1 + erf(x / sqrt(2)))`

The second is the fast approximation based on `tanh`:
`GELU(x)=0.5 ∗ x ∗ (1+Tanh( sqrt(2/π) ∗ (x + 0.044715 ∗ x^3)))`

This implementation uses the [PyTorch docs for
GELU](https://pytorch.org/docs/stable/generated/torch.nn.GELU.html?highlight=gelu#torch.nn.GELU)
as a reference.

PS: I also refactored `onnx/defs/math/defs.cc` to bring the operator
implementation of `mish` right next to its doc string.

---------

Signed-off-by: pranshupant <[email protected]>
Co-authored-by: G. Ramalingam <[email protected]>
  • Loading branch information
pranshupant and gramalingam authored Jul 10, 2023
1 parent d8634f1 commit c0033eb
Show file tree
Hide file tree
Showing 32 changed files with 355 additions and 13 deletions.
42 changes: 42 additions & 0 deletions docs/Changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -23919,6 +23919,48 @@ This version of the operator has been available since version 20 of the default
<dd>Constrain output types to be numerics.</dd>
</dl>

### <a name="Gelu-20"></a>**Gelu-20**</a>

Gelu takes one input data (Tensor<T>) and produces one
output data (Tensor<T>) where the gaussian error linear units function,
$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise.
If the attribute "approximate" is set to "tanh", the function estimation,
$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied
to the tensor elementwise.


#### Version

This version of the operator has been available since version 20 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>approximate</tt> : string (default is none)</dt>
<dd>Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>X</tt> (differentiable) : T</dt>
<dd>Input tensor</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> (differentiable) : T</dt>
<dd>Output tensor</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>

### <a name="GridSample-20"></a>**GridSample-20**</a>

Given an input `X` and a flow-field `grid`, computes the output `Y` using `X` values and pixel locations from the `grid`.
Expand Down
96 changes: 96 additions & 0 deletions docs/Operators.md
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ For an operator input/output's differentiability, it can be differentiable,
|<a href="#Clip">Clip</a>|<a href="Changelog.md#Clip-13">13</a>, <a href="Changelog.md#Clip-12">12</a>, <a href="Changelog.md#Clip-11">11</a>, <a href="Changelog.md#Clip-6">6</a>, <a href="Changelog.md#Clip-1">1</a>|13|
|<a href="#DynamicQuantizeLinear">DynamicQuantizeLinear</a>|<a href="Changelog.md#DynamicQuantizeLinear-11">11</a>|11|
|<a href="#Elu">Elu</a>|<a href="Changelog.md#Elu-6">6</a>, <a href="Changelog.md#Elu-1">1</a>|18|
|<a href="#Gelu">Gelu</a>|<a href="Changelog.md#Gelu-20">20</a>|20|
|<a href="#GreaterOrEqual">GreaterOrEqual</a>|<a href="Changelog.md#GreaterOrEqual-16">16</a>, <a href="Changelog.md#GreaterOrEqual-12">12</a>|16|
|<a href="#GroupNormalization">GroupNormalization</a>|<a href="Changelog.md#GroupNormalization-18">18</a>|18|
|<a href="#HammingWindow">HammingWindow</a>|<a href="Changelog.md#HammingWindow-17">17</a>|17|
Expand Down Expand Up @@ -9410,6 +9411,101 @@ expect(
</details>


### <a name="Gelu"></a><a name="gelu">**Gelu**</a>

Gelu takes one input data (Tensor<T>) and produces one
output data (Tensor<T>) where the gaussian error linear units function,
$y = 0.5 * x * (1 + erf(x/sqrt(2)))$ is applied to the tensor elementwise.
If the attribute "approximate" is set to "tanh", the function estimation,
$y = 0.5 * x * (1 + Tanh(sqrt(2/\pi) * (x + 0.044715 * x^3)))$ is used and applied
to the tensor elementwise.


#### Version

This version of the operator has been available since version 20 of the default ONNX operator set.

#### Attributes

<dl>
<dt><tt>approximate</tt> : string (default is none)</dt>
<dd>Gelu approximation algorithm: `"tanh"`, `"none"`(default).`"none"`: do not use approximation.`"tanh"`: use tanh approximation.</dd>
</dl>

#### Inputs

<dl>
<dt><tt>X</tt> (differentiable) : T</dt>
<dd>Input tensor</dd>
</dl>

#### Outputs

<dl>
<dt><tt>Y</tt> (differentiable) : T</dt>
<dd>Output tensor</dd>
</dl>

#### Type Constraints

<dl>
<dt><tt>T</tt> : tensor(float16), tensor(float), tensor(double), tensor(bfloat16)</dt>
<dd>Constrain input and output types to float tensors.</dd>
</dl>


#### Examples

<details>
<summary>gelu_default</summary>

```python
node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"])

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.15865526, 0., 0.84134474]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.99595031, 3.99987331, 4.99999857]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2")
```

</details>


<details>
<summary>gelu_tanh</summary>

```python
node = onnx.helper.make_node(
"Gelu", inputs=["x"], outputs=["y"], approximate="tanh"
)

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.158808, 0., 0.841192]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.9963627, 3.99993, 4.9999995]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2")
```

</details>


### <a name="Gemm"></a><a name="gemm">**Gemm**</a>

General Matrix multiplication:
Expand Down
52 changes: 51 additions & 1 deletion docs/TestCoverage.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
* [Overall Test Coverage](#overall-test-coverage)
# Node Test Coverage
## Summary
Node tests have covered 173/186 (93.01%, 5 generators excluded) common operators.
Node tests have covered 174/187 (93.05%, 5 generators excluded) common operators.

Node tests have covered 0/0 (N/A) experimental operators.

Expand Down Expand Up @@ -6241,6 +6241,56 @@ expect(
</details>


### Gelu
There are 2 test cases, listed as following:
<details>
<summary>gelu_default</summary>

```python
node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"])

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.15865526, 0., 0.84134474]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.99595031, 3.99987331, 4.99999857]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2")
```

</details>
<details>
<summary>gelu_tanh</summary>

```python
node = onnx.helper.make_node(
"Gelu", inputs=["x"], outputs=["y"], approximate="tanh"
)

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.158808, 0., 0.841192]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.9963627, 3.99993, 4.9999995]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2")
```

</details>


### Gemm
There are 11 test cases, listed as following:
<details>
Expand Down
51 changes: 51 additions & 0 deletions onnx/backend/test/case/node/gelu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) ONNX Project Contributors
#
# SPDX-License-Identifier: Apache-2.0

import math

import numpy as np

import onnx
from onnx.backend.test.case.base import Base
from onnx.backend.test.case.node import expect


class Gelu(Base):
@staticmethod
def export_gelu_tanh() -> None:
node = onnx.helper.make_node(
"Gelu", inputs=["x"], outputs=["y"], approximate="tanh"
)

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.158808, 0., 0.841192]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.9963627, 3.99993, 4.9999995]
y = (
0.5
* x
* (1 + np.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3))))
).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_tanh_2")

@staticmethod
def export_gelu_default() -> None:
node = onnx.helper.make_node("Gelu", inputs=["x"], outputs=["y"])

x = np.array([-1, 0, 1]).astype(np.float32)
# expected output [-0.15865526, 0., 0.84134474]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_1")

x = np.random.randn(3, 4, 5).astype(np.float32)
# expected output [2.99595031, 3.99987331, 4.99999857]
y = (0.5 * x * (1 + np.vectorize(math.erf)(x / np.sqrt(2)))).astype(np.float32)
expect(node, inputs=[x], outputs=[y], name="test_gelu_default_2")
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ByJ��?�K�>��Q?A�@t��? V$���I?�W��uB���>�d�=��?VQ?V��=q��>�~W>�Q�?�x�G>d�+�8^_�$��>�o2?�.�ޓ@�3ٽxD�<+7����?п�?24�=I�z>jL*��'A������=D�?D�?��
�|��
��B����KR�?h@�:U�Z����?T��;%��;�)����>�f�����a�J�>�s=��?>��*�S �
Expand Down
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
ByJ��?�K�>��Q?A�@t��? V$���I?�W��uB���>�d�=��?VQ?V��=q��>�~W>�Q�?�x�G>d�+�8^_�$��>�o2?�.�ޓ@�3ٽxD�<+7����?п�?24�=I�z>jL*��'A������=D�?D�?��
�|��
��B����KR�?h@�:U�Z����?T��;%��;�)����>�f�����a�J�>�s=��?>��*�S �
Expand Down
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Binary file not shown.
Binary file not shown.
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
BxJ�x��?h��>��z?�j@$ �?�.z��8s?b��hdӽ�9�>(�>�%�?^�B?�0�= B�>]ת>�=�?R�iJ�>�Z�/d#��S'?�K]?��=��C@�(��Hm;= �?�2�?��?��>���>�Ec������!��� >*z�?��?�Oƾmǚ��6��&õ�gڿ��?�x�FKྙ[��� G?4�ο��Y�L=e��> �����k��QN�>.:�=�ݚ>�b"�6���
Binary file not shown.
Loading

0 comments on commit c0033eb

Please sign in to comment.