Skip to content

Commit

Permalink
expression: fix type infer when wrap cast decimal as string (pingcap#…
Browse files Browse the repository at this point in the history
  • Loading branch information
XuHuaiyu authored Aug 21, 2018
1 parent d455e6a commit d9ea38b
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 29 deletions.
59 changes: 31 additions & 28 deletions expression/builtin_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -1633,11 +1633,13 @@ func (i inCastContext) String() string {
return "__cast_ctx"
}

// inUnionCastContext is session key value that indicates whether executing in union cast context.
// inUnionCastContext is session key value that indicates whether executing in
// union cast context.
// @see BuildCastFunction4Union
const inUnionCastContext inCastContext = 0

// BuildCastFunction4Union build a implicitly CAST ScalarFunction from the Union Expression.
// BuildCastFunction4Union build a implicitly CAST ScalarFunction from the Union
// Expression.
func BuildCastFunction4Union(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) {
ctx.SetValue(inUnionCastContext, struct{}{})
defer func() {
Expand Down Expand Up @@ -1673,17 +1675,16 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT
Function: f,
}
// We do not fold CAST if the eval type of this scalar function is ETJson
// since we may reset the flag of the field type of CastAsJson later which would
// affect the evaluation of it.
// since we may reset the flag of the field type of CastAsJson later which
// would affect the evaluation of it.
if tp.EvalType() != types.ETJson {
res = FoldConstant(res)
}
return res
}

// WrapWithCastAsInt wraps `expr` with `cast` if the return type
// of expr is not type int,
// otherwise, returns `expr` directly.
// WrapWithCastAsInt wraps `expr` with `cast` if the return type of expr is not
// type int, otherwise, returns `expr` directly.
func WrapWithCastAsInt(ctx sessionctx.Context, expr Expression) Expression {
if expr.GetType().EvalType() == types.ETInt {
return expr
Expand All @@ -1695,9 +1696,8 @@ func WrapWithCastAsInt(ctx sessionctx.Context, expr Expression) Expression {
return BuildCastFunction(ctx, expr, tp)
}

// WrapWithCastAsReal wraps `expr` with `cast` if the return type
// of expr is not type real,
// otherwise, returns `expr` directly.
// WrapWithCastAsReal wraps `expr` with `cast` if the return type of expr is not
// type real, otherwise, returns `expr` directly.
func WrapWithCastAsReal(ctx sessionctx.Context, expr Expression) Expression {
if expr.GetType().EvalType() == types.ETReal {
return expr
Expand All @@ -1709,9 +1709,8 @@ func WrapWithCastAsReal(ctx sessionctx.Context, expr Expression) Expression {
return BuildCastFunction(ctx, expr, tp)
}

// WrapWithCastAsDecimal wraps `expr` with `cast` if the return type
// of expr is not type decimal,
// otherwise, returns `expr` directly.
// WrapWithCastAsDecimal wraps `expr` with `cast` if the return type of expr is
// not type decimal, otherwise, returns `expr` directly.
func WrapWithCastAsDecimal(ctx sessionctx.Context, expr Expression) Expression {
if expr.GetType().EvalType() == types.ETDecimal {
return expr
Expand All @@ -1723,15 +1722,22 @@ func WrapWithCastAsDecimal(ctx sessionctx.Context, expr Expression) Expression {
return BuildCastFunction(ctx, expr, tp)
}

// WrapWithCastAsString wraps `expr` with `cast` if the return type
// of expr is not type string,
// otherwise, returns `expr` directly.
// WrapWithCastAsString wraps `expr` with `cast` if the return type of expr is
// not type string, otherwise, returns `expr` directly.
func WrapWithCastAsString(ctx sessionctx.Context, expr Expression) Expression {
if expr.GetType().EvalType() == types.ETString {
exprTp := expr.GetType()
if exprTp.EvalType() == types.ETString {
return expr
}
argLen := expr.GetType().Flen
if expr.GetType().EvalType() == types.ETInt {
argLen := exprTp.Flen
// If expr is decimal, we should take the decimal point and negative sign
// into consideration, so we set `expr.GetType().Flen + 2` as the `argLen`.
// Since the length of float and double is not accurate, we do not handle
// them.
if exprTp.Tp == mysql.TypeNewDecimal && argLen != types.UnspecifiedFsp {
argLen += 2
}
if exprTp.EvalType() == types.ETInt {
argLen = mysql.MaxIntWidth
}
tp := types.NewFieldType(mysql.TypeVarString)
Expand All @@ -1740,9 +1746,8 @@ func WrapWithCastAsString(ctx sessionctx.Context, expr Expression) Expression {
return BuildCastFunction(ctx, expr, tp)
}

// WrapWithCastAsTime wraps `expr` with `cast` if the return type
// of expr is not same as type of the specified `tp` ,
// otherwise, returns `expr` directly.
// WrapWithCastAsTime wraps `expr` with `cast` if the return type of expr is not
// same as type of the specified `tp` , otherwise, returns `expr` directly.
func WrapWithCastAsTime(ctx sessionctx.Context, expr Expression, tp *types.FieldType) Expression {
exprTp := expr.GetType().Tp
if tp.Tp == exprTp {
Expand All @@ -1769,9 +1774,8 @@ func WrapWithCastAsTime(ctx sessionctx.Context, expr Expression, tp *types.Field
return BuildCastFunction(ctx, expr, tp)
}

// WrapWithCastAsDuration wraps `expr` with `cast` if the return type
// of expr is not type duration,
// otherwise, returns `expr` directly.
// WrapWithCastAsDuration wraps `expr` with `cast` if the return type of expr is
// not type duration, otherwise, returns `expr` directly.
func WrapWithCastAsDuration(ctx sessionctx.Context, expr Expression) Expression {
if expr.GetType().Tp == mysql.TypeDuration {
return expr
Expand All @@ -1790,9 +1794,8 @@ func WrapWithCastAsDuration(ctx sessionctx.Context, expr Expression) Expression
return BuildCastFunction(ctx, expr, tp)
}

// WrapWithCastAsJSON wraps `expr` with `cast` if the return type
// of expr is not type json,
// otherwise, returns `expr` directly.
// WrapWithCastAsJSON wraps `expr` with `cast` if the return type of expr is not
// type json, otherwise, returns `expr` directly.
func WrapWithCastAsJSON(ctx sessionctx.Context, expr Expression) Expression {
if expr.GetType().Tp == mysql.TypeJSON && !mysql.HasParseToJSONFlag(expr.GetType().Flag) {
return expr
Expand Down
7 changes: 7 additions & 0 deletions expression/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,13 @@ func (s *testIntegrationSuite) TestStringBuiltin(c *C) {
result.Check(testkit.Rows("2 0 3 0"))
result = tk.MustQuery(`select field("abc", "a", 1), field(1.3, "1.3", 1.5);`)
result.Check(testkit.Rows("1 1"))

tk.MustExec("drop table if exists t")
tk.MustExec("create table t(a decimal(11, 8), b decimal(11,8))")
tk.MustExec("insert into t values('114.57011441','38.04620115'), ('-38.04620119', '38.04620115');")
result = tk.MustQuery("select a,b,concat_ws(',',a,b) from t")
result.Check(testkit.Rows("114.57011441 38.04620115 114.57011441,38.04620115",
"-38.04620119 38.04620115 -38.04620119,38.04620115"))
}

func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) {
Expand Down
2 changes: 1 addition & 1 deletion expression/typeinfer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,7 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase {
{"reverse(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"reverse(c_float_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 12, types.UnspecifiedLength},
{"reverse(c_double_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 22, types.UnspecifiedLength},
{"reverse(c_decimal )", mysql.TypeVarString, charset.CharsetUTF8, 0, 6, types.UnspecifiedLength},
{"reverse(c_decimal )", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength},
{"reverse(c_char )", mysql.TypeVarString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"reverse(c_varchar )", mysql.TypeVarString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength},
{"reverse(c_text_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 65535, types.UnspecifiedLength},
Expand Down

0 comments on commit d9ea38b

Please sign in to comment.