Skip to content

Commit

Permalink
rename op
Browse files Browse the repository at this point in the history
  • Loading branch information
BolinSNLHM committed Mar 23, 2024
1 parent c53af56 commit 2bd45ef
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# pylint: disable=redefined-builtin
from .matmul import batch_matmul, matmul, matmul_x86, matmul_cublas
from .matmul import batch_matmul, matmul, batch_matmul_x86, matmul_cublas
from .conv1d import conv1d, conv1d_gemm
from .conv1d_transpose import conv1d_transpose
from .conv2d import conv2d, conv2d_channel_last, conv2d_winograd, conv2d_gemm, conv2d_gemm_fp16
Expand Down
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/matmul/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,4 @@


from .matmul_f32_x86 import Matmulx86Op, MatmulF32Taskx86
from .matmul_f32_x86 import matmul_x86
from .matmul_f32_x86 import batch_matmul_x86
2 changes: 1 addition & 1 deletion python/hidet/graph/ops/matmul/matmul_f32_x86.py
Original file line number Diff line number Diff line change
Expand Up @@ -858,5 +858,5 @@ def __init__(self, a: Tensor, b: Tensor):
super().__init__(inputs=[a, b], attributes={}, task=task)


def matmul_x86(a: Tensor, b: Tensor) -> Tensor:
def batch_matmul_x86(a: Tensor, b: Tensor) -> Tensor:
return Matmulx86Op(a, b).outputs[0]
2 changes: 1 addition & 1 deletion tests/operators/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_matmul_x86(a_shape, b_shape):
a_shape,
b_shape,
lambda x, y: np.matmul(x, y),
lambda x, y: ops.matmul_x86(x, y) - ops.matmul_x86(x, y) + ops.matmul_x86(x, y),
lambda x, y: ops.batch_matmul_x86(x, y) - ops.batch_matmul_x86(x, y) + ops.batch_matmul_x86(x, y),
dtype="float32",
atol=1e-4,
rtol=1e-4,
Expand Down

0 comments on commit 2bd45ef

Please sign in to comment.