Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stablehlo] fix type for compare scalar op #3702

Closed
wants to merge 2 commits into from

Conversation

penguin-wwy
Copy link
Collaborator

No description provided.

lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
}
// use lhs's element type as compute type
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems that this will cause torch.tensor([1.0]) < torch.tensor([2.0],dtype=torch.double) will compute on fp32.
I think we need to discuss how to use stablehlo to describe torch's default compute type.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see another PR:#3673

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the previous code, the comparison was done based on the number of bits

if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
} else if (isa<mlir::FloatType>(lhsElemTy) &&
isa<mlir::IntegerType>(rhsElemTy)) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
if (lhsElemTy.getIntOrFloatBitWidth() >
rhsElemTy.getIntOrFloatBitWidth()) {
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy);
} else {
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy);
}
}

And then, #3518 promotes rhsType to lhsType in advance. I'm not quite sure about its purpose; this PR just ensures that when lhsType.isInterger && rhsType.isFloat, the comparison is done in a float manner.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know. In torch, tensor is first level, scalar is second level. When compute tensor < scalar, mostly should use tensor's dtype as compute type except torch.tensor([1]) < 1.1.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

According to PyTorch's implementation, when comparing a tensor and a scalar, the scalar is first converted to a tensor using the scalar_to_tensor function before the comparison. Therefore, compare tensor scalar should exhibit the same behavior as compare tensor tensor.
https://github.com/pytorch/pytorch/blob/02169364e15932d886370d711482ef1cd5a5b137/aten/src/ATen/ScalarOps.h#L45-L51
Therefore, the correct semantics should be to maintain the approach used before #3518 .

qingyunqu added a commit that referenced this pull request Sep 13, 2024
qingyunqu added a commit that referenced this pull request Sep 13, 2024
@penguin-wwy penguin-wwy deleted the fix_stablehlo branch September 14, 2024 08:45
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants