Skip to content

Commit

Permalink
g++ build fix (#2778)
Browse files Browse the repository at this point in the history
Introduced in 704cfda of @wu-s-john 

g++ compiler error: 

Pooling.cpp:177:13: error: explicit specialization in non-namespace
scope ‘class

Design looks good, g++ is just freaking out for no good reason.
Un-nesting the template classes fixes the error.

We don't have g++ CI. This hopefully happens infrequently enough that we
can just fix manually. My service to those folks who really like
building with g++... :)
  • Loading branch information
newling authored Jan 20, 2024
1 parent 2f49240 commit 50ac3b1
Showing 1 changed file with 21 additions and 15 deletions.
36 changes: 21 additions & 15 deletions lib/Conversion/TorchToLinalg/Pooling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,8 @@ checkAndGetPoolingParameters(OpTy op, ConversionPatternRewriter &rewriter,
return success();
}

static Value computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter,
static Value
computeOutputTensor(Operation *op, ConversionPatternRewriter &rewriter,
Value self, int64_t dimensionality, bool ceilMode,
SmallVectorImpl<int64_t> &strideInts,
SmallVectorImpl<int64_t> &paddingInts,
Expand Down Expand Up @@ -167,24 +168,29 @@ static LogicalResult createPoolingOp(
}

namespace {

template <typename T> struct DimensionTraits {};

template <> struct DimensionTraits<AtenMaxPool2dOp> {
static constexpr int64_t Dim = 2;
// unused const variable warning suppression:
static_assert(Dim == Dim);
};

template <> struct DimensionTraits<AtenMaxPool3dOp> {
static constexpr int64_t Dim = 3;
// unused const variable warning suppression:
static_assert(Dim == Dim);
};

template <typename OpTy>
class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
using OpConversionPattern<OpTy>::OpConversionPattern;

private:
template <typename T> struct DimensionTraits;

template <> struct DimensionTraits<AtenMaxPool2dOp> {
static const int64_t Dim = 2;
};

template <> struct DimensionTraits<AtenMaxPool3dOp> {
static const int64_t Dim = 3;
};

static const int64_t Dim = DimensionTraits<OpTy>::Dim;

LogicalResult createPoolingMax3D(AtenMaxPool3dOp &op,
LogicalResult createPoolingMax3D(AtenMaxPool3dOp &op,
typename OpTy::Adaptor adaptor,
ConversionPatternRewriter &rewriter,
SmallVectorImpl<Value> &kernelSizeIntValues,
Expand Down Expand Up @@ -327,9 +333,9 @@ class ConvertAtenMaxPoolOp : public OpConversionPattern<OpTy> {
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, maxPool2d);
return success();
} else {
return createPoolingMax3D(op, adaptor, rewriter,
kernelSizeIntValues, strideInts, paddingInts,
dilationInts, ceilMode);
return createPoolingMax3D(op, adaptor, rewriter, kernelSizeIntValues,
strideInts, paddingInts, dilationInts,
ceilMode);
}
}
};
Expand Down

0 comments on commit 50ac3b1

Please sign in to comment.