-
Notifications
You must be signed in to change notification settings - Fork 491
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[stablehlo] fix type for compare scalar op #3702
Conversation
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); | ||
} | ||
// use lhs's element type as compute type | ||
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
seems that this will cause torch.tensor([1.0]) < torch.tensor([2.0],dtype=torch.double)
will compute on fp32
.
I think we need to discuss how to use stablehlo to describe torch's default compute type.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
see another PR:#3673
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
In the previous code, the comparison was done based on the number of bits
torch-mlir/lib/Conversion/TorchToStablehlo/Basic.cpp
Lines 539 to 551 in bb69014
if (isa<mlir::IntegerType>(lhsElemTy) && isa<mlir::FloatType>(rhsElemTy)) { | |
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); | |
} else if (isa<mlir::FloatType>(lhsElemTy) && | |
isa<mlir::IntegerType>(rhsElemTy)) { | |
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); | |
} else { | |
if (lhsElemTy.getIntOrFloatBitWidth() > | |
rhsElemTy.getIntOrFloatBitWidth()) { | |
rhs = hlo::promoteType(rewriter, op.getLoc(), rhs, lhsElemTy); | |
} else { | |
lhs = hlo::promoteType(rewriter, op.getLoc(), lhs, rhsElemTy); | |
} | |
} |
And then, #3518 promotes rhsType to lhsType in advance. I'm not quite sure about its purpose; this PR just ensures that when lhsType.isInterger && rhsType.isFloat
, the comparison is done in a float manner.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I know. In torch, tensor
is first level, scalar
is second level. When compute tensor
< scalar
, mostly should use tensor
's dtype as compute type except torch.tensor([1]) < 1.1
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
According to PyTorch's implementation, when comparing a tensor and a scalar, the scalar is first converted to a tensor using the scalar_to_tensor function before the comparison. Therefore, compare tensor scalar
should exhibit the same behavior as compare tensor tensor
.
https://github.com/pytorch/pytorch/blob/02169364e15932d886370d711482ef1cd5a5b137/aten/src/ATen/ScalarOps.h#L45-L51
Therefore, the correct semantics should be to maintain the approach used before #3518 .
No description provided.