From d9ea38b2adc99e1ab3672c1307d9a06d98037f35 Mon Sep 17 00:00:00 2001 From: HuaiyuXu <391585975@qq.com> Date: Tue, 21 Aug 2018 19:40:26 +0800 Subject: [PATCH] expression: fix type infer when wrap cast decimal as string (#7451) --- expression/builtin_cast.go | 59 ++++++++++++++++++---------------- expression/integration_test.go | 7 ++++ expression/typeinfer_test.go | 2 +- 3 files changed, 39 insertions(+), 29 deletions(-) diff --git a/expression/builtin_cast.go b/expression/builtin_cast.go index fe0694293e46b..50b88e0bff3dd 100644 --- a/expression/builtin_cast.go +++ b/expression/builtin_cast.go @@ -1633,11 +1633,13 @@ func (i inCastContext) String() string { return "__cast_ctx" } -// inUnionCastContext is session key value that indicates whether executing in union cast context. +// inUnionCastContext is session key value that indicates whether executing in +// union cast context. // @see BuildCastFunction4Union const inUnionCastContext inCastContext = 0 -// BuildCastFunction4Union build a implicitly CAST ScalarFunction from the Union Expression. +// BuildCastFunction4Union build a implicitly CAST ScalarFunction from the Union +// Expression. func BuildCastFunction4Union(ctx sessionctx.Context, expr Expression, tp *types.FieldType) (res Expression) { ctx.SetValue(inUnionCastContext, struct{}{}) defer func() { @@ -1673,17 +1675,16 @@ func BuildCastFunction(ctx sessionctx.Context, expr Expression, tp *types.FieldT Function: f, } // We do not fold CAST if the eval type of this scalar function is ETJson - // since we may reset the flag of the field type of CastAsJson later which would - // affect the evaluation of it. + // since we may reset the flag of the field type of CastAsJson later which + // would affect the evaluation of it. if tp.EvalType() != types.ETJson { res = FoldConstant(res) } return res } -// WrapWithCastAsInt wraps `expr` with `cast` if the return type -// of expr is not type int, -// otherwise, returns `expr` directly. +// WrapWithCastAsInt wraps `expr` with `cast` if the return type of expr is not +// type int, otherwise, returns `expr` directly. func WrapWithCastAsInt(ctx sessionctx.Context, expr Expression) Expression { if expr.GetType().EvalType() == types.ETInt { return expr @@ -1695,9 +1696,8 @@ func WrapWithCastAsInt(ctx sessionctx.Context, expr Expression) Expression { return BuildCastFunction(ctx, expr, tp) } -// WrapWithCastAsReal wraps `expr` with `cast` if the return type -// of expr is not type real, -// otherwise, returns `expr` directly. +// WrapWithCastAsReal wraps `expr` with `cast` if the return type of expr is not +// type real, otherwise, returns `expr` directly. func WrapWithCastAsReal(ctx sessionctx.Context, expr Expression) Expression { if expr.GetType().EvalType() == types.ETReal { return expr @@ -1709,9 +1709,8 @@ func WrapWithCastAsReal(ctx sessionctx.Context, expr Expression) Expression { return BuildCastFunction(ctx, expr, tp) } -// WrapWithCastAsDecimal wraps `expr` with `cast` if the return type -// of expr is not type decimal, -// otherwise, returns `expr` directly. +// WrapWithCastAsDecimal wraps `expr` with `cast` if the return type of expr is +// not type decimal, otherwise, returns `expr` directly. func WrapWithCastAsDecimal(ctx sessionctx.Context, expr Expression) Expression { if expr.GetType().EvalType() == types.ETDecimal { return expr @@ -1723,15 +1722,22 @@ func WrapWithCastAsDecimal(ctx sessionctx.Context, expr Expression) Expression { return BuildCastFunction(ctx, expr, tp) } -// WrapWithCastAsString wraps `expr` with `cast` if the return type -// of expr is not type string, -// otherwise, returns `expr` directly. +// WrapWithCastAsString wraps `expr` with `cast` if the return type of expr is +// not type string, otherwise, returns `expr` directly. func WrapWithCastAsString(ctx sessionctx.Context, expr Expression) Expression { - if expr.GetType().EvalType() == types.ETString { + exprTp := expr.GetType() + if exprTp.EvalType() == types.ETString { return expr } - argLen := expr.GetType().Flen - if expr.GetType().EvalType() == types.ETInt { + argLen := exprTp.Flen + // If expr is decimal, we should take the decimal point and negative sign + // into consideration, so we set `expr.GetType().Flen + 2` as the `argLen`. + // Since the length of float and double is not accurate, we do not handle + // them. + if exprTp.Tp == mysql.TypeNewDecimal && argLen != types.UnspecifiedFsp { + argLen += 2 + } + if exprTp.EvalType() == types.ETInt { argLen = mysql.MaxIntWidth } tp := types.NewFieldType(mysql.TypeVarString) @@ -1740,9 +1746,8 @@ func WrapWithCastAsString(ctx sessionctx.Context, expr Expression) Expression { return BuildCastFunction(ctx, expr, tp) } -// WrapWithCastAsTime wraps `expr` with `cast` if the return type -// of expr is not same as type of the specified `tp` , -// otherwise, returns `expr` directly. +// WrapWithCastAsTime wraps `expr` with `cast` if the return type of expr is not +// same as type of the specified `tp` , otherwise, returns `expr` directly. func WrapWithCastAsTime(ctx sessionctx.Context, expr Expression, tp *types.FieldType) Expression { exprTp := expr.GetType().Tp if tp.Tp == exprTp { @@ -1769,9 +1774,8 @@ func WrapWithCastAsTime(ctx sessionctx.Context, expr Expression, tp *types.Field return BuildCastFunction(ctx, expr, tp) } -// WrapWithCastAsDuration wraps `expr` with `cast` if the return type -// of expr is not type duration, -// otherwise, returns `expr` directly. +// WrapWithCastAsDuration wraps `expr` with `cast` if the return type of expr is +// not type duration, otherwise, returns `expr` directly. func WrapWithCastAsDuration(ctx sessionctx.Context, expr Expression) Expression { if expr.GetType().Tp == mysql.TypeDuration { return expr @@ -1790,9 +1794,8 @@ func WrapWithCastAsDuration(ctx sessionctx.Context, expr Expression) Expression return BuildCastFunction(ctx, expr, tp) } -// WrapWithCastAsJSON wraps `expr` with `cast` if the return type -// of expr is not type json, -// otherwise, returns `expr` directly. +// WrapWithCastAsJSON wraps `expr` with `cast` if the return type of expr is not +// type json, otherwise, returns `expr` directly. func WrapWithCastAsJSON(ctx sessionctx.Context, expr Expression) Expression { if expr.GetType().Tp == mysql.TypeJSON && !mysql.HasParseToJSONFlag(expr.GetType().Flag) { return expr diff --git a/expression/integration_test.go b/expression/integration_test.go index e57b5ef431b3f..20369d6038ab3 100644 --- a/expression/integration_test.go +++ b/expression/integration_test.go @@ -925,6 +925,13 @@ func (s *testIntegrationSuite) TestStringBuiltin(c *C) { result.Check(testkit.Rows("2 0 3 0")) result = tk.MustQuery(`select field("abc", "a", 1), field(1.3, "1.3", 1.5);`) result.Check(testkit.Rows("1 1")) + + tk.MustExec("drop table if exists t") + tk.MustExec("create table t(a decimal(11, 8), b decimal(11,8))") + tk.MustExec("insert into t values('114.57011441','38.04620115'), ('-38.04620119', '38.04620115');") + result = tk.MustQuery("select a,b,concat_ws(',',a,b) from t") + result.Check(testkit.Rows("114.57011441 38.04620115 114.57011441,38.04620115", + "-38.04620119 38.04620115 -38.04620119,38.04620115")) } func (s *testIntegrationSuite) TestEncryptionBuiltin(c *C) { diff --git a/expression/typeinfer_test.go b/expression/typeinfer_test.go index 20832ba88d373..2aac3d967a5da 100644 --- a/expression/typeinfer_test.go +++ b/expression/typeinfer_test.go @@ -403,7 +403,7 @@ func (s *testInferTypeSuite) createTestCase4StrFuncs() []typeInferTestCase { {"reverse(c_bigint_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength}, {"reverse(c_float_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 12, types.UnspecifiedLength}, {"reverse(c_double_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 22, types.UnspecifiedLength}, - {"reverse(c_decimal )", mysql.TypeVarString, charset.CharsetUTF8, 0, 6, types.UnspecifiedLength}, + {"reverse(c_decimal )", mysql.TypeVarString, charset.CharsetUTF8, 0, 8, types.UnspecifiedLength}, {"reverse(c_char )", mysql.TypeVarString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength}, {"reverse(c_varchar )", mysql.TypeVarString, charset.CharsetUTF8, 0, 20, types.UnspecifiedLength}, {"reverse(c_text_d )", mysql.TypeVarString, charset.CharsetUTF8, 0, 65535, types.UnspecifiedLength},