From d1712c8fc40fd62af9259a2fc136ddb7112d80bc Mon Sep 17 00:00:00 2001 From: Muneyuki Noguchi Date: Wed, 28 Oct 2020 08:32:18 +0900 Subject: [PATCH] Don't create two function maps --- calculator.go | 108 +++++++++++++++++++++++++------------------------- parser.go | 78 +++++++++++------------------------- 2 files changed, 79 insertions(+), 107 deletions(-) diff --git a/calculator.go b/calculator.go index 4b5eda3..5d5c50c 100644 --- a/calculator.go +++ b/calculator.go @@ -5,61 +5,62 @@ import ( "math" ) +var functions = map[string]interface{}{ + "abs": math.Abs, + "acos": math.Acos, + "acosh": math.Acosh, + "asin": math.Asin, + "asinh": math.Asinh, + "atan": math.Atan, + "atan2": math.Atan2, + "atanh": math.Atanh, + "cbrt": math.Cbrt, + "ceil": math.Ceil, + "copysign": math.Copysign, + "cos": math.Cos, + "cosh": math.Cosh, + "dim": math.Dim, + "erf": math.Erf, + "erfc": math.Erfc, + "erfcinv": math.Erfcinv, // Go 1.10+ + "erfinv": math.Erfinv, // Go 1.10+ + "exp": math.Exp, + "exp2": math.Exp2, + "expm1": math.Expm1, + "fma": math.FMA, // Go 1.14+ + "floor": math.Floor, + "gamma": math.Gamma, + "hypot": math.Hypot, + "j0": math.J0, + "j1": math.J1, + "log": math.Log, + "log10": math.Log10, + "log1p": math.Log1p, + "log2": math.Log2, + "logb": math.Logb, + "max": math.Max, + "min": math.Min, + "mod": math.Mod, + "nan": math.NaN, + "nextafter": math.Nextafter, + "pow": math.Pow, + "remainder": math.Remainder, + "round": math.Round, // Go 1.10+ + "roundtoeven": math.RoundToEven, // Go 1.10+ + "sin": math.Sin, + "sinh": math.Sinh, + "sqrt": math.Sqrt, + "tan": math.Tan, + "tanh": math.Tanh, + "trunc": math.Trunc, + "y0": math.Y0, + "y1": math.Y1, +} + func call(funcName string, args []float64) (float64, error) { - functions := map[string]interface{}{ - "abs": math.Abs, - "acos": math.Acos, - "acosh": math.Acosh, - "asin": math.Asin, - "asinh": math.Asinh, - "atan": math.Atan, - "atan2": math.Atan2, - "atanh": math.Atanh, - "cbrt": math.Cbrt, - "ceil": math.Ceil, - "copysign": math.Copysign, - "cos": math.Cos, - "cosh": math.Cosh, - "dim": math.Dim, - "erf": math.Erf, - "erfc": math.Erfc, - "erfcinv": math.Erfcinv, - "erfinv": math.Erfinv, - "exp": math.Exp, - "exp2": math.Exp2, - "expm1": math.Expm1, - "fma": math.FMA, - "floor": math.Floor, - "gamma": math.Gamma, - "hypot": math.Hypot, - "j0": math.J0, - "j1": math.J1, - "log": math.Log, - "log10": math.Log10, - "log1p": math.Log1p, - "log2": math.Log2, - "logb": math.Logb, - "max": math.Max, - "min": math.Min, - "mod": math.Mod, - "nan": math.NaN, - "nextafter": math.Nextafter, - "pow": math.Pow, - "remainder": math.Remainder, - "round": math.Round, - "roundtoeven": math.RoundToEven, - "sin": math.Sin, - "sinh": math.Sinh, - "sqrt": math.Sqrt, - "tan": math.Tan, - "tanh": math.Tanh, - "trunc": math.Trunc, - "y0": math.Y0, - "y1": math.Y1, - } f, ok := functions[funcName] if !ok { - return 0, fmt.Errorf("function %s not found", funcName) + return 0, fmt.Errorf("unknown function %s", funcName) } switch f := f.(type) { case func() float64: @@ -70,8 +71,9 @@ func call(funcName string, args []float64) (float64, error) { return f(args[0], args[1]), nil case func(float64, float64, float64) float64: return f(args[0], args[1], args[2]), nil + default: + return 0, fmt.Errorf("invalid function %s", funcName) } - return 0, fmt.Errorf("unknown function %s", funcName) } func calculate(n *node) (float64, error) { diff --git a/parser.go b/parser.go index 46b0888..e64cb60 100644 --- a/parser.go +++ b/parser.go @@ -70,63 +70,33 @@ func (p *parser) constantNode(str string) (*node, error) { return &node{kind: numNode, val: val}, nil } -func (p *parser) functionNode(str string) (*node, error) { - functions := map[string]int{ - "abs": 1, - "acos": 1, - "acosh": 1, - "asin": 1, - "asinh": 1, - "atan": 1, - "atan2": 2, - "atanh": 1, - "cbrt": 1, - "ceil": 1, - "copysign": 2, - "cos": 1, - "cosh": 1, - "dim": 2, - "erf": 1, - "erfc": 1, - "erfcinv": 1, - "erfinv": 1, - "exp": 1, - "exp2": 1, - "expm1": 1, - "fma": 3, - "floor": 1, - "gamma": 1, - "hypot": 2, - "j0": 1, - "j1": 1, - "log": 1, - "log10": 1, - "log1p": 1, - "log2": 1, - "logb": 1, - "max": 2, - "min": 2, - "mod": 2, - "nan": 0, - "nextafter": 2, - "pow": 2, - "remainder": 2, - "round": 1, - "roundtoeven": 1, - "sin": 1, - "sinh": 1, - "sqrt": 1, - "tan": 1, - "tanh": 1, - "trunc": 1, - "y0": 1, - "y1": 1, +func argumentNumber(funcName string) (int, error) { + f, ok := functions[funcName] + if !ok { + return 0, fmt.Errorf("unknown function: %s", funcName) + } + + switch f.(type) { + case func() float64: + return 0, nil + case func(float64) float64: + return 1, nil + case func(float64, float64) float64: + return 2, nil + case func(float64, float64, float64) float64: + return 3, nil + default: + return 0, fmt.Errorf("invalid function: %s", funcName) } +} + +func (p *parser) functionNode(str string) (*node, error) { funcName := strings.ToLower(str) - num, ok := functions[funcName] - if !ok { - return nil, fmt.Errorf("unknown function: %s", funcName) + num, err := argumentNumber(funcName) + if err != nil { + return nil, err } + if p.consume(")") { if num != 0 { return nil, fmt.Errorf("%s should have argument(s)", funcName)