Skip to content

[WIP] fix some issues on paraformer #1107

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -1438,6 +1438,12 @@ class ShapePropagator : public PropertyPropBase {
"aten::to.prim_Device(Tensor(a) self, Device? device, int? dtype=None, bool non_blocking=False, bool copy=False) -> Tensor(a|b)"},
[](Node* node) -> type_vec_t {
if (auto type = node->input(0)->type()->cast<TensorType>()) {
auto dtype = type->scalarType();
at::optional<IValue> maybe_dtype_option = node->get(attr::dtype);
if (maybe_dtype_option && maybe_dtype_option->isInt()) {
dtype = maybe_dtype_option->toScalarType();
}

auto device = getDeviceFromValue(node->namedInput(attr::device));
if (type->dim()) {
auto scalarType =
Expand All @@ -1446,7 +1452,7 @@ class ShapePropagator : public PropertyPropBase {
scalarType = type->scalarType();
}
return {TensorType::create(
scalarType,
dtype,
device,
type->dim(),
/*requires_grad=*/c10::nullopt)
Expand Down
12 changes: 12 additions & 0 deletions pytorch_blade/tests/torchscript/since_1_10.graph
Original file line number Diff line number Diff line change
Expand Up @@ -120,3 +120,15 @@ graph(%p1 : Float(*, *, *, device=cpu)):
// CHECK: Float(*, *, *, device=cuda) = aten::to
%cuda_zeros : Tensor = aten::to(%new_zeros, %cuda, %none, %false, %false)
return (%cuda_zeros)

// aten::to.prim_Device with dtype
// CHECK-LABEL: graph
graph(%p1 : Bool(device=cuda:0)):
%1 : Device = prim::Constant[value="cuda:1"]()
%2 : int = prim::Constant[value=5]()
%3 : bool = prim::Constant[value=0]()
// CHECK: Half(device=cuda:1) = aten::to(%p1, %1, %2, %3, %3)
%5 : Tensor = aten::to(%p1, %1, %2, %3, %3)
return (%2)


18 changes: 18 additions & 0 deletions tao_compiler/mlir/disc/transforms/mhlo_decomp_rewriters.cc
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ LogicalResult PadOpConvert::matchAndRewrite(mhlo::PadOp op,
} // namespace

namespace {

struct SliceOpConvert : public OpRewritePattern<mhlo::SliceOp> {
explicit SliceOpConvert(MLIRContext* context) : OpRewritePattern(context) {}
LogicalResult matchAndRewrite(mhlo::SliceOp op,
Expand All @@ -129,6 +130,22 @@ LogicalResult SliceOpConvert::matchAndRewrite(mhlo::SliceOp op,
auto operand = op.getOperand();
rewriter.replaceOpWithNewOp<mhlo::RealDynamicSliceOp>(
op, op.getType(), operand, startIndices, limitIndices, strides);

return success();
}

struct ArithConstOpConvert : public OpRewritePattern<arith::ConstantOp> {
explicit ArithConstOpConvert(MLIRContext* context)
: OpRewritePattern(context) {}
LogicalResult matchAndRewrite(arith::ConstantOp op,
PatternRewriter& rewriter) const override;
};

LogicalResult ArithConstOpConvert::matchAndRewrite(
arith::ConstantOp op, PatternRewriter& rewriter) const {
auto resultType = op.getType().dyn_cast<RankedTensorType>();
if (!resultType or resultType.getRank() < 1) return failure();
rewriter.replaceOpWithNewOp<mhlo::ConstantOp>(op, op.getValue());
return success();
}
} // namespace
Expand All @@ -139,6 +156,7 @@ struct MhloDecompositionRewriterPass
func::FuncOp func = getOperation();
MLIRContext* ctx = func.getContext();
RewritePatternSet patterns(ctx);
patterns.insert<ArithConstOpConvert>(ctx);
patterns.insert<BatchNormInferenceOpConvert>(ctx);
patterns.insert<PadOpConvert>(ctx);
patterns.insert<SliceOpConvert>(ctx);
Expand Down