Skip to content

Commit 1c9675c

Browse files
committed
[TOSA] Conv3d legalization
- Extend Torch to TOSA conversion with full aten.conv3d coverage, mapping padding/stride/dilation configs and bias handling into canonical TOSA conv ops. - Refactor TorchToTosa lowering utilities to share more code between conv variants. - Update conversion tests plus PT1 xfail lists to reflect the newly supported 3D convolution paths. Signed-off-by: Cathal Corbett <[email protected]> Change-Id: Iaf1261e121dec1ddb84b814e9058bb7ebadd4de7
1 parent 42c3c29 commit 1c9675c

File tree

6 files changed

+720
-253
lines changed

6 files changed

+720
-253
lines changed

include/torch-mlir/Conversion/TorchToTosa/TosaLegalizeUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ LogicalResult getConvOpsAccType(PatternRewriter &rewriter,
105105
FailureOr<Value> getConvBiasForNoneType(Operation *op,
106106
PatternRewriter &rewriter,
107107
Type inputElemTy, Type outputElemTy,
108-
ArrayRef<int64_t> weightShape);
108+
int64_t numOutputChannels);
109109

110110
// Emit an explicit zero-valued `tosa.pad` around an NHWC tensor so that later
111111
// avg_pool lowering can run with `pad = 0`. `padExtents` is ordered as

0 commit comments

Comments
 (0)