@@ -132,9 +132,12 @@ given NDArrayOps[Tensor] with {
132
132
133
133
134
134
extension[DType <: Supported : ClassTag : IsSupported , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape ] (arr : Tensor [DType , (Tt ,Td ,S )]) def slice [Tt2 <: TensorTypeDenotation , AxesStart <: Indices , AxesEnd <: Indices ](using tt : ValueOf [Tt2 ], td : TensorShapeDenotationOf [Td ], s2 : ShapeOf [SlicedShape [AxesStart ,AxesEnd ]], i : IndicesOf [AxesStart ], i2 : IndicesOf [AxesEnd ]): Tensor [DType , (Tt2 ,Td ,SlicedShape [AxesStart ,AxesEnd ])] = onnx.SliceV13 (" slice" , arr, indicesOf[AxesStart ], indicesOf[AxesEnd ]) // wrong denotations
135
+ /*
135
136
extension[DType <: NumericSupported : ClassTag : Numeric: IsNumericSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: Tensor[DType, (Tt,Td,S)]) def gather[Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, AxisIndex <: Index ::: INil, AxisIndices <: Indices](using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td2], s2: ShapeOf[GatheredShape[S, AxisIndex, AxisIndices]], i: IndicesOf[AxisIndex], i2: IndicesOf[AxisIndices]): Tensor[DType, (Tt2,Td2,GatheredShape[S, AxisIndex, AxisIndices])] = onnx.GatherV13("gather", indicesOf[AxisIndex], arr, indicesOf[AxisIndices])
136
137
extension[DType <: NumericSupported : ClassTag : Numeric: IsNumericSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: Tensor[DType, (Tt,Td,S)]) def flatten[Tt2 <: TensorTypeDenotation, AxisIndex <: Index ::: INil](using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td], s2: ShapeOf[FlattenedShape[S, AxisIndex]], i: IndicesOf[AxisIndex]): Tensor[DType, (Tt2,Td,FlattenedShape[S, AxisIndex])] = onnx.FlattenV13("flatten", indicesOf[AxisIndex], arr)
138
+ */
137
139
// Note: currently fixed mode, constant value
140
+
138
141
extension[DType <: NumericSupported : ClassTag : Numeric : IsNumericSupported , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape ] (arr : Tensor [DType , (Tt ,Td ,S )]) def pad [Tt2 <: TensorTypeDenotation , AxesBefore <: Shape , AxesAfter <: Shape ](constantValue : DType )(using tt : ValueOf [Tt2 ], td : TensorShapeDenotationOf [Td ], s2 : ShapeOf [PaddedShape [S ,AxesBefore ,AxesAfter ]], i : ShapeOf [AxesBefore ], i2 : ShapeOf [AxesAfter ]): Tensor [DType , (Tt2 ,Td ,PaddedShape [S ,AxesBefore ,AxesAfter ])] = onnx.PadV13 (" pad" , mode = " constant" , arr, shapeOf[AxesBefore ], shapeOf[AxesAfter ], Some (Tensor (Array (constantValue), tt.value, td.value, SNil )))
139
142
extension[DType <: Supported : ClassTag : IsSupported , Tt <: TensorTypeDenotation , Td <: TensorShapeDenotation , S <: Shape ] (arr : Tensor [DType , (Tt ,Td ,S )]) def tile [Tt2 <: TensorTypeDenotation , AxisRepeats <: Indices ](using tt : ValueOf [Tt2 ], td : TensorShapeDenotationOf [Td ], s2 : ShapeOf [TiledShape [S , AxisRepeats ]], i : IndicesOf [AxisRepeats ]): Tensor [DType , (Tt2 ,Td ,TiledShape [S , AxisRepeats ])] = onnx.TileV13 (" tile" , arr, indicesOf[AxisRepeats ])
140
143
0 commit comments