Skip to content

Commit c9fd0c5

Browse files
committedNov 2, 2021
Comment out gather and flatten, changes to Scala 3.0.1+ broke the corresponding match types; Set # iters in pytorch benchmark to match that in Scala
1 parent 99b9af5 commit c9fd0c5

File tree

4 files changed

+8
-4
lines changed

4 files changed

+8
-4
lines changed
 

‎ONNXScala/src/main/scala/ndscala/ONNXScalaOps.scala

+3
Original file line numberDiff line numberDiff line change
@@ -132,9 +132,12 @@ given NDArrayOps[Tensor] with {
132132

133133

134134
extension[DType <: Supported : ClassTag : IsSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: Tensor[DType, (Tt,Td,S)]) def slice[Tt2 <: TensorTypeDenotation, AxesStart <: Indices, AxesEnd <: Indices](using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td], s2: ShapeOf[SlicedShape[AxesStart,AxesEnd]], i: IndicesOf[AxesStart], i2: IndicesOf[AxesEnd]): Tensor[DType, (Tt2,Td,SlicedShape[AxesStart,AxesEnd])] = onnx.SliceV13("slice", arr, indicesOf[AxesStart], indicesOf[AxesEnd]) //wrong denotations
135+
/*
135136
extension[DType <: NumericSupported : ClassTag : Numeric: IsNumericSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: Tensor[DType, (Tt,Td,S)]) def gather[Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, AxisIndex <: Index ::: INil, AxisIndices <: Indices](using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td2], s2: ShapeOf[GatheredShape[S, AxisIndex, AxisIndices]], i: IndicesOf[AxisIndex], i2: IndicesOf[AxisIndices]): Tensor[DType, (Tt2,Td2,GatheredShape[S, AxisIndex, AxisIndices])] = onnx.GatherV13("gather", indicesOf[AxisIndex], arr, indicesOf[AxisIndices])
136137
extension[DType <: NumericSupported : ClassTag : Numeric: IsNumericSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: Tensor[DType, (Tt,Td,S)]) def flatten[Tt2 <: TensorTypeDenotation, AxisIndex <: Index ::: INil](using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td], s2: ShapeOf[FlattenedShape[S, AxisIndex]], i: IndicesOf[AxisIndex]): Tensor[DType, (Tt2,Td,FlattenedShape[S, AxisIndex])] = onnx.FlattenV13("flatten", indicesOf[AxisIndex], arr)
138+
*/
137139
//Note: currently fixed mode, constant value
140+
138141
extension[DType <: NumericSupported : ClassTag : Numeric: IsNumericSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: Tensor[DType, (Tt,Td,S)]) def pad[Tt2 <: TensorTypeDenotation, AxesBefore <: Shape, AxesAfter <: Shape](constantValue: DType)(using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td], s2: ShapeOf[PaddedShape[S,AxesBefore,AxesAfter]], i: ShapeOf[AxesBefore], i2: ShapeOf[AxesAfter]): Tensor[DType, (Tt2,Td,PaddedShape[S,AxesBefore,AxesAfter])] = onnx.PadV13("pad", mode = "constant", arr, shapeOf[AxesBefore], shapeOf[AxesAfter], Some(Tensor(Array(constantValue), tt.value, td.value, SNil)))
139142
extension[DType <: Supported : ClassTag : IsSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: Tensor[DType, (Tt,Td,S)]) def tile[Tt2 <: TensorTypeDenotation, AxisRepeats <: Indices](using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td], s2: ShapeOf[TiledShape[S, AxisRepeats]], i: IndicesOf[AxisRepeats]): Tensor[DType, (Tt2,Td,TiledShape[S, AxisRepeats])] = onnx.TileV13("tile", arr, indicesOf[AxisRepeats])
140143

‎ONNXScala/src/test/scala/ndscala/ONNXScalaNDArraySpec.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ type TD = "TensorShapeDenotation" ##: TSNil
257257

