@@ -449,6 +449,14 @@ class OnnxFrontend(
449
449
emitGlobalPool(_, nodeProto),
450
450
emitters
451
451
)
452
+ case " Gather" =>
453
+ rewriteSimple(remainingProtos, emitGather(_, nodeProto), emitters)
454
+ case " Unsqueeze" =>
455
+ rewriteSimple(
456
+ remainingProtos,
457
+ emitUnsqueeze(_, nodeProto),
458
+ emitters
459
+ )
452
460
case op =>
453
461
throw new CompilerException (
454
462
s " Unsupported op ${op} ( ${nodeProto.name.get}) "
@@ -917,6 +925,69 @@ class OnnxFrontend(
917
925
)
918
926
}
919
927
928
+ private def emitGather (
929
+ context : EmitContext ,
930
+ gatherProto : NodeProto
931
+ ): Unit = {
932
+ val axisAttr = getAttr(gatherProto, " axis" ).get
933
+
934
+ require(axisAttr.`type`.get.isInt)
935
+
936
+ val axis = axisAttr.i.get
937
+
938
+ val data =
939
+ context.mm
940
+ .getPendingLongConst(gatherProto.input(0 ))
941
+ .asInstanceOf [TensorData [Long ]]
942
+
943
+ val indices = context.mm
944
+ .getPendingLongConst(gatherProto.input(1 ))
945
+ .asInstanceOf [TensorData [Long ]]
946
+
947
+ if (axis != 0 || data.shape.size != 1 || indices.shape.size != 0 )
948
+ throw new CompilerException (" Only 1D gather is supported" );
949
+
950
+ if (indices.as1D(0 ) < 0 || indices.as1D(0 ) >= data.shape(0 ))
951
+ throw new CompilerException (" Gather index is outside of data shape" );
952
+
953
+ context.mm.addPendingConst(
954
+ gatherProto.output(0 ),
955
+ new TensorData (
956
+ Shape (),
957
+ Seq (data.as1D(indices.as1D(0 ).toInt)),
958
+ org.tensorflow.framework.types.DataType .DT_INT64
959
+ )
960
+ )
961
+ }
962
+
963
+ private def emitUnsqueeze (
964
+ context : EmitContext ,
965
+ unsqueezeProto : NodeProto
966
+ ): Unit = {
967
+ val axesAttr = getAttr(unsqueezeProto, " axes" ).get
968
+
969
+ require(axesAttr.`type`.get.isInts)
970
+
971
+ val axes = axesAttr.ints
972
+
973
+ val data =
974
+ context.mm
975
+ .getPendingLongConst(unsqueezeProto.input(0 ))
976
+ .asInstanceOf [TensorData [Long ]]
977
+
978
+ if (axes.size != 1 || axes(0 ) != 0 || data.shape.size != 0 )
979
+ throw new CompilerException (" Only scalar unsqueeze is supported" );
980
+
981
+ context.mm.addPendingConst(
982
+ unsqueezeProto.output(0 ),
983
+ new TensorData (
984
+ Shape (1 ),
985
+ data.as1D,
986
+ org.tensorflow.framework.types.DataType .DT_INT64
987
+ )
988
+ )
989
+ }
990
+
920
991
private def emitConstant (
921
992
context : EmitContext ,
922
993
constantProto : NodeProto
@@ -1058,7 +1129,10 @@ class OnnxFrontend(
1058
1129
context : EmitContext ,
1059
1130
reshapeProto : NodeProto
1060
1131
): Unit = {
1061
- val shape = getTensorData(tensorProtos(reshapeProto.input(1 )))
1132
+ val shapeInputName = reshapeProto.input(1 )
1133
+ val shape = (if (tensorProtos.contains(shapeInputName))
1134
+ getTensorData(tensorProtos(shapeInputName))
1135
+ else context.mm.getPendingLongConst(shapeInputName))
1062
1136
.asInstanceOf [TensorData [Long ]]
1063
1137
.as1D
1064
1138
.map(_.toInt)
@@ -1439,6 +1513,27 @@ class OnnxFrontend(
1439
1513
org.tensorflow.framework.types.DataType .DT_FLOAT
1440
1514
)
1441
1515
)
1516
+ } else if (
1517
+ concatProto.input.forall(name =>
1518
+ context.mm.hasPendingLongConst(name) || tensorProtos.contains(name)
1519
+ )
1520
+ ) {
1521
+ val output = concatProto.input
1522
+ .map(name =>
1523
+ (if (context.mm.hasPendingLongConst(name))
1524
+ context.mm.getPendingLongConst(name)
1525
+ else getTensorData(tensorProtos(name))).as1D
1526
+ )
1527
+ .flatten
1528
+
1529
+ context.mm.addPendingConst(
1530
+ concatProto.output(0 ),
1531
+ new TensorData (
1532
+ Shape (output.size),
1533
+ output,
1534
+ org.tensorflow.framework.types.DataType .DT_INT64
1535
+ )
1536
+ )
1442
1537
} else {
1443
1538
1444
1539
if (axis != 1 )
0 commit comments