矩阵乘法/GEMM 前端,封装 InfiniCore/python/infinicore/ops/matmul.py 中的 pybind11 绑定。
def matmul(input: Tensor, other: Tensor, *, out: Optional[Tensor] = None) -> Tensorinput:左乘矩阵,形状需满足 GEMM 要求,可包含批维。other:右乘矩阵,与input的维度兼容。out:可选输出张量;若提供需与结果形状、dtype、device完全一致。
默认返回新张量;当提供 out 时调用 _infinicore.matmul_ 原地写入。底层会复用 InfiniOP GEMM 描述符完成计算。
- 支持常见数据类型(如
float16、float32、bfloat16)。 - 支持批量维度;所有批维需两输入对齐。
- 当
out与输入不处于同一设备或数据类型不匹配时,底层会抛出异常。
import infinicore as ic
device = ic.device("cuda:0")
a = ic.ones((4, 8), dtype=ic.float16, device=device)
b = ic.ones((8, 16), dtype=ic.float16, device=device)
c = ic.matmul(a, b) # 创建新张量
ic.matmul(a, b, out=c) # 原位复用输出缓冲