diff --git a/operators/cuda/rotary.h b/operators/cuda/rotary.h index a365c2470..b5e27b877 100644 --- a/operators/cuda/rotary.h +++ b/operators/cuda/rotary.h @@ -8,6 +8,16 @@ namespace contrib { +/** +* Y = Rotary(X) is equivalent to if side == LEFT: +* +* N = X.shape[-1] +* Y = X.copy() +* Y[...,:N/2] = X[...,N/2:] +* Y[...,N/2:] = -X[...,:N/2] +* +* And the opposite if side == RIGHT. +*/ template struct Rotary { template