diff --git a/operators/cuda/add_mul.h b/operators/cuda/add_mul.h index b51282f94..1aea73e1f 100644 --- a/operators/cuda/add_mul.h +++ b/operators/cuda/add_mul.h @@ -29,6 +29,14 @@ inline void _FillOutputShape3Op(std::vector& dimsA, } } +/** +* AddOrMulSharedInput(A, B, C) = A + B, A + C ifaddition is true +* AddOrMulSharedInput(A, B, C) = A * B, A * C ifaddition is false +* +* The operator supports broadcast on first dimensions. +* A[1, J] + B[I, J] is supported, +* A[1, J, 1] + B[I, J, K] is not supported, +*/ template struct AddOrMulSharedInput { template @@ -61,6 +69,14 @@ struct AddOrMulSharedInput { } }; +/** +* AddOrMulTwice(A, B, C) = A + B + C ifaddition is true +* AddOrMulTwice(A, B, C) = A * B * C ifaddition is false +* +* The operator supports broadcast on first dimensions. +* A[1, J] + B[I, J] is supported, +* A[1, J, 1] + B[I, J, K] is not supported, +*/ template struct AddOrMulTwice { template @@ -97,6 +113,17 @@ struct AddOrMulTwice { } }; +/** +* AddAndMul(A, B, C) = (A + B) * C if addition_first is true +* AddAndMul(A, B, C) = A * B + C if addition_first is false +* +* The operator supports broadcast on first dimensions. +* A[1, J] + B[I, J] is supported, +* A[1, J, 1] + B[I, J, K] is not supported, +* +* If switchMiddleAxis is true, then the output is transposed, then +* AddAndMul(A, B, C, switchMiddleAxis=1) = Transpose((A + B) * C, perm=[0, 2, 1, 3]) +*/ template struct AddAndMul { template @@ -154,6 +181,17 @@ struct AddAndMul { bool switchMiddelAxis_; }; +/** +* SubAndMul(A, B, C) = (A - B) * C if subtract_first is true +* SubAndMul(A, B, C) = A * B - C if subtract_first is false +* +* The operator supports broadcast on first dimensions. +* A[1, J] + B[I, J] is supported, +* A[1, J, 1] + B[I, J, K] is not supported, +* +* If negative is true, then the output is transposed, then +* SubAndMul(A, B, C, negative=1) = (B - A) * C +*/ template struct SubAndMul { template