Skip to content

Commit

Permalink
address some rnn comments from @zjgarvey
Browse files Browse the repository at this point in the history
  • Loading branch information
renxida committed Jul 2, 2024
1 parent a7622ee commit 8972965
Showing 1 changed file with 35 additions and 17 deletions.
52 changes: 35 additions & 17 deletions lib/Conversion/TorchOnnxToTorch/OnnxRnnExpander.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
Location loc = binder.getLoc();
mlir::ImplicitLocOpBuilder b(loc, rewriter);

std::string direction;
auto intType = b.getType<IntType>();
Value cstNone = b.create<ConstantNoneOp>();
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));

ValueTensorType yTy, Y_hType;
if (binder.tensorResultTypeAtIndex(yTy, 0) ||
Expand All @@ -167,9 +169,20 @@ LogicalResult OnnxRnnExpander(OpBinder binder,

auto xTy = cast<ValueTensorType>(X.getType());
auto wTy = cast<ValueTensorType>(W.getType());
auto WShape = wTy.getSizes();
assert(WShape.size() == 3);
// use unpacking to get the values of the 3 dimensions
int64_t num_directions, hidden_size, input_size;
num_directions = WShape[0];
hidden_size = WShape[1];
input_size = WShape[2];
SmallVector<int64_t, 2> BShape = {num_directions, 2 * hidden_size};
auto BType = b.getType<ValueTensorType>(BShape, wTy.getDtype());

Value B;
if (binder.tensorOperandAtIndex(B, 3)) {
B = b.create<AtenZerosOp>(W.getType(), W);
B = b.create<Torch::AtenZerosOp>(BType, BShape, wTy.getDtype(), cstNone,
cstNone, cstNone);
}

llvm::SmallVector<std::string> activationsList;
Expand All @@ -193,26 +206,41 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
}
}

std::string direction;
if (!binder.customOpNameStringAttr(direction, "direction", "forward") &&
direction != "forward")
return rewriter.notifyMatchFailure(binder.op,
"Unsupported direction attribute value. "
"Only 'forward' is supported but '" +
direction + "' is provided.");
int64_t num_directions = (direction == "bidirectional") ? 2 : 1;
int64_t num_directions_attr = (direction == "bidirectional") ? 2 : 1;
if (num_directions == num_directions_attr) {
return rewriter.notifyMatchFailure(
binder.op, "num_directions from shape of W (" +
std::to_string(num_directions) +
") does not match the direction attribute value (" +
std::to_string(num_directions_attr) + ")");
}

auto XShape = xTy.getSizes();
int64_t batch_size = XShape[1];
int64_t input_size = XShape[2];
if (input_size != XShape[2]) {
return rewriter.notifyMatchFailure(
binder.op, "input_size inferred from shape of W (" +
std::to_string(input_size) +
") does not match the third dimension of X (" +
std::to_string(XShape[2]) + ")");
}
if (num_directions != wTy.getSizes()[0])
return rewriter.notifyMatchFailure(
binder.op, "num_directions (" + std::to_string(num_directions) +
") does not match the first dimension of wTy (" +
std::to_string(wTy.getSizes()[0]) + ")");
if (num_directions != 1)
if (num_directions != 1) {
return rewriter.notifyMatchFailure(
binder.op, "num_directions (" + std::to_string(num_directions) +
") is not equal to 1");
binder.op, "Unsupported num_directions. Only 1 is supported but " +
std::to_string(num_directions) + " is provided.");
}
if (hidden_size != wTy.getSizes()[1])
return rewriter.notifyMatchFailure(
binder.op, "hidden_size (" + std::to_string(hidden_size) +
Expand Down Expand Up @@ -247,16 +275,12 @@ LogicalResult OnnxRnnExpander(OpBinder binder,
llvm::SmallVector<int64_t>{num_directions, batch_size, hidden_size},
xTy.getDtype());

auto intType = b.getType<IntType>();

Value cstNumDirections =
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(num_directions));
Value cstBatchSize =
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(batch_size));
Value cstHiddenSize =
b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(hidden_size));
Value cstNone = b.create<ConstantNoneOp>();
Value cstZero = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(0));

Value hShape = b.create<PrimListConstructOp>(
b.getType<ListType>(intType),
Expand All @@ -272,12 +296,6 @@ LogicalResult OnnxRnnExpander(OpBinder binder,

Value initial_h_forward = getDirection(0, initial_h);

if (num_directions != 1) {
return rewriter.notifyMatchFailure(
binder.op, "Unsupported num_directions. Only 1 is supported but " +
std::to_string(num_directions) + " is provided.");
}

Value cstOne = b.create<ConstantIntOp>(intType, b.getI64IntegerAttr(1));

RnnWeights weights;
Expand Down

0 comments on commit 8972965

Please sign in to comment.