@@ -171,14 +171,18 @@ class ConvertTorchTensorLiteralOp
171171 ConversionPatternRewriter &rewriter) const override {
172172 MLIRContext *context = op->getContext ();
173173 if (auto elements = op.getValueAttr ().dyn_cast <DenseIntElementsAttr>()) {
174- Type elemTy = op.getValueAttr ().getElementType ();
175- unsigned bitWidth = elemTy.getIntOrFloatBitWidth ();
176- Type builtinTensorElemTy = IntegerType::get (context, bitWidth);
177- rewriter.replaceOpWithNewOp <arith::ConstantOp>(
178- op, elements.mapValues (builtinTensorElemTy, [&](const APInt &v) {
179- return APInt (bitWidth, v.getSExtValue ());
180- }));
181- return success ();
174+ if (auto type = elements.getType ().dyn_cast <RankedTensorType>()) {
175+ Type elemTy = op.getValueAttr ().getElementType ();
176+ unsigned bitWidth = elemTy.getIntOrFloatBitWidth ();
177+ Type builtinTensorElemTy = IntegerType::get (context, bitWidth);
178+ auto shapedType =
179+ RankedTensorType::get (type.getShape (), builtinTensorElemTy);
180+ auto rawData = elements.getRawData ();
181+ DenseElementsAttr newAttr = DenseElementsAttr::getFromRawBuffer (
182+ shapedType, rawData);
183+ rewriter.replaceOpWithNewOp <arith::ConstantOp>(op, newAttr);
184+ return success ();
185+ }
182186 }
183187 if (auto elements = op.getValueAttr ().dyn_cast <DenseResourceElementsAttr>()) {
184188 if (auto type = elements.getType ().dyn_cast <RankedTensorType>()) {
@@ -190,7 +194,8 @@ class ConvertTorchTensorLiteralOp
190194 AsmResourceBlob *blob = elements.getRawHandle ().getBlob ();
191195 assert (blob && " Expecting dense resource with a valid blob" );
192196 rewriter.replaceOpWithNewOp <arith::ConstantOp>(
193- op, DenseElementsAttr::get (shapedType, blob->getData ()));
197+ op, DenseResourceElementsAttr::get (shapedType,
198+ elements.getRawHandle ()));
194199 return success ();
195200 }
196201 }
0 commit comments