Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions op_acc_stable_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def check_tensor_aadiff(x, y):

assert x.dtype == y.dtype
assert x.shape == y.shape
if x.dtype == paddle.bfloat16:
if x.dtype in [paddle.bool, paddle.bfloat16]:
x = x.astype(paddle.float32)
y = y.astype(paddle.float32)
assert paddle.max(paddle.abs(x - y)).numpy()[0] == 0, "aadiff check failed"
assert paddle.max(paddle.abs(x - y)).numpy() == 0, "aadiff check failed"


def check_aadiff(x, y):
Expand Down
52 changes: 52 additions & 0 deletions tests/gelu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from op_acc_stable_run import check_tensor_diff, op_acc_stable_run

class GeluTest:
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype

def set_configs(self, paddle):
self.tmp_cache_path = "."
self.inputs = {
"x": paddle.randn(self.shape, dtype=self.dtype),
"y_grad": paddle.randn(self.shape, dtype=self.dtype),
}

def run_paddle(self, paddle):
x = self.inputs["x"]
y = paddle.nn.functional.gelu(x)
y.backward(self.inputs["y_grad"])
return y, x.grad

def run_torch(self, torch):
x = self.inputs["x"]
y = torch.nn.functional.gelu(x)
y.backward(self.inputs["y_grad"])
return y, x.grad

def check_diff(self, paddle, pd_ret, th_ret):
assert len(pd_ret) == len(th_ret)
for pd, th in zip(pd_ret, th_ret):
check_tensor_diff(pd, th, atol=1e-6, rtol=1e-6)

if __name__ == "__main__":
op_acc_stable_run(GeluTest(shape=[1, 12288], dtype="float32"))
op_acc_stable_run(GeluTest(shape=[1, 12288], dtype="float16"))
op_acc_stable_run(GeluTest(shape=[1, 12288], dtype="bfloat16"))
op_acc_stable_run(GeluTest(shape=[1, 4096, 24576], dtype="float32"))
op_acc_stable_run(GeluTest(shape=[1, 4096, 24576], dtype="float16"))
op_acc_stable_run(GeluTest(shape=[1, 4096, 24576], dtype="bfloat16"))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修复最后一行的这个问题。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个最后的一行的问题指什么?我不太确定,指最后一行缺少空行吗?

2 changes: 1 addition & 1 deletion tests/softmax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def check_diff(self, paddle, pd_ret, th_ret):


if __name__ == "__main__":
op_acc_stable_run(SoftmaxTest)
op_acc_stable_run(SoftmaxTest)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修复最后一行的这个问题。