From d3b6f5da1a3f2e774c67e203cfd374d789ef0ed3 Mon Sep 17 00:00:00 2001 From: xadupre Date: Tue, 11 Jun 2024 09:43:30 +0000 Subject: [PATCH] Add missing documentation for fused kernels --- operators/cuda/negxplus1.h | 3 +++ operators/cuda/scatter_nd_of_shape.h | 9 +++++++++ operators/cuda/transpose_cast.h | 3 +++ 3 files changed, 15 insertions(+) diff --git a/operators/cuda/negxplus1.h b/operators/cuda/negxplus1.h index 5460c37a2..5ff53d357 100644 --- a/operators/cuda/negxplus1.h +++ b/operators/cuda/negxplus1.h @@ -8,6 +8,9 @@ namespace contrib { +/** +* NegXPlus1(X) = 1 - X +*/ template struct NegXPlus1 { template diff --git a/operators/cuda/scatter_nd_of_shape.h b/operators/cuda/scatter_nd_of_shape.h index 239c2b5e6..610454d42 100644 --- a/operators/cuda/scatter_nd_of_shape.h +++ b/operators/cuda/scatter_nd_of_shape.h @@ -8,6 +8,9 @@ namespace contrib { +/** +* ScatterNDOfShape(shape, indices, updates) = ScatterND(ConstantOfShape(shape, value=0), indices, updates) +*/ template struct ScatterNDOfShape { OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { @@ -71,6 +74,12 @@ struct ScatterNDOfShape { }; +/** +* MaskedScatterNDOfShape(shape, indices, updates) = ScatterND(ConstantOfShape(shape, value=0), +* indices[indices != maskedValue], +* updates[indices != maskedValue]) +* +*/ template struct MaskedScatterNDOfShape { OrtStatusPtr OnModelAttach(const OrtApi& api, const OrtKernelInfo& info) { diff --git a/operators/cuda/transpose_cast.h b/operators/cuda/transpose_cast.h index 6ffae51c2..92e1f8a23 100644 --- a/operators/cuda/transpose_cast.h +++ b/operators/cuda/transpose_cast.h @@ -8,6 +8,9 @@ namespace contrib { +/** +* Transpose2DCast(X, to=to) = Cast(Transpose(X, perm=[1, 0]), to=to) +*/ template struct Transpose2DCast { template