Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 40 additions & 138 deletions lib/Conversion/TorchToTosa/TorchToTosa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9334,178 +9334,80 @@ class ConvertUpsampleNearest2dForward
LogicalResult
matchAndRewriteImpl(AtenOpT op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
// aten.upsample_nearest2d lowering process:
// 1. Reshape input: (N, C, H, W) -> (N, C, H x W)
// 2. Calculate PyTorch-styled gather op indices based on the following
// formula (based on Torch to Linalg UpsampleNearest2d lowering formula):
// for i in range(N x C):
// for heightIndex in range(scaledHeight):
// for widthIndex in range(scaledWidth):
// indices.append(int(heightIndex // scalesH * selfWidth +
// widthIndex // scalesW))
// 3. Convert PyTorch-styled indices to TensorFlow-styled indices
// 4. Apply TensorFlow-styled ConverGatherOpNd to retrieve the output
// 5. Reshape output to desired output shape
Value self;
Value input;
if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dOp>()) {
self = adaptor.getSelf();
input = adaptor.getSelf();
} else if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dVecOp>()) {
self = adaptor.getInput();
input = adaptor.getInput();
} else {
return rewriter.notifyMatchFailure(
op, "Expected either AtenUpsampleNearest2dOp or "
"AtenUpsampleNearest2dVecOp");
}

auto selfType = dyn_cast<TensorType>(self.getType());
if (!selfType)
auto inputTy = dyn_cast<RankedTensorType>(input.getType());
if (!inputTy) {
return rewriter.notifyMatchFailure(op, "Only tensor types are supported");
}
if (inputTy.getRank() != 4) {
return rewriter.notifyMatchFailure(op, "TOSA resize() requires rank 4");
}

auto selfShape = selfType.getShape();
auto selfRank = selfType.getRank();
auto selfElemTy = selfType.getElementType();

auto selfHeight = selfShape[selfRank - 2];
auto selfWidth = selfShape[selfRank - 1];

auto resultType = dyn_cast<TensorType>(
OpConversionPattern<AtenOpT>::getTypeConverter()->convertType(
op.getType()));
auto resultShape = resultType.getShape();
auto resultElemTy = resultType.getElementType();
auto inputShape = inputTy.getShape();

// Get op's parameters
SmallVector<int64_t> outputSize;
SmallVector<double> scaleFactors;
double scalesH;
double scalesW;
int64_t outputHeight;
int64_t outputWidth;

if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dOp>()) {
SmallVector<int64_t> outputSize;
if (!matchPattern(op.getOutputSize(),
m_TorchListOfConstantInts(outputSize)))
m_TorchListOfConstantInts(outputSize))) {
return rewriter.notifyMatchFailure(
op, "Non-constant output size not supported");
}

outputHeight = outputSize[0];
outputWidth = outputSize[1];

if (isa<Torch::NoneType>(op.getScalesH().getType())) {
scalesH =
static_cast<double>(outputHeight) / static_cast<double>(selfHeight);
} else {
if (!matchPattern(op.getScalesH(), m_TorchConstantFloat(&scalesH)))
return rewriter.notifyMatchFailure(
op, "Non-constant height scales not supported");

scalesH = std::ceil(scalesH);
}

if (isa<Torch::NoneType>(op.getScalesW().getType())) {
scalesW =
static_cast<double>(outputWidth) / static_cast<double>(selfWidth);
} else {
if (!matchPattern(op.getScalesW(), m_TorchConstantFloat(&scalesW)))
return rewriter.notifyMatchFailure(
op, "Non-constant width scales not supported");

scalesW = std::ceil(scalesW);
}
} else if constexpr (std::is_same<AtenOpT, AtenUpsampleNearest2dVecOp>()) {
auto isOutputSizeNone =
isa<Torch::NoneType>(op.getOutputSize().getType());
auto isScaleFactorsNone =
isa<Torch::NoneType>(op.getScaleFactors().getType());

if ((isOutputSizeNone && isScaleFactorsNone) ||
(!isOutputSizeNone && !isScaleFactorsNone))
return rewriter.notifyMatchFailure(
op, "Must specify exactly one of output size and scale factors");

if (!isOutputSizeNone) {
if (!isa<Torch::NoneType>(op.getOutputSize().getType())) {
SmallVector<int64_t> outputSize;
if (!matchPattern(op.getOutputSize(),
m_TorchListOfConstantInts(outputSize)))
m_TorchListOfConstantInts(outputSize))) {
return rewriter.notifyMatchFailure(
op, "Non-constant output size not supported");
}

outputHeight = outputSize[0];
outputWidth = outputSize[1];

// Output size values being provided implies that scale values are not
// provided
scalesH =
static_cast<double>(outputHeight) / static_cast<double>(selfHeight);
scalesW =
static_cast<double>(outputWidth) / static_cast<double>(selfWidth);
} else {
if (!matchPattern(op.getScaleFactors(),
m_TorchListOfConstantFloats(scaleFactors)))
if (isa<Torch::NoneType>(op.getScaleFactors().getType())) {
return rewriter.notifyMatchFailure(
op, "Non-constant output size not supported");

scalesH = std::ceil(scaleFactors[0]);
scalesW = std::ceil(scaleFactors[1]);

// Scale values being provided implies that output size values are not
// provided
outputHeight = static_cast<int64_t>(scalesH * selfHeight);
outputWidth = static_cast<int64_t>(scalesW * selfWidth);
}
}

