Skip to content

Commit 8c77a67

Browse files
authored
ROX-31479: Add getting nodes for CVE (#19)
Assisted-by: Claude Code
1 parent bf6db65 commit 8c77a67

File tree

4 files changed

+585
-1
lines changed

4 files changed

+585
-1
lines changed
Lines changed: 213 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,213 @@
1+
package vulnerability
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"io"
7+
"sort"
8+
"strings"
9+
10+
"github.com/google/jsonschema-go/jsonschema"
11+
"github.com/modelcontextprotocol/go-sdk/mcp"
12+
"github.com/pkg/errors"
13+
v1 "github.com/stackrox/rox/generated/api/v1"
14+
"github.com/stackrox/stackrox-mcp/internal/client"
15+
"github.com/stackrox/stackrox-mcp/internal/client/auth"
16+
"github.com/stackrox/stackrox-mcp/internal/logging"
17+
"github.com/stackrox/stackrox-mcp/internal/toolsets"
18+
"google.golang.org/grpc"
19+
)
20+
21+
// getNodesForCVEInput defines the input parameters for get_nodes_for_cve tool.
22+
type getNodesForCVEInput struct {
23+
CVEName string `json:"cveName"`
24+
FilterClusterID string `json:"filterClusterId,omitempty"`
25+
}
26+
27+
func (input *getNodesForCVEInput) validate() error {
28+
if input.CVEName == "" {
29+
return errors.New("CVE name is required")
30+
}
31+
32+
return nil
33+
}
34+
35+
// NodeGroupResult contains aggregated node information by cluster and OS.
36+
type NodeGroupResult struct {
37+
ClusterID string `json:"clusterId"`
38+
ClusterName string `json:"clusterName"`
39+
OperatingSystem string `json:"operatingSystem"`
40+
Count int `json:"count"`
41+
}
42+
43+
// getNodesForCVEOutput defines the output structure for get_nodes_for_cve tool.
44+
type getNodesForCVEOutput struct {
45+
NodeGroups []NodeGroupResult `json:"nodeGroups"`
46+
}
47+
48+
// getNodesForCVETool implements the get_nodes_for_cve tool.
49+
type getNodesForCVETool struct {
50+
name string
51+
client *client.Client
52+
}
53+
54+
// NewGetNodesForCVETool creates a new get_nodes_for_cve tool.
55+
func NewGetNodesForCVETool(c *client.Client) toolsets.Tool {
56+
return &getNodesForCVETool{
57+
name: "get_nodes_for_cve",
58+
client: c,
59+
}
60+
}
61+
62+
// IsReadOnly returns true as this tool only reads data.
63+
func (t *getNodesForCVETool) IsReadOnly() bool {
64+
return true
65+
}
66+
67+
// GetName returns the tool name.
68+
func (t *getNodesForCVETool) GetName() string {
69+
return t.name
70+
}
71+
72+
// GetTool returns the MCP Tool definition.
73+
func (t *getNodesForCVETool) GetTool() *mcp.Tool {
74+
return &mcp.Tool{
75+
Name: t.name,
76+
Description: "Get aggregated node groups affected by a specific CVE, grouped by cluster and operating system image",
77+
InputSchema: getNodesForCVEInputSchema(),
78+
}
79+
}
80+
81+
// getNodesForCVEInputSchema returns the JSON schema for input validation.
82+
func getNodesForCVEInputSchema() *jsonschema.Schema {
83+
schema, err := jsonschema.For[getNodesForCVEInput](nil)
84+
if err != nil {
85+
logging.Fatal("Could not get jsonschema for get_nodes_for_cve input", err)
86+
87+
return nil
88+
}
89+
90+
// CVE name is required.
91+
schema.Required = []string{"cveName"}
92+
93+
schema.Properties["cveName"].Description = "CVE name to filter nodes (e.g., CVE-2020-26159)"
94+
schema.Properties["filterClusterId"].Description = "Optional cluster ID to filter nodes"
95+
96+
return schema
97+
}
98+
99+
// RegisterWith registers the get_nodes_for_cve tool handler with the MCP server.
100+
func (t *getNodesForCVETool) RegisterWith(server *mcp.Server) {
101+
mcp.AddTool(server, t.GetTool(), t.handle)
102+
}
103+
104+
// buildNodeQuery builds query used to search nodes in StackRox Central.
105+
// We will quote values to have strict match. Without quote: CVE-2025-10, would match CVE-2025-101.
106+
func buildNodeQuery(input getNodesForCVEInput) string {
107+
queryParts := []string{fmt.Sprintf("CVE:%q", input.CVEName)}
108+
109+
if input.FilterClusterID != "" {
110+
queryParts = append(queryParts, fmt.Sprintf("Cluster ID:%q", input.FilterClusterID))
111+
}
112+
113+
return strings.Join(queryParts, "+")
114+
}
115+
116+
// aggregateNodeGroups consumes entire stream and aggregates nodes by cluster and OS.
117+
func aggregateNodeGroups(
118+
stream grpc.ServerStreamingClient[v1.ExportNodeResponse],
119+
) ([]NodeGroupResult, error) {
120+
// Map key: "clusterId|osImage"
121+
// Map value: NodeGroupResult with count and clusterName.
122+
groups := make(map[string]*NodeGroupResult)
123+
124+
for {
125+
resp, err := stream.Recv()
126+
127+
// Stream ended - no more nodes.
128+
if errors.Is(err, io.EOF) {
129+
break
130+
}
131+
132+
if err != nil {
133+
return nil, errors.Wrap(err, "error receiving from stream")
134+
}
135+
136+
node := resp.GetNode()
137+
if node == nil {
138+
continue
139+
}
140+
141+
// Create unique key for this cluster+OS combination.
142+
key := fmt.Sprintf("%s|%s", node.GetClusterId(), node.GetOsImage())
143+
if group, exists := groups[key]; exists {
144+
group.Count++
145+
146+
continue
147+
}
148+
149+
groups[key] = &NodeGroupResult{
150+
ClusterID: node.GetClusterId(),
151+
ClusterName: node.GetClusterName(),
152+
OperatingSystem: node.GetOsImage(),
153+
Count: 1,
154+
}
155+
}
156+
157+
result := make([]NodeGroupResult, 0, len(groups))
158+
for _, group := range groups {
159+
result = append(result, *group)
160+
}
161+
162+
// Sort for consistent ordering (by clusterId, then OS).
163+
sort.Slice(result, func(i, j int) bool {
164+
if result[i].ClusterID != result[j].ClusterID {
165+
return result[i].ClusterID < result[j].ClusterID
166+
}
167+
168+
return result[i].OperatingSystem < result[j].OperatingSystem
169+
})
170+
171+
return result, nil
172+
}
173+
174+
// handle is the handler for get_nodes_for_cve tool.
175+
func (t *getNodesForCVETool) handle(
176+
ctx context.Context,
177+
req *mcp.CallToolRequest,
178+
input getNodesForCVEInput,
179+
) (*mcp.CallToolResult, *getNodesForCVEOutput, error) {
180+
err := input.validate()
181+
if err != nil {
182+
return nil, nil, err
183+
}
184+
185+
conn, err := t.client.ReadyConn(ctx)
186+
if err != nil {
187+
return nil, nil, errors.Wrap(err, "unable to connect to server")
188+
}
189+
190+
callCtx := auth.WithMCPRequestContext(ctx, req)
191+
nodeClient := v1.NewNodeServiceClient(conn)
192+
193+
query := buildNodeQuery(input)
194+
exportReq := &v1.ExportNodeRequest{
195+
Query: query,
196+
}
197+
198+
stream, err := nodeClient.ExportNodes(callCtx, exportReq)
199+
if err != nil {
200+
return nil, nil, client.NewError(err, "ExportNodes")
201+
}
202+
203+
nodeGroups, err := aggregateNodeGroups(stream)
204+
if err != nil {
205+
return nil, nil, err
206+
}
207+
208+
output := &getNodesForCVEOutput{
209+
NodeGroups: nodeGroups,
210+
}
211+
212+
return nil, output, nil
213+
}

0 commit comments

Comments
 (0)