Skip to content

Commit a2fdc52

Browse files
committed
Add get clusters for CVE
1 parent 8c77a67 commit a2fdc52

File tree

4 files changed

+504
-4
lines changed

4 files changed

+504
-4
lines changed
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
package vulnerability
2+
3+
import (
4+
"context"
5+
"fmt"
6+
"sort"
7+
"strings"
8+
9+
"github.com/google/jsonschema-go/jsonschema"
10+
"github.com/modelcontextprotocol/go-sdk/mcp"
11+
"github.com/pkg/errors"
12+
v1 "github.com/stackrox/rox/generated/api/v1"
13+
"github.com/stackrox/stackrox-mcp/internal/client"
14+
"github.com/stackrox/stackrox-mcp/internal/client/auth"
15+
"github.com/stackrox/stackrox-mcp/internal/logging"
16+
"github.com/stackrox/stackrox-mcp/internal/toolsets"
17+
)
18+
19+
// getClustersForCVEInput defines the input parameters for get_clusters_for_cve tool.
20+
type getClustersForCVEInput struct {
21+
CVEName string `json:"cveName"`
22+
FilterClusterID string `json:"filterClusterId,omitempty"`
23+
}
24+
25+
func (input *getClustersForCVEInput) validate() error {
26+
if input.CVEName == "" {
27+
return errors.New("CVE name is required")
28+
}
29+
30+
return nil
31+
}
32+
33+
// ClusterResult contains cluster information.
34+
type ClusterResult struct {
35+
ClusterID string `json:"clusterId"`
36+
ClusterName string `json:"clusterName"`
37+
}
38+
39+
// getClustersForCVEOutput defines the output structure for get_clusters_for_cve tool.
40+
type getClustersForCVEOutput struct {
41+
Clusters []ClusterResult `json:"clusters"`
42+
}
43+
44+
// getClustersForCVETool implements the get_clusters_for_cve tool.
45+
type getClustersForCVETool struct {
46+
name string
47+
client *client.Client
48+
}
49+
50+
// NewGetClustersForCVETool creates a new get_clusters_for_cve tool.
51+
func NewGetClustersForCVETool(c *client.Client) toolsets.Tool {
52+
return &getClustersForCVETool{
53+
name: "get_clusters_for_cve",
54+
client: c,
55+
}
56+
}
57+
58+
// IsReadOnly returns true as this tool only reads data.
59+
func (t *getClustersForCVETool) IsReadOnly() bool {
60+
return true
61+
}
62+
63+
// GetName returns the tool name.
64+
func (t *getClustersForCVETool) GetName() string {
65+
return t.name
66+
}
67+
68+
// GetTool returns the MCP Tool definition.
69+
func (t *getClustersForCVETool) GetTool() *mcp.Tool {
70+
return &mcp.Tool{
71+
Name: t.name,
72+
Description: "Get list of clusters affected by a specific CVE",
73+
InputSchema: getClustersForCVEInputSchema(),
74+
}
75+
}
76+
77+
// getClustersForCVEInputSchema returns the JSON schema for input validation.
78+
func getClustersForCVEInputSchema() *jsonschema.Schema {
79+
schema, err := jsonschema.For[getClustersForCVEInput](nil)
80+
if err != nil {
81+
logging.Fatal("Could not get jsonschema for get_clusters_for_cve input", err)
82+
83+
return nil
84+
}
85+
86+
// CVE name is required.
87+
schema.Required = []string{"cveName"}
88+
89+
schema.Properties["cveName"].Description = "CVE name to filter clusters (e.g., CVE-2021-44228)"
90+
schema.Properties["filterClusterId"].Description = "Optional cluster ID to verify if a specific cluster is affected"
91+
92+
return schema
93+
}
94+
95+
// RegisterWith registers the get_clusters_for_cve tool handler with the MCP server.
96+
func (t *getClustersForCVETool) RegisterWith(server *mcp.Server) {
97+
mcp.AddTool(server, t.GetTool(), t.handle)
98+
}
99+
100+
// buildClusterQuery builds query string for filtering clusters by CVE.
101+
// We quote values for exact match (CVE-2025-10 won't match CVE-2025-101).
102+
func buildClusterQuery(input getClustersForCVEInput) string {
103+
queryParts := []string{fmt.Sprintf("CVE:%q", input.CVEName)}
104+
105+
if input.FilterClusterID != "" {
106+
queryParts = append(queryParts, fmt.Sprintf("Cluster ID:%q", input.FilterClusterID))
107+
}
108+
109+
return strings.Join(queryParts, "+")
110+
}
111+
112+
// handle is the handler for get_clusters_for_cve tool.
113+
func (t *getClustersForCVETool) handle(
114+
ctx context.Context,
115+
req *mcp.CallToolRequest,
116+
input getClustersForCVEInput,
117+
) (*mcp.CallToolResult, *getClustersForCVEOutput, error) {
118+
err := input.validate()
119+
if err != nil {
120+
return nil, nil, err
121+
}
122+
123+
conn, err := t.client.ReadyConn(ctx)
124+
if err != nil {
125+
return nil, nil, errors.Wrap(err, "unable to connect to server")
126+
}
127+
128+
callCtx := auth.WithMCPRequestContext(ctx, req)
129+
130+
clustersClient := v1.NewClustersServiceClient(conn)
131+
132+
query := buildClusterQuery(input)
133+
134+
resp, err := clustersClient.GetClusters(callCtx, &v1.GetClustersRequest{
135+
Query: query,
136+
})
137+
if err != nil {
138+
return nil, nil, client.NewError(err, "GetClusters")
139+
}
140+
141+
clusters := make([]ClusterResult, 0, len(resp.GetClusters()))
142+
for _, cluster := range resp.GetClusters() {
143+
clusters = append(clusters, ClusterResult{
144+
ClusterID: cluster.GetId(),
145+
ClusterName: cluster.GetName(),
146+
})
147+
}
148+
149+
// Sort by cluster ID for deterministic output.
150+
sort.Slice(clusters, func(i, j int) bool {
151+
return clusters[i].ClusterID < clusters[j].ClusterID
152+
})
153+
154+
output := &getClustersForCVEOutput{
155+
Clusters: clusters,
156+
}
157+
158+
return nil, output, nil
159+
}

0 commit comments

Comments
 (0)