Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions deepteam/vulnerabilities/custom/custom.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def __init__(
self.types = Enum(
f"CustomVulnerabilityType", {t.upper(): t for t in types}
)
else:
# Default to a single type derived from the vulnerability name
# so iteration over self.types always works
self.types = Enum(
f"CustomVulnerabilityType",
{name.upper().replace(" ", "_"): name},
)

self.custom_prompt = custom_prompt
self.criteria = criteria.strip()
Expand Down
30 changes: 30 additions & 0 deletions tests/test_core/test_vulnerabilities/test_custom_vulnerability.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,33 @@ def test_custom_vulnerability_initialize(self):
assert sorted(
type.value for type in custom_vulnerability.types
) == sorted(["type1", "type2"])

def test_custom_vulnerability_no_types(self):
"""CustomVulnerability should not raise when types is omitted."""
custom_vulnerability = CustomVulnerability(
criteria="Criteria",
name="Data Leakage",
)
assert custom_vulnerability.name == "Data Leakage"
assert custom_vulnerability.criteria == "Criteria"
# Should have a default type derived from the name
type_values = [t.value for t in custom_vulnerability.types]
assert type_values == ["Data Leakage"]

def test_custom_vulnerability_no_types_get_name(self):
"""get_name works when types is omitted."""
custom_vulnerability = CustomVulnerability(
criteria="Test criteria",
name="Test Vuln",
)
assert custom_vulnerability.get_name() == "Test Vuln"

def test_custom_vulnerability_single_type(self):
"""CustomVulnerability works with a single type."""
custom_vulnerability = CustomVulnerability(
criteria="Criteria",
name="Name",
types=["only_type"],
)
type_values = [t.value for t in custom_vulnerability.types]
assert type_values == ["only_type"]