From 55776b2eb3c606203bc7e83481657007b5ccf24a Mon Sep 17 00:00:00 2001 From: Matthias Schlaipfer Date: Wed, 19 Nov 2025 20:50:48 +0000 Subject: [PATCH 01/11] feat: add `--explain` --- iam-policy-autopilot-cli/src/main.rs | 51 +-- iam-policy-autopilot-cli/src/output.rs | 55 +-- .../src/tools/generate_policy.rs | 34 +- .../src/api/generate_policies.rs | 53 ++- .../src/api/model.rs | 14 +- .../src/enrichment/mod.rs | 246 ++++++++++++ .../src/enrichment/resource_matcher.rs | 196 +++++++-- .../src/extraction/engine.rs | 2 +- .../src/extraction/extractor.rs | 27 +- .../src/extraction/go/disambiguation.rs | 21 + .../src/extraction/go/extractor.rs | 89 +++-- .../src/extraction/go/features_extractor.rs | 32 +- .../src/extraction/go/paginator_extractor.rs | 80 ++-- .../src/extraction/go/waiter_extractor.rs | 181 ++++++--- .../javascript/argument_extractor.rs | 6 +- .../src/extraction/javascript/extractor.rs | 26 +- .../src/extraction/javascript/scanner.rs | 256 ++++++++---- .../src/extraction/javascript/shared.rs | 127 ++++-- .../src/extraction/javascript/types.rs | 32 +- .../src/extraction/mod.rs | 31 ++ .../src/extraction/python/disambiguation.rs | 13 + .../extraction/python/disambiguation_tests.rs | 22 +- .../src/extraction/python/extractor.rs | 40 +- .../extraction/python/paginator_extractor.rs | 77 ++-- .../python/resource_direct_calls_extractor.rs | 80 +++- .../extraction/python/waiters_extractor.rs | 163 +++++--- .../src/extraction/typescript/extractor.rs | 26 +- .../src/lib.rs | 5 +- .../src/policy_generation/engine.rs | 371 +++++++++++++++--- .../policy_generation/integration_tests.rs | 33 +- .../src/policy_generation/mod.rs | 71 +--- .../tests/go_extraction_integration_test.rs | 52 +-- .../tests/public_api_integration_test.rs | 34 +- 33 files changed, 1803 insertions(+), 743 deletions(-) diff --git a/iam-policy-autopilot-cli/src/main.rs b/iam-policy-autopilot-cli/src/main.rs index 9082697..d24f4c1 100644 --- a/iam-policy-autopilot-cli/src/main.rs +++ b/iam-policy-autopilot-cli/src/main.rs @@ -82,14 +82,14 @@ struct GeneratePolicyCliConfig { account: String, /// Output individual policies instead of merged policy individual_policies: bool, - /// Show method to action mappings alongside policies - show_action_mappings: bool, /// Upload policies to AWS with optional custom name prefix upload_policies: Option, /// Enable minimal policy size by allowing cross-service merging minimal_policy_size: bool, /// Disable file system caching for service references disable_cache: bool, + /// Generate explanations for why actions were added + explain: bool, } impl GeneratePolicyCliConfig { @@ -280,16 +280,6 @@ for each method call. Disables --upload-policy, if provided." )] individual_policies: bool, - /// Include method to action mappings alongside the generated policies - #[arg( - hide = true, - long = "show-action-mappings", - long_help = "When enabled, outputs detailed method to action \ -mappings alongside the generated policies. This provides granular visibility into which SDK method calls \ -require which specific IAM actions and their associated resources. Disables --upload-policy, if provided." - )] - show_action_mappings: bool, - /// Upload generated policies to AWS IAM with optional custom name prefix #[arg(long = "upload-policies", num_args = 0..=1, require_equals = false, default_missing_value = "", long_help = "Upload the generated policies to AWS IAM using the iam:CreatePolicy API. \ @@ -327,6 +317,15 @@ Use this flag to force fresh data retrieval on every run." long_help = SERVICE_HINTS_LONG_HELP, )] service_hints: Option>, + + /// Generate explanations for why actions were added + #[arg( + long = "explain", + long_help = "When enabled, generates detailed explanations for why each IAM action \ +was added to the policy. Explanations include the initial operation with location information, FAS (https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html) expansion chains. The output format \ +may change in future versions." + )] + explain: bool, }, /// Start MCP server @@ -425,35 +424,27 @@ async fn handle_generate_policy(config: &GeneratePolicyCliConfig) -> Result<()> service_names: names.clone(), }); - let (policies, method_action_mappings) = generate_policies(&GeneratePolicyConfig { + let result = generate_policies(&GeneratePolicyConfig { extract_sdk_calls_config: ExtractSdkCallsConfig { source_files: config.shared.source_files.to_owned(), language: config.shared.language.to_owned(), service_hints, }, aws_context: AwsContext::new(config.region.clone(), config.account.clone()), - generate_action_mappings: config.show_action_mappings, individual_policies: config.individual_policies, minimize_policy_size: config.minimal_policy_size, disable_file_system_cache: config.disable_cache, + generate_explanations: config.explain, }) .await?; - // Handle policy output based on configuration - if config.show_action_mappings { - // Output combined format with mappings and policies - output::output_combined_policy_mappings( - method_action_mappings, - policies, - config.shared.pretty, - ) - .context("Failed to output combined policy and mappings")?; - - trace!("Combined policy and mappings output written to stdout"); - } else if config.individual_policies { + let policies = result.policies; + let explanations = result.explanations; + + if config.individual_policies { // Output individual policies trace!("Outputting {} individual policies", policies.len()); - output::output_iam_policies(policies, None, config.shared.pretty) + output::output_iam_policies(policies, explanations, None, config.shared.pretty) .context("Failed to output individual IAM policies")?; } else { // Default behavior: output merged policy with optional upload @@ -492,7 +483,7 @@ async fn handle_generate_policy(config: &GeneratePolicyCliConfig) -> Result<()> None }; - output::output_iam_policies(policies, upload_result, config.shared.pretty) + output::output_iam_policies(policies, explanations, upload_result, config.shared.pretty) .context("Failed to output merged IAM policy")? } @@ -564,11 +555,11 @@ async fn main() { region, account, individual_policies, - show_action_mappings, upload_policies, minimal_policy_size, disable_cache, service_hints, + explain, } => { // Initialize logging if let Err(e) = init_logging(debug) { @@ -587,10 +578,10 @@ async fn main() { region, account, individual_policies, - show_action_mappings, upload_policies, minimal_policy_size, disable_cache, + explain, }; match handle_generate_policy(&config).await { diff --git a/iam-policy-autopilot-cli/src/output.rs b/iam-policy-autopilot-cli/src/output.rs index fc136fa..ef4f776 100644 --- a/iam-policy-autopilot-cli/src/output.rs +++ b/iam-policy-autopilot-cli/src/output.rs @@ -1,8 +1,6 @@ use anyhow::{Context, Result}; use iam_policy_autopilot_access_denied::{DenialType, PlanResult}; -use iam_policy_autopilot_policy_generation::policy_generation::{ - MethodActionMapping, PolicyWithMetadata, -}; +use iam_policy_autopilot_policy_generation::{policy_generation::PolicyWithMetadata, Explanation}; use iam_policy_autopilot_tools::BatchUploadResponse; use log::debug; use std::io::{self, Write}; @@ -164,62 +162,18 @@ pub(crate) fn print_unsupported_denial(denial_type: &DenialType, reason: &str) { struct PolicyOutput { /// The generated policies with type information policies: Vec, - /// Upload results (only present when --upload-policies is used) + /// Explanations for why actions were added #[serde(skip_serializing_if = "Option::is_none")] - upload_result: Option, -} - -/// Combined output structure when showing action mappings alongside policies -#[derive(Debug, Clone, serde::Serialize)] -#[serde(rename_all = "PascalCase")] -struct CombinedPolicyOutput { - /// Method to action mappings - method_action_mappings: Vec, - /// The generated policies with type information - policies: Vec, + explanations: Option>, /// Upload results (only present when --upload-policies is used) #[serde(skip_serializing_if = "Option::is_none")] upload_result: Option, } -/// Output combined policy and mappings as JSON to stdout -pub(crate) fn output_combined_policy_mappings( - method_action_mappings: Vec, - policies: Vec, - pretty: bool, -) -> Result<()> { - debug!( - "Formatting combined policy and mappings output as JSON (pretty: {})", - pretty - ); - - let combined_output = CombinedPolicyOutput { - method_action_mappings, - policies, - upload_result: None, - }; - - let json_output = if pretty { - iam_policy_autopilot_policy_generation::JsonProvider::stringify_pretty(&combined_output) - .context("Failed to serialize combined output to pretty JSON")? - } else { - iam_policy_autopilot_policy_generation::JsonProvider::stringify(&combined_output) - .context("Failed to serialize combined output to JSON")? - }; - - // Output to stdout (not using println! to avoid extra newline in compact mode) - print!("{}", json_output); - if pretty { - println!(); // Add newline for pretty output - } - - debug!("Combined policy and mappings JSON output written to stdout"); - Ok(()) -} - /// Output IAM policies as JSON to stdout pub(crate) fn output_iam_policies( policies: Vec, + explanations: Option>, upload_result: Option, pretty: bool, ) -> Result<()> { @@ -230,6 +184,7 @@ pub(crate) fn output_iam_policies( let policy_output = PolicyOutput { policies, + explanations, upload_result, }; diff --git a/iam-policy-autopilot-mcp-server/src/tools/generate_policy.rs b/iam-policy-autopilot-mcp-server/src/tools/generate_policy.rs index 4e373a2..eb7e371 100644 --- a/iam-policy-autopilot-mcp-server/src/tools/generate_policy.rs +++ b/iam-policy-autopilot-mcp-server/src/tools/generate_policy.rs @@ -52,7 +52,7 @@ pub async fn generate_application_policies( service_names: hints, }); - let (policies, _) = api::generate_policies(&GeneratePolicyConfig { + let result = api::generate_policies(&GeneratePolicyConfig { individual_policies: false, extract_sdk_calls_config: ExtractSdkCallsConfig { source_files: input.source_files.into_iter().map(|f| f.into()).collect(), @@ -61,16 +61,17 @@ pub async fn generate_application_policies( service_hints, }, aws_context: AwsContext::new(region, account), - generate_action_mappings: false, minimize_policy_size: false, // true by default, if we want to allow the user to change it we should // accept it as part of the cli input when starting the mcp server disable_file_system_cache: true, + generate_explanations: false, }) .await?; - let policies = policies + let policies = result + .policies .into_iter() .map(|policy| serde_json::to_string(&policy.policy).context("Failed to serialize policy")) .collect::, Error>>()?; @@ -82,26 +83,23 @@ pub async fn generate_application_policies( #[cfg(test)] mod api { use anyhow::Result; - use iam_policy_autopilot_policy_generation::{ - api::model::GeneratePolicyConfig, policy_generation::PolicyWithMetadata, - MethodActionMapping, + use iam_policy_autopilot_policy_generation::api::model::{ + GeneratePoliciesResult, GeneratePolicyConfig, }; // Static mutable return value - pub static mut MOCK_RETURN_VALUE: Option< - Result<(Vec, Vec)>, - > = None; + pub static mut MOCK_RETURN_VALUE: Option> = None; pub async fn generate_policies( _config: &GeneratePolicyConfig, - ) -> Result<(Vec, Vec)> { + ) -> Result { #[allow(static_mut_refs)] unsafe { MOCK_RETURN_VALUE.take().unwrap() } } - pub fn set_mock_return(value: Result<(Vec, Vec)>) { + pub fn set_mock_return(value: Result) { unsafe { MOCK_RETURN_VALUE = Some(value) } } } @@ -113,7 +111,7 @@ mod tests { use super::*; use iam_policy_autopilot_policy_generation::{ - IamPolicy, PolicyType, PolicyWithMetadata, Statement, + api::model::GeneratePoliciesResult, IamPolicy, PolicyType, PolicyWithMetadata, Statement, }; use anyhow::anyhow; @@ -143,7 +141,12 @@ mod tests { policy_type: PolicyType::Identity, }; - api::set_mock_return(Ok((vec![policy], vec![]))); + use iam_policy_autopilot_policy_generation::api::model::GeneratePoliciesResult; + + api::set_mock_return(Ok(GeneratePoliciesResult { + policies: vec![policy], + explanations: None, + })); let result = generate_application_policies(input).await; println!("{result:?}"); @@ -223,7 +226,10 @@ mod tests { policy_type: PolicyType::Identity, }; - api::set_mock_return(Ok((vec![policy], vec![]))); + api::set_mock_return(Ok(GeneratePoliciesResult { + policies: vec![policy], + explanations: None, + })); let result = generate_application_policies(input).await; assert!(result.is_ok()); diff --git a/iam-policy-autopilot-policy-generation/src/api/generate_policies.rs b/iam-policy-autopilot-policy-generation/src/api/generate_policies.rs index c75d6ad..093f857 100644 --- a/iam-policy-autopilot-policy-generation/src/api/generate_policies.rs +++ b/iam-policy-autopilot-policy-generation/src/api/generate_policies.rs @@ -4,16 +4,17 @@ use std::time::Instant; use log::{debug, info, trace}; use crate::{ - api::{common::process_source_files, model::GeneratePolicyConfig}, + api::{ + common::process_source_files, + model::{GeneratePoliciesResult, GeneratePolicyConfig}, + }, extraction::SdkMethodCall, - policy_generation::{merge::PolicyMergerConfig, MethodActionMapping, PolicyWithMetadata}, + policy_generation::merge::PolicyMergerConfig, EnrichmentEngine, PolicyGenerationEngine, }; -/// Generate polcies for source files -pub async fn generate_policies( - config: &GeneratePolicyConfig, -) -> Result<(Vec, Vec)> { +/// Generate policies for source files +pub async fn generate_policies(config: &GeneratePolicyConfig) -> Result { let pipeline_start = Instant::now(); debug!( @@ -55,7 +56,10 @@ pub async fn generate_policies( // Handle empty method lists gracefully if extracted_methods.is_empty() { info!("No methods found to process, returning empty policy list"); - return Ok((vec![], vec![])); + return Ok(GeneratePoliciesResult { + policies: vec![], + explanations: None, + }); } let mut enrichment_engine = EnrichmentEngine::new(config.disable_file_system_cache)?; @@ -85,7 +89,7 @@ pub async fn generate_policies( "Generating IAM policies from {} enriched method calls", enriched_results.len() ); - let policies = policy_engine + let result = policy_engine .generate_policies(&enriched_results) .context("Failed to generate IAM policies")?; @@ -93,10 +97,15 @@ pub async fn generate_policies( debug!( "Policy generation completed in {:?}, generated {} policies", total_duration, - policies.len() + result.policies.len() ); - let mut final_policies = policies; + let mut final_policies = result.policies; + let explanations = if config.generate_explanations { + result.explanations + } else { + None + }; if !config.individual_policies { final_policies = policy_engine @@ -104,24 +113,8 @@ pub async fn generate_policies( .context("Failed to merge IAM policies")?; } - // Handle policy output based on configuration - if config.generate_action_mappings { - // Extract method to action mappings using the core method - debug!( - "Extracting method to action mappings from {} enriched method calls", - enriched_results.len() - ); - let method_action_mappings = policy_engine - .extract_action_mappings(&enriched_results) - .context("Failed to extract method to action mappings")?; - - debug!( - "Extracted {} method to action mappings", - method_action_mappings.len() - ); - - return Ok((final_policies, method_action_mappings)); - } - - Ok((final_policies, vec![])) + Ok(GeneratePoliciesResult { + policies: final_policies, + explanations, + }) } diff --git a/iam-policy-autopilot-policy-generation/src/api/model.rs b/iam-policy-autopilot-policy-generation/src/api/model.rs index ff40859..b25199a 100644 --- a/iam-policy-autopilot-policy-generation/src/api/model.rs +++ b/iam-policy-autopilot-policy-generation/src/api/model.rs @@ -1,4 +1,5 @@ //! Defined model for API +use crate::{enrichment::Explanation, policy_generation::PolicyWithMetadata}; use std::path::PathBuf; /// Configuration for generate_policies Api @@ -8,14 +9,23 @@ pub struct GeneratePolicyConfig { pub extract_sdk_calls_config: ExtractSdkCallsConfig, /// AWS Config pub aws_context: AwsContext, - /// Generate Action Mappings - pub generate_action_mappings: bool, /// Output individual policies pub individual_policies: bool, /// Enable policy size minimization pub minimize_policy_size: bool, /// Disable file system caching for service references pub disable_file_system_cache: bool, + /// Generate explanations for why actions were added + pub generate_explanations: bool, +} + +/// Result of policy generation including policies, action mappings, and explanations +#[derive(Debug, Clone)] +pub struct GeneratePoliciesResult { + /// Generated IAM policies + pub policies: Vec, + /// Explanations for why actions were added (if requested) + pub explanations: Option>, } /// Service hints for filtering SDK method calls diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs index bdc0a24..78c740d 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs @@ -8,10 +8,68 @@ //! that represent method calls enriched with IAM metadata from operation //! action maps and Service Definition Files. +use std::path::PathBuf; + use crate::SdkMethodCall; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +/// Represents a location in a source file +/// +/// This struct stores file path and position information and serializes +/// to the GNU coding standard (https://www.gnu.org/prep/standards/html_node/Errors.html) +/// format: `filename:startLine.startCol-endLine.endCol` +#[derive(Debug, Clone, PartialEq, Eq, Hash, JsonSchema)] +#[schemars( + description = "File location in GNU coding standard format: filename:startLine.startCol-endLine.endCol" +)] +pub struct Location { + /// File path + pub file_path: PathBuf, + /// Starting position (line, column) - both 1-based + pub start_position: (usize, usize), + /// Ending position (line, column) - both 1-based + pub end_position: (usize, usize), +} + +impl Location { + /// Create a new Location + #[must_use] + pub fn new( + file_path: PathBuf, + start_position: (usize, usize), + end_position: (usize, usize), + ) -> Self { + Self { + file_path, + start_position, + end_position, + } + } + + /// Format as GNU coding standard: `filename:startLine.startCol-endLine.endCol` + #[must_use] + pub fn to_gnu_format(&self) -> String { + let path_str = self.file_path.display(); + let (start_line, start_col) = self.start_position; + let (end_line, end_col) = self.end_position; + + format!( + "{}:{}.{}-{}.{}", + path_str, start_line, start_col, end_line, end_col + ) + } +} + +impl Serialize for Location { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_gnu_format()) + } +} + pub(crate) mod engine; pub(crate) mod operation_fas_map; pub(crate) mod resource_matcher; @@ -22,6 +80,95 @@ pub(crate) use operation_fas_map::load_operation_fas_map; pub(crate) use resource_matcher::ResourceMatcher; pub(crate) use service_reference::RemoteServiceReferenceLoader as ServiceReferenceLoader; +/// Represents Forward Access Session (FAS) expansion information +#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] +#[serde(rename_all = "PascalCase")] +pub struct FasInfo { + /// Explanation URL for Forward Access Sessions + pub explanation: String, + /// The chain of operations in the FAS expansion + pub expansion: Vec, +} + +impl FasInfo { + /// Create a new FasInfo with the standard AWS documentation URL + #[must_use] + pub fn new(expansion: Vec) -> Self { + Self { + explanation: "https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html".to_string(), + expansion, + } + } +} + +/// Represents the reason why an action was added to a policy +#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] +#[serde(rename_all = "PascalCase")] +pub struct Reason { + /// The original operation that was extracted + pub operation: OperationView, + /// FAS (Forward Access Sessions) expansion information if this action came from FAS expansion + #[serde(rename = "FAS", skip_serializing_if = "Option::is_none")] + pub fas: Option, +} + +#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] +#[serde(rename_all = "PascalCase")] +#[serde(untagged)] +pub enum OperationView { + /// Operation extracted from source files + #[serde(rename_all = "PascalCase")] + Extracted { + /// Extracted name + name: String, + /// Detected service containing this operation + service: String, + /// Extracted expr + expr: String, + /// Location in source file + location: Location, + }, + /// Operation provided (no metadata available) + #[serde(rename_all = "PascalCase")] + Provided { + /// Provided name + name: String, + /// Provided service + service: String, + }, +} + +impl OperationView { + pub(crate) fn from_call(call: &SdkMethodCall, service: &str) -> Self { + match &call.metadata { + None => Self::Provided { + name: call.name.clone(), + service: service.to_string(), + }, + Some(metadata) => Self::Extracted { + name: call.name.clone(), + service: service.to_string(), + expr: metadata.expr.clone(), + location: Location::new( + metadata.file_path.clone(), + metadata.start_position, + metadata.end_position, + ), + }, + } + } +} + +/// Represents an explanation for why an action was added to a policy +#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] +#[serde(rename_all = "PascalCase")] +pub struct Explanation { + /// The action that was added (authorized action from service reference) + pub action: String, + /// The reasons this action was added (can have multiple reasons for the same action) + pub reasons: Vec, +} + /// Represents an enriched method call with actions that need permissions #[derive(Debug, Clone, Serialize, PartialEq)] #[non_exhaustive] @@ -34,6 +181,8 @@ pub struct EnrichedSdkMethodCall<'a> { pub(crate) actions: Vec, /// The initial SDK method call pub(crate) sdk_method_call: &'a SdkMethodCall, + /// Explanations for why each action was added + pub(crate) explanations: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, JsonSchema)] @@ -362,3 +511,100 @@ pub(crate) mod mock_remote_service_reference { (mock_server, loader) } } + +#[cfg(test)] +mod location_tests { + use super::*; + use std::path::PathBuf; + + #[test] + fn test_location_to_gnu_string() { + let location = Location::new(PathBuf::from("src/main.rs"), (10, 5), (10, 79)); + + assert_eq!(location.to_gnu_format(), "src/main.rs:10.5-10.79"); + } + + #[test] + fn test_location_to_gnu_string_multiline() { + let location = Location::new(PathBuf::from("src/lib.rs"), (10, 5), (15, 20)); + + assert_eq!(location.to_gnu_format(), "src/lib.rs:10.5-15.20"); + } + + #[test] + fn test_location_serialization() { + let location = Location::new(PathBuf::from("test.py"), (42, 15), (42, 80)); + + let json = serde_json::to_string(&location).unwrap(); + assert_eq!(json, "\"test.py:42.15-42.80\""); + } + + #[test] + fn test_location_serialization_multiline() { + let location = Location::new(PathBuf::from("example.go"), (100, 1), (105, 50)); + + let json = serde_json::to_string(&location).unwrap(); + assert_eq!(json, "\"example.go:100.1-105.50\""); + } + + #[test] + fn test_operation_view_extracted_with_location() { + use crate::extraction::{SdkMethodCall, SdkMethodCallMetadata}; + + let call = SdkMethodCall { + name: "get_object".to_string(), + possible_services: vec!["s3".to_string()], + metadata: Some(SdkMethodCallMetadata { + parameters: vec![], + return_type: None, + expr: "s3.get_object(Bucket='my-bucket')".to_string(), + file_path: PathBuf::from("test.py"), + start_position: (10, 5), + end_position: (10, 79), + receiver: Some("s3".to_string()), + }), + }; + + let operation_view = OperationView::from_call(&call, "s3"); + + match operation_view { + OperationView::Extracted { + name, + service, + expr, + location, + } => { + assert_eq!(name, "get_object"); + assert_eq!(service, "s3"); + assert_eq!(expr, "s3.get_object(Bucket='my-bucket')"); + assert_eq!(location.to_gnu_format(), "test.py:10.5-10.79"); + } + _ => panic!("Expected Extracted variant"), + } + } + + #[test] + fn test_operation_view_extracted_serialization() { + use crate::extraction::{SdkMethodCall, SdkMethodCallMetadata}; + + let call = SdkMethodCall { + name: "list_buckets".to_string(), + possible_services: vec!["s3".to_string()], + metadata: Some(SdkMethodCallMetadata { + parameters: vec![], + return_type: None, + expr: "s3.list_buckets()".to_string(), + file_path: PathBuf::from("app.py"), + start_position: (5, 1), + end_position: (5, 20), + receiver: Some("s3".to_string()), + }), + }; + + let operation_view = OperationView::from_call(&call, "s3"); + let json = serde_json::to_string(&operation_view).unwrap(); + + // Verify the location is serialized as a string in GNU format + assert!(json.contains("\"Location\":\"app.py:5.1-5.20\"")); + } +} diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs index d9e21c3..3d0508f 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs @@ -8,7 +8,9 @@ use convert_case::{Case, Casing}; use std::collections::HashSet; use std::sync::Arc; -use super::{Action, Context, EnrichedSdkMethodCall, Resource}; +use super::{ + Action, Context, EnrichedSdkMethodCall, Explanation, FasInfo, OperationView, Reason, Resource, +}; use crate::enrichment::operation_fas_map::{FasOperation, OperationFasMap, OperationFasMaps}; use crate::enrichment::service_reference::ServiceReference; use crate::enrichment::{Condition, ServiceReferenceLoader}; @@ -16,6 +18,30 @@ use crate::errors::{ExtractorError, Result}; use crate::service_configuration::ServiceConfiguration; use crate::{SdkMethodCall, SdkType}; +/// Represents an operation with its provenance chain (how we reached this operation via FAS expansion) +#[derive(Debug, Clone)] +struct OperationWithFasExpansion { + operation: FasOperation, + /// Chain of operation names leading to this operation (excludes current operation) + fas_chain: Vec, +} + +// Custom PartialEq and Hash that only consider the operation, not the FAS chain +// This prevents infinite loops in cycles where the same operation appears with different FAS chains +impl PartialEq for OperationWithFasExpansion { + fn eq(&self, other: &Self) -> bool { + self.operation == other.operation + } +} + +impl Eq for OperationWithFasExpansion {} + +impl std::hash::Hash for OperationWithFasExpansion { + fn hash(&self, state: &mut H) { + self.operation.hash(state); + } +} + /// ResourceMatcher coordinates OperationAction maps and Service Reference data to generate enriched method calls /// /// This struct provides the core functionality for the 3-stage enrichment pipeline, @@ -95,27 +121,36 @@ impl ResourceMatcher { /// This method safely expands FAS operations by iteratively processing new operations /// until no more new operations are discovered (fixed point reached). /// It includes cycle detection to prevent infinite loops. + /// + /// Returns operations with their FAS provenance chains showing how each operation was reached. fn expand_fas_operations_to_fixed_point( &self, initial: FasOperation, - ) -> Result> { - let mut operations = HashSet::::new(); - operations.insert(initial); + ) -> Result> { + let mut operations = HashSet::::new(); + + // Initial operation has empty FAS chain + let initial = OperationWithFasExpansion { + operation: initial.clone(), + fas_chain: vec![], + }; + operations.insert(initial.clone()); + + let mut to_process = vec![initial]; - let mut to_process = operations.clone(); while !to_process.is_empty() { - let mut newly_discovered = HashSet::::new(); + let mut newly_discovered = Vec::new(); // Process all operations in the current batch - for operation in &to_process { - let service_name = operation.service(&self.service_cfg); + for current in &to_process { + let service_name = current.operation.service(&self.service_cfg); let operation_fas_map_option = self.find_operation_fas_map_for_service(&service_name); match operation_fas_map_option { Some(operation_fas_map) => { let service_operation_name = - operation.service_operation_name(&self.service_cfg); + current.operation.service_operation_name(&self.service_cfg); log::debug!("Looking up operation {}", service_operation_name); if let Some(additional_operations) = operation_fas_map @@ -123,9 +158,21 @@ impl ResourceMatcher { .get(&service_operation_name) { for additional_op in additional_operations { + // Build new FAS hain: current chain + current operation + let mut new_chain = current.fas_chain.clone(); + new_chain.push( + current.operation.service_operation_name(&self.service_cfg), + ); + + let new_op_with_prov = OperationWithFasExpansion { + operation: additional_op.clone(), + fas_chain: new_chain, + }; + // Only add if we haven't seen this operation before - if !operations.contains(additional_op) { - newly_discovered.insert(additional_op.clone()); + if !operations.contains(&new_op_with_prov) { + operations.insert(new_op_with_prov.clone()); + newly_discovered.push(new_op_with_prov); } } } else { @@ -138,9 +185,6 @@ impl ResourceMatcher { } } - // Add newly discovered operations to our complete set - operations.extend(newly_discovered.iter().cloned()); - let newly_discovered_count = newly_discovered.len(); // Set up next iteration to process only newly discovered operations @@ -158,8 +202,8 @@ impl ResourceMatcher { ); // Convert HashSet to Vec and sort by service_operation_name for deterministic output - let mut operations_vec: Vec = operations.into_iter().collect(); - operations_vec.sort_by_key(|op| op.service_operation_name(&self.service_cfg)); + let mut operations_vec: Vec = operations.into_iter().collect(); + operations_vec.sort_by_key(|op| op.operation.service_operation_name(&self.service_cfg)); Ok(operations_vec) } @@ -189,6 +233,9 @@ impl ResourceMatcher { parsed_call.name ); + // Store the original service name from parsed_call for use in explanations + let original_service_name = service_name; + let initial = { let initial_service_name = self .service_cfg @@ -224,9 +271,11 @@ impl ResourceMatcher { // Use fixed-point algorithm to safely expand FAS operations until no new operations are found let operations = self.expand_fas_operations_to_fixed_point(initial)?; + let mut explanations = vec![]; let mut enriched_actions = vec![]; - for operation in operations { - let service_name = operation.service(&self.service_cfg); + + for op_with_fas_chain in operations { + let service_name = op_with_fas_chain.operation.service(&self.service_cfg); // Find the corresponding SDF using the cache let service_reference = service_reference_loader.load(&service_name).await?; @@ -235,18 +284,25 @@ impl ResourceMatcher { continue; } Some(service_reference) => { - log::debug!("Creating actions for {:?}", operation); - log::debug!(" with context {:?}", operation.context); + log::debug!("Creating actions for {:?}", op_with_fas_chain.operation); + log::debug!(" with context {:?}", op_with_fas_chain.operation.context); + log::debug!(" FAS chain: {:?}", op_with_fas_chain.fas_chain); + if let Some(operation_to_authorized_actions) = &service_reference.operation_to_authorized_actions { log::debug!( "Looking up {}", - &operation.service_operation_name(&self.service_cfg) + &op_with_fas_chain + .operation + .service_operation_name(&self.service_cfg) ); if let Some(operation_to_authorized_action) = - operation_to_authorized_actions - .get(&operation.service_operation_name(&self.service_cfg)) + operation_to_authorized_actions.get( + &op_with_fas_chain + .operation + .service_operation_name(&self.service_cfg), + ) { for action in &operation_to_authorized_action.authorized_actions { let enriched_resources = self @@ -262,7 +318,8 @@ impl ResourceMatcher { }; // Combine conditions from FAS operation context and AuthorizedAction context - let mut conditions = Self::make_condition(&operation.context); + let mut conditions = + Self::make_condition(&op_with_fas_chain.operation.context); // Add conditions from AuthorizedAction context if present if let Some(auth_context) = &action.context { @@ -278,6 +335,42 @@ impl ResourceMatcher { ); enriched_actions.push(enriched_action); + + // Include FAS chain only if chain is non-empty + let fas_info = if op_with_fas_chain.fas_chain.is_empty() { + log::debug!( + " Action '{}': Excluding FAS chain (initial operation)", + action.name + ); + None + } else { + // Build full chain: fas_chain + current operation + let mut chain = op_with_fas_chain.fas_chain.clone(); + chain.push( + op_with_fas_chain + .operation + .service_operation_name(&self.service_cfg), + ); + log::debug!( + " Action '{}': Including FAS chain: {:?}", + action.name, + chain + ); + Some(FasInfo::new(chain)) + }; + + // Create explanation for this action + let explanation = Explanation { + action: action.name.clone(), + reasons: vec![Reason { + operation: OperationView::from_call( + parsed_call, + original_service_name, + ), + fas: fas_info, + }], + }; + explanations.push(explanation); } } else { // Fallback: operation not found in operation action map, create basic action @@ -285,7 +378,21 @@ impl ResourceMatcher { if let Some(a) = self.create_fallback_action(&parsed_call.name, &service_reference)? { - enriched_actions.push(a) + let action_name = a.name.clone(); + enriched_actions.push(a); + + // Create explanation for fallback action + let explanation = Explanation { + action: action_name, + reasons: vec![Reason { + operation: OperationView::from_call( + parsed_call, + original_service_name, + ), + fas: None, + }], + }; + explanations.push(explanation); } } } else { @@ -293,7 +400,21 @@ impl ResourceMatcher { if let Some(a) = self.create_fallback_action(&parsed_call.name, &service_reference)? { - enriched_actions.push(a) + let action_name = a.name.clone(); + enriched_actions.push(a); + + // Create explanation for fallback action + let explanation = Explanation { + action: action_name, + reasons: vec![Reason { + operation: OperationView::from_call( + parsed_call, + original_service_name, + ), + fas: None, + }], + }; + explanations.push(explanation); } } } @@ -306,9 +427,10 @@ impl ResourceMatcher { Ok(Some(EnrichedSdkMethodCall { method_name: parsed_call.name.clone(), - service: service_name.to_string(), + service: original_service_name.to_string(), actions: enriched_actions, sdk_method_call: parsed_call, + explanations, })) } @@ -1024,7 +1146,7 @@ mod tests { // Verify all expected operations are present let operation_names: std::collections::HashSet = operations .iter() - .map(|op| op.service_operation_name(&service_cfg)) + .map(|op| op.operation.service_operation_name(&service_cfg)) .collect(); assert!(operation_names.contains("service-a:GetObject")); @@ -1105,7 +1227,7 @@ mod tests { // Debug: print what operations we actually got let operation_names: std::collections::HashSet = operations .iter() - .map(|op| op.service_operation_name(&service_cfg)) + .map(|op| op.operation.service_operation_name(&service_cfg)) .collect(); // 3 operations, note that GetObject occurs twice, once with and once without context @@ -1204,10 +1326,14 @@ mod tests { 1, "Should contain only the initial operation" ); - assert!( - operations.contains(&initial), + assert_eq!( + operations[0].operation, initial, "Should contain the initial operation" ); + assert!( + operations[0].fas_chain.is_empty(), + "Initial operation should have empty FAS chain" + ); println!("✓ Test passed: Handles case with no additional FAS operations"); } @@ -1263,10 +1389,14 @@ mod tests { 1, "Self-cycle with identical context should result in exactly 1 operation" ); - assert!( - operations.contains(&initial), + assert_eq!( + operations[0].operation, initial, "Should contain the initial operation" ); + assert!( + operations[0].fas_chain.is_empty(), + "Initial operation should have empty FAS chain" + ); println!("✓ Test passed: Self-cycle with empty context handled correctly"); } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/engine.rs b/iam-policy-autopilot-policy-generation/src/extraction/engine.rs index 912e1c9..c02a8b3 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/engine.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/engine.rs @@ -94,7 +94,7 @@ impl Engine { for source_file in source_files { let extractor = extractor.clone(); - join_set.spawn(async move { extractor.parse(&source_file.content).await }); + join_set.spawn(async move { extractor.parse(&source_file).await }); } // Collect results from concurrent tasks diff --git a/iam-policy-autopilot-policy-generation/src/extraction/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/extractor.rs index 9b9c190..ab5d164 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/extractor.rs @@ -1,31 +1,18 @@ //! Result type alias for operations that can fail with `ExtractorError` -use ast_grep_core::AstGrep; use ast_grep_language::{Go, JavaScript, Python, TypeScript}; use async_trait::async_trait; use crate::extraction::go::types::GoImportInfo; -use crate::{SdkMethodCall, ServiceModelIndex}; +use crate::extraction::AstWithSourceFile; +use crate::{SdkMethodCall, ServiceModelIndex, SourceFile}; /// Enum to handle different AST types from different languages #[derive(Clone)] pub(crate) enum ExtractorResult { - Python( - AstGrep>, - Vec, - ), - Go( - AstGrep>, - Vec, - GoImportInfo, - ), - JavaScript( - AstGrep>, - Vec, - ), - TypeScript( - AstGrep>, - Vec, - ), + Python(AstWithSourceFile, Vec), + Go(AstWithSourceFile, Vec, GoImportInfo), + JavaScript(AstWithSourceFile, Vec), + TypeScript(AstWithSourceFile, Vec), } impl ExtractorResult { @@ -64,7 +51,7 @@ impl ExtractorResult { #[async_trait] pub(crate) trait Extractor: Send + Sync { /// Parse source code into method calls and return the AST - async fn parse(&self, source_code: &str) -> ExtractorResult; + async fn parse(&self, source_file: &SourceFile) -> ExtractorResult; fn filter_map( &self, diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs index ad93529..b205890 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs @@ -335,6 +335,7 @@ mod tests { }; use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; use std::collections::HashMap; + use std::path::PathBuf; fn create_test_service_index() -> ServiceModelIndex { let mut services = HashMap::new(); @@ -587,6 +588,7 @@ mod tests { name: "ListObjectsV2".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "ListObjectsV2".to_string(), parameters: vec![ Parameter::Positional { value: ParameterValue::Unresolved("context.TODO()".to_string()), @@ -605,6 +607,7 @@ mod tests { }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 50), receiver: Some("client".to_string()), @@ -648,6 +651,7 @@ mod tests { name: "NonAwsMethod".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "NonAwsMethod".to_string(), parameters: vec![Parameter::Positional { value: ParameterValue::Unresolved("someParam".to_string()), position: 0, @@ -655,6 +659,7 @@ mod tests { struct_fields: None, }], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 30), receiver: Some("client".to_string()), @@ -685,6 +690,7 @@ mod tests { name: "ListObjectsV2".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "ListObjectsV2".to_string(), parameters: vec![ Parameter::Positional { value: ParameterValue::Unresolved("context.TODO()".to_string()), @@ -703,6 +709,7 @@ mod tests { }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 50), receiver: Some("client".to_string()), @@ -728,6 +735,7 @@ mod tests { name: "ListObjectsV2".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "ListObjectsV2".to_string(), parameters: vec![ Parameter::Positional { value: ParameterValue::Unresolved("context.TODO()".to_string()), @@ -746,6 +754,7 @@ mod tests { }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 50), receiver: Some("client".to_string()), @@ -777,6 +786,7 @@ mod tests { name: "ListObjectsV2".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "ListObjectsV2".to_string(), parameters: vec![ Parameter::Positional { value: ParameterValue::Unresolved("context.TODO()".to_string()), @@ -795,6 +805,7 @@ mod tests { }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 50), receiver: Some("client".to_string()), @@ -832,6 +843,8 @@ mod tests { struct_fields: Some(vec!["Bucket".to_string()]), }, ], + expr: "GetObject".to_string(), + file_path: PathBuf::new(), return_type: None, start_position: (1, 1), end_position: (1, 50), @@ -870,6 +883,8 @@ mod tests { struct_fields: Some(vec!["Bucket".to_string(), "Key".to_string()]), }, ], + expr: "GetObject".to_string(), + file_path: PathBuf::new(), return_type: None, start_position: (1, 1), end_position: (1, 50), @@ -907,6 +922,8 @@ mod tests { struct_fields: Some(vec!["Bucket".to_string(), "Key".to_string()]), }, ], + expr: "GetObject".to_string(), + file_path: PathBuf::new(), return_type: None, start_position: (1, 1), end_position: (1, 50), @@ -949,6 +966,8 @@ mod tests { struct_fields: Some(vec!["Bucket".to_string(), "Key".to_string()]), }, ], + expr: "GetObject".to_string(), + file_path: PathBuf::new(), return_type: None, start_position: (1, 1), end_position: (1, 50), @@ -1006,6 +1025,8 @@ mod tests { struct_fields: Some(vec!["QueueName".to_string(), "Attributes".to_string()]), }, ], + expr: "CreateQueue".to_string(), + file_path: PathBuf::new(), return_type: None, start_position: (1, 1), end_position: (1, 50), diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs index 70433fa..fa6de92 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs @@ -7,8 +7,10 @@ use crate::extraction::go::node_kinds; use crate::extraction::go::paginator_extractor::GoPaginatorExtractor; use crate::extraction::go::types::{GoImportInfo, ImportInfo}; use crate::extraction::go::waiter_extractor::GoWaiterExtractor; -use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; -use crate::ServiceModelIndex; +use crate::extraction::{ + AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, +}; +use crate::{ServiceModelIndex, SourceFile}; use ast_grep_config::from_yaml_string; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Go; @@ -23,12 +25,9 @@ impl GoExtractor { } /// Extract import statements from Go source code using ast-grep - fn extract_imports( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> GoImportInfo { + fn extract_imports(&self, ast: &AstWithSourceFile) -> GoImportInfo { let mut import_info = GoImportInfo::new(); - let root = ast.root(); + let root = ast.ast.root(); // AST-grep configuration for extracting import statements let import_config = r#" @@ -142,6 +141,7 @@ rule: fn parse_method_call( &self, node_match: &ast_grep_core::NodeMatch>, + source_file: &SourceFile, ) -> Option { let env = node_match.get_env(); @@ -200,6 +200,8 @@ rule: metadata: Some(SdkMethodCallMetadata { parameters: arguments, return_type: None, // We don't know the return type from the call site + expr: node_match.text().to_string(), + file_path: source_file.path.clone(), start_position: (start.line() + 1, start.column(node) + 1), end_position: (end.line() + 1, end.column(node) + 1), receiver, @@ -419,9 +421,13 @@ impl Default for GoExtractor { #[async_trait] impl Extractor for GoExtractor { - async fn parse(&self, source_code: &str) -> crate::extraction::extractor::ExtractorResult { - let ast_grep = Go.ast_grep(source_code); - let root = ast_grep.root(); + async fn parse( + &self, + source_file: &SourceFile, + ) -> crate::extraction::extractor::ExtractorResult { + let ast_grep = Go.ast_grep(&source_file.content); + let ast = AstWithSourceFile::new(ast_grep, source_file.clone()); + let root = ast.ast.root(); let mut method_calls = Vec::new(); @@ -453,15 +459,15 @@ rule: // Find all method calls with attribute access: receiver.method(args) for node_match in root.find_all(&config.matcher) { - if let Some(method_call) = self.parse_method_call(&node_match) { + if let Some(method_call) = self.parse_method_call(&node_match, source_file) { method_calls.push(method_call); } } // Extract import information - let import_info = self.extract_imports(&ast_grep); + let import_info = self.extract_imports(&ast); - crate::extraction::extractor::ExtractorResult::Go(ast_grep, method_calls, import_info) + crate::extraction::extractor::ExtractorResult::Go(ast, method_calls, import_info) } fn filter_map( @@ -573,6 +579,8 @@ impl Parameter { #[cfg(test)] mod tests { + use std::path::PathBuf; + use super::*; #[tokio::test] @@ -615,8 +623,10 @@ func main() { } } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), aws_code.to_string(), crate::Language::Go); - let result = extractor.parse(aws_code).await; + let result = extractor.parse(&source_file).await; let aws_method_calls = result.method_calls_ref(); println!("AWS test - Found {} method calls:", aws_method_calls.len()); @@ -655,8 +665,10 @@ func main() { result2, err2 := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); println!( @@ -722,8 +734,10 @@ func main() { result3, err3 := client.GetObject(ctx, &s3.GetObjectInput{}) } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); println!( @@ -813,8 +827,10 @@ func main() { result3, err3 := factory.createClient("s3").GetObject(ctx, &s3.GetObjectInput{}) } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); println!( @@ -905,8 +921,10 @@ func main() { DescribeInstances(ctx, &ec2.DescribeInstancesInput{}) } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); println!( @@ -1167,8 +1185,10 @@ func main() { _ = objectOutput } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; // Verify import extraction if let Some(import_info) = result.go_import_info() { @@ -1288,8 +1308,10 @@ func main() { _ = output } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; // Verify import extraction if let Some(import_info) = result.go_import_info() { @@ -1363,7 +1385,9 @@ func (basics BucketBasics) DownloadFile(ctx context.Context, bucketName string, } "#; - let result = extractor.parse(test_code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); println!( @@ -1470,9 +1494,12 @@ func (basics BucketBasics) DownloadFile(ctx context.Context, bucketName string, } #[cfg(test)] mod test_struct_fields { + use std::path::PathBuf; + use crate::extraction::extractor::Extractor; use crate::extraction::go::extractor::GoExtractor; use crate::extraction::Parameter; + use crate::SourceFile; /// Test extraction of struct literals with multiple fields #[tokio::test] @@ -1497,7 +1524,9 @@ func test() { } "#; - let result = extractor.parse(code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); let call = method_calls @@ -1568,7 +1597,9 @@ func test() { } "#; - let result = extractor.parse(code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); let call = method_calls @@ -1621,7 +1652,9 @@ func test() { } "#; - let result = extractor.parse(code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); let call = method_calls @@ -1672,7 +1705,9 @@ func test() { } "#; - let result = extractor.parse(code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); let call = method_calls @@ -1720,7 +1755,9 @@ func test() { } "#; - let result = extractor.parse(code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); let call = method_calls diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs index bd58dab..1c49ffb 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs @@ -3,10 +3,12 @@ //! This module handles extraction of Go AWS SDK v2 feature methods like S3 Upload/Download, //! and other specialized SDK features. +use std::path::PathBuf; + use crate::extraction::go::features::{FeatureMethod, GoSdkV2Features}; use crate::extraction::go::types::GoImportInfo; use crate::extraction::go::utils; -use crate::extraction::{SdkMethodCall, SdkMethodCallMetadata}; +use crate::extraction::{AstWithSourceFile, SdkMethodCall, SdkMethodCallMetadata}; use ast_grep_config::from_yaml_string; use ast_grep_language::Go; @@ -19,6 +21,10 @@ pub(crate) struct FeatureCallInfo { pub(crate) receiver: Option, /// Extracted arguments pub(crate) arguments: Vec, + /// File where we found the feature call + pub(crate) file_path: PathBuf, + /// Matched expression + pub(crate) expr: String, /// Start position of the call node pub(crate) start_position: (usize, usize), /// End position of the call node @@ -43,7 +49,7 @@ impl GoFeaturesExtractor { /// Extract feature method calls from the AST pub(crate) fn extract_feature_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, import_info: &mut GoImportInfo, ) -> Vec { let mut synthetic_calls = Vec::new(); @@ -63,11 +69,8 @@ impl GoFeaturesExtractor { /// Find all method calls that might be feature methods /// This matches receiver.Method(...) patterns using proper ast-grep config - fn find_method_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_method_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut calls = Vec::new(); // Use the same pattern as the main extractor for method calls @@ -118,6 +121,8 @@ rule: method_name, receiver: Some(receiver), arguments, + expr: node_match.text().to_string(), + file_path: ast.source_file.path.clone(), start_position: (start.line() + 1, start.column(node) + 1), end_position: (end.line() + 1, end.column(node) + 1), }); @@ -224,6 +229,8 @@ rule: metadata: Some(SdkMethodCallMetadata { parameters: parameters.clone(), return_type: None, + expr: call_info.expr.clone(), + file_path: call_info.file_path.clone(), start_position: call_info.start_position, end_position: call_info.end_position, receiver: call_info.receiver.clone(), @@ -236,14 +243,17 @@ rule: #[cfg(test)] mod tests { + use crate::{Language, SourceFile}; + use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Go; - fn create_test_ast( - source_code: &str, - ) -> ast_grep_core::AstGrep> { - Go.ast_grep(source_code) + fn create_test_ast(source_code: &str) -> AstWithSourceFile { + let source_file = + SourceFile::with_language(PathBuf::new(), source_code.to_string(), Language::Go); + let ast_grep = Go.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file) } fn create_test_import_info() -> GoImportInfo { diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs index 8380d79..7f418e9 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs @@ -3,9 +3,11 @@ //! This module handles extraction of Go AWS SDK v2 paginator patterns by detecting //! paginator creation calls, which contain the meaningful parameters for IAM policy generation. +use std::path::{Path, PathBuf}; + use crate::extraction::go::utils; use crate::extraction::sdk_model::ServiceDiscovery; -use crate::extraction::{Parameter, SdkMethodCall, SdkMethodCallMetadata}; +use crate::extraction::{AstWithSourceFile, Parameter, SdkMethodCall, SdkMethodCallMetadata}; use crate::Language; use crate::ServiceModelIndex; use ast_grep_language::Go; @@ -23,8 +25,14 @@ pub(crate) struct PaginatorInfo { pub client_receiver: String, /// Extracted arguments from paginator creation (input struct) pub creation_arguments: Vec, - /// Line number where paginator was created - pub creation_line: usize, + /// Matched expression + pub expr: String, + /// File where we found the paginator + pub file_path: PathBuf, + /// Start position where paginator was created + pub start_position: (usize, usize), + /// End position where paginator was created + pub end_position: (usize, usize), } /// Information about a chained paginator call @@ -36,9 +44,10 @@ pub(crate) struct ChainedPaginatorCallInfo { pub client_receiver: String, /// Extracted arguments from paginator creation (input struct) pub arguments: Vec, - /// Line number where chained call was made - #[allow(dead_code)] - pub line: usize, + /// Matched expression + pub expr: String, + /// File where the chained paginator call was found + pub file_path: PathBuf, /// Start position of the chained call node pub start_position: (usize, usize), /// End position of the chained call node @@ -66,7 +75,7 @@ impl<'a> GoPaginatorExtractor<'a> { /// Extract paginator method calls from the AST pub(crate) fn extract_paginator_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { let mut synthetic_calls = Vec::new(); @@ -88,18 +97,17 @@ impl<'a> GoPaginatorExtractor<'a> { } /// Find all paginator creation calls (NewXxxPaginator functions) - fn find_paginator_creation_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_paginator_creation_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut paginators = Vec::new(); // Pattern: $VAR := $PACKAGE.$FUNCTION($$$ARGS) where FUNCTION contains "New" and "Paginator" let paginator_pattern = "$VAR := $PACKAGE.$FUNCTION($$$ARGS)"; for node_match in root.find_all(paginator_pattern) { - if let Some(paginator_info) = self.parse_paginator_creation_call(&node_match) { + if let Some(paginator_info) = + self.parse_paginator_creation_call(&node_match, &ast.source_file.path) + { paginators.push(paginator_info); } } @@ -110,16 +118,18 @@ impl<'a> GoPaginatorExtractor<'a> { /// Find all chained paginator calls fn find_chained_paginator_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut chained_calls = Vec::new(); // Pattern: $PACKAGE.$FUNCTION($$$ARGS).NextPage($$$NEXT_ARGS) let chained_pattern = "$PACKAGE.$FUNCTION($$$ARGS).NextPage($$$NEXT_ARGS)"; for node_match in root.find_all(chained_pattern) { - if let Some(chained_info) = self.parse_chained_paginator_call(&node_match) { + if let Some(chained_info) = + self.parse_chained_paginator_call(&node_match, &ast.source_file.path) + { chained_calls.push(chained_info); } } @@ -131,6 +141,7 @@ impl<'a> GoPaginatorExtractor<'a> { fn parse_paginator_creation_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -167,14 +178,24 @@ impl<'a> GoPaginatorExtractor<'a> { .and_then(|s| s.strip_suffix("Paginator")); if let Some(operation_name) = operation_name { - let creation_line = node_match.get_node().start_pos().line() + 1; + let start_position = { + let pos = node_match.get_node().start_pos(); + (pos.line() + 1, pos.column(node_match.get_node())) + }; + let end_position = { + let pos = node_match.get_node().end_pos(); + (pos.line() + 1, pos.column(node_match.get_node())) + }; return Some(PaginatorInfo { variable_name, paginator_type: operation_name.to_string(), client_receiver, creation_arguments, - creation_line, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), + start_position, + end_position, }); } @@ -185,6 +206,7 @@ impl<'a> GoPaginatorExtractor<'a> { fn parse_chained_paginator_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -226,7 +248,8 @@ impl<'a> GoPaginatorExtractor<'a> { paginator_type, client_receiver, arguments: creation_arguments, - line: start.line() + 1, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), start_position: (start.line() + 1, start.column(node) + 1), end_position: (end.line() + 1, end.column(node) + 1), }) @@ -257,8 +280,10 @@ impl<'a> GoPaginatorExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: paginator_info.creation_arguments.clone(), return_type: None, - start_position: (paginator_info.creation_line, 1), - end_position: (paginator_info.creation_line, 1), + expr: paginator_info.expr.clone(), + file_path: paginator_info.file_path.clone(), + start_position: paginator_info.start_position, + end_position: paginator_info.end_position, receiver: Some(paginator_info.client_receiver.clone()), }), } @@ -292,6 +317,8 @@ impl<'a> GoPaginatorExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: chained_call.arguments.clone(), return_type: None, + expr: chained_call.expr.clone(), + file_path: chained_call.file_path.clone(), start_position: chained_call.start_position, end_position: chained_call.end_position, receiver: Some(chained_call.client_receiver.clone()), @@ -302,15 +329,18 @@ impl<'a> GoPaginatorExtractor<'a> { #[cfg(test)] mod tests { + use crate::SourceFile; + use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Go; use std::collections::HashMap; - fn create_test_ast( - source_code: &str, - ) -> ast_grep_core::AstGrep> { - Go.ast_grep(source_code) + fn create_test_ast(source_code: &str) -> AstWithSourceFile { + let source_file = + SourceFile::with_language(PathBuf::new(), source_code.to_string(), Language::Go); + let ast_grep = Go.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file) } fn create_test_service_index() -> ServiceModelIndex { diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs index 3a961bc..a438cc1 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs @@ -3,8 +3,12 @@ //! This module handles extraction of Go AWS SDK waiter patterns, which involve //! creating a waiter from a client, then calling Wait() on the waiter. +use std::path::{Path, PathBuf}; + use crate::extraction::go::utils; -use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; +use crate::extraction::{ + AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, +}; use crate::ServiceModelIndex; use ast_grep_language::Go; @@ -14,11 +18,17 @@ pub(crate) struct WaiterInfo { /// Variable name assigned to the waiter (e.g., "waiter", "instanceWaiter") pub variable_name: String, /// Clean waiter name (e.g., "InstanceTerminated") - pub waiter_type: String, + pub waiter_name: String, /// Client receiver variable name (e.g., "client", "ec2Client") pub client_receiver: String, + /// Matched expression + pub expr: String, + /// File where the waiter was found + pub file_path: PathBuf, + /// Line number where waiter was created + pub start_position: (usize, usize), /// Line number where waiter was created - pub creation_line: usize, + pub end_position: (usize, usize), } /// Information about a Wait method call @@ -28,14 +38,65 @@ pub(crate) struct WaitCallInfo { pub waiter_var: String, /// Extracted arguments (context + input struct) pub arguments: Vec, - /// Line number where Wait was called - pub wait_line: usize, + /// Matched expression + pub expr: String, + /// File where the wait call was found + pub file_path: PathBuf, /// Start position of the Wait call node pub start_position: (usize, usize), /// End position of the Wait call node pub end_position: (usize, usize), } +// TODO: This should be refactored at a higher level, so this type can be removed. +// See https://github.com/awslabs/iam-policy-autopilot/issues/88. +enum CallInfo<'a> { + None(&'a WaiterInfo), + Simple(&'a WaiterInfo, &'a WaitCallInfo), +} + +impl<'a> CallInfo<'a> { + fn waiter_name(&self) -> &'a str { + match self { + Self::None(waiter_info) | Self::Simple(waiter_info, ..) => &waiter_info.waiter_name, + } + } + + fn waiter_info(&self) -> &'a WaiterInfo { + match self { + CallInfo::None(waiter_info) | CallInfo::Simple(waiter_info, _) => waiter_info, + } + } + + fn expr(&self) -> &'a str { + match self { + CallInfo::None(waiter_info) => &waiter_info.expr, + CallInfo::Simple(_, wait_call_info) => &wait_call_info.expr, + } + } + + fn file_path(&self) -> &'a PathBuf { + match self { + CallInfo::None(waiter_info) => &waiter_info.file_path, + CallInfo::Simple(_, wait_call_info) => &wait_call_info.file_path, + } + } + + fn start_position(&self) -> (usize, usize) { + match self { + CallInfo::None(waiter_info) => waiter_info.start_position, + CallInfo::Simple(_, wait_call_info) => wait_call_info.start_position, + } + } + + fn end_position(&self) -> (usize, usize) { + match self { + CallInfo::None(waiter_info) => waiter_info.end_position, + CallInfo::Simple(_, wait_call_info) => wait_call_info.end_position, + } + } +} + /// Extractor for Go AWS SDK waiter patterns /// /// This extractor discovers waiter patterns in Go code and creates synthetic @@ -57,7 +118,7 @@ impl<'a> GoWaiterExtractor<'a> { /// Extract waiter method calls from the AST pub(crate) fn extract_waiter_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { // Step 1: Find all waiter creation calls let waiters = self.find_waiter_creation_calls(ast); @@ -89,18 +150,17 @@ impl<'a> GoWaiterExtractor<'a> { } /// Find all waiter creation calls (NewXxxWaiter functions) - fn find_waiter_creation_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_waiter_creation_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut waiters = Vec::new(); // Pattern: $VAR := $PACKAGE.$FUNCTION($$$ARGS) where FUNCTION contains "New" and "Waiter" let waiter_pattern = "$VAR := $PACKAGE.$FUNCTION($$$ARGS)"; for node_match in root.find_all(waiter_pattern) { - if let Some(waiter_info) = self.parse_waiter_creation_call(&node_match) { + if let Some(waiter_info) = + self.parse_waiter_creation_call(&node_match, &ast.source_file.path) + { waiters.push(waiter_info); } } @@ -109,18 +169,15 @@ impl<'a> GoWaiterExtractor<'a> { } /// Find all Wait method calls - fn find_wait_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_wait_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut wait_calls = Vec::new(); // Pattern: $WAITER.Wait($$$ARGS) let wait_pattern = "$WAITER.Wait($$$ARGS)"; for node_match in root.find_all(wait_pattern) { - if let Some(wait_info) = self.parse_wait_call(&node_match) { + if let Some(wait_info) = self.parse_wait_call(&node_match, &ast.source_file.path) { wait_calls.push(wait_info); } } @@ -132,6 +189,7 @@ impl<'a> GoWaiterExtractor<'a> { fn parse_waiter_creation_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -161,13 +219,27 @@ impl<'a> GoWaiterExtractor<'a> { .and_then(|s| s.strip_suffix("Waiter")); if let Some(waiter_name) = waiter_name { - let creation_line = node_match.get_node().start_pos().line() + 1; + let start_position = { + let pos = node_match.get_node().start_pos(); + let line = pos.line() + 1; + let col = pos.column(node_match.get_node()) + 1; + (line, col) + }; + let end_position = { + let pos = node_match.get_node().end_pos(); + let line = pos.line() + 1; + let col = pos.column(node_match.get_node()) + 1; + (line, col) + }; return Some(WaiterInfo { variable_name, - waiter_type: waiter_name.to_string(), + waiter_name: waiter_name.to_string(), client_receiver, - creation_line, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), + start_position, + end_position, }); } @@ -178,6 +250,7 @@ impl<'a> GoWaiterExtractor<'a> { fn parse_wait_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -196,7 +269,8 @@ impl<'a> GoWaiterExtractor<'a> { Some(WaitCallInfo { waiter_var, arguments, - wait_line: start.line() + 1, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), start_position: (start.line() + 1, start.column(node) + 1), end_position: (end.line() + 1, end.column(node) + 1), }) @@ -216,8 +290,8 @@ impl<'a> GoWaiterExtractor<'a> { for (idx, waiter) in waiters.iter().enumerate() { if waiter.variable_name == wait_call.waiter_var { // Only consider waiters that come before the wait call - if waiter.creation_line < wait_call.wait_line { - let distance = wait_call.wait_line - waiter.creation_line; + if waiter.start_position.0 < wait_call.start_position.0 { + let distance = wait_call.start_position.0 - waiter.start_position.0; if distance < best_distance { best_distance = distance; best_match = Some(waiter); @@ -232,36 +306,27 @@ impl<'a> GoWaiterExtractor<'a> { fn create_synthetic_call_internal( &self, - wait_call: Option<&WaitCallInfo>, - waiter_info: &WaiterInfo, + call: CallInfo, + // wait_call: Option<&WaitCallInfo>, + // waiter_info: &WaiterInfo, ) -> Vec { let mut synthetic_calls = Vec::new(); // waiter_type already contains the clean waiter name (e.g., "InstanceTerminated") - if let Some(service_defs) = self - .service_index - .waiter_lookup - .get(&waiter_info.waiter_type) - { + if let Some(service_defs) = self.service_index.waiter_lookup.get(call.waiter_name()) { // Create one call per service for service_def in service_defs { let service_name = &service_def.service_name; let operation_name = &service_def.operation_name; - let (parameters, start_position, end_position) = match wait_call { - Some(wait_call) => ( - self.filter_waiter_parameters(wait_call.arguments.clone()), - wait_call.start_position, - wait_call.end_position, - ), - None => { + let parameters = match call { + CallInfo::Simple(_, wait_call) => { + self.filter_waiter_parameters(wait_call.arguments.clone()) + } + CallInfo::None(_) => { // Fallback: - ( - // Get required parameters for this operation - self.get_required_parameters(service_name, operation_name), - (waiter_info.creation_line, 1), - (waiter_info.creation_line, 1), - ) + // Get required parameters for this operation + self.get_required_parameters(service_name, operation_name) } }; @@ -271,9 +336,11 @@ impl<'a> GoWaiterExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, - start_position, - end_position, - receiver: Some(waiter_info.client_receiver.clone()), + expr: call.expr().to_string(), + file_path: call.file_path().clone(), + start_position: call.start_position(), + end_position: call.end_position(), + receiver: Some(call.waiter_info().client_receiver.clone()), }), }); } @@ -288,13 +355,13 @@ impl<'a> GoWaiterExtractor<'a> { wait_call: &WaitCallInfo, waiter_info: &WaiterInfo, ) -> Vec { - self.create_synthetic_call_internal(Some(wait_call), waiter_info) + self.create_synthetic_call_internal(CallInfo::Simple(waiter_info, wait_call)) } /// Create fallback synthetic calls for unmatched waiter creation /// Returns one call per service that has the waiter, matching Python behavior fn create_fallback_synthetic_call(&self, waiter_info: &WaiterInfo) -> Vec { - self.create_synthetic_call_internal(None, waiter_info) + self.create_synthetic_call_internal(CallInfo::None(waiter_info)) } /// Get required parameters for an operation from the service index @@ -336,16 +403,18 @@ impl<'a> GoWaiterExtractor<'a> { #[cfg(test)] mod tests { use crate::extraction::sdk_model::ServiceMethodRef; + use crate::{Language, SourceFile}; use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Go; - use std::collections::HashMap; + use std::{collections::HashMap, path::PathBuf}; - fn create_test_ast( - source_code: &str, - ) -> ast_grep_core::AstGrep> { - Go.ast_grep(source_code) + fn create_test_ast(source_code: &str) -> AstWithSourceFile { + let source_file = + SourceFile::with_language(PathBuf::new(), source_code.to_string(), Language::Go); + let ast_grep = Go.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file) } fn create_test_service_index() -> ServiceModelIndex { @@ -473,14 +542,14 @@ func main() { for waiter in &waiters { println!( " - {} := {}.{}", - waiter.variable_name, waiter.client_receiver, waiter.waiter_type + waiter.variable_name, waiter.client_receiver, waiter.waiter_name ); } // Should find waiter creation calls with correct Go SDK pattern assert_eq!(waiters.len(), 2); assert_eq!(waiters[0].variable_name, "instanceWaiter"); - assert_eq!(waiters[0].waiter_type, "InstanceRunning"); + assert_eq!(waiters[0].waiter_name, "InstanceRunning"); assert_eq!(waiters[0].client_receiver, "client"); // Client parameter, not package name } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/argument_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/argument_extractor.rs index c6ebda0..f59a64e 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/argument_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/argument_extractor.rs @@ -4,6 +4,8 @@ //! from ast-grep nodes, handling object literals, shorthand properties, //! spread operators, and proper value resolution (literal vs identifier). +use ast_grep_core::tree_sitter::StrDoc; + use crate::extraction::{Parameter, ParameterValue}; /// Utility for extracting arguments from JavaScript/TypeScript AST nodes @@ -19,10 +21,10 @@ impl ArgumentExtractor { /// /// Returns a vector of Parameters with proper Resolved/Unresolved classification pub fn extract_object_parameters( - args_node: Option<&ast_grep_core::Node>, + args_node: Option<&ast_grep_core::Node>>, ) -> Vec where - T: ast_grep_core::Doc, + T: ast_grep_language::LanguageExt, { let Some(node) = args_node else { return Vec::new(); diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs index 10bd3a3..f30416e 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs @@ -10,7 +10,8 @@ use std::collections::HashSet; use crate::extraction::extractor::{Extractor, ExtractorResult}; use crate::extraction::javascript::scanner::ASTScanner; use crate::extraction::javascript::shared::ExtractionUtils; -use crate::ServiceModelIndex; +use crate::extraction::AstWithSourceFile; +use crate::{ServiceModelIndex, SourceFile}; /// JavaScript extractor for AWS SDK method calls pub(crate) struct JavaScriptExtractor; @@ -30,9 +31,10 @@ impl Default for JavaScriptExtractor { #[async_trait] impl Extractor for JavaScriptExtractor { - async fn parse(&self, source_code: &str) -> ExtractorResult { + async fn parse(&self, source_file: &SourceFile) -> ExtractorResult { // Create AST once and reuse it - let ast = JavaScript.ast_grep(source_code); + let ast_grep = JavaScript.ast_grep(&source_file.content); + let ast = AstWithSourceFile::new(ast_grep, source_file.clone()); // Create scanner with the pre-built AST let mut scanner = ASTScanner::new(ast.clone(), JavaScript.into()); @@ -209,6 +211,14 @@ mod tests { } } + fn create_source_file(source_code: &str) -> SourceFile { + SourceFile::with_language( + std::path::PathBuf::new(), + source_code.to_string(), + crate::Language::JavaScript, + ) + } + #[tokio::test] async fn test_parse_import_via_require() { let extractor = JavaScriptExtractor::new(); @@ -228,7 +238,7 @@ async function createMyBucket() { createMyBucket(); "#; - let result = extractor.parse(source_code).await; + let result = extractor.parse(&create_source_file(source_code)).await; let method_calls = result.method_calls_ref(); // Should infer CreateBucket operation from CreateBucketCommand import @@ -300,7 +310,7 @@ async function getAllDynamoDBTables() { getAllDynamoDBTables(); "#; - let result = extractor.parse(source_code).await; + let result = extractor.parse(&create_source_file(source_code)).await; let method_calls = result.method_calls_ref(); // Should infer ListTables operation from paginateListTables import @@ -353,7 +363,7 @@ const bodyAsString = await bodyStream.transformToString(); const __error__ = await bodyStream.transformToString(); "#; - let result = extractor.parse(source_code).await; + let result = extractor.parse(&create_source_file(source_code)).await; let method_calls = result.method_calls_ref(); // Should find GetObject operation from direct client method call @@ -398,7 +408,7 @@ const command = new GetObjectCommand({ Bucket: 'test', Key: 'test.txt' }); "#; // Parse the code - let mut results = vec![extractor.parse(javascript_code).await]; + let mut results = vec![extractor.parse(&create_source_file(javascript_code)).await]; // Build service index with all services for testing let service_index = ServiceDiscovery::load_service_index(Language::JavaScript) @@ -470,7 +480,7 @@ const command = new GetObjectCommand({ Bucket: 'test', Key: 'test.txt' }); "#; // Parse the code - let mut results = vec![extractor.parse(code).await]; + let mut results = vec![extractor.parse(&create_source_file(code)).await]; // Load service index let service_index = ServiceDiscovery::load_service_index(Language::JavaScript) diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs index 80dc01e..81180b1 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs @@ -1,16 +1,24 @@ //! Core JavaScript/TypeScript scanning logic for AWS SDK extraction +use crate::extraction::javascript::shared::CommandUsage; use crate::extraction::javascript::types::{ ClientInstantiation, ImportInfo, JavaScriptScanResults, MethodCall, SublibraryInfo, ValidClientTypes, }; +use crate::extraction::AstWithSourceFile; use ast_grep_core::matcher::Pattern; -use ast_grep_core::MatchStrictness; +use ast_grep_core::tree_sitter::StrDoc; +use ast_grep_core::Doc; +use ast_grep_core::{tree_sitter, MatchStrictness, NodeMatch}; use std::collections::HashMap; -fn parse_import_item_with_line(import_item: &str, line: usize) -> Option { +fn parse_import_item_with_span( + import_item: &str, + start_position: (usize, usize), + end_position: (usize, usize), +) -> Option { let import_item = import_item.trim(); if import_item.is_empty() { return None; @@ -20,11 +28,23 @@ fn parse_import_item_with_line(import_item: &str, line: usize) -> Option) { } } -fn parse_and_add_imports_with_line( +fn parse_and_add_imports_with_span( imports_text: &str, sublibrary_info: &mut SublibraryInfo, - line: usize, + start_position: (usize, usize), + end_position: (usize, usize), ) { // Handle different import formats if imports_text.starts_with('{') && imports_text.ends_with('}') { @@ -126,13 +147,17 @@ fn parse_and_add_imports_with_line( // Split by comma and parse each import for import_item in imports_content.split(',') { - if let Some(import_info) = parse_import_item_with_line(import_item, line) { + if let Some(import_info) = + parse_import_item_with_span(import_item, start_position, end_position) + { sublibrary_info.add_import(import_info); } } } else { // Default import - single identifier - if let Some(import_info) = parse_import_item_with_line(imports_text, line) { + if let Some(import_info) = + parse_import_item_with_span(imports_text, start_position, end_position) + { sublibrary_info.add_import(import_info); } } @@ -141,20 +166,20 @@ fn parse_and_add_imports_with_line( /// Core AST scanner for JavaScript/TypeScript AWS SDK usage patterns pub(crate) struct ASTScanner where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { /// Pre-built AST grep root passed from extractor - ast_grep: ast_grep_core::AstGrep, - language: ast_grep_language::SupportLang, + pub(crate) ast_grep: AstWithSourceFile, + pub(crate) language: ast_grep_language::SupportLang, } impl ASTScanner where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { /// Create a new scanner with pre-built AST from extractor pub(crate) fn new( - ast_grep: ast_grep_core::AstGrep, + ast_grep: AstWithSourceFile, language: ast_grep_language::SupportLang, ) -> Self { Self { ast_grep, language } @@ -164,8 +189,8 @@ where fn find_all_matches( &self, pattern: &str, - ) -> Result>, String> { - let root = self.ast_grep.root(); + ) -> Result>>, String> { + let root = self.ast_grep.ast.root(); // Build pattern with relaxed strictness to handle inline comments let pattern_obj = @@ -175,29 +200,41 @@ where } /// Extract 1-based (line, column) position from the first match - fn get_first_match_position(matches: &[ast_grep_core::NodeMatch]) -> Option<(usize, usize)> { + fn get_first_match_span( + matches: &[ast_grep_core::NodeMatch>], + ) -> Option<((usize, usize), (usize, usize))> { matches.first().map(|first_match| { let node = first_match.get_node(); - let pos = node.start_pos(); - let line = pos.line() + 1; - let column = pos.column(node) + 1; - (line, column) + let start_pos = { + let pos = node.start_pos(); + let line = pos.line() + 1; + let column = pos.column(node) + 1; + (line, column) + }; + let end_pos = { + let pos = node.end_pos(); + let line = pos.line() + 1; + let column = pos.column(node) + 1; + (line, column) + }; + (start_pos, end_pos) }) } /// Find Command instantiation and extract its arguments - /// Returns ((line_number, column_number), parameters) tuple + /// Returns CommandInstantiationResult with position and parameters pub(crate) fn find_command_instantiation_with_args( &self, command_name: &str, - ) -> Option<((usize, usize), Vec)> { + ) -> Option> { use crate::extraction::javascript::argument_extractor::ArgumentExtractor; let pattern = format!("new {}($ARGS)", command_name); if let Ok(matches) = self.find_all_matches(&pattern) { - if let Some(position) = Self::get_first_match_position(&matches) { + if let Some(span) = Self::get_first_match_span(&matches) { let first_match = matches.first().unwrap(); + let env = first_match.get_env(); // Extract arguments from the ARGS node @@ -205,7 +242,12 @@ where let args_node = env.get_match("ARGS"); let parameters = ArgumentExtractor::extract_object_parameters(args_node); - return Some((position, parameters)); + return Some(CommandUsage::new( + first_match.text(), + span.0, + span.1, + parameters, + )); } } None @@ -215,14 +257,14 @@ where pub(crate) fn find_paginate_function_with_args( &self, function_name: &str, - ) -> Option<((usize, usize), Vec)> { + ) -> Option> { use crate::extraction::javascript::argument_extractor::ArgumentExtractor; // Use explicit two-argument pattern let pattern = format!("{}($ARG1, $ARG2)", function_name); if let Ok(matches) = self.find_all_matches(&pattern) { - if let Some(position) = Self::get_first_match_position(&matches) { + if let Some(span) = Self::get_first_match_span(&matches) { let first_match = matches.first().unwrap(); let env = first_match.get_env(); @@ -230,7 +272,12 @@ where let second_arg = env.get_match("ARG2"); let parameters = ArgumentExtractor::extract_object_parameters(second_arg); - return Some((position, parameters)); + return Some(CommandUsage::new( + first_match.text(), + span.0, + span.1, + parameters, + )); } } None @@ -240,7 +287,7 @@ where pub(crate) fn find_waiter_function_with_args( &self, function_name: &str, - ) -> Option<((usize, usize), Vec)> { + ) -> Option> { use crate::extraction::javascript::argument_extractor::ArgumentExtractor; // Try patterns with and without await keyword using explicit two-argument pattern @@ -251,7 +298,7 @@ where for pattern in &patterns { if let Ok(matches) = self.find_all_matches(pattern) { - if let Some(position) = Self::get_first_match_position(&matches) { + if let Some(span) = Self::get_first_match_span(&matches) { let first_match = matches.first().unwrap(); let env = first_match.get_env(); @@ -259,7 +306,12 @@ where let second_arg = env.get_match("ARG2"); let parameters = ArgumentExtractor::extract_object_parameters(second_arg); - return Some((position, parameters)); + return Some(CommandUsage::new( + first_match.text(), + span.0, + span.1, + parameters, + )); } } } @@ -270,7 +322,7 @@ where pub(crate) fn find_command_input_usage_position( &self, type_name: &str, - ) -> Option<(usize, usize)> { + ) -> Option> { // Try multiple patterns for TypeScript type annotations let patterns = [ format!("const $VAR: {} = $VALUE", type_name), // const variable: Type = value @@ -280,8 +332,11 @@ where for pattern in &patterns { if let Ok(matches) = self.find_all_matches(pattern) { - if let Some(position) = Self::get_first_match_position(&matches) { - return Some(position); + if let Some(span) = Self::get_first_match_span(&matches) { + let expr_text = matches.first().unwrap().text(); + // TODO: Extract from variable assignments + let parameters = vec![]; + return Some(CommandUsage::new(expr_text, span.0, span.1, parameters)); } } } @@ -293,19 +348,18 @@ where let mut sublibrary_data: HashMap = HashMap::new(); let matches = self.find_all_matches(pattern)?; - Self::process_import_matches(matches, &mut sublibrary_data, true)?; + Self::process_import_matches(matches, &mut sublibrary_data)?; Ok(sublibrary_data.into_values().collect()) } /// Generic processing for import/require matches - works for both JavaScript and TypeScript fn process_import_matches( - matches: Vec>, + matches: Vec>, sublibrary_data: &mut HashMap, - include_line_numbers: bool, ) -> Result<(), String> where - U: ast_grep_core::Doc + std::clone::Clone, + U: Doc + std::clone::Clone, { for node_match in matches { let env = node_match.get_env(); @@ -331,13 +385,24 @@ where .entry(sublibrary.clone()) .or_insert_with(|| SublibraryInfo::new(sublibrary)); - if include_line_numbers { - // Get line number from AST node - let line = node_match.get_node().start_pos().line() + 1; - parse_and_add_imports_with_line(imports_text_str, sublibrary_info, line); - } else { - parse_and_add_imports_with_line(imports_text_str, sublibrary_info, 1); - } + let start_position = { + let pos = node_match.get_node().start_pos(); + let line = pos.line() + 1; + let column = pos.column(imports_node) + 1; + (line, column) + }; + let end_position = { + let pos = node_match.get_node().end_pos(); + let line = pos.line() + 1; + let column = pos.column(imports_node) + 1; + (line, column) + }; + parse_and_add_imports_with_span( + imports_text_str, + sublibrary_info, + start_position, + end_position, + ); } } Ok(()) @@ -448,14 +513,14 @@ where /// Generic processing for client instantiation matches - works for both JavaScript and TypeScript fn process_client_instantiation_matches( - matches: Vec>, + matches: Vec>, valid_client_types: &[String], client_name_mappings: &HashMap, client_sublibrary_mappings: &HashMap, results: &mut Vec, ) -> Result<(), String> where - U: ast_grep_core::Doc + std::clone::Clone, + U: Doc + std::clone::Clone, { for node_match in matches { let env = node_match.get_env(); @@ -506,13 +571,14 @@ where /// Generic processing for method call matches - works for both JavaScript and TypeScript fn process_method_call_matches( - matches: Vec>, + &self, + matches: Vec>, client_variables: &[String], client_info_map: &HashMap, results: &mut Vec, ) -> Result<(), String> where - U: ast_grep_core::Doc + std::clone::Clone, + U: Doc + std::clone::Clone, { for node_match in matches { let env = node_match.get_env(); @@ -538,17 +604,26 @@ where HashMap::new() }; - // Get line number - let line = node_match.get_node().start_pos().line() + 1; + let start_position = { + let pos = node_match.get_node().start_pos(); + (pos.line() + 1, pos.column(node_match.get_node()) + 1) + }; + let end_position = { + let pos = node_match.get_node().end_pos(); + (pos.line() + 1, pos.column(node_match.get_node()) + 1) + }; results.push(MethodCall { client_variable: variable_name, client_type: client_type.clone(), original_client_type: original_client_type.clone(), client_sublibrary: client_sublibrary.clone(), + expr: node_match.text().to_string(), + file_path: self.ast_grep.source_file.path.clone(), method_name, arguments, - line, + start_position, + end_position, }); } } @@ -585,7 +660,7 @@ where // Single pattern to match method calls (covers both awaited and non-awaited) let matches = self.find_all_matches("$VAR.$METHOD($ARGS)")?; - Self::process_method_call_matches( + self.process_method_call_matches( matches, &client_variables, &client_info_map, @@ -614,20 +689,25 @@ where #[cfg(test)] mod tests { + use std::path::PathBuf; + + use crate::SourceFile; + use super::*; - use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::{JavaScript, TypeScript}; + use tree_sitter::LanguageExt; #[test] fn test_parse_import_item() { // Test regular import - let import_info = parse_import_item_with_line("S3Client", 1).unwrap(); + let import_info = parse_import_item_with_span("S3Client", (1, 1), (1, 1)).unwrap(); assert_eq!(import_info.original_name, "S3Client"); assert_eq!(import_info.local_name, "S3Client"); assert!(!import_info.is_renamed); // Test renamed import - let import_info = parse_import_item_with_line("S3Client as MyS3Client", 1).unwrap(); + let import_info = + parse_import_item_with_span("S3Client as MyS3Client", (1, 1), (1, 1)).unwrap(); assert_eq!(import_info.original_name, "S3Client"); assert_eq!(import_info.local_name, "MyS3Client"); assert!(import_info.is_renamed); @@ -644,6 +724,16 @@ mod tests { assert!(result.is_empty()); } + fn create_js_ast(source_code: &str) -> AstWithSourceFile { + let source_file = SourceFile::with_language( + PathBuf::new(), + source_code.to_string(), + crate::Language::JavaScript, + ); + let ast_grep = JavaScript.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file.clone()) + } + #[test] fn test_import_require_scanning_comprehensive() { // Create comprehensive test case with multiple sublibrary patterns @@ -655,7 +745,7 @@ const { LambdaClient, InvokeCommand } = require("@aws-sdk/client-lambda"); const { SESClient } = require("@aws-sdk/client-ses"); "#; - let ast = JavaScript.ast_grep(source); + let ast = create_js_ast(source); let mut scanner = ASTScanner::new(ast, JavaScript.into()); let (imports, requires) = scanner.scan_all_aws_imports().unwrap(); @@ -837,7 +927,7 @@ async function uploadFile() { } "#; - let ast = JavaScript.ast_grep(source_with_usage); + let ast = create_js_ast(source_with_usage); let scanner = ASTScanner::new(ast, JavaScript.into()); // Should find CreateBucketCommand instantiation at line ~6 @@ -888,7 +978,7 @@ async function listAllTables() { } "#; - let ast = JavaScript.ast_grep(source_with_usage); + let ast = create_js_ast(source_with_usage); let scanner = ASTScanner::new(ast, JavaScript.into()); // Should find paginateQuery call at line ~7 @@ -909,6 +999,16 @@ async function listAllTables() { println!("✅ Paginate function position heuristics working correctly"); } + fn create_ts_ast(source_code: &str) -> AstWithSourceFile { + let source_file = SourceFile::with_language( + PathBuf::new(), + source_code.to_string(), + crate::Language::TypeScript, + ); + let ast_grep = TypeScript.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file.clone()) + } + #[test] fn test_position_heuristics_command_input_typescript() { // Test CommandInput type usage position finding (TypeScript-specific) @@ -933,24 +1033,26 @@ function createListParams(): ListTablesInput { } "#; - let ast = TypeScript.ast_grep(typescript_source); + let ast = create_ts_ast(typescript_source); let scanner = ASTScanner::new(ast, TypeScript.into()); - let query_input_pos = scanner.find_command_input_usage_position("QueryCommandInput"); - assert!( - query_input_pos.is_some(), - "Should find QueryCommandInput usage" - ); - let (line, _col) = query_input_pos.unwrap(); - assert_eq!(line, 9, "QueryCommandInput should be at line 9"); + if let Some(result) = scanner.find_command_input_usage_position("QueryCommandInput") { + assert_eq!( + result.start_position.0, 9, + "QueryCommandInput should be at line 9" + ); + } else { + panic!("Should find QueryCommandInput usage"); + } - let list_input_pos = scanner.find_command_input_usage_position("ListTablesInput"); - assert!( - list_input_pos.is_some(), - "Should find ListTablesInput usage" - ); - let (line, _col) = list_input_pos.unwrap(); - assert_eq!(line, 15, "ListTablesInput should be at line 15"); + if let Some(result) = scanner.find_command_input_usage_position("ListTablesInput") { + assert_eq!( + result.start_position.0, 15, + "ListTablesInput should be at line 15" + ); + } else { + panic!("Should find ListTablesInput usage"); + } // Should return None for type that wasn't used let missing_type_pos = scanner.find_command_input_usage_position("PutItemInput"); @@ -971,7 +1073,7 @@ const { CreateBucketCommand } = require("@aws-sdk/client-s3"); const command = new CreateBucketCommand({ Bucket: "test" }); "#; - let ast = JavaScript.ast_grep(javascript_source); + let ast = create_js_ast(javascript_source); let scanner = ASTScanner::new(ast, JavaScript.into()); // JavaScript should find command instantiation @@ -1010,7 +1112,7 @@ let dynamoSdk = require("@aws-sdk/lib-dynamodb"); var ec2Sdk = require("@aws-sdk/client-ec2"); "#; - let ast = JavaScript.ast_grep(source_with_mixed_requires); + let ast = create_js_ast(source_with_mixed_requires); let mut scanner = ASTScanner::new(ast, JavaScript.into()); let (imports, requires) = scanner.scan_all_aws_imports().unwrap(); @@ -1159,7 +1261,7 @@ async function testOperations() { } "#; - let ast = TypeScript.ast_grep(typescript_source); + let ast = create_ts_ast(typescript_source); let mut scanner = ASTScanner::new(ast, TypeScript.into()); let scan_results = scanner.scan_all().unwrap(); @@ -1282,7 +1384,7 @@ async function uploadLargeFile() { } "#; - let ast = TypeScript.ast_grep(typescript_source); + let ast = create_ts_ast(typescript_source); let mut scanner = ASTScanner::new(ast, TypeScript.into()); let scan_results = scanner.scan_all().unwrap(); diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs index ce8e105..b0223e1 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs @@ -3,10 +3,11 @@ //! This module contains common functionality shared between JavaScript and TypeScript //! extractors. -use crate::extraction::javascript::types::JavaScriptScanResults; +use crate::extraction::javascript::types::{ImportInfo, JavaScriptScanResults}; use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; use rust_embed::RustEmbed; use serde::Deserialize; +use std::borrow::Cow; use std::collections::HashMap; /// Embedded JavaScript SDK v3 libraries mapping @@ -48,6 +49,50 @@ fn load_libraries_mapping() -> Option { serde_json::from_str(content).ok() } +/// Result of finding a command/function instantiation with its arguments +#[derive(Debug, Clone)] +pub(crate) struct CommandUsage<'a> { + /// The matched text from the AST + pub(crate) text: Cow<'a, str>, + /// Start position in the source file (line, column) - 1-based + pub(crate) start_position: (usize, usize), + /// End position in the source file (line, column) - 1-based + pub(crate) end_position: (usize, usize), + /// Extracted parameters from the command/function arguments + pub(crate) parameters: Vec, +} + +impl<'a> CommandUsage<'a> { + /// Create a new CommandInstantiationResult + pub(crate) fn new( + text: Cow<'a, str>, + start_position: (usize, usize), + end_position: (usize, usize), + parameters: Vec, + ) -> Self { + Self { + text, + start_position, + end_position, + parameters, + } + } +} + +// Used when we cannot find a method call, and fall back to adding an operation purely based on an import statement +impl From<&ImportInfo> for CommandUsage<'_> { + fn from(value: &ImportInfo) -> Self { + Self { + text: Cow::Owned(value.statement.clone()), + start_position: value.start_position, + end_position: value.end_position, + // TODO: parameters should be an Option, so we can distinguish + // the case where we fall back to an import statement + parameters: vec![], + } + } +} + /// Shared extraction utilities for JavaScript/TypeScript AWS SDK method calls pub(crate) struct ExtractionUtils; @@ -58,7 +103,7 @@ impl ExtractionUtils { scanner: &mut crate::extraction::javascript::scanner::ASTScanner, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut method_calls = Vec::new(); @@ -105,7 +150,7 @@ impl ExtractionUtils { lib_mappings: Option<&JsV3LibrariesMapping>, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut operations = Vec::new(); @@ -124,9 +169,9 @@ impl ExtractionUtils { if import_info.original_name.ends_with("Command") { // Try to find the actual constructor instantiation with arguments // Use the local name for the search (handles renames) - let (usage_position, parameters) = scanner + let result = scanner .find_command_instantiation_with_args(&import_info.local_name) - .unwrap_or_else(|| ((import_info.line, 1), Vec::new())); // Fallback to import position with no params + .unwrap_or_else(|| import_info.into()); // Fallback to import position with no params // Check if this needs library expansion (lib-* sublibraries) let expanded_command_names = @@ -155,15 +200,17 @@ impl ExtractionUtils { // Extract operation name by removing "Command" suffix if let Some(operation_name) = command_name.strip_suffix("Command") { // Keep PascalCase operation name to match service index - // e.g., "PutItem" from "PutItemCommand" + // e.g., "CreateBucket" stays "CreateBucket" let method_call = SdkMethodCall { name: operation_name.to_string(), possible_services: vec![service.clone()], metadata: Some(SdkMethodCallMetadata { - parameters: parameters.clone(), + parameters: result.parameters.clone(), return_type: None, - start_position: usage_position, - end_position: usage_position, + expr: result.text.to_string(), + file_path: scanner.ast_grep.source_file.path.clone(), + start_position: result.start_position, + end_position: result.end_position, receiver: None, // Commands are typically standalone }), }; @@ -185,7 +232,7 @@ impl ExtractionUtils { lib_mappings: Option<&JsV3LibrariesMapping>, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut operations = Vec::new(); @@ -204,9 +251,9 @@ impl ExtractionUtils { if import_info.original_name.starts_with("paginate") { // Try to find the actual paginate function call with arguments // Use the local name for the search (handles renames) - let (usage_position, parameters) = scanner + let result = scanner .find_paginate_function_with_args(&import_info.local_name) - .unwrap_or_else(|| ((import_info.line, 1), Vec::new())); // Fallback to import position with no params + .unwrap_or_else(|| import_info.into()); // Fallback to import position with no params // Check if this needs library expansion (lib-* sublibraries) let expanded_paginator_names = @@ -249,10 +296,12 @@ impl ExtractionUtils { name: operation_name, possible_services: vec![service.clone()], metadata: Some(SdkMethodCallMetadata { - parameters: parameters.clone(), // extracted from 2nd argument! + parameters: result.parameters.clone(), // extracted from 2nd argument! return_type: None, - start_position: usage_position, - end_position: usage_position, + expr: result.text.to_string(), + file_path: scanner.ast_grep.source_file.path.clone(), + start_position: result.start_position, + end_position: result.end_position, receiver: None, }), }; @@ -274,7 +323,7 @@ impl ExtractionUtils { scanner: &mut crate::extraction::javascript::scanner::ASTScanner, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut operations = Vec::new(); @@ -297,9 +346,9 @@ impl ExtractionUtils { { // Try to find the actual waiter function call with arguments // Use the local name for the search (handles renames) - let (usage_position, parameters) = scanner + let result = scanner .find_waiter_function_with_args(&import_info.local_name) - .unwrap_or_else(|| ((import_info.line, 1), Vec::new())); // Fallback to import position with no params + .unwrap_or_else(|| import_info.into()); // Fallback to import position with no params // Keep PascalCase waiter name // e.g., "BucketExists" from "waitUntilBucketExists" @@ -308,10 +357,12 @@ impl ExtractionUtils { name: waiter_name.to_string(), possible_services: vec![service.clone()], metadata: Some(SdkMethodCallMetadata { - parameters, // Extracted from 2nd argument (operation params) + parameters: result.parameters, // Extracted from 2nd argument (operation params) return_type: None, - start_position: usage_position, - end_position: usage_position, + expr: result.text.to_string(), + file_path: scanner.ast_grep.source_file.path.clone(), + start_position: result.start_position, + end_position: result.end_position, receiver: None, // Waiter functions are standalone }), }; @@ -331,7 +382,7 @@ impl ExtractionUtils { scanner: &mut crate::extraction::javascript::scanner::ASTScanner, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut operations = Vec::new(); @@ -357,9 +408,9 @@ impl ExtractionUtils { if let Some(operation_name) = operation_name { // Try to find the actual CommandInput type usage position (TypeScript-specific) // Use the local name for the search (handles renames) - let usage_position = scanner + let result = scanner .find_command_input_usage_position(&import_info.local_name) - .unwrap_or((import_info.line, 1)); // Fallback to import position + .unwrap_or_else(|| import_info.into()); // Fallback to import position with no params // Keep PascalCase operation name to match service index // e.g., "Query" stays "Query" @@ -367,10 +418,12 @@ impl ExtractionUtils { name: operation_name.to_string(), possible_services: vec![service.clone()], metadata: Some(SdkMethodCallMetadata { - parameters: Vec::new(), // TODO: Extract from variable assignments + parameters: Vec::new(), return_type: None, - start_position: usage_position, // Using enhanced position tracking - end_position: usage_position, // Using enhanced position tracking + expr: result.text.to_string(), + file_path: scanner.ast_grep.source_file.path.clone(), + start_position: result.start_position, + end_position: result.end_position, receiver: None, }), }; @@ -391,7 +444,7 @@ impl ExtractionUtils { lib_mappings: Option<&JsV3LibrariesMapping>, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut operations = Vec::new(); @@ -427,9 +480,9 @@ impl ExtractionUtils { .and_then(|lib| lib.get(&import_info.original_name)) { // Try to find class instantiation, fallback to import position - let (usage_position, parameters) = scanner + let result = scanner .find_command_instantiation_with_args(&import_info.local_name) - .unwrap_or_else(|| ((import_info.line, 1), Vec::new())); + .unwrap_or_else(|| import_info.into()); // Fallback to import position with no params // Create operations for each expanded command for command_name in expanded_commands { @@ -439,10 +492,12 @@ impl ExtractionUtils { name: operation_name.to_string(), possible_services: vec![service.clone()], metadata: Some(SdkMethodCallMetadata { - parameters: parameters.clone(), + parameters: result.parameters.clone(), return_type: None, - start_position: usage_position, - end_position: usage_position, + expr: result.text.to_string(), + file_path: scanner.ast_grep.source_file.path.clone(), + start_position: result.start_position, + end_position: result.end_position, receiver: None, }), }; @@ -490,8 +545,10 @@ impl ExtractionUtils { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, - start_position: (method_call.line, 1), - end_position: (method_call.line, 1), + expr: method_call.expr.clone(), + file_path: method_call.file_path.clone(), + start_position: method_call.start_position, + end_position: method_call.end_position, receiver: Some(method_call.client_variable.clone()), }), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs index 227c237..000063b 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs @@ -1,30 +1,42 @@ //! JavaScript/TypeScript specific data types for AWS SDK extraction use serde::{Deserialize, Serialize}; -use std::collections::HashMap; +use std::{collections::HashMap, path::PathBuf}; /// Information about a single import with rename support #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] pub(crate) struct ImportInfo { /// Original name in the AWS SDK (e.g., "S3Client", "PutObjectCommand") pub(crate) original_name: String, + /// Import statement + pub(crate) statement: String, /// Local name used in the code (e.g., "MyS3Client", "PutObject") pub(crate) local_name: String, /// Whether this import was renamed (original_name != local_name) pub(crate) is_renamed: bool, - /// Line number where this import appears - pub(crate) line: usize, + /// Start position of the import + pub(crate) start_position: (usize, usize), + /// End position of the import + pub(crate) end_position: (usize, usize), } impl ImportInfo { /// Create a new ImportInfo with the given names and line position - pub(crate) fn new(original_name: String, local_name: String, line: usize) -> Self { + pub(crate) fn new( + original_name: String, + local_name: String, + statement: &str, + start_position: (usize, usize), + end_position: (usize, usize), + ) -> Self { let is_renamed = original_name != local_name; Self { original_name, + statement: statement.to_string(), local_name, is_renamed, - line, + start_position, + end_position, } } } @@ -123,8 +135,14 @@ pub(crate) struct MethodCall { pub(crate) method_name: String, /// Method arguments pub(crate) arguments: HashMap, - /// Line number where call occurs - pub(crate) line: usize, + /// Matched expression + pub(crate) expr: String, + /// File where the method call was found + pub(crate) file_path: PathBuf, + /// Start position where call occurs + pub(crate) start_position: (usize, usize), + /// End position where call occurs + pub(crate) end_position: (usize, usize), } /// Combined results from all scanning operations diff --git a/iam-policy-autopilot-policy-generation/src/extraction/mod.rs b/iam-policy-autopilot-policy-generation/src/extraction/mod.rs index 91be7da..3121ad6 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/mod.rs @@ -27,10 +27,30 @@ pub use self::{core::*, output::*}; /// Core data structures for source file parsing and method extraction pub mod core { + use std::sync::Arc; + use crate::Language; use super::{Deserialize, Path, PathBuf, Serialize}; + #[derive(Clone)] + pub(crate) struct AstWithSourceFile { + pub(crate) ast: Arc>>, + pub(crate) source_file: Arc, + } + + impl AstWithSourceFile { + pub(crate) fn new( + ast: ast_grep_core::AstGrep>, + source_file: SourceFile, + ) -> Self { + Self { + ast: Arc::new(ast), + source_file: Arc::new(source_file), + } + } + } + /// Represents a source file being analyzed /// /// Contains the file path, content, and detected programming language. @@ -90,7 +110,12 @@ pub mod core { /// Return type annotation if available pub(crate) return_type: Option, + /// The matched expression + pub(crate) expr: String, + // Position information + /// File path + pub(crate) file_path: PathBuf, /// Starting position (line, column) - both 1-based pub(crate) start_position: (usize, usize), /// Ending position (line, column) - both 1-based @@ -334,6 +359,8 @@ mod tests { type_annotation: Some("str".to_string()), }], return_type: Some("bool".to_string()), + expr: "test_method".to_string(), + file_path: PathBuf::new(), start_position: (10, 1), end_position: (10, 25), receiver: None, @@ -359,6 +386,8 @@ mod tests { type_annotation: Some("str".to_string()), }], return_type: Some("Dict[str, Any]".to_string()), + expr: "get_object".to_string(), + file_path: PathBuf::new(), start_position: (15, 5), end_position: (15, 45), receiver: Some("s3_client".to_string()), @@ -424,6 +453,8 @@ mod tests { let metadata = SdkMethodCallMetadata { parameters: vec![], return_type: Some("Dict[str, Any]".to_string()), + expr: "s3_client.foo_bar".to_string(), + file_path: PathBuf::new(), start_position: (10, 5), end_position: (10, 30), receiver: Some("s3_client".to_string()), diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs index 9019cbe..3dc3559 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs @@ -200,6 +200,7 @@ mod tests { }; use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; use std::collections::HashMap; + use std::path::PathBuf; fn create_test_service_index() -> ServiceModelIndex { let mut services = HashMap::new(); @@ -297,6 +298,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -318,6 +320,7 @@ mod tests { }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 50), receiver: Some("client".to_string()), @@ -338,6 +341,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -348,6 +352,7 @@ mod tests { // Missing required Stage and ApiId parameters ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 30), receiver: Some("client".to_string()), @@ -367,11 +372,13 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![Parameter::DictionarySplat { expression: "**params".to_string(), position: 0, }], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 30), receiver: Some("client".to_string()), @@ -392,6 +399,7 @@ mod tests { name: "non_aws_method".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "non_aws_method".to_string(), parameters: vec![Parameter::Keyword { name: "custom_param".to_string(), value: ParameterValue::Resolved("value".to_string()), @@ -399,6 +407,7 @@ mod tests { type_annotation: None, }], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 30), receiver: Some("client".to_string()), @@ -421,6 +430,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -440,6 +450,7 @@ mod tests { }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 50), receiver: Some("client".to_string()), @@ -462,6 +473,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -477,6 +489,7 @@ mod tests { }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 50), receiver: Some("client".to_string()), diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs index e8ca62d..1983225 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs @@ -180,6 +180,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -201,6 +202,7 @@ mod tests { }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 80), receiver: Some("apigateway_client".to_string()), @@ -222,6 +224,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -232,6 +235,7 @@ mod tests { // Missing required Stage and ApiId parameters ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 40), receiver: Some("apigateway_client".to_string()), @@ -251,6 +255,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -278,6 +283,7 @@ mod tests { }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 100), receiver: Some("apigateway_client".to_string()), @@ -297,11 +303,13 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![Parameter::DictionarySplat { expression: "**params".to_string(), position: 0, }], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 50), receiver: Some("apigateway_client".to_string()), @@ -328,6 +336,7 @@ mod tests { name: "custom_method".to_string(), // Not an AWS SDK method possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "custom_method".to_string(), parameters: vec![Parameter::Keyword { name: "custom_param".to_string(), value: ParameterValue::Resolved("value".to_string()), @@ -335,6 +344,7 @@ mod tests { type_annotation: None, }], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 30), receiver: Some("custom_client".to_string()), @@ -356,6 +366,7 @@ mod tests { name: "get_object".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "get_object".to_string(), parameters: vec![ Parameter::Keyword { name: "Bucket".to_string(), @@ -371,6 +382,7 @@ mod tests { }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 50), receiver: Some("s3_client".to_string()), @@ -381,6 +393,7 @@ mod tests { name: "get_object".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "get_object".to_string(), parameters: vec![Parameter::Keyword { name: "custom_param".to_string(), // Invalid parameter for AWS S3 value: ParameterValue::Resolved("value".to_string()), @@ -388,6 +401,7 @@ mod tests { type_annotation: None, }], return_type: None, + file_path: PathBuf::new(), start_position: (2, 1), end_position: (2, 30), receiver: Some("custom_client".to_string()), @@ -398,6 +412,7 @@ mod tests { name: "custom_method".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "custom_method".to_string(), parameters: vec![Parameter::Keyword { name: "param".to_string(), value: ParameterValue::Resolved("value".to_string()), @@ -405,6 +420,7 @@ mod tests { type_annotation: None, }], return_type: None, + file_path: PathBuf::new(), start_position: (3, 1), end_position: (3, 25), receiver: Some("custom_client".to_string()), @@ -463,7 +479,7 @@ def example(): let extractor = PythonExtractor::new(); // Extract method calls using tree-sitter - let mut result = vec![extractor.parse(&source.content).await]; + let mut result = vec![extractor.parse(&source).await]; assert_eq!(result.first().unwrap().method_calls_ref().len(), 7); // Apply disambiguation @@ -510,6 +526,7 @@ def example(): name: "get_object".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "get_object".to_string(), parameters: vec![ Parameter::Keyword { name: "Bucket".to_string(), @@ -540,6 +557,7 @@ def example(): }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 80), receiver: Some("s3_client".to_string()), @@ -568,6 +586,7 @@ def example(): name: "get_object".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "get_object".to_string(), parameters: vec![ // Valid required parameters Parameter::Keyword { @@ -591,6 +610,7 @@ def example(): }, ], return_type: None, + file_path: PathBuf::new(), start_position: (1, 1), end_position: (1, 60), receiver: Some("s3_client".to_string()), diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs index d413b15..83160ad 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs @@ -6,8 +6,8 @@ use crate::extraction::python::disambiguation::MethodDisambiguator; use crate::extraction::python::paginator_extractor::PaginatorExtractor; use crate::extraction::python::resource_direct_calls_extractor::ResourceDirectCallsExtractor; use crate::extraction::python::waiters_extractor::WaitersExtractor; -use crate::extraction::{SdkMethodCall, SdkMethodCallMetadata}; -use crate::ServiceModelIndex; +use crate::extraction::{AstWithSourceFile, SdkMethodCall, SdkMethodCallMetadata}; +use crate::{ServiceModelIndex, SourceFile}; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Python; use async_trait::async_trait; @@ -24,6 +24,7 @@ impl PythonExtractor { fn parse_method_call( &self, node_match: &ast_grep_core::NodeMatch>, + source_file: &SourceFile, ) -> Option { let env = node_match.get_env(); @@ -54,6 +55,8 @@ impl PythonExtractor { metadata: Some(SdkMethodCallMetadata { parameters: arguments, return_type: None, // We don't know the return type from the call site + expr: node_match.text().to_string(), + file_path: source_file.path.clone(), start_position: (start.line() + 1, start.column(node) + 1), end_position: (end.line() + 1, end.column(node) + 1), receiver, @@ -73,9 +76,13 @@ impl Default for PythonExtractor { #[async_trait] impl Extractor for PythonExtractor { - async fn parse(&self, source_code: &str) -> crate::extraction::extractor::ExtractorResult { - let ast_grep = Python.ast_grep(source_code); - let root = ast_grep.root(); + async fn parse( + &self, + source_file: &SourceFile, + ) -> crate::extraction::extractor::ExtractorResult { + let ast_grep = Python.ast_grep(&source_file.content); + let ast = AstWithSourceFile::new(ast_grep, source_file.clone()); + let root = ast.ast.root(); let mut method_calls = Vec::new(); @@ -83,12 +90,12 @@ impl Extractor for PythonExtractor { // Find all method calls with attribute access: obj.method(args) for node_match in root.find_all(pattern) { - if let Some(method_call) = self.parse_method_call(&node_match) { + if let Some(method_call) = self.parse_method_call(&node_match, source_file) { method_calls.push(method_call); } } - ExtractorResult::Python(ast_grep, method_calls) + ExtractorResult::Python(ast, method_calls) } fn filter_map( @@ -148,15 +155,25 @@ impl Extractor for PythonExtractor { #[cfg(test)] mod tests { + use std::path::PathBuf; + use super::*; - use crate::extraction::{Parameter, ParameterValue}; + use crate::{ + extraction::{Parameter, ParameterValue}, + Language, + }; #[tokio::test] async fn test_basic_method_call_extraction() { let extractor = PythonExtractor::new(); let source_code = "s3_client.get_object(Bucket='my-bucket', Key='my-key')"; - let result = extractor.parse(source_code).await; + let source_file = SourceFile::with_language( + std::path::PathBuf::new(), + source_code.to_string(), + crate::Language::Python, + ); + let result = extractor.parse(&source_file).await; assert_eq!(result.method_calls_ref().len(), 1); assert_eq!(result.method_calls_ref()[0].name, "get_object"); } @@ -177,8 +194,9 @@ cloudwatch_client.put_metric_alarm( Threshold=0.0 ) "#; - - let result = extractor.parse(source_code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), source_code.to_string(), Language::Python); + let result = extractor.parse(&source_file).await; // Verify exactly one method call is extracted assert_eq!(result.method_calls_ref().len(), 1); diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs index 28b4b80..3723d6f 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs @@ -4,9 +4,11 @@ //! two-phase operations: creating a paginator from a client, then executing //! the paginator with operation arguments. +use std::path::{Path, PathBuf}; + use crate::extraction::python::common::{ArgumentExtractor, ParameterFilter}; use crate::extraction::sdk_model::ServiceDiscovery; -use crate::extraction::{Parameter, SdkMethodCall, SdkMethodCallMetadata}; +use crate::extraction::{AstWithSourceFile, Parameter, SdkMethodCall, SdkMethodCallMetadata}; use crate::Language; use crate::ServiceModelIndex; use ast_grep_language::Python; @@ -20,6 +22,10 @@ pub(crate) struct PaginatorInfo { pub operation_name: String, /// Client receiver variable name (e.g., "client", "s3_client") pub client_receiver: String, + /// Matched expression + pub expr: String, + /// File the paginator was found in + pub file_path: PathBuf, /// Line number where get_paginator was called pub get_paginator_line: usize, } @@ -33,6 +39,10 @@ pub(crate) struct PaginateCallInfo { pub arguments: Vec, /// Line number where paginate was called (preferred for position reporting) pub paginate_line: usize, + /// Matched expression + pub expr: String, + /// File the paginator was found in + pub file_path: PathBuf, /// Start position of the paginate call node pub start_position: (usize, usize), /// End position of the paginate call node @@ -51,6 +61,10 @@ pub(crate) struct ChainedPaginatorCallInfo { /// Line number where chained call was made #[allow(dead_code)] pub line: usize, + /// Matched expression + pub expr: String, + /// File the chained paginator call was found in + pub file_path: PathBuf, /// Start position of the chained call node pub start_position: (usize, usize), /// End position of the chained call node @@ -101,7 +115,7 @@ impl<'a> PaginatorExtractor<'a> { /// empty parameters, since paginators are often created but used elsewhere. pub(crate) fn extract_paginate_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { // Step 1: Find all get_paginator calls let paginators = self.find_get_paginator_calls(ast); @@ -141,18 +155,17 @@ impl<'a> PaginatorExtractor<'a> { } /// Find all get_paginator calls in the AST - fn find_get_paginator_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_get_paginator_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut paginators = Vec::new(); // Pattern: $PAGINATOR = $CLIENT.get_paginator($OPERATION $$$ARGS) let get_paginator_pattern = "$PAGINATOR = $CLIENT.get_paginator($OPERATION $$$ARGS)"; for node_match in root.find_all(get_paginator_pattern) { - if let Some(paginator_info) = self.parse_get_paginator_call(&node_match) { + if let Some(paginator_info) = + self.parse_get_paginator_call(&node_match, &ast.source_file.path) + { paginators.push(paginator_info); } } @@ -161,18 +174,17 @@ impl<'a> PaginatorExtractor<'a> { } /// Find all paginate calls in the AST - fn find_paginate_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_paginate_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut paginate_calls = Vec::new(); // Pattern: $PAGINATOR.paginate($$$ARGS) - flexible pattern without assignment requirement let paginate_pattern = "$PAGINATOR.paginate($$$ARGS)"; for node_match in root.find_all(paginate_pattern) { - if let Some(paginate_info) = self.parse_paginate_call(&node_match) { + if let Some(paginate_info) = + self.parse_paginate_call(&node_match, &ast.source_file.path) + { paginate_calls.push(paginate_info); } } @@ -183,9 +195,9 @@ impl<'a> PaginatorExtractor<'a> { /// Find all chained paginator calls in the AST fn find_chained_paginator_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut chained_calls = Vec::new(); // Pattern: $CLIENT.get_paginator($OPERATION $$$GET_ARGS).paginate($$$PAGINATE_ARGS) @@ -193,7 +205,9 @@ impl<'a> PaginatorExtractor<'a> { "$CLIENT.get_paginator($OPERATION $$$GET_ARGS).paginate($$$PAGINATE_ARGS)"; for node_match in root.find_all(chained_pattern) { - if let Some(chained_info) = self.parse_chained_paginator_call(&node_match) { + if let Some(chained_info) = + self.parse_chained_paginator_call(&node_match, &ast.source_file.path) + { chained_calls.push(chained_info); } } @@ -205,6 +219,7 @@ impl<'a> PaginatorExtractor<'a> { fn parse_get_paginator_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -227,6 +242,8 @@ impl<'a> PaginatorExtractor<'a> { operation_name, client_receiver, get_paginator_line, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), }) } @@ -234,6 +251,7 @@ impl<'a> PaginatorExtractor<'a> { fn parse_paginate_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -254,6 +272,8 @@ impl<'a> PaginatorExtractor<'a> { paginator_var, arguments: filtered_arguments, paginate_line: start.line() + 1, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), start_position: (start.line() + 1, start.column(node) + 1), end_position: (end.line() + 1, end.column(node) + 1), }) @@ -263,6 +283,7 @@ impl<'a> PaginatorExtractor<'a> { fn parse_chained_paginator_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -289,6 +310,8 @@ impl<'a> PaginatorExtractor<'a> { operation_name, arguments: filtered_arguments, line: start.line() + 1, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), start_position: (start.line() + 1, start.column(node) + 1), end_position: (end.line() + 1, end.column(node) + 1), }) @@ -360,7 +383,9 @@ impl<'a> PaginatorExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: Vec::new(), // Empty parameters for unmatched paginators return_type: None, + expr: paginator_info.expr.clone(), // Use get_paginator call position + file_path: paginator_info.file_path.clone(), start_position: (paginator_info.get_paginator_line, 1), end_position: (paginator_info.get_paginator_line, 1), receiver: Some(paginator_info.client_receiver.clone()), @@ -395,7 +420,9 @@ impl<'a> PaginatorExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: paginate_call.arguments.clone(), return_type: None, + expr: paginate_call.expr.clone(), // Use paginate call position (most specific) + file_path: paginate_call.file_path.clone(), start_position: paginate_call.start_position, end_position: paginate_call.end_position, // Use client receiver from get_paginator call @@ -430,7 +457,9 @@ impl<'a> PaginatorExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: chained_call.arguments.clone(), return_type: None, + expr: chained_call.expr.clone(), // Use chained call position + file_path: chained_call.file_path.clone(), start_position: chained_call.start_position, end_position: chained_call.end_position, // Use client receiver from chained call @@ -452,17 +481,21 @@ impl<'a> PaginatorExtractor<'a> { #[cfg(test)] mod tests { - use crate::extraction::ParameterValue; + use crate::{extraction::ParameterValue, SourceFile}; use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Python; use std::collections::HashMap; - fn create_test_ast( - source_code: &str, - ) -> ast_grep_core::AstGrep> { - Python.ast_grep(source_code) + fn create_test_ast(source_code: &str) -> AstWithSourceFile { + let source_file = SourceFile::with_language( + PathBuf::new(), + source_code.to_string(), + crate::Language::Python, + ); + let ast_grep = Python.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file.clone()) } fn create_test_service_index() -> ServiceModelIndex { diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs index 161d576..d65b1de 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs @@ -38,11 +38,14 @@ use crate::extraction::python::boto3_resources_model::{ Boto3ResourcesModel, Boto3ResourcesRegistry, HasManySpec, OperationType, }; use crate::extraction::python::common::ArgumentExtractor; -use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; +use crate::extraction::{ + AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, +}; use crate::ServiceModelIndex; use ast_grep_language::Python; use convert_case::{Case, Casing}; use std::collections::{HashMap, HashSet}; +use std::path::{Path, PathBuf}; /// Position tracking for deduplication (Tier 3) #[derive(Debug, Clone, Hash, Eq, PartialEq)] @@ -71,6 +74,8 @@ struct ResourceMethodCallInfo { method_name: String, arguments: Vec, method_call_line: usize, + expr: String, + file_path: PathBuf, start_position: (usize, usize), end_position: (usize, usize), } @@ -124,7 +129,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// **Tier 3**: Constructor only → maximum conservation with constructor position as evidence pub(crate) fn extract_resource_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { // Step 1: Find all resource constructors using service-agnostic matching let constructors = self.find_resource_constructors(ast, &self.registry); @@ -414,6 +419,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, + expr: method_call.expr.clone(), + file_path: method_call.file_path.clone(), start_position: method_call.start_position, end_position: method_call.end_position, receiver: Some(method_call.resource_var.clone()), @@ -439,10 +446,13 @@ impl<'a> ResourceDirectCallsExtractor<'a> { let mut synthetic_calls = Vec::new(); // Extract position from evidence source - let (start_pos, end_pos) = match evidence { - SyntheticEvidenceSource::UnmatchedMethod(ref method_call) => { - (method_call.start_position, method_call.end_position) - } + let (expr, file_path, start_pos, end_pos) = match evidence { + SyntheticEvidenceSource::UnmatchedMethod(ref method_call) => ( + method_call.expr.clone(), + method_call.file_path.clone(), + method_call.start_position, + method_call.end_position, + ), }; // Generate synthetic call for each action @@ -496,6 +506,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, + expr: expr.clone(), + file_path: file_path.clone(), start_position: start_pos, end_position: end_pos, receiver: Some(constructor.variable_name.clone()), // Use actual variable name from constructor @@ -509,10 +521,10 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// Find all resource constructor calls in the AST using service-agnostic matching fn find_resource_constructors( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, registry: &Boto3ResourcesRegistry, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut constructors = Vec::new(); // Service-agnostic pattern: $VAR = $ANY.$RESOURCE_TYPE($$$ARGS) @@ -582,15 +594,17 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// Find all method calls on potential resource objects fn find_resource_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut method_calls = Vec::new(); let method_call_pattern = "$RESULT = $RESOURCE_VAR.$METHOD($$$ARGS)"; for node_match in root.find_all(method_call_pattern) { - if let Some(method_call_info) = self.parse_resource_method_call(&node_match) { + if let Some(method_call_info) = + self.parse_resource_method_call(&node_match, &ast.source_file.path) + { method_calls.push(method_call_info); } } @@ -599,7 +613,9 @@ impl<'a> ResourceDirectCallsExtractor<'a> { let simple_method_pattern = "$RESOURCE_VAR.$METHOD($$$ARGS)"; for node_match in root.find_all(simple_method_pattern) { - if let Some(method_call_info) = self.parse_simple_resource_method_call(&node_match) { + if let Some(method_call_info) = + self.parse_simple_resource_method_call(&node_match, &ast.source_file.path) + { method_calls.push(method_call_info); } } @@ -624,6 +640,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { fn parse_resource_method_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -646,6 +663,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { resource_var, method_name, arguments, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), method_call_line: start.line() + 1, start_position: (start.line() + 1, start.column(node) + 1), end_position: (end.line() + 1, end.column(node) + 1), @@ -656,6 +675,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { fn parse_simple_resource_method_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -678,6 +698,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { resource_var, method_name, arguments, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), method_call_line: start.line() + 1, start_position: (start.line() + 1, start.column(node) + 1), end_position: (end.line() + 1, end.column(node) + 1), @@ -804,6 +826,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: combined_parameters, return_type: None, + expr: method_call.expr.clone(), + file_path: method_call.file_path.clone(), start_position: method_call.start_position, end_position: method_call.end_position, receiver: Some(method_call.resource_var.clone()), @@ -817,10 +841,10 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// Generates synthetic SdkMethodCall for the collection's operation at the access point fn find_and_generate_collection_synthetics( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, constructors: &[ResourceConstructorInfo], ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut synthetic_calls = Vec::new(); // Pattern: $VAR = $RESOURCE_VAR.$ATTR_NAME (with optional assignment) @@ -873,6 +897,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { if let Some(synthetic_call) = self.generate_synthetic_for_collection( constructor, has_many_spec, + node_match.text().to_string(), + &ast.source_file.path, (start.line() + 1, start.column(node) + 1), (end.line() + 1, end.column(node) + 1), ) { @@ -890,6 +916,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &self, constructor: &ResourceConstructorInfo, has_many_spec: &HasManySpec, + expr: String, + file_path: &Path, start_position: (usize, usize), end_position: (usize, usize), ) -> Option { @@ -925,6 +953,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, + expr: expr.clone(), + file_path: file_path.to_path_buf(), start_position, end_position, receiver: Some(constructor.variable_name.clone()), // Use actual variable name from constructor @@ -939,7 +969,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// all-synthetic parameters since we don't know the receiver. fn find_unmatched_utility_and_collection_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, matched_positions: &HashSet, ) -> Vec { let mut tier3_calls = Vec::new(); @@ -956,10 +986,10 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// Find utility method calls with unknown receivers (Tier 3) fn find_unmatched_utility_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, matched_positions: &HashSet, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut calls = Vec::new(); // Pattern for method calls @@ -1017,6 +1047,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &operation.operation, &arguments, &operation.required_params, + node_match.text().to_string(), + &ast.source_file.path, (start.line() + 1, start.column(node) + 1), (node.end_pos().line() + 1, node.end_pos().column(node) + 1), &receiver_var, // Use actual receiver from code @@ -1035,6 +1067,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &operation.operation, &arguments, &operation.required_params, + node_match.text().to_string(), + &ast.source_file.path, (start.line() + 1, start.column(node) + 1), (node.end_pos().line() + 1, node.end_pos().column(node) + 1), &receiver_var, // Use actual receiver from code @@ -1052,10 +1086,10 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// Find collection accesses with unknown receivers (Tier 3) fn find_unmatched_collection_accesses( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, matched_positions: &HashSet, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut calls = Vec::new(); // Patterns for attribute access (including chained method calls) @@ -1109,6 +1143,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &has_many_spec.identifier_params, ), return_type: None, + expr: node_match.text().to_string(), + file_path: ast.source_file.path.clone(), start_position: (start.line() + 1, start.column(node) + 1), end_position: ( node.end_pos().line() + 1, @@ -1134,6 +1170,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &service_has_many_spec.identifier_params, ), return_type: None, + expr: node_match.text().to_string(), + file_path: ast.source_file.path.clone(), start_position: (start.line() + 1, start.column(node) + 1), end_position: ( node.end_pos().line() + 1, @@ -1158,6 +1196,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { operation: &str, arguments: &[Parameter], required_params: &[String], + expr: String, + file_path: &Path, start_position: (usize, usize), end_position: (usize, usize), receiver_marker: &str, @@ -1194,6 +1234,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, + expr: expr.clone(), + file_path: file_path.to_path_buf(), start_position, end_position, receiver: Some(receiver_marker.to_string()), diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs index 3bd8fb8..d595932 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs @@ -4,8 +4,12 @@ //! two-phase operations: creating a waiter from a client, then calling wait() //! on the waiter with operation arguments. +use std::path::{Path, PathBuf}; + use crate::extraction::python::common::{ArgumentExtractor, ParameterFilter}; -use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; +use crate::extraction::{ + AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, +}; use crate::ServiceModelIndex; use ast_grep_language::Python; @@ -18,8 +22,14 @@ pub(crate) struct WaiterInfo { pub waiter_name: String, /// Client receiver variable name (e.g., "client", "ec2_client") pub client_receiver: String, - /// Line number where get_waiter was called - pub get_waiter_line: usize, + /// Matched expression + pub expr: String, + /// File where we found the waiter + pub file_path: PathBuf, + /// Start position where get_waiter was called + pub start_position: (usize, usize), + /// End position where get_waiter was called + pub end_position: (usize, usize), } // TODO: This should be refactored at a higher level, so this type can be removed. @@ -37,6 +47,38 @@ impl<'a> CallInfo<'a> { Self::Chained(waiter_call_info) => &waiter_call_info.waiter_name, } } + + fn expr(&self) -> &'a str { + match self { + CallInfo::None(waiter_info) => &waiter_info.expr, + CallInfo::Simple(_, wait_call_info) => &wait_call_info.expr, + CallInfo::Chained(chained_waiter_call_info) => &chained_waiter_call_info.expr, + } + } + + fn file_path(&self) -> &'a PathBuf { + match self { + CallInfo::None(waiter_info) => &waiter_info.file_path, + CallInfo::Simple(_, wait_call_info) => &wait_call_info.file_path, + CallInfo::Chained(chained_waiter_call_info) => &chained_waiter_call_info.file_path, + } + } + + fn start_position(&self) -> (usize, usize) { + match self { + CallInfo::None(waiter_info) => waiter_info.start_position, + CallInfo::Simple(_, wait_call_info) => wait_call_info.start_position, + CallInfo::Chained(chained_waiter_call_info) => chained_waiter_call_info.start_position, + } + } + + fn end_position(&self) -> (usize, usize) { + match self { + CallInfo::None(waiter_info) => waiter_info.end_position, + CallInfo::Simple(_, wait_call_info) => wait_call_info.end_position, + CallInfo::Chained(chained_waiter_call_info) => chained_waiter_call_info.end_position, + } + } } /// Information about a wait method call @@ -46,8 +88,10 @@ pub(crate) struct WaitCallInfo { pub waiter_var: String, /// Extracted arguments (including WaiterConfig) pub arguments: Vec, - /// Line number where wait was called - pub wait_line: usize, + /// Matched expression + pub expr: String, + /// File where we found the waiter + pub file_path: PathBuf, /// Start position of the wait call node pub start_position: (usize, usize), /// End position of the wait call node @@ -66,6 +110,10 @@ pub(crate) struct ChainedWaiterCallInfo { /// Line number where chained call was made #[allow(dead_code)] pub line: usize, + /// Matched expression + pub expr: String, + /// File we found the chained waiter call in + pub file_path: PathBuf, /// Start position of the chained call node pub start_position: (usize, usize), /// End position of the chained call node @@ -107,7 +155,7 @@ impl<'a> WaitersExtractor<'a> { /// 3. Unmatched wait: Ignored (no waiter context) pub(crate) fn extract_waiter_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { // Step 1: Find all get_waiter calls let waiters = self.find_get_waiter_calls(ast); @@ -150,18 +198,17 @@ impl<'a> WaitersExtractor<'a> { } /// Find all get_waiter calls in the AST - fn find_get_waiter_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_get_waiter_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut waiters = Vec::new(); // Pattern: $WAITER = $CLIENT.get_waiter($NAME $$$ARGS) let get_waiter_pattern = "$WAITER = $CLIENT.get_waiter($NAME $$$ARGS)"; for node_match in root.find_all(get_waiter_pattern) { - if let Some(waiter_info) = self.parse_get_waiter_call(&node_match) { + if let Some(waiter_info) = + self.parse_get_waiter_call(&node_match, &ast.source_file.path) + { waiters.push(waiter_info); } } @@ -170,11 +217,8 @@ impl<'a> WaitersExtractor<'a> { } /// Find all wait calls in the AST - fn find_wait_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_wait_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut wait_calls = Vec::new(); // Pattern: $WAITER.wait($$$ARGS) @@ -182,7 +226,7 @@ impl<'a> WaitersExtractor<'a> { let wait_pattern = "$WAITER.wait($$$ARGS)"; for node_match in root.find_all(wait_pattern) { - if let Some(wait_info) = self.parse_wait_call(&node_match) { + if let Some(wait_info) = self.parse_wait_call(&node_match, &ast.source_file.path) { wait_calls.push(wait_info); } } @@ -193,16 +237,18 @@ impl<'a> WaitersExtractor<'a> { /// Find all chained waiter calls in the AST fn find_chained_waiter_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut chained_calls = Vec::new(); // Pattern: $CLIENT.get_waiter($NAME $$$WAITER_ARGS).wait($$$WAIT_ARGS) let chained_pattern = "$CLIENT.get_waiter($NAME $$$WAITER_ARGS).wait($$$WAIT_ARGS)"; for node_match in root.find_all(chained_pattern) { - if let Some(chained_info) = self.parse_chained_waiter_call(&node_match) { + if let Some(chained_info) = + self.parse_chained_waiter_call(&node_match, &ast.source_file.path) + { chained_calls.push(chained_info); } } @@ -214,6 +260,7 @@ impl<'a> WaitersExtractor<'a> { fn parse_get_waiter_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -228,14 +275,18 @@ impl<'a> WaitersExtractor<'a> { let name_text = name_node.text(); let waiter_name = self.extract_quoted_string(&name_text)?; - // Get line number - let get_waiter_line = node_match.get_node().start_pos().line() + 1; + let node = node_match.get_node(); + let start = node.start_pos(); + let end = node.end_pos(); Some(WaiterInfo { variable_name, waiter_name, client_receiver, - get_waiter_line, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), + start_position: (start.line() + 1, start.column(node) + 1), + end_position: (end.line() + 1, end.column(node) + 1), }) } @@ -243,6 +294,7 @@ impl<'a> WaitersExtractor<'a> { fn parse_wait_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -261,7 +313,8 @@ impl<'a> WaitersExtractor<'a> { Some(WaitCallInfo { waiter_var, arguments, - wait_line: start.line() + 1, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), start_position: (start.line() + 1, start.column(node) + 1), end_position: (end.line() + 1, end.column(node) + 1), }) @@ -271,6 +324,7 @@ impl<'a> WaitersExtractor<'a> { fn parse_chained_waiter_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -296,6 +350,8 @@ impl<'a> WaitersExtractor<'a> { waiter_name, arguments, line: start.line() + 1, + expr: node_match.text().to_string(), + file_path: file_path.to_path_buf(), start_position: (start.line() + 1, start.column(node) + 1), end_position: (end.line() + 1, end.column(node) + 1), }) @@ -316,8 +372,8 @@ impl<'a> WaitersExtractor<'a> { for (idx, waiter) in waiters.iter().enumerate() { if waiter.variable_name == wait_call.waiter_var { // Only consider waiters that come before the wait call - if waiter.get_waiter_line < wait_call.wait_line { - let distance = wait_call.wait_line - waiter.get_waiter_line; + if waiter.start_position.0 < wait_call.start_position.0 { + let distance = wait_call.start_position.0 - waiter.start_position.0; if distance < best_distance { best_distance = distance; best_match = Some(waiter); @@ -346,36 +402,22 @@ impl<'a> WaitersExtractor<'a> { .get(wait_call.waiter_name()) { for service_method in service_defs { - let (parameters, start_position, end_position) = match wait_call { + let parameters = match wait_call { CallInfo::Simple(_, wait_call) => { // Filter out WaiterConfig from arguments - it's waiter-specific, not operation-specific - ( - ParameterFilter::filter_waiter_parameters(wait_call.arguments.clone()), - wait_call.start_position, - wait_call.end_position, - ) + ParameterFilter::filter_waiter_parameters(wait_call.arguments.clone()) } CallInfo::Chained(chained_wait_call) => { // Filter out WaiterConfig from arguments - it's waiter-specific, not operation-specific - ( - ParameterFilter::filter_waiter_parameters( - chained_wait_call.arguments.clone(), - ), - // Use wait call position (most specific) - chained_wait_call.start_position, - chained_wait_call.end_position, + ParameterFilter::filter_waiter_parameters( + chained_wait_call.arguments.clone(), ) } - CallInfo::None(waiter_info) => { - let fallback_start_pos = (waiter_info.get_waiter_line, 1); - let fallback_end_pos = (waiter_info.get_waiter_line, 1); - let parameters = self.get_required_parameters( - &service_method.service_name, - &service_method.operation_name, - self.service_index, - ); - (parameters, fallback_start_pos, fallback_end_pos) - } + CallInfo::None(_) => self.get_required_parameters( + &service_method.service_name, + &service_method.operation_name, + self.service_index, + ), }; // Create synthetic call with filtered wait() arguments @@ -385,8 +427,10 @@ impl<'a> WaitersExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, - start_position, - end_position, + expr: wait_call.expr().to_string(), + file_path: wait_call.file_path().clone(), + start_position: wait_call.start_position(), + end_position: wait_call.end_position(), // Use client receiver from get_waiter call receiver: receiver.clone(), }), @@ -471,16 +515,21 @@ impl<'a> WaitersExtractor<'a> { #[cfg(test)] mod tests { use crate::extraction::sdk_model::ServiceMethodRef; + use crate::SourceFile; use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Python; use std::collections::HashMap; - fn create_test_ast( - source_code: &str, - ) -> ast_grep_core::AstGrep> { - Python.ast_grep(source_code) + fn create_test_ast(source_code: &str) -> AstWithSourceFile { + let source_file = SourceFile::with_language( + PathBuf::new(), + source_code.to_string(), + crate::Language::Python, + ); + let ast_grep = Python.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file.clone()) } fn create_test_service_index() -> ServiceModelIndex { @@ -638,7 +687,7 @@ waiter = ec2_client.get_waiter('instance_terminated') assert_eq!(waiters[0].variable_name, "waiter"); assert_eq!(waiters[0].waiter_name, "instance_terminated"); assert_eq!(waiters[0].client_receiver, "ec2_client"); - assert_eq!(waiters[0].get_waiter_line, 4); + assert_eq!(waiters[0].start_position.0, 4); } #[test] diff --git a/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs index 85d53f4..db27ac3 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs @@ -10,7 +10,8 @@ use std::collections::HashSet; use crate::extraction::extractor::{Extractor, ExtractorResult}; use crate::extraction::javascript::scanner::ASTScanner; use crate::extraction::javascript::shared::ExtractionUtils; -use crate::ServiceModelIndex; +use crate::extraction::AstWithSourceFile; +use crate::{ServiceModelIndex, SourceFile}; /// TypeScript extractor for AWS SDK method calls pub(crate) struct TypeScriptExtractor; @@ -30,9 +31,10 @@ impl Default for TypeScriptExtractor { #[async_trait] impl Extractor for TypeScriptExtractor { - async fn parse(&self, source_code: &str) -> ExtractorResult { + async fn parse(&self, source_file: &SourceFile) -> ExtractorResult { // Create AST once and reuse it - let ast = TypeScript.ast_grep(source_code); + let ast_grep = TypeScript.ast_grep(&source_file.content); + let ast = AstWithSourceFile::new(ast_grep, source_file.clone()); // Create scanner with the pre-built AST let mut scanner = ASTScanner::new(ast.clone(), TypeScript.into()); @@ -55,7 +57,7 @@ impl Extractor for TypeScriptExtractor { &scan_results, )); - // Return TypeScript variant with the same AST (no double construction) + // Return TypeScript variant with the same AST ExtractorResult::TypeScript(ast, method_calls) } @@ -207,6 +209,14 @@ mod tests { } } + fn create_source_file(source_code: &str) -> SourceFile { + SourceFile::with_language( + std::path::PathBuf::new(), + source_code.to_string(), + crate::Language::TypeScript, + ) + } + #[tokio::test] async fn test_parse_typescript_with_types() { let extractor = TypeScriptExtractor::new(); @@ -242,7 +252,7 @@ async function queryUsers(): Promise { } "#; - let result = extractor.parse(typescript_code).await; + let result = extractor.parse(&create_source_file(typescript_code)).await; // Verify TypeScript AST is returned match result { @@ -320,7 +330,7 @@ class MyS3Service implements S3Service { } "#; - let result = extractor.parse(typescript_code).await; + let result = extractor.parse(&create_source_file(typescript_code)).await; // Verify TypeScript extraction with generics match result { @@ -375,7 +385,7 @@ const command = new GetObjectCommand({ Bucket: 'test', Key: 'test.txt' }); "#; // Parse the code - let mut results = vec![extractor.parse(typescript_code).await]; + let mut results = vec![extractor.parse(&create_source_file(typescript_code)).await]; // Build service index with all services for testing let service_index = ServiceDiscovery::load_service_index(Language::TypeScript) @@ -447,7 +457,7 @@ const command = new GetObjectCommand({ Bucket: 'test', Key: 'test.txt' }); "#; // Parse the code - let mut results = vec![extractor.parse(typescript_code).await]; + let mut results = vec![extractor.parse(&create_source_file(typescript_code)).await]; // Load service index let service_index = ServiceDiscovery::load_service_index(Language::TypeScript) diff --git a/iam-policy-autopilot-policy-generation/src/lib.rs b/iam-policy-autopilot-policy-generation/src/lib.rs index 4a68c3f..2813fef 100644 --- a/iam-policy-autopilot-policy-generation/src/lib.rs +++ b/iam-policy-autopilot-policy-generation/src/lib.rs @@ -33,11 +33,10 @@ pub mod api; use std::fmt::Display; -pub use enrichment::Engine as EnrichmentEngine; +pub use enrichment::{Engine as EnrichmentEngine, Explanation, Location}; pub use extraction::{Engine as ExtractionEngine, ExtractedMethods, SdkMethodCall, SourceFile}; pub use policy_generation::{ - Effect, Engine as PolicyGenerationEngine, IamPolicy, MethodActionMapping, PolicyType, - PolicyWithMetadata, Statement, + Effect, Engine as PolicyGenerationEngine, IamPolicy, PolicyType, PolicyWithMetadata, Statement, }; // Re-export commonly used types for convenience diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs index e19ce8b..22139b7 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs @@ -6,8 +6,8 @@ use super::merge::{PolicyMerger, PolicyMergerConfig}; use super::utils::{ArnParser, ConditionValueProcessor}; -use super::{ActionMapping, IamPolicy, MethodActionMapping, Statement}; -use crate::enrichment::{Action, Condition, EnrichedSdkMethodCall}; +use super::{IamPolicy, PolicyGenerationResult, Statement}; +use crate::enrichment::{Action, Condition, EnrichedSdkMethodCall, Explanation, Reason}; use crate::errors::{ExtractorError, Result}; use crate::policy_generation::{PolicyType, PolicyWithMetadata}; @@ -48,7 +48,7 @@ impl<'a> Engine<'a> { /// Creates one IAM policy per EnrichedSdkMethodCall, with each Action becoming /// a separate statement within the policy. ARN patterns are processed to replace /// placeholder variables with actual values or wildcards. - pub fn generate_policies( + fn generate_individual_policies( &self, enriched_calls: &[EnrichedSdkMethodCall], ) -> Result> { @@ -221,7 +221,7 @@ impl<'a> Engine<'a> { /// /// # Errors /// Returns an error if policy merging fails - pub fn merge_policies( + pub(crate) fn merge_policies( &self, policies: &[PolicyWithMetadata], ) -> Result> { @@ -247,53 +247,76 @@ impl<'a> Engine<'a> { } } - /// Extract method to action mappings from enriched method calls + /// Generate IAM policies with explanations from enriched method calls /// - /// This method processes enriched method calls to create detailed mappings - /// between SDK method calls and their required IAM actions with associated resources. - /// It provides granular visibility into which SDK method calls require which - /// specific IAM actions and their associated resources. + /// This method generates policies and collects explanations for why each action + /// was added. Explanations are deduplicated to avoid redundant entries. /// /// # Arguments /// * `enriched_calls` - Slice of enriched method calls to process /// /// # Returns - /// A vector of method action mappings showing the relationship between - /// method calls and their required IAM actions + /// A PolicyGenerationResult containing policies and deduplicated explanations /// /// # Errors - /// Returns an error if resource processing fails for any action - pub fn extract_action_mappings( + /// Returns an error if policy generation fails + pub fn generate_policies( &self, enriched_calls: &[EnrichedSdkMethodCall], - ) -> Result> { - let mut mappings = Vec::new(); + ) -> Result { + let policies = self.generate_individual_policies(enriched_calls)?; - for enriched_call in enriched_calls { - let mut action_mappings = Vec::new(); + // Collect and deduplicate explanations + let explanations = self.group_explanations_by_action(enriched_calls); + + Ok(PolicyGenerationResult { + policies, + explanations: if explanations.is_empty() { + None + } else { + Some(explanations) + }, + }) + } - for action in &enriched_call.actions { - // Process resources to get ARN patterns using the existing method - let resources = self.process_action_resources(action)?; + /// Collect and deduplicate explanations from enriched method calls + /// + /// This method gathers all explanations from the enriched calls and groups + /// reasons by action name, removing duplicate reasons. + fn group_explanations_by_action( + &self, + enriched_calls: &[EnrichedSdkMethodCall], + ) -> Vec { + use std::collections::HashMap; - let action_mapping = ActionMapping { - action_name: action.name.clone(), - resources, - }; + // Group reasons by action name + let mut action_to_reasons: HashMap> = HashMap::new(); - action_mappings.push(action_mapping); + for enriched_call in enriched_calls { + for explanation in &enriched_call.explanations { + let reasons = action_to_reasons + .entry(explanation.action.clone()) + .or_default(); + + // Add all reasons from this explanation, deduplicating as we go + for reason in &explanation.reasons { + if !reasons.contains(reason) { + reasons.push(reason.clone()); + } + } } + } - let method_mapping = MethodActionMapping { - method_call: enriched_call.method_name.clone(), - service: enriched_call.service.clone(), - actions: action_mappings, - }; + // Convert back to Vec and sort + let mut result: Vec = action_to_reasons + .into_iter() + .map(|(action, reasons)| Explanation { action, reasons }) + .collect(); - mappings.push(method_mapping); - } + // Sort by action name for consistent output + result.sort_by(|a, b| a.action.cmp(&b.action)); - Ok(mappings) + result } } @@ -303,7 +326,7 @@ mod tests { use crate::SdkMethodCall; use super::super::Effect; - use crate::enrichment::{Action, EnrichedSdkMethodCall, Resource}; + use crate::enrichment::{Action, EnrichedSdkMethodCall, OperationView, Resource}; use crate::errors::ExtractorError; fn create_test_engine() -> Engine<'static> { @@ -337,12 +360,13 @@ mod tests { vec![], )], sdk_method_call: &sdk_call, + explanations: vec![], }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - assert_eq!(policies.len(), 1); + let result = engine.generate_policies(&[enriched_call]).unwrap(); + assert_eq!(result.policies.len(), 1); - let policy = &policies[0].policy; + let policy = &result.policies[0].policy; assert_eq!(policy.version, "2012-10-17"); assert_eq!(policy.statements.len(), 1); @@ -384,12 +408,13 @@ mod tests { ), ], sdk_method_call: &sdk_call, + explanations: vec![], }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - assert_eq!(policies.len(), 1); + let result = engine.generate_policies(&[enriched_call]).unwrap(); + assert_eq!(result.policies.len(), 1); - let policy = &policies[0].policy; + let policy = &result.policies[0].policy; assert_eq!(policy.statements.len(), 2); // Check first statement @@ -417,12 +442,13 @@ mod tests { vec![], )], sdk_method_call: &sdk_call, + explanations: vec![], }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - assert_eq!(policies.len(), 1); + let result = engine.generate_policies(&[enriched_call]).unwrap(); + assert_eq!(result.policies.len(), 1); - let policy = &policies[0].policy; + let policy = &result.policies[0].policy; let statement = &policy.statements[0]; assert_eq!(statement.resource, vec!["*"]); } @@ -452,12 +478,13 @@ mod tests { ) ], sdk_method_call: &sdk_call, + explanations: vec![], }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - assert_eq!(policies.len(), 1); + let result = engine.generate_policies(&[enriched_call]).unwrap(); + assert_eq!(result.policies.len(), 1); - let policy = &policies[0].policy; + let policy = &result.policies[0].policy; let statement = &policy.statements[0]; assert_eq!( statement.resource, @@ -478,6 +505,7 @@ mod tests { service: "s3".to_string(), actions: vec![], sdk_method_call: &sdk_call, + explanations: vec![], }; let action = Action::new("s3:GetObject".to_string(), vec![], vec![]); @@ -494,8 +522,8 @@ mod tests { #[test] fn test_generate_policies_empty_input() { let engine = create_test_engine(); - let policies = engine.generate_policies(&[]).unwrap(); - assert!(policies.is_empty()); + let result = engine.generate_policies(&[]).unwrap(); + assert!(result.policies.is_empty()); } #[test] @@ -508,6 +536,7 @@ mod tests { service: "s3".to_string(), actions: vec![], sdk_method_call: &sdk_call, + explanations: vec![], }; let result = engine.generate_policies(&[enriched_call]); @@ -874,4 +903,250 @@ mod tests { crate::enrichment::Operator::StringEquals ); } + + #[test] + fn test_generate_policies_with_explanations() { + use crate::enrichment::{Explanation, Reason}; + + let engine = create_test_engine(); + let sdk_call = create_test_sdk_call(); + + // Create enriched call with explanations + let enriched_call = EnrichedSdkMethodCall { + method_name: "get_object".to_string(), + service: "s3".to_string(), + actions: vec![Action::new( + "s3:GetObject".to_string(), + vec![Resource::new( + "object".to_string(), + Some(vec![ + "arn:${Partition}:s3:::${BucketName}/${ObjectName}".to_string() + ]), + )], + vec![], + )], + sdk_method_call: &sdk_call, + explanations: vec![Explanation { + action: "s3:GetObject".to_string(), + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call, "s3"), + fas: None, + }], + }], + }; + + let result = engine.generate_policies(&[enriched_call]).unwrap(); + + // Verify policies were generated + assert_eq!(result.policies.len(), 1); + + // Verify explanations were collected + assert!(result.explanations.is_some()); + let explanations = result.explanations.unwrap(); + assert_eq!(explanations.len(), 1); + assert_eq!(explanations[0].action, "s3:GetObject"); + assert_eq!(explanations[0].reasons.len(), 1); + assert!(explanations[0].reasons[0].fas.is_none()); + } + + #[test] + fn test_explanation_deduplication() { + use crate::enrichment::{Explanation, Reason}; + + let engine = create_test_engine(); + let sdk_call1 = create_test_sdk_call(); + let sdk_call2 = SdkMethodCall { + name: "get_object".to_string(), + possible_services: vec!["s3".to_string()], + metadata: None, + }; + + // Create two enriched calls with duplicate explanations + let enriched_call1 = EnrichedSdkMethodCall { + method_name: "get_object".to_string(), + service: "s3".to_string(), + actions: vec![Action::new( + "s3:GetObject".to_string(), + vec![Resource::new( + "object".to_string(), + Some(vec![ + "arn:${Partition}:s3:::${BucketName}/${ObjectName}".to_string() + ]), + )], + vec![], + )], + sdk_method_call: &sdk_call1, + explanations: vec![Explanation { + action: "s3:GetObject".to_string(), + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call1, "s3"), + fas: None, + }], + }], + }; + + let enriched_call2 = EnrichedSdkMethodCall { + method_name: "get_object".to_string(), + service: "s3".to_string(), + actions: vec![Action::new( + "s3:GetObject".to_string(), + vec![Resource::new( + "object".to_string(), + Some(vec![ + "arn:${Partition}:s3:::${BucketName}/${ObjectName}".to_string() + ]), + )], + vec![], + )], + sdk_method_call: &sdk_call2, + explanations: vec![Explanation { + action: "s3:GetObject".to_string(), + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call2, "s3"), + fas: None, + }], + }], + }; + + let result = engine + .generate_policies(&[enriched_call1, enriched_call2]) + .unwrap(); + + // Verify explanations were grouped by action with deduplicated reasons + assert!(result.explanations.is_some()); + let explanations = result.explanations.unwrap(); + assert_eq!( + explanations.len(), + 1, + "Should have one action with grouped reasons" + ); + assert_eq!( + explanations[0].reasons.len(), + 1, + "Duplicate reasons should be deduplicated" + ); + } + + #[test] + fn test_explanation_with_fas_expansion() { + use crate::enrichment::{Explanation, Reason}; + + let engine = create_test_engine(); + let sdk_call = create_test_sdk_call(); + + // Create enriched call with FAS expansion + let enriched_call = EnrichedSdkMethodCall { + method_name: "get_object".to_string(), + service: "s3".to_string(), + actions: vec![ + Action::new( + "s3:GetObject".to_string(), + vec![Resource::new( + "object".to_string(), + Some(vec![ + "arn:${Partition}:s3:::${BucketName}/${ObjectName}".to_string() + ]), + )], + vec![], + ), + Action::new( + "kms:Decrypt".to_string(), + vec![Resource::new( + "key".to_string(), + Some(vec![ + "arn:${Partition}:kms:${Region}:${Account}:key/${KeyId}".to_string(), + ]), + )], + vec![], + ), + ], + sdk_method_call: &sdk_call, + explanations: vec![ + Explanation { + action: "s3:GetObject".to_string(), + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call, "s3"), + fas: None, + }], + }, + Explanation { + action: "kms:Decrypt".to_string(), + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call, "s3"), + fas: Some(crate::enrichment::FasInfo::new(vec![ + "s3:GetObject".to_string(), + "kms:Decrypt".to_string(), + ])), + }], + }, + ], + }; + + let result = engine.generate_policies(&[enriched_call]).unwrap(); + + // Verify explanations include FAS expansion + assert!(result.explanations.is_some()); + let explanations = result.explanations.unwrap(); + assert_eq!(explanations.len(), 2); + + // Check the FAS-expanded action + let kms_explanation = explanations + .iter() + .find(|e| e.action == "kms:Decrypt") + .expect("Should have kms:Decrypt explanation"); + assert_eq!(kms_explanation.reasons.len(), 1); + assert!(kms_explanation.reasons[0].fas.is_some()); + let fas_info = kms_explanation.reasons[0].fas.as_ref().unwrap(); + assert_eq!(fas_info.expansion.len(), 2); + assert!(fas_info.expansion.contains(&"s3:GetObject".to_string())); + assert!(fas_info.expansion.contains(&"kms:Decrypt".to_string())); + // Verify the explanation URL is set + assert_eq!( + fas_info.explanation, + "https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html" + ); + } + + #[test] + fn test_explanation_with_possible_false_positive() { + use crate::enrichment::{Explanation, Reason}; + + let engine = create_test_engine(); + let sdk_call = SdkMethodCall { + name: "get_object".to_string(), + possible_services: vec!["s3".to_string(), "s3-object-lambda".to_string()], + metadata: None, + }; + + // Create enriched call with multiple possible services (false positive flag) + let enriched_call = EnrichedSdkMethodCall { + method_name: "get_object".to_string(), + service: "s3".to_string(), + actions: vec![Action::new( + "s3:GetObject".to_string(), + vec![Resource::new( + "object".to_string(), + Some(vec![ + "arn:${Partition}:s3:::${BucketName}/${ObjectName}".to_string() + ]), + )], + vec![], + )], + sdk_method_call: &sdk_call, + explanations: vec![Explanation { + action: "s3:GetObject".to_string(), + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call, "s3"), + fas: None, + }], + }], + }; + + let result = engine.generate_policies(&[enriched_call]).unwrap(); + + // Verify possible_false_positive flag is set + assert!(result.explanations.is_some()); + let explanations = result.explanations.unwrap(); + assert_eq!(explanations.len(), 1); + } } diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs index db4214c..0c44638 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs @@ -51,14 +51,15 @@ mod tests { ), ], sdk_method_call: &sdk_call, + explanations: vec![], }; // Generate policies - let policies = engine.generate_policies(&[enriched_call]).unwrap(); + let result = engine.generate_policies(&[enriched_call]).unwrap(); // Verify results - assert_eq!(policies.len(), 1); - let policy = &policies[0].policy; + assert_eq!(result.policies.len(), 1); + let policy = &result.policies[0].policy; // Check policy structure assert_eq!(policy.version, "2012-10-17"); @@ -110,6 +111,7 @@ mod tests { vec![], )], sdk_method_call: &sdk_call1, + explanations: vec![], }, EnrichedSdkMethodCall { method_name: "put_object".to_string(), @@ -125,22 +127,23 @@ mod tests { vec![], )], sdk_method_call: &sdk_call2, + explanations: vec![], }, ]; - let policies = engine.generate_policies(&enriched_calls).unwrap(); + let result = engine.generate_policies(&enriched_calls).unwrap(); // Should generate one policy per enriched call - assert_eq!(policies.len(), 2); + assert_eq!(result.policies.len(), 2); // Check first policy - let policy1 = &policies[0].policy; + let policy1 = &result.policies[0].policy; assert_eq!(policy1.statements.len(), 1); assert_eq!(policy1.statements[0].action, vec!["s3:GetObject"]); assert_eq!(policy1.statements[0].resource, vec!["arn:aws:s3:::*/*"]); // Check second policy - let policy2 = &policies[1].policy; + let policy2 = &result.policies[1].policy; assert_eq!(policy2.statements.len(), 1); assert_eq!(policy2.statements[0].action, vec!["s3:PutObject"]); assert_eq!(policy2.statements[0].resource, vec!["arn:aws:s3:::*/*"]); @@ -176,10 +179,11 @@ mod tests { ) ], sdk_method_call: &sdk_call, + explanations: vec![], }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - let policy = &policies[0].policy; + let result = engine.generate_policies(&[enriched_call]).unwrap(); + let policy = &result.policies[0].policy; let statement = &policy.statements[0]; // Verify ARN patterns are correctly processed for China partition @@ -211,10 +215,11 @@ mod tests { vec![], )], sdk_method_call: &sdk_call, + explanations: vec![], }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - let policy = &policies[0]; + let result = engine.generate_policies(&[enriched_call]).unwrap(); + let policy = &result.policies[0]; // Test JSON serialization let json = serde_json::to_string_pretty(policy).unwrap(); @@ -246,6 +251,7 @@ mod tests { vec![], )], sdk_method_call: &sdk_call, + explanations: vec![], }; // Should fail due to empty placeholder @@ -273,10 +279,11 @@ mod tests { vec![], )], sdk_method_call: &sdk_call, + explanations: vec![], }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - let policy = &policies[0].policy; + let result = engine.generate_policies(&[enriched_call]).unwrap(); + let policy = &result.policies[0].policy; let statement = &policy.statements[0]; // Should fallback to wildcard resource diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs index 650b4a0..fc7a7ef 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs @@ -42,29 +42,6 @@ where condition_map.serialize(serializer) } -/// Represents a single IAM action with its associated resources -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "PascalCase")] -pub(crate) struct ActionMapping { - /// The IAM action name (e.g., "s3:GetObject") - pub(crate) action_name: String, - /// Resources this action applies to - pub(crate) resources: Vec, -} - -/// Represents the mapping between a method call and its required IAM actions -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "PascalCase")] -#[non_exhaustive] -pub struct MethodActionMapping { - /// The method call name - pub(crate) method_call: String, - /// The AWS service this method belongs to - pub(crate) service: String, - /// List of IAM actions required by this method - pub(crate) actions: Vec, -} - /// Represents a complete IAM policy document #[derive(Debug, Clone, Serialize, PartialEq, Eq)] #[non_exhaustive] @@ -132,6 +109,17 @@ pub struct PolicyWithMetadata { pub policy_type: PolicyType, } +/// Result of policy generation including policies and explanations +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PolicyGenerationResult { + /// Generated policies + pub policies: Vec, + /// Explanations for why actions were added + #[serde(skip_serializing_if = "Option::is_none")] + pub explanations: Option>, +} + impl IamPolicy { /// Create a new IAM policy with the standard version pub fn new() -> Self { @@ -283,43 +271,6 @@ mod tests { assert!(json.contains("\"Condition\":{\"StringLike\":{\"s3:ExistingObjectTag/Environment\":[\"production-*\",\"staging-*\"]}}")); } - #[test] - fn test_method_action_mapping_serialization() { - let action_mapping = ActionMapping { - action_name: "s3:GetObject".to_string(), - resources: vec!["arn:aws:s3:::*/*".to_string()], - }; - - let method_mapping = MethodActionMapping { - method_call: "get_object".to_string(), - service: "s3".to_string(), - actions: vec![action_mapping], - }; - - let json = serde_json::to_string(&method_mapping).unwrap(); - - // Verify PascalCase field names - assert!(json.contains("\"MethodCall\":\"get_object\"")); - assert!(json.contains("\"Service\":\"s3\"")); - assert!(json.contains("\"Actions\":")); - assert!(json.contains("\"ActionName\":\"s3:GetObject\"")); - assert!(json.contains("\"Resources\":")); - } - - #[test] - fn test_action_mapping_serialization() { - let action_mapping = ActionMapping { - action_name: "events:PutRule".to_string(), - resources: vec!["arn:aws:events:us-east-1:123456789012:rule/*".to_string()], - }; - - let json = serde_json::to_string(&action_mapping).unwrap(); - - // Verify PascalCase field names - assert!(json.contains("\"ActionName\":\"events:PutRule\"")); - assert!(json.contains("\"Resources\":[\"arn:aws:events:us-east-1:123456789012:rule/*\"]")); - } - #[test] fn test_policy_with_metadata_serialization() { let mut policy = IamPolicy::new(); diff --git a/iam-policy-autopilot-policy-generation/tests/go_extraction_integration_test.rs b/iam-policy-autopilot-policy-generation/tests/go_extraction_integration_test.rs index bdd2ab4..3fb94d1 100644 --- a/iam-policy-autopilot-policy-generation/tests/go_extraction_integration_test.rs +++ b/iam-policy-autopilot-policy-generation/tests/go_extraction_integration_test.rs @@ -109,59 +109,15 @@ async fn test_go_extraction_to_policy_generation_integration() { PolicyGenerationEngine::new("aws", "us-east-1", "123456789012"); match policy_engine.generate_policies(&enriched_calls) { - Ok(policies) => { + Ok(result) => { + let policies = result.policies; println!("Generated {} IAM policies:", policies.len()); // Verify policy generation worked assert!(!policies.is_empty(), "Should generate at least one policy"); - // Step 5: Test policy merging - println!("\nStep 5: Testing policy merging..."); - - if policies.len() > 1 { - match policy_engine.merge_policies(&policies) { - Ok(merged_policy) => { - println!( - "Successfully merged {} policies into one", - policies.len() - ); - - // Test JSON serialization of merged policy - match serde_json::to_string_pretty(&merged_policy) { - Ok(json) => { - println!( - "Merged policy JSON ({} bytes)", - json.len() - ); - - // Verify JSON structure - assert!( - json.contains("\"Version\""), - "JSON should contain Version field" - ); - assert!( - json.contains("\"Statement\""), - "JSON should contain Statement field" - ); - } - Err(e) => { - panic!( - "Failed to serialize merged policy to JSON: {}", - e - ); - } - } - } - Err(e) => { - panic!("Policy merging failed: {}", e); - } - } - } else { - println!("Only one policy generated, skipping merge test"); - } - - // Step 6: Test JSON serialization of individual policies - println!("\nStep 6: Testing JSON serialization..."); + // Step 5: Test JSON serialization of individual policies + println!("\nStep 5: Testing JSON serialization..."); for (i, policy) in policies.iter().enumerate() { match serde_json::to_string_pretty(policy) { diff --git a/iam-policy-autopilot-policy-generation/tests/public_api_integration_test.rs b/iam-policy-autopilot-policy-generation/tests/public_api_integration_test.rs index 33195f6..c791ee8 100644 --- a/iam-policy-autopilot-policy-generation/tests/public_api_integration_test.rs +++ b/iam-policy-autopilot-policy-generation/tests/public_api_integration_test.rs @@ -102,10 +102,12 @@ def download_object(): // 4. Test Policy Generation Engine (Public API) let policy_engine = PolicyGenerationEngine::new("aws", "us-east-1", "123456789012"); - let policies = policy_engine + let result = policy_engine .generate_policies(&enriched_methods) .expect("Policy generation should succeed"); + let policies = result.policies; + // Verify policy generation results using serialization assert!(!policies.is_empty(), "Should generate policies"); println!("Generated {} policies", policies.len()); @@ -117,29 +119,6 @@ def download_object(): assert!(policy_json.contains("Statement")); println!(" Policy {}: serialized successfully", i + 1); } - - // 5. Test policy merging - let merged_policy = policy_engine - .merge_policies(&policies) - .expect("Policy merging should succeed"); - - let merged_json = - serde_json::to_string(&merged_policy).expect("Should serialize merged policy"); - assert!(merged_json.contains("2012-10-17")); - println!("Merged policy serialized successfully"); - - // 6. Test method action mapping extraction - let action_mappings = policy_engine - .extract_action_mappings(&enriched_methods) - .expect("Action mapping extraction should succeed"); - - assert!(!action_mappings.is_empty(), "Should have action mappings"); - println!("Generated {} action mappings", action_mappings.len()); - - // Test serialization of action mappings - let mappings_json = - serde_json::to_string(&action_mappings).expect("Should serialize action mappings"); - assert!(!mappings_json.is_empty()); } #[tokio::test] @@ -289,10 +268,12 @@ def multi_service_operations(): let policy_engine = PolicyGenerationEngine::new("aws", "us-west-2", "987654321098"); - let policies = policy_engine + let result = policy_engine .generate_policies(&enriched) .expect("Should generate policies"); + let policies = result.policies; + // Verify we got policies for multi-service operations println!( "Generated {} policies for multi-service operations", @@ -445,7 +426,8 @@ def start_policy_generation(): let policy_engine = PolicyGenerationEngine::new("aws", "us-east-1", "123456789012"); let policies = policy_engine .generate_policies(&enriched) - .expect("Policy generation should succeed"); + .expect("Policy generation should succeed") + .policies; assert!(!policies.is_empty(), "Should generate policies"); From 33fae3b546016d85d68bef06fd9279b1a6a605a5 Mon Sep 17 00:00:00 2001 From: Matthias Schlaipfer Date: Fri, 19 Dec 2025 20:20:55 +0000 Subject: [PATCH 02/11] feat: refactor based on PR comments --- iam-policy-autopilot-cli/src/main.rs | 11 +- iam-policy-autopilot-cli/src/output.rs | 15 +- .../src/api/model.rs | 12 +- .../src/enrichment/mod.rs | 62 +++-- .../src/enrichment/resource_matcher.rs | 165 +++++------- .../src/policy_generation/engine.rs | 255 +++++++++--------- .../policy_generation/integration_tests.rs | 19 +- .../src/policy_generation/mod.rs | 11 - 8 files changed, 259 insertions(+), 291 deletions(-) diff --git a/iam-policy-autopilot-cli/src/main.rs b/iam-policy-autopilot-cli/src/main.rs index d24f4c1..07f0b6d 100644 --- a/iam-policy-autopilot-cli/src/main.rs +++ b/iam-policy-autopilot-cli/src/main.rs @@ -438,13 +438,10 @@ async fn handle_generate_policy(config: &GeneratePolicyCliConfig) -> Result<()> }) .await?; - let policies = result.policies; - let explanations = result.explanations; - if config.individual_policies { // Output individual policies - trace!("Outputting {} individual policies", policies.len()); - output::output_iam_policies(policies, explanations, None, config.shared.pretty) + trace!("Outputting {} individual policies", result.policies.len()); + output::output_iam_policies(result, None, config.shared.pretty) .context("Failed to output individual IAM policies")?; } else { // Default behavior: output merged policy with optional upload @@ -457,7 +454,7 @@ async fn handle_generate_policy(config: &GeneratePolicyCliConfig) -> Result<()> let custom_name = config.upload_policies.as_deref().filter(|s| !s.is_empty()); let batch_response = uploader - .upload_policies(&policies, custom_name) + .upload_policies(&result.policies, custom_name) .await .context("Failed to upload policies to AWS IAM")?; @@ -483,7 +480,7 @@ async fn handle_generate_policy(config: &GeneratePolicyCliConfig) -> Result<()> None }; - output::output_iam_policies(policies, explanations, upload_result, config.shared.pretty) + output::output_iam_policies(result, upload_result, config.shared.pretty) .context("Failed to output merged IAM policy")? } diff --git a/iam-policy-autopilot-cli/src/output.rs b/iam-policy-autopilot-cli/src/output.rs index ef4f776..959df50 100644 --- a/iam-policy-autopilot-cli/src/output.rs +++ b/iam-policy-autopilot-cli/src/output.rs @@ -1,6 +1,6 @@ use anyhow::{Context, Result}; use iam_policy_autopilot_access_denied::{DenialType, PlanResult}; -use iam_policy_autopilot_policy_generation::{policy_generation::PolicyWithMetadata, Explanation}; +use iam_policy_autopilot_policy_generation::api::model::GeneratePoliciesResult; use iam_policy_autopilot_tools::BatchUploadResponse; use log::debug; use std::io::{self, Write}; @@ -161,10 +161,9 @@ pub(crate) fn print_unsupported_denial(denial_type: &DenialType, reason: &str) { #[serde(rename_all = "PascalCase")] struct PolicyOutput { /// The generated policies with type information - policies: Vec, - /// Explanations for why actions were added - #[serde(skip_serializing_if = "Option::is_none")] - explanations: Option>, + /// and explanations for why actions were added + #[serde(flatten)] + result: GeneratePoliciesResult, /// Upload results (only present when --upload-policies is used) #[serde(skip_serializing_if = "Option::is_none")] upload_result: Option, @@ -172,8 +171,7 @@ struct PolicyOutput { /// Output IAM policies as JSON to stdout pub(crate) fn output_iam_policies( - policies: Vec, - explanations: Option>, + result: GeneratePoliciesResult, upload_result: Option, pretty: bool, ) -> Result<()> { @@ -183,8 +181,7 @@ pub(crate) fn output_iam_policies( ); let policy_output = PolicyOutput { - policies, - explanations, + result, upload_result, }; diff --git a/iam-policy-autopilot-policy-generation/src/api/model.rs b/iam-policy-autopilot-policy-generation/src/api/model.rs index b25199a..2dc4e75 100644 --- a/iam-policy-autopilot-policy-generation/src/api/model.rs +++ b/iam-policy-autopilot-policy-generation/src/api/model.rs @@ -1,8 +1,10 @@ //! Defined model for API +use serde::Serialize; + use crate::{enrichment::Explanation, policy_generation::PolicyWithMetadata}; -use std::path::PathBuf; +use std::{collections::BTreeMap, path::PathBuf}; -/// Configuration for generate_policies Api +/// Configuration for generate_policies API #[derive(Debug, Clone)] pub struct GeneratePolicyConfig { /// Config used to extract sdk calls for policy generation @@ -20,12 +22,14 @@ pub struct GeneratePolicyConfig { } /// Result of policy generation including policies, action mappings, and explanations -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "PascalCase")] pub struct GeneratePoliciesResult { /// Generated IAM policies pub policies: Vec, /// Explanations for why actions were added (if requested) - pub explanations: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub explanations: Option>, } /// Service hints for filtering SDK method calls diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs index 78c740d..f53becf 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs @@ -8,7 +8,7 @@ //! that represent method calls enriched with IAM metadata from operation //! action maps and Service Definition Files. -use std::path::PathBuf; +use std::{collections::HashSet, path::PathBuf}; use crate::SdkMethodCall; use schemars::JsonSchema; @@ -80,6 +80,9 @@ pub(crate) use operation_fas_map::load_operation_fas_map; pub(crate) use resource_matcher::ResourceMatcher; pub(crate) use service_reference::RemoteServiceReferenceLoader as ServiceReferenceLoader; +const FAS_URL: &str = + "https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html"; + /// Represents Forward Access Session (FAS) expansion information #[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] #[serde(rename_all = "PascalCase")] @@ -95,7 +98,7 @@ impl FasInfo { #[must_use] pub fn new(expansion: Vec) -> Self { Self { - explanation: "https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html".to_string(), + explanation: FAS_URL.to_string(), expansion, } } @@ -160,15 +163,25 @@ impl OperationView { } /// Represents an explanation for why an action was added to a policy -#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] +#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema, Default)] #[serde(rename_all = "PascalCase")] pub struct Explanation { - /// The action that was added (authorized action from service reference) - pub action: String, /// The reasons this action was added (can have multiple reasons for the same action) pub reasons: Vec, } +impl Explanation { + pub(crate) fn merge(&mut self, other: Explanation) { + let reasons_set = self.reasons.iter().cloned().collect::>(); + for new_reason in other.reasons { + if reasons_set.contains(&new_reason) { + continue; + } + self.reasons.push(new_reason); + } + } +} + /// Represents an enriched method call with actions that need permissions #[derive(Debug, Clone, Serialize, PartialEq)] #[non_exhaustive] @@ -181,8 +194,6 @@ pub struct EnrichedSdkMethodCall<'a> { pub(crate) actions: Vec, /// The initial SDK method call pub(crate) sdk_method_call: &'a SdkMethodCall, - /// Explanations for why each action was added - pub(crate) explanations: Vec, } #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash, JsonSchema)] @@ -216,7 +227,7 @@ pub(crate) trait Context { /// /// This structure combines OperationAction action data with Service Reference resource information to provide /// complete IAM policy metadata for a single action. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, PartialEq)] pub(crate) struct Action { /// The IAM action name (e.g., "s3:GetObject") pub(crate) name: String, @@ -224,18 +235,8 @@ pub(crate) struct Action { pub(crate) resources: Vec, /// List of conditions we are adding pub(crate) conditions: Vec, -} - -/// Represents a resource enriched with ARN pattern and metadata -/// -/// This structure combines OperationAction resource data with Service Reference ARN patterns and additional -/// metadata to provide complete resource information for IAM policies. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub(crate) struct Resource { - /// The resource type name (e.g., "bucket", "object", "*") - pub(crate) name: String, - /// ARN patterns from Service Reference data, if available - pub(crate) arn_patterns: Option>, + /// Optional explanation why this action has been added + pub(crate) explanation: Explanation, } impl Action { @@ -245,16 +246,35 @@ impl Action { /// * `name` - The IAM action name /// * `resources` - List of enriched resources /// * `conditions` - List of conditions + /// * `explanation` - Explanation why the action has been added #[must_use] - pub(crate) fn new(name: String, resources: Vec, conditions: Vec) -> Self { + pub(crate) fn new( + name: String, + resources: Vec, + conditions: Vec, + explanation: Explanation, + ) -> Self { Self { name, resources, conditions, + explanation, } } } +/// Represents a resource enriched with ARN pattern and metadata +/// +/// This structure combines OperationAction resource data with Service Reference ARN patterns and additional +/// metadata to provide complete resource information for IAM policies. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub(crate) struct Resource { + /// The resource type name (e.g., "bucket", "object", "*") + pub(crate) name: String, + /// ARN patterns from Service Reference data, if available + pub(crate) arn_patterns: Option>, +} + impl Resource { /// Create a new enriched resource #[must_use] diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs index 3d0508f..ea061e9 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs @@ -42,6 +42,21 @@ impl std::hash::Hash for OperationWithFasExpansion { } } +impl OperationWithFasExpansion { + // Build the FasInfo, which we eventually output, from the operation. + // None if there is no FAS expansion + fn to_fas_info(&self, service_cfg: &ServiceConfiguration) -> Option { + if self.fas_chain.is_empty() { + None + } else { + // Build full chain: fas_chain + current operation + let mut chain = self.fas_chain.clone(); + chain.push(self.operation.service_operation_name(service_cfg)); + Some(FasInfo::new(chain)) + } + } +} + /// ResourceMatcher coordinates OperationAction maps and Service Reference data to generate enriched method calls /// /// This struct provides the core functionality for the 3-stage enrichment pipeline, @@ -127,14 +142,14 @@ impl ResourceMatcher { &self, initial: FasOperation, ) -> Result> { - let mut operations = HashSet::::new(); + let mut processed_operations = HashSet::::new(); // Initial operation has empty FAS chain let initial = OperationWithFasExpansion { operation: initial.clone(), fas_chain: vec![], }; - operations.insert(initial.clone()); + processed_operations.insert(initial.clone()); let mut to_process = vec![initial]; @@ -144,10 +159,7 @@ impl ResourceMatcher { // Process all operations in the current batch for current in &to_process { let service_name = current.operation.service(&self.service_cfg); - let operation_fas_map_option = - self.find_operation_fas_map_for_service(&service_name); - - match operation_fas_map_option { + match self.find_operation_fas_map_for_service(&service_name) { Some(operation_fas_map) => { let service_operation_name = current.operation.service_operation_name(&self.service_cfg); @@ -158,21 +170,21 @@ impl ResourceMatcher { .get(&service_operation_name) { for additional_op in additional_operations { - // Build new FAS hain: current chain + current operation + // Build new FAS chain: current chain + current operation let mut new_chain = current.fas_chain.clone(); new_chain.push( current.operation.service_operation_name(&self.service_cfg), ); - let new_op_with_prov = OperationWithFasExpansion { + let new_op = OperationWithFasExpansion { operation: additional_op.clone(), fas_chain: new_chain, }; // Only add if we haven't seen this operation before - if !operations.contains(&new_op_with_prov) { - operations.insert(new_op_with_prov.clone()); - newly_discovered.push(new_op_with_prov); + if !processed_operations.contains(&new_op) { + processed_operations.insert(new_op.clone()); + newly_discovered.push(new_op); } } } else { @@ -198,11 +210,12 @@ impl ResourceMatcher { log::debug!( "FAS expansion completed with {} total operations", - operations.len() + processed_operations.len() ); // Convert HashSet to Vec and sort by service_operation_name for deterministic output - let mut operations_vec: Vec = operations.into_iter().collect(); + let mut operations_vec: Vec = + processed_operations.into_iter().collect(); operations_vec.sort_by_key(|op| op.operation.service_operation_name(&self.service_cfg)); Ok(operations_vec) @@ -271,11 +284,10 @@ impl ResourceMatcher { // Use fixed-point algorithm to safely expand FAS operations until no new operations are found let operations = self.expand_fas_operations_to_fixed_point(initial)?; - let mut explanations = vec![]; let mut enriched_actions = vec![]; - for op_with_fas_chain in operations { - let service_name = op_with_fas_chain.operation.service(&self.service_cfg); + for op in operations { + let service_name = op.operation.service(&self.service_cfg); // Find the corresponding SDF using the cache let service_reference = service_reference_loader.load(&service_name).await?; @@ -284,25 +296,20 @@ impl ResourceMatcher { continue; } Some(service_reference) => { - log::debug!("Creating actions for {:?}", op_with_fas_chain.operation); - log::debug!(" with context {:?}", op_with_fas_chain.operation.context); - log::debug!(" FAS chain: {:?}", op_with_fas_chain.fas_chain); + log::debug!("Creating actions for {:?}", op.operation); + log::debug!(" with context {:?}", op.operation.context); + log::debug!(" FAS chain: {:?}", op.fas_chain); if let Some(operation_to_authorized_actions) = &service_reference.operation_to_authorized_actions { log::debug!( "Looking up {}", - &op_with_fas_chain - .operation - .service_operation_name(&self.service_cfg) + &op.operation.service_operation_name(&self.service_cfg) ); if let Some(operation_to_authorized_action) = - operation_to_authorized_actions.get( - &op_with_fas_chain - .operation - .service_operation_name(&self.service_cfg), - ) + operation_to_authorized_actions + .get(&op.operation.service_operation_name(&self.service_cfg)) { for action in &operation_to_authorized_action.authorized_actions { let enriched_resources = self @@ -318,8 +325,7 @@ impl ResourceMatcher { }; // Combine conditions from FAS operation context and AuthorizedAction context - let mut conditions = - Self::make_condition(&op_with_fas_chain.operation.context); + let mut conditions = Self::make_condition(&op.operation.context); // Add conditions from AuthorizedAction context if present if let Some(auth_context) = &action.context { @@ -328,93 +334,46 @@ impl ResourceMatcher { ))); } - let enriched_action = Action::new( - action.name.clone(), - enriched_resources, - conditions, - ); - - enriched_actions.push(enriched_action); - - // Include FAS chain only if chain is non-empty - let fas_info = if op_with_fas_chain.fas_chain.is_empty() { - log::debug!( - " Action '{}': Excluding FAS chain (initial operation)", - action.name - ); - None - } else { - // Build full chain: fas_chain + current operation - let mut chain = op_with_fas_chain.fas_chain.clone(); - chain.push( - op_with_fas_chain - .operation - .service_operation_name(&self.service_cfg), - ); - log::debug!( - " Action '{}': Including FAS chain: {:?}", - action.name, - chain - ); - Some(FasInfo::new(chain)) - }; - // Create explanation for this action let explanation = Explanation { - action: action.name.clone(), reasons: vec![Reason { operation: OperationView::from_call( parsed_call, original_service_name, ), - fas: fas_info, + fas: op.to_fas_info(&self.service_cfg), }], }; - explanations.push(explanation); + let enriched_action = Action::new( + action.name.clone(), + enriched_resources, + conditions, + explanation, + ); + + enriched_actions.push(enriched_action); } } else { // Fallback: operation not found in operation action map, create basic action // This ensures we don't filter out operations, only ADD additional ones from the map - if let Some(a) = - self.create_fallback_action(&parsed_call.name, &service_reference)? - { - let action_name = a.name.clone(); + if let Some(a) = self.create_fallback_action( + &parsed_call.name, + &service_reference, + parsed_call, + original_service_name, + )? { enriched_actions.push(a); - - // Create explanation for fallback action - let explanation = Explanation { - action: action_name, - reasons: vec![Reason { - operation: OperationView::from_call( - parsed_call, - original_service_name, - ), - fas: None, - }], - }; - explanations.push(explanation); } } } else { // Fallback: operation action map does not exist, create basic action - if let Some(a) = - self.create_fallback_action(&parsed_call.name, &service_reference)? - { - let action_name = a.name.clone(); + if let Some(a) = self.create_fallback_action( + &parsed_call.name, + &service_reference, + parsed_call, + original_service_name, + )? { enriched_actions.push(a); - - // Create explanation for fallback action - let explanation = Explanation { - action: action_name, - reasons: vec![Reason { - operation: OperationView::from_call( - parsed_call, - original_service_name, - ), - fas: None, - }], - }; - explanations.push(explanation); } } } @@ -430,7 +389,6 @@ impl ResourceMatcher { service: original_service_name.to_string(), actions: enriched_actions, sdk_method_call: parsed_call, - explanations, })) } @@ -442,6 +400,8 @@ impl ResourceMatcher { &self, method_name: &str, service_reference: &ServiceReference, + parsed_call: &SdkMethodCall, + original_service_name: &str, ) -> Result> { let renamed_service = self .service_cfg @@ -461,10 +421,19 @@ impl ResourceMatcher { let resources = self.find_resources_for_action_in_service_reference(&action_name, service_reference)?; + // Create explanation for fallback action + let explanation = Explanation { + reasons: vec![Reason { + operation: OperationView::from_call(parsed_call, original_service_name), + fas: None, + }], + }; + Ok(Some(Action::new( action_name.to_string(), resources, vec![], + explanation, ))) } diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs index 22139b7..b439068 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs @@ -4,12 +4,16 @@ //! The engine processes EnrichedSdkMethodCall instances and creates corresponding IAM policies //! with proper ARN pattern replacement. +use std::collections::BTreeMap; + use super::merge::{PolicyMerger, PolicyMergerConfig}; use super::utils::{ArnParser, ConditionValueProcessor}; -use super::{IamPolicy, PolicyGenerationResult, Statement}; -use crate::enrichment::{Action, Condition, EnrichedSdkMethodCall, Explanation, Reason}; +use super::{IamPolicy, Statement}; +use crate::api::model::GeneratePoliciesResult; +use crate::enrichment::{Action, Condition, EnrichedSdkMethodCall}; use crate::errors::{ExtractorError, Result}; use crate::policy_generation::{PolicyType, PolicyWithMetadata}; +use crate::Explanation; /// Policy generation engine that converts enriched method calls into IAM policies #[derive(Debug, Clone)] @@ -263,67 +267,43 @@ impl<'a> Engine<'a> { pub fn generate_policies( &self, enriched_calls: &[EnrichedSdkMethodCall], - ) -> Result { + ) -> Result { let policies = self.generate_individual_policies(enriched_calls)?; - // Collect and deduplicate explanations - let explanations = self.group_explanations_by_action(enriched_calls); + // Collect explanations + let explanations = extract_explanations(enriched_calls); - Ok(PolicyGenerationResult { + Ok(GeneratePoliciesResult { policies, - explanations: if explanations.is_empty() { - None - } else { - Some(explanations) - }, + explanations: Some(explanations), }) } +} - /// Collect and deduplicate explanations from enriched method calls - /// - /// This method gathers all explanations from the enriched calls and groups - /// reasons by action name, removing duplicate reasons. - fn group_explanations_by_action( - &self, - enriched_calls: &[EnrichedSdkMethodCall], - ) -> Vec { - use std::collections::HashMap; - - // Group reasons by action name - let mut action_to_reasons: HashMap> = HashMap::new(); - - for enriched_call in enriched_calls { - for explanation in &enriched_call.explanations { - let reasons = action_to_reasons - .entry(explanation.action.clone()) - .or_default(); - - // Add all reasons from this explanation, deduplicating as we go - for reason in &explanation.reasons { - if !reasons.contains(reason) { - reasons.push(reason.clone()); - } - } - } +fn extract_explanations( + enriched_calls: &[EnrichedSdkMethodCall<'_>], +) -> BTreeMap { + let mut explanations: BTreeMap = BTreeMap::new(); + + // Collect and merge explanations for each action name + for call in enriched_calls { + for action in &call.actions { + explanations + .entry(action.name.clone()) + .and_modify(|existing_explanation| { + existing_explanation.merge(action.explanation.clone()); + }) + .or_insert_with(|| action.explanation.clone()); } - - // Convert back to Vec and sort - let mut result: Vec = action_to_reasons - .into_iter() - .map(|(action, reasons)| Explanation { action, reasons }) - .collect(); - - // Sort by action name for consistent output - result.sort_by(|a, b| a.action.cmp(&b.action)); - - result } + + explanations } #[cfg(test)] mod tests { use super::*; - use crate::SdkMethodCall; + use crate::{Explanation, SdkMethodCall}; use super::super::Effect; use crate::enrichment::{Action, EnrichedSdkMethodCall, OperationView, Resource}; @@ -358,9 +338,9 @@ mod tests { ]), )], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call, - explanations: vec![], }; let result = engine.generate_policies(&[enriched_call]).unwrap(); @@ -395,6 +375,7 @@ mod tests { ]), )], vec![], + Explanation::default(), ), Action::new( "s3:GetObjectVersion".to_string(), @@ -405,10 +386,10 @@ mod tests { ]), )], vec![], + Explanation::default(), ), ], sdk_method_call: &sdk_call, - explanations: vec![], }; let result = engine.generate_policies(&[enriched_call]).unwrap(); @@ -440,9 +421,9 @@ mod tests { "s3:ListAllMyBuckets".to_string(), vec![Resource::new("*".to_string(), None)], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call, - explanations: vec![], }; let result = engine.generate_policies(&[enriched_call]).unwrap(); @@ -475,10 +456,10 @@ mod tests { ) ], vec![], + Explanation::default(), ) ], sdk_method_call: &sdk_call, - explanations: vec![], }; let result = engine.generate_policies(&[enriched_call]).unwrap(); @@ -505,10 +486,14 @@ mod tests { service: "s3".to_string(), actions: vec![], sdk_method_call: &sdk_call, - explanations: vec![], }; - let action = Action::new("s3:GetObject".to_string(), vec![], vec![]); + let action = Action::new( + "s3:GetObject".to_string(), + vec![], + vec![], + Explanation::default(), + ); // Test first action (index 0) let sid1 = engine.generate_statement_id(&enriched_call, &action, 0); @@ -536,7 +521,6 @@ mod tests { service: "s3".to_string(), actions: vec![], sdk_method_call: &sdk_call, - explanations: vec![], }; let result = engine.generate_policies(&[enriched_call]); @@ -617,6 +601,7 @@ mod tests { ]), )], vec![], + Explanation::default(), ); let processed_resources = engine.process_action_resources(&action).unwrap(); @@ -650,6 +635,7 @@ mod tests { ), ], vec![], + Explanation::default(), ); let processed_resources = engine.process_action_resources(&action).unwrap(); @@ -681,6 +667,7 @@ mod tests { values: vec!["${region}".to_string(), "us-west-${unknown}".to_string()], }, ], + Explanation::default(), ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -711,6 +698,7 @@ mod tests { key: "s3:ExistingObjectTag/Environment".to_string(), values: vec!["production".to_string(), "staging".to_string()], }], + Explanation::default(), ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -727,7 +715,12 @@ mod tests { let engine = create_test_engine(); // Create an action with no conditions - let action = Action::new("s3:GetObject".to_string(), vec![], vec![]); + let action = Action::new( + "s3:GetObject".to_string(), + vec![], + vec![], + Explanation::default(), + ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -747,6 +740,7 @@ mod tests { key: "s3:ExistingObjectTag/Environment".to_string(), values: vec!["s3.${unknown}.amazonaws.com".to_string()], }], + Explanation::default(), ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -780,6 +774,7 @@ mod tests { values: vec!["s3.${unknown}.amazonaws.com".to_string()], // Unknown placeholder, introduces wildcards }, ], + Explanation::default(), ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -818,6 +813,7 @@ mod tests { "us-west-${unknown}".to_string(), // Unknown placeholder, introduces wildcards ], }], + Explanation::default(), ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -844,6 +840,7 @@ mod tests { key: "s3:ExistingObjectTag/Environment".to_string(), values: vec!["arn:${partition}:s3:${region}:${account}:bucket/test".to_string()], }], + Explanation::default(), ); let processed_conditions = engine_wildcard_partition @@ -924,15 +921,14 @@ mod tests { ]), )], vec![], + Explanation { + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call, "s3"), + fas: None, + }], + }, )], sdk_method_call: &sdk_call, - explanations: vec![Explanation { - action: "s3:GetObject".to_string(), - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call, "s3"), - fas: None, - }], - }], }; let result = engine.generate_policies(&[enriched_call]).unwrap(); @@ -941,12 +937,16 @@ mod tests { assert_eq!(result.policies.len(), 1); // Verify explanations were collected - assert!(result.explanations.is_some()); - let explanations = result.explanations.unwrap(); - assert_eq!(explanations.len(), 1); - assert_eq!(explanations[0].action, "s3:GetObject"); - assert_eq!(explanations[0].reasons.len(), 1); - assert!(explanations[0].reasons[0].fas.is_none()); + if let Some(explanation) = result + .explanations + .as_ref() + .and_then(|explanations| explanations.get("s3:GetObject")) + { + assert_eq!(explanation.reasons.len(), 1); + assert!(explanation.reasons[0].fas.is_none()); + } else { + panic!("Must have an explanation for s3:GetObject"); + } } #[test] @@ -974,15 +974,14 @@ mod tests { ]), )], vec![], + Explanation { + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call1, "s3"), + fas: None, + }], + }, )], sdk_method_call: &sdk_call1, - explanations: vec![Explanation { - action: "s3:GetObject".to_string(), - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call1, "s3"), - fas: None, - }], - }], }; let enriched_call2 = EnrichedSdkMethodCall { @@ -997,15 +996,14 @@ mod tests { ]), )], vec![], + Explanation { + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call1, "s3"), + fas: None, + }], + }, )], sdk_method_call: &sdk_call2, - explanations: vec![Explanation { - action: "s3:GetObject".to_string(), - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call2, "s3"), - fas: None, - }], - }], }; let result = engine @@ -1013,18 +1011,19 @@ mod tests { .unwrap(); // Verify explanations were grouped by action with deduplicated reasons - assert!(result.explanations.is_some()); - let explanations = result.explanations.unwrap(); - assert_eq!( - explanations.len(), - 1, - "Should have one action with grouped reasons" - ); - assert_eq!( - explanations[0].reasons.len(), - 1, - "Duplicate reasons should be deduplicated" - ); + if let Some(explanation) = result + .explanations + .as_ref() + .and_then(|explanations| explanations.get("s3:GetObject")) + { + assert_eq!( + explanation.reasons.len(), + 1, + "Duplicate reasons should be deduplicated" + ); + } else { + panic!("Must have an explanation for s3:GetObject"); + } } #[test] @@ -1048,6 +1047,12 @@ mod tests { ]), )], vec![], + Explanation { + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call, "s3"), + fas: None, + }], + }, ), Action::new( "kms:Decrypt".to_string(), @@ -1058,41 +1063,31 @@ mod tests { ]), )], vec![], + Explanation { + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call, "s3"), + fas: Some(crate::enrichment::FasInfo::new(vec![ + "s3:GetObject".to_string(), + "kms:Decrypt".to_string(), + ])), + }], + }, ), ], sdk_method_call: &sdk_call, - explanations: vec![ - Explanation { - action: "s3:GetObject".to_string(), - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call, "s3"), - fas: None, - }], - }, - Explanation { - action: "kms:Decrypt".to_string(), - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call, "s3"), - fas: Some(crate::enrichment::FasInfo::new(vec![ - "s3:GetObject".to_string(), - "kms:Decrypt".to_string(), - ])), - }], - }, - ], }; let result = engine.generate_policies(&[enriched_call]).unwrap(); // Verify explanations include FAS expansion - assert!(result.explanations.is_some()); - let explanations = result.explanations.unwrap(); - assert_eq!(explanations.len(), 2); + assert_eq!(result.explanations.as_ref().unwrap().len(), 2); // Check the FAS-expanded action - let kms_explanation = explanations - .iter() - .find(|e| e.action == "kms:Decrypt") + let kms_explanation = result + .explanations + .as_ref() + .unwrap() + .get("kms:Decrypt") .expect("Should have kms:Decrypt explanation"); assert_eq!(kms_explanation.reasons.len(), 1); assert!(kms_explanation.reasons[0].fas.is_some()); @@ -1131,22 +1126,18 @@ mod tests { ]), )], vec![], + Explanation { + reasons: vec![Reason { + operation: OperationView::from_call(&sdk_call, "s3"), + fas: None, + }], + }, )], sdk_method_call: &sdk_call, - explanations: vec![Explanation { - action: "s3:GetObject".to_string(), - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call, "s3"), - fas: None, - }], - }], }; let result = engine.generate_policies(&[enriched_call]).unwrap(); - // Verify possible_false_positive flag is set - assert!(result.explanations.is_some()); - let explanations = result.explanations.unwrap(); - assert_eq!(explanations.len(), 1); + assert_eq!(result.explanations.as_ref().unwrap().len(), 1); } } diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs index 0c44638..9c8ea2b 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs @@ -8,7 +8,7 @@ mod tests { use super::super::{Effect, Engine}; use crate::enrichment::{Action, EnrichedSdkMethodCall, Resource}; use crate::errors::ExtractorError; - use crate::SdkMethodCall; + use crate::{Explanation, SdkMethodCall}; fn create_test_sdk_call() -> SdkMethodCall { SdkMethodCall { @@ -38,6 +38,7 @@ mod tests { ]), )], vec![], + Explanation::default(), ), Action::new( "s3:GetObjectVersion".to_string(), @@ -48,10 +49,10 @@ mod tests { ]), )], vec![], + Explanation::default(), ), ], sdk_method_call: &sdk_call, - explanations: vec![], }; // Generate policies @@ -109,9 +110,9 @@ mod tests { ]), )], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call1, - explanations: vec![], }, EnrichedSdkMethodCall { method_name: "put_object".to_string(), @@ -125,9 +126,9 @@ mod tests { ]), )], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call2, - explanations: vec![], }, ]; @@ -175,11 +176,11 @@ mod tests { ]) ) ], - vec![] + vec![], + Explanation::default(), ) ], sdk_method_call: &sdk_call, - explanations: vec![], }; let result = engine.generate_policies(&[enriched_call]).unwrap(); @@ -213,9 +214,9 @@ mod tests { ]), )], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call, - explanations: vec![], }; let result = engine.generate_policies(&[enriched_call]).unwrap(); @@ -249,9 +250,9 @@ mod tests { ]), // Invalid empty placeholder )], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call, - explanations: vec![], }; // Should fail due to empty placeholder @@ -277,9 +278,9 @@ mod tests { "s3:ListAllMyBuckets".to_string(), vec![Resource::new("*".to_string(), None)], // No ARN patterns vec![], + Explanation::default(), )], sdk_method_call: &sdk_call, - explanations: vec![], }; let result = engine.generate_policies(&[enriched_call]).unwrap(); diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs index fc7a7ef..e8ae478 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs @@ -109,17 +109,6 @@ pub struct PolicyWithMetadata { pub policy_type: PolicyType, } -/// Result of policy generation including policies and explanations -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct PolicyGenerationResult { - /// Generated policies - pub policies: Vec, - /// Explanations for why actions were added - #[serde(skip_serializing_if = "Option::is_none")] - pub explanations: Option>, -} - impl IamPolicy { /// Create a new IAM policy with the standard version pub fn new() -> Self { From 07e2bb9547d6578ce7279d68ae6d65450f754952 Mon Sep 17 00:00:00 2001 From: Matthias Schlaipfer Date: Wed, 7 Jan 2026 12:08:19 +0000 Subject: [PATCH 03/11] refactor: refactor Location --- Cargo.toml | 1 + .../Cargo.toml | 1 + .../src/enrichment/mod.rs | 229 +++++-------- .../src/enrichment/resource_matcher.rs | 21 +- .../src/extraction/go/disambiguation.rs | 41 +-- .../src/extraction/go/extractor.rs | 11 +- .../src/extraction/go/features_extractor.rs | 29 +- .../src/extraction/go/paginator_extractor.rs | 52 +-- .../src/extraction/go/waiter_extractor.rs | 86 ++--- .../src/extraction/javascript/scanner.rs | 304 ++++++++---------- .../src/extraction/javascript/shared.rs | 40 +-- .../src/extraction/javascript/types.rs | 32 +- .../src/extraction/mod.rs | 41 ++- .../src/extraction/python/disambiguation.rs | 25 +- .../extraction/python/disambiguation_tests.rs | 41 +-- .../src/extraction/python/extractor.rs | 14 +- .../extraction/python/paginator_extractor.rs | 89 ++--- .../python/resource_direct_calls_extractor.rs | 179 ++++------- .../extraction/python/waiters_extractor.rs | 105 ++---- .../src/lib.rs | 219 ++++++++++++- .../src/policy_generation/engine.rs | 36 +-- 21 files changed, 702 insertions(+), 894 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 1ad87c4..78fca62 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -29,6 +29,7 @@ ast-grep-core = "0.39" schemars = { version = "^1", features = ["derive"] } rust-embed = { version = "8.9", features = ["compression", "include-exclude"] } reqwest = { version = "0.12.4", features = ["rustls-tls"], default-features = false } +derive-new = "0.7.0" # Native async runtime and parallel processing tokio = { version = "1.0", features = ["fs", "rt", "rt-multi-thread", "macros", "signal"] } diff --git a/iam-policy-autopilot-policy-generation/Cargo.toml b/iam-policy-autopilot-policy-generation/Cargo.toml index 4a9e5ca..1f1abb4 100644 --- a/iam-policy-autopilot-policy-generation/Cargo.toml +++ b/iam-policy-autopilot-policy-generation/Cargo.toml @@ -27,6 +27,7 @@ serde_json.workspace = true tokio.workspace = true async-trait.workspace = true strsim.workspace = true +derive-new.workspace = true # Build dependencies [build-dependencies] diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs index f53becf..07d1b24 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs @@ -8,68 +8,12 @@ //! that represent method calls enriched with IAM metadata from operation //! action maps and Service Definition Files. -use std::{collections::HashSet, path::PathBuf}; +use std::collections::HashSet; -use crate::SdkMethodCall; +use crate::{extraction::SdkMethodCallMetadata, SdkMethodCall}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; -/// Represents a location in a source file -/// -/// This struct stores file path and position information and serializes -/// to the GNU coding standard (https://www.gnu.org/prep/standards/html_node/Errors.html) -/// format: `filename:startLine.startCol-endLine.endCol` -#[derive(Debug, Clone, PartialEq, Eq, Hash, JsonSchema)] -#[schemars( - description = "File location in GNU coding standard format: filename:startLine.startCol-endLine.endCol" -)] -pub struct Location { - /// File path - pub file_path: PathBuf, - /// Starting position (line, column) - both 1-based - pub start_position: (usize, usize), - /// Ending position (line, column) - both 1-based - pub end_position: (usize, usize), -} - -impl Location { - /// Create a new Location - #[must_use] - pub fn new( - file_path: PathBuf, - start_position: (usize, usize), - end_position: (usize, usize), - ) -> Self { - Self { - file_path, - start_position, - end_position, - } - } - - /// Format as GNU coding standard: `filename:startLine.startCol-endLine.endCol` - #[must_use] - pub fn to_gnu_format(&self) -> String { - let path_str = self.file_path.display(); - let (start_line, start_col) = self.start_position; - let (end_line, end_col) = self.end_position; - - format!( - "{}:{}.{}-{}.{}", - path_str, start_line, start_col, end_line, end_col - ) - } -} - -impl Serialize for Location { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_str(&self.to_gnu_format()) - } -} - pub(crate) mod engine; pub(crate) mod operation_fas_map; pub(crate) mod resource_matcher; @@ -88,7 +32,7 @@ const FAS_URL: &str = #[serde(rename_all = "PascalCase")] pub struct FasInfo { /// Explanation URL for Forward Access Sessions - pub explanation: String, + pub explanation: &'static str, /// The chain of operations in the FAS expansion pub expansion: Vec, } @@ -98,7 +42,7 @@ impl FasInfo { #[must_use] pub fn new(expansion: Vec) -> Self { Self { - explanation: FAS_URL.to_string(), + explanation: FAS_URL, expansion, } } @@ -109,58 +53,77 @@ impl FasInfo { #[serde(rename_all = "PascalCase")] pub struct Reason { /// The original operation that was extracted - pub operation: OperationView, - /// FAS (Forward Access Sessions) expansion information if this action came from FAS expansion - #[serde(rename = "FAS", skip_serializing_if = "Option::is_none")] + pub initial_operation: Operation, + /// Source of the operation + pub source: OperationSource, + /// Optional FAS expansion information pub fas: Option, } +impl Reason { + pub(crate) fn new( + call: &SdkMethodCall, + original_service_name: &str, + fas: Option, + ) -> Self { + let initial_operation = + Operation::new(call.name.clone(), original_service_name.to_string()); + match &call.metadata { + None => Self { + initial_operation, + source: OperationSource::Provided, + fas, + }, + Some(metadata) => Self { + initial_operation, + source: OperationSource::Extracted(metadata.clone()), + fas, + }, + } + } +} + +#[derive(derive_new::new, Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] +#[serde(rename_all = "PascalCase")] +pub struct Operation { + /// Name of the operation + pub name: String, + /// Name of the service + pub service: String, +} + #[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] #[serde(rename_all = "PascalCase")] #[serde(untagged)] -pub enum OperationView { +pub enum OperationSource { /// Operation extracted from source files #[serde(rename_all = "PascalCase")] - Extracted { - /// Extracted name - name: String, - /// Detected service containing this operation - service: String, - /// Extracted expr - expr: String, - /// Location in source file - location: Location, - }, + Extracted(SdkMethodCallMetadata), /// Operation provided (no metadata available) #[serde(rename_all = "PascalCase")] - Provided { - /// Provided name - name: String, - /// Provided service - service: String, - }, + Provided, } -impl OperationView { - pub(crate) fn from_call(call: &SdkMethodCall, service: &str) -> Self { - match &call.metadata { - None => Self::Provided { - name: call.name.clone(), - service: service.to_string(), - }, - Some(metadata) => Self::Extracted { - name: call.name.clone(), - service: service.to_string(), - expr: metadata.expr.clone(), - location: Location::new( - metadata.file_path.clone(), - metadata.start_position, - metadata.end_position, - ), - }, - } - } -} +// impl OperationSource { +// pub(crate) fn from_call(call: &SdkMethodCall, service: &str) -> Self { +// match &call.metadata { +// None => Self::Provided { +// name: call.name.clone(), +// service: service.to_string(), +// }, +// Some(metadata) => Self::Extracted { +// name: call.name.clone(), +// service: service.to_string(), +// expr: metadata.expr.clone(), +// location: Location::new( +// metadata.file_path.clone(), +// metadata.start_position, +// metadata.end_position, +// ), +// }, +// } +// } +// } /// Represents an explanation for why an action was added to a policy #[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema, Default)] @@ -535,6 +498,7 @@ pub(crate) mod mock_remote_service_reference { #[cfg(test)] mod location_tests { use super::*; + use crate::Location; use std::path::PathBuf; #[test] @@ -567,64 +531,45 @@ mod location_tests { assert_eq!(json, "\"example.go:100.1-105.50\""); } - #[test] - fn test_operation_view_extracted_with_location() { - use crate::extraction::{SdkMethodCall, SdkMethodCallMetadata}; - - let call = SdkMethodCall { + fn mock_sdk_method_call() -> SdkMethodCall { + SdkMethodCall { name: "get_object".to_string(), possible_services: vec!["s3".to_string()], metadata: Some(SdkMethodCallMetadata { parameters: vec![], return_type: None, expr: "s3.get_object(Bucket='my-bucket')".to_string(), - file_path: PathBuf::from("test.py"), - start_position: (10, 5), - end_position: (10, 79), + location: Location::new(PathBuf::from("test.py"), (10, 5), (10, 79)), receiver: Some("s3".to_string()), }), - }; - - let operation_view = OperationView::from_call(&call, "s3"); - - match operation_view { - OperationView::Extracted { - name, - service, - expr, - location, - } => { - assert_eq!(name, "get_object"); - assert_eq!(service, "s3"); - assert_eq!(expr, "s3.get_object(Bucket='my-bucket')"); - assert_eq!(location.to_gnu_format(), "test.py:10.5-10.79"); + } + } + + #[test] + fn test_reason_extracted_with_location() { + let call = mock_sdk_method_call(); + + let reason = Reason::new(&call, "s3", None); + + match reason.source { + OperationSource::Extracted(metadata) => { + assert_eq!(reason.initial_operation.name, "get_object"); + assert_eq!(reason.initial_operation.service, "s3"); + assert_eq!(metadata.expr, "s3.get_object(Bucket='my-bucket')"); + assert_eq!(metadata.location.to_gnu_format(), "test.py:10.5-10.79"); } _ => panic!("Expected Extracted variant"), } } #[test] - fn test_operation_view_extracted_serialization() { - use crate::extraction::{SdkMethodCall, SdkMethodCallMetadata}; - - let call = SdkMethodCall { - name: "list_buckets".to_string(), - possible_services: vec!["s3".to_string()], - metadata: Some(SdkMethodCallMetadata { - parameters: vec![], - return_type: None, - expr: "s3.list_buckets()".to_string(), - file_path: PathBuf::from("app.py"), - start_position: (5, 1), - end_position: (5, 20), - receiver: Some("s3".to_string()), - }), - }; + fn test_reason_extracted_serialization() { + let call = mock_sdk_method_call(); - let operation_view = OperationView::from_call(&call, "s3"); - let json = serde_json::to_string(&operation_view).unwrap(); + let reason = Reason::new(&call, "s3", None); + let json = serde_json::to_string(&reason).unwrap(); // Verify the location is serialized as a string in GNU format - assert!(json.contains("\"Location\":\"app.py:5.1-5.20\"")); + assert!(json.contains("\"Location\":\"test.py:10.5-10.79\"")); } } diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs index ea061e9..13b7eb9 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs @@ -8,9 +8,7 @@ use convert_case::{Case, Casing}; use std::collections::HashSet; use std::sync::Arc; -use super::{ - Action, Context, EnrichedSdkMethodCall, Explanation, FasInfo, OperationView, Reason, Resource, -}; +use super::{Action, Context, EnrichedSdkMethodCall, Explanation, FasInfo, Reason, Resource}; use crate::enrichment::operation_fas_map::{FasOperation, OperationFasMap, OperationFasMaps}; use crate::enrichment::service_reference::ServiceReference; use crate::enrichment::{Condition, ServiceReferenceLoader}; @@ -336,13 +334,11 @@ impl ResourceMatcher { // Create explanation for this action let explanation = Explanation { - reasons: vec![Reason { - operation: OperationView::from_call( - parsed_call, - original_service_name, - ), - fas: op.to_fas_info(&self.service_cfg), - }], + reasons: vec![Reason::new( + parsed_call, + original_service_name, + op.to_fas_info(&self.service_cfg), + )], }; let enriched_action = Action::new( action.name.clone(), @@ -423,10 +419,7 @@ impl ResourceMatcher { // Create explanation for fallback action let explanation = Explanation { - reasons: vec![Reason { - operation: OperationView::from_call(parsed_call, original_service_name), - fas: None, - }], + reasons: vec![Reason::new(parsed_call, original_service_name, None)], }; Ok(Some(Action::new( diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs index b205890..c0ebca2 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs @@ -334,6 +334,7 @@ mod tests { Shape, ShapeReference, }; use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; + use crate::Location; use std::collections::HashMap; use std::path::PathBuf; @@ -607,9 +608,7 @@ mod tests { }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -659,9 +658,7 @@ mod tests { struct_fields: None, }], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 30), + location: Location::new(PathBuf::new(), (1, 1), (1, 30)), receiver: Some("client".to_string()), }), }; @@ -709,9 +706,7 @@ mod tests { }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -754,9 +749,7 @@ mod tests { }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -805,9 +798,7 @@ mod tests { }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -844,10 +835,8 @@ mod tests { }, ], expr: "GetObject".to_string(), - file_path: PathBuf::new(), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), return_type: None, - start_position: (1, 1), - end_position: (1, 50), receiver: Some("client".to_string()), }), }; @@ -884,10 +873,8 @@ mod tests { }, ], expr: "GetObject".to_string(), - file_path: PathBuf::new(), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), return_type: None, - start_position: (1, 1), - end_position: (1, 50), receiver: Some("client".to_string()), }), }; @@ -923,10 +910,8 @@ mod tests { }, ], expr: "GetObject".to_string(), - file_path: PathBuf::new(), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), return_type: None, - start_position: (1, 1), - end_position: (1, 50), receiver: Some("client".to_string()), }), }; @@ -967,10 +952,8 @@ mod tests { }, ], expr: "GetObject".to_string(), - file_path: PathBuf::new(), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), return_type: None, - start_position: (1, 1), - end_position: (1, 50), receiver: Some("client".to_string()), }), }; @@ -1026,10 +1009,8 @@ mod tests { }, ], expr: "CreateQueue".to_string(), - file_path: PathBuf::new(), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), return_type: None, - start_position: (1, 1), - end_position: (1, 50), receiver: Some("sqsClient".to_string()), }), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs index fa6de92..39c4a75 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs @@ -10,7 +10,7 @@ use crate::extraction::go::waiter_extractor::GoWaiterExtractor; use crate::extraction::{ AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, }; -use crate::{ServiceModelIndex, SourceFile}; +use crate::{Location, ServiceModelIndex, SourceFile}; use ast_grep_config::from_yaml_string; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Go; @@ -189,11 +189,6 @@ rule: vec![] }; - // Get position information - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - let method_call = SdkMethodCall { name: method_name.to_string(), possible_services: Vec::new(), // Will be determined later during service validation @@ -201,9 +196,7 @@ rule: parameters: arguments, return_type: None, // We don't know the return type from the call site expr: node_match.text().to_string(), - file_path: source_file.path.clone(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node(source_file.path.clone(), node_match.get_node()), receiver, }), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs index 1c49ffb..a7928e4 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs @@ -3,12 +3,11 @@ //! This module handles extraction of Go AWS SDK v2 feature methods like S3 Upload/Download, //! and other specialized SDK features. -use std::path::PathBuf; - use crate::extraction::go::features::{FeatureMethod, GoSdkV2Features}; use crate::extraction::go::types::GoImportInfo; use crate::extraction::go::utils; use crate::extraction::{AstWithSourceFile, SdkMethodCall, SdkMethodCallMetadata}; +use crate::Location; use ast_grep_config::from_yaml_string; use ast_grep_language::Go; @@ -21,14 +20,10 @@ pub(crate) struct FeatureCallInfo { pub(crate) receiver: Option, /// Extracted arguments pub(crate) arguments: Vec, - /// File where we found the feature call - pub(crate) file_path: PathBuf, /// Matched expression pub(crate) expr: String, - /// Start position of the call node - pub(crate) start_position: (usize, usize), - /// End position of the call node - pub(crate) end_position: (usize, usize), + /// Location of the call + pub(crate) location: Location, } /// Extractor for Go AWS SDK v2 feature methods @@ -112,19 +107,15 @@ rule: let args_nodes = env.get_multiple_matches("ARGS"); let arguments = utils::extract_arguments(&args_nodes); - // Get position information - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - calls.push(FeatureCallInfo { method_name, receiver: Some(receiver), arguments, expr: node_match.text().to_string(), - file_path: ast.source_file.path.clone(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node( + ast.source_file.path.clone(), + node_match.get_node(), + ), }); } } @@ -230,9 +221,7 @@ rule: parameters: parameters.clone(), return_type: None, expr: call_info.expr.clone(), - file_path: call_info.file_path.clone(), - start_position: call_info.start_position, - end_position: call_info.end_position, + location: call_info.location.clone(), receiver: call_info.receiver.clone(), }), } @@ -243,6 +232,8 @@ rule: #[cfg(test)] mod tests { + use std::path::PathBuf; + use crate::{Language, SourceFile}; use super::*; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs index 7f418e9..60145ff 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs @@ -3,13 +3,13 @@ //! This module handles extraction of Go AWS SDK v2 paginator patterns by detecting //! paginator creation calls, which contain the meaningful parameters for IAM policy generation. -use std::path::{Path, PathBuf}; +use std::path::Path; use crate::extraction::go::utils; use crate::extraction::sdk_model::ServiceDiscovery; use crate::extraction::{AstWithSourceFile, Parameter, SdkMethodCall, SdkMethodCallMetadata}; -use crate::Language; use crate::ServiceModelIndex; +use crate::{Language, Location}; use ast_grep_language::Go; /// Information about a discovered paginator creation call @@ -27,12 +27,8 @@ pub(crate) struct PaginatorInfo { pub creation_arguments: Vec, /// Matched expression pub expr: String, - /// File where we found the paginator - pub file_path: PathBuf, - /// Start position where paginator was created - pub start_position: (usize, usize), - /// End position where paginator was created - pub end_position: (usize, usize), + /// Location of the paginator creation + pub location: Location, } /// Information about a chained paginator call @@ -46,12 +42,8 @@ pub(crate) struct ChainedPaginatorCallInfo { pub arguments: Vec, /// Matched expression pub expr: String, - /// File where the chained paginator call was found - pub file_path: PathBuf, - /// Start position of the chained call node - pub start_position: (usize, usize), - /// End position of the chained call node - pub end_position: (usize, usize), + /// Location of the paginator was called + pub location: Location, } /// Extractor for Go AWS SDK paginator patterns @@ -178,24 +170,13 @@ impl<'a> GoPaginatorExtractor<'a> { .and_then(|s| s.strip_suffix("Paginator")); if let Some(operation_name) = operation_name { - let start_position = { - let pos = node_match.get_node().start_pos(); - (pos.line() + 1, pos.column(node_match.get_node())) - }; - let end_position = { - let pos = node_match.get_node().end_pos(); - (pos.line() + 1, pos.column(node_match.get_node())) - }; - return Some(PaginatorInfo { variable_name, paginator_type: operation_name.to_string(), client_receiver, creation_arguments, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), - start_position, - end_position, + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }); } @@ -239,19 +220,12 @@ impl<'a> GoPaginatorExtractor<'a> { Vec::new() }; - // Get position information - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(ChainedPaginatorCallInfo { paginator_type, client_receiver, arguments: creation_arguments, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -281,9 +255,7 @@ impl<'a> GoPaginatorExtractor<'a> { parameters: paginator_info.creation_arguments.clone(), return_type: None, expr: paginator_info.expr.clone(), - file_path: paginator_info.file_path.clone(), - start_position: paginator_info.start_position, - end_position: paginator_info.end_position, + location: paginator_info.location.clone(), receiver: Some(paginator_info.client_receiver.clone()), }), } @@ -318,9 +290,7 @@ impl<'a> GoPaginatorExtractor<'a> { parameters: chained_call.arguments.clone(), return_type: None, expr: chained_call.expr.clone(), - file_path: chained_call.file_path.clone(), - start_position: chained_call.start_position, - end_position: chained_call.end_position, + location: chained_call.location.clone(), receiver: Some(chained_call.client_receiver.clone()), }), } @@ -334,7 +304,7 @@ mod tests { use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Go; - use std::collections::HashMap; + use std::{collections::HashMap, path::PathBuf}; fn create_test_ast(source_code: &str) -> AstWithSourceFile { let source_file = diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs index a438cc1..a2dce8b 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs @@ -3,13 +3,13 @@ //! This module handles extraction of Go AWS SDK waiter patterns, which involve //! creating a waiter from a client, then calling Wait() on the waiter. -use std::path::{Path, PathBuf}; +use std::path::Path; use crate::extraction::go::utils; use crate::extraction::{ AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, }; -use crate::ServiceModelIndex; +use crate::{Location, ServiceModelIndex}; use ast_grep_language::Go; /// Information about a discovered waiter creation call @@ -23,12 +23,14 @@ pub(crate) struct WaiterInfo { pub client_receiver: String, /// Matched expression pub expr: String, - /// File where the waiter was found - pub file_path: PathBuf, - /// Line number where waiter was created - pub start_position: (usize, usize), - /// Line number where waiter was created - pub end_position: (usize, usize), + /// Location where the waiter was created + pub location: Location, +} + +impl WaiterInfo { + pub(crate) fn start_line(&self) -> usize { + self.location.start_line() + } } /// Information about a Wait method call @@ -40,12 +42,14 @@ pub(crate) struct WaitCallInfo { pub arguments: Vec, /// Matched expression pub expr: String, - /// File where the wait call was found - pub file_path: PathBuf, - /// Start position of the Wait call node - pub start_position: (usize, usize), - /// End position of the Wait call node - pub end_position: (usize, usize), + /// Location where the waiter was called + pub location: Location, +} + +impl WaitCallInfo { + pub(crate) fn start_line(&self) -> usize { + self.location.start_line() + } } // TODO: This should be refactored at a higher level, so this type can be removed. @@ -75,24 +79,10 @@ impl<'a> CallInfo<'a> { } } - fn file_path(&self) -> &'a PathBuf { - match self { - CallInfo::None(waiter_info) => &waiter_info.file_path, - CallInfo::Simple(_, wait_call_info) => &wait_call_info.file_path, - } - } - - fn start_position(&self) -> (usize, usize) { - match self { - CallInfo::None(waiter_info) => waiter_info.start_position, - CallInfo::Simple(_, wait_call_info) => wait_call_info.start_position, - } - } - - fn end_position(&self) -> (usize, usize) { + fn location(&self) -> &'a Location { match self { - CallInfo::None(waiter_info) => waiter_info.end_position, - CallInfo::Simple(_, wait_call_info) => wait_call_info.end_position, + CallInfo::None(waiter_info) => &waiter_info.location, + CallInfo::Simple(_, wait_call_info) => &wait_call_info.location, } } } @@ -219,27 +209,12 @@ impl<'a> GoWaiterExtractor<'a> { .and_then(|s| s.strip_suffix("Waiter")); if let Some(waiter_name) = waiter_name { - let start_position = { - let pos = node_match.get_node().start_pos(); - let line = pos.line() + 1; - let col = pos.column(node_match.get_node()) + 1; - (line, col) - }; - let end_position = { - let pos = node_match.get_node().end_pos(); - let line = pos.line() + 1; - let col = pos.column(node_match.get_node()) + 1; - (line, col) - }; - return Some(WaiterInfo { variable_name, waiter_name: waiter_name.to_string(), client_receiver, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), - start_position, - end_position, + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }); } @@ -261,18 +236,11 @@ impl<'a> GoWaiterExtractor<'a> { let args_nodes = env.get_multiple_matches("ARGS"); let arguments = utils::extract_arguments(&args_nodes); - // Get position information - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(WaitCallInfo { waiter_var, arguments, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -290,8 +258,8 @@ impl<'a> GoWaiterExtractor<'a> { for (idx, waiter) in waiters.iter().enumerate() { if waiter.variable_name == wait_call.waiter_var { // Only consider waiters that come before the wait call - if waiter.start_position.0 < wait_call.start_position.0 { - let distance = wait_call.start_position.0 - waiter.start_position.0; + if waiter.start_line() < wait_call.start_line() { + let distance = wait_call.start_line() - waiter.start_line(); if distance < best_distance { best_distance = distance; best_match = Some(waiter); @@ -337,9 +305,7 @@ impl<'a> GoWaiterExtractor<'a> { parameters, return_type: None, expr: call.expr().to_string(), - file_path: call.file_path().clone(), - start_position: call.start_position(), - end_position: call.end_position(), + location: call.location().clone(), receiver: Some(call.waiter_info().client_receiver.clone()), }), }); diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs index 81180b1..47f044a 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs @@ -6,48 +6,14 @@ use crate::extraction::javascript::types::{ ValidClientTypes, }; use crate::extraction::AstWithSourceFile; +use crate::Location; use ast_grep_core::matcher::Pattern; -use ast_grep_core::tree_sitter::StrDoc; -use ast_grep_core::Doc; use ast_grep_core::{tree_sitter, MatchStrictness, NodeMatch}; +use ast_grep_core::{Doc, Node}; use std::collections::HashMap; -fn parse_import_item_with_span( - import_item: &str, - start_position: (usize, usize), - end_position: (usize, usize), -) -> Option { - let import_item = import_item.trim(); - if import_item.is_empty() { - return None; - } - - // Check for rename syntax: "OriginalName as LocalName" - if let Some(as_pos) = import_item.find(" as ") { - let original_name = import_item[..as_pos].trim().to_string(); - let local_name = import_item[as_pos + 4..].trim().to_string(); - Some(ImportInfo::new( - original_name, - local_name, - import_item, - start_position, - end_position, - )) - } else { - // No rename - original name is the same as local name - let import_name = import_item.trim().to_string(); - Some(ImportInfo::new( - import_name.clone(), - import_name, - import_item, - start_position, - end_position, - )) - } -} - fn parse_object_literal(obj_text: &str) -> HashMap { let mut result = HashMap::new(); @@ -134,35 +100,6 @@ fn parse_key_value_pair(pair: &str, result: &mut HashMap) { } } -fn parse_and_add_imports_with_span( - imports_text: &str, - sublibrary_info: &mut SublibraryInfo, - start_position: (usize, usize), - end_position: (usize, usize), -) { - // Handle different import formats - if imports_text.starts_with('{') && imports_text.ends_with('}') { - // Destructuring - parse with rename support - let imports_content = &imports_text[1..imports_text.len() - 1]; // Remove braces - - // Split by comma and parse each import - for import_item in imports_content.split(',') { - if let Some(import_info) = - parse_import_item_with_span(import_item, start_position, end_position) - { - sublibrary_info.add_import(import_info); - } - } - } else { - // Default import - single identifier - if let Some(import_info) = - parse_import_item_with_span(imports_text, start_position, end_position) - { - sublibrary_info.add_import(import_info); - } - } -} - /// Core AST scanner for JavaScript/TypeScript AWS SDK usage patterns pub(crate) struct ASTScanner where @@ -185,6 +122,63 @@ where Self { ast_grep, language } } + fn parse_and_add_imports( + &self, + imports_text: &str, + sublibrary_info: &mut SublibraryInfo, + node: &Node<'_, tree_sitter::StrDoc>, + ) { + // Handle different import formats + if imports_text.starts_with('{') && imports_text.ends_with('}') { + // Destructuring - parse with rename support + let imports_content = &imports_text[1..imports_text.len() - 1]; // Remove braces + + // Split by comma and parse each import + for import_item in imports_content.split(',') { + if let Some(import_info) = self.parse_import_item(import_item, node) { + sublibrary_info.add_import(import_info); + } + } + } else { + // Default import - single identifier + if let Some(import_info) = self.parse_import_item(imports_text, node) { + sublibrary_info.add_import(import_info); + } + } + } + + fn parse_import_item( + &self, + import_item: &str, + node: &Node<'_, tree_sitter::StrDoc>, + ) -> Option { + let import_item = import_item.trim(); + if import_item.is_empty() { + return None; + } + + // Check for rename syntax: "OriginalName as LocalName" + if let Some(as_pos) = import_item.find(" as ") { + let original_name = import_item[..as_pos].trim().to_string(); + let local_name = import_item[as_pos + 4..].trim().to_string(); + Some(ImportInfo::new( + original_name, + local_name, + import_item, + Location::from_node(self.ast_grep.source_file.path.to_path_buf(), node), + )) + } else { + // No rename - original name is the same as local name + let import_name = import_item.trim().to_string(); + Some(ImportInfo::new( + import_name.clone(), + import_name, + import_item, + Location::from_node(self.ast_grep.source_file.path.to_path_buf(), node), + )) + } + } + /// Execute a pattern match against the AST using relaxed strictness to handle inline comments fn find_all_matches( &self, @@ -199,28 +193,6 @@ where Ok(root.find_all(pattern_obj).collect()) } - /// Extract 1-based (line, column) position from the first match - fn get_first_match_span( - matches: &[ast_grep_core::NodeMatch>], - ) -> Option<((usize, usize), (usize, usize))> { - matches.first().map(|first_match| { - let node = first_match.get_node(); - let start_pos = { - let pos = node.start_pos(); - let line = pos.line() + 1; - let column = pos.column(node) + 1; - (line, column) - }; - let end_pos = { - let pos = node.end_pos(); - let line = pos.line() + 1; - let column = pos.column(node) + 1; - (line, column) - }; - (start_pos, end_pos) - }) - } - /// Find Command instantiation and extract its arguments /// Returns CommandInstantiationResult with position and parameters pub(crate) fn find_command_instantiation_with_args( @@ -232,9 +204,9 @@ where let pattern = format!("new {}($ARGS)", command_name); if let Ok(matches) = self.find_all_matches(&pattern) { - if let Some(span) = Self::get_first_match_span(&matches) { - let first_match = matches.first().unwrap(); - + if let Some(first_match) = matches.first() { + let location = + Location::from_node(self.ast_grep.source_file.path.to_path_buf(), first_match); let env = first_match.get_env(); // Extract arguments from the ARGS node @@ -242,12 +214,7 @@ where let args_node = env.get_match("ARGS"); let parameters = ArgumentExtractor::extract_object_parameters(args_node); - return Some(CommandUsage::new( - first_match.text(), - span.0, - span.1, - parameters, - )); + return Some(CommandUsage::new(first_match.text(), location, parameters)); } } None @@ -264,20 +231,16 @@ where let pattern = format!("{}($ARG1, $ARG2)", function_name); if let Ok(matches) = self.find_all_matches(&pattern) { - if let Some(span) = Self::get_first_match_span(&matches) { - let first_match = matches.first().unwrap(); + if let Some(first_match) = matches.first() { + let location = + Location::from_node(self.ast_grep.source_file.path.to_path_buf(), first_match); let env = first_match.get_env(); // Extract parameters from second argument (ARG2 = operation params) let second_arg = env.get_match("ARG2"); let parameters = ArgumentExtractor::extract_object_parameters(second_arg); - return Some(CommandUsage::new( - first_match.text(), - span.0, - span.1, - parameters, - )); + return Some(CommandUsage::new(first_match.text(), location, parameters)); } } None @@ -298,20 +261,18 @@ where for pattern in &patterns { if let Ok(matches) = self.find_all_matches(pattern) { - if let Some(span) = Self::get_first_match_span(&matches) { - let first_match = matches.first().unwrap(); + if let Some(first_match) = matches.first() { + let location = Location::from_node( + self.ast_grep.source_file.path.to_path_buf(), + first_match, + ); let env = first_match.get_env(); // Extract parameters from second argument (ARG2 = operation params) let second_arg = env.get_match("ARG2"); let parameters = ArgumentExtractor::extract_object_parameters(second_arg); - return Some(CommandUsage::new( - first_match.text(), - span.0, - span.1, - parameters, - )); + return Some(CommandUsage::new(first_match.text(), location, parameters)); } } } @@ -332,11 +293,15 @@ where for pattern in &patterns { if let Ok(matches) = self.find_all_matches(pattern) { - if let Some(span) = Self::get_first_match_span(&matches) { + if let Some(first_match) = matches.first() { + let location = Location::from_node( + self.ast_grep.source_file.path.to_path_buf(), + first_match, + ); let expr_text = matches.first().unwrap().text(); // TODO: Extract from variable assignments let parameters = vec![]; - return Some(CommandUsage::new(expr_text, span.0, span.1, parameters)); + return Some(CommandUsage::new(expr_text, location, parameters)); } } } @@ -348,19 +313,17 @@ where let mut sublibrary_data: HashMap = HashMap::new(); let matches = self.find_all_matches(pattern)?; - Self::process_import_matches(matches, &mut sublibrary_data)?; + self.process_import_matches(matches, &mut sublibrary_data)?; Ok(sublibrary_data.into_values().collect()) } /// Generic processing for import/require matches - works for both JavaScript and TypeScript - fn process_import_matches( - matches: Vec>, + fn process_import_matches( + &self, + matches: Vec>>, sublibrary_data: &mut HashMap, - ) -> Result<(), String> - where - U: Doc + std::clone::Clone, - { + ) -> Result<(), String> { for node_match in matches { let env = node_match.get_env(); @@ -385,23 +348,10 @@ where .entry(sublibrary.clone()) .or_insert_with(|| SublibraryInfo::new(sublibrary)); - let start_position = { - let pos = node_match.get_node().start_pos(); - let line = pos.line() + 1; - let column = pos.column(imports_node) + 1; - (line, column) - }; - let end_position = { - let pos = node_match.get_node().end_pos(); - let line = pos.line() + 1; - let column = pos.column(imports_node) + 1; - (line, column) - }; - parse_and_add_imports_with_span( + self.parse_and_add_imports( imports_text_str, sublibrary_info, - start_position, - end_position, + node_match.get_node(), ); } } @@ -570,16 +520,13 @@ where } /// Generic processing for method call matches - works for both JavaScript and TypeScript - fn process_method_call_matches( + fn process_method_call_matches( &self, - matches: Vec>, + matches: Vec>>, client_variables: &[String], client_info_map: &HashMap, results: &mut Vec, - ) -> Result<(), String> - where - U: Doc + std::clone::Clone, - { + ) -> Result<(), String> { for node_match in matches { let env = node_match.get_env(); @@ -604,26 +551,18 @@ where HashMap::new() }; - let start_position = { - let pos = node_match.get_node().start_pos(); - (pos.line() + 1, pos.column(node_match.get_node()) + 1) - }; - let end_position = { - let pos = node_match.get_node().end_pos(); - (pos.line() + 1, pos.column(node_match.get_node()) + 1) - }; - results.push(MethodCall { client_variable: variable_name, client_type: client_type.clone(), original_client_type: original_client_type.clone(), client_sublibrary: client_sublibrary.clone(), expr: node_match.text().to_string(), - file_path: self.ast_grep.source_file.path.clone(), method_name, arguments, - start_position, - end_position, + location: Location::from_node( + self.ast_grep.source_file.path.to_path_buf(), + node_match.get_node(), + ), }); } } @@ -697,22 +636,6 @@ mod tests { use ast_grep_language::{JavaScript, TypeScript}; use tree_sitter::LanguageExt; - #[test] - fn test_parse_import_item() { - // Test regular import - let import_info = parse_import_item_with_span("S3Client", (1, 1), (1, 1)).unwrap(); - assert_eq!(import_info.original_name, "S3Client"); - assert_eq!(import_info.local_name, "S3Client"); - assert!(!import_info.is_renamed); - - // Test renamed import - let import_info = - parse_import_item_with_span("S3Client as MyS3Client", (1, 1), (1, 1)).unwrap(); - assert_eq!(import_info.original_name, "S3Client"); - assert_eq!(import_info.local_name, "MyS3Client"); - assert!(import_info.is_renamed); - } - #[test] fn test_parse_object_literal() { let result = parse_object_literal("{region: 'us-east-1', timeout: 5000}"); @@ -734,6 +657,41 @@ mod tests { AstWithSourceFile::new(ast_grep, source_file.clone()) } + #[test] + fn test_parse_import_item() { + // Test regular import + let source = r#"import S3Client from "@aws-sdk/client-s3""#; + let ast = create_js_ast(source); + let mut scanner = ASTScanner::new(ast, JavaScript.into()); + + let (imports, _requires) = scanner.scan_all_aws_imports().unwrap(); + assert_eq!( + ImportInfo::new( + "S3Client".to_string(), + "S3Client".to_string(), + "S3Client", + Location::new(PathBuf::new(), (1, 1), (1, 42)), + ), + imports[0].imports[0] + ); + + // Test renamed import + let source = r#"import { S3Client as MyS3Client } from "@aws-sdk/client-s3";"#; + let ast = create_js_ast(source); + let mut scanner = ASTScanner::new(ast, JavaScript.into()); + + let (imports, _requires) = scanner.scan_all_aws_imports().unwrap(); + assert_eq!( + ImportInfo::new( + "S3Client".to_string(), + "MyS3Client".to_string(), + "S3Client as MyS3Client", + Location::new(PathBuf::new(), (1, 1), (1, 61)), + ), + imports[0].imports[0] + ); + } + #[test] fn test_import_require_scanning_comprehensive() { // Create comprehensive test case with multiple sublibrary patterns @@ -1038,7 +996,8 @@ function createListParams(): ListTablesInput { if let Some(result) = scanner.find_command_input_usage_position("QueryCommandInput") { assert_eq!( - result.start_position.0, 9, + result.location.start_line(), + 9, "QueryCommandInput should be at line 9" ); } else { @@ -1047,7 +1006,8 @@ function createListParams(): ListTablesInput { if let Some(result) = scanner.find_command_input_usage_position("ListTablesInput") { assert_eq!( - result.start_position.0, 15, + result.location.start_line(), + 15, "ListTablesInput should be at line 15" ); } else { diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs index b0223e1..56c0669 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs @@ -5,6 +5,7 @@ use crate::extraction::javascript::types::{ImportInfo, JavaScriptScanResults}; use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; +use crate::Location; use rust_embed::RustEmbed; use serde::Deserialize; use std::borrow::Cow; @@ -54,10 +55,8 @@ fn load_libraries_mapping() -> Option { pub(crate) struct CommandUsage<'a> { /// The matched text from the AST pub(crate) text: Cow<'a, str>, - /// Start position in the source file (line, column) - 1-based - pub(crate) start_position: (usize, usize), - /// End position in the source file (line, column) - 1-based - pub(crate) end_position: (usize, usize), + /// Location where the command usage was found + pub(crate) location: Location, /// Extracted parameters from the command/function arguments pub(crate) parameters: Vec, } @@ -66,14 +65,12 @@ impl<'a> CommandUsage<'a> { /// Create a new CommandInstantiationResult pub(crate) fn new( text: Cow<'a, str>, - start_position: (usize, usize), - end_position: (usize, usize), + location: Location, parameters: Vec, ) -> Self { Self { text, - start_position, - end_position, + location, parameters, } } @@ -84,8 +81,7 @@ impl From<&ImportInfo> for CommandUsage<'_> { fn from(value: &ImportInfo) -> Self { Self { text: Cow::Owned(value.statement.clone()), - start_position: value.start_position, - end_position: value.end_position, + location: value.location.clone(), // TODO: parameters should be an Option, so we can distinguish // the case where we fall back to an import statement parameters: vec![], @@ -208,9 +204,7 @@ impl ExtractionUtils { parameters: result.parameters.clone(), return_type: None, expr: result.text.to_string(), - file_path: scanner.ast_grep.source_file.path.clone(), - start_position: result.start_position, - end_position: result.end_position, + location: result.location.clone(), receiver: None, // Commands are typically standalone }), }; @@ -299,9 +293,7 @@ impl ExtractionUtils { parameters: result.parameters.clone(), // extracted from 2nd argument! return_type: None, expr: result.text.to_string(), - file_path: scanner.ast_grep.source_file.path.clone(), - start_position: result.start_position, - end_position: result.end_position, + location: result.location.clone(), receiver: None, }), }; @@ -360,9 +352,7 @@ impl ExtractionUtils { parameters: result.parameters, // Extracted from 2nd argument (operation params) return_type: None, expr: result.text.to_string(), - file_path: scanner.ast_grep.source_file.path.clone(), - start_position: result.start_position, - end_position: result.end_position, + location: result.location.clone(), receiver: None, // Waiter functions are standalone }), }; @@ -421,9 +411,7 @@ impl ExtractionUtils { parameters: Vec::new(), return_type: None, expr: result.text.to_string(), - file_path: scanner.ast_grep.source_file.path.clone(), - start_position: result.start_position, - end_position: result.end_position, + location: result.location.clone(), receiver: None, }), }; @@ -495,9 +483,7 @@ impl ExtractionUtils { parameters: result.parameters.clone(), return_type: None, expr: result.text.to_string(), - file_path: scanner.ast_grep.source_file.path.clone(), - start_position: result.start_position, - end_position: result.end_position, + location: result.location.clone(), receiver: None, }), }; @@ -546,9 +532,7 @@ impl ExtractionUtils { parameters, return_type: None, expr: method_call.expr.clone(), - file_path: method_call.file_path.clone(), - start_position: method_call.start_position, - end_position: method_call.end_position, + location: method_call.location.clone(), receiver: Some(method_call.client_variable.clone()), }), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs index 000063b..2f55f6c 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs @@ -1,10 +1,12 @@ //! JavaScript/TypeScript specific data types for AWS SDK extraction use serde::{Deserialize, Serialize}; -use std::{collections::HashMap, path::PathBuf}; +use std::collections::HashMap; + +use crate::Location; /// Information about a single import with rename support -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub(crate) struct ImportInfo { /// Original name in the AWS SDK (e.g., "S3Client", "PutObjectCommand") pub(crate) original_name: String, @@ -14,10 +16,8 @@ pub(crate) struct ImportInfo { pub(crate) local_name: String, /// Whether this import was renamed (original_name != local_name) pub(crate) is_renamed: bool, - /// Start position of the import - pub(crate) start_position: (usize, usize), - /// End position of the import - pub(crate) end_position: (usize, usize), + /// Location of the import + pub(crate) location: Location, } impl ImportInfo { @@ -26,8 +26,7 @@ impl ImportInfo { original_name: String, local_name: String, statement: &str, - start_position: (usize, usize), - end_position: (usize, usize), + location: Location, ) -> Self { let is_renamed = original_name != local_name; Self { @@ -35,14 +34,13 @@ impl ImportInfo { statement: statement.to_string(), local_name, is_renamed, - start_position, - end_position, + location, } } } /// Information about imports from a specific AWS SDK sublibrary -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub(crate) struct SublibraryInfo { /// AWS SDK sublibrary name (e.g., "client-s3", "lib-dynamodb") pub(crate) sublibrary: String, @@ -121,7 +119,7 @@ impl ValidClientTypes { } /// Information about a method call (non-send) -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub(crate) struct MethodCall { /// Client variable name pub(crate) client_variable: String, @@ -137,16 +135,12 @@ pub(crate) struct MethodCall { pub(crate) arguments: HashMap, /// Matched expression pub(crate) expr: String, - /// File where the method call was found - pub(crate) file_path: PathBuf, - /// Start position where call occurs - pub(crate) start_position: (usize, usize), - /// End position where call occurs - pub(crate) end_position: (usize, usize), + /// Location where the method call was found + pub(crate) location: Location, } /// Combined results from all scanning operations -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub(crate) struct JavaScriptScanResults { /// Import information pub(crate) imports: Vec, diff --git a/iam-policy-autopilot-policy-generation/src/extraction/mod.rs b/iam-policy-autopilot-policy-generation/src/extraction/mod.rs index 3121ad6..8484520 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/mod.rs @@ -29,7 +29,9 @@ pub use self::{core::*, output::*}; pub mod core { use std::sync::Arc; - use crate::Language; + use schemars::JsonSchema; + + use crate::{Language, Location}; use super::{Deserialize, Path, PathBuf, Serialize}; @@ -102,7 +104,7 @@ pub mod core { /// Contains detailed information about a method call including parameters, /// position information, and parsing context. This is optional metadata /// that can be omitted when only basic method identification is needed. - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] + #[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] #[serde(rename_all = "PascalCase")] pub struct SdkMethodCallMetadata { /// List of method parameters with their metadata @@ -114,12 +116,7 @@ pub mod core { pub(crate) expr: String, // Position information - /// File path - pub(crate) file_path: PathBuf, - /// Starting position (line, column) - both 1-based - pub(crate) start_position: (usize, usize), - /// Ending position (line, column) - both 1-based - pub(crate) end_position: (usize, usize), + pub(crate) location: Location, // SDK method call context /// Receiver variable name (e.g., "`s3_client`", "ec2") @@ -192,7 +189,7 @@ pub mod core { } /// Parameter value that distinguishes between resolved literals and unresolved expressions - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] + #[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] pub(crate) enum ParameterValue { /// Resolved string literal with quotes stripped (e.g., "my-bucket", "42", "true") Resolved(String), @@ -217,7 +214,7 @@ pub mod core { /// /// TODO: Refactor enum variant fields into separate structs to enable Default trait /// implementation and improve ergonomics. See: https://github.com/awslabs/iam-policy-autopilot/issues/61 - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] + #[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] pub(crate) enum Parameter { /// Positional argument (e.g., first, second argument in call) Positional { @@ -328,7 +325,7 @@ pub mod output { #[cfg(test)] mod tests { - use crate::Language; + use crate::{Language, Location}; use super::*; use std::path::PathBuf; @@ -360,17 +357,21 @@ mod tests { }], return_type: Some("bool".to_string()), expr: "test_method".to_string(), - file_path: PathBuf::new(), - start_position: (10, 1), - end_position: (10, 25), + location: Location::new(PathBuf::new(), (10, 1), (10, 25)), receiver: None, }), }; assert_eq!(method.name, "test_method"); assert_eq!(method.metadata.as_ref().unwrap().parameters.len(), 1); - assert_eq!(method.metadata.as_ref().unwrap().start_position, (10, 1)); - assert_eq!(method.metadata.as_ref().unwrap().end_position, (10, 25)); + assert_eq!( + method.metadata.as_ref().unwrap().location.start_position, + (10, 1) + ); + assert_eq!( + method.metadata.as_ref().unwrap().location.end_position, + (10, 25) + ); } #[test] @@ -387,9 +388,7 @@ mod tests { }], return_type: Some("Dict[str, Any]".to_string()), expr: "get_object".to_string(), - file_path: PathBuf::new(), - start_position: (15, 5), - end_position: (15, 45), + location: Location::new(PathBuf::new(), (15, 5), (15, 45)), receiver: Some("s3_client".to_string()), }), }; @@ -454,9 +453,7 @@ mod tests { parameters: vec![], return_type: Some("Dict[str, Any]".to_string()), expr: "s3_client.foo_bar".to_string(), - file_path: PathBuf::new(), - start_position: (10, 5), - end_position: (10, 30), + location: Location::new(PathBuf::new(), (10, 5), (10, 30)), receiver: Some("s3_client".to_string()), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs index 3dc3559..f848930 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs @@ -199,6 +199,7 @@ mod tests { Shape, ShapeReference, }; use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; + use crate::Location; use std::collections::HashMap; use std::path::PathBuf; @@ -320,9 +321,7 @@ mod tests { }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -352,9 +351,7 @@ mod tests { // Missing required Stage and ApiId parameters ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 30), + location: Location::new(PathBuf::new(), (1, 1), (1, 30)), receiver: Some("client".to_string()), }), }; @@ -378,9 +375,7 @@ mod tests { position: 0, }], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 30), + location: Location::new(PathBuf::new(), (1, 1), (1, 30)), receiver: Some("client".to_string()), }), }; @@ -407,9 +402,7 @@ mod tests { type_annotation: None, }], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 30), + location: Location::new(PathBuf::new(), (1, 1), (1, 30)), receiver: Some("client".to_string()), }), }; @@ -450,9 +443,7 @@ mod tests { }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -489,9 +480,7 @@ mod tests { }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs index 1983225..b060268 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs @@ -15,6 +15,7 @@ mod tests { use crate::extraction::{ Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, SourceFile, }; + use crate::Location; use std::collections::HashMap; use std::path::PathBuf; @@ -202,9 +203,7 @@ mod tests { }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 80), + location: Location::new(PathBuf::new(), (1, 1), (1, 80)), receiver: Some("apigateway_client".to_string()), }), }; @@ -235,9 +234,7 @@ mod tests { // Missing required Stage and ApiId parameters ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 40), + location: Location::new(PathBuf::new(), (1, 1), (1, 40)), receiver: Some("apigateway_client".to_string()), }), }; @@ -283,9 +280,7 @@ mod tests { }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 100), + location: Location::new(PathBuf::new(), (1, 1), (1, 100)), receiver: Some("apigateway_client".to_string()), }), }; @@ -309,9 +304,7 @@ mod tests { position: 0, }], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("apigateway_client".to_string()), }), }; @@ -344,9 +337,7 @@ mod tests { type_annotation: None, }], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 30), + location: Location::new(PathBuf::new(), (1, 1), (1, 30)), receiver: Some("custom_client".to_string()), }), }; @@ -382,9 +373,7 @@ mod tests { }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("s3_client".to_string()), }), }, @@ -401,9 +390,7 @@ mod tests { type_annotation: None, }], return_type: None, - file_path: PathBuf::new(), - start_position: (2, 1), - end_position: (2, 30), + location: Location::new(PathBuf::new(), (2, 1), (2, 30)), receiver: Some("custom_client".to_string()), }), }, @@ -420,9 +407,7 @@ mod tests { type_annotation: None, }], return_type: None, - file_path: PathBuf::new(), - start_position: (3, 1), - end_position: (3, 25), + location: Location::new(PathBuf::new(), (3, 1), (3, 25)), receiver: Some("custom_client".to_string()), }), }, @@ -557,9 +542,7 @@ def example(): }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 80), + location: Location::new(PathBuf::new(), (1, 1), (1, 80)), receiver: Some("s3_client".to_string()), }), }; @@ -610,9 +593,7 @@ def example(): }, ], return_type: None, - file_path: PathBuf::new(), - start_position: (1, 1), - end_position: (1, 60), + location: Location::new(PathBuf::new(), (1, 1), (1, 60)), receiver: Some("s3_client".to_string()), }), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs index 83160ad..0ad48c6 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs @@ -7,7 +7,7 @@ use crate::extraction::python::paginator_extractor::PaginatorExtractor; use crate::extraction::python::resource_direct_calls_extractor::ResourceDirectCallsExtractor; use crate::extraction::python::waiters_extractor::WaitersExtractor; use crate::extraction::{AstWithSourceFile, SdkMethodCall, SdkMethodCallMetadata}; -use crate::{ServiceModelIndex, SourceFile}; +use crate::{Location, ServiceModelIndex, SourceFile}; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Python; use async_trait::async_trait; @@ -44,11 +44,6 @@ impl PythonExtractor { let args_nodes = env.get_multiple_matches("ARGS"); let arguments = ArgumentExtractor::extract_arguments(&args_nodes); - // Get position information - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - let method_call = SdkMethodCall { name: method_name.to_string(), possible_services: Vec::new(), // Will be determined later during service validation @@ -56,9 +51,10 @@ impl PythonExtractor { parameters: arguments, return_type: None, // We don't know the return type from the call site expr: node_match.text().to_string(), - file_path: source_file.path.clone(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node( + source_file.path.to_path_buf(), + node_match.get_node(), + ), receiver, }), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs index 3723d6f..a11dc1a 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs @@ -4,13 +4,13 @@ //! two-phase operations: creating a paginator from a client, then executing //! the paginator with operation arguments. -use std::path::{Path, PathBuf}; +use std::path::Path; use crate::extraction::python::common::{ArgumentExtractor, ParameterFilter}; use crate::extraction::sdk_model::ServiceDiscovery; use crate::extraction::{AstWithSourceFile, Parameter, SdkMethodCall, SdkMethodCallMetadata}; -use crate::Language; use crate::ServiceModelIndex; +use crate::{Language, Location}; use ast_grep_language::Python; /// Information about a discovered get_paginator call @@ -24,10 +24,14 @@ pub(crate) struct PaginatorInfo { pub client_receiver: String, /// Matched expression pub expr: String, - /// File the paginator was found in - pub file_path: PathBuf, - /// Line number where get_paginator was called - pub get_paginator_line: usize, + /// Location where get_paginator was called + pub location: Location, +} + +impl PaginatorInfo { + fn start_line(&self) -> usize { + self.location.start_line() + } } /// Information about a paginate method call @@ -37,16 +41,16 @@ pub(crate) struct PaginateCallInfo { pub paginator_var: String, /// Extracted arguments (excluding pagination-specific ones) pub arguments: Vec, - /// Line number where paginate was called (preferred for position reporting) - pub paginate_line: usize, /// Matched expression pub expr: String, - /// File the paginator was found in - pub file_path: PathBuf, - /// Start position of the paginate call node - pub start_position: (usize, usize), - /// End position of the paginate call node - pub end_position: (usize, usize), + /// Location where paginator was called + pub location: Location, +} + +impl PaginateCallInfo { + fn start_line(&self) -> usize { + self.location.start_line() + } } /// Information about a chained paginator call (client.get_paginator().paginate()) @@ -58,17 +62,10 @@ pub(crate) struct ChainedPaginatorCallInfo { pub operation_name: String, /// Extracted arguments from paginate call (excluding pagination-specific ones) pub arguments: Vec, - /// Line number where chained call was made - #[allow(dead_code)] - pub line: usize, /// Matched expression pub expr: String, - /// File the chained paginator call was found in - pub file_path: PathBuf, - /// Start position of the chained call node - pub start_position: (usize, usize), - /// End position of the chained call node - pub end_position: (usize, usize), + /// Location where paginator was called + pub location: Location, } /// Extractor for boto3 paginate method patterns @@ -234,16 +231,12 @@ impl<'a> PaginatorExtractor<'a> { let operation_text = operation_node.text(); let operation_name = self.extract_quoted_string(&operation_text)?; - // Get line number - let get_paginator_line = node_match.get_node().start_pos().line() + 1; - Some(PaginatorInfo { variable_name, operation_name, client_receiver, - get_paginator_line, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -263,19 +256,11 @@ impl<'a> PaginatorExtractor<'a> { let all_arguments = self.extract_arguments(&args_nodes); let filtered_arguments = self.filter_pagination_parameters(all_arguments); - // Get position information from the paginate call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(PaginateCallInfo { paginator_var, arguments: filtered_arguments, - paginate_line: start.line() + 1, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -300,20 +285,12 @@ impl<'a> PaginatorExtractor<'a> { let all_arguments = self.extract_arguments(&paginate_args_nodes); let filtered_arguments = self.filter_pagination_parameters(all_arguments); - // Get position information from the chained call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(ChainedPaginatorCallInfo { client_receiver, operation_name, arguments: filtered_arguments, - line: start.line() + 1, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -340,8 +317,8 @@ impl<'a> PaginatorExtractor<'a> { for (idx, paginator) in paginators.iter().enumerate() { if paginator.variable_name == paginate_call.paginator_var { // Only consider paginators that come before the paginate call - if paginator.get_paginator_line < paginate_call.paginate_line { - let distance = paginate_call.paginate_line - paginator.get_paginator_line; + if paginator.start_line() < paginate_call.start_line() { + let distance = paginate_call.start_line() - paginator.start_line(); if distance < best_distance { best_distance = distance; best_match = Some(paginator); @@ -385,9 +362,7 @@ impl<'a> PaginatorExtractor<'a> { return_type: None, expr: paginator_info.expr.clone(), // Use get_paginator call position - file_path: paginator_info.file_path.clone(), - start_position: (paginator_info.get_paginator_line, 1), - end_position: (paginator_info.get_paginator_line, 1), + location: paginator_info.location.clone(), receiver: Some(paginator_info.client_receiver.clone()), }), } @@ -422,9 +397,7 @@ impl<'a> PaginatorExtractor<'a> { return_type: None, expr: paginate_call.expr.clone(), // Use paginate call position (most specific) - file_path: paginate_call.file_path.clone(), - start_position: paginate_call.start_position, - end_position: paginate_call.end_position, + location: paginate_call.location.clone(), // Use client receiver from get_paginator call receiver: Some(paginator_info.client_receiver.clone()), }), @@ -459,9 +432,7 @@ impl<'a> PaginatorExtractor<'a> { return_type: None, expr: chained_call.expr.clone(), // Use chained call position - file_path: chained_call.file_path.clone(), - start_position: chained_call.start_position, - end_position: chained_call.end_position, + location: chained_call.location.clone(), // Use client receiver from chained call receiver: Some(chained_call.client_receiver.clone()), }), @@ -486,7 +457,7 @@ mod tests { use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Python; - use std::collections::HashMap; + use std::{collections::HashMap, path::PathBuf}; fn create_test_ast(source_code: &str) -> AstWithSourceFile { let source_file = SourceFile::with_language( @@ -733,7 +704,7 @@ page_iterator = paginator.paginate(Bucket='test-bucket') assert_eq!(call.name, "list_objects_v2"); // Position should be from the paginate call (line 5), not get_paginator call (line 4) - assert_eq!(call.metadata.as_ref().unwrap().start_position.0, 5); + assert_eq!(call.metadata.as_ref().unwrap().location.start_line(), 5); assert_eq!( call.metadata.as_ref().unwrap().receiver, Some("client".to_string()) diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs index d65b1de..66fd240 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs @@ -41,18 +41,11 @@ use crate::extraction::python::common::ArgumentExtractor; use crate::extraction::{ AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, }; -use crate::ServiceModelIndex; +use crate::{Location, ServiceModelIndex}; use ast_grep_language::Python; use convert_case::{Case, Casing}; use std::collections::{HashMap, HashSet}; -use std::path::{Path, PathBuf}; - -/// Position tracking for deduplication (Tier 3) -#[derive(Debug, Clone, Hash, Eq, PartialEq)] -struct MatchedPosition { - line: usize, - column: usize, -} +use std::path::Path; /// Information about a discovered resource constructor call #[derive(Debug, Clone)] @@ -73,11 +66,14 @@ struct ResourceMethodCallInfo { resource_var: String, method_name: String, arguments: Vec, - method_call_line: usize, expr: String, - file_path: PathBuf, - start_position: (usize, usize), - end_position: (usize, usize), + location: Location, +} + +impl ResourceMethodCallInfo { + fn start_line(&self) -> usize { + self.location.start_line() + } } /// Resource usage classification for two-tier approach (Tier 3 is now separate) @@ -178,18 +174,15 @@ impl<'a> ResourceDirectCallsExtractor<'a> { all_calls.extend(collection_synthetics); // Step 6: Collect matched positions for Tier 3 deduplication - let mut matched_positions = HashSet::new(); + let mut matched_locations = HashSet::new(); for call in &all_calls { if let Some(metadata) = &call.metadata { - matched_positions.insert(MatchedPosition { - line: metadata.start_position.0, - column: metadata.start_position.1, - }); + matched_locations.insert(metadata.location.clone()); } } // Step 7: New Tier 3 - service-agnostic fallback for unknown receivers - let tier3_calls = self.find_unmatched_utility_and_collection_calls(ast, &matched_positions); + let tier3_calls = self.find_unmatched_utility_and_collection_calls(ast, &matched_locations); all_calls.extend(tier3_calls); all_calls @@ -420,9 +413,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { parameters, return_type: None, expr: method_call.expr.clone(), - file_path: method_call.file_path.clone(), - start_position: method_call.start_position, - end_position: method_call.end_position, + location: method_call.location.clone(), receiver: Some(method_call.resource_var.clone()), }), }); @@ -446,13 +437,10 @@ impl<'a> ResourceDirectCallsExtractor<'a> { let mut synthetic_calls = Vec::new(); // Extract position from evidence source - let (expr, file_path, start_pos, end_pos) = match evidence { - SyntheticEvidenceSource::UnmatchedMethod(ref method_call) => ( - method_call.expr.clone(), - method_call.file_path.clone(), - method_call.start_position, - method_call.end_position, - ), + let (expr, location) = match evidence { + SyntheticEvidenceSource::UnmatchedMethod(ref method_call) => { + (method_call.expr.clone(), method_call.location.clone()) + } }; // Generate synthetic call for each action @@ -507,9 +495,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { parameters, return_type: None, expr: expr.clone(), - file_path: file_path.clone(), - start_position: start_pos, - end_position: end_pos, + location: location.clone(), receiver: Some(constructor.variable_name.clone()), // Use actual variable name from constructor }), }); @@ -625,12 +611,12 @@ impl<'a> ResourceDirectCallsExtractor<'a> { a.resource_var .cmp(&b.resource_var) .then(a.method_name.cmp(&b.method_name)) - .then(a.method_call_line.cmp(&b.method_call_line)) + .then(a.start_line().cmp(&b.start_line())) }); method_calls.dedup_by(|a, b| { a.resource_var == b.resource_var && a.method_name == b.method_name - && a.method_call_line == b.method_call_line + && a.start_line() == b.start_line() }); method_calls @@ -654,20 +640,12 @@ impl<'a> ResourceDirectCallsExtractor<'a> { let args_nodes = env.get_multiple_matches("ARGS"); let arguments = ArgumentExtractor::extract_arguments(&args_nodes); - // Get position information from the method call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(ResourceMethodCallInfo { resource_var, method_name, arguments, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), - method_call_line: start.line() + 1, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -689,20 +667,12 @@ impl<'a> ResourceDirectCallsExtractor<'a> { let args_nodes = env.get_multiple_matches("ARGS"); let arguments = ArgumentExtractor::extract_arguments(&args_nodes); - // Get position information from the method call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(ResourceMethodCallInfo { resource_var, method_name, arguments, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), - method_call_line: start.line() + 1, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -827,9 +797,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { parameters: combined_parameters, return_type: None, expr: method_call.expr.clone(), - file_path: method_call.file_path.clone(), - start_position: method_call.start_position, - end_position: method_call.end_position, + location: method_call.location.clone(), receiver: Some(method_call.resource_var.clone()), }), }) @@ -886,24 +854,21 @@ impl<'a> ResourceDirectCallsExtractor<'a> { }; // Check if this attribute matches a hasMany collection (in snake_case) - if let Some(has_many_spec) = - boto3_model.get_has_many_spec(&constructor.resource_type, &attr_name) + if let Some(synthetic_call) = boto3_model + .get_has_many_spec(&constructor.resource_type, &attr_name) + .and_then(|has_many_spec| { + self.generate_synthetic_for_collection( + constructor, + has_many_spec, + node_match.text().to_string(), + Location::from_node( + ast.source_file.path.to_path_buf(), + node_match.get_node(), + ), + ) + }) { - // Generate synthetic call for the collection's operation - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - - if let Some(synthetic_call) = self.generate_synthetic_for_collection( - constructor, - has_many_spec, - node_match.text().to_string(), - &ast.source_file.path, - (start.line() + 1, start.column(node) + 1), - (end.line() + 1, end.column(node) + 1), - ) { - synthetic_calls.push(synthetic_call); - } + synthetic_calls.push(synthetic_call); } } } @@ -917,9 +882,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { constructor: &ResourceConstructorInfo, has_many_spec: &HasManySpec, expr: String, - file_path: &Path, - start_position: (usize, usize), - end_position: (usize, usize), + location: Location, ) -> Option { let mut parameters = Vec::new(); @@ -954,9 +917,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { parameters, return_type: None, expr: expr.clone(), - file_path: file_path.to_path_buf(), - start_position, - end_position, + location, receiver: Some(constructor.variable_name.clone()), // Use actual variable name from constructor }), }) @@ -970,15 +931,15 @@ impl<'a> ResourceDirectCallsExtractor<'a> { fn find_unmatched_utility_and_collection_calls( &self, ast: &AstWithSourceFile, - matched_positions: &HashSet, + matched_locations: &HashSet, ) -> Vec { let mut tier3_calls = Vec::new(); // Search for utility method calls across all services - tier3_calls.extend(self.find_unmatched_utility_method_calls(ast, matched_positions)); + tier3_calls.extend(self.find_unmatched_utility_method_calls(ast, matched_locations)); // Search for collection accesses across all services - tier3_calls.extend(self.find_unmatched_collection_accesses(ast, matched_positions)); + tier3_calls.extend(self.find_unmatched_collection_accesses(ast, matched_locations)); tier3_calls } @@ -987,7 +948,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { fn find_unmatched_utility_method_calls( &self, ast: &AstWithSourceFile, - matched_positions: &HashSet, + matched_locations: &HashSet, ) -> Vec { let root = ast.ast.root(); let mut calls = Vec::new(); @@ -1011,16 +972,11 @@ impl<'a> ResourceDirectCallsExtractor<'a> { None => continue, }; - // Get position - let node = node_match.get_node(); - let start = node.start_pos(); - let position = MatchedPosition { - line: start.line() + 1, - column: start.column(node) + 1, - }; + let location = + Location::from_node(ast.source_file.path.clone(), node_match.get_node()); // Skip if already matched in Tier 1/2 - if matched_positions.contains(&position) { + if matched_locations.contains(&location) { continue; } @@ -1048,9 +1004,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &arguments, &operation.required_params, node_match.text().to_string(), - &ast.source_file.path, - (start.line() + 1, start.column(node) + 1), - (node.end_pos().line() + 1, node.end_pos().column(node) + 1), + &location, &receiver_var, // Use actual receiver from code )); } @@ -1068,9 +1022,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &arguments, &operation.required_params, node_match.text().to_string(), - &ast.source_file.path, - (start.line() + 1, start.column(node) + 1), - (node.end_pos().line() + 1, node.end_pos().column(node) + 1), + &location, &receiver_var, // Use actual receiver from code )); } @@ -1087,7 +1039,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { fn find_unmatched_collection_accesses( &self, ast: &AstWithSourceFile, - matched_positions: &HashSet, + matched_locations: &HashSet, ) -> Vec { let root = ast.ast.root(); let mut calls = Vec::new(); @@ -1116,16 +1068,11 @@ impl<'a> ResourceDirectCallsExtractor<'a> { None => continue, }; - // Get position - let node = node_match.get_node(); - let start = node.start_pos(); - let position = MatchedPosition { - line: start.line() + 1, - column: start.column(node) + 1, - }; + let location = + Location::from_node(ast.source_file.path.clone(), node_match.get_node()); // Skip if already matched in Tier 1/2 - if matched_positions.contains(&position) { + if matched_locations.contains(&location) { continue; } @@ -1144,12 +1091,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { ), return_type: None, expr: node_match.text().to_string(), - file_path: ast.source_file.path.clone(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: ( - node.end_pos().line() + 1, - node.end_pos().column(node) + 1, - ), + location: location.clone(), receiver: Some(receiver_var.clone()), // Use actual receiver from code }), }); @@ -1171,12 +1113,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { ), return_type: None, expr: node_match.text().to_string(), - file_path: ast.source_file.path.clone(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: ( - node.end_pos().line() + 1, - node.end_pos().column(node) + 1, - ), + location: location.clone(), receiver: Some(receiver_var.clone()), // Use actual receiver from code }), }); @@ -1197,9 +1134,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { arguments: &[Parameter], required_params: &[String], expr: String, - file_path: &Path, - start_position: (usize, usize), - end_position: (usize, usize), + location: &Location, receiver_marker: &str, ) -> SdkMethodCall { let mut parameters = Vec::new(); @@ -1235,9 +1170,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { parameters, return_type: None, expr: expr.clone(), - file_path: file_path.to_path_buf(), - start_position, - end_position, + location: location.clone(), receiver: Some(receiver_marker.to_string()), }), } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs index d595932..2ed755d 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs @@ -4,13 +4,13 @@ //! two-phase operations: creating a waiter from a client, then calling wait() //! on the waiter with operation arguments. -use std::path::{Path, PathBuf}; +use std::path::Path; use crate::extraction::python::common::{ArgumentExtractor, ParameterFilter}; use crate::extraction::{ AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, }; -use crate::ServiceModelIndex; +use crate::{Location, ServiceModelIndex}; use ast_grep_language::Python; /// Information about a discovered get_waiter call @@ -24,12 +24,14 @@ pub(crate) struct WaiterInfo { pub client_receiver: String, /// Matched expression pub expr: String, - /// File where we found the waiter - pub file_path: PathBuf, - /// Start position where get_waiter was called - pub start_position: (usize, usize), - /// End position where get_waiter was called - pub end_position: (usize, usize), + /// Location where we found the waiter + pub location: Location, +} + +impl WaiterInfo { + pub(crate) fn start_line(&self) -> usize { + self.location.start_line() + } } // TODO: This should be refactored at a higher level, so this type can be removed. @@ -56,27 +58,11 @@ impl<'a> CallInfo<'a> { } } - fn file_path(&self) -> &'a PathBuf { + fn location(&self) -> &'a Location { match self { - CallInfo::None(waiter_info) => &waiter_info.file_path, - CallInfo::Simple(_, wait_call_info) => &wait_call_info.file_path, - CallInfo::Chained(chained_waiter_call_info) => &chained_waiter_call_info.file_path, - } - } - - fn start_position(&self) -> (usize, usize) { - match self { - CallInfo::None(waiter_info) => waiter_info.start_position, - CallInfo::Simple(_, wait_call_info) => wait_call_info.start_position, - CallInfo::Chained(chained_waiter_call_info) => chained_waiter_call_info.start_position, - } - } - - fn end_position(&self) -> (usize, usize) { - match self { - CallInfo::None(waiter_info) => waiter_info.end_position, - CallInfo::Simple(_, wait_call_info) => wait_call_info.end_position, - CallInfo::Chained(chained_waiter_call_info) => chained_waiter_call_info.end_position, + CallInfo::None(waiter_info) => &waiter_info.location, + CallInfo::Simple(_, wait_call_info) => &wait_call_info.location, + CallInfo::Chained(chained_waiter_call_info) => &chained_waiter_call_info.location, } } } @@ -90,12 +76,14 @@ pub(crate) struct WaitCallInfo { pub arguments: Vec, /// Matched expression pub expr: String, - /// File where we found the waiter - pub file_path: PathBuf, - /// Start position of the wait call node - pub start_position: (usize, usize), - /// End position of the wait call node - pub end_position: (usize, usize), + /// Location where we found the waiter + pub location: Location, +} + +impl WaitCallInfo { + pub(crate) fn start_line(&self) -> usize { + self.location.start_line() + } } /// Information about a chained waiter call (client.get_waiter().wait()) @@ -107,17 +95,10 @@ pub(crate) struct ChainedWaiterCallInfo { pub waiter_name: String, /// Extracted arguments from wait call (including WaiterConfig) pub arguments: Vec, - /// Line number where chained call was made - #[allow(dead_code)] - pub line: usize, /// Matched expression pub expr: String, - /// File we found the chained waiter call in - pub file_path: PathBuf, - /// Start position of the chained call node - pub start_position: (usize, usize), - /// End position of the chained call node - pub end_position: (usize, usize), + /// Location where we found the waiter call + pub location: Location, } /// Extractor for boto3 waiter patterns @@ -275,18 +256,12 @@ impl<'a> WaitersExtractor<'a> { let name_text = name_node.text(); let waiter_name = self.extract_quoted_string(&name_text)?; - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(WaiterInfo { variable_name, waiter_name, client_receiver, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -305,18 +280,11 @@ impl<'a> WaitersExtractor<'a> { let args_nodes = env.get_multiple_matches("ARGS"); let arguments = ArgumentExtractor::extract_arguments(&args_nodes); - // Get position information from the wait call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(WaitCallInfo { waiter_var, arguments, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -340,20 +308,12 @@ impl<'a> WaitersExtractor<'a> { let wait_args_nodes = env.get_multiple_matches("WAIT_ARGS"); let arguments = ArgumentExtractor::extract_arguments(&wait_args_nodes); - // Get position information from the chained call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(ChainedWaiterCallInfo { client_receiver, waiter_name, arguments, - line: start.line() + 1, expr: node_match.text().to_string(), - file_path: file_path.to_path_buf(), - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -372,8 +332,8 @@ impl<'a> WaitersExtractor<'a> { for (idx, waiter) in waiters.iter().enumerate() { if waiter.variable_name == wait_call.waiter_var { // Only consider waiters that come before the wait call - if waiter.start_position.0 < wait_call.start_position.0 { - let distance = wait_call.start_position.0 - waiter.start_position.0; + if waiter.start_line() < wait_call.start_line() { + let distance = wait_call.start_line() - waiter.start_line(); if distance < best_distance { best_distance = distance; best_match = Some(waiter); @@ -428,9 +388,7 @@ impl<'a> WaitersExtractor<'a> { parameters, return_type: None, expr: wait_call.expr().to_string(), - file_path: wait_call.file_path().clone(), - start_position: wait_call.start_position(), - end_position: wait_call.end_position(), + location: wait_call.location().clone(), // Use client receiver from get_waiter call receiver: receiver.clone(), }), @@ -521,6 +479,7 @@ mod tests { use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Python; use std::collections::HashMap; + use std::path::PathBuf; fn create_test_ast(source_code: &str) -> AstWithSourceFile { let source_file = SourceFile::with_language( @@ -687,7 +646,7 @@ waiter = ec2_client.get_waiter('instance_terminated') assert_eq!(waiters[0].variable_name, "waiter"); assert_eq!(waiters[0].waiter_name, "instance_terminated"); assert_eq!(waiters[0].client_receiver, "ec2_client"); - assert_eq!(waiters[0].start_position.0, 4); + assert_eq!(waiters[0].start_line(), 4); } #[test] diff --git a/iam-policy-autopilot-policy-generation/src/lib.rs b/iam-policy-autopilot-policy-generation/src/lib.rs index 2813fef..e9fe721 100644 --- a/iam-policy-autopilot-policy-generation/src/lib.rs +++ b/iam-policy-autopilot-policy-generation/src/lib.rs @@ -32,8 +32,9 @@ pub mod policy_generation; pub mod api; use std::fmt::Display; +use std::path::PathBuf; -pub use enrichment::{Engine as EnrichmentEngine, Explanation, Location}; +pub use enrichment::{Engine as EnrichmentEngine, Explanation}; pub use extraction::{Engine as ExtractionEngine, ExtractedMethods, SdkMethodCall, SourceFile}; pub use policy_generation::{ Effect, Engine as PolicyGenerationEngine, IamPolicy, PolicyType, PolicyWithMetadata, Statement, @@ -44,6 +45,9 @@ pub(crate) use extraction::ServiceModelIndex; pub use providers::FileSystemProvider; pub use providers::JsonProvider; +use schemars::JsonSchema; +use serde::Deserialize; +use serde::Serialize; use crate::errors::ExtractorError; @@ -133,6 +137,155 @@ impl From for String { } } +/// Represents a location in a source file +/// +/// This struct stores file path and position information and serializes +/// to the GNU coding standard (https://www.gnu.org/prep/standards/html_node/Errors.html) +/// format: `filename:startLine.startCol-endLine.endCol` +#[derive(Debug, Clone, PartialEq, Eq, Hash, JsonSchema)] +#[schemars( + description = "File location in GNU coding standard format: filename:startLine.startCol-endLine.endCol" +)] +pub struct Location { + /// File path + pub file_path: PathBuf, + /// Starting position (line, column) - both 1-based + pub start_position: (usize, usize), + /// Ending position (line, column) - both 1-based + pub end_position: (usize, usize), +} + +impl Location { + /// Create a new Location + #[must_use] + pub fn new( + file_path: PathBuf, + start_position: (usize, usize), + end_position: (usize, usize), + ) -> Self { + Self { + file_path, + start_position, + end_position, + } + } + + /// Create a new Location from an AST node + #[must_use] + pub fn from_node( + file_path: PathBuf, + node: &ast_grep_core::Node>, + ) -> Self + where + T: ast_grep_language::LanguageExt, + { + let start = node.start_pos(); + let end = node.end_pos(); + Self { + file_path, + start_position: (start.line() + 1, start.column(node) + 1), + end_position: (end.line() + 1, end.column(node) + 1), + } + } + + /// Line where the finding starts + pub fn start_line(&self) -> usize { + self.start_position.0 + } + + /// Column where the finding starts + pub fn start_col(&self) -> usize { + self.start_position.1 + } + + /// Line where the finding ends + pub fn end_line(&self) -> usize { + self.end_position.0 + } + + /// Column where the finding ends + pub fn end_col(&self) -> usize { + self.end_position.1 + } + + /// Format as GNU coding standard: `filename:startLine.startCol-endLine.endCol` + #[must_use] + pub fn to_gnu_format(&self) -> String { + let path_str = self.file_path.display(); + let (start_line, start_col) = self.start_position; + let (end_line, end_col) = self.end_position; + + format!( + "{}:{}.{}-{}.{}", + path_str, start_line, start_col, end_line, end_col + ) + } +} + +impl Serialize for Location { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_gnu_format()) + } +} + +impl<'de> Deserialize<'de> for Location { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Self::from_gnu_format(&s).map_err(serde::de::Error::custom) + } +} + +impl Location { + /// Parse a Location from GNU coding standard format: `filename:startLine.startCol-endLine.endCol` + pub fn from_gnu_format(s: &str) -> Result { + // Find the colon that separates filename from position + let colon_pos = s.rfind(':').ok_or("Missing colon separator")?; + let (file_path_str, position_str) = s.split_at(colon_pos); + let position_str = &position_str[1..]; // Remove the colon + + // Parse the position part: startLine.startCol-endLine.endCol + let dash_pos = position_str.find('-').ok_or("Missing dash separator")?; + let (start_str, end_str) = position_str.split_at(dash_pos); + let end_str = &end_str[1..]; // Remove the dash + + // Parse start position: startLine.startCol + let start_dot_pos = start_str.find('.').ok_or("Missing dot in start position")?; + let (start_line_str, start_col_str) = start_str.split_at(start_dot_pos); + let start_col_str = &start_col_str[1..]; // Remove the dot + + // Parse end position: endLine.endCol + let end_dot_pos = end_str.find('.').ok_or("Missing dot in end position")?; + let (end_line_str, end_col_str) = end_str.split_at(end_dot_pos); + let end_col_str = &end_col_str[1..]; // Remove the dot + + // Convert strings to numbers + let start_line = start_line_str + .parse::() + .map_err(|_| format!("Invalid start line: {}", start_line_str))?; + let start_col = start_col_str + .parse::() + .map_err(|_| format!("Invalid start column: {}", start_col_str))?; + let end_line = end_line_str + .parse::() + .map_err(|_| format!("Invalid end line: {}", end_line_str))?; + let end_col = end_col_str + .parse::() + .map_err(|_| format!("Invalid end column: {}", end_col_str))?; + + Ok(Location { + file_path: PathBuf::from(file_path_str), + start_position: (start_line, start_col), + end_position: (end_line, end_col), + }) + } +} + #[cfg(test)] mod tests { use super::*; @@ -172,4 +325,68 @@ mod tests { assert!(Language::try_from_str("java").is_err()); assert!(Language::try_from_str("").is_err()); } + + #[test] + fn test_location_gnu_format() { + let location = Location::new(PathBuf::from("src/main.rs"), (10, 5), (15, 20)); + + let gnu_format = location.to_gnu_format(); + assert_eq!(gnu_format, "src/main.rs:10.5-15.20"); + } + + #[test] + fn test_location_from_gnu_format() { + let gnu_str = "src/main.rs:10.5-15.20"; + let location = Location::from_gnu_format(gnu_str).unwrap(); + + assert_eq!(location.file_path, PathBuf::from("src/main.rs")); + assert_eq!(location.start_position, (10, 5)); + assert_eq!(location.end_position, (15, 20)); + } + + #[test] + fn test_location_from_gnu_format_with_complex_path() { + let gnu_str = "/home/user/project/src/lib.rs:1.1-100.50"; + let location = Location::from_gnu_format(gnu_str).unwrap(); + + assert_eq!( + location.file_path, + PathBuf::from("/home/user/project/src/lib.rs") + ); + assert_eq!(location.start_position, (1, 1)); + assert_eq!(location.end_position, (100, 50)); + } + + #[test] + fn test_location_serialize_deserialize_roundtrip() { + let original = Location::new(PathBuf::from("test/file.py"), (42, 13), (45, 7)); + + // Serialize to JSON + let json = serde_json::to_string(&original).unwrap(); + assert_eq!(json, "\"test/file.py:42.13-45.7\""); + + // Deserialize back + let deserialized: Location = serde_json::from_str(&json).unwrap(); + + assert_eq!(original, deserialized); + } + + #[test] + fn test_location_from_gnu_format_invalid_formats() { + // Missing colon + assert!(Location::from_gnu_format("src/main.rs10.5-15.20").is_err()); + + // Missing dash + assert!(Location::from_gnu_format("src/main.rs:10.515.20").is_err()); + + // Missing dots + assert!(Location::from_gnu_format("src/main.rs:105-1520").is_err()); + + // Invalid numbers + assert!(Location::from_gnu_format("src/main.rs:abc.5-15.20").is_err()); + assert!(Location::from_gnu_format("src/main.rs:10.xyz-15.20").is_err()); + + // Empty string + assert!(Location::from_gnu_format("").is_err()); + } } diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs index b439068..16ddd34 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs @@ -306,7 +306,7 @@ mod tests { use crate::{Explanation, SdkMethodCall}; use super::super::Effect; - use crate::enrichment::{Action, EnrichedSdkMethodCall, OperationView, Resource}; + use crate::enrichment::{Action, EnrichedSdkMethodCall, Resource}; use crate::errors::ExtractorError; fn create_test_engine() -> Engine<'static> { @@ -922,10 +922,7 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call, "s3"), - fas: None, - }], + reasons: vec![Reason::new(&sdk_call, "s3", None)], }, )], sdk_method_call: &sdk_call, @@ -975,10 +972,7 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call1, "s3"), - fas: None, - }], + reasons: vec![Reason::new(&sdk_call1, "s3", None)], }, )], sdk_method_call: &sdk_call1, @@ -997,10 +991,7 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call1, "s3"), - fas: None, - }], + reasons: vec![Reason::new(&sdk_call1, "s3", None)], }, )], sdk_method_call: &sdk_call2, @@ -1048,10 +1039,7 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call, "s3"), - fas: None, - }], + reasons: vec![Reason::new(&sdk_call, "s3", None)], }, ), Action::new( @@ -1064,13 +1052,14 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call, "s3"), - fas: Some(crate::enrichment::FasInfo::new(vec![ + reasons: vec![Reason::new( + &sdk_call, + "s3", + Some(crate::enrichment::FasInfo::new(vec![ "s3:GetObject".to_string(), "kms:Decrypt".to_string(), ])), - }], + )], }, ), ], @@ -1127,10 +1116,7 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason { - operation: OperationView::from_call(&sdk_call, "s3"), - fas: None, - }], + reasons: vec![Reason::new(&sdk_call, "s3", None)], }, )], sdk_method_call: &sdk_call, From fc36e6dbec0e0380457d92eec0cfeb780ac7f482 Mon Sep 17 00:00:00 2001 From: Matthias Schlaipfer Date: Fri, 9 Jan 2026 20:20:33 +0000 Subject: [PATCH 04/11] Refactor resource_matcher and operation --- Cargo.toml | 2 +- .../src/enrichment/mod.rs | 290 +++++++---- .../src/enrichment/operation_fas_map.rs | 55 +-- .../src/enrichment/resource_matcher.rs | 457 +++++++----------- .../src/extraction/mod.rs | 3 +- .../src/policy_generation/engine.rs | 86 ++-- .../src/service_configuration.rs | 6 + 7 files changed, 452 insertions(+), 447 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 78fca62..c2f8068 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ description = "An open source Model Context Protocol (MCP) server and command-li # Shared dependency versions across workspace [workspace.dependencies] # Core serialization and utilities -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", features = ["derive", "rc"] } futures = "0.3" thiserror = "1.0" anyhow = "1.0" diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs index 07d1b24..b1301ae 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs @@ -8,9 +8,10 @@ //! that represent method calls enriched with IAM metadata from operation //! action maps and Service Definition Files. -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; -use crate::{extraction::SdkMethodCallMetadata, SdkMethodCall}; +use crate::{SdkMethodCall, SdkType, enrichment::operation_fas_map::{FasContext, FasOperation}, extraction::SdkMethodCallMetadata, service_configuration::ServiceConfiguration}; +use convert_case::{Case, Casing}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -27,103 +28,144 @@ pub(crate) use service_reference::RemoteServiceReferenceLoader as ServiceReferen const FAS_URL: &str = "https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html"; -/// Represents Forward Access Session (FAS) expansion information +/// Represents the reason why an action was added to a policy +#[derive(derive_new::new, Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] +#[serde(rename_all = "PascalCase")] +pub struct Reason { + /// The original operation that was extracted + pub operations: Vec>, +} + #[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] #[serde(rename_all = "PascalCase")] -pub struct FasInfo { - /// Explanation URL for Forward Access Sessions - pub explanation: &'static str, - /// The chain of operations in the FAS expansion - pub expansion: Vec, +pub struct Operation { + /// Name of the operation + pub name: String, + /// Name of the service + pub service: String, + /// Source of the operation, + pub source: OperationSource, + /// Disallow struct construction, need to use Self::from_call or Operation::from(FasOperation) + _private: () } -impl FasInfo { - /// Create a new FasInfo with the standard AWS documentation URL - #[must_use] - pub fn new(expansion: Vec) -> Self { +impl Operation { + #[cfg(test)] + /// Convenience constructor for tests + pub(crate) fn new(name: String, service: String, source: OperationSource) -> Self { Self { - explanation: FAS_URL, - expansion, + name, service, source, _private: () } } -} -/// Represents the reason why an action was added to a policy -#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] -#[serde(rename_all = "PascalCase")] -pub struct Reason { - /// The original operation that was extracted - pub initial_operation: Operation, - /// Source of the operation - pub source: OperationSource, - /// Optional FAS expansion information - pub fas: Option, -} + pub(crate) fn service_operation_name(&self) -> String { + format!("{}:{}", self.service, self.name) + } -impl Reason { - pub(crate) fn new( + pub(crate) fn context(&self) -> &[FasContext] { + match &self.source { + OperationSource::Fas(context) => context, + _ => &[], + } + } + + pub(crate) async fn from_call( call: &SdkMethodCall, original_service_name: &str, - fas: Option, - ) -> Self { - let initial_operation = - Operation::new(call.name.clone(), original_service_name.to_string()); - match &call.metadata { + service_cfg: &ServiceConfiguration, + sdk: SdkType, + service_reference_loader: &ServiceReferenceLoader, + ) -> crate::errors::Result { + let service = service_cfg.rename_service_service_reference(original_service_name).to_string(); + let name = if sdk == SdkType::Boto3 { + // Try to load service reference and look up the boto3 method mapping + service_reference_loader + .load(&service) + .await? + .and_then(|service_ref| { + log::debug!("Looking up method {}", call.name); + service_ref + .boto3_method_to_operation + .get(&call.name) + .map(|op| { + log::debug!("got {:?}", op); + op.split(':').nth(1).unwrap_or(op).to_string() + }) + }) + // Fallback to PascalCase conversion if mapping not found + // This should not be reachable, but if for some reason we cannot use the SDF, + // we try converting to PascalCase, knowing that this is flawed in some cases: + // think `AddRoleToDBInstance` (actual name) + // vs. `AddRoleToDbInstance` (converted name) + .unwrap_or_else(|| call.name.to_case(Case::Pascal)) + } else { + // For non-Boto3 SDKs we use the extracted name as-is + call.name.clone() + }; + + Ok(match &call.metadata { None => Self { - initial_operation, + name, + service, source: OperationSource::Provided, - fas, + _private: (), }, Some(metadata) => Self { - initial_operation, + name, + service, source: OperationSource::Extracted(metadata.clone()), - fas, + _private: (), }, + }) + } +} + +impl From for Operation { + fn from(fas_op: FasOperation) -> Self { + Self { + name: fas_op.operation, + service: fas_op.service, + source: OperationSource::Fas(fas_op.context), + _private: (), } } } -#[derive(derive_new::new, Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] -#[serde(rename_all = "PascalCase")] -pub struct Operation { - /// Name of the operation - pub name: String, - /// Name of the service - pub service: String, +/// Custom serializer for extracted metadata that flattens the structure +fn serialize_extracted_metadata(metadata: &SdkMethodCallMetadata, serializer: S) -> Result +where + S: serde::Serializer, +{ + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("Expr", &metadata.expr)?; + map.serialize_entry("Location", &metadata.location)?; + map.end() } -#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, JsonSchema)] #[serde(rename_all = "PascalCase")] -#[serde(untagged)] pub enum OperationSource { /// Operation extracted from source files - #[serde(rename_all = "PascalCase")] Extracted(SdkMethodCallMetadata), /// Operation provided (no metadata available) - #[serde(rename_all = "PascalCase")] Provided, + /// Operation comes from FAS expansion + Fas(Vec), } -// impl OperationSource { -// pub(crate) fn from_call(call: &SdkMethodCall, service: &str) -> Self { -// match &call.metadata { -// None => Self::Provided { -// name: call.name.clone(), -// service: service.to_string(), -// }, -// Some(metadata) => Self::Extracted { -// name: call.name.clone(), -// service: service.to_string(), -// expr: metadata.expr.clone(), -// location: Location::new( -// metadata.file_path.clone(), -// metadata.start_position, -// metadata.end_position, -// ), -// }, -// } -// } -// } +impl Serialize for OperationSource { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + OperationSource::Extracted(metadata) => serialize_extracted_metadata(metadata, serializer), + OperationSource::Provided => serializer.serialize_str("Provided"), + OperationSource::Fas(_) => serializer.serialize_str("FAS"), + } + } +} /// Represents an explanation for why an action was added to a policy #[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema, Default)] @@ -498,7 +540,7 @@ pub(crate) mod mock_remote_service_reference { #[cfg(test)] mod location_tests { use super::*; - use crate::Location; + use crate::{Location, enrichment::mock_remote_service_reference::setup_mock_server_with_loader_without_operation_to_action_mapping, service_configuration::load_service_configuration}; use std::path::PathBuf; #[test] @@ -545,31 +587,117 @@ mod location_tests { } } - #[test] - fn test_reason_extracted_with_location() { + #[tokio::test] + async fn test_reason_extracted_with_location() { + let service_cfg = load_service_configuration().unwrap(); + let (_, service_reference_loader) = setup_mock_server_with_loader_without_operation_to_action_mapping().await; let call = mock_sdk_method_call(); - let reason = Reason::new(&call, "s3", None); + let reason = Reason::new(vec![Arc::new(Operation::from_call(&call, "s3", &service_cfg, SdkType::Boto3, &service_reference_loader).await.unwrap())]); - match reason.source { + assert_eq!(reason.operations[0].name, "GetObject"); + assert_eq!(reason.operations[0].service, "s3"); + match &reason.operations[0].source { OperationSource::Extracted(metadata) => { - assert_eq!(reason.initial_operation.name, "get_object"); - assert_eq!(reason.initial_operation.service, "s3"); assert_eq!(metadata.expr, "s3.get_object(Bucket='my-bucket')"); assert_eq!(metadata.location.to_gnu_format(), "test.py:10.5-10.79"); } _ => panic!("Expected Extracted variant"), } + + let json = serde_json::to_string(&reason).unwrap(); + // Verify the location is serialized as a string in GNU format + assert!(json.contains("\"Location\":\"test.py:10.5-10.79\"")); } #[test] - fn test_reason_extracted_serialization() { - let call = mock_sdk_method_call(); + fn test_operation_source_extracted_serialization() { + let metadata = SdkMethodCallMetadata { + parameters: vec![], + return_type: None, + expr: "dynamodb.get_item(\n TableName='my-table',\n Key={'id': {'S': '123'}}\n )".to_string(), + location: Location::new(PathBuf::from("iam-policy-autopilot-cli/tests/resources/test_example.py"), (19, 5), (22, 5)), + receiver: Some("dynamodb".to_string()), + }; + + let source = OperationSource::Extracted(metadata); + let json = serde_json::to_string(&source).unwrap(); + + // Verify the custom serialization format + assert!(json.contains("\"Expr\":\"dynamodb.get_item(\\n TableName='my-table',\\n Key={'id': {'S': '123'}}\\n )\"")); + assert!(json.contains("\"Location\":\"iam-policy-autopilot-cli/tests/resources/test_example.py:19.5-22.5\"")); + // Should not contain nested "Source" key + assert!(!json.contains("\"Source\"")); + } - let reason = Reason::new(&call, "s3", None); - let json = serde_json::to_string(&reason).unwrap(); + #[test] + fn test_operation_source_provided_serialization() { + let source = OperationSource::Provided; + let json = serde_json::to_string(&source).unwrap(); + + // Verify the custom serialization format + assert_eq!(json, "\"Provided\""); + } - // Verify the location is serialized as a string in GNU format - assert!(json.contains("\"Location\":\"test.py:10.5-10.79\"")); + #[test] + fn test_operation_source_fas_serialization() { + use crate::enrichment::operation_fas_map::FasContext; + + let fas_context = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["ssm.${region}.amazonaws.com".to_string()], + )]; + let source = OperationSource::Fas(fas_context); + let json = serde_json::to_string(&source).unwrap(); + + // Verify the custom serialization format - should be just "FAS", not nested + assert_eq!(json, "\"FAS\""); + } + + #[tokio::test] + async fn test_operation_methods() { + let service_cfg = load_service_configuration().unwrap(); + let (_, service_reference_loader) = setup_mock_server_with_loader_without_operation_to_action_mapping().await; + + { + let call = SdkMethodCall { + name: "decrypt".to_string(), + possible_services: vec!["kms".to_string()], + metadata: None, + }; + let op = Operation::from_call(&call, "kms", &service_cfg, SdkType::Boto3, &service_reference_loader).await.unwrap(); + assert_eq!(op.service_operation_name(), "kms:Decrypt"); + assert_eq!(op.context(), &[]); + } + + { + let expr = "kms.decrypt(...)".to_string(); + let metadata = SdkMethodCallMetadata { + parameters: vec![], + return_type: None, + expr: expr.clone(), + location: Location::new(PathBuf::new(), (1, 1), (1, expr.len() + 1)), + receiver: Some("kms".to_string()), + }; + let call = SdkMethodCall { + name: "decrypt".to_string(), + possible_services: vec!["kms".to_string()], + metadata: Some(metadata), + }; + let op = Operation::from_call(&call, "kms", &service_cfg, SdkType::Boto3, &service_reference_loader).await.unwrap(); + assert_eq!(op.service_operation_name(), "kms:Decrypt"); + assert_eq!(op.context(), &[]); + } + + { + let context = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["ssm.${region}.amazonaws.com".to_string()], + )]; + let fas_operation = FasOperation::new("Decrypt".to_string(), "kms".to_string(), context.clone()); + let op = Operation::from(fas_operation); + assert_eq!(op.service_operation_name(), "kms:Decrypt"); + assert_eq!(op.context(), &context); + } } } diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/operation_fas_map.rs b/iam-policy-autopilot-policy-generation/src/enrichment/operation_fas_map.rs index f2065cd..be524d3 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/operation_fas_map.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/operation_fas_map.rs @@ -3,15 +3,14 @@ //! This module contains the data structures used to represent operation //! action maps that are loaded from embedded JSON files and used for IAM policy enrichment. -use std::borrow::Cow; use std::collections::HashMap; use std::sync::{Arc, OnceLock, RwLock}; use rust_embed::RustEmbed; +use schemars::JsonSchema; use serde::{Deserialize, Deserializer}; use crate::enrichment::Context; -use crate::service_configuration::ServiceConfiguration; type ServiceName = String; type OperationName = String; @@ -45,8 +44,8 @@ pub(crate) struct OperationFasMap { pub(crate) fas_operations: HashMap>, } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub(crate) struct FasContext { +#[derive(Debug, Clone, PartialEq, Eq, Hash, JsonSchema)] +pub struct FasContext { pub(crate) key: String, pub(crate) values: Vec, } @@ -117,13 +116,14 @@ where #[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] pub(crate) struct FasOperation { #[serde(rename = "Operation")] - operation: String, + pub(crate) operation: String, #[serde(rename = "Service")] - service: String, + pub(crate) service: String, #[serde(rename = "Context", deserialize_with = "deserialize_context_map")] pub(crate) context: Vec, } +#[cfg(test)] impl FasOperation { pub(crate) fn new(operation: String, service: String, context: Vec) -> Self { FasOperation { @@ -132,26 +132,6 @@ impl FasOperation { context, } } - - // TODO: I think this should be removed once we use the service reference API - // The Operation -> Action map uses this format, so map lookups - // need to convert to it. - pub(crate) fn service_operation_name(&self, service_cfg: &ServiceConfiguration) -> String { - let service = self.service(service_cfg); - format!("{}:{}", service, self.operation(&service, service_cfg)) - } - - pub(crate) fn service<'a>(&'a self, service_cfg: &ServiceConfiguration) -> Cow<'a, str> { - service_cfg.rename_service_service_reference(&self.service) - } - - pub(crate) fn operation<'a>( - &'a self, - service: &str, - service_cfg: &ServiceConfiguration, - ) -> Cow<'a, str> { - service_cfg.rename_operation(service, &self.operation) - } } impl<'de> Deserialize<'de> for OperationFasMap { @@ -301,29 +281,6 @@ mod tests { } } - #[test] - fn test_fas_operation_methods() { - use crate::service_configuration::load_service_configuration; - - let service_cfg = load_service_configuration().unwrap(); - - let context = vec![FasContext::new( - "kms:ViaService".to_string(), - vec!["ssm.${region}.amazonaws.com".to_string()], - )]; - - let fas_op = FasOperation::new("Decrypt".to_string(), "kms".to_string(), context); - - // Test service method - assert_eq!(fas_op.service(&service_cfg), "kms"); - - // Test operation method - assert_eq!(fas_op.operation("kms", &service_cfg), "Decrypt"); - - // Test service_operation_name method - assert_eq!(fas_op.service_operation_name(&service_cfg), "kms:Decrypt"); - } - #[test] fn test_operation_fas_map_deserialization() { let json_content = r#"{ diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs index 13b7eb9..e72dffa 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs @@ -4,55 +4,116 @@ //! action maps with Service Definition Files to generate enriched method calls //! with complete IAM metadata. -use convert_case::{Case, Casing}; -use std::collections::HashSet; +use std::collections::HashMap; use std::sync::Arc; -use super::{Action, Context, EnrichedSdkMethodCall, Explanation, FasInfo, Reason, Resource}; -use crate::enrichment::operation_fas_map::{FasOperation, OperationFasMap, OperationFasMaps}; +use super::{Action, Context, EnrichedSdkMethodCall, Explanation, Reason, Resource}; +use crate::enrichment::operation_fas_map::{OperationFasMap, OperationFasMaps}; use crate::enrichment::service_reference::ServiceReference; -use crate::enrichment::{Condition, ServiceReferenceLoader}; +use crate::enrichment::{Condition, Operation, ServiceReferenceLoader}; use crate::errors::{ExtractorError, Result}; use crate::service_configuration::ServiceConfiguration; use crate::{SdkMethodCall, SdkType}; -/// Represents an operation with its provenance chain (how we reached this operation via FAS expansion) -#[derive(Debug, Clone)] -struct OperationWithFasExpansion { - operation: FasOperation, - /// Chain of operation names leading to this operation (excludes current operation) - fas_chain: Vec, +#[derive(derive_new::new, Clone, Debug)] +struct FasExpansion<'a> { + service_cfg: &'a ServiceConfiguration, + fas_maps: &'a OperationFasMaps, + #[new(default)] + dependency_graph: HashMap, Vec>>, } -// Custom PartialEq and Hash that only consider the operation, not the FAS chain -// This prevents infinite loops in cycles where the same operation appears with different FAS chains -impl PartialEq for OperationWithFasExpansion { - fn eq(&self, other: &Self) -> bool { - self.operation == other.operation +impl<'a> FasExpansion<'a> { + fn operations(&self) -> impl Iterator> { + self.dependency_graph.keys() } -} - -impl Eq for OperationWithFasExpansion {} -impl std::hash::Hash for OperationWithFasExpansion { - fn hash(&self, state: &mut H) { - self.operation.hash(state); + fn complete_provenance_chain(&self, op: Arc) -> Vec> { + let mut result = vec![]; + if let Some(deps) = self.dependency_graph.get(&op) { + for dep in deps { + result.push(Arc::clone(dep)); + } + } + // Add the initial operation + result.push(Arc::clone(&op)); + result } -} - -impl OperationWithFasExpansion { - // Build the FasInfo, which we eventually output, from the operation. - // None if there is no FAS expansion - fn to_fas_info(&self, service_cfg: &ServiceConfiguration) -> Option { - if self.fas_chain.is_empty() { - None - } else { - // Build full chain: fas_chain + current operation - let mut chain = self.fas_chain.clone(); - chain.push(self.operation.service_operation_name(service_cfg)); - Some(FasInfo::new(chain)) + + /// Find OperationFas map for a specific service + fn find_operation_fas_map_for_service( + &self, + service_name: &str, + ) -> Option> { + self.fas_maps + .get( + self.service_cfg + .rename_service_operation_action_map(service_name) + .as_ref(), + ) + .cloned() + } + + fn expand_to_fixed_point(&mut self, initial: Operation) { + let initial_arc = Arc::new(initial); + + self.dependency_graph.insert(Arc::clone(&initial_arc), Vec::new()); // Root has no dependencies + + let mut to_process = vec![Arc::clone(&initial_arc)]; + + while !to_process.is_empty() { + let mut newly_discovered = Vec::new(); + + for current in &to_process { + let service_name = ¤t.service; + + match self.find_operation_fas_map_for_service(service_name) { + Some(operation_fas_map) => { + let service_operation_name = current.service_operation_name(); + log::debug!("Looking up operation {}", service_operation_name); + + if let Some(additional_operations) = operation_fas_map + .fas_operations + .get(&service_operation_name) + { + for additional_op in additional_operations { + let new_op = Arc::new(Operation::from(additional_op.clone())); + + if let Some(existing_deps) = self.dependency_graph.get_mut(&new_op) { + // Operation already exists, add this dependency relationship + existing_deps.push(Arc::clone(current)); + } else { + // New operation + self.dependency_graph.insert(Arc::clone(&new_op), vec![Arc::clone(current)]); + newly_discovered.push(Arc::clone(&new_op)); + } + } + } else { + log::debug!("Did not find {}", service_operation_name); + } + } + None => { + log::debug!("No FAS map found for service: {}", service_name); + } + } + } + + let newly_discovered_count = newly_discovered.len(); + to_process = newly_discovered; + + log::debug!( + "FAS expansion discovered {} new operations", + newly_discovered_count + ); } + + log::debug!( + "FAS expansion completed with {} total operations", + self.dependency_graph.len() + ); } + + } /// ResourceMatcher coordinates OperationAction maps and Service Reference data to generate enriched method calls @@ -60,7 +121,7 @@ impl OperationWithFasExpansion { /// This struct provides the core functionality for the 3-stage enrichment pipeline, /// combining parsed method calls with operation action maps and Service /// Definition Files to produce complete IAM metadata. -#[derive(Debug, Clone)] +#[derive(derive_new::new, Debug, Clone)] pub(crate) struct ResourceMatcher { service_cfg: Arc, fas_maps: OperationFasMaps, @@ -71,20 +132,6 @@ pub(crate) struct ResourceMatcher { const RESOURCE_CUTOFF: usize = 5; impl ResourceMatcher { - /// Create a new ResourceMatcher instance - #[must_use] - pub(crate) fn new( - service_cfg: Arc, - fas_maps: OperationFasMaps, - sdk: SdkType, - ) -> Self { - Self { - service_cfg, - fas_maps, - sdk, - } - } - /// Enrich a parsed method call with OperationAction maps, FAS maps, and Service /// Reference data pub(crate) async fn enrich_method_call<'b>( @@ -115,109 +162,7 @@ impl ResourceMatcher { Ok(enriched_calls) } - /// Find OperationFas map for a specific service - fn find_operation_fas_map_for_service( - &self, - service_name: &str, - ) -> Option> { - self.fas_maps - .get( - self.service_cfg - .rename_service_operation_action_map(service_name) - .as_ref(), - ) - .cloned() - } - - /// Expand FAS operations to a fixed point, avoiding infinite loops from cycles - /// - /// This method safely expands FAS operations by iteratively processing new operations - /// until no more new operations are discovered (fixed point reached). - /// It includes cycle detection to prevent infinite loops. - /// - /// Returns operations with their FAS provenance chains showing how each operation was reached. - fn expand_fas_operations_to_fixed_point( - &self, - initial: FasOperation, - ) -> Result> { - let mut processed_operations = HashSet::::new(); - - // Initial operation has empty FAS chain - let initial = OperationWithFasExpansion { - operation: initial.clone(), - fas_chain: vec![], - }; - processed_operations.insert(initial.clone()); - let mut to_process = vec![initial]; - - while !to_process.is_empty() { - let mut newly_discovered = Vec::new(); - - // Process all operations in the current batch - for current in &to_process { - let service_name = current.operation.service(&self.service_cfg); - match self.find_operation_fas_map_for_service(&service_name) { - Some(operation_fas_map) => { - let service_operation_name = - current.operation.service_operation_name(&self.service_cfg); - log::debug!("Looking up operation {}", service_operation_name); - - if let Some(additional_operations) = operation_fas_map - .fas_operations - .get(&service_operation_name) - { - for additional_op in additional_operations { - // Build new FAS chain: current chain + current operation - let mut new_chain = current.fas_chain.clone(); - new_chain.push( - current.operation.service_operation_name(&self.service_cfg), - ); - - let new_op = OperationWithFasExpansion { - operation: additional_op.clone(), - fas_chain: new_chain, - }; - - // Only add if we haven't seen this operation before - if !processed_operations.contains(&new_op) { - processed_operations.insert(new_op.clone()); - newly_discovered.push(new_op); - } - } - } else { - log::debug!("Did not find {}", service_operation_name); - } - } - None => { - log::debug!("No FAS map found for service: {}", service_name); - } - } - } - - let newly_discovered_count = newly_discovered.len(); - - // Set up next iteration to process only newly discovered operations - to_process = newly_discovered; - - log::debug!( - "FAS expansion discovered {} new operations", - newly_discovered_count - ); - } - - log::debug!( - "FAS expansion completed with {} total operations", - processed_operations.len() - ); - - // Convert HashSet to Vec and sort by service_operation_name for deterministic output - let mut operations_vec: Vec = - processed_operations.into_iter().collect(); - operations_vec.sort_by_key(|op| op.operation.service_operation_name(&self.service_cfg)); - - Ok(operations_vec) - } fn make_condition(context: &[T]) -> Vec { let mut result = vec![]; @@ -243,72 +188,46 @@ impl ResourceMatcher { service_name, parsed_call.name ); + + let mut fas_expansion = FasExpansion::new(&self.service_cfg, &self.fas_maps); // Store the original service name from parsed_call for use in explanations let original_service_name = service_name; - let initial = { - let initial_service_name = self - .service_cfg - .rename_service_service_reference(service_name); - // Determine the initial operation name, with special handling for Python's boto3 method names - let initial_operation_name = if self.sdk == SdkType::Boto3 { - // Try to load service reference and look up the boto3 method mapping - service_reference_loader - .load(&initial_service_name) - .await? - .and_then(|service_ref| { - log::debug!("Looking up method {}", parsed_call.name); - service_ref - .boto3_method_to_operation - .get(&parsed_call.name) - .map(|op| { - log::debug!("got {:?}", op); - op.split(':').nth(1).unwrap_or(op).to_string() - }) - }) - // Fallback to PascalCase conversion if mapping not found - // This should not be reachable, but if for some reason we cannot use the SDF, - // we try converting to PascalCase, knowing that this is flawed in some cases: - // think `AddRoleToDBInstance` (actual name) - // vs. `AddRoleToDbInstance` (converted name) - .unwrap_or_else(|| parsed_call.name.to_case(Case::Pascal)) - } else { - // For non-Boto3 SDKs we use the extracted name as-is - parsed_call.name.clone() - }; - FasOperation::new(initial_operation_name, service_name.to_string(), Vec::new()) - }; + let initial = Operation::from_call(parsed_call, service_name, &self.service_cfg, self.sdk, service_reference_loader).await?; + + log::debug!("Expanded {:?}", initial); // Use fixed-point algorithm to safely expand FAS operations until no new operations are found - let operations = self.expand_fas_operations_to_fixed_point(initial)?; + fas_expansion.expand_to_fixed_point(initial); + log::debug!("to\n{:?}", fas_expansion.dependency_graph); let mut enriched_actions = vec![]; - for op in operations { - let service_name = op.operation.service(&self.service_cfg); + for op in fas_expansion.operations() { + log::debug!("Creating actions for operation {:?}", op.service_operation_name()); + log::debug!(" with context {:?}", op.context()); + // Find the corresponding SDF using the cache - let service_reference = service_reference_loader.load(&service_name).await?; - + let service_reference = service_reference_loader.load(&op.service).await?; match service_reference { None => { + log::debug!("Skipping operation due to no service reference"); continue; } Some(service_reference) => { - log::debug!("Creating actions for {:?}", op.operation); - log::debug!(" with context {:?}", op.operation.context); - log::debug!(" FAS chain: {:?}", op.fas_chain); if let Some(operation_to_authorized_actions) = &service_reference.operation_to_authorized_actions { log::debug!( "Looking up {}", - &op.operation.service_operation_name(&self.service_cfg) + &op.service_operation_name() ); if let Some(operation_to_authorized_action) = operation_to_authorized_actions - .get(&op.operation.service_operation_name(&self.service_cfg)) + .get(&op.service_operation_name()) { + log::debug!("Found operation action map for {:?}", operation_to_authorized_action.name); for action in &operation_to_authorized_action.authorized_actions { let enriched_resources = self .find_resources_for_action_in_service_reference( @@ -323,7 +242,7 @@ impl ResourceMatcher { }; // Combine conditions from FAS operation context and AuthorizedAction context - let mut conditions = Self::make_condition(&op.operation.context); + let mut conditions = Self::make_condition(op.context()); // Add conditions from AuthorizedAction context if present if let Some(auth_context) = &action.context { @@ -331,14 +250,12 @@ impl ResourceMatcher { auth_context, ))); } + + let ops = fas_expansion.complete_provenance_chain(Arc::clone(op)); // Create explanation for this action let explanation = Explanation { - reasons: vec![Reason::new( - parsed_call, - original_service_name, - op.to_fas_info(&self.service_cfg), - )], + reasons: vec![Reason::new(ops)], }; let enriched_action = Action::new( action.name.clone(), @@ -346,29 +263,25 @@ impl ResourceMatcher { conditions, explanation, ); - + log::debug!("Created action: {:?}", enriched_action); enriched_actions.push(enriched_action); } } else { // Fallback: operation not found in operation action map, create basic action // This ensures we don't filter out operations, only ADD additional ones from the map if let Some(a) = self.create_fallback_action( - &parsed_call.name, - &service_reference, - parsed_call, - original_service_name, + op, &fas_expansion, &service_reference, )? { + log::debug!("Created fallback action due to no entry in operation action map: {:?}", a); enriched_actions.push(a); } } } else { // Fallback: operation action map does not exist, create basic action if let Some(a) = self.create_fallback_action( - &parsed_call.name, - &service_reference, - parsed_call, - original_service_name, + op, &fas_expansion, &service_reference, )? { + log::debug!("Created fallback action due to no operation action map for service: {:?}", a); enriched_actions.push(a); } } @@ -394,32 +307,28 @@ impl ResourceMatcher { /// corresponding resources in the SDF. fn create_fallback_action( &self, - method_name: &str, + op: &Arc, + fas_expansion_result: &FasExpansion, service_reference: &ServiceReference, - parsed_call: &SdkMethodCall, - original_service_name: &str, ) -> Result> { - let renamed_service = self - .service_cfg - .rename_service_service_reference(&service_reference.service_name); - let renamed_action = &method_name.to_case(Case::Pascal); - let action_name = format!("{}:{}", renamed_service, renamed_action); + let action_name = op.service_operation_name(); // Sanity check that the action exists in the SDF if !service_reference .actions - .contains_key(renamed_action.as_str()) + .contains_key(&op.name) { + log::debug!("Not creating fallback action: service reference doesn't contain key: {:?}", action_name); return Ok(None); } // Look up the action in the Service Reference to find associated resources let resources = self.find_resources_for_action_in_service_reference(&action_name, service_reference)?; - + // Create explanation for fallback action let explanation = Explanation { - reasons: vec![Reason::new(parsed_call, original_service_name, None)], + reasons: vec![Reason::new(fas_expansion_result.complete_provenance_chain(Arc::clone(op)))], }; Ok(Some(Action::new( @@ -488,7 +397,7 @@ mod tests { use std::collections::HashMap; use super::*; - use crate::enrichment::mock_remote_service_reference; + use crate::enrichment::{OperationSource, mock_remote_service_reference}; use crate::enrichment::operation_fas_map::{FasContext, FasOperation, OperationFasMap}; fn create_test_parsed_method_call() -> SdkMethodCall { @@ -1086,29 +995,23 @@ mod tests { }), ); - let matcher = ResourceMatcher::new(service_cfg.clone(), fas_maps, SdkType::Other); - // Test expansion starting from GetObject let initial = - FasOperation::new("GetObject".to_string(), "service-a".to_string(), Vec::new()); + Operation::new("GetObject".to_string(), "service-a".to_string(), OperationSource::Provided); - let result = matcher.expand_fas_operations_to_fixed_point(initial); - assert!( - result.is_ok(), - "Fixed-point expansion should succeed for non-cyclic operations" - ); + let mut fas_expansion = FasExpansion::new(&service_cfg, &fas_maps); + fas_expansion.expand_to_fixed_point(initial); - let operations = result.unwrap(); assert_eq!( - operations.len(), + fas_expansion.dependency_graph.len(), 3, "Should have exactly 3 operations: GetObject, Decrypt, Log" ); // Verify all expected operations are present - let operation_names: std::collections::HashSet = operations - .iter() - .map(|op| op.operation.service_operation_name(&service_cfg)) + let operation_names: std::collections::HashSet = fas_expansion + .operations() + .map(|op| op.service_operation_name()) .collect(); assert!(operation_names.contains("service-a:GetObject")); @@ -1171,29 +1074,22 @@ mod tests { }), ); - let matcher = ResourceMatcher::new(service_cfg.clone(), fas_maps, SdkType::Other); + let mut fas_expansion = FasExpansion::new(&service_cfg, &fas_maps); // Test expansion starting from GetObject - should detect cycle and terminate let initial = - FasOperation::new("GetObject".to_string(), "service-a".to_string(), Vec::new()); + Operation::new("GetObject".to_string(), "service-a".to_string(), OperationSource::Provided); - let result = matcher.expand_fas_operations_to_fixed_point(initial); - - assert!( - result.is_ok(), - "Fixed-point expansion should handle cycles gracefully" - ); - - let operations = result.unwrap(); + fas_expansion.expand_to_fixed_point(initial); // Debug: print what operations we actually got - let operation_names: std::collections::HashSet = operations - .iter() - .map(|op| op.operation.service_operation_name(&service_cfg)) + let operation_names: std::collections::HashSet = fas_expansion + .operations() + .map(|op| op.service_operation_name()) .collect(); // 3 operations, note that GetObject occurs twice, once with and once without context - assert!(operations.len() == 3, "Should have 3 operations"); + assert!(fas_expansion.dependency_graph.len() == 3, "Should have 3 operations"); // Verify expected operations are present assert!(operation_names.contains("service-a:GetObject")); @@ -1243,24 +1139,16 @@ mod tests { ); } - let matcher = ResourceMatcher::new(service_cfg.clone(), fas_maps, SdkType::Other); - + let mut fas_expansion = FasExpansion::new(&service_cfg, &fas_maps); + let initial = - FasOperation::new("GetObject".to_string(), "service-a".to_string(), Vec::new()); - - let result = matcher.expand_fas_operations_to_fixed_point(initial); - - // Should succeed and return operations for the cycle - assert!( - result.is_ok(), - "Should handle complex cycles without hitting max iterations" - ); + Operation::new("GetObject".to_string(), "service-a".to_string(), OperationSource::Provided); - let operations = result.unwrap(); + fas_expansion.expand_to_fixed_point(initial); // We have 5 operations, note that GetObject occurs twice, once with context and the initial one without assert!( - operations.len() == 5, + fas_expansion.dependency_graph.len() == 5, "Should have 5 operations in the cycle" ); } @@ -1270,31 +1158,31 @@ mod tests { use std::collections::HashMap; let service_cfg = create_empty_service_config(); + let fas_maps = HashMap::new(); - let matcher = ResourceMatcher::new(service_cfg.clone(), HashMap::new(), SdkType::Other); + let mut fas_expansion = FasExpansion::new(&service_cfg, &fas_maps); - let initial = FasOperation::new( + let initial = Operation::new( "NonExistentOperation".to_string(), "non-existent-service".to_string(), - Vec::new(), + OperationSource::Provided, ); - let result = matcher.expand_fas_operations_to_fixed_point(initial.clone()); - assert!(result.is_ok(), "Should succeed even with no FAS maps"); - - let operations = result.unwrap(); + fas_expansion.expand_to_fixed_point(initial.clone()); assert_eq!( - operations.len(), + fas_expansion.dependency_graph.len(), 1, "Should contain only the initial operation" ); + + let operations: Vec<_> = fas_expansion.operations().collect(); assert_eq!( - operations[0].operation, initial, + **operations[0], initial, "Should contain the initial operation" ); assert!( - operations[0].fas_chain.is_empty(), - "Initial operation should have empty FAS chain" + !matches!(operations[0].source, OperationSource::Fas(_)), + "Initial operation should not be from FAS expansion" ); println!("✓ Test passed: Handles case with no additional FAS operations"); @@ -1328,36 +1216,32 @@ mod tests { }), ); - let matcher = ResourceMatcher::new(service_cfg.clone(), fas_maps, SdkType::Other); + let mut fas_expansion = FasExpansion::new(&service_cfg, &fas_maps); // Test expansion starting from GetObject with empty context - let initial = FasOperation::new( + let initial = Operation::new( "GetObject".to_string(), "service-a".to_string(), - Vec::new(), // Empty context + OperationSource::Provided, ); - let result = matcher.expand_fas_operations_to_fixed_point(initial.clone()); - assert!( - result.is_ok(), - "Self-cycle with empty context should be handled gracefully" - ); - - let operations = result.unwrap(); + fas_expansion.expand_to_fixed_point(initial.clone()); // Should have exactly 1 operation since A->A with same context creates no new operations assert_eq!( - operations.len(), + fas_expansion.dependency_graph.len(), 1, "Self-cycle with identical context should result in exactly 1 operation" ); + + let operations: Vec<_> = fas_expansion.operations().collect(); assert_eq!( - operations[0].operation, initial, + **operations[0], initial, "Should contain the initial operation" ); assert!( - operations[0].fas_chain.is_empty(), - "Initial operation should have empty FAS chain" + !matches!(operations[0].source, OperationSource::Fas(_)), + "Initial operation should not be from FAS expansion" ); println!("✓ Test passed: Self-cycle with empty context handled correctly"); @@ -1451,3 +1335,4 @@ mod tests { assert_eq!(enriched_calls[0].actions[0].name, "rds:ModifyDBCluster"); } } + diff --git a/iam-policy-autopilot-policy-generation/src/extraction/mod.rs b/iam-policy-autopilot-policy-generation/src/extraction/mod.rs index 8484520..d635a50 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/mod.rs @@ -462,8 +462,7 @@ mod tests { // Verify PascalCase field names assert!(json.contains("\"Parameters\"")); assert!(json.contains("\"ReturnType\"")); - assert!(json.contains("\"StartPosition\"")); - assert!(json.contains("\"EndPosition\"")); + assert!(json.contains("\"Location\"")); assert!(json.contains("\"Receiver\"")); } } diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs index 16ddd34..dc3ae49 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs @@ -903,7 +903,8 @@ mod tests { #[test] fn test_generate_policies_with_explanations() { - use crate::enrichment::{Explanation, Reason}; + use crate::enrichment::{Explanation, Reason, Operation, OperationSource}; + use std::sync::Arc; let engine = create_test_engine(); let sdk_call = create_test_sdk_call(); @@ -922,7 +923,11 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason::new(&sdk_call, "s3", None)], + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "get_object".to_string(), + "s3".to_string(), + OperationSource::Provided, + ))])], }, )], sdk_method_call: &sdk_call, @@ -940,7 +945,7 @@ mod tests { .and_then(|explanations| explanations.get("s3:GetObject")) { assert_eq!(explanation.reasons.len(), 1); - assert!(explanation.reasons[0].fas.is_none()); + assert_eq!(explanation.reasons[0].operations.len(), 1); } else { panic!("Must have an explanation for s3:GetObject"); } @@ -948,7 +953,8 @@ mod tests { #[test] fn test_explanation_deduplication() { - use crate::enrichment::{Explanation, Reason}; + use crate::enrichment::{Explanation, Reason, Operation, OperationSource}; + use std::sync::Arc; let engine = create_test_engine(); let sdk_call1 = create_test_sdk_call(); @@ -972,7 +978,11 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason::new(&sdk_call1, "s3", None)], + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "get_object".to_string(), + "s3".to_string(), + OperationSource::Provided, + ))])], }, )], sdk_method_call: &sdk_call1, @@ -991,7 +1001,11 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason::new(&sdk_call1, "s3", None)], + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "get_object".to_string(), + "s3".to_string(), + OperationSource::Provided, + ))])], }, )], sdk_method_call: &sdk_call2, @@ -1019,7 +1033,8 @@ mod tests { #[test] fn test_explanation_with_fas_expansion() { - use crate::enrichment::{Explanation, Reason}; + use crate::enrichment::{Explanation, Reason, Operation, OperationSource, operation_fas_map::FasContext}; + use std::sync::Arc; let engine = create_test_engine(); let sdk_call = create_test_sdk_call(); @@ -1039,7 +1054,11 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason::new(&sdk_call, "s3", None)], + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "get_object".to_string(), + "s3".to_string(), + OperationSource::Provided, + ))])], }, ), Action::new( @@ -1052,14 +1071,14 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason::new( - &sdk_call, - "s3", - Some(crate::enrichment::FasInfo::new(vec![ - "s3:GetObject".to_string(), - "kms:Decrypt".to_string(), - ])), - )], + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "Decrypt".to_string(), + "kms".to_string(), + OperationSource::Fas(vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["s3.us-east-1.amazonaws.com".to_string()], + )]), + ))])], }, ), ], @@ -1079,21 +1098,26 @@ mod tests { .get("kms:Decrypt") .expect("Should have kms:Decrypt explanation"); assert_eq!(kms_explanation.reasons.len(), 1); - assert!(kms_explanation.reasons[0].fas.is_some()); - let fas_info = kms_explanation.reasons[0].fas.as_ref().unwrap(); - assert_eq!(fas_info.expansion.len(), 2); - assert!(fas_info.expansion.contains(&"s3:GetObject".to_string())); - assert!(fas_info.expansion.contains(&"kms:Decrypt".to_string())); - // Verify the explanation URL is set - assert_eq!( - fas_info.explanation, - "https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html" - ); + assert_eq!(kms_explanation.reasons[0].operations.len(), 1); + + // Check that the operation has FAS context + let operation = &kms_explanation.reasons[0].operations[0]; + assert_eq!(operation.name, "Decrypt"); + assert_eq!(operation.service, "kms"); + match &operation.source { + OperationSource::Fas(context) => { + assert_eq!(context.len(), 1); + assert_eq!(context[0].key, "kms:ViaService"); + assert_eq!(context[0].values, vec!["s3.us-east-1.amazonaws.com"]); + } + _ => panic!("Expected FAS operation source"), + } } #[test] fn test_explanation_with_possible_false_positive() { - use crate::enrichment::{Explanation, Reason}; + use crate::enrichment::{Explanation, Reason, Operation, OperationSource}; + use std::sync::Arc; let engine = create_test_engine(); let sdk_call = SdkMethodCall { @@ -1116,7 +1140,11 @@ mod tests { )], vec![], Explanation { - reasons: vec![Reason::new(&sdk_call, "s3", None)], + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "get_object".to_string(), + "s3".to_string(), + OperationSource::Provided, + ))])], }, )], sdk_method_call: &sdk_call, @@ -1127,3 +1155,5 @@ mod tests { assert_eq!(result.explanations.as_ref().unwrap().len(), 1); } } + + diff --git a/iam-policy-autopilot-policy-generation/src/service_configuration.rs b/iam-policy-autopilot-policy-generation/src/service_configuration.rs index d603244..2cc6337 100644 --- a/iam-policy-autopilot-policy-generation/src/service_configuration.rs +++ b/iam-policy-autopilot-policy-generation/src/service_configuration.rs @@ -14,6 +14,8 @@ use std::{ /// Operation rename configuration #[derive(Clone, Debug, Deserialize)] +// TODO: remove +#[allow(dead_code)] pub(crate) struct OperationRename { /// Target service name pub(crate) service: String, @@ -31,6 +33,8 @@ pub(crate) struct ServiceConfiguration { pub(crate) rename_services_service_reference: HashMap, /// Smithy to Botocore model: service renames pub(crate) smithy_botocore_service_name_mapping: HashMap, + // TODO: remove + #[allow(dead_code)] /// Operation renames pub(crate) rename_operations: HashMap, /// Resource overrides @@ -55,6 +59,8 @@ impl ServiceConfiguration { } } + // TODO: remove + #[allow(dead_code)] pub(crate) fn rename_operation<'a>(&self, service: &str, original: &'a str) -> Cow<'a, str> { let tmp = format!("{}:{}", service, original); match self.rename_operations.get(&tmp) { From 4cf972eba6d3e256d37d82ff32c47e10691372bd Mon Sep 17 00:00:00 2001 From: Matthias Schlaipfer Date: Tue, 13 Jan 2026 12:27:02 +0000 Subject: [PATCH 05/11] Output explanation documentation --- .../src/api/model.rs | 6 +- .../src/enrichment/mod.rs | 64 +++++++++--- .../src/enrichment/resource_matcher.rs | 97 +++++++++---------- .../src/policy_generation/engine.rs | 31 +++--- 4 files changed, 118 insertions(+), 80 deletions(-) diff --git a/iam-policy-autopilot-policy-generation/src/api/model.rs b/iam-policy-autopilot-policy-generation/src/api/model.rs index 2dc4e75..9c10baa 100644 --- a/iam-policy-autopilot-policy-generation/src/api/model.rs +++ b/iam-policy-autopilot-policy-generation/src/api/model.rs @@ -1,8 +1,8 @@ //! Defined model for API use serde::Serialize; -use crate::{enrichment::Explanation, policy_generation::PolicyWithMetadata}; -use std::{collections::BTreeMap, path::PathBuf}; +use crate::{enrichment::Explanations, policy_generation::PolicyWithMetadata}; +use std::path::PathBuf; /// Configuration for generate_policies API #[derive(Debug, Clone)] @@ -29,7 +29,7 @@ pub struct GeneratePoliciesResult { pub policies: Vec, /// Explanations for why actions were added (if requested) #[serde(skip_serializing_if = "Option::is_none")] - pub explanations: Option>, + pub explanations: Option, } /// Service hints for filtering SDK method calls diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs index b1301ae..9b0921f 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs @@ -8,7 +8,7 @@ //! that represent method calls enriched with IAM metadata from operation //! action maps and Service Definition Files. -use std::{collections::HashSet, sync::Arc}; +use std::{collections::{BTreeMap, HashSet}, sync::Arc}; use crate::{SdkMethodCall, SdkType, enrichment::operation_fas_map::{FasContext, FasOperation}, extraction::SdkMethodCallMetadata, service_configuration::ServiceConfiguration}; use convert_case::{Case, Casing}; @@ -25,9 +25,6 @@ pub(crate) use operation_fas_map::load_operation_fas_map; pub(crate) use resource_matcher::ResourceMatcher; pub(crate) use service_reference::RemoteServiceReferenceLoader as ServiceReferenceLoader; -const FAS_URL: &str = - "https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html"; - /// Represents the reason why an action was added to a policy #[derive(derive_new::new, Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] #[serde(rename_all = "PascalCase")] @@ -39,22 +36,23 @@ pub struct Reason { #[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] #[serde(rename_all = "PascalCase")] pub struct Operation { - /// Name of the operation - pub name: String, /// Name of the service pub service: String, + /// Name of the operation + pub name: String, /// Source of the operation, pub source: OperationSource, /// Disallow struct construction, need to use Self::from_call or Operation::from(FasOperation) + #[serde(skip)] _private: () } impl Operation { #[cfg(test)] /// Convenience constructor for tests - pub(crate) fn new(name: String, service: String, source: OperationSource) -> Self { + pub(crate) fn new(service: String, name: String, source: OperationSource) -> Self { Self { - name, service, source, _private: () + service, name, source, _private: () } } @@ -105,14 +103,14 @@ impl Operation { Ok(match &call.metadata { None => Self { - name, service, + name, source: OperationSource::Provided, _private: (), }, Some(metadata) => Self { - name, service, + name, source: OperationSource::Extracted(metadata.clone()), _private: (), }, @@ -123,8 +121,8 @@ impl Operation { impl From for Operation { fn from(fas_op: FasOperation) -> Self { Self { - name: fas_op.operation, service: fas_op.service, + name: fas_op.operation, source: OperationSource::Fas(fas_op.context), _private: (), } @@ -167,9 +165,51 @@ impl Serialize for OperationSource { } } +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "PascalCase")] +pub struct Explanations { + pub explanation_for_action: BTreeMap, + pub documentation: BTreeMap<&'static str, Documentation>, +} + +impl Explanations { + const FAS_URL: &str = + "https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html"; + + pub(crate) fn new(explanations: BTreeMap) -> Self { + let mut documentation: Vec<(&'static str, Documentation)> = vec![]; + for (_, explanation) in &explanations { + for reason in &explanation.reasons { + for op in &reason.operations { + match op.source { + OperationSource::Extracted(_) | OperationSource::Provided => (), + OperationSource::Fas(_) => documentation.push(("FAS", Documentation { + plain: "The explanation contains an operation added due to Forward Access Sessions.", + url: Self::FAS_URL, + })) + } + } + } + } + Self { + explanation_for_action: explanations, + documentation: BTreeMap::from_iter(documentation.into_iter()) + } + } +} + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "PascalCase")] +pub struct Documentation { + plain: &'static str, + #[serde(rename = "URL")] + url: &'static str, +} + /// Represents an explanation for why an action was added to a policy #[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema, Default)] -#[serde(rename_all = "PascalCase")] +// Don't print the `"Reasons":` key, treat this as just a JSON array. +#[serde(transparent)] pub struct Explanation { /// The reasons this action was added (can have multiple reasons for the same action) pub reasons: Vec, diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs index e72dffa..a62154e 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs @@ -15,31 +15,12 @@ use crate::errors::{ExtractorError, Result}; use crate::service_configuration::ServiceConfiguration; use crate::{SdkMethodCall, SdkType}; -#[derive(derive_new::new, Clone, Debug)] -struct FasExpansion<'a> { +struct FasExpansionBuilder<'a> { service_cfg: &'a ServiceConfiguration, fas_maps: &'a OperationFasMaps, - #[new(default)] - dependency_graph: HashMap, Vec>>, } -impl<'a> FasExpansion<'a> { - fn operations(&self) -> impl Iterator> { - self.dependency_graph.keys() - } - - fn complete_provenance_chain(&self, op: Arc) -> Vec> { - let mut result = vec![]; - if let Some(deps) = self.dependency_graph.get(&op) { - for dep in deps { - result.push(Arc::clone(dep)); - } - } - // Add the initial operation - result.push(Arc::clone(&op)); - result - } - +impl<'a> FasExpansionBuilder<'a> { /// Find OperationFas map for a specific service fn find_operation_fas_map_for_service( &self, @@ -53,11 +34,12 @@ impl<'a> FasExpansion<'a> { ) .cloned() } - - fn expand_to_fixed_point(&mut self, initial: Operation) { + + fn expand_to_fixed_point(&self, initial: Operation) -> FasExpansion { + let mut dependency_graph: HashMap, Vec>> = HashMap::new(); let initial_arc = Arc::new(initial); - self.dependency_graph.insert(Arc::clone(&initial_arc), Vec::new()); // Root has no dependencies + dependency_graph.insert(Arc::clone(&initial_arc), Vec::new()); // Root has no dependencies let mut to_process = vec![Arc::clone(&initial_arc)]; @@ -79,12 +61,12 @@ impl<'a> FasExpansion<'a> { for additional_op in additional_operations { let new_op = Arc::new(Operation::from(additional_op.clone())); - if let Some(existing_deps) = self.dependency_graph.get_mut(&new_op) { + if let Some(existing_deps) = dependency_graph.get_mut(&new_op) { // Operation already exists, add this dependency relationship existing_deps.push(Arc::clone(current)); } else { // New operation - self.dependency_graph.insert(Arc::clone(&new_op), vec![Arc::clone(current)]); + dependency_graph.insert(Arc::clone(&new_op), vec![Arc::clone(current)]); newly_discovered.push(Arc::clone(&new_op)); } } @@ -109,11 +91,38 @@ impl<'a> FasExpansion<'a> { log::debug!( "FAS expansion completed with {} total operations", - self.dependency_graph.len() + dependency_graph.len() ); + FasExpansion { dependency_graph } + } + +} + +#[derive(Clone, Debug)] +struct FasExpansion { + dependency_graph: HashMap, Vec>>, +} + +impl FasExpansion { + fn builder<'a>(service_cfg: &'a ServiceConfiguration, fas_maps: &'a OperationFasMaps) -> FasExpansionBuilder<'a> { + FasExpansionBuilder { service_cfg, fas_maps } } + fn operations(&self) -> impl Iterator> { + self.dependency_graph.keys() + } + fn complete_provenance_chain(&self, op: Arc) -> Vec> { + let mut result = vec![]; + if let Some(deps) = self.dependency_graph.get(&op) { + for dep in deps { + result.push(Arc::clone(dep)); + } + } + // Add the initial operation + result.push(Arc::clone(&op)); + result + } } /// ResourceMatcher coordinates OperationAction maps and Service Reference data to generate enriched method calls @@ -189,8 +198,6 @@ impl ResourceMatcher { parsed_call.name ); - let mut fas_expansion = FasExpansion::new(&self.service_cfg, &self.fas_maps); - // Store the original service name from parsed_call for use in explanations let original_service_name = service_name; @@ -198,7 +205,8 @@ impl ResourceMatcher { log::debug!("Expanded {:?}", initial); // Use fixed-point algorithm to safely expand FAS operations until no new operations are found - fas_expansion.expand_to_fixed_point(initial); + let fas_expansion = FasExpansion::builder(&self.service_cfg, &self.fas_maps).expand_to_fixed_point(initial); + log::debug!("to\n{:?}", fas_expansion.dependency_graph); let mut enriched_actions = vec![]; @@ -997,10 +1005,9 @@ mod tests { // Test expansion starting from GetObject let initial = - Operation::new("GetObject".to_string(), "service-a".to_string(), OperationSource::Provided); + Operation::new("service-a".to_string(), "GetObject".to_string(), OperationSource::Provided); - let mut fas_expansion = FasExpansion::new(&service_cfg, &fas_maps); - fas_expansion.expand_to_fixed_point(initial); + let fas_expansion = FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); assert_eq!( fas_expansion.dependency_graph.len(), @@ -1074,13 +1081,11 @@ mod tests { }), ); - let mut fas_expansion = FasExpansion::new(&service_cfg, &fas_maps); - // Test expansion starting from GetObject - should detect cycle and terminate let initial = - Operation::new("GetObject".to_string(), "service-a".to_string(), OperationSource::Provided); + Operation::new("service-a".to_string(), "GetObject".to_string(), OperationSource::Provided); - fas_expansion.expand_to_fixed_point(initial); + let fas_expansion = FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); // Debug: print what operations we actually got let operation_names: std::collections::HashSet = fas_expansion @@ -1139,12 +1144,10 @@ mod tests { ); } - let mut fas_expansion = FasExpansion::new(&service_cfg, &fas_maps); - let initial = - Operation::new("GetObject".to_string(), "service-a".to_string(), OperationSource::Provided); + Operation::new("service-a".to_string(),"GetObject".to_string() , OperationSource::Provided); - fas_expansion.expand_to_fixed_point(initial); + let fas_expansion = FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); // We have 5 operations, note that GetObject occurs twice, once with context and the initial one without assert!( @@ -1160,15 +1163,13 @@ mod tests { let service_cfg = create_empty_service_config(); let fas_maps = HashMap::new(); - let mut fas_expansion = FasExpansion::new(&service_cfg, &fas_maps); - let initial = Operation::new( - "NonExistentOperation".to_string(), "non-existent-service".to_string(), + "NonExistentOperation".to_string(), OperationSource::Provided, ); - fas_expansion.expand_to_fixed_point(initial.clone()); + let fas_expansion = FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial.clone()); assert_eq!( fas_expansion.dependency_graph.len(), 1, @@ -1216,16 +1217,14 @@ mod tests { }), ); - let mut fas_expansion = FasExpansion::new(&service_cfg, &fas_maps); - // Test expansion starting from GetObject with empty context let initial = Operation::new( - "GetObject".to_string(), "service-a".to_string(), + "GetObject".to_string(), OperationSource::Provided, ); - fas_expansion.expand_to_fixed_point(initial.clone()); + let fas_expansion = FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial.clone()); // Should have exactly 1 operation since A->A with same context creates no new operations assert_eq!( diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs index dc3ae49..b303e9b 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs @@ -10,7 +10,7 @@ use super::merge::{PolicyMerger, PolicyMergerConfig}; use super::utils::{ArnParser, ConditionValueProcessor}; use super::{IamPolicy, Statement}; use crate::api::model::GeneratePoliciesResult; -use crate::enrichment::{Action, Condition, EnrichedSdkMethodCall}; +use crate::enrichment::{Action, Condition, EnrichedSdkMethodCall, Explanations}; use crate::errors::{ExtractorError, Result}; use crate::policy_generation::{PolicyType, PolicyWithMetadata}; use crate::Explanation; @@ -280,9 +280,7 @@ impl<'a> Engine<'a> { } } -fn extract_explanations( - enriched_calls: &[EnrichedSdkMethodCall<'_>], -) -> BTreeMap { +fn extract_explanations(enriched_calls: &[EnrichedSdkMethodCall<'_>]) -> Explanations { let mut explanations: BTreeMap = BTreeMap::new(); // Collect and merge explanations for each action name @@ -296,8 +294,8 @@ fn extract_explanations( .or_insert_with(|| action.explanation.clone()); } } - - explanations + + Explanations::new(explanations) } #[cfg(test)] @@ -924,8 +922,8 @@ mod tests { vec![], Explanation { reasons: vec![Reason::new(vec![Arc::new(Operation::new( - "get_object".to_string(), "s3".to_string(), + "get_object".to_string(), OperationSource::Provided, ))])], }, @@ -942,7 +940,7 @@ mod tests { if let Some(explanation) = result .explanations .as_ref() - .and_then(|explanations| explanations.get("s3:GetObject")) + .and_then(|explanations| explanations.explanation_for_action.get("s3:GetObject")) { assert_eq!(explanation.reasons.len(), 1); assert_eq!(explanation.reasons[0].operations.len(), 1); @@ -979,8 +977,8 @@ mod tests { vec![], Explanation { reasons: vec![Reason::new(vec![Arc::new(Operation::new( - "get_object".to_string(), "s3".to_string(), + "get_object".to_string(), OperationSource::Provided, ))])], }, @@ -1002,8 +1000,8 @@ mod tests { vec![], Explanation { reasons: vec![Reason::new(vec![Arc::new(Operation::new( - "get_object".to_string(), "s3".to_string(), + "get_object".to_string(), OperationSource::Provided, ))])], }, @@ -1019,7 +1017,7 @@ mod tests { if let Some(explanation) = result .explanations .as_ref() - .and_then(|explanations| explanations.get("s3:GetObject")) + .and_then(|explanations| explanations.explanation_for_action.get("s3:GetObject")) { assert_eq!( explanation.reasons.len(), @@ -1055,8 +1053,8 @@ mod tests { vec![], Explanation { reasons: vec![Reason::new(vec![Arc::new(Operation::new( - "get_object".to_string(), "s3".to_string(), + "get_object".to_string(), OperationSource::Provided, ))])], }, @@ -1072,8 +1070,8 @@ mod tests { vec![], Explanation { reasons: vec![Reason::new(vec![Arc::new(Operation::new( - "Decrypt".to_string(), "kms".to_string(), + "Decrypt".to_string(), OperationSource::Fas(vec![FasContext::new( "kms:ViaService".to_string(), vec!["s3.us-east-1.amazonaws.com".to_string()], @@ -1088,13 +1086,14 @@ mod tests { let result = engine.generate_policies(&[enriched_call]).unwrap(); // Verify explanations include FAS expansion - assert_eq!(result.explanations.as_ref().unwrap().len(), 2); + assert_eq!(result.explanations.as_ref().unwrap().explanation_for_action.len(), 2); // Check the FAS-expanded action let kms_explanation = result .explanations .as_ref() .unwrap() + .explanation_for_action .get("kms:Decrypt") .expect("Should have kms:Decrypt explanation"); assert_eq!(kms_explanation.reasons.len(), 1); @@ -1141,8 +1140,8 @@ mod tests { vec![], Explanation { reasons: vec![Reason::new(vec![Arc::new(Operation::new( - "get_object".to_string(), "s3".to_string(), + "get_object".to_string(), OperationSource::Provided, ))])], }, @@ -1152,7 +1151,7 @@ mod tests { let result = engine.generate_policies(&[enriched_call]).unwrap(); - assert_eq!(result.explanations.as_ref().unwrap().len(), 1); + assert_eq!(result.explanations.as_ref().unwrap().explanation_for_action.len(), 1); } } From f59160ff5f2ac8f50a5e68b74d1890d02545dc6b Mon Sep 17 00:00:00 2001 From: Matthias Schlaipfer Date: Tue, 13 Jan 2026 12:27:19 +0000 Subject: [PATCH 06/11] cargo fmt --- .../src/enrichment/mod.rs | 98 +++++++++--- .../src/enrichment/resource_matcher.rs | 149 +++++++++++------- .../src/policy_generation/engine.rs | 36 +++-- 3 files changed, 191 insertions(+), 92 deletions(-) diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs index 9b0921f..c666599 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs @@ -8,9 +8,17 @@ //! that represent method calls enriched with IAM metadata from operation //! action maps and Service Definition Files. -use std::{collections::{BTreeMap, HashSet}, sync::Arc}; - -use crate::{SdkMethodCall, SdkType, enrichment::operation_fas_map::{FasContext, FasOperation}, extraction::SdkMethodCallMetadata, service_configuration::ServiceConfiguration}; +use std::{ + collections::{BTreeMap, HashSet}, + sync::Arc, +}; + +use crate::{ + enrichment::operation_fas_map::{FasContext, FasOperation}, + extraction::SdkMethodCallMetadata, + service_configuration::ServiceConfiguration, + SdkMethodCall, SdkType, +}; use convert_case::{Case, Casing}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -44,7 +52,7 @@ pub struct Operation { pub source: OperationSource, /// Disallow struct construction, need to use Self::from_call or Operation::from(FasOperation) #[serde(skip)] - _private: () + _private: (), } impl Operation { @@ -52,7 +60,10 @@ impl Operation { /// Convenience constructor for tests pub(crate) fn new(service: String, name: String, source: OperationSource) -> Self { Self { - service, name, source, _private: () + service, + name, + source, + _private: (), } } @@ -66,7 +77,7 @@ impl Operation { _ => &[], } } - + pub(crate) async fn from_call( call: &SdkMethodCall, original_service_name: &str, @@ -74,7 +85,9 @@ impl Operation { sdk: SdkType, service_reference_loader: &ServiceReferenceLoader, ) -> crate::errors::Result { - let service = service_cfg.rename_service_service_reference(original_service_name).to_string(); + let service = service_cfg + .rename_service_service_reference(original_service_name) + .to_string(); let name = if sdk == SdkType::Boto3 { // Try to load service reference and look up the boto3 method mapping service_reference_loader @@ -130,7 +143,10 @@ impl From for Operation { } /// Custom serializer for extracted metadata that flattens the structure -fn serialize_extracted_metadata(metadata: &SdkMethodCallMetadata, serializer: S) -> Result +fn serialize_extracted_metadata( + metadata: &SdkMethodCallMetadata, + serializer: S, +) -> Result where S: serde::Serializer, { @@ -158,7 +174,9 @@ impl Serialize for OperationSource { S: serde::Serializer, { match self { - OperationSource::Extracted(metadata) => serialize_extracted_metadata(metadata, serializer), + OperationSource::Extracted(metadata) => { + serialize_extracted_metadata(metadata, serializer) + } OperationSource::Provided => serializer.serialize_str("Provided"), OperationSource::Fas(_) => serializer.serialize_str("FAS"), } @@ -193,7 +211,7 @@ impl Explanations { } Self { explanation_for_action: explanations, - documentation: BTreeMap::from_iter(documentation.into_iter()) + documentation: BTreeMap::from_iter(documentation.into_iter()), } } } @@ -580,7 +598,10 @@ pub(crate) mod mock_remote_service_reference { #[cfg(test)] mod location_tests { use super::*; - use crate::{Location, enrichment::mock_remote_service_reference::setup_mock_server_with_loader_without_operation_to_action_mapping, service_configuration::load_service_configuration}; + use crate::{ + enrichment::mock_remote_service_reference::setup_mock_server_with_loader_without_operation_to_action_mapping, + service_configuration::load_service_configuration, Location, + }; use std::path::PathBuf; #[test] @@ -630,10 +651,21 @@ mod location_tests { #[tokio::test] async fn test_reason_extracted_with_location() { let service_cfg = load_service_configuration().unwrap(); - let (_, service_reference_loader) = setup_mock_server_with_loader_without_operation_to_action_mapping().await; + let (_, service_reference_loader) = + setup_mock_server_with_loader_without_operation_to_action_mapping().await; let call = mock_sdk_method_call(); - let reason = Reason::new(vec![Arc::new(Operation::from_call(&call, "s3", &service_cfg, SdkType::Boto3, &service_reference_loader).await.unwrap())]); + let reason = Reason::new(vec![Arc::new( + Operation::from_call( + &call, + "s3", + &service_cfg, + SdkType::Boto3, + &service_reference_loader, + ) + .await + .unwrap(), + )]); assert_eq!(reason.operations[0].name, "GetObject"); assert_eq!(reason.operations[0].service, "s3"); @@ -659,13 +691,15 @@ mod location_tests { location: Location::new(PathBuf::from("iam-policy-autopilot-cli/tests/resources/test_example.py"), (19, 5), (22, 5)), receiver: Some("dynamodb".to_string()), }; - + let source = OperationSource::Extracted(metadata); let json = serde_json::to_string(&source).unwrap(); - + // Verify the custom serialization format assert!(json.contains("\"Expr\":\"dynamodb.get_item(\\n TableName='my-table',\\n Key={'id': {'S': '123'}}\\n )\"")); - assert!(json.contains("\"Location\":\"iam-policy-autopilot-cli/tests/resources/test_example.py:19.5-22.5\"")); + assert!(json.contains( + "\"Location\":\"iam-policy-autopilot-cli/tests/resources/test_example.py:19.5-22.5\"" + )); // Should not contain nested "Source" key assert!(!json.contains("\"Source\"")); } @@ -674,7 +708,7 @@ mod location_tests { fn test_operation_source_provided_serialization() { let source = OperationSource::Provided; let json = serde_json::to_string(&source).unwrap(); - + // Verify the custom serialization format assert_eq!(json, "\"Provided\""); } @@ -682,14 +716,14 @@ mod location_tests { #[test] fn test_operation_source_fas_serialization() { use crate::enrichment::operation_fas_map::FasContext; - + let fas_context = vec![FasContext::new( "kms:ViaService".to_string(), vec!["ssm.${region}.amazonaws.com".to_string()], )]; let source = OperationSource::Fas(fas_context); let json = serde_json::to_string(&source).unwrap(); - + // Verify the custom serialization format - should be just "FAS", not nested assert_eq!(json, "\"FAS\""); } @@ -697,7 +731,8 @@ mod location_tests { #[tokio::test] async fn test_operation_methods() { let service_cfg = load_service_configuration().unwrap(); - let (_, service_reference_loader) = setup_mock_server_with_loader_without_operation_to_action_mapping().await; + let (_, service_reference_loader) = + setup_mock_server_with_loader_without_operation_to_action_mapping().await; { let call = SdkMethodCall { @@ -705,7 +740,15 @@ mod location_tests { possible_services: vec!["kms".to_string()], metadata: None, }; - let op = Operation::from_call(&call, "kms", &service_cfg, SdkType::Boto3, &service_reference_loader).await.unwrap(); + let op = Operation::from_call( + &call, + "kms", + &service_cfg, + SdkType::Boto3, + &service_reference_loader, + ) + .await + .unwrap(); assert_eq!(op.service_operation_name(), "kms:Decrypt"); assert_eq!(op.context(), &[]); } @@ -724,7 +767,15 @@ mod location_tests { possible_services: vec!["kms".to_string()], metadata: Some(metadata), }; - let op = Operation::from_call(&call, "kms", &service_cfg, SdkType::Boto3, &service_reference_loader).await.unwrap(); + let op = Operation::from_call( + &call, + "kms", + &service_cfg, + SdkType::Boto3, + &service_reference_loader, + ) + .await + .unwrap(); assert_eq!(op.service_operation_name(), "kms:Decrypt"); assert_eq!(op.context(), &[]); } @@ -734,7 +785,8 @@ mod location_tests { "kms:ViaService".to_string(), vec!["ssm.${region}.amazonaws.com".to_string()], )]; - let fas_operation = FasOperation::new("Decrypt".to_string(), "kms".to_string(), context.clone()); + let fas_operation = + FasOperation::new("Decrypt".to_string(), "kms".to_string(), context.clone()); let op = Operation::from(fas_operation); assert_eq!(op.service_operation_name(), "kms:Decrypt"); assert_eq!(op.context(), &context); diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs index a62154e..d26c987 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs @@ -38,35 +38,36 @@ impl<'a> FasExpansionBuilder<'a> { fn expand_to_fixed_point(&self, initial: Operation) -> FasExpansion { let mut dependency_graph: HashMap, Vec>> = HashMap::new(); let initial_arc = Arc::new(initial); - + dependency_graph.insert(Arc::clone(&initial_arc), Vec::new()); // Root has no dependencies - + let mut to_process = vec![Arc::clone(&initial_arc)]; - + while !to_process.is_empty() { let mut newly_discovered = Vec::new(); - + for current in &to_process { let service_name = ¤t.service; - + match self.find_operation_fas_map_for_service(service_name) { Some(operation_fas_map) => { let service_operation_name = current.service_operation_name(); log::debug!("Looking up operation {}", service_operation_name); - + if let Some(additional_operations) = operation_fas_map .fas_operations .get(&service_operation_name) { for additional_op in additional_operations { let new_op = Arc::new(Operation::from(additional_op.clone())); - + if let Some(existing_deps) = dependency_graph.get_mut(&new_op) { // Operation already exists, add this dependency relationship existing_deps.push(Arc::clone(current)); } else { // New operation - dependency_graph.insert(Arc::clone(&new_op), vec![Arc::clone(current)]); + dependency_graph + .insert(Arc::clone(&new_op), vec![Arc::clone(current)]); newly_discovered.push(Arc::clone(&new_op)); } } @@ -79,23 +80,22 @@ impl<'a> FasExpansionBuilder<'a> { } } } - + let newly_discovered_count = newly_discovered.len(); to_process = newly_discovered; - + log::debug!( "FAS expansion discovered {} new operations", newly_discovered_count ); } - + log::debug!( "FAS expansion completed with {} total operations", dependency_graph.len() ); FasExpansion { dependency_graph } } - } #[derive(Clone, Debug)] @@ -104,8 +104,14 @@ struct FasExpansion { } impl FasExpansion { - fn builder<'a>(service_cfg: &'a ServiceConfiguration, fas_maps: &'a OperationFasMaps) -> FasExpansionBuilder<'a> { - FasExpansionBuilder { service_cfg, fas_maps } + fn builder<'a>( + service_cfg: &'a ServiceConfiguration, + fas_maps: &'a OperationFasMaps, + ) -> FasExpansionBuilder<'a> { + FasExpansionBuilder { + service_cfg, + fas_maps, + } } fn operations(&self) -> impl Iterator> { @@ -171,8 +177,6 @@ impl ResourceMatcher { Ok(enriched_calls) } - - fn make_condition(context: &[T]) -> Vec { let mut result = vec![]; for ctx in context { @@ -197,24 +201,35 @@ impl ResourceMatcher { service_name, parsed_call.name ); - + // Store the original service name from parsed_call for use in explanations let original_service_name = service_name; - let initial = Operation::from_call(parsed_call, service_name, &self.service_cfg, self.sdk, service_reference_loader).await?; + let initial = Operation::from_call( + parsed_call, + service_name, + &self.service_cfg, + self.sdk, + service_reference_loader, + ) + .await?; log::debug!("Expanded {:?}", initial); // Use fixed-point algorithm to safely expand FAS operations until no new operations are found - let fas_expansion = FasExpansion::builder(&self.service_cfg, &self.fas_maps).expand_to_fixed_point(initial); + let fas_expansion = + FasExpansion::builder(&self.service_cfg, &self.fas_maps).expand_to_fixed_point(initial); log::debug!("to\n{:?}", fas_expansion.dependency_graph); let mut enriched_actions = vec![]; for op in fas_expansion.operations() { - log::debug!("Creating actions for operation {:?}", op.service_operation_name()); + log::debug!( + "Creating actions for operation {:?}", + op.service_operation_name() + ); log::debug!(" with context {:?}", op.context()); - + // Find the corresponding SDF using the cache let service_reference = service_reference_loader.load(&op.service).await?; match service_reference { @@ -223,19 +238,17 @@ impl ResourceMatcher { continue; } Some(service_reference) => { - if let Some(operation_to_authorized_actions) = &service_reference.operation_to_authorized_actions { - log::debug!( - "Looking up {}", - &op.service_operation_name() - ); + log::debug!("Looking up {}", &op.service_operation_name()); if let Some(operation_to_authorized_action) = - operation_to_authorized_actions - .get(&op.service_operation_name()) + operation_to_authorized_actions.get(&op.service_operation_name()) { - log::debug!("Found operation action map for {:?}", operation_to_authorized_action.name); + log::debug!( + "Found operation action map for {:?}", + operation_to_authorized_action.name + ); for action in &operation_to_authorized_action.authorized_actions { let enriched_resources = self .find_resources_for_action_in_service_reference( @@ -258,7 +271,7 @@ impl ResourceMatcher { auth_context, ))); } - + let ops = fas_expansion.complete_provenance_chain(Arc::clone(op)); // Create explanation for this action @@ -277,18 +290,18 @@ impl ResourceMatcher { } else { // Fallback: operation not found in operation action map, create basic action // This ensures we don't filter out operations, only ADD additional ones from the map - if let Some(a) = self.create_fallback_action( - op, &fas_expansion, &service_reference, - )? { + if let Some(a) = + self.create_fallback_action(op, &fas_expansion, &service_reference)? + { log::debug!("Created fallback action due to no entry in operation action map: {:?}", a); enriched_actions.push(a); } } } else { // Fallback: operation action map does not exist, create basic action - if let Some(a) = self.create_fallback_action( - op, &fas_expansion, &service_reference, - )? { + if let Some(a) = + self.create_fallback_action(op, &fas_expansion, &service_reference)? + { log::debug!("Created fallback action due to no operation action map for service: {:?}", a); enriched_actions.push(a); } @@ -322,21 +335,23 @@ impl ResourceMatcher { let action_name = op.service_operation_name(); // Sanity check that the action exists in the SDF - if !service_reference - .actions - .contains_key(&op.name) - { - log::debug!("Not creating fallback action: service reference doesn't contain key: {:?}", action_name); + if !service_reference.actions.contains_key(&op.name) { + log::debug!( + "Not creating fallback action: service reference doesn't contain key: {:?}", + action_name + ); return Ok(None); } // Look up the action in the Service Reference to find associated resources let resources = self.find_resources_for_action_in_service_reference(&action_name, service_reference)?; - + // Create explanation for fallback action let explanation = Explanation { - reasons: vec![Reason::new(fas_expansion_result.complete_provenance_chain(Arc::clone(op)))], + reasons: vec![Reason::new( + fas_expansion_result.complete_provenance_chain(Arc::clone(op)), + )], }; Ok(Some(Action::new( @@ -405,8 +420,8 @@ mod tests { use std::collections::HashMap; use super::*; - use crate::enrichment::{OperationSource, mock_remote_service_reference}; use crate::enrichment::operation_fas_map::{FasContext, FasOperation, OperationFasMap}; + use crate::enrichment::{mock_remote_service_reference, OperationSource}; fn create_test_parsed_method_call() -> SdkMethodCall { SdkMethodCall { @@ -1004,10 +1019,14 @@ mod tests { ); // Test expansion starting from GetObject - let initial = - Operation::new("service-a".to_string(), "GetObject".to_string(), OperationSource::Provided); + let initial = Operation::new( + "service-a".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); - let fas_expansion = FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); + let fas_expansion = + FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); assert_eq!( fas_expansion.dependency_graph.len(), @@ -1082,10 +1101,14 @@ mod tests { ); // Test expansion starting from GetObject - should detect cycle and terminate - let initial = - Operation::new("service-a".to_string(), "GetObject".to_string(), OperationSource::Provided); + let initial = Operation::new( + "service-a".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); - let fas_expansion = FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); + let fas_expansion = + FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); // Debug: print what operations we actually got let operation_names: std::collections::HashSet = fas_expansion @@ -1094,7 +1117,10 @@ mod tests { .collect(); // 3 operations, note that GetObject occurs twice, once with and once without context - assert!(fas_expansion.dependency_graph.len() == 3, "Should have 3 operations"); + assert!( + fas_expansion.dependency_graph.len() == 3, + "Should have 3 operations" + ); // Verify expected operations are present assert!(operation_names.contains("service-a:GetObject")); @@ -1144,10 +1170,14 @@ mod tests { ); } - let initial = - Operation::new("service-a".to_string(),"GetObject".to_string() , OperationSource::Provided); + let initial = Operation::new( + "service-a".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); - let fas_expansion = FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); + let fas_expansion = + FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); // We have 5 operations, note that GetObject occurs twice, once with context and the initial one without assert!( @@ -1169,13 +1199,14 @@ mod tests { OperationSource::Provided, ); - let fas_expansion = FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial.clone()); + let fas_expansion = + FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial.clone()); assert_eq!( fas_expansion.dependency_graph.len(), 1, "Should contain only the initial operation" ); - + let operations: Vec<_> = fas_expansion.operations().collect(); assert_eq!( **operations[0], initial, @@ -1224,7 +1255,8 @@ mod tests { OperationSource::Provided, ); - let fas_expansion = FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial.clone()); + let fas_expansion = + FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial.clone()); // Should have exactly 1 operation since A->A with same context creates no new operations assert_eq!( @@ -1232,7 +1264,7 @@ mod tests { 1, "Self-cycle with identical context should result in exactly 1 operation" ); - + let operations: Vec<_> = fas_expansion.operations().collect(); assert_eq!( **operations[0], initial, @@ -1334,4 +1366,3 @@ mod tests { assert_eq!(enriched_calls[0].actions[0].name, "rds:ModifyDBCluster"); } } - diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs index b303e9b..c93c43f 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs @@ -294,7 +294,7 @@ fn extract_explanations(enriched_calls: &[EnrichedSdkMethodCall<'_>]) -> Explana .or_insert_with(|| action.explanation.clone()); } } - + Explanations::new(explanations) } @@ -901,7 +901,7 @@ mod tests { #[test] fn test_generate_policies_with_explanations() { - use crate::enrichment::{Explanation, Reason, Operation, OperationSource}; + use crate::enrichment::{Explanation, Operation, OperationSource, Reason}; use std::sync::Arc; let engine = create_test_engine(); @@ -951,7 +951,7 @@ mod tests { #[test] fn test_explanation_deduplication() { - use crate::enrichment::{Explanation, Reason, Operation, OperationSource}; + use crate::enrichment::{Explanation, Operation, OperationSource, Reason}; use std::sync::Arc; let engine = create_test_engine(); @@ -1031,7 +1031,9 @@ mod tests { #[test] fn test_explanation_with_fas_expansion() { - use crate::enrichment::{Explanation, Reason, Operation, OperationSource, operation_fas_map::FasContext}; + use crate::enrichment::{ + operation_fas_map::FasContext, Explanation, Operation, OperationSource, Reason, + }; use std::sync::Arc; let engine = create_test_engine(); @@ -1086,7 +1088,15 @@ mod tests { let result = engine.generate_policies(&[enriched_call]).unwrap(); // Verify explanations include FAS expansion - assert_eq!(result.explanations.as_ref().unwrap().explanation_for_action.len(), 2); + assert_eq!( + result + .explanations + .as_ref() + .unwrap() + .explanation_for_action + .len(), + 2 + ); // Check the FAS-expanded action let kms_explanation = result @@ -1098,7 +1108,7 @@ mod tests { .expect("Should have kms:Decrypt explanation"); assert_eq!(kms_explanation.reasons.len(), 1); assert_eq!(kms_explanation.reasons[0].operations.len(), 1); - + // Check that the operation has FAS context let operation = &kms_explanation.reasons[0].operations[0]; assert_eq!(operation.name, "Decrypt"); @@ -1115,7 +1125,7 @@ mod tests { #[test] fn test_explanation_with_possible_false_positive() { - use crate::enrichment::{Explanation, Reason, Operation, OperationSource}; + use crate::enrichment::{Explanation, Operation, OperationSource, Reason}; use std::sync::Arc; let engine = create_test_engine(); @@ -1151,8 +1161,14 @@ mod tests { let result = engine.generate_policies(&[enriched_call]).unwrap(); - assert_eq!(result.explanations.as_ref().unwrap().explanation_for_action.len(), 1); + assert_eq!( + result + .explanations + .as_ref() + .unwrap() + .explanation_for_action + .len(), + 1 + ); } } - - From 645cdbf91ed1ab6ed122d476dbdab4988b6796ae Mon Sep 17 00:00:00 2001 From: Matthias Schlaipfer Date: Tue, 13 Jan 2026 12:32:09 +0000 Subject: [PATCH 07/11] Documentation comments and minor refactoring --- .../src/enrichment/mod.rs | 22 ++++++++++++++----- 1 file changed, 16 insertions(+), 6 deletions(-) diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs index c666599..57f7e67 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs @@ -183,39 +183,49 @@ impl Serialize for OperationSource { } } +/// Explanations for why actions have been included in a policy, with documentation for +/// concepts leading to inclusion (such as FAS expansion) #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "PascalCase")] pub struct Explanations { + /// Explanation for inclusion of an action pub explanation_for_action: BTreeMap, + /// Documentation of concepts used in the explanation for an action pub documentation: BTreeMap<&'static str, Documentation>, } impl Explanations { + const FAS_PLAIN: &str = + "The explanation contains an operation added due to Forward Access Sessions."; const FAS_URL: &str = "https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html"; pub(crate) fn new(explanations: BTreeMap) -> Self { let mut documentation: Vec<(&'static str, Documentation)> = vec![]; - for (_, explanation) in &explanations { + for explanation in explanations.values() { for reason in &explanation.reasons { for op in &reason.operations { match op.source { OperationSource::Extracted(_) | OperationSource::Provided => (), - OperationSource::Fas(_) => documentation.push(("FAS", Documentation { - plain: "The explanation contains an operation added due to Forward Access Sessions.", - url: Self::FAS_URL, - })) + OperationSource::Fas(_) => documentation.push(( + "FAS", + Documentation { + plain: Self::FAS_PLAIN, + url: Self::FAS_URL, + }, + )), } } } } Self { explanation_for_action: explanations, - documentation: BTreeMap::from_iter(documentation.into_iter()), + documentation: BTreeMap::from_iter(documentation), } } } +/// Documentation of concepts appearing in explanations #[derive(Debug, Clone, Serialize)] #[serde(rename_all = "PascalCase")] pub struct Documentation { From e098cdd5ecc4dccdf7bcc1e687f40dbbd9c2c897 Mon Sep 17 00:00:00 2001 From: Matthias Schlaipfer Date: Tue, 13 Jan 2026 14:56:48 +0000 Subject: [PATCH 08/11] Update changelog --- CHANGELOG.md | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 8318de6..210e273 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## [0.1.3] + +## Added + +- Add `--explain` feature to output the reasons for why an action has been added to the policy. The explanations allow to review the operations which static analysis extracted from source code, and to correct them using the `--service-hints` flag, if necessary. + ## [0.1.2] - 2025-12-15 ## Fixed From ed39391c7689638fd83cf2c94772822705a7be29 Mon Sep 17 00:00:00 2001 From: Matthias Schlaipfer Date: Tue, 13 Jan 2026 15:07:39 +0000 Subject: [PATCH 09/11] Remove FasExpansionBuilder --- .../src/enrichment/resource_matcher.rs | 364 +++++++++--------- 1 file changed, 184 insertions(+), 180 deletions(-) diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs index d26c987..6eb3a54 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs @@ -15,27 +15,17 @@ use crate::errors::{ExtractorError, Result}; use crate::service_configuration::ServiceConfiguration; use crate::{SdkMethodCall, SdkType}; -struct FasExpansionBuilder<'a> { - service_cfg: &'a ServiceConfiguration, - fas_maps: &'a OperationFasMaps, +#[derive(Clone, Debug)] +struct FasExpansion { + dependency_graph: HashMap, Vec>>, } -impl<'a> FasExpansionBuilder<'a> { - /// Find OperationFas map for a specific service - fn find_operation_fas_map_for_service( - &self, - service_name: &str, - ) -> Option> { - self.fas_maps - .get( - self.service_cfg - .rename_service_operation_action_map(service_name) - .as_ref(), - ) - .cloned() - } - - fn expand_to_fixed_point(&self, initial: Operation) -> FasExpansion { +impl FasExpansion { + fn new( + service_cfg: &ServiceConfiguration, + fas_maps: &OperationFasMaps, + initial: Operation, + ) -> Self { let mut dependency_graph: HashMap, Vec>> = HashMap::new(); let initial_arc = Arc::new(initial); @@ -49,7 +39,8 @@ impl<'a> FasExpansionBuilder<'a> { for current in &to_process { let service_name = ¤t.service; - match self.find_operation_fas_map_for_service(service_name) { + match Self::find_operation_fas_map_for_service(service_cfg, fas_maps, service_name) + { Some(operation_fas_map) => { let service_operation_name = current.service_operation_name(); log::debug!("Looking up operation {}", service_operation_name); @@ -61,6 +52,15 @@ impl<'a> FasExpansionBuilder<'a> { for additional_op in additional_operations { let new_op = Arc::new(Operation::from(additional_op.clone())); + if dependency_graph.contains_key(&new_op) { + // Skip adding this operation as it's logically equivalent to an existing one + log::debug!( + "Skipping logically equivalent operation: {}", + new_op.service_operation_name() + ); + continue; + } + if let Some(existing_deps) = dependency_graph.get_mut(&new_op) { // Operation already exists, add this dependency relationship existing_deps.push(Arc::clone(current)); @@ -94,24 +94,22 @@ impl<'a> FasExpansionBuilder<'a> { "FAS expansion completed with {} total operations", dependency_graph.len() ); - FasExpansion { dependency_graph } + Self { dependency_graph } } -} -#[derive(Clone, Debug)] -struct FasExpansion { - dependency_graph: HashMap, Vec>>, -} - -impl FasExpansion { - fn builder<'a>( - service_cfg: &'a ServiceConfiguration, - fas_maps: &'a OperationFasMaps, - ) -> FasExpansionBuilder<'a> { - FasExpansionBuilder { - service_cfg, - fas_maps, - } + /// Find OperationFas map for a specific service + fn find_operation_fas_map_for_service( + service_cfg: &ServiceConfiguration, + fas_maps: &OperationFasMaps, + service_name: &str, + ) -> Option> { + fas_maps + .get( + service_cfg + .rename_service_operation_action_map(service_name) + .as_ref(), + ) + .cloned() } fn operations(&self) -> impl Iterator> { @@ -216,8 +214,7 @@ impl ResourceMatcher { log::debug!("Expanded {:?}", initial); // Use fixed-point algorithm to safely expand FAS operations until no new operations are found - let fas_expansion = - FasExpansion::builder(&self.service_cfg, &self.fas_maps).expand_to_fixed_point(initial); + let fas_expansion = FasExpansion::new(&self.service_cfg, &self.fas_maps, initial); log::debug!("to\n{:?}", fas_expansion.dependency_graph); @@ -966,57 +963,60 @@ mod tests { let service_cfg = create_empty_service_config(); // Create a mock FAS map with no cycles: A -> B -> C (linear chain) - let mut fas_maps = HashMap::new(); - - // Service A: GetObject -> Service B: Decrypt - let mut service_a_operations = HashMap::new(); - service_a_operations.insert( - "service-a:GetObject".to_string(), - vec![FasOperation::new( - "Decrypt".to_string(), - "service-b".to_string(), - vec![FasContext::new( - "test".to_string(), - vec!["value".to_string()], + let fas_maps = { + let mut fas_maps = HashMap::new(); + + // Service A: GetObject -> Service B: Decrypt + let mut service_a_operations = HashMap::new(); + service_a_operations.insert( + "service-a:GetObject".to_string(), + vec![FasOperation::new( + "Decrypt".to_string(), + "service-b".to_string(), + vec![FasContext::new( + "test".to_string(), + vec!["value".to_string()], + )], )], - )], - ); + ); - // Service B: Decrypt -> Service C: Log - let mut service_b_operations = HashMap::new(); - service_b_operations.insert( - "service-b:Decrypt".to_string(), - vec![FasOperation::new( - "Log".to_string(), - "service-c".to_string(), - vec![FasContext::new( - "test2".to_string(), - vec!["value2".to_string()], + // Service B: Decrypt -> Service C: Log + let mut service_b_operations = HashMap::new(); + service_b_operations.insert( + "service-b:Decrypt".to_string(), + vec![FasOperation::new( + "Log".to_string(), + "service-c".to_string(), + vec![FasContext::new( + "test2".to_string(), + vec!["value2".to_string()], + )], )], - )], - ); + ); - // Service C: Log -> nothing (terminal) - let service_c_operations = HashMap::new(); + // Service C: Log -> nothing (terminal) + let service_c_operations = HashMap::new(); - fas_maps.insert( - "service-a".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_a_operations, - }), - ); - fas_maps.insert( - "service-b".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_b_operations, - }), - ); - fas_maps.insert( - "service-c".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_c_operations, - }), - ); + fas_maps.insert( + "service-a".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_a_operations, + }), + ); + fas_maps.insert( + "service-b".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_b_operations, + }), + ); + fas_maps.insert( + "service-c".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_c_operations, + }), + ); + fas_maps + }; // Test expansion starting from GetObject let initial = Operation::new( @@ -1025,8 +1025,7 @@ mod tests { OperationSource::Provided, ); - let fas_expansion = - FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); + let fas_expansion = FasExpansion::new(&service_cfg, &fas_maps, initial); assert_eq!( fas_expansion.dependency_graph.len(), @@ -1057,48 +1056,51 @@ mod tests { let service_cfg = create_empty_service_config(); // Create a mock FAS map with a cycle: A -> B -> A - let mut fas_maps = HashMap::new(); - - // Service A: GetObject -> Service B: Decrypt - let mut service_a_operations = HashMap::new(); - service_a_operations.insert( - "service-a:GetObject".to_string(), - vec![FasOperation::new( - "Decrypt".to_string(), - "service-b".to_string(), - vec![FasContext::new( - "test".to_string(), - vec!["value".to_string()], + let fas_maps = { + let mut fas_maps = HashMap::new(); + + // Service A: GetObject -> Service B: Decrypt + let mut service_a_operations = HashMap::new(); + service_a_operations.insert( + "service-a:GetObject".to_string(), + vec![FasOperation::new( + "Decrypt".to_string(), + "service-b".to_string(), + vec![FasContext::new( + "test".to_string(), + vec!["value".to_string()], + )], )], - )], - ); + ); - // Service B: Decrypt -> Service A: GetObject (creates cycle!) - let mut service_b_operations = HashMap::new(); - service_b_operations.insert( - "service-b:Decrypt".to_string(), - vec![FasOperation::new( - "GetObject".to_string(), - "service-a".to_string(), - vec![FasContext::new( - "test2".to_string(), - vec!["value2".to_string()], + // Service B: Decrypt -> Service A: GetObject (creates cycle!) + let mut service_b_operations = HashMap::new(); + service_b_operations.insert( + "service-b:Decrypt".to_string(), + vec![FasOperation::new( + "GetObject".to_string(), + "service-a".to_string(), + vec![FasContext::new( + "test2".to_string(), + vec!["value2".to_string()], + )], )], - )], - ); + ); - fas_maps.insert( - "service-a".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_a_operations, - }), - ); - fas_maps.insert( - "service-b".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_b_operations, - }), - ); + fas_maps.insert( + "service-a".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_a_operations, + }), + ); + fas_maps.insert( + "service-b".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_b_operations, + }), + ); + fas_maps + }; // Test expansion starting from GetObject - should detect cycle and terminate let initial = Operation::new( @@ -1107,8 +1109,7 @@ mod tests { OperationSource::Provided, ); - let fas_expansion = - FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); + let fas_expansion = FasExpansion::new(&service_cfg, &fas_maps, initial); // Debug: print what operations we actually got let operation_names: std::collections::HashSet = fas_expansion @@ -1138,37 +1139,40 @@ mod tests { // Create a service configuration let service_cfg = create_empty_service_config(); - let mut fas_maps = HashMap::new(); - - // Create a chain that loops back: A -> B -> C -> D -> A - let operations_data = vec![ - ("service-a", "GetObject", "service-b", "Decrypt"), - ("service-b", "Decrypt", "service-c", "Validate"), - ("service-c", "Validate", "service-d", "Log"), - ("service-d", "Log", "service-a", "GetObject"), // Back to start - ]; - - for (from_service, from_op, to_service, to_op) in operations_data { - let mut operations = HashMap::new(); - operations.insert( - format!("{}:{}", from_service, from_op), - vec![FasOperation::new( - to_op.to_string(), - to_service.to_string(), - vec![FasContext::new( - "cycle".to_string(), - vec!["test".to_string()], + let fas_maps = { + let mut fas_maps = HashMap::new(); + + // Create a chain that loops back: A -> B -> C -> D -> A + let operations_data = vec![ + ("service-a", "GetObject", "service-b", "Decrypt"), + ("service-b", "Decrypt", "service-c", "Validate"), + ("service-c", "Validate", "service-d", "Log"), + ("service-d", "Log", "service-a", "GetObject"), // Back to start + ]; + + for (from_service, from_op, to_service, to_op) in operations_data { + let mut operations = HashMap::new(); + operations.insert( + format!("{}:{}", from_service, from_op), + vec![FasOperation::new( + to_op.to_string(), + to_service.to_string(), + vec![FasContext::new( + "cycle".to_string(), + vec!["test".to_string()], + )], )], - )], - ); - - fas_maps.insert( - from_service.to_string(), - Arc::new(OperationFasMap { - fas_operations: operations, - }), - ); - } + ); + + fas_maps.insert( + from_service.to_string(), + Arc::new(OperationFasMap { + fas_operations: operations, + }), + ); + } + fas_maps + }; let initial = Operation::new( "service-a".to_string(), @@ -1176,8 +1180,7 @@ mod tests { OperationSource::Provided, ); - let fas_expansion = - FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial); + let fas_expansion = FasExpansion::new(&service_cfg, &fas_maps, initial); // We have 5 operations, note that GetObject occurs twice, once with context and the initial one without assert!( @@ -1199,8 +1202,7 @@ mod tests { OperationSource::Provided, ); - let fas_expansion = - FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial.clone()); + let fas_expansion = FasExpansion::new(&service_cfg, &fas_maps, initial.clone()); assert_eq!( fas_expansion.dependency_graph.len(), 1, @@ -1228,25 +1230,28 @@ mod tests { let service_cfg = create_empty_service_config(); // Create a FAS map where A -> A with empty context (self-referential) - let mut fas_maps = HashMap::new(); - - // Service A: GetObject -> Service A: GetObject (with empty context) - let mut service_a_operations = HashMap::new(); - service_a_operations.insert( - "service-a:GetObject".to_string(), - vec![FasOperation::new( - "GetObject".to_string(), - "service-a".to_string(), - Vec::new(), // Empty context - same as initial - )], - ); + let fas_maps = { + let mut fas_maps = HashMap::new(); - fas_maps.insert( - "service-a".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_a_operations, - }), - ); + // Service A: GetObject -> Service A: GetObject (with empty context) + let mut service_a_operations = HashMap::new(); + service_a_operations.insert( + "service-a:GetObject".to_string(), + vec![FasOperation::new( + "GetObject".to_string(), + "service-a".to_string(), + Vec::new(), // Empty context - same as initial + )], + ); + + fas_maps.insert( + "service-a".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_a_operations, + }), + ); + fas_maps + }; // Test expansion starting from GetObject with empty context let initial = Operation::new( @@ -1255,8 +1260,7 @@ mod tests { OperationSource::Provided, ); - let fas_expansion = - FasExpansion::builder(&service_cfg, &fas_maps).expand_to_fixed_point(initial.clone()); + let fas_expansion = FasExpansion::new(&service_cfg, &fas_maps, initial.clone()); // Should have exactly 1 operation since A->A with same context creates no new operations assert_eq!( From dbea1487b01f510cf2f0db79821f2ede016b85e6 Mon Sep 17 00:00:00 2001 From: Matthias Schlaipfer Date: Tue, 13 Jan 2026 16:13:00 +0000 Subject: [PATCH 10/11] Custom PartialEq, Hash for Operation to fix test --- .../src/enrichment/mod.rs | 208 +++++++++++++++++- .../src/enrichment/resource_matcher.rs | 9 - 2 files changed, 207 insertions(+), 10 deletions(-) diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs index 57f7e67..6a8f702 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs @@ -41,7 +41,7 @@ pub struct Reason { pub operations: Vec>, } -#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] +#[derive(Debug, Clone, Serialize, Eq, JsonSchema)] #[serde(rename_all = "PascalCase")] pub struct Operation { /// Name of the service @@ -142,6 +142,26 @@ impl From for Operation { } } +// Custom PartialEq and Hash implementations for Operation: + +// We consider operations to be equal when they would produce the same action in a policy. +// I.e., same operation and same context used for the condition. Directly relevant to FAS expansion. +impl PartialEq for Operation { + fn eq(&self, other: &Self) -> bool { + self.service == other.service + && self.name == other.name + && self.context() == other.context() + } +} + +impl std::hash::Hash for Operation { + fn hash(&self, state: &mut H) { + self.service.hash(state); + self.name.hash(state); + self.context().hash(state); + } +} + /// Custom serializer for extracted metadata that flattens the structure fn serialize_extracted_metadata( metadata: &SdkMethodCallMetadata, @@ -359,6 +379,7 @@ impl Resource { #[cfg(test)] mod tests { use super::*; + use crate::enrichment::operation_fas_map::FasContext; #[test] fn test_enriched_resource_creation() { @@ -373,6 +394,191 @@ mod tests { Some(vec!["arn:aws:s3:::bucket/*".to_string()]) ); } + + #[test] + fn test_operation_custom_equality_same_operation_different_sources() { + // Test that operations with same service, name, and context are equal regardless of source + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + let op2 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(Vec::new()), // Empty context + ); + + // Should be equal because they have same service, name, and context (both empty) + assert_eq!(op1, op2); + assert_eq!(op2, op1); // Symmetric + } + + #[test] + fn test_operation_custom_equality_different_contexts() { + // Test that operations with different contexts are NOT equal + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, // Empty context + ); + + let context = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["s3.us-east-1.amazonaws.com".to_string()], + )]; + let op2 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(context), + ); + + // Should NOT be equal because they have different contexts + assert_ne!(op1, op2); + assert_ne!(op2, op1); // Symmetric + } + + #[test] + fn test_operation_custom_equality_same_contexts() { + // Test that operations with same contexts are equal + let context1 = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["s3.us-east-1.amazonaws.com".to_string()], + )]; + let context2 = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["s3.us-east-1.amazonaws.com".to_string()], + )]; + + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(context1), + ); + + let op2 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(context2), + ); + + // Should be equal because they have same service, name, and context + assert_eq!(op1, op2); + assert_eq!(op2, op1); // Symmetric + } + + #[test] + fn test_operation_custom_equality_different_services() { + // Test that operations with different services are NOT equal + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + let op2 = Operation::new( + "kms".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + // Should NOT be equal because they have different services + assert_ne!(op1, op2); + assert_ne!(op2, op1); // Symmetric + } + + #[test] + fn test_operation_custom_equality_different_names() { + // Test that operations with different names are NOT equal + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + let op2 = Operation::new( + "s3".to_string(), + "PutObject".to_string(), + OperationSource::Provided, + ); + + // Should NOT be equal because they have different operation names + assert_ne!(op1, op2); + assert_ne!(op2, op1); // Symmetric + } + + #[test] + fn test_operation_custom_hash_consistency() { + // Test that equal operations have the same hash + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + let op2 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(Vec::new()), // Empty context + ); + + // Equal operations should have the same hash + assert_eq!(op1, op2); + + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher1 = DefaultHasher::new(); + op1.hash(&mut hasher1); + let hash1 = hasher1.finish(); + + let mut hasher2 = DefaultHasher::new(); + op2.hash(&mut hasher2); + let hash2 = hasher2.finish(); + + assert_eq!(hash1, hash2, "Equal operations should have the same hash"); + } + + #[test] + fn test_operation_custom_hash_different_for_unequal() { + // Test that unequal operations typically have different hashes + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + let context = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["s3.us-east-1.amazonaws.com".to_string()], + )]; + let op2 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(context), + ); + + // Unequal operations should typically have different hashes + assert_ne!(op1, op2); + + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher1 = DefaultHasher::new(); + op1.hash(&mut hasher1); + let hash1 = hasher1.finish(); + + let mut hasher2 = DefaultHasher::new(); + op2.hash(&mut hasher2); + let hash2 = hasher2.finish(); + + // Note: Hash collisions are possible but unlikely for this test case + assert_ne!( + hash1, hash2, + "Unequal operations should typically have different hashes" + ); + } } #[cfg(test)] diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs index 6eb3a54..3163159 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs @@ -52,15 +52,6 @@ impl FasExpansion { for additional_op in additional_operations { let new_op = Arc::new(Operation::from(additional_op.clone())); - if dependency_graph.contains_key(&new_op) { - // Skip adding this operation as it's logically equivalent to an existing one - log::debug!( - "Skipping logically equivalent operation: {}", - new_op.service_operation_name() - ); - continue; - } - if let Some(existing_deps) = dependency_graph.get_mut(&new_op) { // Operation already exists, add this dependency relationship existing_deps.push(Arc::clone(current)); From b6dafa0c4cdd0973f8ccffab2d64c2bf46bc478c Mon Sep 17 00:00:00 2001 From: Matthias Schlaipfer Date: Wed, 14 Jan 2026 13:56:28 +0000 Subject: [PATCH 11/11] Address PR comments, simpler explanation documentation --- CHANGELOG.md | 2 +- .../src/enrichment/mod.rs | 31 +++------ .../src/extraction/javascript/shared.rs | 2 +- .../src/extraction/mod.rs | 66 ++----------------- 4 files changed, 14 insertions(+), 87 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 210e273..1f86d94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,4 +1,4 @@ -## [0.1.3] +## [Unreleased] ## Added diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs index 6a8f702..52181a4 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs @@ -211,49 +211,34 @@ pub struct Explanations { /// Explanation for inclusion of an action pub explanation_for_action: BTreeMap, /// Documentation of concepts used in the explanation for an action - pub documentation: BTreeMap<&'static str, Documentation>, + #[serde(skip_serializing_if = "Vec::is_empty")] + pub documentation: Vec<&'static str>, } impl Explanations { - const FAS_PLAIN: &str = - "The explanation contains an operation added due to Forward Access Sessions."; - const FAS_URL: &str = - "https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html"; + const FAS: &str = + "The explanation contains an operation added due to Forward Access Sessions (FAS). See https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html."; pub(crate) fn new(explanations: BTreeMap) -> Self { - let mut documentation: Vec<(&'static str, Documentation)> = vec![]; + let mut documentation: Vec<&'static str> = vec![]; for explanation in explanations.values() { for reason in &explanation.reasons { for op in &reason.operations { match op.source { OperationSource::Extracted(_) | OperationSource::Provided => (), - OperationSource::Fas(_) => documentation.push(( - "FAS", - Documentation { - plain: Self::FAS_PLAIN, - url: Self::FAS_URL, - }, - )), + OperationSource::Fas(_) => documentation.push(Self::FAS), } } } } + documentation.dedup(); Self { explanation_for_action: explanations, - documentation: BTreeMap::from_iter(documentation), + documentation, } } } -/// Documentation of concepts appearing in explanations -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "PascalCase")] -pub struct Documentation { - plain: &'static str, - #[serde(rename = "URL")] - url: &'static str, -} - /// Represents an explanation for why an action was added to a policy #[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema, Default)] // Don't print the `"Reasons":` key, treat this as just a JSON array. diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs index 56c0669..870bff8 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs @@ -196,7 +196,7 @@ impl ExtractionUtils { // Extract operation name by removing "Command" suffix if let Some(operation_name) = command_name.strip_suffix("Command") { // Keep PascalCase operation name to match service index - // e.g., "CreateBucket" stays "CreateBucket" + // e.g., "PutItem" from "PutItemCommand" let method_call = SdkMethodCall { name: operation_name.to_string(), possible_services: vec![service.clone()], diff --git a/iam-policy-autopilot-policy-generation/src/extraction/mod.rs b/iam-policy-autopilot-policy-generation/src/extraction/mod.rs index d635a50..ff14ff0 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/mod.rs @@ -344,68 +344,10 @@ mod tests { } #[test] - fn test_parsed_method_creation() { - let method = SdkMethodCall { - name: "test_method".to_string(), - possible_services: Vec::new(), - metadata: Some(SdkMethodCallMetadata { - parameters: vec![Parameter::Keyword { - name: "param1".to_string(), - value: ParameterValue::Resolved("test_value".to_string()), - position: 0, - type_annotation: Some("str".to_string()), - }], - return_type: Some("bool".to_string()), - expr: "test_method".to_string(), - location: Location::new(PathBuf::new(), (10, 1), (10, 25)), - receiver: None, - }), - }; - - assert_eq!(method.name, "test_method"); - assert_eq!(method.metadata.as_ref().unwrap().parameters.len(), 1); - assert_eq!( - method.metadata.as_ref().unwrap().location.start_position, - (10, 1) - ); - assert_eq!( - method.metadata.as_ref().unwrap().location.end_position, - (10, 25) - ); - } - - #[test] - fn test_parsed_method_with_sdk_context() { - let method = SdkMethodCall { - name: "get_object".to_string(), - possible_services: vec!["s3".to_string()], - metadata: Some(SdkMethodCallMetadata { - parameters: vec![Parameter::Keyword { - name: "Bucket".to_string(), - value: ParameterValue::Resolved("my-bucket".to_string()), - position: 0, - type_annotation: Some("str".to_string()), - }], - return_type: Some("Dict[str, Any]".to_string()), - expr: "get_object".to_string(), - location: Location::new(PathBuf::new(), (15, 5), (15, 45)), - receiver: Some("s3_client".to_string()), - }), - }; - - assert_eq!(method.name, "get_object"); - assert_eq!( - method.metadata.as_ref().unwrap().receiver, - Some("s3_client".to_string()) - ); - assert_eq!(method.possible_services, vec!["s3".to_string()]); - if let Parameter::Keyword { - value, position, .. - } = &method.metadata.as_ref().unwrap().parameters[0] - { - assert_eq!(value.as_string(), "my-bucket"); - assert_eq!(*position, 0); - } + fn test_location_construction() { + let location = Location::new(PathBuf::new(), (10, 1), (10, 25)); + assert_eq!(location.start_position, (10, 1)); + assert_eq!(location.end_position, (10, 25)); } #[test]