@@ -934,3 +934,19 @@ func.func @channel_shuffle(%arg0: !torch.vtensor<[1,8,4,4],f32>) -> !torch.vtens
934934 %0 = torch.aten.channel_shuffle %arg0 , %int4 : !torch.vtensor <[1 ,8 ,4 ,4 ],f32 >, !torch.int -> !torch.vtensor <[1 ,8 ,4 ,4 ],f32 >
935935 return %0 : !torch.vtensor <[1 ,8 ,4 ,4 ],f32 >
936936}
937+
938+ // -----
939+
940+ // CHECK-LABEL: func.func @torch.aten.as_strided$static_shapes
941+ func.func @torch.aten.as_strided$static_shapes (%arg0: !torch.vtensor <[4 ,8 ],f32 >) -> !torch.vtensor <[2 ,3 ],f32 > {
942+ %int0 = torch.constant.int 0
943+ %int2 = torch.constant.int 2
944+ %int3 = torch.constant.int 3
945+ %int1 = torch.constant.int 1
946+ %size = torch.prim.ListConstruct %int2 , %int3 : (!torch.int , !torch.int ) -> !torch.list <int >
947+ %stride = torch.prim.ListConstruct %int1 , %int1 : (!torch.int , !torch.int ) -> !torch.list <int >
948+ // CHECK: torch.aten.view {{.*}} -> !torch.vtensor<[2,1],si64>
949+ // CHECK: torch.aten.view {{.*}} -> !torch.vtensor<[1,3],si64>
950+ %0 = torch.aten.as_strided %arg0 , %size , %stride , %int0 : !torch.vtensor <[4 ,8 ],f32 >, !torch.list <int >, !torch.list <int >, !torch.int -> !torch.vtensor <[2 ,3 ],f32 >
951+ return %0 : !torch.vtensor <[2 ,3 ],f32 >
952+ }
0 commit comments