258258
doAssert((arr.pad[TT, 1 #: SNil, 1 #: SNil](42)) ==== expectedResult)
259259
}
260-
260+
/*
261261
"Tensor" should "gather" in {
262262
val arr = Tensor(Array(1, 2, 3, 4),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 4 #: SNil )
263263
val expectedResult = Tensor(Array(2, 4),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 2 #: SNil)
@@ -272,7 +272,7 @@ type TD = "TensorShapeDenotation" ##: TSNil
272272
273273
doAssert((res) ==== expectedResult)
274274
}
275-
275+
*/
276276
"Tensor" should "tile" in {
277277
val arr = Tensor(Array(1, 2, 3, 4),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 4 #: SNil )
278278
val expectedResult = Tensor(Array(1, 2, 3, 4, 1, 2, 3, 4),"TensorTypeDenotation", "TensorShapeDenotation" ##: TSNil, 8 #: SNil)

‎core/src/main/scala/ndscala/NDArrayOps.scala

+2-1
Original file line numberDiff line numberDiff line change
@@ -60,10 +60,11 @@ trait NDArrayOps[SomeNDArray[_ <: AllSupported, _ <: Axes]] {
6060
// extension[DType <: Supported : ClassTag : IsSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: SomeNDArray[DType, (Tt,Td,S)]) def slice[Tt1 <: TensorTypeDenotation, Td1 <: TensorShapeDenotation, S1 <: Shape](start: Int, end: Int)(using tt: ValueOf[Tt], td: TensorShapeDenotationOf[Td], s: ShapeOf[S],tt1: ValueOf[Tt1], td1: TensorShapeDenotationOf[Td1], s1: ShapeOf[S1]): SomeNDArray[DType, (Tt1,Td1,S1)]
6161

6262
extension[DType <: Supported : ClassTag : IsSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: SomeNDArray[DType, (Tt,Td,S)]) def slice[Tt2 <: TensorTypeDenotation, AxesStart <: Indices, AxesEnd <: Indices](using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td], s2: ShapeOf[SlicedShape[AxesStart,AxesEnd]], i: IndicesOf[AxesStart], i2: IndicesOf[AxesEnd]): SomeNDArray[DType, (Tt2,Td,SlicedShape[AxesStart,AxesEnd])]
63-
63+
/*
6464
extension[DType <: NumericSupported : ClassTag : Numeric: IsNumericSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: SomeNDArray[DType, (Tt,Td,S)]) def gather[Tt2 <: TensorTypeDenotation, Td2 <: TensorShapeDenotation, AxisIndex <: Index ::: INil, AxisIndices <: Indices](using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td2], s2: ShapeOf[GatheredShape[S, AxisIndex, AxisIndices]], i: IndicesOf[AxisIndex], i2: IndicesOf[AxisIndices]): SomeNDArray[DType, (Tt2,Td2,GatheredShape[S, AxisIndex, AxisIndices])]
6565
6666
extension[DType <: NumericSupported : ClassTag : Numeric: IsNumericSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: SomeNDArray[DType, (Tt,Td,S)]) def flatten[Tt2 <: TensorTypeDenotation, AxisIndex <: Index ::: INil](using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td], s2: ShapeOf[FlattenedShape[S, AxisIndex]], i: IndicesOf[AxisIndex]): SomeNDArray[DType, (Tt2,Td,FlattenedShape[S, AxisIndex])]
67+
*/
6768
extension[DType <: Supported : ClassTag : IsSupported, Tt <: TensorTypeDenotation, Td <: TensorShapeDenotation, S <: Shape] (arr: SomeNDArray[DType, (Tt,Td,S)]) def tile[Tt2 <: TensorTypeDenotation, AxisRepeats <: Indices](using tt: ValueOf[Tt2], td: TensorShapeDenotationOf[Td], s2: ShapeOf[TiledShape[S, AxisRepeats]], i: IndicesOf[AxisRepeats]): SomeNDArray[DType, (Tt2,Td,TiledShape[S, AxisRepeats])]
6869

6970

‎elevenLinesBenchmarkPyTorch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88

99
print(X.shape)
1010
print(X.dtype)
11-
iter = 1
11+
iter = 5
1212

1313
def elevenlines():
1414
syn0 = 2*torch.randn((10000,10000)).float() - 1

0 commit comments

Comments
 (0)
Please sign in to comment.