Skip to content
Open
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions op_acc_stable_run.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,10 +92,10 @@ def check_tensor_aadiff(x, y):

assert x.dtype == y.dtype
assert x.shape == y.shape
if x.dtype == paddle.bfloat16:
if x.dtype in [paddle.bool, paddle.bfloat16]:
x = x.astype(paddle.float32)
y = y.astype(paddle.float32)
assert paddle.max(paddle.abs(x - y)).numpy()[0] == 0, "aadiff check failed"
assert paddle.max(paddle.abs(x - y)).numpy() == 0, "aadiff check failed"


def check_aadiff(x, y):
Expand Down Expand Up @@ -128,7 +128,7 @@ def op_acc_stable_run(test_obj, stable_num=100):
ret = []
tmp_cache_path = getattr(test_obj, "tmp_cache_path", None)
if not tmp_cache_path:
tmp_cache_path = os.getenv("TMP_CACHE_PATH", "/dev/shm")
tmp_cache_path = os.getenv("TMP_CACHE_PATH", "/home")
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不要改这个,这个具体设置可以在你的单测里改。
比如单测里写self.tmp_cache_path = "."

with tempfile.TemporaryDirectory(dir=tmp_cache_path) as path:
input_pickle_path = os.path.join(path, "inputs.bin")
with open(input_pickle_path, "wb") as f:
Expand Down Expand Up @@ -168,6 +168,7 @@ def op_acc_stable_run(test_obj, stable_num=100):
else:
with {framework}.no_grad():
check_aadiff(prev_ret, outputs)
print(i)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除无用代码。


if stable_num > 1:
print(f'AAdiff check passed after {stable_num} runs')
Expand Down
51 changes: 51 additions & 0 deletions tests/gelu_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from op_acc_stable_run import check_tensor_diff, op_acc_stable_run

class GeluTest:
def __init__(self, shape, dtype):
self.shape = shape
self.dtype = dtype

def set_configs(self, paddle):
self.inputs = {
"x": paddle.randn(self.shape, dtype=self.dtype),
"y_grad": paddle.randn(self.shape, dtype=self.dtype),
}

def run_paddle(self, paddle):
x = self.inputs["x"]
y = paddle.nn.functional.gelu(x)
y.backward(self.inputs["y_grad"])
return y, x.grad

def run_torch(self, torch):
x = self.inputs["x"]
y = torch.nn.functional.gelu(x)
y.backward(self.inputs["y_grad"])
return y, x.grad

def check_diff(self, paddle, pd_ret, th_ret):
assert len(pd_ret) == len(th_ret)
for pd, th in zip(pd_ret, th_ret):
check_tensor_diff(pd, th, atol=1e-6, rtol=1e-6)

if __name__ == "__main__":
op_acc_stable_run(GeluTest(shape=[1, 12288], dtype="float32"))
op_acc_stable_run(GeluTest(shape=[1, 12288], dtype="float16"))
op_acc_stable_run(GeluTest(shape=[1, 12288], dtype="bfloat16"))
op_acc_stable_run(GeluTest(shape=[1, 4096, 24576], dtype="float32"))
op_acc_stable_run(GeluTest(shape=[1, 4096, 24576], dtype="float16"))
op_acc_stable_run(GeluTest(shape=[1, 4096, 24576], dtype="bfloat16"))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修复最后一行的这个问题。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个最后的一行的问题指什么?我不太确定,指最后一行缺少空行吗?

2 changes: 1 addition & 1 deletion tests/softmax_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,4 @@ def check_diff(self, paddle, pd_ret, th_ret):


if __name__ == "__main__":
op_acc_stable_run(SoftmaxTest)
op_acc_stable_run(SoftmaxTest)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修复最后一行的这个问题。