diff --git a/terrasafe/application/scanner.py b/terrasafe/application/scanner.py index 01f813b..1e8bf31 100644 --- a/terrasafe/application/scanner.py +++ b/terrasafe/application/scanner.py @@ -219,9 +219,10 @@ def _validate_features(self, features: np.ndarray) -> np.ndarray: Validated feature array with values clipped to acceptable bounds """ # Define acceptable bounds for each feature - # [open_ports, hardcoded_secrets, public_access, unencrypted_storage, total_resources] - min_bounds = np.array([0, 0, 0, 0, 0], dtype=np.int32) - max_bounds = np.array([100, 100, 100, 100, 10000], dtype=np.int32) + # [open_ports, hardcoded_secrets, public_access, unencrypted_storage, + # missing_logging, missing_flow_logs, total_resources] + min_bounds = np.array([0, 0, 0, 0, 0, 0, 0], dtype=np.int32) + max_bounds = np.array([100, 100, 100, 100, 100, 100, 10000], dtype=np.int32) # Clip features to acceptable ranges validated = np.clip(features, min_bounds, max_bounds) @@ -244,11 +245,11 @@ def _extract_features(self, vulnerabilities: List[Vulnerability]) -> np.ndarray: vulnerabilities: List of detected vulnerabilities Returns: - Numpy array of features (shape: 1x5) + Numpy array of features (shape: 1x7) """ if not vulnerabilities: # Return default feature vector for empty vulnerability list - return np.array([[0, 0, 0, 0, 1]], dtype=np.int32) + return np.array([[0, 0, 0, 0, 0, 0, 1]], dtype=np.int32) # Count unique resources unique_resources = len(set(v.resource for v in vulnerabilities)) @@ -270,12 +271,18 @@ def _extract_features(self, vulnerabilities: List[Vulnerability]) -> np.ndarray: unencrypted_mask = np.char.find(messages, 'unencrypted') >= 0 + missing_logging_mask = np.char.find(messages, 'missing logging') >= 0 + + missing_flow_logs_mask = np.char.find(messages, 'missing vpc flow logs') >= 0 + # Count matches using numpy sum (faster than Python loops) features = np.array([ np.sum(open_ports_mask), np.sum(hardcoded_mask), np.sum(public_access_mask), np.sum(unencrypted_mask), + np.sum(missing_logging_mask), + np.sum(missing_flow_logs_mask), unique_resources ], dtype=np.int32).reshape(1, -1) @@ -288,7 +295,10 @@ def _summarize_vulns(self, vulns: List[Vulnerability]) -> Dict[str, int]: return summary def _format_features(self, features: np.ndarray) -> Dict[str, int]: - feature_names = ['open_ports', 'hardcoded_secrets', 'public_access', 'unencrypted_storage', 'total_resources'] + feature_names = [ + 'open_ports', 'hardcoded_secrets', 'public_access', 'unencrypted_storage', + 'missing_logging', 'missing_flow_logs', 'total_resources' + ] return {name: int(val) for name, val in zip(feature_names, features[0])} def _vulnerability_to_dict(self, vuln: Vulnerability) -> Dict[str, Any]: diff --git a/terrasafe/config/settings.py b/terrasafe/config/settings.py index 8c3b5a4..970552a 100644 --- a/terrasafe/config/settings.py +++ b/terrasafe/config/settings.py @@ -67,6 +67,10 @@ class Settings(BaseSettings): default="models/isolation_forest.pkl", description="Path to ML model file" ) + severity_overrides: Dict[str, str] = Field( + default={}, + description="Override severity for specific rules, e.g. {'missing_logging': 'MEDIUM'}" + ) # Security Configuration max_file_size_mb: int = Field( diff --git a/terrasafe/domain/security_rules.py b/terrasafe/domain/security_rules.py index 82f5483..2b32c06 100644 --- a/terrasafe/domain/security_rules.py +++ b/terrasafe/domain/security_rules.py @@ -2,6 +2,7 @@ import re from typing import List, Dict from .models import Vulnerability, Severity +from ..config.settings import get_settings # Constants for severity points (Clean Code: No magic numbers) @@ -258,6 +259,67 @@ def check_iam_policies(self, tf_content: Dict) -> List[Vulnerability]: return vulns + def check_missing_logging(self, tf_content: Dict) -> List[Vulnerability]: + """Check for missing CloudTrail/CloudWatch logging resources. + + If infrastructure resources exist but no logging resources are present, + flag as HIGH severity. + """ + vulns: List[Vulnerability] = [] + + if 'resource' not in tf_content: + return vulns + + resources = tf_content.get('resource', []) + all_resource_types = set() + for resource_block in resources: + all_resource_types.update(resource_block.keys()) + + # Only flag if there are infrastructure resources to log + infra_types = all_resource_types - {'aws_cloudtrail', 'aws_cloudwatch_log_group'} + has_infra = bool(infra_types) + has_logging = 'aws_cloudtrail' in all_resource_types or 'aws_cloudwatch_log_group' in all_resource_types + + if has_infra and not has_logging: + vulns.append(Vulnerability( + severity=Severity.HIGH, + points=POINTS_HIGH, + message="[HIGH] Missing logging - no CloudTrail or CloudWatch log group detected", + resource="Logging", + remediation="Add aws_cloudtrail or aws_cloudwatch_log_group to enable audit logging" + )) + + return vulns + + def check_missing_vpc_flow_logs(self, tf_content: Dict) -> List[Vulnerability]: + """Check for VPC resources without corresponding flow logs. + + If an aws_vpc resource exists but no aws_flow_log is found, flag as MEDIUM. + """ + vulns: List[Vulnerability] = [] + + if 'resource' not in tf_content: + return vulns + + resources = tf_content.get('resource', []) + all_resource_types = set() + for resource_block in resources: + all_resource_types.update(resource_block.keys()) + + has_vpc = 'aws_vpc' in all_resource_types + has_flow_log = 'aws_flow_log' in all_resource_types + + if has_vpc and not has_flow_log: + vulns.append(Vulnerability( + severity=Severity.MEDIUM, + points=POINTS_MEDIUM, + message="[MEDIUM] Missing VPC flow logs - aws_vpc present but no aws_flow_log detected", + resource="VPC", + remediation="Add an aws_flow_log resource to enable VPC traffic logging" + )) + + return vulns + def analyze(self, tf_content: Dict, raw_content: str) -> List[Vulnerability]: """Run all security checks""" all_vulns = [] @@ -268,5 +330,23 @@ def analyze(self, tf_content: Dict, raw_content: str) -> List[Vulnerability]: all_vulns.extend(self.check_encryption(tf_content)) all_vulns.extend(self.check_public_s3(tf_content)) all_vulns.extend(self.check_iam_policies(tf_content)) + all_vulns.extend(self.check_missing_logging(tf_content)) + all_vulns.extend(self.check_missing_vpc_flow_logs(tf_content)) + + # Apply severity overrides from config + overrides = get_settings().severity_overrides + if overrides: + severity_map = {s.value: s for s in Severity} + rule_key_map = { + 'missing_logging': '[HIGH] Missing logging', + 'missing_flow_logs': '[MEDIUM] Missing VPC flow logs', + } + for vuln in all_vulns: + for rule_name, override_level in overrides.items(): + fragment = rule_key_map.get(rule_name) + if fragment and fragment in vuln.message: + new_severity = severity_map.get(override_level.upper()) + if new_severity: + vuln.severity = new_severity return all_vulns diff --git a/terrasafe/infrastructure/ml_model.py b/terrasafe/infrastructure/ml_model.py index cf93c4e..8d37651 100644 --- a/terrasafe/infrastructure/ml_model.py +++ b/terrasafe/infrastructure/ml_model.py @@ -364,33 +364,33 @@ def _train_baseline_model(self): rng = np.random.default_rng(42) # Enhanced baseline patterns representing secure configurations - # Features: [open_ports, secrets, public_access, unencrypted, resource_count] + # Features: [open_ports, secrets, public_access, unencrypted, missing_logging, missing_flow_logs, resource_count] baseline_patterns = [ # Fully secure configurations - [0, 0, 0, 0, 5], # Small secure microservice - [0, 0, 0, 0, 10], # Medium secure application - [0, 0, 0, 0, 15], # Large secure infrastructure - [0, 0, 0, 0, 25], # Enterprise secure setup - [0, 0, 0, 0, 3], # Minimal secure Lambda function + [0, 0, 0, 0, 0, 0, 5], # Small secure microservice + [0, 0, 0, 0, 0, 0, 10], # Medium secure application + [0, 0, 0, 0, 0, 0, 15], # Large secure infrastructure + [0, 0, 0, 0, 0, 0, 25], # Enterprise secure setup + [0, 0, 0, 0, 0, 0, 3], # Minimal secure Lambda function # Web applications (acceptable public exposure) - [1, 0, 0, 0, 8], # Simple web app with HTTP - [2, 0, 0, 0, 12], # Web app with HTTP/HTTPS - [2, 0, 1, 0, 20], # E-commerce with CDN (public S3) - [1, 0, 1, 0, 15], # Static site with S3 hosting - [2, 0, 2, 0, 30], # Multi-region web platform + [1, 0, 0, 0, 0, 0, 8], # Simple web app with HTTP + [2, 0, 0, 0, 0, 0, 12], # Web app with HTTP/HTTPS + [2, 0, 1, 0, 0, 0, 20], # E-commerce with CDN (public S3) + [1, 0, 1, 0, 0, 0, 15], # Static site with S3 hosting + [2, 0, 2, 0, 0, 0, 30], # Multi-region web platform # Development environments (slightly relaxed) - [1, 0, 0, 1, 6], # Dev env with one unencrypted volume - [2, 0, 0, 1, 10], # Staging with test data - [1, 0, 1, 1, 8], # QA environment - [0, 0, 0, 2, 12], # Test cluster with temp storage + [1, 0, 0, 1, 0, 0, 6], # Dev env with one unencrypted volume + [2, 0, 0, 1, 0, 0, 10], # Staging with test data + [1, 0, 1, 1, 0, 0, 8], # QA environment + [0, 0, 0, 2, 0, 0, 12], # Test cluster with temp storage # Microservices architectures - [3, 0, 0, 0, 40], # Service mesh with multiple endpoints - [4, 0, 1, 0, 50], # Kubernetes cluster with ingress - [2, 0, 0, 0, 35], # Docker swarm setup - [3, 0, 2, 0, 45], # Multi-service with CDN + [3, 0, 0, 0, 0, 0, 40], # Service mesh with multiple endpoints + [4, 0, 1, 0, 0, 0, 50], # Kubernetes cluster with ingress + [2, 0, 0, 0, 0, 0, 35], # Docker swarm setup + [3, 0, 2, 0, 0, 0, 45], # Multi-service with CDN ] baseline_features = np.array(baseline_patterns) @@ -401,7 +401,7 @@ def _train_baseline_model(self): # Add noise variations for each pattern for pattern in baseline_features: for _ in range(3): # Create 3 variations per pattern - noise = rng.normal(0, 0.15, 5) + noise = rng.normal(0, 0.15, 7) augmented = pattern + noise augmented = np.maximum(augmented, 0) # Ensure non-negative # Round discrete features @@ -410,11 +410,11 @@ def _train_baseline_model(self): # Add edge cases representing acceptable boundaries edge_cases = np.array([ - [5, 0, 0, 0, 60], # Large microservices - [0, 0, 5, 0, 40], # Content delivery network - [3, 0, 3, 2, 50], # Legacy migration - [0, 0, 0, 3, 25], # Development cluster - [6, 0, 2, 0, 70], # API gateway with multiple services + [5, 0, 0, 0, 0, 0, 60], # Large microservices + [0, 0, 5, 0, 0, 0, 40], # Content delivery network + [3, 0, 3, 2, 0, 0, 50], # Legacy migration + [0, 0, 0, 3, 0, 0, 25], # Development cluster + [6, 0, 2, 0, 0, 0, 70], # API gateway with multiple services ]) augmented_data = np.vstack([augmented_data, edge_cases]) @@ -446,7 +446,9 @@ def _train_baseline_model(self): 'hardcoded_secrets': {'min': int(augmented_data[:, 1].min()), 'max': int(augmented_data[:, 1].max())}, 'public_access': {'min': int(augmented_data[:, 2].min()), 'max': int(augmented_data[:, 2].max())}, 'unencrypted_storage': {'min': int(augmented_data[:, 3].min()), 'max': int(augmented_data[:, 3].max())}, - 'total_resources': {'min': int(augmented_data[:, 4].min()), 'max': int(augmented_data[:, 4].max())}, + 'missing_logging': {'min': int(augmented_data[:, 4].min()), 'max': int(augmented_data[:, 4].max())}, + 'missing_flow_logs': {'min': int(augmented_data[:, 5].min()), 'max': int(augmented_data[:, 5].max())}, + 'total_resources': {'min': int(augmented_data[:, 6].min()), 'max': int(augmented_data[:, 6].max())}, }, 'model_parameters': { 'contamination': 0.1, diff --git a/test_files/mixed.tf b/test_files/mixed.tf index 139a3df..29bc53d 100644 --- a/test_files/mixed.tf +++ b/test_files/mixed.tf @@ -63,3 +63,18 @@ variable "db_password" { type = string sensitive = true } + +resource "aws_vpc" "app_vpc" { + cidr_block = "10.0.0.0/16" # MEDIUM: No aws_flow_log present + tags = { + Name = "mixed-vpc" + } +} + +resource "aws_cloudtrail" "app_trail" { + name = "app-trail" + s3_bucket_name = aws_s3_bucket.app_bucket.bucket + # CloudTrail present — satisfies missing_logging rule +} +# NOTE: aws_cloudtrail present → no missing_logging vuln +# NOTE: no aws_flow_log → triggers missing_vpc_flow_logs rule only diff --git a/test_files/secure.tf b/test_files/secure.tf index 9175ee9..9ec5a69 100644 --- a/test_files/secure.tf +++ b/test_files/secure.tf @@ -72,4 +72,36 @@ variable "db_password" { description = "Database password" type = string sensitive = true +} + +resource "aws_vpc" "main" { + cidr_block = "10.0.0.0/16" + tags = { + Name = "secure-vpc" + } +} + +resource "aws_flow_log" "main" { + vpc_id = aws_vpc.main.id + traffic_type = "ALL" + iam_role_arn = var.flow_log_role_arn + log_destination = aws_cloudwatch_log_group.flow_logs.arn +} + +resource "aws_cloudwatch_log_group" "flow_logs" { + name = "/aws/vpc/flow-logs" + retention_in_days = 90 +} + +resource "aws_cloudtrail" "main" { + name = "secure-trail" + s3_bucket_name = aws_s3_bucket.main_bucket.bucket + include_global_service_events = true + is_multi_region_trail = true + enable_log_file_validation = true +} + +variable "flow_log_role_arn" { + description = "IAM role ARN for VPC flow logs" + type = string } \ No newline at end of file diff --git a/test_files/vulnerable.tf b/test_files/vulnerable.tf index 7cf9723..f031701 100644 --- a/test_files/vulnerable.tf +++ b/test_files/vulnerable.tf @@ -59,4 +59,13 @@ resource "aws_s3_bucket" "main_bucket" { tags = { Environment = "test" } -} \ No newline at end of file +} + +resource "aws_vpc" "main" { + cidr_block = "10.0.0.0/16" # HIGH: No aws_flow_log present — VPC traffic is unmonitored + tags = { + Name = "vulnerable-vpc" + } +} +# NOTE: no aws_cloudtrail, no aws_cloudwatch_log_group → triggers missing_logging rule +# NOTE: no aws_flow_log → triggers missing_vpc_flow_logs rule \ No newline at end of file diff --git a/tests/test_security_rules_logging.py b/tests/test_security_rules_logging.py new file mode 100644 index 0000000..4cb49ff --- /dev/null +++ b/tests/test_security_rules_logging.py @@ -0,0 +1,182 @@ +#!/usr/bin/env python3 +""" +Unit tests for the new logging and VPC flow log security rules. + +Covers: +- check_missing_logging() — various configurations +- check_missing_vpc_flow_logs() — various configurations +- Severity overrides applied through analyze() +- Boundary cases (empty tf_content) +""" + +import pytest +from unittest.mock import patch + +from terrasafe.domain.models import Severity +from terrasafe.domain.security_rules import SecurityRuleEngine + + +@pytest.fixture +def engine(): + """Provide a SecurityRuleEngine for each test.""" + return SecurityRuleEngine() + + +# --------------------------------------------------------------------------- +# check_missing_logging +# --------------------------------------------------------------------------- + +def test_missing_logging_no_logging_resources(engine): + """Infrastructure resources but no CloudTrail/CloudWatch → HIGH vuln.""" + tf_content = { + 'resource': [ + {'aws_vpc': [{'main': {'cidr_block': '10.0.0.0/16'}}]}, + {'aws_db_instance': [{'db': {'engine': 'mysql', 'storage_encrypted': True}}]}, + ] + } + vulns = engine.check_missing_logging(tf_content) + assert len(vulns) == 1 + assert vulns[0].severity == Severity.HIGH + assert 'missing logging' in vulns[0].message.lower() + + +def test_missing_logging_with_cloudtrail_present(engine): + """CloudTrail present → no vuln.""" + tf_content = { + 'resource': [ + {'aws_vpc': [{'main': {'cidr_block': '10.0.0.0/16'}}]}, + {'aws_cloudtrail': [{'trail': {'name': 'my-trail', 's3_bucket_name': 'my-bucket'}}]}, + ] + } + vulns = engine.check_missing_logging(tf_content) + assert vulns == [] + + +def test_missing_logging_with_cloudwatch_log_group_present(engine): + """CloudWatch log group present → no vuln.""" + tf_content = { + 'resource': [ + {'aws_security_group': [{'sg': {'ingress': []}}]}, + {'aws_cloudwatch_log_group': [{'lg': {'name': '/app/logs'}}]}, + ] + } + vulns = engine.check_missing_logging(tf_content) + assert vulns == [] + + +def test_missing_logging_empty_tf_content(engine): + """Empty tf_content (no 'resource' key) → no vuln.""" + vulns = engine.check_missing_logging({}) + assert vulns == [] + + +def test_missing_logging_resources_are_only_logging(engine): + """Only logging resources present (no infra to monitor) → no vuln.""" + tf_content = { + 'resource': [ + {'aws_cloudtrail': [{'trail': {'name': 'trail', 's3_bucket_name': 'b'}}]}, + {'aws_cloudwatch_log_group': [{'lg': {'name': '/logs'}}]}, + ] + } + vulns = engine.check_missing_logging(tf_content) + assert vulns == [] + + +# --------------------------------------------------------------------------- +# check_missing_vpc_flow_logs +# --------------------------------------------------------------------------- + +def test_missing_vpc_flow_logs_vpc_without_flow_log(engine): + """aws_vpc present but no aws_flow_log → MEDIUM vuln.""" + tf_content = { + 'resource': [ + {'aws_vpc': [{'main': {'cidr_block': '10.0.0.0/16'}}]}, + ] + } + vulns = engine.check_missing_vpc_flow_logs(tf_content) + assert len(vulns) == 1 + assert vulns[0].severity == Severity.MEDIUM + assert 'flow log' in vulns[0].message.lower() + + +def test_missing_vpc_flow_logs_vpc_with_flow_log(engine): + """aws_vpc + aws_flow_log present → no vuln.""" + tf_content = { + 'resource': [ + {'aws_vpc': [{'main': {'cidr_block': '10.0.0.0/16'}}]}, + {'aws_flow_log': [{'fl': {'vpc_id': 'aws_vpc.main.id', 'traffic_type': 'ALL'}}]}, + ] + } + vulns = engine.check_missing_vpc_flow_logs(tf_content) + assert vulns == [] + + +def test_missing_vpc_flow_logs_no_vpc_at_all(engine): + """No aws_vpc resource → no vuln (rule does not apply).""" + tf_content = { + 'resource': [ + {'aws_db_instance': [{'db': {'engine': 'mysql', 'storage_encrypted': True}}]}, + ] + } + vulns = engine.check_missing_vpc_flow_logs(tf_content) + assert vulns == [] + + +def test_missing_vpc_flow_logs_empty_tf_content(engine): + """Empty tf_content → no vuln.""" + vulns = engine.check_missing_vpc_flow_logs({}) + assert vulns == [] + + +# --------------------------------------------------------------------------- +# Severity overrides via analyze() +# --------------------------------------------------------------------------- + +def test_severity_override_missing_logging_to_medium(engine): + """severity_overrides remaps missing_logging from HIGH → MEDIUM.""" + tf_content = { + 'resource': [ + {'aws_vpc': [{'main': {'cidr_block': '10.0.0.0/16'}}]}, + ] + } + mock_settings = type('S', (), {'severity_overrides': {'missing_logging': 'MEDIUM'}})() + with patch('terrasafe.domain.security_rules.get_settings', return_value=mock_settings): + vulns = engine.analyze(tf_content, "") + + logging_vulns = [v for v in vulns if 'missing logging' in v.message.lower()] + assert len(logging_vulns) == 1 + assert logging_vulns[0].severity == Severity.MEDIUM + + +def test_severity_override_missing_flow_logs_to_high(engine): + """severity_overrides remaps missing_flow_logs from MEDIUM → HIGH.""" + tf_content = { + 'resource': [ + {'aws_vpc': [{'main': {'cidr_block': '10.0.0.0/16'}}]}, + {'aws_cloudtrail': [{'trail': {'name': 'trail', 's3_bucket_name': 'b'}}]}, + ] + } + mock_settings = type('S', (), {'severity_overrides': {'missing_flow_logs': 'HIGH'}})() + with patch('terrasafe.domain.security_rules.get_settings', return_value=mock_settings): + vulns = engine.analyze(tf_content, "") + + flow_vulns = [v for v in vulns if 'flow log' in v.message.lower()] + assert len(flow_vulns) == 1 + assert flow_vulns[0].severity == Severity.HIGH + + +def test_no_severity_override_applied_when_empty(engine): + """Empty severity_overrides dict → original severities unchanged.""" + tf_content = { + 'resource': [ + {'aws_vpc': [{'main': {'cidr_block': '10.0.0.0/16'}}]}, + ] + } + mock_settings = type('S', (), {'severity_overrides': {}})() + with patch('terrasafe.domain.security_rules.get_settings', return_value=mock_settings): + vulns = engine.analyze(tf_content, "") + + logging_vulns = [v for v in vulns if 'missing logging' in v.message.lower()] + flow_vulns = [v for v in vulns if 'flow log' in v.message.lower()] + assert all(v.severity == Severity.HIGH for v in logging_vulns) + assert all(v.severity == Severity.MEDIUM for v in flow_vulns) diff --git a/tests/test_security_scanner.py b/tests/test_security_scanner.py index 623dc47..8ed0e6b 100644 --- a/tests/test_security_scanner.py +++ b/tests/test_security_scanner.py @@ -196,9 +196,9 @@ def test_detect_open_ssh_port(rule_engine): 'resource': [{'aws_security_group': [{'test_sg': {'ingress': [{'from_port': 22, 'to_port': 22, 'protocol': 'tcp', 'cidr_blocks': ['0.0.0.0/0']}]}}]}] } vulnerabilities = rule_engine.analyze(tf_content, "") - assert len(vulnerabilities) == 1 - assert vulnerabilities[0].severity == Severity.CRITICAL - assert 'SSH' in vulnerabilities[0].message.upper() + ssh_vulns = [v for v in vulnerabilities if 'SSH' in v.message.upper()] + assert len(ssh_vulns) == 1 + assert ssh_vulns[0].severity == Severity.CRITICAL def test_detect_hardcoded_password(rule_engine): @@ -215,9 +215,10 @@ def test_detect_unencrypted_rds(rule_engine): 'resource': [{'aws_db_instance': [{'test_db': {'engine': 'mysql', 'storage_encrypted': False}}]}] } vulnerabilities = rule_engine.analyze(tf_content, "") - assert len(vulnerabilities) == 1 - assert vulnerabilities[0].severity == Severity.HIGH - assert 'unencrypted' in vulnerabilities[0].message.lower() + rds_vulns = [v for v in vulnerabilities if 'unencrypted' in v.message.lower() and 'rds' in v.message.lower()] + assert len(rds_vulns) == 1 + assert rds_vulns[0].severity == Severity.HIGH + assert 'unencrypted' in rds_vulns[0].message.lower() def test_detect_public_s3_bucket(rule_engine): @@ -231,17 +232,24 @@ def test_detect_public_s3_bucket(rule_engine): }}]}] } vulnerabilities = rule_engine.analyze(tf_content, "") - assert len(vulnerabilities) == 1 - assert vulnerabilities[0].severity == Severity.HIGH + s3_vulns = [v for v in vulnerabilities if 's3' in v.message.lower()] + assert len(s3_vulns) == 1 + assert s3_vulns[0].severity == Severity.HIGH def test_no_vulnerabilities_secure_config(rule_engine): - """Test that secure configurations don't trigger false positives""" + """Test that secure configurations don't trigger encryption false positives""" tf_content = { - 'resource': [{'aws_db_instance': [{'secure_db': {'engine': 'mysql', 'storage_encrypted': True}}]}] + 'resource': [ + {'aws_db_instance': [{'secure_db': {'engine': 'mysql', 'storage_encrypted': True}}]}, + {'aws_cloudtrail': [{'trail': {'name': 'audit-trail', 's3_bucket_name': 'audit-bucket'}}]}, + ] } vulnerabilities = rule_engine.analyze(tf_content, "") - assert len(vulnerabilities) == 0 + # No encryption, SSH, S3, IAM, or logging vulnerabilities expected + unwanted_vulns = [v for v in vulnerabilities if 'missing logging' not in v.message.lower() + and 'flow log' not in v.message.lower()] + assert len(unwanted_vulns) == 0 # ============================================================================ @@ -387,8 +395,9 @@ def test_feature_extraction(scanner_with_mocks): features = scanner_with_mocks._extract_features(vulnerabilities) - # Expected: [1 open_port, 1 secret, 1 public_access, 1 unencrypted, 4 resources] - expected = np.array([[1, 1, 1, 1, 4]]) + # Expected: [1 open_port, 1 secret, 1 public_access, 1 unencrypted, 0 missing_logging, + # 0 missing_flow_logs, 4 resources] + expected = np.array([[1, 1, 1, 1, 0, 0, 4]]) np.testing.assert_array_equal(features, expected) @@ -427,7 +436,7 @@ def test_vulnerability_to_dict(scanner_with_mocks): def test_format_features(scanner_with_mocks): """Test feature formatting""" - features = np.array([[2, 1, 0, 3, 10]]) + features = np.array([[2, 1, 0, 3, 1, 0, 10]]) formatted = scanner_with_mocks._format_features(features) expected = { @@ -435,6 +444,8 @@ def test_format_features(scanner_with_mocks): 'hardcoded_secrets': 1, 'public_access': 0, 'unencrypted_storage': 3, + 'missing_logging': 1, + 'missing_flow_logs': 0, 'total_resources': 10 } @@ -664,8 +675,8 @@ def test_predict_risk_with_features(ml_predictor): IMPROVED: Now includes relative score comparisons. """ # Test with different feature patterns - low_risk_features = np.array([[0, 0, 0, 0, 5]]) - high_risk_features = np.array([[3, 2, 2, 3, 20]]) + low_risk_features = np.array([[0, 0, 0, 0, 0, 0, 5]]) + high_risk_features = np.array([[3, 2, 2, 3, 1, 1, 20]]) low_score, low_conf = ml_predictor.predict_risk(low_risk_features) high_score, high_conf = ml_predictor.predict_risk(high_risk_features) @@ -692,8 +703,8 @@ def test_predict_risk_anomaly_detected(ml_predictor): IMPROVED: Strengthened assertions for anomaly detection. """ # Create a pattern that should be flagged as high risk - very_high_risk_features = np.array([[5, 5, 5, 5, 50]]) # Many vulnerabilities - moderate_risk_features = np.array([[1, 0, 1, 0, 10]]) # Few vulnerabilities + very_high_risk_features = np.array([[5, 5, 5, 5, 1, 1, 50]]) # Many vulnerabilities + moderate_risk_features = np.array([[1, 0, 1, 0, 0, 0, 10]]) # Few vulnerabilities very_high_score, very_high_conf = ml_predictor.predict_risk(very_high_risk_features) moderate_score, moderate_conf = ml_predictor.predict_risk(moderate_risk_features) @@ -710,7 +721,7 @@ def test_predict_risk_anomaly_detected(ml_predictor): def test_predict_risk_edge_cases(ml_predictor): """Test risk prediction with edge case inputs""" # Empty features (no vulnerabilities) - empty_features = np.array([[0, 0, 0, 0, 0]]) + empty_features = np.array([[0, 0, 0, 0, 0, 0, 0]]) score, confidence = ml_predictor.predict_risk(empty_features) assert score >= 0