// Reshape input
SmallVector<int64_t> reshapedSelfShape(selfShape.begin(),
selfShape.end() - 2);
reshapedSelfShape.push_back(selfHeight * selfWidth);
op, "Missing output size and scale factors");
}

auto reshapedSelf = tosa::ReshapeOp::create(
rewriter, op->getLoc(),
RankedTensorType::get(reshapedSelfShape, selfElemTy), self,
tosa::getTosaConstShape(rewriter, op->getLoc(), reshapedSelfShape));

// Calculate PyTorch-styled gather indices
SmallVector<int32_t> targetIndicesVec;
int64_t indexRepeat = std::accumulate(
selfShape.begin(), selfShape.end() - 2, 1, std::multiplies<int64_t>());
for (int64_t i = 0; i < indexRepeat; i++) {
for (int64_t heightIndex = 0; heightIndex < outputHeight; heightIndex++) {
for (int64_t widthIndex = 0; widthIndex < outputWidth; widthIndex++) {
targetIndicesVec.push_back(static_cast<int32_t>(
std::floor(heightIndex / scalesH) * selfWidth +
std::floor(widthIndex / scalesW)));
SmallVector<double, 2> scaleFactors;
if (!matchPattern(op.getScaleFactors(),
m_TorchListOfConstantFloats(scaleFactors))) {
return rewriter.notifyMatchFailure(
op, "Non-constant scale_factors not supported");
}

// PyTorch uses floor after the scale multiplication
// https://docs.pytorch.org/docs/stable/generated/torch.nn.UpsamplingNearest2d.html
outputHeight =
static_cast<int64_t>(std::floor(inputShape[2] * scaleFactors[0]));
outputWidth =
static_cast<int64_t>(std::floor(inputShape[3] * scaleFactors[1]));
}
}

SmallVector<int64_t> targetIndicesShape(selfShape.begin(),
selfShape.end() - 2);
targetIndicesShape.push_back(outputHeight * outputWidth);
auto targetIndicesTorch =
tosa::getConstTensor<int32_t>(rewriter, op, targetIndicesVec,
targetIndicesShape)
.value();

// Convert PyTorch-styled indices to TensorFlow-styled indices
auto targetIndicesTF = tosa::convertTorchIndexToTfIndices(
rewriter, op, reshapedSelf.getResult(), targetIndicesTorch,
selfRank - 2);
if (!targetIndicesTF)
return rewriter.notifyMatchFailure(
op, "Convert PyTorch-styled indices and dim "
"to TensorFlow-styled indices failed");
// Apply TensorFlow GatherNdOp with TensorFlow-style indices to retrieve
// target elements
auto gatherOp = tosa::convertGatherNdOp(
rewriter, op, RankedTensorType::get(targetIndicesShape, resultElemTy),
reshapedSelf.getResult(), targetIndicesTF.value());
if (!gatherOp)
return rewriter.notifyMatchFailure(op, "Convert GatherNdOp failed");

auto result = tosa::ReshapeOp::create(
rewriter, op->getLoc(), resultType, gatherOp.value(),
tosa::getTosaConstShape(rewriter, op->getLoc(), resultShape));

rewriter.replaceOp(op, {result.getResult()});
auto resultTy = cast<RankedTensorType>(
this->getTypeConverter()->convertType(op.getType()));
Value resizeOp = convertResizeOp(rewriter, op, this->getTypeConverter(),
input, inputTy, resultTy, outputHeight,
outputWidth, /*alignCorners=*/false,
tosa::ResizeMode::NEAREST_NEIGHBOR);
rewriter.replaceOp(op, {resizeOp});

return success();
}
Expand Down
Loading
Loading