diff --git a/deepteam/vulnerabilities/custom/custom.py b/deepteam/vulnerabilities/custom/custom.py index 24ddd4fc..2fdb4e82 100644 --- a/deepteam/vulnerabilities/custom/custom.py +++ b/deepteam/vulnerabilities/custom/custom.py @@ -41,6 +41,13 @@ def __init__( self.types = Enum( f"CustomVulnerabilityType", {t.upper(): t for t in types} ) + else: + # Default to a single type derived from the vulnerability name + # so iteration over self.types always works + self.types = Enum( + f"CustomVulnerabilityType", + {name.upper().replace(" ", "_"): name}, + ) self.custom_prompt = custom_prompt self.criteria = criteria.strip() diff --git a/tests/test_core/test_vulnerabilities/test_custom_vulnerability.py b/tests/test_core/test_vulnerabilities/test_custom_vulnerability.py index 5b0f7e3b..c4d7338f 100644 --- a/tests/test_core/test_vulnerabilities/test_custom_vulnerability.py +++ b/tests/test_core/test_vulnerabilities/test_custom_vulnerability.py @@ -19,3 +19,33 @@ def test_custom_vulnerability_initialize(self): assert sorted( type.value for type in custom_vulnerability.types ) == sorted(["type1", "type2"]) + + def test_custom_vulnerability_no_types(self): + """CustomVulnerability should not raise when types is omitted.""" + custom_vulnerability = CustomVulnerability( + criteria="Criteria", + name="Data Leakage", + ) + assert custom_vulnerability.name == "Data Leakage" + assert custom_vulnerability.criteria == "Criteria" + # Should have a default type derived from the name + type_values = [t.value for t in custom_vulnerability.types] + assert type_values == ["Data Leakage"] + + def test_custom_vulnerability_no_types_get_name(self): + """get_name works when types is omitted.""" + custom_vulnerability = CustomVulnerability( + criteria="Test criteria", + name="Test Vuln", + ) + assert custom_vulnerability.get_name() == "Test Vuln" + + def test_custom_vulnerability_single_type(self): + """CustomVulnerability works with a single type.""" + custom_vulnerability = CustomVulnerability( + criteria="Criteria", + name="Name", + types=["only_type"], + ) + type_values = [t.value for t in custom_vulnerability.types] + assert type_values == ["only_type"]