Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -16,65 +16,19 @@
package com.contrast.labs.ai.mcp.contrast.sdkextension;

import com.contrastsecurity.models.TraceFilterBody;
import com.contrastsecurity.models.TraceMetadataFilter;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import lombok.Getter;
import lombok.Setter;

/**
* Extended TraceFilterBody that provides helper methods for building filters with session
* parameters (agentSessionId and metadata filters).
* Extended TraceFilterBody that adds fields not present in the SDK's base class.
*
* <p>This class exists because TeamServer requires numeric field IDs for metadata filtering, but
* users provide human-readable field names. The resolution happens in the tool layer, and this
* class accepts the already-resolved filters.
* <p>The SDK's TraceFilterBody lacks the 'status' field that TeamServer's API supports. This class
* adds it. The base class already has agentSessionId and metadataFilters fields.
*/
@Getter
@Setter
public class ExtendedTraceFilterBody extends TraceFilterBody {

/**
* Creates an ExtendedTraceFilterBody from a base filter body with session parameters.
*
* @param source Base filter body with standard filters (severities, environments, etc.)
* @param agentSessionId Agent session ID for latest session filtering (nullable)
* @param metadataFilters List of metadata filters with resolved field IDs (nullable)
* @return Extended filter body with all parameters set
*/
public static ExtendedTraceFilterBody withSessionFilters(
TraceFilterBody source, String agentSessionId, List<TraceMetadataFilter> metadataFilters) {
Objects.requireNonNull(source, "source TraceFilterBody must not be null");

var extended = new ExtendedTraceFilterBody();

// Copy base filters from source
if (source.getSeverities() != null) {
extended.setSeverities(source.getSeverities());
}
if (source.getVulnTypes() != null) {
extended.setVulnTypes(source.getVulnTypes());
}
if (source.getEnvironments() != null) {
extended.setEnvironments(source.getEnvironments());
}
if (source.getStartDate() != null) {
extended.setStartDate(source.getStartDate());
}
if (source.getEndDate() != null) {
extended.setEndDate(source.getEndDate());
}
if (source.getFilterTags() != null) {
extended.setFilterTags(source.getFilterTags());
}
extended.setTracked(source.isTracked());
extended.setUntracked(source.isUntracked());

// Add session parameters
if (agentSessionId != null) {
extended.setAgentSessionId(agentSessionId);
}

if (metadataFilters != null && !metadataFilters.isEmpty()) {
extended.setMetadataFilters(metadataFilters);
}

return extended;
}
private Set<String> status;
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,15 @@
import com.contrast.labs.ai.mcp.contrast.PaginationParams;
import com.contrast.labs.ai.mcp.contrast.data.VulnLight;
import com.contrast.labs.ai.mcp.contrast.mapper.VulnerabilityMapper;
import com.contrast.labs.ai.mcp.contrast.sdkextension.ExtendedTraceFilterBody;
import com.contrast.labs.ai.mcp.contrast.sdkextension.SDKExtension;
import com.contrast.labs.ai.mcp.contrast.tool.assess.params.SearchAppVulnerabilitiesParams;
import com.contrast.labs.ai.mcp.contrast.tool.base.BasePaginatedTool;
import com.contrast.labs.ai.mcp.contrast.tool.base.ExecutionResult;
import com.contrast.labs.ai.mcp.contrast.tool.base.PaginatedToolResponse;
import com.contrast.labs.ai.mcp.contrast.tool.validation.UnresolvedMetadataFilter;
import com.contrastsecurity.http.TraceFilterForm.TraceExpandValue;
import com.contrastsecurity.models.TraceFilterBody;
import com.contrastsecurity.models.TraceMetadataFilter;
import com.contrastsecurity.sdk.ContrastSDK;
import java.util.ArrayList;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
Expand Down Expand Up @@ -183,14 +181,13 @@ protected ExecutionResult<VulnLight> doExecute(
log.debug("Resolved {} session metadata filters", resolvedFilters.size());
}

// Build filter body - use ExtendedTraceFilterBody if session params present
TraceFilterBody filterBody;
if (agentSessionId != null || resolvedFilters != null) {
filterBody =
ExtendedTraceFilterBody.withSessionFilters(
params.toTraceFilterBody(), agentSessionId, resolvedFilters);
} else {
filterBody = params.toTraceFilterBody();
// Build filter body and add session params if present
var filterBody = params.toTraceFilterBody();
if (agentSessionId != null) {
filterBody.setAgentSessionId(agentSessionId);
}
if (resolvedFilters != null && !resolvedFilters.isEmpty()) {
filterBody.setMetadataFilters(resolvedFilters);
}

var expand =
Expand All @@ -214,54 +211,27 @@ protected ExecutionResult<VulnLight> doExecute(
return ExecutionResult.of(vulnerabilities, traces.getCount());
}

/**
* Normalizes a filter value to a List of Strings. Handles both single String values and
* List<String> values.
*/
@SuppressWarnings("unchecked")
private List<String> normalizeFilterValue(Object value) {
if (value instanceof String) {
return List.of((String) value);
} else if (value instanceof List) {
return (List<String>) value;
}
return List.of();
}

/**
* Resolves session metadata filter names to numeric IDs and builds TraceMetadataFilter list.
*
* @param sdk ContrastSDK instance
* @param orgId Organization ID
* @param appId Application ID
* @param filters Map of field names to values (from parsed JSON)
* @param filters List of unresolved metadata filters (from parsed JSON)
* @return List of TraceMetadataFilter with resolved field IDs
* @throws IllegalArgumentException if any field name is not found
*/
private List<TraceMetadataFilter> resolveSessionMetadataFilters(
ContrastSDK sdk, String orgId, String appId, Map<String, Object> filters) throws Exception {
ContrastSDK sdk, String orgId, String appId, List<UnresolvedMetadataFilter> filters)
throws Exception {

// Build field name to ID mapping (case-insensitive)
var fieldNameToId = buildFieldNameToIdMapping(sdk, orgId, appId);

var result = new ArrayList<TraceMetadataFilter>();
var notFoundFields = new ArrayList<String>();

for (var entry : filters.entrySet()) {
var fieldName = entry.getKey();
var fieldId = fieldNameToId.get(fieldName.toLowerCase());

if (fieldId == null) {
notFoundFields.add(fieldName);
continue;
}

// Convert value to List<String>
var values = normalizeFilterValue(entry.getValue());
result.add(new TraceMetadataFilter(fieldId, values));

log.debug("Resolved session metadata field '{}' to ID '{}'", fieldName, fieldId);
}
var notFoundFields =
filters.stream()
.map(UnresolvedMetadataFilter::fieldName)
.filter(name -> !fieldNameToId.containsKey(name.toLowerCase()))
.toList();

if (!notFoundFields.isEmpty()) {
throw new IllegalArgumentException(
Expand All @@ -271,7 +241,14 @@ private List<TraceMetadataFilter> resolveSessionMetadataFilters(
appId, String.join(", ", notFoundFields)));
}

return result;
return filters.stream()
.map(
f -> {
var fieldId = fieldNameToId.get(f.fieldName().toLowerCase());
log.debug("Resolved session metadata field '{}' to ID '{}'", f.fieldName(), fieldId);
return new TraceMetadataFilter(fieldId, f.values());
})
.toList();
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
*/
package com.contrast.labs.ai.mcp.contrast.tool.assess.params;

import com.contrast.labs.ai.mcp.contrast.sdkextension.ExtendedTraceFilterBody;
import com.contrast.labs.ai.mcp.contrast.tool.base.BaseToolParams;
import com.contrast.labs.ai.mcp.contrast.tool.validation.ToolValidationContext;
import com.contrast.labs.ai.mcp.contrast.tool.validation.UnresolvedMetadataFilter;
import com.contrast.labs.ai.mcp.contrast.tool.validation.ValidationConstants;
import com.contrastsecurity.http.RuleSeverity;
import com.contrastsecurity.http.ServerEnvironment;
import com.contrastsecurity.models.TraceFilterBody;
import java.util.Date;
import java.util.EnumSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.springframework.util.StringUtils;

Expand All @@ -45,14 +46,6 @@
*/
public class SearchAppVulnerabilitiesParams extends BaseToolParams {

/** Valid status values for vulnerability filtering. */
public static final Set<String> VALID_STATUSES =
Set.of("Reported", "Suspicious", "Confirmed", "Remediated", "Fixed");

/** Default statuses - excludes Fixed and Remediated to focus on actionable items. */
public static final List<String> DEFAULT_STATUSES =
List.of("Reported", "Suspicious", "Confirmed");

private String appId;
private EnumSet<RuleSeverity> severities;
private List<String> statuses;
Expand All @@ -61,7 +54,7 @@ public class SearchAppVulnerabilitiesParams extends BaseToolParams {
private Date lastSeenAfter;
private Date lastSeenBefore;
private List<String> vulnTags;
private Map<String, Object> sessionMetadataFilters;
private List<UnresolvedMetadataFilter> sessionMetadataFilters;
private Boolean useLatestSession;

/** Private constructor - use static factory method {@link #of}. */
Expand Down Expand Up @@ -107,9 +100,9 @@ public static SearchAppVulnerabilitiesParams of(

params.statuses =
ctx.stringListParam(statusesParam, "statuses")
.allowedValues(VALID_STATUSES)
.allowedValues(ValidationConstants.VALID_VULN_STATUSES)
.defaultTo(
DEFAULT_STATUSES,
ValidationConstants.DEFAULT_VULN_STATUSES,
"Showing actionable vulnerabilities only (excluding Fixed and Remediated). "
+ "To see all statuses, specify statuses parameter explicitly.")
.get();
Expand Down Expand Up @@ -150,16 +143,19 @@ public static SearchAppVulnerabilitiesParams of(
}

/**
* Convert to SDK TraceFilterBody for POST endpoint API calls.
* Convert to ExtendedTraceFilterBody for POST endpoint API calls.
*
* @return TraceFilterBody configured with all filters
* @return ExtendedTraceFilterBody configured with all filters including status
*/
public TraceFilterBody toTraceFilterBody() {
var body = new TraceFilterBody();
public ExtendedTraceFilterBody toTraceFilterBody() {
var body = new ExtendedTraceFilterBody();
// Note: tracked/untracked NOT set - primitive defaults (false) mean "return all"
if (severities != null) {
body.setSeverities(severities.stream().toList());
}
if (statuses != null) {
body.setStatus(Set.copyOf(statuses));
}
if (vulnTypes != null) {
body.setVulnTypes(vulnTypes);
}
Expand Down Expand Up @@ -210,7 +206,7 @@ public List<String> getVulnTags() {
return vulnTags;
}

public Map<String, Object> getSessionMetadataFilters() {
public List<UnresolvedMetadataFilter> getSessionMetadataFilters() {
return sessionMetadataFilters;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@
*/
package com.contrast.labs.ai.mcp.contrast.tool.assess.params;

import com.contrast.labs.ai.mcp.contrast.sdkextension.ExtendedTraceFilterBody;
import com.contrast.labs.ai.mcp.contrast.tool.base.BaseToolParams;
import com.contrast.labs.ai.mcp.contrast.tool.validation.ToolValidationContext;
import com.contrast.labs.ai.mcp.contrast.tool.validation.ValidationConstants;
import com.contrastsecurity.http.RuleSeverity;
import com.contrastsecurity.http.ServerEnvironment;
import com.contrastsecurity.models.TraceFilterBody;
import java.util.Date;
import java.util.EnumSet;
import java.util.List;
import java.util.Set;

/**
* Vulnerability filter parameters using fluent validation API. Demonstrates the
Expand Down Expand Up @@ -113,22 +114,23 @@ public static VulnerabilityFilterParams of(
}

/**
* Convert to SDK TraceFilterBody for POST endpoint API calls.
* Convert to ExtendedTraceFilterBody for POST endpoint API calls.
*
* <p>TraceFilterBody has primitive boolean tracked/untracked fields that default to false. When
* both are false, the API returns ALL vulnerabilities (both tracked and untracked). This avoids
* the TraceFilterForm default of tracked=true which filters out untracked vulns.
*
* @return TraceFilterBody configured with all filters
* @return ExtendedTraceFilterBody configured with all filters including status
*/
public TraceFilterBody toTraceFilterBody() {
var body = new TraceFilterBody();
public ExtendedTraceFilterBody toTraceFilterBody() {
var body = new ExtendedTraceFilterBody();
// Note: tracked/untracked NOT set - primitive defaults (false, false) mean "return all"
if (severities != null) {
body.setSeverities(severities.stream().toList());
}
// TraceFilterBody doesn't have status field - it uses quickFilter instead
// Status filtering handled server-side with different mechanism
if (statuses != null) {
body.setStatus(Set.copyOf(statuses));
}
if (vulnTypes != null) {
body.setVulnTypes(vulnTypes);
}
Expand Down
Loading