diff --git a/pkg/ast/ast.go b/pkg/ast/ast.go index afdfeb9..da127c0 100644 --- a/pkg/ast/ast.go +++ b/pkg/ast/ast.go @@ -10,12 +10,20 @@ import ( "fmt" "sort" "strings" + "sync" "github.com/GuanceCloud/platypus/pkg/token" ) type NodeType uint +var ( + dagCache = struct { + sync.RWMutex + m map[string]*DAGNode + }{m: make(map[string]*DAGNode)} +) + const ( // expr. TypeInvalid NodeType = iota @@ -146,6 +154,7 @@ type Node struct { // node type NodeType NodeType elem AstNode + DagNode *DAGNode } func (n *Node) String() string { @@ -233,107 +242,296 @@ func (n *Node) StartPos() token.LnColPos { } func WrapIdentifier(node *Identifier) *Node { + fingerprint := fmt.Sprintf("id %s", node.Name) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeIdentifier, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(node.Name, node) + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeIdentifier, elem: node, + DagNode: dagNode, } } func WrapStringLiteral(node *StringLiteral) *Node { + fingerprint := fmt.Sprintf("str %s", node.Val) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeStringLiteral, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(node.Val, node) + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeStringLiteral, elem: node, + DagNode: dagNode, } } func WrapIntegerLiteral(node *IntegerLiteral) *Node { + fingerprint := fmt.Sprintf("int %d", node.Val) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeIntegerLiteral, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(node.String(), node) + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeIntegerLiteral, elem: node, + DagNode: dagNode, } } func WrapFloatLiteral(node *FloatLiteral) *Node { + fingerprint := fmt.Sprintf("float %f", node.Val) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeFloatLiteral, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(node.String(), node) + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeFloatLiteral, elem: node, + DagNode: dagNode, } } func WrapBoolLiteral(node *BoolLiteral) *Node { + fingerprint := fmt.Sprintf("bool %t", node.Val) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeBoolLiteral, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(node.String(), node) + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeBoolLiteral, elem: node, + DagNode: dagNode, } } func WrapNilLiteral(node *NilLiteral) *Node { + fingerprint := "nil" + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeNilLiteral, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode("nil", node) return &Node{ NodeType: TypeNilLiteral, elem: node, + DagNode: dagNode, } } func WrapListInitExpr(node *ListLiteral) *Node { + fingerprint := fmt.Sprintf("list %s", node.String()) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeListLiteral, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(node.String(), node) + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeListLiteral, elem: node, + DagNode: dagNode, } } func WrapMapLiteral(node *MapLiteral) *Node { + fingerprint := fmt.Sprintf("map %s", node.String()) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeMapLiteral, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(node.String(), node) + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeMapLiteral, elem: node, + DagNode: dagNode, } } func WrapParenExpr(node *ParenExpr) *Node { + fingerprint := fmt.Sprintf("paren %s", node.String()) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeParenExpr, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(node.String(), node) + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeParenExpr, elem: node, + DagNode: dagNode, } } func WrapAttrExpr(node *AttrExpr) *Node { + fingerprint := fmt.Sprintf("attr %s", node.String()) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeAttrExpr, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(node.String(), node) + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeAttrExpr, elem: node, + DagNode: dagNode, } } func WrapIndexExpr(node *IndexExpr) *Node { + fingerprint := fmt.Sprintf("index %s", node.String()) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeIndexExpr, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(fingerprint, node) + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeIndexExpr, elem: node, + DagNode: dagNode, } } func WrapArithmeticExpr(node *ArithmeticExpr) *Node { + lhsHash := node.LHS.Hash() + rhsHash := node.RHS.Hash() + fingerprint := fmt.Sprintf("%s %s %s", lhsHash, node.Op, rhsHash) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeArithmeticExpr, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(fingerprint, node) + if node.LHS.DagNode != nil { + dagNode.AddChild(node.LHS.DagNode) + } + if node.RHS.DagNode != nil { + dagNode.AddChild(node.RHS.DagNode) + } + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeArithmeticExpr, elem: node, + DagNode: dagNode, } } func WrapConditionExpr(node *ConditionalExpr) *Node { + lhsHash := node.LHS.Hash() + rhsHash := node.RHS.Hash() + fingerprint := fmt.Sprintf("%s %s %s", lhsHash, node.Op, rhsHash) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeConditionalExpr, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(fingerprint, node) + if node.LHS.DagNode != nil { + dagNode.AddChild(node.LHS.DagNode) + } + if node.RHS.DagNode != nil { + dagNode.AddChild(node.RHS.DagNode) + } + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeConditionalExpr, elem: node, + DagNode: dagNode, } } func WrapInExpr(node *InExpr) *Node { + lhsHash := node.LHS.Hash() + rhsHash := node.RHS.Hash() + fingerprint := fmt.Sprintf("%s %s %s", lhsHash, node.Op, rhsHash) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeInExpr, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(fingerprint, node) + if node.LHS.DagNode != nil { + dagNode.AddChild(node.LHS.DagNode) + } + if node.RHS.DagNode != nil { + dagNode.AddChild(node.RHS.DagNode) + } + cacheNode(fingerprint, dagNode) return &Node{ NodeType: TypeInExpr, elem: node, + DagNode: dagNode, } } func WrapUnaryExpr(node *UnaryExpr) *Node { + rhsHash := node.RHS.Hash() + fingerprint := fmt.Sprintf("%s %s", node.Op, rhsHash) + if cached, exists := getCachedNode(fingerprint); exists { + return &Node{ + NodeType: TypeUnaryExpr, + elem: node, + DagNode: cached, + } + } + dagNode := NewDAGNode(node.String(), node) + return &Node{ NodeType: TypeUnaryExpr, elem: node, + DagNode: dagNode, } } @@ -468,3 +666,55 @@ func NodeStartPos(node *Node) token.LnColPos { } return token.InvalidLnColPos } + +func getCachedNode(fingerprint string) (*DAGNode, bool) { + dagCache.RLock() + defer dagCache.RUnlock() + node, exists := dagCache.m[fingerprint] + return node, exists +} + +func cacheNode(fingerprint string, node *DAGNode) { + dagCache.Lock() + defer dagCache.Unlock() + dagCache.m[fingerprint] = node +} + +func (n *Node) Hash() string { + if n.DagNode != nil { + return n.DagNode.ID + } + switch n.NodeType { + case TypeIdentifier: + return n.Identifier().Name + case TypeStringLiteral: + return n.StringLiteral().Val + case TypeIntegerLiteral: + return n.IntegerLiteral().String() + case TypeFloatLiteral: + return n.FloatLiteral().String() + case TypeBoolLiteral: + return n.BoolLiteral().String() + case TypeNilLiteral: + return "nil" + case TypeListLiteral: + return n.ListLiteral().String() + case TypeMapLiteral: + return n.MapLiteral().String() + case TypeParenExpr: + return n.ParenExpr().String() + case TypeAttrExpr: + return n.AttrExpr().String() + case TypeIndexExpr: + return n.IndexExpr().String() + case TypeUnaryExpr: + return n.UnaryExpr().String() + case TypeArithmeticExpr: + return n.ArithmeticExpr().String() + case TypeConditionalExpr: + return n.ConditionalExpr().String() + case TypeAssignmentExpr: + return n.AssignmentExpr().String() + } + return "" +} diff --git a/pkg/ast/dag.go b/pkg/ast/dag.go new file mode 100644 index 0000000..4876807 --- /dev/null +++ b/pkg/ast/dag.go @@ -0,0 +1,49 @@ +package ast + +import "fmt" + +type DAGNode struct { + ID string + Children []*DAGNode + Parents []*DAGNode + Data interface{} +} + +func NewDAGNode(id string, data interface{}) *DAGNode { + return &DAGNode{ + ID: id, + Data: data, + } +} + +func (n *DAGNode) AddChild(child *DAGNode) error { + if n == child || createsCycle(n, child) { + return fmt.Errorf("cycle detected") + } + + n.Children = append(n.Children, child) + child.Parents = append(child.Parents, n) + return nil +} +func createsCycle(a, b *DAGNode) bool { + visited := make(map[string]bool) + var check func(*DAGNode) bool + + check = func(node *DAGNode) bool { + if node.ID == a.ID { + return true + } + if visited[node.ID] { + return false + } + visited[node.ID] = true + for _, p := range node.Parents { + if check(p) { + return true + } + } + return false + } + + return check(b) +} diff --git a/pkg/parser/parser_test.go b/pkg/parser/parser_test.go index bd74ced..ba71846 100644 --- a/pkg/parser/parser_test.go +++ b/pkg/parser/parser_test.go @@ -1596,7 +1596,19 @@ multiline-string { name: "invalid slice with invalid object type", in: `true[1:3]`, - fail: true, + expected: ast.Stmts{ + ast.WrapSliceExpr(&ast.SliceExpr{ + Obj: ast.WrapBoolLiteral(&ast.BoolLiteral{ + Val: true, + }), + Start: ast.WrapIntegerLiteral(&ast.IntegerLiteral{ + Val: 1, + }), + End: ast.WrapIntegerLiteral(&ast.IntegerLiteral{ + Val: 3, + }), + }), + }, }, { name: "invalid slice with invalid start type", @@ -1742,3 +1754,91 @@ func(a, b, c)[func(a, b, c):b] t.Log(stmts) } } +func TestArithDAGReuse(t *testing.T) { + // 定义AST遍历函数 + var collectArithmeticNodes func(*ast.Node, *[]*ast.Node) + collectArithmeticNodes = func(n *ast.Node, nodes *[]*ast.Node) { + if n == nil { + return + } + if n.NodeType == ast.TypeArithmeticExpr { + *nodes = append(*nodes, n) + } + // 递归遍历子节点 + switch n.NodeType { + case ast.TypeArithmeticExpr: + expr := n.ArithmeticExpr() + collectArithmeticNodes(expr.LHS, nodes) + collectArithmeticNodes(expr.RHS, nodes) + case ast.TypeAssignmentExpr: + collectArithmeticNodes(n.AssignmentExpr().RHS, nodes) + case ast.TypeParenExpr: + collectArithmeticNodes(n.ParenExpr().Param, nodes) + case ast.TypeCallExpr: + for _, p := range n.CallExpr().Param { + collectArithmeticNodes(p, nodes) + } + } + } + + t.Run("basic-reuse", func(t *testing.T) { + input := `(2+2)+(2+2)` + stmts, err := ParsePipeline("", input) + assert.NoError(t, err) + + // 收集所有算术表达式节点 + var arithNodes []*ast.Node + collectArithmeticNodes(stmts[0], &arithNodes) + + assert.Len(t, arithNodes, 3, "应该有3个算术表达式节点") + + mulNode := arithNodes[0] + assert.Equal(t, ast.ADD, mulNode.ArithmeticExpr().Op, "应为加法操作") + + // 获取两个加法表达式节点 + addNode1 := arithNodes[1] + addNode2 := arithNodes[2] + assert.Equal(t, ast.ADD, addNode1.ArithmeticExpr().Op, "应为加法操作") + assert.Equal(t, ast.ADD, addNode2.ArithmeticExpr().Op, "应为加法操作") + + // 验证DAG节点复用 + assert.Equal(t, addNode1.DagNode.ID, addNode2.DagNode.ID, + "两个加法表达式应共享相同DAG节点") + assert.NotEqual(t, addNode1.DagNode.ID, mulNode.DagNode.ID, + "不同操作数的表达式应有不同DAG节点") + }) + + t.Run("different-expressions", func(t *testing.T) { + input := `y = (c + d) * (e + f)` // 完全不同的表达式 + stmts, err := ParsePipeline("", input) + assert.NoError(t, err) + + var arithNodes []*ast.Node + collectArithmeticNodes(stmts[0], &arithNodes) + assert.Len(t, arithNodes, 3, "应该有三个算术表达式节点") + + // 验证所有节点都有不同的DAG ID + assert.NotEqual(t, arithNodes[1].DagNode.ID, arithNodes[2].DagNode.ID, + "不同操作数的表达式应有不同DAG节点") + }) + + t.Run("nested-reuse", func(t *testing.T) { + input := `z = ((a + b) + c) + (a + b)` // 嵌套复用 + stmts, err := ParsePipeline("", input) + assert.NoError(t, err) + + var arithNodes []*ast.Node + collectArithmeticNodes(stmts[0], &arithNodes) + + // 节点结构: + // 0: 最外层加法 ((a+b)+c) + (a+b) + // 1: 中间层加法 (a+b)+c + // 2: 基础加法 a+b + // 3: 基础加法 a+b + assert.Len(t, arithNodes, 4, "应该有四个算术表达式节点") + + // 验证最底层两个a+b复用 + assert.Equal(t, arithNodes[2].DagNode.ID, arithNodes[3].DagNode.ID, + "相同基础表达式应复用DAG节点") + }) +}