diff --git a/src/main/java/com/contrast/labs/ai/mcp/contrast/sdkextension/ExtendedTraceFilterBody.java b/src/main/java/com/contrast/labs/ai/mcp/contrast/sdkextension/ExtendedTraceFilterBody.java index 3ae3397..604a83c 100644 --- a/src/main/java/com/contrast/labs/ai/mcp/contrast/sdkextension/ExtendedTraceFilterBody.java +++ b/src/main/java/com/contrast/labs/ai/mcp/contrast/sdkextension/ExtendedTraceFilterBody.java @@ -16,65 +16,19 @@ package com.contrast.labs.ai.mcp.contrast.sdkextension; import com.contrastsecurity.models.TraceFilterBody; -import com.contrastsecurity.models.TraceMetadataFilter; -import java.util.List; -import java.util.Objects; +import java.util.Set; +import lombok.Getter; +import lombok.Setter; /** - * Extended TraceFilterBody that provides helper methods for building filters with session - * parameters (agentSessionId and metadata filters). + * Extended TraceFilterBody that adds fields not present in the SDK's base class. * - *

This class exists because TeamServer requires numeric field IDs for metadata filtering, but - * users provide human-readable field names. The resolution happens in the tool layer, and this - * class accepts the already-resolved filters. + *

The SDK's TraceFilterBody lacks the 'status' field that TeamServer's API supports. This class + * adds it. The base class already has agentSessionId and metadataFilters fields. */ +@Getter +@Setter public class ExtendedTraceFilterBody extends TraceFilterBody { - /** - * Creates an ExtendedTraceFilterBody from a base filter body with session parameters. - * - * @param source Base filter body with standard filters (severities, environments, etc.) - * @param agentSessionId Agent session ID for latest session filtering (nullable) - * @param metadataFilters List of metadata filters with resolved field IDs (nullable) - * @return Extended filter body with all parameters set - */ - public static ExtendedTraceFilterBody withSessionFilters( - TraceFilterBody source, String agentSessionId, List metadataFilters) { - Objects.requireNonNull(source, "source TraceFilterBody must not be null"); - - var extended = new ExtendedTraceFilterBody(); - - // Copy base filters from source - if (source.getSeverities() != null) { - extended.setSeverities(source.getSeverities()); - } - if (source.getVulnTypes() != null) { - extended.setVulnTypes(source.getVulnTypes()); - } - if (source.getEnvironments() != null) { - extended.setEnvironments(source.getEnvironments()); - } - if (source.getStartDate() != null) { - extended.setStartDate(source.getStartDate()); - } - if (source.getEndDate() != null) { - extended.setEndDate(source.getEndDate()); - } - if (source.getFilterTags() != null) { - extended.setFilterTags(source.getFilterTags()); - } - extended.setTracked(source.isTracked()); - extended.setUntracked(source.isUntracked()); - - // Add session parameters - if (agentSessionId != null) { - extended.setAgentSessionId(agentSessionId); - } - - if (metadataFilters != null && !metadataFilters.isEmpty()) { - extended.setMetadataFilters(metadataFilters); - } - - return extended; - } + private Set status; } diff --git a/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/SearchAppVulnerabilitiesTool.java b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/SearchAppVulnerabilitiesTool.java index 5fe1eab..3e393aa 100644 --- a/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/SearchAppVulnerabilitiesTool.java +++ b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/SearchAppVulnerabilitiesTool.java @@ -18,17 +18,15 @@ import com.contrast.labs.ai.mcp.contrast.PaginationParams; import com.contrast.labs.ai.mcp.contrast.data.VulnLight; import com.contrast.labs.ai.mcp.contrast.mapper.VulnerabilityMapper; -import com.contrast.labs.ai.mcp.contrast.sdkextension.ExtendedTraceFilterBody; import com.contrast.labs.ai.mcp.contrast.sdkextension.SDKExtension; import com.contrast.labs.ai.mcp.contrast.tool.assess.params.SearchAppVulnerabilitiesParams; import com.contrast.labs.ai.mcp.contrast.tool.base.BasePaginatedTool; import com.contrast.labs.ai.mcp.contrast.tool.base.ExecutionResult; import com.contrast.labs.ai.mcp.contrast.tool.base.PaginatedToolResponse; +import com.contrast.labs.ai.mcp.contrast.tool.validation.UnresolvedMetadataFilter; import com.contrastsecurity.http.TraceFilterForm.TraceExpandValue; -import com.contrastsecurity.models.TraceFilterBody; import com.contrastsecurity.models.TraceMetadataFilter; import com.contrastsecurity.sdk.ContrastSDK; -import java.util.ArrayList; import java.util.EnumSet; import java.util.HashMap; import java.util.List; @@ -183,14 +181,13 @@ protected ExecutionResult doExecute( log.debug("Resolved {} session metadata filters", resolvedFilters.size()); } - // Build filter body - use ExtendedTraceFilterBody if session params present - TraceFilterBody filterBody; - if (agentSessionId != null || resolvedFilters != null) { - filterBody = - ExtendedTraceFilterBody.withSessionFilters( - params.toTraceFilterBody(), agentSessionId, resolvedFilters); - } else { - filterBody = params.toTraceFilterBody(); + // Build filter body and add session params if present + var filterBody = params.toTraceFilterBody(); + if (agentSessionId != null) { + filterBody.setAgentSessionId(agentSessionId); + } + if (resolvedFilters != null && !resolvedFilters.isEmpty()) { + filterBody.setMetadataFilters(resolvedFilters); } var expand = @@ -214,54 +211,27 @@ protected ExecutionResult doExecute( return ExecutionResult.of(vulnerabilities, traces.getCount()); } - /** - * Normalizes a filter value to a List of Strings. Handles both single String values and - * List values. - */ - @SuppressWarnings("unchecked") - private List normalizeFilterValue(Object value) { - if (value instanceof String) { - return List.of((String) value); - } else if (value instanceof List) { - return (List) value; - } - return List.of(); - } - /** * Resolves session metadata filter names to numeric IDs and builds TraceMetadataFilter list. * * @param sdk ContrastSDK instance * @param orgId Organization ID * @param appId Application ID - * @param filters Map of field names to values (from parsed JSON) + * @param filters List of unresolved metadata filters (from parsed JSON) * @return List of TraceMetadataFilter with resolved field IDs * @throws IllegalArgumentException if any field name is not found */ private List resolveSessionMetadataFilters( - ContrastSDK sdk, String orgId, String appId, Map filters) throws Exception { + ContrastSDK sdk, String orgId, String appId, List filters) + throws Exception { - // Build field name to ID mapping (case-insensitive) var fieldNameToId = buildFieldNameToIdMapping(sdk, orgId, appId); - var result = new ArrayList(); - var notFoundFields = new ArrayList(); - - for (var entry : filters.entrySet()) { - var fieldName = entry.getKey(); - var fieldId = fieldNameToId.get(fieldName.toLowerCase()); - - if (fieldId == null) { - notFoundFields.add(fieldName); - continue; - } - - // Convert value to List - var values = normalizeFilterValue(entry.getValue()); - result.add(new TraceMetadataFilter(fieldId, values)); - - log.debug("Resolved session metadata field '{}' to ID '{}'", fieldName, fieldId); - } + var notFoundFields = + filters.stream() + .map(UnresolvedMetadataFilter::fieldName) + .filter(name -> !fieldNameToId.containsKey(name.toLowerCase())) + .toList(); if (!notFoundFields.isEmpty()) { throw new IllegalArgumentException( @@ -271,7 +241,14 @@ private List resolveSessionMetadataFilters( appId, String.join(", ", notFoundFields))); } - return result; + return filters.stream() + .map( + f -> { + var fieldId = fieldNameToId.get(f.fieldName().toLowerCase()); + log.debug("Resolved session metadata field '{}' to ID '{}'", f.fieldName(), fieldId); + return new TraceMetadataFilter(fieldId, f.values()); + }) + .toList(); } /** diff --git a/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/SearchAppVulnerabilitiesParams.java b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/SearchAppVulnerabilitiesParams.java index fe29fd5..85ac506 100644 --- a/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/SearchAppVulnerabilitiesParams.java +++ b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/SearchAppVulnerabilitiesParams.java @@ -15,15 +15,16 @@ */ package com.contrast.labs.ai.mcp.contrast.tool.assess.params; +import com.contrast.labs.ai.mcp.contrast.sdkextension.ExtendedTraceFilterBody; import com.contrast.labs.ai.mcp.contrast.tool.base.BaseToolParams; import com.contrast.labs.ai.mcp.contrast.tool.validation.ToolValidationContext; +import com.contrast.labs.ai.mcp.contrast.tool.validation.UnresolvedMetadataFilter; +import com.contrast.labs.ai.mcp.contrast.tool.validation.ValidationConstants; import com.contrastsecurity.http.RuleSeverity; import com.contrastsecurity.http.ServerEnvironment; -import com.contrastsecurity.models.TraceFilterBody; import java.util.Date; import java.util.EnumSet; import java.util.List; -import java.util.Map; import java.util.Set; import org.springframework.util.StringUtils; @@ -45,14 +46,6 @@ */ public class SearchAppVulnerabilitiesParams extends BaseToolParams { - /** Valid status values for vulnerability filtering. */ - public static final Set VALID_STATUSES = - Set.of("Reported", "Suspicious", "Confirmed", "Remediated", "Fixed"); - - /** Default statuses - excludes Fixed and Remediated to focus on actionable items. */ - public static final List DEFAULT_STATUSES = - List.of("Reported", "Suspicious", "Confirmed"); - private String appId; private EnumSet severities; private List statuses; @@ -61,7 +54,7 @@ public class SearchAppVulnerabilitiesParams extends BaseToolParams { private Date lastSeenAfter; private Date lastSeenBefore; private List vulnTags; - private Map sessionMetadataFilters; + private List sessionMetadataFilters; private Boolean useLatestSession; /** Private constructor - use static factory method {@link #of}. */ @@ -107,9 +100,9 @@ public static SearchAppVulnerabilitiesParams of( params.statuses = ctx.stringListParam(statusesParam, "statuses") - .allowedValues(VALID_STATUSES) + .allowedValues(ValidationConstants.VALID_VULN_STATUSES) .defaultTo( - DEFAULT_STATUSES, + ValidationConstants.DEFAULT_VULN_STATUSES, "Showing actionable vulnerabilities only (excluding Fixed and Remediated). " + "To see all statuses, specify statuses parameter explicitly.") .get(); @@ -150,16 +143,19 @@ public static SearchAppVulnerabilitiesParams of( } /** - * Convert to SDK TraceFilterBody for POST endpoint API calls. + * Convert to ExtendedTraceFilterBody for POST endpoint API calls. * - * @return TraceFilterBody configured with all filters + * @return ExtendedTraceFilterBody configured with all filters including status */ - public TraceFilterBody toTraceFilterBody() { - var body = new TraceFilterBody(); + public ExtendedTraceFilterBody toTraceFilterBody() { + var body = new ExtendedTraceFilterBody(); // Note: tracked/untracked NOT set - primitive defaults (false) mean "return all" if (severities != null) { body.setSeverities(severities.stream().toList()); } + if (statuses != null) { + body.setStatus(Set.copyOf(statuses)); + } if (vulnTypes != null) { body.setVulnTypes(vulnTypes); } @@ -210,7 +206,7 @@ public List getVulnTags() { return vulnTags; } - public Map getSessionMetadataFilters() { + public List getSessionMetadataFilters() { return sessionMetadataFilters; } diff --git a/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/VulnerabilityFilterParams.java b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/VulnerabilityFilterParams.java index f33054f..e8630cc 100644 --- a/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/VulnerabilityFilterParams.java +++ b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/VulnerabilityFilterParams.java @@ -15,15 +15,16 @@ */ package com.contrast.labs.ai.mcp.contrast.tool.assess.params; +import com.contrast.labs.ai.mcp.contrast.sdkextension.ExtendedTraceFilterBody; import com.contrast.labs.ai.mcp.contrast.tool.base.BaseToolParams; import com.contrast.labs.ai.mcp.contrast.tool.validation.ToolValidationContext; import com.contrast.labs.ai.mcp.contrast.tool.validation.ValidationConstants; import com.contrastsecurity.http.RuleSeverity; import com.contrastsecurity.http.ServerEnvironment; -import com.contrastsecurity.models.TraceFilterBody; import java.util.Date; import java.util.EnumSet; import java.util.List; +import java.util.Set; /** * Vulnerability filter parameters using fluent validation API. Demonstrates the @@ -113,22 +114,23 @@ public static VulnerabilityFilterParams of( } /** - * Convert to SDK TraceFilterBody for POST endpoint API calls. + * Convert to ExtendedTraceFilterBody for POST endpoint API calls. * *

TraceFilterBody has primitive boolean tracked/untracked fields that default to false. When * both are false, the API returns ALL vulnerabilities (both tracked and untracked). This avoids * the TraceFilterForm default of tracked=true which filters out untracked vulns. * - * @return TraceFilterBody configured with all filters + * @return ExtendedTraceFilterBody configured with all filters including status */ - public TraceFilterBody toTraceFilterBody() { - var body = new TraceFilterBody(); + public ExtendedTraceFilterBody toTraceFilterBody() { + var body = new ExtendedTraceFilterBody(); // Note: tracked/untracked NOT set - primitive defaults (false, false) mean "return all" if (severities != null) { body.setSeverities(severities.stream().toList()); } - // TraceFilterBody doesn't have status field - it uses quickFilter instead - // Status filtering handled server-side with different mechanism + if (statuses != null) { + body.setStatus(Set.copyOf(statuses)); + } if (vulnTypes != null) { body.setVulnTypes(vulnTypes); } diff --git a/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/validation/MetadataJsonFilterSpec.java b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/validation/MetadataJsonFilterSpec.java index d8bf1df..557513d 100644 --- a/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/validation/MetadataJsonFilterSpec.java +++ b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/validation/MetadataJsonFilterSpec.java @@ -19,15 +19,14 @@ import com.google.gson.JsonSyntaxException; import com.google.gson.reflect.TypeToken; import java.util.ArrayList; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import org.springframework.util.StringUtils; /** - * Fluent validation spec for metadata filter JSON parameters. Parses JSON to Map - * where values are String or List. Used for session metadata filters and can be extended - * for application metadata. + * Fluent validation spec for metadata filter JSON parameters. Parses JSON to a list of {@link + * UnresolvedMetadataFilter} records where each filter has a field name and list of values. Used for + * session metadata filters and can be extended for application metadata. */ public class MetadataJsonFilterSpec { @@ -45,9 +44,9 @@ public class MetadataJsonFilterSpec { /** * Parses the JSON value and validates its structure. * - * @return Map of field names to values (String or List), or null if empty/invalid + * @return List of UnresolvedMetadataFilter records, or null if empty/invalid */ - public Map get() { + public List get() { if (!StringUtils.hasText(value)) { return null; } @@ -58,39 +57,14 @@ public Map get() { return null; } - Map result = new LinkedHashMap<>(); - List invalidEntries = new ArrayList<>(); + var result = new ArrayList(); + var invalidEntries = new ArrayList(); for (var entry : rawMap.entrySet()) { - var key = entry.getKey(); - var val = entry.getValue(); - - if (val instanceof String) { - result.put(key, val); - } else if (val instanceof Number) { - result.put(key, formatNumber((Number) val)); - } else if (val instanceof List) { - // Validate array contains only strings/numbers - List stringList = new ArrayList<>(); - boolean valid = true; - for (Object item : (List) val) { - if (item instanceof String) { - stringList.add((String) item); - } else if (item instanceof Number) { - stringList.add(formatNumber((Number) item)); - } else if (item != null) { - valid = false; - break; - } - } - if (valid) { - result.put(key, stringList); - } else { - invalidEntries.add(String.format("'%s' (array contains non-string values)", key)); - } - } else if (val != null) { - // Reject complex objects - invalidEntries.add(String.format("'%s' (expected string or array of strings)", key)); + var fieldName = entry.getKey(); + var values = parseValues(entry.getValue(), fieldName, invalidEntries); + if (values != null) { + result.add(new UnresolvedMetadataFilter(fieldName, values)); } } @@ -103,7 +77,7 @@ public Map get() { return null; } - return result; + return result.isEmpty() ? null : List.copyOf(result); } catch (JsonSyntaxException e) { ctx.addError( String.format( @@ -114,6 +88,30 @@ public Map get() { } } + private List parseValues(Object val, String fieldName, List invalidEntries) { + if (val instanceof String s) { + return List.of(s); + } else if (val instanceof Number n) { + return List.of(formatNumber(n)); + } else if (val instanceof List list) { + var strings = new ArrayList(); + for (Object item : list) { + if (item instanceof String s) { + strings.add(s); + } else if (item instanceof Number n) { + strings.add(formatNumber(n)); + } else if (item != null) { + invalidEntries.add(String.format("'%s' (array contains non-string values)", fieldName)); + return null; + } + } + return strings.isEmpty() ? null : List.copyOf(strings); + } else if (val != null) { + invalidEntries.add(String.format("'%s' (expected string or array of strings)", fieldName)); + } + return null; + } + /** * Formats a number as a string, using integer format when possible. Gson parses all JSON numbers * as doubles, so 42 becomes 42.0. This method converts back to integer format when the value is a diff --git a/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/validation/StringListSpec.java b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/validation/StringListSpec.java index f4a22be..ceae00b 100644 --- a/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/validation/StringListSpec.java +++ b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/validation/StringListSpec.java @@ -62,9 +62,10 @@ public StringListSpec defaultTo(List val, String reason) { } /** - * Sets allowed values. Invalid items are added as errors. + * Sets allowed values with case-insensitive matching. Input values are normalized to the + * canonical form (e.g., "reported" becomes "Reported"). Invalid items are added as errors. * - * @param values set of valid values + * @param values set of valid values in canonical form * @return this for fluent chaining */ public StringListSpec allowedValues(Set values) { @@ -104,14 +105,26 @@ public List get() { } if (allowedValues != null) { + // Build lowercase -> canonical mapping for case-insensitive matching + var canonicalMap = new java.util.HashMap(); + for (String allowed : allowedValues) { + canonicalMap.put(allowed.toLowerCase(), allowed); + } + + // Validate and normalize each item + var normalized = new java.util.ArrayList(); for (String item : parsed) { - if (!allowedValues.contains(item)) { + String canonical = canonicalMap.get(item.toLowerCase()); + if (canonical != null) { + normalized.add(canonical); + } else { ctx.addError( String.format( "Invalid %s: '%s'. Valid values: %s", name, item, String.join(", ", allowedValues))); } } + parsed = normalized; } return ctx.isValid() ? List.copyOf(parsed) : null; diff --git a/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/validation/UnresolvedMetadataFilter.java b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/validation/UnresolvedMetadataFilter.java new file mode 100644 index 0000000..1db0093 --- /dev/null +++ b/src/main/java/com/contrast/labs/ai/mcp/contrast/tool/validation/UnresolvedMetadataFilter.java @@ -0,0 +1,37 @@ +/* + * Copyright 2025 Contrast Security + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.contrast.labs.ai.mcp.contrast.tool.validation; + +import java.util.List; + +/** + * Represents a session metadata filter before field name resolution. + * + *

The field name is what the user provides (e.g., "branch", "developer"). This must be resolved + * to a field ID by looking up available session metadata for the application before it can be used + * in API calls. + * + * @param fieldName The user-provided field name (case-insensitive during resolution) + * @param values The filter values (OR logic - matches if any value matches) + */ +public record UnresolvedMetadataFilter(String fieldName, List values) { + public UnresolvedMetadataFilter { + if (fieldName == null || fieldName.isBlank()) { + throw new IllegalArgumentException("fieldName cannot be null or blank"); + } + values = values == null ? List.of() : List.copyOf(values); + } +} diff --git a/src/test/java/com/contrast/labs/ai/mcp/contrast/sdkextension/ExtendedTraceFilterBodyTest.java b/src/test/java/com/contrast/labs/ai/mcp/contrast/sdkextension/ExtendedTraceFilterBodyTest.java index 29eb724..4222027 100644 --- a/src/test/java/com/contrast/labs/ai/mcp/contrast/sdkextension/ExtendedTraceFilterBodyTest.java +++ b/src/test/java/com/contrast/labs/ai/mcp/contrast/sdkextension/ExtendedTraceFilterBodyTest.java @@ -16,165 +16,17 @@ package com.contrast.labs.ai.mcp.contrast.sdkextension; import static org.assertj.core.api.Assertions.assertThat; -import static org.assertj.core.api.Assertions.assertThatNullPointerException; -import com.contrastsecurity.http.RuleSeverity; -import com.contrastsecurity.http.ServerEnvironment; -import com.contrastsecurity.models.TraceFilterBody; -import com.contrastsecurity.models.TraceMetadataFilter; -import java.util.Date; -import java.util.List; +import java.util.Set; import org.junit.jupiter.api.Test; class ExtendedTraceFilterBodyTest { @Test - void withSessionFilters_should_copy_all_base_filters() { - // Given a fully populated source filter body - var source = new TraceFilterBody(); - source.setSeverities(List.of(RuleSeverity.CRITICAL, RuleSeverity.HIGH)); - source.setVulnTypes(List.of("sql-injection", "xss-reflected")); - source.setEnvironments(List.of(ServerEnvironment.PRODUCTION)); - var startDate = new Date(); - var endDate = new Date(); - source.setStartDate(startDate); - source.setEndDate(endDate); - source.setFilterTags(List.of("tag1", "tag2")); - source.setTracked(true); - source.setUntracked(false); + void status_field_should_be_settable_and_gettable() { + var body = new ExtendedTraceFilterBody(); + body.setStatus(Set.of("Reported", "Confirmed")); - // When creating extended filter with no session params - var result = ExtendedTraceFilterBody.withSessionFilters(source, null, null); - - // Then all base filters should be copied - assertThat(result.getSeverities()).containsExactly(RuleSeverity.CRITICAL, RuleSeverity.HIGH); - assertThat(result.getVulnTypes()).containsExactly("sql-injection", "xss-reflected"); - assertThat(result.getEnvironments()).containsExactly(ServerEnvironment.PRODUCTION); - assertThat(result.getStartDate()).isEqualTo(startDate); - assertThat(result.getEndDate()).isEqualTo(endDate); - assertThat(result.getFilterTags()).containsExactly("tag1", "tag2"); - assertThat(result.isTracked()).isTrue(); - assertThat(result.isUntracked()).isFalse(); - } - - @Test - void withSessionFilters_should_handle_null_source_fields() { - // Given a source with null fields - var source = new TraceFilterBody(); - - // When creating extended filter - var result = ExtendedTraceFilterBody.withSessionFilters(source, null, null); - - // Then null fields should remain null - assertThat(result.getSeverities()).isNull(); - assertThat(result.getVulnTypes()).isNull(); - assertThat(result.getEnvironments()).isNull(); - assertThat(result.getStartDate()).isNull(); - assertThat(result.getEndDate()).isNull(); - assertThat(result.getFilterTags()).isNull(); - } - - @Test - void withSessionFilters_should_set_agent_session_id() { - // Given - var source = new TraceFilterBody(); - var sessionId = "agent-session-123"; - - // When - var result = ExtendedTraceFilterBody.withSessionFilters(source, sessionId, null); - - // Then - assertThat(result.getAgentSessionId()).isEqualTo(sessionId); - } - - @Test - void withSessionFilters_should_set_single_metadata_filter() { - // Given - var source = new TraceFilterBody(); - var filter = new TraceMetadataFilter("87", "main"); - - // When - var result = ExtendedTraceFilterBody.withSessionFilters(source, null, List.of(filter)); - - // Then - assertThat(result.getMetadataFilters()).hasSize(1); - assertThat(result.getMetadataFilters().get(0).getFieldID()).isEqualTo("87"); - assertThat(result.getMetadataFilters().get(0).getValues()).containsExactly("main"); - } - - @Test - void withSessionFilters_should_set_multiple_metadata_filters() { - // Given - var source = new TraceFilterBody(); - var filters = - List.of( - new TraceMetadataFilter("87", "main"), - new TraceMetadataFilter("88", List.of("Ellen", "Sam"))); - - // When - var result = ExtendedTraceFilterBody.withSessionFilters(source, null, filters); - - // Then - assertThat(result.getMetadataFilters()).hasSize(2); - assertThat(result.getMetadataFilters().get(0).getFieldID()).isEqualTo("87"); - assertThat(result.getMetadataFilters().get(0).getValues()).containsExactly("main"); - assertThat(result.getMetadataFilters().get(1).getFieldID()).isEqualTo("88"); - assertThat(result.getMetadataFilters().get(1).getValues()).containsExactly("Ellen", "Sam"); - } - - @Test - void withSessionFilters_should_handle_null_metadata_filters() { - // Given - var source = new TraceFilterBody(); - - // When - var result = ExtendedTraceFilterBody.withSessionFilters(source, null, null); - - // Then - assertThat(result.getMetadataFilters()).isNull(); - } - - @Test - void withSessionFilters_should_handle_empty_metadata_filters() { - // Given - var source = new TraceFilterBody(); - - // When - var result = ExtendedTraceFilterBody.withSessionFilters(source, null, List.of()); - - // Then - assertThat(result.getMetadataFilters()).isNull(); - } - - @Test - void withSessionFilters_should_set_all_session_params_together() { - // Given - var source = new TraceFilterBody(); - source.setSeverities(List.of(RuleSeverity.CRITICAL)); - var sessionId = "agent-session-456"; - var filters = - List.of( - new TraceMetadataFilter("89", "Sam"), - new TraceMetadataFilter("90", List.of("feature-1", "feature-2"))); - - // When - var result = ExtendedTraceFilterBody.withSessionFilters(source, sessionId, filters); - - // Then all params should be set - assertThat(result.getSeverities()).containsExactly(RuleSeverity.CRITICAL); - assertThat(result.getAgentSessionId()).isEqualTo(sessionId); - assertThat(result.getMetadataFilters()).hasSize(2); - assertThat(result.getMetadataFilters().get(0).getFieldID()).isEqualTo("89"); - assertThat(result.getMetadataFilters().get(0).getValues()).containsExactly("Sam"); - assertThat(result.getMetadataFilters().get(1).getFieldID()).isEqualTo("90"); - assertThat(result.getMetadataFilters().get(1).getValues()) - .containsExactly("feature-1", "feature-2"); - } - - @Test - void withSessionFilters_should_reject_null_source() { - assertThatNullPointerException() - .isThrownBy(() -> ExtendedTraceFilterBody.withSessionFilters(null, null, null)) - .withMessage("source TraceFilterBody must not be null"); + assertThat(body.getStatus()).containsExactlyInAnyOrder("Reported", "Confirmed"); } } diff --git a/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/SearchAppVulnerabilitiesParamsTest.java b/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/SearchAppVulnerabilitiesParamsTest.java index 00a2891..009bd46 100644 --- a/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/SearchAppVulnerabilitiesParamsTest.java +++ b/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/SearchAppVulnerabilitiesParamsTest.java @@ -17,9 +17,10 @@ import static org.assertj.core.api.Assertions.assertThat; +import com.contrast.labs.ai.mcp.contrast.sdkextension.ExtendedTraceFilterBody; +import com.contrast.labs.ai.mcp.contrast.tool.validation.UnresolvedMetadataFilter; import com.contrastsecurity.http.RuleSeverity; import com.contrastsecurity.http.ServerEnvironment; -import java.util.List; import org.junit.jupiter.api.Test; class SearchAppVulnerabilitiesParamsTest { @@ -147,7 +148,9 @@ void of_should_parse_valid_sessionMetadataFilters() { VALID_APP_ID, null, null, null, null, null, null, null, "{\"branch\":\"main\"}", null); assertThat(params.isValid()).isTrue(); - assertThat(params.getSessionMetadataFilters()).containsEntry("branch", "main"); + assertThat(params.getSessionMetadataFilters()).hasSize(1); + assertThat(params.getSessionMetadataFilters().get(0).fieldName()).isEqualTo("branch"); + assertThat(params.getSessionMetadataFilters().get(0).values()).containsExactly("main"); } @Test @@ -166,9 +169,11 @@ void of_should_parse_sessionMetadataFilters_with_multiple_fields() { null); assertThat(params.isValid()).isTrue(); - assertThat(params.getSessionMetadataFilters()) - .containsEntry("developer", "Ellen") - .containsEntry("commit", "100"); + assertThat(params.getSessionMetadataFilters()).hasSize(2); + // LinkedHashMap preserves order + var filters = params.getSessionMetadataFilters(); + assertThat(filters.stream().map(UnresolvedMetadataFilter::fieldName)) + .containsExactly("developer", "commit"); } @Test @@ -187,9 +192,9 @@ void of_should_parse_sessionMetadataFilters_with_array_values() { null); assertThat(params.isValid()).isTrue(); - @SuppressWarnings("unchecked") - var developers = (List) params.getSessionMetadataFilters().get("developer"); - assertThat(developers).containsExactly("Ellen", "Sam"); + assertThat(params.getSessionMetadataFilters()).hasSize(1); + assertThat(params.getSessionMetadataFilters().get(0).fieldName()).isEqualTo("developer"); + assertThat(params.getSessionMetadataFilters().get(0).values()).containsExactly("Ellen", "Sam"); } @Test @@ -301,4 +306,16 @@ void toTraceFilterBody_empty_should_return_all_vulnerabilities() { assertThat(body.isTracked()).isFalse(); assertThat(body.isUntracked()).isFalse(); } + + @Test + void toTraceFilterBody_should_return_ExtendedTraceFilterBody_with_status() { + var params = + SearchAppVulnerabilitiesParams.of( + VALID_APP_ID, null, "Reported,Confirmed", null, null, null, null, null, null, null); + + var body = params.toTraceFilterBody(); + + assertThat(body).isInstanceOf(ExtendedTraceFilterBody.class); + assertThat(body.getStatus()).containsExactlyInAnyOrder("Reported", "Confirmed"); + } } diff --git a/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/VulnerabilityFilterParamsTest.java b/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/VulnerabilityFilterParamsTest.java index d7a3bfb..57d9021 100644 --- a/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/VulnerabilityFilterParamsTest.java +++ b/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/assess/params/VulnerabilityFilterParamsTest.java @@ -17,6 +17,7 @@ import static org.assertj.core.api.Assertions.assertThat; +import com.contrast.labs.ai.mcp.contrast.sdkextension.ExtendedTraceFilterBody; import com.contrastsecurity.http.RuleSeverity; import com.contrastsecurity.http.ServerEnvironment; import org.junit.jupiter.api.Test; @@ -80,6 +81,17 @@ void of_should_apply_default_statuses_with_warning() { assertThat(params.warnings()).anyMatch(w -> w.contains("excluding Fixed and Remediated")); } + @Test + void of_should_accept_statuses_case_insensitively_and_normalize() { + var params = + VulnerabilityFilterParams.of( + null, "reported,CONFIRMED,Fixed", null, null, null, null, null); + + assertThat(params.isValid()).isTrue(); + // Should normalize to canonical form + assertThat(params.getStatuses()).containsExactly("Reported", "Confirmed", "Fixed"); + } + // -- Environment tests -- @Test @@ -198,4 +210,15 @@ void toTraceFilterBody_empty_should_return_all_vulnerabilities() { assertThat(body.isTracked()).isFalse(); assertThat(body.isUntracked()).isFalse(); } + + @Test + void toTraceFilterBody_should_return_ExtendedTraceFilterBody_with_status() { + var params = + VulnerabilityFilterParams.of(null, "Reported,Confirmed", null, null, null, null, null); + + var body = params.toTraceFilterBody(); + + assertThat(body).isInstanceOf(ExtendedTraceFilterBody.class); + assertThat(body.getStatus()).containsExactlyInAnyOrder("Reported", "Confirmed"); + } } diff --git a/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/validation/MetadataJsonFilterSpecTest.java b/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/validation/MetadataJsonFilterSpecTest.java index f83600d..6acc0dd 100644 --- a/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/validation/MetadataJsonFilterSpecTest.java +++ b/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/validation/MetadataJsonFilterSpecTest.java @@ -17,7 +17,6 @@ import static org.assertj.core.api.Assertions.assertThat; -import java.util.List; import org.junit.jupiter.api.Test; class MetadataJsonFilterSpecTest { @@ -26,7 +25,10 @@ class MetadataJsonFilterSpecTest { void get_should_parse_simple_key_value() { var ctx = new ToolValidationContext(); var result = ctx.metadataJsonFilterParam("{\"key\":\"value\"}", "test").get(); - assertThat(result).containsEntry("key", "value"); + + assertThat(result).hasSize(1); + assertThat(result.get(0).fieldName()).isEqualTo("key"); + assertThat(result.get(0).values()).containsExactly("value"); assertThat(ctx.isValid()).isTrue(); } @@ -34,10 +36,11 @@ void get_should_parse_simple_key_value() { void get_should_parse_array_values() { var ctx = new ToolValidationContext(); var result = ctx.metadataJsonFilterParam("{\"key\":[\"a\",\"b\"]}", "test").get(); - assertThat(result).containsKey("key"); - @SuppressWarnings("unchecked") - var values = (List) result.get("key"); - assertThat(values).containsExactly("a", "b"); + + assertThat(result).hasSize(1); + assertThat(result.get(0).fieldName()).isEqualTo("key"); + assertThat(result.get(0).values()).containsExactly("a", "b"); + assertThat(ctx.isValid()).isTrue(); } @Test @@ -73,7 +76,10 @@ void get_should_add_error_for_invalid_json() { void get_should_convert_number_to_string() { var ctx = new ToolValidationContext(); var result = ctx.metadataJsonFilterParam("{\"count\":42}", "test").get(); - assertThat(result.get("count")).isEqualTo("42"); + + assertThat(result).hasSize(1); + assertThat(result.get(0).fieldName()).isEqualTo("count"); + assertThat(result.get(0).values()).containsExactly("42"); assertThat(ctx.isValid()).isTrue(); } @@ -81,9 +87,10 @@ void get_should_convert_number_to_string() { void get_should_convert_number_in_array_to_string() { var ctx = new ToolValidationContext(); var result = ctx.metadataJsonFilterParam("{\"ids\":[1,2,3]}", "test").get(); - @SuppressWarnings("unchecked") - var ids = (List) result.get("ids"); - assertThat(ids).containsExactly("1", "2", "3"); + + assertThat(result).hasSize(1); + assertThat(result.get(0).fieldName()).isEqualTo("ids"); + assertThat(result.get(0).values()).containsExactly("1", "2", "3"); assertThat(ctx.isValid()).isTrue(); } @@ -112,14 +119,29 @@ void get_should_handle_mixed_valid_values() { var result = ctx.metadataJsonFilterParam("{\"str\":\"val\",\"arr\":[\"a\",\"b\"],\"num\":123}", "test") .get(); - assertThat(result.get("str")).isEqualTo("val"); - @SuppressWarnings("unchecked") - var arr = (List) result.get("arr"); - assertThat(arr).containsExactly("a", "b"); - assertThat(result.get("num")).isEqualTo("123"); + + assertThat(result).hasSize(3); + // Note: LinkedHashMap preserves insertion order + assertThat(result.get(0).fieldName()).isEqualTo("str"); + assertThat(result.get(0).values()).containsExactly("val"); + assertThat(result.get(1).fieldName()).isEqualTo("arr"); + assertThat(result.get(1).values()).containsExactly("a", "b"); + assertThat(result.get(2).fieldName()).isEqualTo("num"); + assertThat(result.get(2).values()).containsExactly("123"); assertThat(ctx.isValid()).isTrue(); } + @Test + void get_should_preserve_insertion_order() { + var ctx = new ToolValidationContext(); + var result = + ctx.metadataJsonFilterParam("{\"branch\":\"main\",\"developer\":\"Ellen\"}", "test").get(); + + assertThat(result).hasSize(2); + assertThat(result.get(0).fieldName()).isEqualTo("branch"); + assertThat(result.get(1).fieldName()).isEqualTo("developer"); + } + // Tests for mutuallyExclusive method on ToolValidationContext @Test void mutuallyExclusive_should_add_error_when_both_present() { diff --git a/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/validation/UnresolvedMetadataFilterTest.java b/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/validation/UnresolvedMetadataFilterTest.java new file mode 100644 index 0000000..e3e4b7d --- /dev/null +++ b/src/test/java/com/contrast/labs/ai/mcp/contrast/tool/validation/UnresolvedMetadataFilterTest.java @@ -0,0 +1,64 @@ +/* + * Copyright 2025 Contrast Security + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.contrast.labs.ai.mcp.contrast.tool.validation; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +import java.util.ArrayList; +import java.util.List; +import org.junit.jupiter.api.Test; + +class UnresolvedMetadataFilterTest { + + @Test + void constructor_should_reject_null_field_name() { + assertThatThrownBy(() -> new UnresolvedMetadataFilter(null, List.of("value"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("fieldName"); + } + + @Test + void constructor_should_reject_blank_field_name() { + assertThatThrownBy(() -> new UnresolvedMetadataFilter(" ", List.of("value"))) + .isInstanceOf(IllegalArgumentException.class) + .hasMessageContaining("fieldName"); + } + + @Test + void constructor_should_make_defensive_copy_of_values() { + var mutableList = new ArrayList<>(List.of("a", "b")); + var filter = new UnresolvedMetadataFilter("field", mutableList); + mutableList.add("c"); + + assertThat(filter.values()).containsExactly("a", "b"); + } + + @Test + void constructor_should_handle_null_values_as_empty_list() { + var filter = new UnresolvedMetadataFilter("field", null); + + assertThat(filter.values()).isEmpty(); + } + + @Test + void record_should_store_field_name_and_values() { + var filter = new UnresolvedMetadataFilter("branch", List.of("main", "develop")); + + assertThat(filter.fieldName()).isEqualTo("branch"); + assertThat(filter.values()).containsExactly("main", "develop"); + } +}