Skip to content

Commit 88adf38

Browse files
torch-mlir change for dense resource implementation (#2513)
Co-authored-by: Avinash Sharma <[email protected]>
1 parent 1b9fb1b commit 88adf38

File tree

1 file changed

+14
-9
lines changed

1 file changed

+14
-9
lines changed

lib/Conversion/TorchToArith/TorchToArith.cpp

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -171,14 +171,18 @@ class ConvertTorchTensorLiteralOp
171171
ConversionPatternRewriter &rewriter) const override {
172172
MLIRContext *context = op->getContext();
173173
if (auto elements = op.getValueAttr().dyn_cast<DenseIntElementsAttr>()) {
174-
Type elemTy = op.getValueAttr().getElementType();
175-
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
176-
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
177-
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
178-
op, elements.mapValues(builtinTensorElemTy, [&](const APInt &v) {
179-
return APInt(bitWidth, v.getSExtValue());
180-
}));
181-
return success();
174+
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
175+
Type elemTy = op.getValueAttr().getElementType();
176+
unsigned bitWidth = elemTy.getIntOrFloatBitWidth();
177+
Type builtinTensorElemTy = IntegerType::get(context, bitWidth);
178+
auto shapedType =
179+
RankedTensorType::get(type.getShape(), builtinTensorElemTy);
180+
auto rawData = elements.getRawData();
181+
DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer(
182+
shapedType, rawData);
183+
rewriter.replaceOpWithNewOp<arith::ConstantOp>(op, newAttr);
184+
return success();
185+
}
182186
}
183187
if (auto elements = op.getValueAttr().dyn_cast<DenseResourceElementsAttr>()) {
184188
if (auto type = elements.getType().dyn_cast<RankedTensorType>()) {
@@ -190,7 +194,8 @@ class ConvertTorchTensorLiteralOp
190194
AsmResourceBlob *blob = elements.getRawHandle().getBlob();
191195
assert(blob && "Expecting dense resource with a valid blob");
192196
rewriter.replaceOpWithNewOp<arith::ConstantOp>(
193-
op, DenseElementsAttr::get(shapedType, blob->getData()));
197+
op, DenseResourceElementsAttr::get(shapedType,
198+
elements.getRawHandle()));
194199
return success();
195200
}
196201
}

0 commit comments

Comments
 (0)