diff --git a/CHANGELOG.md b/CHANGELOG.md index 8318de6..1f86d94 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,9 @@ +## [Unreleased] + +## Added + +- Add `--explain` feature to output the reasons for why an action has been added to the policy. The explanations allow to review the operations which static analysis extracted from source code, and to correct them using the `--service-hints` flag, if necessary. + ## [0.1.2] - 2025-12-15 ## Fixed diff --git a/Cargo.toml b/Cargo.toml index ff7a8a5..cc5280f 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,7 +19,7 @@ description = "An open source Model Context Protocol (MCP) server and command-li # Shared dependency versions across workspace [workspace.dependencies] # Core serialization and utilities -serde = { version = "1.0", features = ["derive"] } +serde = { version = "1.0", features = ["derive", "rc"] } futures = "0.3" thiserror = "1.0" anyhow = "1.0" @@ -29,6 +29,7 @@ ast-grep-core = "0.39" schemars = { version = "^1", features = ["derive"] } rust-embed = { version = "8.9", features = ["compression", "include-exclude"] } reqwest = { version = "0.12.4", features = ["rustls-tls"], default-features = false } +derive-new = "0.7.0" openssl = { version = "0.10", features = ["vendored"] } # Native async runtime and parallel processing diff --git a/iam-policy-autopilot-cli/src/main.rs b/iam-policy-autopilot-cli/src/main.rs index c9158cc..47e3302 100644 --- a/iam-policy-autopilot-cli/src/main.rs +++ b/iam-policy-autopilot-cli/src/main.rs @@ -84,14 +84,14 @@ struct GeneratePolicyCliConfig { account: String, /// Output individual policies instead of merged policy individual_policies: bool, - /// Show method to action mappings alongside policies - show_action_mappings: bool, /// Upload policies to AWS with optional custom name prefix upload_policies: Option, /// Enable minimal policy size by allowing cross-service merging minimal_policy_size: bool, /// Disable file system caching for service references disable_cache: bool, + /// Generate explanations for why actions were added + explain: bool, } impl GeneratePolicyCliConfig { @@ -283,16 +283,6 @@ for each method call. Disables --upload-policy, if provided." )] individual_policies: bool, - /// Include method to action mappings alongside the generated policies - #[arg( - hide = true, - long = "show-action-mappings", - long_help = "When enabled, outputs detailed method to action \ -mappings alongside the generated policies. This provides granular visibility into which SDK method calls \ -require which specific IAM actions and their associated resources. Disables --upload-policy, if provided." - )] - show_action_mappings: bool, - /// Upload generated policies to AWS IAM with optional custom name prefix #[arg(long = "upload-policies", num_args = 0..=1, require_equals = false, default_missing_value = "", long_help = "Upload the generated policies to AWS IAM using the iam:CreatePolicy API. \ @@ -330,6 +320,15 @@ Use this flag to force fresh data retrieval on every run." long_help = SERVICE_HINTS_LONG_HELP, )] service_hints: Option>, + + /// Generate explanations for why actions were added + #[arg( + long = "explain", + long_help = "When enabled, generates detailed explanations for why each IAM action \ +was added to the policy. Explanations include the initial operation with location information, FAS (https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html) expansion chains. The output format \ +may change in future versions." + )] + explain: bool, }, /// Start MCP server @@ -438,35 +437,24 @@ async fn handle_generate_policy(config: &GeneratePolicyCliConfig) -> Result<()> service_names: names.clone(), }); - let (policies, method_action_mappings) = generate_policies(&GeneratePolicyConfig { + let result = generate_policies(&GeneratePolicyConfig { extract_sdk_calls_config: ExtractSdkCallsConfig { source_files: config.shared.source_files.to_owned(), language: config.shared.language.to_owned(), service_hints, }, aws_context: AwsContext::new(config.region.clone(), config.account.clone()), - generate_action_mappings: config.show_action_mappings, individual_policies: config.individual_policies, minimize_policy_size: config.minimal_policy_size, disable_file_system_cache: config.disable_cache, + generate_explanations: config.explain, }) .await?; - // Handle policy output based on configuration - if config.show_action_mappings { - // Output combined format with mappings and policies - output::output_combined_policy_mappings( - method_action_mappings, - policies, - config.shared.pretty, - ) - .context("Failed to output combined policy and mappings")?; - - trace!("Combined policy and mappings output written to stdout"); - } else if config.individual_policies { + if config.individual_policies { // Output individual policies - trace!("Outputting {} individual policies", policies.len()); - output::output_iam_policies(policies, None, config.shared.pretty) + trace!("Outputting {} individual policies", result.policies.len()); + output::output_iam_policies(result, None, config.shared.pretty) .context("Failed to output individual IAM policies")?; } else { // Default behavior: output merged policy with optional upload @@ -479,7 +467,7 @@ async fn handle_generate_policy(config: &GeneratePolicyCliConfig) -> Result<()> let custom_name = config.upload_policies.as_deref().filter(|s| !s.is_empty()); let batch_response = uploader - .upload_policies(&policies, custom_name) + .upload_policies(&result.policies, custom_name) .await .context("Failed to upload policies to AWS IAM")?; @@ -505,7 +493,7 @@ async fn handle_generate_policy(config: &GeneratePolicyCliConfig) -> Result<()> None }; - output::output_iam_policies(policies, upload_result, config.shared.pretty) + output::output_iam_policies(result, upload_result, config.shared.pretty) .context("Failed to output merged IAM policy")? } @@ -577,11 +565,11 @@ async fn main() { region, account, individual_policies, - show_action_mappings, upload_policies, minimal_policy_size, disable_cache, service_hints, + explain, } => { // Initialize logging if let Err(e) = init_logging(debug) { @@ -600,10 +588,10 @@ async fn main() { region, account, individual_policies, - show_action_mappings, upload_policies, minimal_policy_size, disable_cache, + explain, }; match handle_generate_policy(&config).await { diff --git a/iam-policy-autopilot-cli/src/output.rs b/iam-policy-autopilot-cli/src/output.rs index fc136fa..959df50 100644 --- a/iam-policy-autopilot-cli/src/output.rs +++ b/iam-policy-autopilot-cli/src/output.rs @@ -1,8 +1,6 @@ use anyhow::{Context, Result}; use iam_policy_autopilot_access_denied::{DenialType, PlanResult}; -use iam_policy_autopilot_policy_generation::policy_generation::{ - MethodActionMapping, PolicyWithMetadata, -}; +use iam_policy_autopilot_policy_generation::api::model::GeneratePoliciesResult; use iam_policy_autopilot_tools::BatchUploadResponse; use log::debug; use std::io::{self, Write}; @@ -163,63 +161,17 @@ pub(crate) fn print_unsupported_denial(denial_type: &DenialType, reason: &str) { #[serde(rename_all = "PascalCase")] struct PolicyOutput { /// The generated policies with type information - policies: Vec, + /// and explanations for why actions were added + #[serde(flatten)] + result: GeneratePoliciesResult, /// Upload results (only present when --upload-policies is used) #[serde(skip_serializing_if = "Option::is_none")] upload_result: Option, } -/// Combined output structure when showing action mappings alongside policies -#[derive(Debug, Clone, serde::Serialize)] -#[serde(rename_all = "PascalCase")] -struct CombinedPolicyOutput { - /// Method to action mappings - method_action_mappings: Vec, - /// The generated policies with type information - policies: Vec, - /// Upload results (only present when --upload-policies is used) - #[serde(skip_serializing_if = "Option::is_none")] - upload_result: Option, -} - -/// Output combined policy and mappings as JSON to stdout -pub(crate) fn output_combined_policy_mappings( - method_action_mappings: Vec, - policies: Vec, - pretty: bool, -) -> Result<()> { - debug!( - "Formatting combined policy and mappings output as JSON (pretty: {})", - pretty - ); - - let combined_output = CombinedPolicyOutput { - method_action_mappings, - policies, - upload_result: None, - }; - - let json_output = if pretty { - iam_policy_autopilot_policy_generation::JsonProvider::stringify_pretty(&combined_output) - .context("Failed to serialize combined output to pretty JSON")? - } else { - iam_policy_autopilot_policy_generation::JsonProvider::stringify(&combined_output) - .context("Failed to serialize combined output to JSON")? - }; - - // Output to stdout (not using println! to avoid extra newline in compact mode) - print!("{}", json_output); - if pretty { - println!(); // Add newline for pretty output - } - - debug!("Combined policy and mappings JSON output written to stdout"); - Ok(()) -} - /// Output IAM policies as JSON to stdout pub(crate) fn output_iam_policies( - policies: Vec, + result: GeneratePoliciesResult, upload_result: Option, pretty: bool, ) -> Result<()> { @@ -229,7 +181,7 @@ pub(crate) fn output_iam_policies( ); let policy_output = PolicyOutput { - policies, + result, upload_result, }; diff --git a/iam-policy-autopilot-mcp-server/src/tools/generate_policy.rs b/iam-policy-autopilot-mcp-server/src/tools/generate_policy.rs index 4e373a2..eb7e371 100644 --- a/iam-policy-autopilot-mcp-server/src/tools/generate_policy.rs +++ b/iam-policy-autopilot-mcp-server/src/tools/generate_policy.rs @@ -52,7 +52,7 @@ pub async fn generate_application_policies( service_names: hints, }); - let (policies, _) = api::generate_policies(&GeneratePolicyConfig { + let result = api::generate_policies(&GeneratePolicyConfig { individual_policies: false, extract_sdk_calls_config: ExtractSdkCallsConfig { source_files: input.source_files.into_iter().map(|f| f.into()).collect(), @@ -61,16 +61,17 @@ pub async fn generate_application_policies( service_hints, }, aws_context: AwsContext::new(region, account), - generate_action_mappings: false, minimize_policy_size: false, // true by default, if we want to allow the user to change it we should // accept it as part of the cli input when starting the mcp server disable_file_system_cache: true, + generate_explanations: false, }) .await?; - let policies = policies + let policies = result + .policies .into_iter() .map(|policy| serde_json::to_string(&policy.policy).context("Failed to serialize policy")) .collect::, Error>>()?; @@ -82,26 +83,23 @@ pub async fn generate_application_policies( #[cfg(test)] mod api { use anyhow::Result; - use iam_policy_autopilot_policy_generation::{ - api::model::GeneratePolicyConfig, policy_generation::PolicyWithMetadata, - MethodActionMapping, + use iam_policy_autopilot_policy_generation::api::model::{ + GeneratePoliciesResult, GeneratePolicyConfig, }; // Static mutable return value - pub static mut MOCK_RETURN_VALUE: Option< - Result<(Vec, Vec)>, - > = None; + pub static mut MOCK_RETURN_VALUE: Option> = None; pub async fn generate_policies( _config: &GeneratePolicyConfig, - ) -> Result<(Vec, Vec)> { + ) -> Result { #[allow(static_mut_refs)] unsafe { MOCK_RETURN_VALUE.take().unwrap() } } - pub fn set_mock_return(value: Result<(Vec, Vec)>) { + pub fn set_mock_return(value: Result) { unsafe { MOCK_RETURN_VALUE = Some(value) } } } @@ -113,7 +111,7 @@ mod tests { use super::*; use iam_policy_autopilot_policy_generation::{ - IamPolicy, PolicyType, PolicyWithMetadata, Statement, + api::model::GeneratePoliciesResult, IamPolicy, PolicyType, PolicyWithMetadata, Statement, }; use anyhow::anyhow; @@ -143,7 +141,12 @@ mod tests { policy_type: PolicyType::Identity, }; - api::set_mock_return(Ok((vec![policy], vec![]))); + use iam_policy_autopilot_policy_generation::api::model::GeneratePoliciesResult; + + api::set_mock_return(Ok(GeneratePoliciesResult { + policies: vec![policy], + explanations: None, + })); let result = generate_application_policies(input).await; println!("{result:?}"); @@ -223,7 +226,10 @@ mod tests { policy_type: PolicyType::Identity, }; - api::set_mock_return(Ok((vec![policy], vec![]))); + api::set_mock_return(Ok(GeneratePoliciesResult { + policies: vec![policy], + explanations: None, + })); let result = generate_application_policies(input).await; assert!(result.is_ok()); diff --git a/iam-policy-autopilot-policy-generation/Cargo.toml b/iam-policy-autopilot-policy-generation/Cargo.toml index 0a6d469..ed28f7b 100644 --- a/iam-policy-autopilot-policy-generation/Cargo.toml +++ b/iam-policy-autopilot-policy-generation/Cargo.toml @@ -27,6 +27,7 @@ serde_json.workspace = true tokio.workspace = true async-trait.workspace = true strsim.workspace = true +derive-new.workspace = true # Build dependencies diff --git a/iam-policy-autopilot-policy-generation/src/api/generate_policies.rs b/iam-policy-autopilot-policy-generation/src/api/generate_policies.rs index c75d6ad..093f857 100644 --- a/iam-policy-autopilot-policy-generation/src/api/generate_policies.rs +++ b/iam-policy-autopilot-policy-generation/src/api/generate_policies.rs @@ -4,16 +4,17 @@ use std::time::Instant; use log::{debug, info, trace}; use crate::{ - api::{common::process_source_files, model::GeneratePolicyConfig}, + api::{ + common::process_source_files, + model::{GeneratePoliciesResult, GeneratePolicyConfig}, + }, extraction::SdkMethodCall, - policy_generation::{merge::PolicyMergerConfig, MethodActionMapping, PolicyWithMetadata}, + policy_generation::merge::PolicyMergerConfig, EnrichmentEngine, PolicyGenerationEngine, }; -/// Generate polcies for source files -pub async fn generate_policies( - config: &GeneratePolicyConfig, -) -> Result<(Vec, Vec)> { +/// Generate policies for source files +pub async fn generate_policies(config: &GeneratePolicyConfig) -> Result { let pipeline_start = Instant::now(); debug!( @@ -55,7 +56,10 @@ pub async fn generate_policies( // Handle empty method lists gracefully if extracted_methods.is_empty() { info!("No methods found to process, returning empty policy list"); - return Ok((vec![], vec![])); + return Ok(GeneratePoliciesResult { + policies: vec![], + explanations: None, + }); } let mut enrichment_engine = EnrichmentEngine::new(config.disable_file_system_cache)?; @@ -85,7 +89,7 @@ pub async fn generate_policies( "Generating IAM policies from {} enriched method calls", enriched_results.len() ); - let policies = policy_engine + let result = policy_engine .generate_policies(&enriched_results) .context("Failed to generate IAM policies")?; @@ -93,10 +97,15 @@ pub async fn generate_policies( debug!( "Policy generation completed in {:?}, generated {} policies", total_duration, - policies.len() + result.policies.len() ); - let mut final_policies = policies; + let mut final_policies = result.policies; + let explanations = if config.generate_explanations { + result.explanations + } else { + None + }; if !config.individual_policies { final_policies = policy_engine @@ -104,24 +113,8 @@ pub async fn generate_policies( .context("Failed to merge IAM policies")?; } - // Handle policy output based on configuration - if config.generate_action_mappings { - // Extract method to action mappings using the core method - debug!( - "Extracting method to action mappings from {} enriched method calls", - enriched_results.len() - ); - let method_action_mappings = policy_engine - .extract_action_mappings(&enriched_results) - .context("Failed to extract method to action mappings")?; - - debug!( - "Extracted {} method to action mappings", - method_action_mappings.len() - ); - - return Ok((final_policies, method_action_mappings)); - } - - Ok((final_policies, vec![])) + Ok(GeneratePoliciesResult { + policies: final_policies, + explanations, + }) } diff --git a/iam-policy-autopilot-policy-generation/src/api/model.rs b/iam-policy-autopilot-policy-generation/src/api/model.rs index 77c9bb0..60c2846 100644 --- a/iam-policy-autopilot-policy-generation/src/api/model.rs +++ b/iam-policy-autopilot-policy-generation/src/api/model.rs @@ -1,23 +1,35 @@ //! Defined model for API -use std::path::PathBuf; - use serde::{Deserialize, Serialize}; -/// Configuration for generate_policies Api +use crate::{enrichment::Explanations, policy_generation::PolicyWithMetadata}; +use std::path::PathBuf; + +/// Configuration for generate_policies API #[derive(Debug, Clone)] pub struct GeneratePolicyConfig { /// Config used to extract sdk calls for policy generation pub extract_sdk_calls_config: ExtractSdkCallsConfig, /// AWS Config pub aws_context: AwsContext, - /// Generate Action Mappings - pub generate_action_mappings: bool, /// Output individual policies pub individual_policies: bool, /// Enable policy size minimization pub minimize_policy_size: bool, /// Disable file system caching for service references pub disable_file_system_cache: bool, + /// Generate explanations for why actions were added + pub generate_explanations: bool, +} + +/// Result of policy generation including policies, action mappings, and explanations +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "PascalCase")] +pub struct GeneratePoliciesResult { + /// Generated IAM policies + pub policies: Vec, + /// Explanations for why actions were added (if requested) + #[serde(skip_serializing_if = "Option::is_none")] + pub explanations: Option, } /// Service hints for filtering SDK method calls diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs index bdc0a24..52181a4 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/mod.rs @@ -8,7 +8,18 @@ //! that represent method calls enriched with IAM metadata from operation //! action maps and Service Definition Files. -use crate::SdkMethodCall; +use std::{ + collections::{BTreeMap, HashSet}, + sync::Arc, +}; + +use crate::{ + enrichment::operation_fas_map::{FasContext, FasOperation}, + extraction::SdkMethodCallMetadata, + service_configuration::ServiceConfiguration, + SdkMethodCall, SdkType, +}; +use convert_case::{Case, Casing}; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; @@ -22,6 +33,233 @@ pub(crate) use operation_fas_map::load_operation_fas_map; pub(crate) use resource_matcher::ResourceMatcher; pub(crate) use service_reference::RemoteServiceReferenceLoader as ServiceReferenceLoader; +/// Represents the reason why an action was added to a policy +#[derive(derive_new::new, Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema)] +#[serde(rename_all = "PascalCase")] +pub struct Reason { + /// The original operation that was extracted + pub operations: Vec>, +} + +#[derive(Debug, Clone, Serialize, Eq, JsonSchema)] +#[serde(rename_all = "PascalCase")] +pub struct Operation { + /// Name of the service + pub service: String, + /// Name of the operation + pub name: String, + /// Source of the operation, + pub source: OperationSource, + /// Disallow struct construction, need to use Self::from_call or Operation::from(FasOperation) + #[serde(skip)] + _private: (), +} + +impl Operation { + #[cfg(test)] + /// Convenience constructor for tests + pub(crate) fn new(service: String, name: String, source: OperationSource) -> Self { + Self { + service, + name, + source, + _private: (), + } + } + + pub(crate) fn service_operation_name(&self) -> String { + format!("{}:{}", self.service, self.name) + } + + pub(crate) fn context(&self) -> &[FasContext] { + match &self.source { + OperationSource::Fas(context) => context, + _ => &[], + } + } + + pub(crate) async fn from_call( + call: &SdkMethodCall, + original_service_name: &str, + service_cfg: &ServiceConfiguration, + sdk: SdkType, + service_reference_loader: &ServiceReferenceLoader, + ) -> crate::errors::Result { + let service = service_cfg + .rename_service_service_reference(original_service_name) + .to_string(); + let name = if sdk == SdkType::Boto3 { + // Try to load service reference and look up the boto3 method mapping + service_reference_loader + .load(&service) + .await? + .and_then(|service_ref| { + log::debug!("Looking up method {}", call.name); + service_ref + .boto3_method_to_operation + .get(&call.name) + .map(|op| { + log::debug!("got {:?}", op); + op.split(':').nth(1).unwrap_or(op).to_string() + }) + }) + // Fallback to PascalCase conversion if mapping not found + // This should not be reachable, but if for some reason we cannot use the SDF, + // we try converting to PascalCase, knowing that this is flawed in some cases: + // think `AddRoleToDBInstance` (actual name) + // vs. `AddRoleToDbInstance` (converted name) + .unwrap_or_else(|| call.name.to_case(Case::Pascal)) + } else { + // For non-Boto3 SDKs we use the extracted name as-is + call.name.clone() + }; + + Ok(match &call.metadata { + None => Self { + service, + name, + source: OperationSource::Provided, + _private: (), + }, + Some(metadata) => Self { + service, + name, + source: OperationSource::Extracted(metadata.clone()), + _private: (), + }, + }) + } +} + +impl From for Operation { + fn from(fas_op: FasOperation) -> Self { + Self { + service: fas_op.service, + name: fas_op.operation, + source: OperationSource::Fas(fas_op.context), + _private: (), + } + } +} + +// Custom PartialEq and Hash implementations for Operation: + +// We consider operations to be equal when they would produce the same action in a policy. +// I.e., same operation and same context used for the condition. Directly relevant to FAS expansion. +impl PartialEq for Operation { + fn eq(&self, other: &Self) -> bool { + self.service == other.service + && self.name == other.name + && self.context() == other.context() + } +} + +impl std::hash::Hash for Operation { + fn hash(&self, state: &mut H) { + self.service.hash(state); + self.name.hash(state); + self.context().hash(state); + } +} + +/// Custom serializer for extracted metadata that flattens the structure +fn serialize_extracted_metadata( + metadata: &SdkMethodCallMetadata, + serializer: S, +) -> Result +where + S: serde::Serializer, +{ + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(2))?; + map.serialize_entry("Expr", &metadata.expr)?; + map.serialize_entry("Location", &metadata.location)?; + map.end() +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, JsonSchema)] +#[serde(rename_all = "PascalCase")] +pub enum OperationSource { + /// Operation extracted from source files + Extracted(SdkMethodCallMetadata), + /// Operation provided (no metadata available) + Provided, + /// Operation comes from FAS expansion + Fas(Vec), +} + +impl Serialize for OperationSource { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + OperationSource::Extracted(metadata) => { + serialize_extracted_metadata(metadata, serializer) + } + OperationSource::Provided => serializer.serialize_str("Provided"), + OperationSource::Fas(_) => serializer.serialize_str("FAS"), + } + } +} + +/// Explanations for why actions have been included in a policy, with documentation for +/// concepts leading to inclusion (such as FAS expansion) +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "PascalCase")] +pub struct Explanations { + /// Explanation for inclusion of an action + pub explanation_for_action: BTreeMap, + /// Documentation of concepts used in the explanation for an action + #[serde(skip_serializing_if = "Vec::is_empty")] + pub documentation: Vec<&'static str>, +} + +impl Explanations { + const FAS: &str = + "The explanation contains an operation added due to Forward Access Sessions (FAS). See https://docs.aws.amazon.com/IAM/latest/UserGuide/access_forward_access_sessions.html."; + + pub(crate) fn new(explanations: BTreeMap) -> Self { + let mut documentation: Vec<&'static str> = vec![]; + for explanation in explanations.values() { + for reason in &explanation.reasons { + for op in &reason.operations { + match op.source { + OperationSource::Extracted(_) | OperationSource::Provided => (), + OperationSource::Fas(_) => documentation.push(Self::FAS), + } + } + } + } + documentation.dedup(); + Self { + explanation_for_action: explanations, + documentation, + } + } +} + +/// Represents an explanation for why an action was added to a policy +#[derive(Debug, Clone, Serialize, PartialEq, Eq, Hash, JsonSchema, Default)] +// Don't print the `"Reasons":` key, treat this as just a JSON array. +#[serde(transparent)] +pub struct Explanation { + /// The reasons this action was added (can have multiple reasons for the same action) + pub reasons: Vec, +} + +impl Explanation { + pub(crate) fn merge(&mut self, other: Explanation) { + let reasons_set = self.reasons.iter().cloned().collect::>(); + for new_reason in other.reasons { + if reasons_set.contains(&new_reason) { + continue; + } + self.reasons.push(new_reason); + } + } +} + /// Represents an enriched method call with actions that need permissions #[derive(Debug, Clone, Serialize, PartialEq)] #[non_exhaustive] @@ -67,7 +305,7 @@ pub(crate) trait Context { /// /// This structure combines OperationAction action data with Service Reference resource information to provide /// complete IAM policy metadata for a single action. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +#[derive(Debug, Clone, Serialize, PartialEq)] pub(crate) struct Action { /// The IAM action name (e.g., "s3:GetObject") pub(crate) name: String, @@ -75,18 +313,8 @@ pub(crate) struct Action { pub(crate) resources: Vec, /// List of conditions we are adding pub(crate) conditions: Vec, -} - -/// Represents a resource enriched with ARN pattern and metadata -/// -/// This structure combines OperationAction resource data with Service Reference ARN patterns and additional -/// metadata to provide complete resource information for IAM policies. -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] -pub(crate) struct Resource { - /// The resource type name (e.g., "bucket", "object", "*") - pub(crate) name: String, - /// ARN patterns from Service Reference data, if available - pub(crate) arn_patterns: Option>, + /// Optional explanation why this action has been added + pub(crate) explanation: Explanation, } impl Action { @@ -96,16 +324,35 @@ impl Action { /// * `name` - The IAM action name /// * `resources` - List of enriched resources /// * `conditions` - List of conditions + /// * `explanation` - Explanation why the action has been added #[must_use] - pub(crate) fn new(name: String, resources: Vec, conditions: Vec) -> Self { + pub(crate) fn new( + name: String, + resources: Vec, + conditions: Vec, + explanation: Explanation, + ) -> Self { Self { name, resources, conditions, + explanation, } } } +/// Represents a resource enriched with ARN pattern and metadata +/// +/// This structure combines OperationAction resource data with Service Reference ARN patterns and additional +/// metadata to provide complete resource information for IAM policies. +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)] +pub(crate) struct Resource { + /// The resource type name (e.g., "bucket", "object", "*") + pub(crate) name: String, + /// ARN patterns from Service Reference data, if available + pub(crate) arn_patterns: Option>, +} + impl Resource { /// Create a new enriched resource #[must_use] @@ -117,6 +364,7 @@ impl Resource { #[cfg(test)] mod tests { use super::*; + use crate::enrichment::operation_fas_map::FasContext; #[test] fn test_enriched_resource_creation() { @@ -131,6 +379,191 @@ mod tests { Some(vec!["arn:aws:s3:::bucket/*".to_string()]) ); } + + #[test] + fn test_operation_custom_equality_same_operation_different_sources() { + // Test that operations with same service, name, and context are equal regardless of source + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + let op2 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(Vec::new()), // Empty context + ); + + // Should be equal because they have same service, name, and context (both empty) + assert_eq!(op1, op2); + assert_eq!(op2, op1); // Symmetric + } + + #[test] + fn test_operation_custom_equality_different_contexts() { + // Test that operations with different contexts are NOT equal + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, // Empty context + ); + + let context = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["s3.us-east-1.amazonaws.com".to_string()], + )]; + let op2 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(context), + ); + + // Should NOT be equal because they have different contexts + assert_ne!(op1, op2); + assert_ne!(op2, op1); // Symmetric + } + + #[test] + fn test_operation_custom_equality_same_contexts() { + // Test that operations with same contexts are equal + let context1 = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["s3.us-east-1.amazonaws.com".to_string()], + )]; + let context2 = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["s3.us-east-1.amazonaws.com".to_string()], + )]; + + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(context1), + ); + + let op2 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(context2), + ); + + // Should be equal because they have same service, name, and context + assert_eq!(op1, op2); + assert_eq!(op2, op1); // Symmetric + } + + #[test] + fn test_operation_custom_equality_different_services() { + // Test that operations with different services are NOT equal + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + let op2 = Operation::new( + "kms".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + // Should NOT be equal because they have different services + assert_ne!(op1, op2); + assert_ne!(op2, op1); // Symmetric + } + + #[test] + fn test_operation_custom_equality_different_names() { + // Test that operations with different names are NOT equal + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + let op2 = Operation::new( + "s3".to_string(), + "PutObject".to_string(), + OperationSource::Provided, + ); + + // Should NOT be equal because they have different operation names + assert_ne!(op1, op2); + assert_ne!(op2, op1); // Symmetric + } + + #[test] + fn test_operation_custom_hash_consistency() { + // Test that equal operations have the same hash + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + let op2 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(Vec::new()), // Empty context + ); + + // Equal operations should have the same hash + assert_eq!(op1, op2); + + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher1 = DefaultHasher::new(); + op1.hash(&mut hasher1); + let hash1 = hasher1.finish(); + + let mut hasher2 = DefaultHasher::new(); + op2.hash(&mut hasher2); + let hash2 = hasher2.finish(); + + assert_eq!(hash1, hash2, "Equal operations should have the same hash"); + } + + #[test] + fn test_operation_custom_hash_different_for_unequal() { + // Test that unequal operations typically have different hashes + let op1 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Provided, + ); + + let context = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["s3.us-east-1.amazonaws.com".to_string()], + )]; + let op2 = Operation::new( + "s3".to_string(), + "GetObject".to_string(), + OperationSource::Fas(context), + ); + + // Unequal operations should typically have different hashes + assert_ne!(op1, op2); + + use std::collections::hash_map::DefaultHasher; + use std::hash::{Hash, Hasher}; + + let mut hasher1 = DefaultHasher::new(); + op1.hash(&mut hasher1); + let hash1 = hasher1.finish(); + + let mut hasher2 = DefaultHasher::new(); + op2.hash(&mut hasher2); + let hash2 = hasher2.finish(); + + // Note: Hash collisions are possible but unlikely for this test case + assert_ne!( + hash1, hash2, + "Unequal operations should typically have different hashes" + ); + } } #[cfg(test)] @@ -362,3 +795,202 @@ pub(crate) mod mock_remote_service_reference { (mock_server, loader) } } + +#[cfg(test)] +mod location_tests { + use super::*; + use crate::{ + enrichment::mock_remote_service_reference::setup_mock_server_with_loader_without_operation_to_action_mapping, + service_configuration::load_service_configuration, Location, + }; + use std::path::PathBuf; + + #[test] + fn test_location_to_gnu_string() { + let location = Location::new(PathBuf::from("src/main.rs"), (10, 5), (10, 79)); + + assert_eq!(location.to_gnu_format(), "src/main.rs:10.5-10.79"); + } + + #[test] + fn test_location_to_gnu_string_multiline() { + let location = Location::new(PathBuf::from("src/lib.rs"), (10, 5), (15, 20)); + + assert_eq!(location.to_gnu_format(), "src/lib.rs:10.5-15.20"); + } + + #[test] + fn test_location_serialization() { + let location = Location::new(PathBuf::from("test.py"), (42, 15), (42, 80)); + + let json = serde_json::to_string(&location).unwrap(); + assert_eq!(json, "\"test.py:42.15-42.80\""); + } + + #[test] + fn test_location_serialization_multiline() { + let location = Location::new(PathBuf::from("example.go"), (100, 1), (105, 50)); + + let json = serde_json::to_string(&location).unwrap(); + assert_eq!(json, "\"example.go:100.1-105.50\""); + } + + fn mock_sdk_method_call() -> SdkMethodCall { + SdkMethodCall { + name: "get_object".to_string(), + possible_services: vec!["s3".to_string()], + metadata: Some(SdkMethodCallMetadata { + parameters: vec![], + return_type: None, + expr: "s3.get_object(Bucket='my-bucket')".to_string(), + location: Location::new(PathBuf::from("test.py"), (10, 5), (10, 79)), + receiver: Some("s3".to_string()), + }), + } + } + + #[tokio::test] + async fn test_reason_extracted_with_location() { + let service_cfg = load_service_configuration().unwrap(); + let (_, service_reference_loader) = + setup_mock_server_with_loader_without_operation_to_action_mapping().await; + let call = mock_sdk_method_call(); + + let reason = Reason::new(vec![Arc::new( + Operation::from_call( + &call, + "s3", + &service_cfg, + SdkType::Boto3, + &service_reference_loader, + ) + .await + .unwrap(), + )]); + + assert_eq!(reason.operations[0].name, "GetObject"); + assert_eq!(reason.operations[0].service, "s3"); + match &reason.operations[0].source { + OperationSource::Extracted(metadata) => { + assert_eq!(metadata.expr, "s3.get_object(Bucket='my-bucket')"); + assert_eq!(metadata.location.to_gnu_format(), "test.py:10.5-10.79"); + } + _ => panic!("Expected Extracted variant"), + } + + let json = serde_json::to_string(&reason).unwrap(); + // Verify the location is serialized as a string in GNU format + assert!(json.contains("\"Location\":\"test.py:10.5-10.79\"")); + } + + #[test] + fn test_operation_source_extracted_serialization() { + let metadata = SdkMethodCallMetadata { + parameters: vec![], + return_type: None, + expr: "dynamodb.get_item(\n TableName='my-table',\n Key={'id': {'S': '123'}}\n )".to_string(), + location: Location::new(PathBuf::from("iam-policy-autopilot-cli/tests/resources/test_example.py"), (19, 5), (22, 5)), + receiver: Some("dynamodb".to_string()), + }; + + let source = OperationSource::Extracted(metadata); + let json = serde_json::to_string(&source).unwrap(); + + // Verify the custom serialization format + assert!(json.contains("\"Expr\":\"dynamodb.get_item(\\n TableName='my-table',\\n Key={'id': {'S': '123'}}\\n )\"")); + assert!(json.contains( + "\"Location\":\"iam-policy-autopilot-cli/tests/resources/test_example.py:19.5-22.5\"" + )); + // Should not contain nested "Source" key + assert!(!json.contains("\"Source\"")); + } + + #[test] + fn test_operation_source_provided_serialization() { + let source = OperationSource::Provided; + let json = serde_json::to_string(&source).unwrap(); + + // Verify the custom serialization format + assert_eq!(json, "\"Provided\""); + } + + #[test] + fn test_operation_source_fas_serialization() { + use crate::enrichment::operation_fas_map::FasContext; + + let fas_context = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["ssm.${region}.amazonaws.com".to_string()], + )]; + let source = OperationSource::Fas(fas_context); + let json = serde_json::to_string(&source).unwrap(); + + // Verify the custom serialization format - should be just "FAS", not nested + assert_eq!(json, "\"FAS\""); + } + + #[tokio::test] + async fn test_operation_methods() { + let service_cfg = load_service_configuration().unwrap(); + let (_, service_reference_loader) = + setup_mock_server_with_loader_without_operation_to_action_mapping().await; + + { + let call = SdkMethodCall { + name: "decrypt".to_string(), + possible_services: vec!["kms".to_string()], + metadata: None, + }; + let op = Operation::from_call( + &call, + "kms", + &service_cfg, + SdkType::Boto3, + &service_reference_loader, + ) + .await + .unwrap(); + assert_eq!(op.service_operation_name(), "kms:Decrypt"); + assert_eq!(op.context(), &[]); + } + + { + let expr = "kms.decrypt(...)".to_string(); + let metadata = SdkMethodCallMetadata { + parameters: vec![], + return_type: None, + expr: expr.clone(), + location: Location::new(PathBuf::new(), (1, 1), (1, expr.len() + 1)), + receiver: Some("kms".to_string()), + }; + let call = SdkMethodCall { + name: "decrypt".to_string(), + possible_services: vec!["kms".to_string()], + metadata: Some(metadata), + }; + let op = Operation::from_call( + &call, + "kms", + &service_cfg, + SdkType::Boto3, + &service_reference_loader, + ) + .await + .unwrap(); + assert_eq!(op.service_operation_name(), "kms:Decrypt"); + assert_eq!(op.context(), &[]); + } + + { + let context = vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["ssm.${region}.amazonaws.com".to_string()], + )]; + let fas_operation = + FasOperation::new("Decrypt".to_string(), "kms".to_string(), context.clone()); + let op = Operation::from(fas_operation); + assert_eq!(op.service_operation_name(), "kms:Decrypt"); + assert_eq!(op.context(), &context); + } + } +} diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/operation_fas_map.rs b/iam-policy-autopilot-policy-generation/src/enrichment/operation_fas_map.rs index f2065cd..be524d3 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/operation_fas_map.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/operation_fas_map.rs @@ -3,15 +3,14 @@ //! This module contains the data structures used to represent operation //! action maps that are loaded from embedded JSON files and used for IAM policy enrichment. -use std::borrow::Cow; use std::collections::HashMap; use std::sync::{Arc, OnceLock, RwLock}; use rust_embed::RustEmbed; +use schemars::JsonSchema; use serde::{Deserialize, Deserializer}; use crate::enrichment::Context; -use crate::service_configuration::ServiceConfiguration; type ServiceName = String; type OperationName = String; @@ -45,8 +44,8 @@ pub(crate) struct OperationFasMap { pub(crate) fas_operations: HashMap>, } -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub(crate) struct FasContext { +#[derive(Debug, Clone, PartialEq, Eq, Hash, JsonSchema)] +pub struct FasContext { pub(crate) key: String, pub(crate) values: Vec, } @@ -117,13 +116,14 @@ where #[derive(Debug, Clone, Deserialize, PartialEq, Eq, Hash)] pub(crate) struct FasOperation { #[serde(rename = "Operation")] - operation: String, + pub(crate) operation: String, #[serde(rename = "Service")] - service: String, + pub(crate) service: String, #[serde(rename = "Context", deserialize_with = "deserialize_context_map")] pub(crate) context: Vec, } +#[cfg(test)] impl FasOperation { pub(crate) fn new(operation: String, service: String, context: Vec) -> Self { FasOperation { @@ -132,26 +132,6 @@ impl FasOperation { context, } } - - // TODO: I think this should be removed once we use the service reference API - // The Operation -> Action map uses this format, so map lookups - // need to convert to it. - pub(crate) fn service_operation_name(&self, service_cfg: &ServiceConfiguration) -> String { - let service = self.service(service_cfg); - format!("{}:{}", service, self.operation(&service, service_cfg)) - } - - pub(crate) fn service<'a>(&'a self, service_cfg: &ServiceConfiguration) -> Cow<'a, str> { - service_cfg.rename_service_service_reference(&self.service) - } - - pub(crate) fn operation<'a>( - &'a self, - service: &str, - service_cfg: &ServiceConfiguration, - ) -> Cow<'a, str> { - service_cfg.rename_operation(service, &self.operation) - } } impl<'de> Deserialize<'de> for OperationFasMap { @@ -301,29 +281,6 @@ mod tests { } } - #[test] - fn test_fas_operation_methods() { - use crate::service_configuration::load_service_configuration; - - let service_cfg = load_service_configuration().unwrap(); - - let context = vec![FasContext::new( - "kms:ViaService".to_string(), - vec!["ssm.${region}.amazonaws.com".to_string()], - )]; - - let fas_op = FasOperation::new("Decrypt".to_string(), "kms".to_string(), context); - - // Test service method - assert_eq!(fas_op.service(&service_cfg), "kms"); - - // Test operation method - assert_eq!(fas_op.operation("kms", &service_cfg), "Decrypt"); - - // Test service_operation_name method - assert_eq!(fas_op.service_operation_name(&service_cfg), "kms:Decrypt"); - } - #[test] fn test_operation_fas_map_deserialization() { let json_content = r#"{ diff --git a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs index d9e21c3..3163159 100644 --- a/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs +++ b/iam-policy-autopilot-policy-generation/src/enrichment/resource_matcher.rs @@ -4,118 +4,45 @@ //! action maps with Service Definition Files to generate enriched method calls //! with complete IAM metadata. -use convert_case::{Case, Casing}; -use std::collections::HashSet; +use std::collections::HashMap; use std::sync::Arc; -use super::{Action, Context, EnrichedSdkMethodCall, Resource}; -use crate::enrichment::operation_fas_map::{FasOperation, OperationFasMap, OperationFasMaps}; +use super::{Action, Context, EnrichedSdkMethodCall, Explanation, Reason, Resource}; +use crate::enrichment::operation_fas_map::{OperationFasMap, OperationFasMaps}; use crate::enrichment::service_reference::ServiceReference; -use crate::enrichment::{Condition, ServiceReferenceLoader}; +use crate::enrichment::{Condition, Operation, ServiceReferenceLoader}; use crate::errors::{ExtractorError, Result}; use crate::service_configuration::ServiceConfiguration; use crate::{SdkMethodCall, SdkType}; -/// ResourceMatcher coordinates OperationAction maps and Service Reference data to generate enriched method calls -/// -/// This struct provides the core functionality for the 3-stage enrichment pipeline, -/// combining parsed method calls with operation action maps and Service -/// Definition Files to produce complete IAM metadata. -#[derive(Debug, Clone)] -pub(crate) struct ResourceMatcher { - service_cfg: Arc, - fas_maps: OperationFasMaps, - sdk: SdkType, +#[derive(Clone, Debug)] +struct FasExpansion { + dependency_graph: HashMap, Vec>>, } -// TODO: Make this configurable: https://github.com/awslabs/iam-policy-autopilot/issues/19 -const RESOURCE_CUTOFF: usize = 5; - -impl ResourceMatcher { - /// Create a new ResourceMatcher instance - #[must_use] - pub(crate) fn new( - service_cfg: Arc, - fas_maps: OperationFasMaps, - sdk: SdkType, +impl FasExpansion { + fn new( + service_cfg: &ServiceConfiguration, + fas_maps: &OperationFasMaps, + initial: Operation, ) -> Self { - Self { - service_cfg, - fas_maps, - sdk, - } - } - - /// Enrich a parsed method call with OperationAction maps, FAS maps, and Service - /// Reference data - pub(crate) async fn enrich_method_call<'b>( - &self, - parsed_call: &'b SdkMethodCall, - service_reference_loader: &ServiceReferenceLoader, - ) -> Result>> { - if parsed_call.possible_services.is_empty() { - return Err(ExtractorError::enrichment_error( - &parsed_call.name, - "No matching services found for method call", - )); - } - - let mut enriched_calls: Vec> = Vec::new(); - - // For each possible service in the parsed method call - for service_name in &parsed_call.possible_services { - // Create enriched method call for this service - if let Some(enriched_call) = self - .create_enriched_method_call(parsed_call, service_name, service_reference_loader) - .await? - { - enriched_calls.push(enriched_call); - } - } - - Ok(enriched_calls) - } + let mut dependency_graph: HashMap, Vec>> = HashMap::new(); + let initial_arc = Arc::new(initial); - /// Find OperationFas map for a specific service - fn find_operation_fas_map_for_service( - &self, - service_name: &str, - ) -> Option> { - self.fas_maps - .get( - self.service_cfg - .rename_service_operation_action_map(service_name) - .as_ref(), - ) - .cloned() - } + dependency_graph.insert(Arc::clone(&initial_arc), Vec::new()); // Root has no dependencies - /// Expand FAS operations to a fixed point, avoiding infinite loops from cycles - /// - /// This method safely expands FAS operations by iteratively processing new operations - /// until no more new operations are discovered (fixed point reached). - /// It includes cycle detection to prevent infinite loops. - fn expand_fas_operations_to_fixed_point( - &self, - initial: FasOperation, - ) -> Result> { - let mut operations = HashSet::::new(); - operations.insert(initial); + let mut to_process = vec![Arc::clone(&initial_arc)]; - let mut to_process = operations.clone(); while !to_process.is_empty() { - let mut newly_discovered = HashSet::::new(); + let mut newly_discovered = Vec::new(); - // Process all operations in the current batch - for operation in &to_process { - let service_name = operation.service(&self.service_cfg); - let operation_fas_map_option = - self.find_operation_fas_map_for_service(&service_name); + for current in &to_process { + let service_name = ¤t.service; - match operation_fas_map_option { + match Self::find_operation_fas_map_for_service(service_cfg, fas_maps, service_name) + { Some(operation_fas_map) => { - let service_operation_name = - operation.service_operation_name(&self.service_cfg); + let service_operation_name = current.service_operation_name(); log::debug!("Looking up operation {}", service_operation_name); if let Some(additional_operations) = operation_fas_map @@ -123,9 +50,16 @@ impl ResourceMatcher { .get(&service_operation_name) { for additional_op in additional_operations { - // Only add if we haven't seen this operation before - if !operations.contains(additional_op) { - newly_discovered.insert(additional_op.clone()); + let new_op = Arc::new(Operation::from(additional_op.clone())); + + if let Some(existing_deps) = dependency_graph.get_mut(&new_op) { + // Operation already exists, add this dependency relationship + existing_deps.push(Arc::clone(current)); + } else { + // New operation + dependency_graph + .insert(Arc::clone(&new_op), vec![Arc::clone(current)]); + newly_discovered.push(Arc::clone(&new_op)); } } } else { @@ -138,12 +72,7 @@ impl ResourceMatcher { } } - // Add newly discovered operations to our complete set - operations.extend(newly_discovered.iter().cloned()); - let newly_discovered_count = newly_discovered.len(); - - // Set up next iteration to process only newly discovered operations to_process = newly_discovered; log::debug!( @@ -154,14 +83,87 @@ impl ResourceMatcher { log::debug!( "FAS expansion completed with {} total operations", - operations.len() + dependency_graph.len() ); + Self { dependency_graph } + } + + /// Find OperationFas map for a specific service + fn find_operation_fas_map_for_service( + service_cfg: &ServiceConfiguration, + fas_maps: &OperationFasMaps, + service_name: &str, + ) -> Option> { + fas_maps + .get( + service_cfg + .rename_service_operation_action_map(service_name) + .as_ref(), + ) + .cloned() + } + + fn operations(&self) -> impl Iterator> { + self.dependency_graph.keys() + } + + fn complete_provenance_chain(&self, op: Arc) -> Vec> { + let mut result = vec![]; + if let Some(deps) = self.dependency_graph.get(&op) { + for dep in deps { + result.push(Arc::clone(dep)); + } + } + // Add the initial operation + result.push(Arc::clone(&op)); + result + } +} + +/// ResourceMatcher coordinates OperationAction maps and Service Reference data to generate enriched method calls +/// +/// This struct provides the core functionality for the 3-stage enrichment pipeline, +/// combining parsed method calls with operation action maps and Service +/// Definition Files to produce complete IAM metadata. +#[derive(derive_new::new, Debug, Clone)] +pub(crate) struct ResourceMatcher { + service_cfg: Arc, + fas_maps: OperationFasMaps, + sdk: SdkType, +} + +// TODO: Make this configurable: https://github.com/awslabs/iam-policy-autopilot/issues/19 +const RESOURCE_CUTOFF: usize = 5; + +impl ResourceMatcher { + /// Enrich a parsed method call with OperationAction maps, FAS maps, and Service + /// Reference data + pub(crate) async fn enrich_method_call<'b>( + &self, + parsed_call: &'b SdkMethodCall, + service_reference_loader: &ServiceReferenceLoader, + ) -> Result>> { + if parsed_call.possible_services.is_empty() { + return Err(ExtractorError::enrichment_error( + &parsed_call.name, + "No matching services found for method call", + )); + } + + let mut enriched_calls: Vec> = Vec::new(); - // Convert HashSet to Vec and sort by service_operation_name for deterministic output - let mut operations_vec: Vec = operations.into_iter().collect(); - operations_vec.sort_by_key(|op| op.service_operation_name(&self.service_cfg)); + // For each possible service in the parsed method call + for service_name in &parsed_call.possible_services { + // Create enriched method call for this service + if let Some(enriched_call) = self + .create_enriched_method_call(parsed_call, service_name, service_reference_loader) + .await? + { + enriched_calls.push(enriched_call); + } + } - Ok(operations_vec) + Ok(enriched_calls) } fn make_condition(context: &[T]) -> Vec { @@ -189,65 +191,52 @@ impl ResourceMatcher { parsed_call.name ); - let initial = { - let initial_service_name = self - .service_cfg - .rename_service_service_reference(service_name); - // Determine the initial operation name, with special handling for Python's boto3 method names - let initial_operation_name = if self.sdk == SdkType::Boto3 { - // Try to load service reference and look up the boto3 method mapping - service_reference_loader - .load(&initial_service_name) - .await? - .and_then(|service_ref| { - log::debug!("Looking up method {}", parsed_call.name); - service_ref - .boto3_method_to_operation - .get(&parsed_call.name) - .map(|op| { - log::debug!("got {:?}", op); - op.split(':').nth(1).unwrap_or(op).to_string() - }) - }) - // Fallback to PascalCase conversion if mapping not found - // This should not be reachable, but if for some reason we cannot use the SDF, - // we try converting to PascalCase, knowing that this is flawed in some cases: - // think `AddRoleToDBInstance` (actual name) - // vs. `AddRoleToDbInstance` (converted name) - .unwrap_or_else(|| parsed_call.name.to_case(Case::Pascal)) - } else { - // For non-Boto3 SDKs we use the extracted name as-is - parsed_call.name.clone() - }; - FasOperation::new(initial_operation_name, service_name.to_string(), Vec::new()) - }; + // Store the original service name from parsed_call for use in explanations + let original_service_name = service_name; + + let initial = Operation::from_call( + parsed_call, + service_name, + &self.service_cfg, + self.sdk, + service_reference_loader, + ) + .await?; + + log::debug!("Expanded {:?}", initial); // Use fixed-point algorithm to safely expand FAS operations until no new operations are found - let operations = self.expand_fas_operations_to_fixed_point(initial)?; + let fas_expansion = FasExpansion::new(&self.service_cfg, &self.fas_maps, initial); + + log::debug!("to\n{:?}", fas_expansion.dependency_graph); let mut enriched_actions = vec![]; - for operation in operations { - let service_name = operation.service(&self.service_cfg); - // Find the corresponding SDF using the cache - let service_reference = service_reference_loader.load(&service_name).await?; + for op in fas_expansion.operations() { + log::debug!( + "Creating actions for operation {:?}", + op.service_operation_name() + ); + log::debug!(" with context {:?}", op.context()); + + // Find the corresponding SDF using the cache + let service_reference = service_reference_loader.load(&op.service).await?; match service_reference { None => { + log::debug!("Skipping operation due to no service reference"); continue; } Some(service_reference) => { - log::debug!("Creating actions for {:?}", operation); - log::debug!(" with context {:?}", operation.context); if let Some(operation_to_authorized_actions) = &service_reference.operation_to_authorized_actions { - log::debug!( - "Looking up {}", - &operation.service_operation_name(&self.service_cfg) - ); + log::debug!("Looking up {}", &op.service_operation_name()); if let Some(operation_to_authorized_action) = - operation_to_authorized_actions - .get(&operation.service_operation_name(&self.service_cfg)) + operation_to_authorized_actions.get(&op.service_operation_name()) { + log::debug!( + "Found operation action map for {:?}", + operation_to_authorized_action.name + ); for action in &operation_to_authorized_action.authorized_actions { let enriched_resources = self .find_resources_for_action_in_service_reference( @@ -262,7 +251,7 @@ impl ResourceMatcher { }; // Combine conditions from FAS operation context and AuthorizedAction context - let mut conditions = Self::make_condition(&operation.context); + let mut conditions = Self::make_condition(op.context()); // Add conditions from AuthorizedAction context if present if let Some(auth_context) = &action.context { @@ -271,29 +260,38 @@ impl ResourceMatcher { ))); } + let ops = fas_expansion.complete_provenance_chain(Arc::clone(op)); + + // Create explanation for this action + let explanation = Explanation { + reasons: vec![Reason::new(ops)], + }; let enriched_action = Action::new( action.name.clone(), enriched_resources, conditions, + explanation, ); - + log::debug!("Created action: {:?}", enriched_action); enriched_actions.push(enriched_action); } } else { // Fallback: operation not found in operation action map, create basic action // This ensures we don't filter out operations, only ADD additional ones from the map if let Some(a) = - self.create_fallback_action(&parsed_call.name, &service_reference)? + self.create_fallback_action(op, &fas_expansion, &service_reference)? { - enriched_actions.push(a) + log::debug!("Created fallback action due to no entry in operation action map: {:?}", a); + enriched_actions.push(a); } } } else { // Fallback: operation action map does not exist, create basic action if let Some(a) = - self.create_fallback_action(&parsed_call.name, &service_reference)? + self.create_fallback_action(op, &fas_expansion, &service_reference)? { - enriched_actions.push(a) + log::debug!("Created fallback action due to no operation action map for service: {:?}", a); + enriched_actions.push(a); } } } @@ -306,7 +304,7 @@ impl ResourceMatcher { Ok(Some(EnrichedSdkMethodCall { method_name: parsed_call.name.clone(), - service: service_name.to_string(), + service: original_service_name.to_string(), actions: enriched_actions, sdk_method_call: parsed_call, })) @@ -318,20 +316,18 @@ impl ResourceMatcher { /// corresponding resources in the SDF. fn create_fallback_action( &self, - method_name: &str, + op: &Arc, + fas_expansion_result: &FasExpansion, service_reference: &ServiceReference, ) -> Result> { - let renamed_service = self - .service_cfg - .rename_service_service_reference(&service_reference.service_name); - let renamed_action = &method_name.to_case(Case::Pascal); - let action_name = format!("{}:{}", renamed_service, renamed_action); + let action_name = op.service_operation_name(); // Sanity check that the action exists in the SDF - if !service_reference - .actions - .contains_key(renamed_action.as_str()) - { + if !service_reference.actions.contains_key(&op.name) { + log::debug!( + "Not creating fallback action: service reference doesn't contain key: {:?}", + action_name + ); return Ok(None); } @@ -339,10 +335,18 @@ impl ResourceMatcher { let resources = self.find_resources_for_action_in_service_reference(&action_name, service_reference)?; + // Create explanation for fallback action + let explanation = Explanation { + reasons: vec![Reason::new( + fas_expansion_result.complete_provenance_chain(Arc::clone(op)), + )], + }; + Ok(Some(Action::new( action_name.to_string(), resources, vec![], + explanation, ))) } @@ -404,8 +408,8 @@ mod tests { use std::collections::HashMap; use super::*; - use crate::enrichment::mock_remote_service_reference; use crate::enrichment::operation_fas_map::{FasContext, FasOperation, OperationFasMap}; + use crate::enrichment::{mock_remote_service_reference, OperationSource}; fn create_test_parsed_method_call() -> SdkMethodCall { SdkMethodCall { @@ -950,81 +954,80 @@ mod tests { let service_cfg = create_empty_service_config(); // Create a mock FAS map with no cycles: A -> B -> C (linear chain) - let mut fas_maps = HashMap::new(); - - // Service A: GetObject -> Service B: Decrypt - let mut service_a_operations = HashMap::new(); - service_a_operations.insert( - "service-a:GetObject".to_string(), - vec![FasOperation::new( - "Decrypt".to_string(), - "service-b".to_string(), - vec![FasContext::new( - "test".to_string(), - vec!["value".to_string()], - )], - )], - ); + let fas_maps = { + let mut fas_maps = HashMap::new(); - // Service B: Decrypt -> Service C: Log - let mut service_b_operations = HashMap::new(); - service_b_operations.insert( - "service-b:Decrypt".to_string(), - vec![FasOperation::new( - "Log".to_string(), - "service-c".to_string(), - vec![FasContext::new( - "test2".to_string(), - vec!["value2".to_string()], + // Service A: GetObject -> Service B: Decrypt + let mut service_a_operations = HashMap::new(); + service_a_operations.insert( + "service-a:GetObject".to_string(), + vec![FasOperation::new( + "Decrypt".to_string(), + "service-b".to_string(), + vec![FasContext::new( + "test".to_string(), + vec!["value".to_string()], + )], )], - )], - ); + ); - // Service C: Log -> nothing (terminal) - let service_c_operations = HashMap::new(); + // Service B: Decrypt -> Service C: Log + let mut service_b_operations = HashMap::new(); + service_b_operations.insert( + "service-b:Decrypt".to_string(), + vec![FasOperation::new( + "Log".to_string(), + "service-c".to_string(), + vec![FasContext::new( + "test2".to_string(), + vec!["value2".to_string()], + )], + )], + ); - fas_maps.insert( - "service-a".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_a_operations, - }), - ); - fas_maps.insert( - "service-b".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_b_operations, - }), - ); - fas_maps.insert( - "service-c".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_c_operations, - }), - ); + // Service C: Log -> nothing (terminal) + let service_c_operations = HashMap::new(); - let matcher = ResourceMatcher::new(service_cfg.clone(), fas_maps, SdkType::Other); + fas_maps.insert( + "service-a".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_a_operations, + }), + ); + fas_maps.insert( + "service-b".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_b_operations, + }), + ); + fas_maps.insert( + "service-c".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_c_operations, + }), + ); + fas_maps + }; // Test expansion starting from GetObject - let initial = - FasOperation::new("GetObject".to_string(), "service-a".to_string(), Vec::new()); - - let result = matcher.expand_fas_operations_to_fixed_point(initial); - assert!( - result.is_ok(), - "Fixed-point expansion should succeed for non-cyclic operations" + let initial = Operation::new( + "service-a".to_string(), + "GetObject".to_string(), + OperationSource::Provided, ); - let operations = result.unwrap(); + let fas_expansion = FasExpansion::new(&service_cfg, &fas_maps, initial); + assert_eq!( - operations.len(), + fas_expansion.dependency_graph.len(), 3, "Should have exactly 3 operations: GetObject, Decrypt, Log" ); // Verify all expected operations are present - let operation_names: std::collections::HashSet = operations - .iter() - .map(|op| op.service_operation_name(&service_cfg)) + let operation_names: std::collections::HashSet = fas_expansion + .operations() + .map(|op| op.service_operation_name()) .collect(); assert!(operation_names.contains("service-a:GetObject")); @@ -1044,72 +1047,72 @@ mod tests { let service_cfg = create_empty_service_config(); // Create a mock FAS map with a cycle: A -> B -> A - let mut fas_maps = HashMap::new(); - - // Service A: GetObject -> Service B: Decrypt - let mut service_a_operations = HashMap::new(); - service_a_operations.insert( - "service-a:GetObject".to_string(), - vec![FasOperation::new( - "Decrypt".to_string(), - "service-b".to_string(), - vec![FasContext::new( - "test".to_string(), - vec!["value".to_string()], - )], - )], - ); + let fas_maps = { + let mut fas_maps = HashMap::new(); - // Service B: Decrypt -> Service A: GetObject (creates cycle!) - let mut service_b_operations = HashMap::new(); - service_b_operations.insert( - "service-b:Decrypt".to_string(), - vec![FasOperation::new( - "GetObject".to_string(), - "service-a".to_string(), - vec![FasContext::new( - "test2".to_string(), - vec!["value2".to_string()], + // Service A: GetObject -> Service B: Decrypt + let mut service_a_operations = HashMap::new(); + service_a_operations.insert( + "service-a:GetObject".to_string(), + vec![FasOperation::new( + "Decrypt".to_string(), + "service-b".to_string(), + vec![FasContext::new( + "test".to_string(), + vec!["value".to_string()], + )], )], - )], - ); + ); - fas_maps.insert( - "service-a".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_a_operations, - }), - ); - fas_maps.insert( - "service-b".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_b_operations, - }), - ); + // Service B: Decrypt -> Service A: GetObject (creates cycle!) + let mut service_b_operations = HashMap::new(); + service_b_operations.insert( + "service-b:Decrypt".to_string(), + vec![FasOperation::new( + "GetObject".to_string(), + "service-a".to_string(), + vec![FasContext::new( + "test2".to_string(), + vec!["value2".to_string()], + )], + )], + ); - let matcher = ResourceMatcher::new(service_cfg.clone(), fas_maps, SdkType::Other); + fas_maps.insert( + "service-a".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_a_operations, + }), + ); + fas_maps.insert( + "service-b".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_b_operations, + }), + ); + fas_maps + }; // Test expansion starting from GetObject - should detect cycle and terminate - let initial = - FasOperation::new("GetObject".to_string(), "service-a".to_string(), Vec::new()); - - let result = matcher.expand_fas_operations_to_fixed_point(initial); - - assert!( - result.is_ok(), - "Fixed-point expansion should handle cycles gracefully" + let initial = Operation::new( + "service-a".to_string(), + "GetObject".to_string(), + OperationSource::Provided, ); - let operations = result.unwrap(); + let fas_expansion = FasExpansion::new(&service_cfg, &fas_maps, initial); // Debug: print what operations we actually got - let operation_names: std::collections::HashSet = operations - .iter() - .map(|op| op.service_operation_name(&service_cfg)) + let operation_names: std::collections::HashSet = fas_expansion + .operations() + .map(|op| op.service_operation_name()) .collect(); // 3 operations, note that GetObject occurs twice, once with and once without context - assert!(operations.len() == 3, "Should have 3 operations"); + assert!( + fas_expansion.dependency_graph.len() == 3, + "Should have 3 operations" + ); // Verify expected operations are present assert!(operation_names.contains("service-a:GetObject")); @@ -1127,56 +1130,52 @@ mod tests { // Create a service configuration let service_cfg = create_empty_service_config(); - let mut fas_maps = HashMap::new(); - - // Create a chain that loops back: A -> B -> C -> D -> A - let operations_data = vec![ - ("service-a", "GetObject", "service-b", "Decrypt"), - ("service-b", "Decrypt", "service-c", "Validate"), - ("service-c", "Validate", "service-d", "Log"), - ("service-d", "Log", "service-a", "GetObject"), // Back to start - ]; - - for (from_service, from_op, to_service, to_op) in operations_data { - let mut operations = HashMap::new(); - operations.insert( - format!("{}:{}", from_service, from_op), - vec![FasOperation::new( - to_op.to_string(), - to_service.to_string(), - vec![FasContext::new( - "cycle".to_string(), - vec!["test".to_string()], + let fas_maps = { + let mut fas_maps = HashMap::new(); + + // Create a chain that loops back: A -> B -> C -> D -> A + let operations_data = vec![ + ("service-a", "GetObject", "service-b", "Decrypt"), + ("service-b", "Decrypt", "service-c", "Validate"), + ("service-c", "Validate", "service-d", "Log"), + ("service-d", "Log", "service-a", "GetObject"), // Back to start + ]; + + for (from_service, from_op, to_service, to_op) in operations_data { + let mut operations = HashMap::new(); + operations.insert( + format!("{}:{}", from_service, from_op), + vec![FasOperation::new( + to_op.to_string(), + to_service.to_string(), + vec![FasContext::new( + "cycle".to_string(), + vec!["test".to_string()], + )], )], - )], - ); - - fas_maps.insert( - from_service.to_string(), - Arc::new(OperationFasMap { - fas_operations: operations, - }), - ); - } - - let matcher = ResourceMatcher::new(service_cfg.clone(), fas_maps, SdkType::Other); - - let initial = - FasOperation::new("GetObject".to_string(), "service-a".to_string(), Vec::new()); - - let result = matcher.expand_fas_operations_to_fixed_point(initial); + ); + + fas_maps.insert( + from_service.to_string(), + Arc::new(OperationFasMap { + fas_operations: operations, + }), + ); + } + fas_maps + }; - // Should succeed and return operations for the cycle - assert!( - result.is_ok(), - "Should handle complex cycles without hitting max iterations" + let initial = Operation::new( + "service-a".to_string(), + "GetObject".to_string(), + OperationSource::Provided, ); - let operations = result.unwrap(); + let fas_expansion = FasExpansion::new(&service_cfg, &fas_maps, initial); // We have 5 operations, note that GetObject occurs twice, once with context and the initial one without assert!( - operations.len() == 5, + fas_expansion.dependency_graph.len() == 5, "Should have 5 operations in the cycle" ); } @@ -1186,28 +1185,30 @@ mod tests { use std::collections::HashMap; let service_cfg = create_empty_service_config(); + let fas_maps = HashMap::new(); - let matcher = ResourceMatcher::new(service_cfg.clone(), HashMap::new(), SdkType::Other); - - let initial = FasOperation::new( - "NonExistentOperation".to_string(), + let initial = Operation::new( "non-existent-service".to_string(), - Vec::new(), + "NonExistentOperation".to_string(), + OperationSource::Provided, ); - let result = matcher.expand_fas_operations_to_fixed_point(initial.clone()); - assert!(result.is_ok(), "Should succeed even with no FAS maps"); - - let operations = result.unwrap(); + let fas_expansion = FasExpansion::new(&service_cfg, &fas_maps, initial.clone()); assert_eq!( - operations.len(), + fas_expansion.dependency_graph.len(), 1, "Should contain only the initial operation" ); - assert!( - operations.contains(&initial), + + let operations: Vec<_> = fas_expansion.operations().collect(); + assert_eq!( + **operations[0], initial, "Should contain the initial operation" ); + assert!( + !matches!(operations[0].source, OperationSource::Fas(_)), + "Initial operation should not be from FAS expansion" + ); println!("✓ Test passed: Handles case with no additional FAS operations"); } @@ -1220,53 +1221,54 @@ mod tests { let service_cfg = create_empty_service_config(); // Create a FAS map where A -> A with empty context (self-referential) - let mut fas_maps = HashMap::new(); - - // Service A: GetObject -> Service A: GetObject (with empty context) - let mut service_a_operations = HashMap::new(); - service_a_operations.insert( - "service-a:GetObject".to_string(), - vec![FasOperation::new( - "GetObject".to_string(), - "service-a".to_string(), - Vec::new(), // Empty context - same as initial - )], - ); + let fas_maps = { + let mut fas_maps = HashMap::new(); - fas_maps.insert( - "service-a".to_string(), - Arc::new(OperationFasMap { - fas_operations: service_a_operations, - }), - ); + // Service A: GetObject -> Service A: GetObject (with empty context) + let mut service_a_operations = HashMap::new(); + service_a_operations.insert( + "service-a:GetObject".to_string(), + vec![FasOperation::new( + "GetObject".to_string(), + "service-a".to_string(), + Vec::new(), // Empty context - same as initial + )], + ); - let matcher = ResourceMatcher::new(service_cfg.clone(), fas_maps, SdkType::Other); + fas_maps.insert( + "service-a".to_string(), + Arc::new(OperationFasMap { + fas_operations: service_a_operations, + }), + ); + fas_maps + }; // Test expansion starting from GetObject with empty context - let initial = FasOperation::new( - "GetObject".to_string(), + let initial = Operation::new( "service-a".to_string(), - Vec::new(), // Empty context - ); - - let result = matcher.expand_fas_operations_to_fixed_point(initial.clone()); - assert!( - result.is_ok(), - "Self-cycle with empty context should be handled gracefully" + "GetObject".to_string(), + OperationSource::Provided, ); - let operations = result.unwrap(); + let fas_expansion = FasExpansion::new(&service_cfg, &fas_maps, initial.clone()); // Should have exactly 1 operation since A->A with same context creates no new operations assert_eq!( - operations.len(), + fas_expansion.dependency_graph.len(), 1, "Self-cycle with identical context should result in exactly 1 operation" ); - assert!( - operations.contains(&initial), + + let operations: Vec<_> = fas_expansion.operations().collect(); + assert_eq!( + **operations[0], initial, "Should contain the initial operation" ); + assert!( + !matches!(operations[0].source, OperationSource::Fas(_)), + "Initial operation should not be from FAS expansion" + ); println!("✓ Test passed: Self-cycle with empty context handled correctly"); } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/engine.rs b/iam-policy-autopilot-policy-generation/src/extraction/engine.rs index 912e1c9..c02a8b3 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/engine.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/engine.rs @@ -94,7 +94,7 @@ impl Engine { for source_file in source_files { let extractor = extractor.clone(); - join_set.spawn(async move { extractor.parse(&source_file.content).await }); + join_set.spawn(async move { extractor.parse(&source_file).await }); } // Collect results from concurrent tasks diff --git a/iam-policy-autopilot-policy-generation/src/extraction/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/extractor.rs index 9b9c190..ab5d164 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/extractor.rs @@ -1,31 +1,18 @@ //! Result type alias for operations that can fail with `ExtractorError` -use ast_grep_core::AstGrep; use ast_grep_language::{Go, JavaScript, Python, TypeScript}; use async_trait::async_trait; use crate::extraction::go::types::GoImportInfo; -use crate::{SdkMethodCall, ServiceModelIndex}; +use crate::extraction::AstWithSourceFile; +use crate::{SdkMethodCall, ServiceModelIndex, SourceFile}; /// Enum to handle different AST types from different languages #[derive(Clone)] pub(crate) enum ExtractorResult { - Python( - AstGrep>, - Vec, - ), - Go( - AstGrep>, - Vec, - GoImportInfo, - ), - JavaScript( - AstGrep>, - Vec, - ), - TypeScript( - AstGrep>, - Vec, - ), + Python(AstWithSourceFile, Vec), + Go(AstWithSourceFile, Vec, GoImportInfo), + JavaScript(AstWithSourceFile, Vec), + TypeScript(AstWithSourceFile, Vec), } impl ExtractorResult { @@ -64,7 +51,7 @@ impl ExtractorResult { #[async_trait] pub(crate) trait Extractor: Send + Sync { /// Parse source code into method calls and return the AST - async fn parse(&self, source_code: &str) -> ExtractorResult; + async fn parse(&self, source_file: &SourceFile) -> ExtractorResult; fn filter_map( &self, diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs index ad93529..c0ebca2 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs @@ -334,7 +334,9 @@ mod tests { Shape, ShapeReference, }; use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; + use crate::Location; use std::collections::HashMap; + use std::path::PathBuf; fn create_test_service_index() -> ServiceModelIndex { let mut services = HashMap::new(); @@ -587,6 +589,7 @@ mod tests { name: "ListObjectsV2".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "ListObjectsV2".to_string(), parameters: vec![ Parameter::Positional { value: ParameterValue::Unresolved("context.TODO()".to_string()), @@ -605,8 +608,7 @@ mod tests { }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -648,6 +650,7 @@ mod tests { name: "NonAwsMethod".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "NonAwsMethod".to_string(), parameters: vec![Parameter::Positional { value: ParameterValue::Unresolved("someParam".to_string()), position: 0, @@ -655,8 +658,7 @@ mod tests { struct_fields: None, }], return_type: None, - start_position: (1, 1), - end_position: (1, 30), + location: Location::new(PathBuf::new(), (1, 1), (1, 30)), receiver: Some("client".to_string()), }), }; @@ -685,6 +687,7 @@ mod tests { name: "ListObjectsV2".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "ListObjectsV2".to_string(), parameters: vec![ Parameter::Positional { value: ParameterValue::Unresolved("context.TODO()".to_string()), @@ -703,8 +706,7 @@ mod tests { }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -728,6 +730,7 @@ mod tests { name: "ListObjectsV2".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "ListObjectsV2".to_string(), parameters: vec![ Parameter::Positional { value: ParameterValue::Unresolved("context.TODO()".to_string()), @@ -746,8 +749,7 @@ mod tests { }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -777,6 +779,7 @@ mod tests { name: "ListObjectsV2".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "ListObjectsV2".to_string(), parameters: vec![ Parameter::Positional { value: ParameterValue::Unresolved("context.TODO()".to_string()), @@ -795,8 +798,7 @@ mod tests { }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -832,9 +834,9 @@ mod tests { struct_fields: Some(vec!["Bucket".to_string()]), }, ], + expr: "GetObject".to_string(), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), return_type: None, - start_position: (1, 1), - end_position: (1, 50), receiver: Some("client".to_string()), }), }; @@ -870,9 +872,9 @@ mod tests { struct_fields: Some(vec!["Bucket".to_string(), "Key".to_string()]), }, ], + expr: "GetObject".to_string(), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), return_type: None, - start_position: (1, 1), - end_position: (1, 50), receiver: Some("client".to_string()), }), }; @@ -907,9 +909,9 @@ mod tests { struct_fields: Some(vec!["Bucket".to_string(), "Key".to_string()]), }, ], + expr: "GetObject".to_string(), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), return_type: None, - start_position: (1, 1), - end_position: (1, 50), receiver: Some("client".to_string()), }), }; @@ -949,9 +951,9 @@ mod tests { struct_fields: Some(vec!["Bucket".to_string(), "Key".to_string()]), }, ], + expr: "GetObject".to_string(), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), return_type: None, - start_position: (1, 1), - end_position: (1, 50), receiver: Some("client".to_string()), }), }; @@ -1006,9 +1008,9 @@ mod tests { struct_fields: Some(vec!["QueueName".to_string(), "Attributes".to_string()]), }, ], + expr: "CreateQueue".to_string(), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), return_type: None, - start_position: (1, 1), - end_position: (1, 50), receiver: Some("sqsClient".to_string()), }), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs index 70433fa..39c4a75 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs @@ -7,8 +7,10 @@ use crate::extraction::go::node_kinds; use crate::extraction::go::paginator_extractor::GoPaginatorExtractor; use crate::extraction::go::types::{GoImportInfo, ImportInfo}; use crate::extraction::go::waiter_extractor::GoWaiterExtractor; -use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; -use crate::ServiceModelIndex; +use crate::extraction::{ + AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, +}; +use crate::{Location, ServiceModelIndex, SourceFile}; use ast_grep_config::from_yaml_string; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Go; @@ -23,12 +25,9 @@ impl GoExtractor { } /// Extract import statements from Go source code using ast-grep - fn extract_imports( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> GoImportInfo { + fn extract_imports(&self, ast: &AstWithSourceFile) -> GoImportInfo { let mut import_info = GoImportInfo::new(); - let root = ast.root(); + let root = ast.ast.root(); // AST-grep configuration for extracting import statements let import_config = r#" @@ -142,6 +141,7 @@ rule: fn parse_method_call( &self, node_match: &ast_grep_core::NodeMatch>, + source_file: &SourceFile, ) -> Option { let env = node_match.get_env(); @@ -189,19 +189,14 @@ rule: vec![] }; - // Get position information - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - let method_call = SdkMethodCall { name: method_name.to_string(), possible_services: Vec::new(), // Will be determined later during service validation metadata: Some(SdkMethodCallMetadata { parameters: arguments, return_type: None, // We don't know the return type from the call site - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + expr: node_match.text().to_string(), + location: Location::from_node(source_file.path.clone(), node_match.get_node()), receiver, }), }; @@ -419,9 +414,13 @@ impl Default for GoExtractor { #[async_trait] impl Extractor for GoExtractor { - async fn parse(&self, source_code: &str) -> crate::extraction::extractor::ExtractorResult { - let ast_grep = Go.ast_grep(source_code); - let root = ast_grep.root(); + async fn parse( + &self, + source_file: &SourceFile, + ) -> crate::extraction::extractor::ExtractorResult { + let ast_grep = Go.ast_grep(&source_file.content); + let ast = AstWithSourceFile::new(ast_grep, source_file.clone()); + let root = ast.ast.root(); let mut method_calls = Vec::new(); @@ -453,15 +452,15 @@ rule: // Find all method calls with attribute access: receiver.method(args) for node_match in root.find_all(&config.matcher) { - if let Some(method_call) = self.parse_method_call(&node_match) { + if let Some(method_call) = self.parse_method_call(&node_match, source_file) { method_calls.push(method_call); } } // Extract import information - let import_info = self.extract_imports(&ast_grep); + let import_info = self.extract_imports(&ast); - crate::extraction::extractor::ExtractorResult::Go(ast_grep, method_calls, import_info) + crate::extraction::extractor::ExtractorResult::Go(ast, method_calls, import_info) } fn filter_map( @@ -573,6 +572,8 @@ impl Parameter { #[cfg(test)] mod tests { + use std::path::PathBuf; + use super::*; #[tokio::test] @@ -615,8 +616,10 @@ func main() { } } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), aws_code.to_string(), crate::Language::Go); - let result = extractor.parse(aws_code).await; + let result = extractor.parse(&source_file).await; let aws_method_calls = result.method_calls_ref(); println!("AWS test - Found {} method calls:", aws_method_calls.len()); @@ -655,8 +658,10 @@ func main() { result2, err2 := stsClient.GetCallerIdentity(ctx, &sts.GetCallerIdentityInput{}) } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); println!( @@ -722,8 +727,10 @@ func main() { result3, err3 := client.GetObject(ctx, &s3.GetObjectInput{}) } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); println!( @@ -813,8 +820,10 @@ func main() { result3, err3 := factory.createClient("s3").GetObject(ctx, &s3.GetObjectInput{}) } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); println!( @@ -905,8 +914,10 @@ func main() { DescribeInstances(ctx, &ec2.DescribeInstancesInput{}) } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); println!( @@ -1167,8 +1178,10 @@ func main() { _ = objectOutput } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; // Verify import extraction if let Some(import_info) = result.go_import_info() { @@ -1288,8 +1301,10 @@ func main() { _ = output } "#; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); - let result = extractor.parse(test_code).await; + let result = extractor.parse(&source_file).await; // Verify import extraction if let Some(import_info) = result.go_import_info() { @@ -1363,7 +1378,9 @@ func (basics BucketBasics) DownloadFile(ctx context.Context, bucketName string, } "#; - let result = extractor.parse(test_code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), test_code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); println!( @@ -1470,9 +1487,12 @@ func (basics BucketBasics) DownloadFile(ctx context.Context, bucketName string, } #[cfg(test)] mod test_struct_fields { + use std::path::PathBuf; + use crate::extraction::extractor::Extractor; use crate::extraction::go::extractor::GoExtractor; use crate::extraction::Parameter; + use crate::SourceFile; /// Test extraction of struct literals with multiple fields #[tokio::test] @@ -1497,7 +1517,9 @@ func test() { } "#; - let result = extractor.parse(code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); let call = method_calls @@ -1568,7 +1590,9 @@ func test() { } "#; - let result = extractor.parse(code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); let call = method_calls @@ -1621,7 +1645,9 @@ func test() { } "#; - let result = extractor.parse(code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); let call = method_calls @@ -1672,7 +1698,9 @@ func test() { } "#; - let result = extractor.parse(code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); let call = method_calls @@ -1720,7 +1748,9 @@ func test() { } "#; - let result = extractor.parse(code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), code.to_string(), crate::Language::Go); + let result = extractor.parse(&source_file).await; let method_calls = result.method_calls_ref(); let call = method_calls diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs index bd58dab..a7928e4 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/features_extractor.rs @@ -6,7 +6,8 @@ use crate::extraction::go::features::{FeatureMethod, GoSdkV2Features}; use crate::extraction::go::types::GoImportInfo; use crate::extraction::go::utils; -use crate::extraction::{SdkMethodCall, SdkMethodCallMetadata}; +use crate::extraction::{AstWithSourceFile, SdkMethodCall, SdkMethodCallMetadata}; +use crate::Location; use ast_grep_config::from_yaml_string; use ast_grep_language::Go; @@ -19,10 +20,10 @@ pub(crate) struct FeatureCallInfo { pub(crate) receiver: Option, /// Extracted arguments pub(crate) arguments: Vec, - /// Start position of the call node - pub(crate) start_position: (usize, usize), - /// End position of the call node - pub(crate) end_position: (usize, usize), + /// Matched expression + pub(crate) expr: String, + /// Location of the call + pub(crate) location: Location, } /// Extractor for Go AWS SDK v2 feature methods @@ -43,7 +44,7 @@ impl GoFeaturesExtractor { /// Extract feature method calls from the AST pub(crate) fn extract_feature_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, import_info: &mut GoImportInfo, ) -> Vec { let mut synthetic_calls = Vec::new(); @@ -63,11 +64,8 @@ impl GoFeaturesExtractor { /// Find all method calls that might be feature methods /// This matches receiver.Method(...) patterns using proper ast-grep config - fn find_method_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_method_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut calls = Vec::new(); // Use the same pattern as the main extractor for method calls @@ -109,17 +107,15 @@ rule: let args_nodes = env.get_multiple_matches("ARGS"); let arguments = utils::extract_arguments(&args_nodes); - // Get position information - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - calls.push(FeatureCallInfo { method_name, receiver: Some(receiver), arguments, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + expr: node_match.text().to_string(), + location: Location::from_node( + ast.source_file.path.clone(), + node_match.get_node(), + ), }); } } @@ -224,8 +220,8 @@ rule: metadata: Some(SdkMethodCallMetadata { parameters: parameters.clone(), return_type: None, - start_position: call_info.start_position, - end_position: call_info.end_position, + expr: call_info.expr.clone(), + location: call_info.location.clone(), receiver: call_info.receiver.clone(), }), } @@ -236,14 +232,19 @@ rule: #[cfg(test)] mod tests { + use std::path::PathBuf; + + use crate::{Language, SourceFile}; + use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Go; - fn create_test_ast( - source_code: &str, - ) -> ast_grep_core::AstGrep> { - Go.ast_grep(source_code) + fn create_test_ast(source_code: &str) -> AstWithSourceFile { + let source_file = + SourceFile::with_language(PathBuf::new(), source_code.to_string(), Language::Go); + let ast_grep = Go.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file) } fn create_test_import_info() -> GoImportInfo { diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs index 8380d79..60145ff 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs @@ -3,11 +3,13 @@ //! This module handles extraction of Go AWS SDK v2 paginator patterns by detecting //! paginator creation calls, which contain the meaningful parameters for IAM policy generation. +use std::path::Path; + use crate::extraction::go::utils; use crate::extraction::sdk_model::ServiceDiscovery; -use crate::extraction::{Parameter, SdkMethodCall, SdkMethodCallMetadata}; -use crate::Language; +use crate::extraction::{AstWithSourceFile, Parameter, SdkMethodCall, SdkMethodCallMetadata}; use crate::ServiceModelIndex; +use crate::{Language, Location}; use ast_grep_language::Go; /// Information about a discovered paginator creation call @@ -23,8 +25,10 @@ pub(crate) struct PaginatorInfo { pub client_receiver: String, /// Extracted arguments from paginator creation (input struct) pub creation_arguments: Vec, - /// Line number where paginator was created - pub creation_line: usize, + /// Matched expression + pub expr: String, + /// Location of the paginator creation + pub location: Location, } /// Information about a chained paginator call @@ -36,13 +40,10 @@ pub(crate) struct ChainedPaginatorCallInfo { pub client_receiver: String, /// Extracted arguments from paginator creation (input struct) pub arguments: Vec, - /// Line number where chained call was made - #[allow(dead_code)] - pub line: usize, - /// Start position of the chained call node - pub start_position: (usize, usize), - /// End position of the chained call node - pub end_position: (usize, usize), + /// Matched expression + pub expr: String, + /// Location of the paginator was called + pub location: Location, } /// Extractor for Go AWS SDK paginator patterns @@ -66,7 +67,7 @@ impl<'a> GoPaginatorExtractor<'a> { /// Extract paginator method calls from the AST pub(crate) fn extract_paginator_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { let mut synthetic_calls = Vec::new(); @@ -88,18 +89,17 @@ impl<'a> GoPaginatorExtractor<'a> { } /// Find all paginator creation calls (NewXxxPaginator functions) - fn find_paginator_creation_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_paginator_creation_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut paginators = Vec::new(); // Pattern: $VAR := $PACKAGE.$FUNCTION($$$ARGS) where FUNCTION contains "New" and "Paginator" let paginator_pattern = "$VAR := $PACKAGE.$FUNCTION($$$ARGS)"; for node_match in root.find_all(paginator_pattern) { - if let Some(paginator_info) = self.parse_paginator_creation_call(&node_match) { + if let Some(paginator_info) = + self.parse_paginator_creation_call(&node_match, &ast.source_file.path) + { paginators.push(paginator_info); } } @@ -110,16 +110,18 @@ impl<'a> GoPaginatorExtractor<'a> { /// Find all chained paginator calls fn find_chained_paginator_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut chained_calls = Vec::new(); // Pattern: $PACKAGE.$FUNCTION($$$ARGS).NextPage($$$NEXT_ARGS) let chained_pattern = "$PACKAGE.$FUNCTION($$$ARGS).NextPage($$$NEXT_ARGS)"; for node_match in root.find_all(chained_pattern) { - if let Some(chained_info) = self.parse_chained_paginator_call(&node_match) { + if let Some(chained_info) = + self.parse_chained_paginator_call(&node_match, &ast.source_file.path) + { chained_calls.push(chained_info); } } @@ -131,6 +133,7 @@ impl<'a> GoPaginatorExtractor<'a> { fn parse_paginator_creation_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -167,14 +170,13 @@ impl<'a> GoPaginatorExtractor<'a> { .and_then(|s| s.strip_suffix("Paginator")); if let Some(operation_name) = operation_name { - let creation_line = node_match.get_node().start_pos().line() + 1; - return Some(PaginatorInfo { variable_name, paginator_type: operation_name.to_string(), client_receiver, creation_arguments, - creation_line, + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }); } @@ -185,6 +187,7 @@ impl<'a> GoPaginatorExtractor<'a> { fn parse_chained_paginator_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -217,18 +220,12 @@ impl<'a> GoPaginatorExtractor<'a> { Vec::new() }; - // Get position information - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(ChainedPaginatorCallInfo { paginator_type, client_receiver, arguments: creation_arguments, - line: start.line() + 1, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -257,8 +254,8 @@ impl<'a> GoPaginatorExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: paginator_info.creation_arguments.clone(), return_type: None, - start_position: (paginator_info.creation_line, 1), - end_position: (paginator_info.creation_line, 1), + expr: paginator_info.expr.clone(), + location: paginator_info.location.clone(), receiver: Some(paginator_info.client_receiver.clone()), }), } @@ -292,8 +289,8 @@ impl<'a> GoPaginatorExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: chained_call.arguments.clone(), return_type: None, - start_position: chained_call.start_position, - end_position: chained_call.end_position, + expr: chained_call.expr.clone(), + location: chained_call.location.clone(), receiver: Some(chained_call.client_receiver.clone()), }), } @@ -302,15 +299,18 @@ impl<'a> GoPaginatorExtractor<'a> { #[cfg(test)] mod tests { + use crate::SourceFile; + use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Go; - use std::collections::HashMap; + use std::{collections::HashMap, path::PathBuf}; - fn create_test_ast( - source_code: &str, - ) -> ast_grep_core::AstGrep> { - Go.ast_grep(source_code) + fn create_test_ast(source_code: &str) -> AstWithSourceFile { + let source_file = + SourceFile::with_language(PathBuf::new(), source_code.to_string(), Language::Go); + let ast_grep = Go.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file) } fn create_test_service_index() -> ServiceModelIndex { diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs index 3a961bc..a2dce8b 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs @@ -3,9 +3,13 @@ //! This module handles extraction of Go AWS SDK waiter patterns, which involve //! creating a waiter from a client, then calling Wait() on the waiter. +use std::path::Path; + use crate::extraction::go::utils; -use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; -use crate::ServiceModelIndex; +use crate::extraction::{ + AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, +}; +use crate::{Location, ServiceModelIndex}; use ast_grep_language::Go; /// Information about a discovered waiter creation call @@ -14,11 +18,19 @@ pub(crate) struct WaiterInfo { /// Variable name assigned to the waiter (e.g., "waiter", "instanceWaiter") pub variable_name: String, /// Clean waiter name (e.g., "InstanceTerminated") - pub waiter_type: String, + pub waiter_name: String, /// Client receiver variable name (e.g., "client", "ec2Client") pub client_receiver: String, - /// Line number where waiter was created - pub creation_line: usize, + /// Matched expression + pub expr: String, + /// Location where the waiter was created + pub location: Location, +} + +impl WaiterInfo { + pub(crate) fn start_line(&self) -> usize { + self.location.start_line() + } } /// Information about a Wait method call @@ -28,12 +40,51 @@ pub(crate) struct WaitCallInfo { pub waiter_var: String, /// Extracted arguments (context + input struct) pub arguments: Vec, - /// Line number where Wait was called - pub wait_line: usize, - /// Start position of the Wait call node - pub start_position: (usize, usize), - /// End position of the Wait call node - pub end_position: (usize, usize), + /// Matched expression + pub expr: String, + /// Location where the waiter was called + pub location: Location, +} + +impl WaitCallInfo { + pub(crate) fn start_line(&self) -> usize { + self.location.start_line() + } +} + +// TODO: This should be refactored at a higher level, so this type can be removed. +// See https://github.com/awslabs/iam-policy-autopilot/issues/88. +enum CallInfo<'a> { + None(&'a WaiterInfo), + Simple(&'a WaiterInfo, &'a WaitCallInfo), +} + +impl<'a> CallInfo<'a> { + fn waiter_name(&self) -> &'a str { + match self { + Self::None(waiter_info) | Self::Simple(waiter_info, ..) => &waiter_info.waiter_name, + } + } + + fn waiter_info(&self) -> &'a WaiterInfo { + match self { + CallInfo::None(waiter_info) | CallInfo::Simple(waiter_info, _) => waiter_info, + } + } + + fn expr(&self) -> &'a str { + match self { + CallInfo::None(waiter_info) => &waiter_info.expr, + CallInfo::Simple(_, wait_call_info) => &wait_call_info.expr, + } + } + + fn location(&self) -> &'a Location { + match self { + CallInfo::None(waiter_info) => &waiter_info.location, + CallInfo::Simple(_, wait_call_info) => &wait_call_info.location, + } + } } /// Extractor for Go AWS SDK waiter patterns @@ -57,7 +108,7 @@ impl<'a> GoWaiterExtractor<'a> { /// Extract waiter method calls from the AST pub(crate) fn extract_waiter_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { // Step 1: Find all waiter creation calls let waiters = self.find_waiter_creation_calls(ast); @@ -89,18 +140,17 @@ impl<'a> GoWaiterExtractor<'a> { } /// Find all waiter creation calls (NewXxxWaiter functions) - fn find_waiter_creation_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_waiter_creation_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut waiters = Vec::new(); // Pattern: $VAR := $PACKAGE.$FUNCTION($$$ARGS) where FUNCTION contains "New" and "Waiter" let waiter_pattern = "$VAR := $PACKAGE.$FUNCTION($$$ARGS)"; for node_match in root.find_all(waiter_pattern) { - if let Some(waiter_info) = self.parse_waiter_creation_call(&node_match) { + if let Some(waiter_info) = + self.parse_waiter_creation_call(&node_match, &ast.source_file.path) + { waiters.push(waiter_info); } } @@ -109,18 +159,15 @@ impl<'a> GoWaiterExtractor<'a> { } /// Find all Wait method calls - fn find_wait_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_wait_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut wait_calls = Vec::new(); // Pattern: $WAITER.Wait($$$ARGS) let wait_pattern = "$WAITER.Wait($$$ARGS)"; for node_match in root.find_all(wait_pattern) { - if let Some(wait_info) = self.parse_wait_call(&node_match) { + if let Some(wait_info) = self.parse_wait_call(&node_match, &ast.source_file.path) { wait_calls.push(wait_info); } } @@ -132,6 +179,7 @@ impl<'a> GoWaiterExtractor<'a> { fn parse_waiter_creation_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -161,13 +209,12 @@ impl<'a> GoWaiterExtractor<'a> { .and_then(|s| s.strip_suffix("Waiter")); if let Some(waiter_name) = waiter_name { - let creation_line = node_match.get_node().start_pos().line() + 1; - return Some(WaiterInfo { variable_name, - waiter_type: waiter_name.to_string(), + waiter_name: waiter_name.to_string(), client_receiver, - creation_line, + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }); } @@ -178,6 +225,7 @@ impl<'a> GoWaiterExtractor<'a> { fn parse_wait_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -188,17 +236,11 @@ impl<'a> GoWaiterExtractor<'a> { let args_nodes = env.get_multiple_matches("ARGS"); let arguments = utils::extract_arguments(&args_nodes); - // Get position information - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(WaitCallInfo { waiter_var, arguments, - wait_line: start.line() + 1, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -216,8 +258,8 @@ impl<'a> GoWaiterExtractor<'a> { for (idx, waiter) in waiters.iter().enumerate() { if waiter.variable_name == wait_call.waiter_var { // Only consider waiters that come before the wait call - if waiter.creation_line < wait_call.wait_line { - let distance = wait_call.wait_line - waiter.creation_line; + if waiter.start_line() < wait_call.start_line() { + let distance = wait_call.start_line() - waiter.start_line(); if distance < best_distance { best_distance = distance; best_match = Some(waiter); @@ -232,36 +274,27 @@ impl<'a> GoWaiterExtractor<'a> { fn create_synthetic_call_internal( &self, - wait_call: Option<&WaitCallInfo>, - waiter_info: &WaiterInfo, + call: CallInfo, + // wait_call: Option<&WaitCallInfo>, + // waiter_info: &WaiterInfo, ) -> Vec { let mut synthetic_calls = Vec::new(); // waiter_type already contains the clean waiter name (e.g., "InstanceTerminated") - if let Some(service_defs) = self - .service_index - .waiter_lookup - .get(&waiter_info.waiter_type) - { + if let Some(service_defs) = self.service_index.waiter_lookup.get(call.waiter_name()) { // Create one call per service for service_def in service_defs { let service_name = &service_def.service_name; let operation_name = &service_def.operation_name; - let (parameters, start_position, end_position) = match wait_call { - Some(wait_call) => ( - self.filter_waiter_parameters(wait_call.arguments.clone()), - wait_call.start_position, - wait_call.end_position, - ), - None => { + let parameters = match call { + CallInfo::Simple(_, wait_call) => { + self.filter_waiter_parameters(wait_call.arguments.clone()) + } + CallInfo::None(_) => { // Fallback: - ( - // Get required parameters for this operation - self.get_required_parameters(service_name, operation_name), - (waiter_info.creation_line, 1), - (waiter_info.creation_line, 1), - ) + // Get required parameters for this operation + self.get_required_parameters(service_name, operation_name) } }; @@ -271,9 +304,9 @@ impl<'a> GoWaiterExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, - start_position, - end_position, - receiver: Some(waiter_info.client_receiver.clone()), + expr: call.expr().to_string(), + location: call.location().clone(), + receiver: Some(call.waiter_info().client_receiver.clone()), }), }); } @@ -288,13 +321,13 @@ impl<'a> GoWaiterExtractor<'a> { wait_call: &WaitCallInfo, waiter_info: &WaiterInfo, ) -> Vec { - self.create_synthetic_call_internal(Some(wait_call), waiter_info) + self.create_synthetic_call_internal(CallInfo::Simple(waiter_info, wait_call)) } /// Create fallback synthetic calls for unmatched waiter creation /// Returns one call per service that has the waiter, matching Python behavior fn create_fallback_synthetic_call(&self, waiter_info: &WaiterInfo) -> Vec { - self.create_synthetic_call_internal(None, waiter_info) + self.create_synthetic_call_internal(CallInfo::None(waiter_info)) } /// Get required parameters for an operation from the service index @@ -336,16 +369,18 @@ impl<'a> GoWaiterExtractor<'a> { #[cfg(test)] mod tests { use crate::extraction::sdk_model::ServiceMethodRef; + use crate::{Language, SourceFile}; use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Go; - use std::collections::HashMap; + use std::{collections::HashMap, path::PathBuf}; - fn create_test_ast( - source_code: &str, - ) -> ast_grep_core::AstGrep> { - Go.ast_grep(source_code) + fn create_test_ast(source_code: &str) -> AstWithSourceFile { + let source_file = + SourceFile::with_language(PathBuf::new(), source_code.to_string(), Language::Go); + let ast_grep = Go.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file) } fn create_test_service_index() -> ServiceModelIndex { @@ -473,14 +508,14 @@ func main() { for waiter in &waiters { println!( " - {} := {}.{}", - waiter.variable_name, waiter.client_receiver, waiter.waiter_type + waiter.variable_name, waiter.client_receiver, waiter.waiter_name ); } // Should find waiter creation calls with correct Go SDK pattern assert_eq!(waiters.len(), 2); assert_eq!(waiters[0].variable_name, "instanceWaiter"); - assert_eq!(waiters[0].waiter_type, "InstanceRunning"); + assert_eq!(waiters[0].waiter_name, "InstanceRunning"); assert_eq!(waiters[0].client_receiver, "client"); // Client parameter, not package name } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/argument_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/argument_extractor.rs index c6ebda0..f59a64e 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/argument_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/argument_extractor.rs @@ -4,6 +4,8 @@ //! from ast-grep nodes, handling object literals, shorthand properties, //! spread operators, and proper value resolution (literal vs identifier). +use ast_grep_core::tree_sitter::StrDoc; + use crate::extraction::{Parameter, ParameterValue}; /// Utility for extracting arguments from JavaScript/TypeScript AST nodes @@ -19,10 +21,10 @@ impl ArgumentExtractor { /// /// Returns a vector of Parameters with proper Resolved/Unresolved classification pub fn extract_object_parameters( - args_node: Option<&ast_grep_core::Node>, + args_node: Option<&ast_grep_core::Node>>, ) -> Vec where - T: ast_grep_core::Doc, + T: ast_grep_language::LanguageExt, { let Some(node) = args_node else { return Vec::new(); diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs index 10bd3a3..f30416e 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs @@ -10,7 +10,8 @@ use std::collections::HashSet; use crate::extraction::extractor::{Extractor, ExtractorResult}; use crate::extraction::javascript::scanner::ASTScanner; use crate::extraction::javascript::shared::ExtractionUtils; -use crate::ServiceModelIndex; +use crate::extraction::AstWithSourceFile; +use crate::{ServiceModelIndex, SourceFile}; /// JavaScript extractor for AWS SDK method calls pub(crate) struct JavaScriptExtractor; @@ -30,9 +31,10 @@ impl Default for JavaScriptExtractor { #[async_trait] impl Extractor for JavaScriptExtractor { - async fn parse(&self, source_code: &str) -> ExtractorResult { + async fn parse(&self, source_file: &SourceFile) -> ExtractorResult { // Create AST once and reuse it - let ast = JavaScript.ast_grep(source_code); + let ast_grep = JavaScript.ast_grep(&source_file.content); + let ast = AstWithSourceFile::new(ast_grep, source_file.clone()); // Create scanner with the pre-built AST let mut scanner = ASTScanner::new(ast.clone(), JavaScript.into()); @@ -209,6 +211,14 @@ mod tests { } } + fn create_source_file(source_code: &str) -> SourceFile { + SourceFile::with_language( + std::path::PathBuf::new(), + source_code.to_string(), + crate::Language::JavaScript, + ) + } + #[tokio::test] async fn test_parse_import_via_require() { let extractor = JavaScriptExtractor::new(); @@ -228,7 +238,7 @@ async function createMyBucket() { createMyBucket(); "#; - let result = extractor.parse(source_code).await; + let result = extractor.parse(&create_source_file(source_code)).await; let method_calls = result.method_calls_ref(); // Should infer CreateBucket operation from CreateBucketCommand import @@ -300,7 +310,7 @@ async function getAllDynamoDBTables() { getAllDynamoDBTables(); "#; - let result = extractor.parse(source_code).await; + let result = extractor.parse(&create_source_file(source_code)).await; let method_calls = result.method_calls_ref(); // Should infer ListTables operation from paginateListTables import @@ -353,7 +363,7 @@ const bodyAsString = await bodyStream.transformToString(); const __error__ = await bodyStream.transformToString(); "#; - let result = extractor.parse(source_code).await; + let result = extractor.parse(&create_source_file(source_code)).await; let method_calls = result.method_calls_ref(); // Should find GetObject operation from direct client method call @@ -398,7 +408,7 @@ const command = new GetObjectCommand({ Bucket: 'test', Key: 'test.txt' }); "#; // Parse the code - let mut results = vec![extractor.parse(javascript_code).await]; + let mut results = vec![extractor.parse(&create_source_file(javascript_code)).await]; // Build service index with all services for testing let service_index = ServiceDiscovery::load_service_index(Language::JavaScript) @@ -470,7 +480,7 @@ const command = new GetObjectCommand({ Bucket: 'test', Key: 'test.txt' }); "#; // Parse the code - let mut results = vec![extractor.parse(code).await]; + let mut results = vec![extractor.parse(&create_source_file(code)).await]; // Load service index let service_index = ServiceDiscovery::load_service_index(Language::JavaScript) diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs index 80dc01e..47f044a 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/scanner.rs @@ -1,33 +1,19 @@ //! Core JavaScript/TypeScript scanning logic for AWS SDK extraction +use crate::extraction::javascript::shared::CommandUsage; use crate::extraction::javascript::types::{ ClientInstantiation, ImportInfo, JavaScriptScanResults, MethodCall, SublibraryInfo, ValidClientTypes, }; +use crate::extraction::AstWithSourceFile; +use crate::Location; use ast_grep_core::matcher::Pattern; -use ast_grep_core::MatchStrictness; +use ast_grep_core::{tree_sitter, MatchStrictness, NodeMatch}; +use ast_grep_core::{Doc, Node}; use std::collections::HashMap; -fn parse_import_item_with_line(import_item: &str, line: usize) -> Option { - let import_item = import_item.trim(); - if import_item.is_empty() { - return None; - } - - // Check for rename syntax: "OriginalName as LocalName" - if let Some(as_pos) = import_item.find(" as ") { - let original_name = import_item[..as_pos].trim().to_string(); - let local_name = import_item[as_pos + 4..].trim().to_string(); - Some(ImportInfo::new(original_name, local_name, line)) - } else { - // No rename - original name is the same as local name - let import_name = import_item.trim().to_string(); - Some(ImportInfo::new(import_name.clone(), import_name, line)) - } -} - fn parse_object_literal(obj_text: &str) -> HashMap { let mut result = HashMap::new(); @@ -114,58 +100,91 @@ fn parse_key_value_pair(pair: &str, result: &mut HashMap) { } } -fn parse_and_add_imports_with_line( - imports_text: &str, - sublibrary_info: &mut SublibraryInfo, - line: usize, -) { - // Handle different import formats - if imports_text.starts_with('{') && imports_text.ends_with('}') { - // Destructuring - parse with rename support - let imports_content = &imports_text[1..imports_text.len() - 1]; // Remove braces - - // Split by comma and parse each import - for import_item in imports_content.split(',') { - if let Some(import_info) = parse_import_item_with_line(import_item, line) { - sublibrary_info.add_import(import_info); - } - } - } else { - // Default import - single identifier - if let Some(import_info) = parse_import_item_with_line(imports_text, line) { - sublibrary_info.add_import(import_info); - } - } -} - /// Core AST scanner for JavaScript/TypeScript AWS SDK usage patterns pub(crate) struct ASTScanner where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { /// Pre-built AST grep root passed from extractor - ast_grep: ast_grep_core::AstGrep, - language: ast_grep_language::SupportLang, + pub(crate) ast_grep: AstWithSourceFile, + pub(crate) language: ast_grep_language::SupportLang, } impl ASTScanner where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { /// Create a new scanner with pre-built AST from extractor pub(crate) fn new( - ast_grep: ast_grep_core::AstGrep, + ast_grep: AstWithSourceFile, language: ast_grep_language::SupportLang, ) -> Self { Self { ast_grep, language } } + fn parse_and_add_imports( + &self, + imports_text: &str, + sublibrary_info: &mut SublibraryInfo, + node: &Node<'_, tree_sitter::StrDoc>, + ) { + // Handle different import formats + if imports_text.starts_with('{') && imports_text.ends_with('}') { + // Destructuring - parse with rename support + let imports_content = &imports_text[1..imports_text.len() - 1]; // Remove braces + + // Split by comma and parse each import + for import_item in imports_content.split(',') { + if let Some(import_info) = self.parse_import_item(import_item, node) { + sublibrary_info.add_import(import_info); + } + } + } else { + // Default import - single identifier + if let Some(import_info) = self.parse_import_item(imports_text, node) { + sublibrary_info.add_import(import_info); + } + } + } + + fn parse_import_item( + &self, + import_item: &str, + node: &Node<'_, tree_sitter::StrDoc>, + ) -> Option { + let import_item = import_item.trim(); + if import_item.is_empty() { + return None; + } + + // Check for rename syntax: "OriginalName as LocalName" + if let Some(as_pos) = import_item.find(" as ") { + let original_name = import_item[..as_pos].trim().to_string(); + let local_name = import_item[as_pos + 4..].trim().to_string(); + Some(ImportInfo::new( + original_name, + local_name, + import_item, + Location::from_node(self.ast_grep.source_file.path.to_path_buf(), node), + )) + } else { + // No rename - original name is the same as local name + let import_name = import_item.trim().to_string(); + Some(ImportInfo::new( + import_name.clone(), + import_name, + import_item, + Location::from_node(self.ast_grep.source_file.path.to_path_buf(), node), + )) + } + } + /// Execute a pattern match against the AST using relaxed strictness to handle inline comments fn find_all_matches( &self, pattern: &str, - ) -> Result>, String> { - let root = self.ast_grep.root(); + ) -> Result>>, String> { + let root = self.ast_grep.ast.root(); // Build pattern with relaxed strictness to handle inline comments let pattern_obj = @@ -174,30 +193,20 @@ where Ok(root.find_all(pattern_obj).collect()) } - /// Extract 1-based (line, column) position from the first match - fn get_first_match_position(matches: &[ast_grep_core::NodeMatch]) -> Option<(usize, usize)> { - matches.first().map(|first_match| { - let node = first_match.get_node(); - let pos = node.start_pos(); - let line = pos.line() + 1; - let column = pos.column(node) + 1; - (line, column) - }) - } - /// Find Command instantiation and extract its arguments - /// Returns ((line_number, column_number), parameters) tuple + /// Returns CommandInstantiationResult with position and parameters pub(crate) fn find_command_instantiation_with_args( &self, command_name: &str, - ) -> Option<((usize, usize), Vec)> { + ) -> Option> { use crate::extraction::javascript::argument_extractor::ArgumentExtractor; let pattern = format!("new {}($ARGS)", command_name); if let Ok(matches) = self.find_all_matches(&pattern) { - if let Some(position) = Self::get_first_match_position(&matches) { - let first_match = matches.first().unwrap(); + if let Some(first_match) = matches.first() { + let location = + Location::from_node(self.ast_grep.source_file.path.to_path_buf(), first_match); let env = first_match.get_env(); // Extract arguments from the ARGS node @@ -205,7 +214,7 @@ where let args_node = env.get_match("ARGS"); let parameters = ArgumentExtractor::extract_object_parameters(args_node); - return Some((position, parameters)); + return Some(CommandUsage::new(first_match.text(), location, parameters)); } } None @@ -215,22 +224,23 @@ where pub(crate) fn find_paginate_function_with_args( &self, function_name: &str, - ) -> Option<((usize, usize), Vec)> { + ) -> Option> { use crate::extraction::javascript::argument_extractor::ArgumentExtractor; // Use explicit two-argument pattern let pattern = format!("{}($ARG1, $ARG2)", function_name); if let Ok(matches) = self.find_all_matches(&pattern) { - if let Some(position) = Self::get_first_match_position(&matches) { - let first_match = matches.first().unwrap(); + if let Some(first_match) = matches.first() { + let location = + Location::from_node(self.ast_grep.source_file.path.to_path_buf(), first_match); let env = first_match.get_env(); // Extract parameters from second argument (ARG2 = operation params) let second_arg = env.get_match("ARG2"); let parameters = ArgumentExtractor::extract_object_parameters(second_arg); - return Some((position, parameters)); + return Some(CommandUsage::new(first_match.text(), location, parameters)); } } None @@ -240,7 +250,7 @@ where pub(crate) fn find_waiter_function_with_args( &self, function_name: &str, - ) -> Option<((usize, usize), Vec)> { + ) -> Option> { use crate::extraction::javascript::argument_extractor::ArgumentExtractor; // Try patterns with and without await keyword using explicit two-argument pattern @@ -251,15 +261,18 @@ where for pattern in &patterns { if let Ok(matches) = self.find_all_matches(pattern) { - if let Some(position) = Self::get_first_match_position(&matches) { - let first_match = matches.first().unwrap(); + if let Some(first_match) = matches.first() { + let location = Location::from_node( + self.ast_grep.source_file.path.to_path_buf(), + first_match, + ); let env = first_match.get_env(); // Extract parameters from second argument (ARG2 = operation params) let second_arg = env.get_match("ARG2"); let parameters = ArgumentExtractor::extract_object_parameters(second_arg); - return Some((position, parameters)); + return Some(CommandUsage::new(first_match.text(), location, parameters)); } } } @@ -270,7 +283,7 @@ where pub(crate) fn find_command_input_usage_position( &self, type_name: &str, - ) -> Option<(usize, usize)> { + ) -> Option> { // Try multiple patterns for TypeScript type annotations let patterns = [ format!("const $VAR: {} = $VALUE", type_name), // const variable: Type = value @@ -280,8 +293,15 @@ where for pattern in &patterns { if let Ok(matches) = self.find_all_matches(pattern) { - if let Some(position) = Self::get_first_match_position(&matches) { - return Some(position); + if let Some(first_match) = matches.first() { + let location = Location::from_node( + self.ast_grep.source_file.path.to_path_buf(), + first_match, + ); + let expr_text = matches.first().unwrap().text(); + // TODO: Extract from variable assignments + let parameters = vec![]; + return Some(CommandUsage::new(expr_text, location, parameters)); } } } @@ -293,20 +313,17 @@ where let mut sublibrary_data: HashMap = HashMap::new(); let matches = self.find_all_matches(pattern)?; - Self::process_import_matches(matches, &mut sublibrary_data, true)?; + self.process_import_matches(matches, &mut sublibrary_data)?; Ok(sublibrary_data.into_values().collect()) } /// Generic processing for import/require matches - works for both JavaScript and TypeScript - fn process_import_matches( - matches: Vec>, + fn process_import_matches( + &self, + matches: Vec>>, sublibrary_data: &mut HashMap, - include_line_numbers: bool, - ) -> Result<(), String> - where - U: ast_grep_core::Doc + std::clone::Clone, - { + ) -> Result<(), String> { for node_match in matches { let env = node_match.get_env(); @@ -331,13 +348,11 @@ where .entry(sublibrary.clone()) .or_insert_with(|| SublibraryInfo::new(sublibrary)); - if include_line_numbers { - // Get line number from AST node - let line = node_match.get_node().start_pos().line() + 1; - parse_and_add_imports_with_line(imports_text_str, sublibrary_info, line); - } else { - parse_and_add_imports_with_line(imports_text_str, sublibrary_info, 1); - } + self.parse_and_add_imports( + imports_text_str, + sublibrary_info, + node_match.get_node(), + ); } } Ok(()) @@ -448,14 +463,14 @@ where /// Generic processing for client instantiation matches - works for both JavaScript and TypeScript fn process_client_instantiation_matches( - matches: Vec>, + matches: Vec>, valid_client_types: &[String], client_name_mappings: &HashMap, client_sublibrary_mappings: &HashMap, results: &mut Vec, ) -> Result<(), String> where - U: ast_grep_core::Doc + std::clone::Clone, + U: Doc + std::clone::Clone, { for node_match in matches { let env = node_match.get_env(); @@ -505,15 +520,13 @@ where } /// Generic processing for method call matches - works for both JavaScript and TypeScript - fn process_method_call_matches( - matches: Vec>, + fn process_method_call_matches( + &self, + matches: Vec>>, client_variables: &[String], client_info_map: &HashMap, results: &mut Vec, - ) -> Result<(), String> - where - U: ast_grep_core::Doc + std::clone::Clone, - { + ) -> Result<(), String> { for node_match in matches { let env = node_match.get_env(); @@ -538,17 +551,18 @@ where HashMap::new() }; - // Get line number - let line = node_match.get_node().start_pos().line() + 1; - results.push(MethodCall { client_variable: variable_name, client_type: client_type.clone(), original_client_type: original_client_type.clone(), client_sublibrary: client_sublibrary.clone(), + expr: node_match.text().to_string(), method_name, arguments, - line, + location: Location::from_node( + self.ast_grep.source_file.path.to_path_buf(), + node_match.get_node(), + ), }); } } @@ -585,7 +599,7 @@ where // Single pattern to match method calls (covers both awaited and non-awaited) let matches = self.find_all_matches("$VAR.$METHOD($ARGS)")?; - Self::process_method_call_matches( + self.process_method_call_matches( matches, &client_variables, &client_info_map, @@ -614,24 +628,13 @@ where #[cfg(test)] mod tests { - use super::*; - use ast_grep_core::tree_sitter::LanguageExt; - use ast_grep_language::{JavaScript, TypeScript}; + use std::path::PathBuf; - #[test] - fn test_parse_import_item() { - // Test regular import - let import_info = parse_import_item_with_line("S3Client", 1).unwrap(); - assert_eq!(import_info.original_name, "S3Client"); - assert_eq!(import_info.local_name, "S3Client"); - assert!(!import_info.is_renamed); + use crate::SourceFile; - // Test renamed import - let import_info = parse_import_item_with_line("S3Client as MyS3Client", 1).unwrap(); - assert_eq!(import_info.original_name, "S3Client"); - assert_eq!(import_info.local_name, "MyS3Client"); - assert!(import_info.is_renamed); - } + use super::*; + use ast_grep_language::{JavaScript, TypeScript}; + use tree_sitter::LanguageExt; #[test] fn test_parse_object_literal() { @@ -644,6 +647,51 @@ mod tests { assert!(result.is_empty()); } + fn create_js_ast(source_code: &str) -> AstWithSourceFile { + let source_file = SourceFile::with_language( + PathBuf::new(), + source_code.to_string(), + crate::Language::JavaScript, + ); + let ast_grep = JavaScript.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file.clone()) + } + + #[test] + fn test_parse_import_item() { + // Test regular import + let source = r#"import S3Client from "@aws-sdk/client-s3""#; + let ast = create_js_ast(source); + let mut scanner = ASTScanner::new(ast, JavaScript.into()); + + let (imports, _requires) = scanner.scan_all_aws_imports().unwrap(); + assert_eq!( + ImportInfo::new( + "S3Client".to_string(), + "S3Client".to_string(), + "S3Client", + Location::new(PathBuf::new(), (1, 1), (1, 42)), + ), + imports[0].imports[0] + ); + + // Test renamed import + let source = r#"import { S3Client as MyS3Client } from "@aws-sdk/client-s3";"#; + let ast = create_js_ast(source); + let mut scanner = ASTScanner::new(ast, JavaScript.into()); + + let (imports, _requires) = scanner.scan_all_aws_imports().unwrap(); + assert_eq!( + ImportInfo::new( + "S3Client".to_string(), + "MyS3Client".to_string(), + "S3Client as MyS3Client", + Location::new(PathBuf::new(), (1, 1), (1, 61)), + ), + imports[0].imports[0] + ); + } + #[test] fn test_import_require_scanning_comprehensive() { // Create comprehensive test case with multiple sublibrary patterns @@ -655,7 +703,7 @@ const { LambdaClient, InvokeCommand } = require("@aws-sdk/client-lambda"); const { SESClient } = require("@aws-sdk/client-ses"); "#; - let ast = JavaScript.ast_grep(source); + let ast = create_js_ast(source); let mut scanner = ASTScanner::new(ast, JavaScript.into()); let (imports, requires) = scanner.scan_all_aws_imports().unwrap(); @@ -837,7 +885,7 @@ async function uploadFile() { } "#; - let ast = JavaScript.ast_grep(source_with_usage); + let ast = create_js_ast(source_with_usage); let scanner = ASTScanner::new(ast, JavaScript.into()); // Should find CreateBucketCommand instantiation at line ~6 @@ -888,7 +936,7 @@ async function listAllTables() { } "#; - let ast = JavaScript.ast_grep(source_with_usage); + let ast = create_js_ast(source_with_usage); let scanner = ASTScanner::new(ast, JavaScript.into()); // Should find paginateQuery call at line ~7 @@ -909,6 +957,16 @@ async function listAllTables() { println!("✅ Paginate function position heuristics working correctly"); } + fn create_ts_ast(source_code: &str) -> AstWithSourceFile { + let source_file = SourceFile::with_language( + PathBuf::new(), + source_code.to_string(), + crate::Language::TypeScript, + ); + let ast_grep = TypeScript.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file.clone()) + } + #[test] fn test_position_heuristics_command_input_typescript() { // Test CommandInput type usage position finding (TypeScript-specific) @@ -933,24 +991,28 @@ function createListParams(): ListTablesInput { } "#; - let ast = TypeScript.ast_grep(typescript_source); + let ast = create_ts_ast(typescript_source); let scanner = ASTScanner::new(ast, TypeScript.into()); - let query_input_pos = scanner.find_command_input_usage_position("QueryCommandInput"); - assert!( - query_input_pos.is_some(), - "Should find QueryCommandInput usage" - ); - let (line, _col) = query_input_pos.unwrap(); - assert_eq!(line, 9, "QueryCommandInput should be at line 9"); + if let Some(result) = scanner.find_command_input_usage_position("QueryCommandInput") { + assert_eq!( + result.location.start_line(), + 9, + "QueryCommandInput should be at line 9" + ); + } else { + panic!("Should find QueryCommandInput usage"); + } - let list_input_pos = scanner.find_command_input_usage_position("ListTablesInput"); - assert!( - list_input_pos.is_some(), - "Should find ListTablesInput usage" - ); - let (line, _col) = list_input_pos.unwrap(); - assert_eq!(line, 15, "ListTablesInput should be at line 15"); + if let Some(result) = scanner.find_command_input_usage_position("ListTablesInput") { + assert_eq!( + result.location.start_line(), + 15, + "ListTablesInput should be at line 15" + ); + } else { + panic!("Should find ListTablesInput usage"); + } // Should return None for type that wasn't used let missing_type_pos = scanner.find_command_input_usage_position("PutItemInput"); @@ -971,7 +1033,7 @@ const { CreateBucketCommand } = require("@aws-sdk/client-s3"); const command = new CreateBucketCommand({ Bucket: "test" }); "#; - let ast = JavaScript.ast_grep(javascript_source); + let ast = create_js_ast(javascript_source); let scanner = ASTScanner::new(ast, JavaScript.into()); // JavaScript should find command instantiation @@ -1010,7 +1072,7 @@ let dynamoSdk = require("@aws-sdk/lib-dynamodb"); var ec2Sdk = require("@aws-sdk/client-ec2"); "#; - let ast = JavaScript.ast_grep(source_with_mixed_requires); + let ast = create_js_ast(source_with_mixed_requires); let mut scanner = ASTScanner::new(ast, JavaScript.into()); let (imports, requires) = scanner.scan_all_aws_imports().unwrap(); @@ -1159,7 +1221,7 @@ async function testOperations() { } "#; - let ast = TypeScript.ast_grep(typescript_source); + let ast = create_ts_ast(typescript_source); let mut scanner = ASTScanner::new(ast, TypeScript.into()); let scan_results = scanner.scan_all().unwrap(); @@ -1282,7 +1344,7 @@ async function uploadLargeFile() { } "#; - let ast = TypeScript.ast_grep(typescript_source); + let ast = create_ts_ast(typescript_source); let mut scanner = ASTScanner::new(ast, TypeScript.into()); let scan_results = scanner.scan_all().unwrap(); diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs index ce8e105..870bff8 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/shared.rs @@ -3,10 +3,12 @@ //! This module contains common functionality shared between JavaScript and TypeScript //! extractors. -use crate::extraction::javascript::types::JavaScriptScanResults; +use crate::extraction::javascript::types::{ImportInfo, JavaScriptScanResults}; use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; +use crate::Location; use rust_embed::RustEmbed; use serde::Deserialize; +use std::borrow::Cow; use std::collections::HashMap; /// Embedded JavaScript SDK v3 libraries mapping @@ -48,6 +50,45 @@ fn load_libraries_mapping() -> Option { serde_json::from_str(content).ok() } +/// Result of finding a command/function instantiation with its arguments +#[derive(Debug, Clone)] +pub(crate) struct CommandUsage<'a> { + /// The matched text from the AST + pub(crate) text: Cow<'a, str>, + /// Location where the command usage was found + pub(crate) location: Location, + /// Extracted parameters from the command/function arguments + pub(crate) parameters: Vec, +} + +impl<'a> CommandUsage<'a> { + /// Create a new CommandInstantiationResult + pub(crate) fn new( + text: Cow<'a, str>, + location: Location, + parameters: Vec, + ) -> Self { + Self { + text, + location, + parameters, + } + } +} + +// Used when we cannot find a method call, and fall back to adding an operation purely based on an import statement +impl From<&ImportInfo> for CommandUsage<'_> { + fn from(value: &ImportInfo) -> Self { + Self { + text: Cow::Owned(value.statement.clone()), + location: value.location.clone(), + // TODO: parameters should be an Option, so we can distinguish + // the case where we fall back to an import statement + parameters: vec![], + } + } +} + /// Shared extraction utilities for JavaScript/TypeScript AWS SDK method calls pub(crate) struct ExtractionUtils; @@ -58,7 +99,7 @@ impl ExtractionUtils { scanner: &mut crate::extraction::javascript::scanner::ASTScanner, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut method_calls = Vec::new(); @@ -105,7 +146,7 @@ impl ExtractionUtils { lib_mappings: Option<&JsV3LibrariesMapping>, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut operations = Vec::new(); @@ -124,9 +165,9 @@ impl ExtractionUtils { if import_info.original_name.ends_with("Command") { // Try to find the actual constructor instantiation with arguments // Use the local name for the search (handles renames) - let (usage_position, parameters) = scanner + let result = scanner .find_command_instantiation_with_args(&import_info.local_name) - .unwrap_or_else(|| ((import_info.line, 1), Vec::new())); // Fallback to import position with no params + .unwrap_or_else(|| import_info.into()); // Fallback to import position with no params // Check if this needs library expansion (lib-* sublibraries) let expanded_command_names = @@ -160,10 +201,10 @@ impl ExtractionUtils { name: operation_name.to_string(), possible_services: vec![service.clone()], metadata: Some(SdkMethodCallMetadata { - parameters: parameters.clone(), + parameters: result.parameters.clone(), return_type: None, - start_position: usage_position, - end_position: usage_position, + expr: result.text.to_string(), + location: result.location.clone(), receiver: None, // Commands are typically standalone }), }; @@ -185,7 +226,7 @@ impl ExtractionUtils { lib_mappings: Option<&JsV3LibrariesMapping>, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut operations = Vec::new(); @@ -204,9 +245,9 @@ impl ExtractionUtils { if import_info.original_name.starts_with("paginate") { // Try to find the actual paginate function call with arguments // Use the local name for the search (handles renames) - let (usage_position, parameters) = scanner + let result = scanner .find_paginate_function_with_args(&import_info.local_name) - .unwrap_or_else(|| ((import_info.line, 1), Vec::new())); // Fallback to import position with no params + .unwrap_or_else(|| import_info.into()); // Fallback to import position with no params // Check if this needs library expansion (lib-* sublibraries) let expanded_paginator_names = @@ -249,10 +290,10 @@ impl ExtractionUtils { name: operation_name, possible_services: vec![service.clone()], metadata: Some(SdkMethodCallMetadata { - parameters: parameters.clone(), // extracted from 2nd argument! + parameters: result.parameters.clone(), // extracted from 2nd argument! return_type: None, - start_position: usage_position, - end_position: usage_position, + expr: result.text.to_string(), + location: result.location.clone(), receiver: None, }), }; @@ -274,7 +315,7 @@ impl ExtractionUtils { scanner: &mut crate::extraction::javascript::scanner::ASTScanner, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut operations = Vec::new(); @@ -297,9 +338,9 @@ impl ExtractionUtils { { // Try to find the actual waiter function call with arguments // Use the local name for the search (handles renames) - let (usage_position, parameters) = scanner + let result = scanner .find_waiter_function_with_args(&import_info.local_name) - .unwrap_or_else(|| ((import_info.line, 1), Vec::new())); // Fallback to import position with no params + .unwrap_or_else(|| import_info.into()); // Fallback to import position with no params // Keep PascalCase waiter name // e.g., "BucketExists" from "waitUntilBucketExists" @@ -308,10 +349,10 @@ impl ExtractionUtils { name: waiter_name.to_string(), possible_services: vec![service.clone()], metadata: Some(SdkMethodCallMetadata { - parameters, // Extracted from 2nd argument (operation params) + parameters: result.parameters, // Extracted from 2nd argument (operation params) return_type: None, - start_position: usage_position, - end_position: usage_position, + expr: result.text.to_string(), + location: result.location.clone(), receiver: None, // Waiter functions are standalone }), }; @@ -331,7 +372,7 @@ impl ExtractionUtils { scanner: &mut crate::extraction::javascript::scanner::ASTScanner, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut operations = Vec::new(); @@ -357,9 +398,9 @@ impl ExtractionUtils { if let Some(operation_name) = operation_name { // Try to find the actual CommandInput type usage position (TypeScript-specific) // Use the local name for the search (handles renames) - let usage_position = scanner + let result = scanner .find_command_input_usage_position(&import_info.local_name) - .unwrap_or((import_info.line, 1)); // Fallback to import position + .unwrap_or_else(|| import_info.into()); // Fallback to import position with no params // Keep PascalCase operation name to match service index // e.g., "Query" stays "Query" @@ -367,10 +408,10 @@ impl ExtractionUtils { name: operation_name.to_string(), possible_services: vec![service.clone()], metadata: Some(SdkMethodCallMetadata { - parameters: Vec::new(), // TODO: Extract from variable assignments + parameters: Vec::new(), return_type: None, - start_position: usage_position, // Using enhanced position tracking - end_position: usage_position, // Using enhanced position tracking + expr: result.text.to_string(), + location: result.location.clone(), receiver: None, }), }; @@ -391,7 +432,7 @@ impl ExtractionUtils { lib_mappings: Option<&JsV3LibrariesMapping>, ) -> Vec where - T: ast_grep_core::Doc + Clone, + T: ast_grep_language::LanguageExt, { let mut operations = Vec::new(); @@ -427,9 +468,9 @@ impl ExtractionUtils { .and_then(|lib| lib.get(&import_info.original_name)) { // Try to find class instantiation, fallback to import position - let (usage_position, parameters) = scanner + let result = scanner .find_command_instantiation_with_args(&import_info.local_name) - .unwrap_or_else(|| ((import_info.line, 1), Vec::new())); + .unwrap_or_else(|| import_info.into()); // Fallback to import position with no params // Create operations for each expanded command for command_name in expanded_commands { @@ -439,10 +480,10 @@ impl ExtractionUtils { name: operation_name.to_string(), possible_services: vec![service.clone()], metadata: Some(SdkMethodCallMetadata { - parameters: parameters.clone(), + parameters: result.parameters.clone(), return_type: None, - start_position: usage_position, - end_position: usage_position, + expr: result.text.to_string(), + location: result.location.clone(), receiver: None, }), }; @@ -490,8 +531,8 @@ impl ExtractionUtils { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, - start_position: (method_call.line, 1), - end_position: (method_call.line, 1), + expr: method_call.expr.clone(), + location: method_call.location.clone(), receiver: Some(method_call.client_variable.clone()), }), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs index 227c237..2f55f6c 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/types.rs @@ -3,34 +3,44 @@ use serde::{Deserialize, Serialize}; use std::collections::HashMap; +use crate::Location; + /// Information about a single import with rename support -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub(crate) struct ImportInfo { /// Original name in the AWS SDK (e.g., "S3Client", "PutObjectCommand") pub(crate) original_name: String, + /// Import statement + pub(crate) statement: String, /// Local name used in the code (e.g., "MyS3Client", "PutObject") pub(crate) local_name: String, /// Whether this import was renamed (original_name != local_name) pub(crate) is_renamed: bool, - /// Line number where this import appears - pub(crate) line: usize, + /// Location of the import + pub(crate) location: Location, } impl ImportInfo { /// Create a new ImportInfo with the given names and line position - pub(crate) fn new(original_name: String, local_name: String, line: usize) -> Self { + pub(crate) fn new( + original_name: String, + local_name: String, + statement: &str, + location: Location, + ) -> Self { let is_renamed = original_name != local_name; Self { original_name, + statement: statement.to_string(), local_name, is_renamed, - line, + location, } } } /// Information about imports from a specific AWS SDK sublibrary -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub(crate) struct SublibraryInfo { /// AWS SDK sublibrary name (e.g., "client-s3", "lib-dynamodb") pub(crate) sublibrary: String, @@ -109,7 +119,7 @@ impl ValidClientTypes { } /// Information about a method call (non-send) -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub(crate) struct MethodCall { /// Client variable name pub(crate) client_variable: String, @@ -123,12 +133,14 @@ pub(crate) struct MethodCall { pub(crate) method_name: String, /// Method arguments pub(crate) arguments: HashMap, - /// Line number where call occurs - pub(crate) line: usize, + /// Matched expression + pub(crate) expr: String, + /// Location where the method call was found + pub(crate) location: Location, } /// Combined results from all scanning operations -#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub(crate) struct JavaScriptScanResults { /// Import information pub(crate) imports: Vec, diff --git a/iam-policy-autopilot-policy-generation/src/extraction/mod.rs b/iam-policy-autopilot-policy-generation/src/extraction/mod.rs index 91be7da..ff14ff0 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/mod.rs @@ -27,10 +27,32 @@ pub use self::{core::*, output::*}; /// Core data structures for source file parsing and method extraction pub mod core { - use crate::Language; + use std::sync::Arc; + + use schemars::JsonSchema; + + use crate::{Language, Location}; use super::{Deserialize, Path, PathBuf, Serialize}; + #[derive(Clone)] + pub(crate) struct AstWithSourceFile { + pub(crate) ast: Arc>>, + pub(crate) source_file: Arc, + } + + impl AstWithSourceFile { + pub(crate) fn new( + ast: ast_grep_core::AstGrep>, + source_file: SourceFile, + ) -> Self { + Self { + ast: Arc::new(ast), + source_file: Arc::new(source_file), + } + } + } + /// Represents a source file being analyzed /// /// Contains the file path, content, and detected programming language. @@ -82,7 +104,7 @@ pub mod core { /// Contains detailed information about a method call including parameters, /// position information, and parsing context. This is optional metadata /// that can be omitted when only basic method identification is needed. - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] + #[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] #[serde(rename_all = "PascalCase")] pub struct SdkMethodCallMetadata { /// List of method parameters with their metadata @@ -90,11 +112,11 @@ pub mod core { /// Return type annotation if available pub(crate) return_type: Option, + /// The matched expression + pub(crate) expr: String, + // Position information - /// Starting position (line, column) - both 1-based - pub(crate) start_position: (usize, usize), - /// Ending position (line, column) - both 1-based - pub(crate) end_position: (usize, usize), + pub(crate) location: Location, // SDK method call context /// Receiver variable name (e.g., "`s3_client`", "ec2") @@ -167,7 +189,7 @@ pub mod core { } /// Parameter value that distinguishes between resolved literals and unresolved expressions - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] + #[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] pub(crate) enum ParameterValue { /// Resolved string literal with quotes stripped (e.g., "my-bucket", "42", "true") Resolved(String), @@ -192,7 +214,7 @@ pub mod core { /// /// TODO: Refactor enum variant fields into separate structs to enable Default trait /// implementation and improve ergonomics. See: https://github.com/awslabs/iam-policy-autopilot/issues/61 - #[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] + #[derive(Debug, Clone, Hash, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] pub(crate) enum Parameter { /// Positional argument (e.g., first, second argument in call) Positional { @@ -303,7 +325,7 @@ pub mod output { #[cfg(test)] mod tests { - use crate::Language; + use crate::{Language, Location}; use super::*; use std::path::PathBuf; @@ -322,62 +344,10 @@ mod tests { } #[test] - fn test_parsed_method_creation() { - let method = SdkMethodCall { - name: "test_method".to_string(), - possible_services: Vec::new(), - metadata: Some(SdkMethodCallMetadata { - parameters: vec![Parameter::Keyword { - name: "param1".to_string(), - value: ParameterValue::Resolved("test_value".to_string()), - position: 0, - type_annotation: Some("str".to_string()), - }], - return_type: Some("bool".to_string()), - start_position: (10, 1), - end_position: (10, 25), - receiver: None, - }), - }; - - assert_eq!(method.name, "test_method"); - assert_eq!(method.metadata.as_ref().unwrap().parameters.len(), 1); - assert_eq!(method.metadata.as_ref().unwrap().start_position, (10, 1)); - assert_eq!(method.metadata.as_ref().unwrap().end_position, (10, 25)); - } - - #[test] - fn test_parsed_method_with_sdk_context() { - let method = SdkMethodCall { - name: "get_object".to_string(), - possible_services: vec!["s3".to_string()], - metadata: Some(SdkMethodCallMetadata { - parameters: vec![Parameter::Keyword { - name: "Bucket".to_string(), - value: ParameterValue::Resolved("my-bucket".to_string()), - position: 0, - type_annotation: Some("str".to_string()), - }], - return_type: Some("Dict[str, Any]".to_string()), - start_position: (15, 5), - end_position: (15, 45), - receiver: Some("s3_client".to_string()), - }), - }; - - assert_eq!(method.name, "get_object"); - assert_eq!( - method.metadata.as_ref().unwrap().receiver, - Some("s3_client".to_string()) - ); - assert_eq!(method.possible_services, vec!["s3".to_string()]); - if let Parameter::Keyword { - value, position, .. - } = &method.metadata.as_ref().unwrap().parameters[0] - { - assert_eq!(value.as_string(), "my-bucket"); - assert_eq!(*position, 0); - } + fn test_location_construction() { + let location = Location::new(PathBuf::new(), (10, 1), (10, 25)); + assert_eq!(location.start_position, (10, 1)); + assert_eq!(location.end_position, (10, 25)); } #[test] @@ -424,8 +394,8 @@ mod tests { let metadata = SdkMethodCallMetadata { parameters: vec![], return_type: Some("Dict[str, Any]".to_string()), - start_position: (10, 5), - end_position: (10, 30), + expr: "s3_client.foo_bar".to_string(), + location: Location::new(PathBuf::new(), (10, 5), (10, 30)), receiver: Some("s3_client".to_string()), }; @@ -434,8 +404,7 @@ mod tests { // Verify PascalCase field names assert!(json.contains("\"Parameters\"")); assert!(json.contains("\"ReturnType\"")); - assert!(json.contains("\"StartPosition\"")); - assert!(json.contains("\"EndPosition\"")); + assert!(json.contains("\"Location\"")); assert!(json.contains("\"Receiver\"")); } } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs index 9019cbe..f848930 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs @@ -199,7 +199,9 @@ mod tests { Shape, ShapeReference, }; use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; + use crate::Location; use std::collections::HashMap; + use std::path::PathBuf; fn create_test_service_index() -> ServiceModelIndex { let mut services = HashMap::new(); @@ -297,6 +299,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -318,8 +321,7 @@ mod tests { }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -338,6 +340,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -348,8 +351,7 @@ mod tests { // Missing required Stage and ApiId parameters ], return_type: None, - start_position: (1, 1), - end_position: (1, 30), + location: Location::new(PathBuf::new(), (1, 1), (1, 30)), receiver: Some("client".to_string()), }), }; @@ -367,13 +369,13 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![Parameter::DictionarySplat { expression: "**params".to_string(), position: 0, }], return_type: None, - start_position: (1, 1), - end_position: (1, 30), + location: Location::new(PathBuf::new(), (1, 1), (1, 30)), receiver: Some("client".to_string()), }), }; @@ -392,6 +394,7 @@ mod tests { name: "non_aws_method".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "non_aws_method".to_string(), parameters: vec![Parameter::Keyword { name: "custom_param".to_string(), value: ParameterValue::Resolved("value".to_string()), @@ -399,8 +402,7 @@ mod tests { type_annotation: None, }], return_type: None, - start_position: (1, 1), - end_position: (1, 30), + location: Location::new(PathBuf::new(), (1, 1), (1, 30)), receiver: Some("client".to_string()), }), }; @@ -421,6 +423,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -440,8 +443,7 @@ mod tests { }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; @@ -462,6 +464,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -477,8 +480,7 @@ mod tests { }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("client".to_string()), }), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs index e8ca62d..b060268 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs @@ -15,6 +15,7 @@ mod tests { use crate::extraction::{ Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, SourceFile, }; + use crate::Location; use std::collections::HashMap; use std::path::PathBuf; @@ -180,6 +181,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -201,8 +203,7 @@ mod tests { }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 80), + location: Location::new(PathBuf::new(), (1, 1), (1, 80)), receiver: Some("apigateway_client".to_string()), }), }; @@ -222,6 +223,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -232,8 +234,7 @@ mod tests { // Missing required Stage and ApiId parameters ], return_type: None, - start_position: (1, 1), - end_position: (1, 40), + location: Location::new(PathBuf::new(), (1, 1), (1, 40)), receiver: Some("apigateway_client".to_string()), }), }; @@ -251,6 +252,7 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![ Parameter::Keyword { name: "DomainName".to_string(), @@ -278,8 +280,7 @@ mod tests { }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 100), + location: Location::new(PathBuf::new(), (1, 1), (1, 100)), receiver: Some("apigateway_client".to_string()), }), }; @@ -297,13 +298,13 @@ mod tests { name: "create_api_mapping".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "create_api_mapping".to_string(), parameters: vec![Parameter::DictionarySplat { expression: "**params".to_string(), position: 0, }], return_type: None, - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("apigateway_client".to_string()), }), }; @@ -328,6 +329,7 @@ mod tests { name: "custom_method".to_string(), // Not an AWS SDK method possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "custom_method".to_string(), parameters: vec![Parameter::Keyword { name: "custom_param".to_string(), value: ParameterValue::Resolved("value".to_string()), @@ -335,8 +337,7 @@ mod tests { type_annotation: None, }], return_type: None, - start_position: (1, 1), - end_position: (1, 30), + location: Location::new(PathBuf::new(), (1, 1), (1, 30)), receiver: Some("custom_client".to_string()), }), }; @@ -356,6 +357,7 @@ mod tests { name: "get_object".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "get_object".to_string(), parameters: vec![ Parameter::Keyword { name: "Bucket".to_string(), @@ -371,8 +373,7 @@ mod tests { }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 50), + location: Location::new(PathBuf::new(), (1, 1), (1, 50)), receiver: Some("s3_client".to_string()), }), }, @@ -381,6 +382,7 @@ mod tests { name: "get_object".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "get_object".to_string(), parameters: vec![Parameter::Keyword { name: "custom_param".to_string(), // Invalid parameter for AWS S3 value: ParameterValue::Resolved("value".to_string()), @@ -388,8 +390,7 @@ mod tests { type_annotation: None, }], return_type: None, - start_position: (2, 1), - end_position: (2, 30), + location: Location::new(PathBuf::new(), (2, 1), (2, 30)), receiver: Some("custom_client".to_string()), }), }, @@ -398,6 +399,7 @@ mod tests { name: "custom_method".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "custom_method".to_string(), parameters: vec![Parameter::Keyword { name: "param".to_string(), value: ParameterValue::Resolved("value".to_string()), @@ -405,8 +407,7 @@ mod tests { type_annotation: None, }], return_type: None, - start_position: (3, 1), - end_position: (3, 25), + location: Location::new(PathBuf::new(), (3, 1), (3, 25)), receiver: Some("custom_client".to_string()), }), }, @@ -463,7 +464,7 @@ def example(): let extractor = PythonExtractor::new(); // Extract method calls using tree-sitter - let mut result = vec![extractor.parse(&source.content).await]; + let mut result = vec![extractor.parse(&source).await]; assert_eq!(result.first().unwrap().method_calls_ref().len(), 7); // Apply disambiguation @@ -510,6 +511,7 @@ def example(): name: "get_object".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "get_object".to_string(), parameters: vec![ Parameter::Keyword { name: "Bucket".to_string(), @@ -540,8 +542,7 @@ def example(): }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 80), + location: Location::new(PathBuf::new(), (1, 1), (1, 80)), receiver: Some("s3_client".to_string()), }), }; @@ -568,6 +569,7 @@ def example(): name: "get_object".to_string(), possible_services: Vec::new(), metadata: Some(SdkMethodCallMetadata { + expr: "get_object".to_string(), parameters: vec![ // Valid required parameters Parameter::Keyword { @@ -591,8 +593,7 @@ def example(): }, ], return_type: None, - start_position: (1, 1), - end_position: (1, 60), + location: Location::new(PathBuf::new(), (1, 1), (1, 60)), receiver: Some("s3_client".to_string()), }), }; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs index d413b15..0ad48c6 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs @@ -6,8 +6,8 @@ use crate::extraction::python::disambiguation::MethodDisambiguator; use crate::extraction::python::paginator_extractor::PaginatorExtractor; use crate::extraction::python::resource_direct_calls_extractor::ResourceDirectCallsExtractor; use crate::extraction::python::waiters_extractor::WaitersExtractor; -use crate::extraction::{SdkMethodCall, SdkMethodCallMetadata}; -use crate::ServiceModelIndex; +use crate::extraction::{AstWithSourceFile, SdkMethodCall, SdkMethodCallMetadata}; +use crate::{Location, ServiceModelIndex, SourceFile}; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Python; use async_trait::async_trait; @@ -24,6 +24,7 @@ impl PythonExtractor { fn parse_method_call( &self, node_match: &ast_grep_core::NodeMatch>, + source_file: &SourceFile, ) -> Option { let env = node_match.get_env(); @@ -43,19 +44,17 @@ impl PythonExtractor { let args_nodes = env.get_multiple_matches("ARGS"); let arguments = ArgumentExtractor::extract_arguments(&args_nodes); - // Get position information - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - let method_call = SdkMethodCall { name: method_name.to_string(), possible_services: Vec::new(), // Will be determined later during service validation metadata: Some(SdkMethodCallMetadata { parameters: arguments, return_type: None, // We don't know the return type from the call site - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + expr: node_match.text().to_string(), + location: Location::from_node( + source_file.path.to_path_buf(), + node_match.get_node(), + ), receiver, }), }; @@ -73,9 +72,13 @@ impl Default for PythonExtractor { #[async_trait] impl Extractor for PythonExtractor { - async fn parse(&self, source_code: &str) -> crate::extraction::extractor::ExtractorResult { - let ast_grep = Python.ast_grep(source_code); - let root = ast_grep.root(); + async fn parse( + &self, + source_file: &SourceFile, + ) -> crate::extraction::extractor::ExtractorResult { + let ast_grep = Python.ast_grep(&source_file.content); + let ast = AstWithSourceFile::new(ast_grep, source_file.clone()); + let root = ast.ast.root(); let mut method_calls = Vec::new(); @@ -83,12 +86,12 @@ impl Extractor for PythonExtractor { // Find all method calls with attribute access: obj.method(args) for node_match in root.find_all(pattern) { - if let Some(method_call) = self.parse_method_call(&node_match) { + if let Some(method_call) = self.parse_method_call(&node_match, source_file) { method_calls.push(method_call); } } - ExtractorResult::Python(ast_grep, method_calls) + ExtractorResult::Python(ast, method_calls) } fn filter_map( @@ -148,15 +151,25 @@ impl Extractor for PythonExtractor { #[cfg(test)] mod tests { + use std::path::PathBuf; + use super::*; - use crate::extraction::{Parameter, ParameterValue}; + use crate::{ + extraction::{Parameter, ParameterValue}, + Language, + }; #[tokio::test] async fn test_basic_method_call_extraction() { let extractor = PythonExtractor::new(); let source_code = "s3_client.get_object(Bucket='my-bucket', Key='my-key')"; - let result = extractor.parse(source_code).await; + let source_file = SourceFile::with_language( + std::path::PathBuf::new(), + source_code.to_string(), + crate::Language::Python, + ); + let result = extractor.parse(&source_file).await; assert_eq!(result.method_calls_ref().len(), 1); assert_eq!(result.method_calls_ref()[0].name, "get_object"); } @@ -177,8 +190,9 @@ cloudwatch_client.put_metric_alarm( Threshold=0.0 ) "#; - - let result = extractor.parse(source_code).await; + let source_file = + SourceFile::with_language(PathBuf::new(), source_code.to_string(), Language::Python); + let result = extractor.parse(&source_file).await; // Verify exactly one method call is extracted assert_eq!(result.method_calls_ref().len(), 1); diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs index 28b4b80..a11dc1a 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs @@ -4,11 +4,13 @@ //! two-phase operations: creating a paginator from a client, then executing //! the paginator with operation arguments. +use std::path::Path; + use crate::extraction::python::common::{ArgumentExtractor, ParameterFilter}; use crate::extraction::sdk_model::ServiceDiscovery; -use crate::extraction::{Parameter, SdkMethodCall, SdkMethodCallMetadata}; -use crate::Language; +use crate::extraction::{AstWithSourceFile, Parameter, SdkMethodCall, SdkMethodCallMetadata}; use crate::ServiceModelIndex; +use crate::{Language, Location}; use ast_grep_language::Python; /// Information about a discovered get_paginator call @@ -20,8 +22,16 @@ pub(crate) struct PaginatorInfo { pub operation_name: String, /// Client receiver variable name (e.g., "client", "s3_client") pub client_receiver: String, - /// Line number where get_paginator was called - pub get_paginator_line: usize, + /// Matched expression + pub expr: String, + /// Location where get_paginator was called + pub location: Location, +} + +impl PaginatorInfo { + fn start_line(&self) -> usize { + self.location.start_line() + } } /// Information about a paginate method call @@ -31,12 +41,16 @@ pub(crate) struct PaginateCallInfo { pub paginator_var: String, /// Extracted arguments (excluding pagination-specific ones) pub arguments: Vec, - /// Line number where paginate was called (preferred for position reporting) - pub paginate_line: usize, - /// Start position of the paginate call node - pub start_position: (usize, usize), - /// End position of the paginate call node - pub end_position: (usize, usize), + /// Matched expression + pub expr: String, + /// Location where paginator was called + pub location: Location, +} + +impl PaginateCallInfo { + fn start_line(&self) -> usize { + self.location.start_line() + } } /// Information about a chained paginator call (client.get_paginator().paginate()) @@ -48,13 +62,10 @@ pub(crate) struct ChainedPaginatorCallInfo { pub operation_name: String, /// Extracted arguments from paginate call (excluding pagination-specific ones) pub arguments: Vec, - /// Line number where chained call was made - #[allow(dead_code)] - pub line: usize, - /// Start position of the chained call node - pub start_position: (usize, usize), - /// End position of the chained call node - pub end_position: (usize, usize), + /// Matched expression + pub expr: String, + /// Location where paginator was called + pub location: Location, } /// Extractor for boto3 paginate method patterns @@ -101,7 +112,7 @@ impl<'a> PaginatorExtractor<'a> { /// empty parameters, since paginators are often created but used elsewhere. pub(crate) fn extract_paginate_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { // Step 1: Find all get_paginator calls let paginators = self.find_get_paginator_calls(ast); @@ -141,18 +152,17 @@ impl<'a> PaginatorExtractor<'a> { } /// Find all get_paginator calls in the AST - fn find_get_paginator_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_get_paginator_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut paginators = Vec::new(); // Pattern: $PAGINATOR = $CLIENT.get_paginator($OPERATION $$$ARGS) let get_paginator_pattern = "$PAGINATOR = $CLIENT.get_paginator($OPERATION $$$ARGS)"; for node_match in root.find_all(get_paginator_pattern) { - if let Some(paginator_info) = self.parse_get_paginator_call(&node_match) { + if let Some(paginator_info) = + self.parse_get_paginator_call(&node_match, &ast.source_file.path) + { paginators.push(paginator_info); } } @@ -161,18 +171,17 @@ impl<'a> PaginatorExtractor<'a> { } /// Find all paginate calls in the AST - fn find_paginate_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_paginate_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut paginate_calls = Vec::new(); // Pattern: $PAGINATOR.paginate($$$ARGS) - flexible pattern without assignment requirement let paginate_pattern = "$PAGINATOR.paginate($$$ARGS)"; for node_match in root.find_all(paginate_pattern) { - if let Some(paginate_info) = self.parse_paginate_call(&node_match) { + if let Some(paginate_info) = + self.parse_paginate_call(&node_match, &ast.source_file.path) + { paginate_calls.push(paginate_info); } } @@ -183,9 +192,9 @@ impl<'a> PaginatorExtractor<'a> { /// Find all chained paginator calls in the AST fn find_chained_paginator_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut chained_calls = Vec::new(); // Pattern: $CLIENT.get_paginator($OPERATION $$$GET_ARGS).paginate($$$PAGINATE_ARGS) @@ -193,7 +202,9 @@ impl<'a> PaginatorExtractor<'a> { "$CLIENT.get_paginator($OPERATION $$$GET_ARGS).paginate($$$PAGINATE_ARGS)"; for node_match in root.find_all(chained_pattern) { - if let Some(chained_info) = self.parse_chained_paginator_call(&node_match) { + if let Some(chained_info) = + self.parse_chained_paginator_call(&node_match, &ast.source_file.path) + { chained_calls.push(chained_info); } } @@ -205,6 +216,7 @@ impl<'a> PaginatorExtractor<'a> { fn parse_get_paginator_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -219,14 +231,12 @@ impl<'a> PaginatorExtractor<'a> { let operation_text = operation_node.text(); let operation_name = self.extract_quoted_string(&operation_text)?; - // Get line number - let get_paginator_line = node_match.get_node().start_pos().line() + 1; - Some(PaginatorInfo { variable_name, operation_name, client_receiver, - get_paginator_line, + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -234,6 +244,7 @@ impl<'a> PaginatorExtractor<'a> { fn parse_paginate_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -245,17 +256,11 @@ impl<'a> PaginatorExtractor<'a> { let all_arguments = self.extract_arguments(&args_nodes); let filtered_arguments = self.filter_pagination_parameters(all_arguments); - // Get position information from the paginate call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(PaginateCallInfo { paginator_var, arguments: filtered_arguments, - paginate_line: start.line() + 1, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -263,6 +268,7 @@ impl<'a> PaginatorExtractor<'a> { fn parse_chained_paginator_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -279,18 +285,12 @@ impl<'a> PaginatorExtractor<'a> { let all_arguments = self.extract_arguments(&paginate_args_nodes); let filtered_arguments = self.filter_pagination_parameters(all_arguments); - // Get position information from the chained call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(ChainedPaginatorCallInfo { client_receiver, operation_name, arguments: filtered_arguments, - line: start.line() + 1, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -317,8 +317,8 @@ impl<'a> PaginatorExtractor<'a> { for (idx, paginator) in paginators.iter().enumerate() { if paginator.variable_name == paginate_call.paginator_var { // Only consider paginators that come before the paginate call - if paginator.get_paginator_line < paginate_call.paginate_line { - let distance = paginate_call.paginate_line - paginator.get_paginator_line; + if paginator.start_line() < paginate_call.start_line() { + let distance = paginate_call.start_line() - paginator.start_line(); if distance < best_distance { best_distance = distance; best_match = Some(paginator); @@ -360,9 +360,9 @@ impl<'a> PaginatorExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: Vec::new(), // Empty parameters for unmatched paginators return_type: None, + expr: paginator_info.expr.clone(), // Use get_paginator call position - start_position: (paginator_info.get_paginator_line, 1), - end_position: (paginator_info.get_paginator_line, 1), + location: paginator_info.location.clone(), receiver: Some(paginator_info.client_receiver.clone()), }), } @@ -395,9 +395,9 @@ impl<'a> PaginatorExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: paginate_call.arguments.clone(), return_type: None, + expr: paginate_call.expr.clone(), // Use paginate call position (most specific) - start_position: paginate_call.start_position, - end_position: paginate_call.end_position, + location: paginate_call.location.clone(), // Use client receiver from get_paginator call receiver: Some(paginator_info.client_receiver.clone()), }), @@ -430,9 +430,9 @@ impl<'a> PaginatorExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: chained_call.arguments.clone(), return_type: None, + expr: chained_call.expr.clone(), // Use chained call position - start_position: chained_call.start_position, - end_position: chained_call.end_position, + location: chained_call.location.clone(), // Use client receiver from chained call receiver: Some(chained_call.client_receiver.clone()), }), @@ -452,17 +452,21 @@ impl<'a> PaginatorExtractor<'a> { #[cfg(test)] mod tests { - use crate::extraction::ParameterValue; + use crate::{extraction::ParameterValue, SourceFile}; use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Python; - use std::collections::HashMap; + use std::{collections::HashMap, path::PathBuf}; - fn create_test_ast( - source_code: &str, - ) -> ast_grep_core::AstGrep> { - Python.ast_grep(source_code) + fn create_test_ast(source_code: &str) -> AstWithSourceFile { + let source_file = SourceFile::with_language( + PathBuf::new(), + source_code.to_string(), + crate::Language::Python, + ); + let ast_grep = Python.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file.clone()) } fn create_test_service_index() -> ServiceModelIndex { @@ -700,7 +704,7 @@ page_iterator = paginator.paginate(Bucket='test-bucket') assert_eq!(call.name, "list_objects_v2"); // Position should be from the paginate call (line 5), not get_paginator call (line 4) - assert_eq!(call.metadata.as_ref().unwrap().start_position.0, 5); + assert_eq!(call.metadata.as_ref().unwrap().location.start_line(), 5); assert_eq!( call.metadata.as_ref().unwrap().receiver, Some("client".to_string()) diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs index 161d576..66fd240 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs @@ -38,18 +38,14 @@ use crate::extraction::python::boto3_resources_model::{ Boto3ResourcesModel, Boto3ResourcesRegistry, HasManySpec, OperationType, }; use crate::extraction::python::common::ArgumentExtractor; -use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; -use crate::ServiceModelIndex; +use crate::extraction::{ + AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, +}; +use crate::{Location, ServiceModelIndex}; use ast_grep_language::Python; use convert_case::{Case, Casing}; use std::collections::{HashMap, HashSet}; - -/// Position tracking for deduplication (Tier 3) -#[derive(Debug, Clone, Hash, Eq, PartialEq)] -struct MatchedPosition { - line: usize, - column: usize, -} +use std::path::Path; /// Information about a discovered resource constructor call #[derive(Debug, Clone)] @@ -70,9 +66,14 @@ struct ResourceMethodCallInfo { resource_var: String, method_name: String, arguments: Vec, - method_call_line: usize, - start_position: (usize, usize), - end_position: (usize, usize), + expr: String, + location: Location, +} + +impl ResourceMethodCallInfo { + fn start_line(&self) -> usize { + self.location.start_line() + } } /// Resource usage classification for two-tier approach (Tier 3 is now separate) @@ -124,7 +125,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// **Tier 3**: Constructor only → maximum conservation with constructor position as evidence pub(crate) fn extract_resource_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { // Step 1: Find all resource constructors using service-agnostic matching let constructors = self.find_resource_constructors(ast, &self.registry); @@ -173,18 +174,15 @@ impl<'a> ResourceDirectCallsExtractor<'a> { all_calls.extend(collection_synthetics); // Step 6: Collect matched positions for Tier 3 deduplication - let mut matched_positions = HashSet::new(); + let mut matched_locations = HashSet::new(); for call in &all_calls { if let Some(metadata) = &call.metadata { - matched_positions.insert(MatchedPosition { - line: metadata.start_position.0, - column: metadata.start_position.1, - }); + matched_locations.insert(metadata.location.clone()); } } // Step 7: New Tier 3 - service-agnostic fallback for unknown receivers - let tier3_calls = self.find_unmatched_utility_and_collection_calls(ast, &matched_positions); + let tier3_calls = self.find_unmatched_utility_and_collection_calls(ast, &matched_locations); all_calls.extend(tier3_calls); all_calls @@ -414,8 +412,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, - start_position: method_call.start_position, - end_position: method_call.end_position, + expr: method_call.expr.clone(), + location: method_call.location.clone(), receiver: Some(method_call.resource_var.clone()), }), }); @@ -439,9 +437,9 @@ impl<'a> ResourceDirectCallsExtractor<'a> { let mut synthetic_calls = Vec::new(); // Extract position from evidence source - let (start_pos, end_pos) = match evidence { + let (expr, location) = match evidence { SyntheticEvidenceSource::UnmatchedMethod(ref method_call) => { - (method_call.start_position, method_call.end_position) + (method_call.expr.clone(), method_call.location.clone()) } }; @@ -496,8 +494,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, - start_position: start_pos, - end_position: end_pos, + expr: expr.clone(), + location: location.clone(), receiver: Some(constructor.variable_name.clone()), // Use actual variable name from constructor }), }); @@ -509,10 +507,10 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// Find all resource constructor calls in the AST using service-agnostic matching fn find_resource_constructors( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, registry: &Boto3ResourcesRegistry, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut constructors = Vec::new(); // Service-agnostic pattern: $VAR = $ANY.$RESOURCE_TYPE($$$ARGS) @@ -582,15 +580,17 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// Find all method calls on potential resource objects fn find_resource_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut method_calls = Vec::new(); let method_call_pattern = "$RESULT = $RESOURCE_VAR.$METHOD($$$ARGS)"; for node_match in root.find_all(method_call_pattern) { - if let Some(method_call_info) = self.parse_resource_method_call(&node_match) { + if let Some(method_call_info) = + self.parse_resource_method_call(&node_match, &ast.source_file.path) + { method_calls.push(method_call_info); } } @@ -599,7 +599,9 @@ impl<'a> ResourceDirectCallsExtractor<'a> { let simple_method_pattern = "$RESOURCE_VAR.$METHOD($$$ARGS)"; for node_match in root.find_all(simple_method_pattern) { - if let Some(method_call_info) = self.parse_simple_resource_method_call(&node_match) { + if let Some(method_call_info) = + self.parse_simple_resource_method_call(&node_match, &ast.source_file.path) + { method_calls.push(method_call_info); } } @@ -609,12 +611,12 @@ impl<'a> ResourceDirectCallsExtractor<'a> { a.resource_var .cmp(&b.resource_var) .then(a.method_name.cmp(&b.method_name)) - .then(a.method_call_line.cmp(&b.method_call_line)) + .then(a.start_line().cmp(&b.start_line())) }); method_calls.dedup_by(|a, b| { a.resource_var == b.resource_var && a.method_name == b.method_name - && a.method_call_line == b.method_call_line + && a.start_line() == b.start_line() }); method_calls @@ -624,6 +626,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { fn parse_resource_method_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -637,18 +640,12 @@ impl<'a> ResourceDirectCallsExtractor<'a> { let args_nodes = env.get_multiple_matches("ARGS"); let arguments = ArgumentExtractor::extract_arguments(&args_nodes); - // Get position information from the method call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(ResourceMethodCallInfo { resource_var, method_name, arguments, - method_call_line: start.line() + 1, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -656,6 +653,7 @@ impl<'a> ResourceDirectCallsExtractor<'a> { fn parse_simple_resource_method_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -669,18 +667,12 @@ impl<'a> ResourceDirectCallsExtractor<'a> { let args_nodes = env.get_multiple_matches("ARGS"); let arguments = ArgumentExtractor::extract_arguments(&args_nodes); - // Get position information from the method call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(ResourceMethodCallInfo { resource_var, method_name, arguments, - method_call_line: start.line() + 1, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -804,8 +796,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters: combined_parameters, return_type: None, - start_position: method_call.start_position, - end_position: method_call.end_position, + expr: method_call.expr.clone(), + location: method_call.location.clone(), receiver: Some(method_call.resource_var.clone()), }), }) @@ -817,10 +809,10 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// Generates synthetic SdkMethodCall for the collection's operation at the access point fn find_and_generate_collection_synthetics( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, constructors: &[ResourceConstructorInfo], ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut synthetic_calls = Vec::new(); // Pattern: $VAR = $RESOURCE_VAR.$ATTR_NAME (with optional assignment) @@ -862,22 +854,21 @@ impl<'a> ResourceDirectCallsExtractor<'a> { }; // Check if this attribute matches a hasMany collection (in snake_case) - if let Some(has_many_spec) = - boto3_model.get_has_many_spec(&constructor.resource_type, &attr_name) + if let Some(synthetic_call) = boto3_model + .get_has_many_spec(&constructor.resource_type, &attr_name) + .and_then(|has_many_spec| { + self.generate_synthetic_for_collection( + constructor, + has_many_spec, + node_match.text().to_string(), + Location::from_node( + ast.source_file.path.to_path_buf(), + node_match.get_node(), + ), + ) + }) { - // Generate synthetic call for the collection's operation - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - - if let Some(synthetic_call) = self.generate_synthetic_for_collection( - constructor, - has_many_spec, - (start.line() + 1, start.column(node) + 1), - (end.line() + 1, end.column(node) + 1), - ) { - synthetic_calls.push(synthetic_call); - } + synthetic_calls.push(synthetic_call); } } } @@ -890,8 +881,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &self, constructor: &ResourceConstructorInfo, has_many_spec: &HasManySpec, - start_position: (usize, usize), - end_position: (usize, usize), + expr: String, + location: Location, ) -> Option { let mut parameters = Vec::new(); @@ -925,8 +916,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, - start_position, - end_position, + expr: expr.clone(), + location, receiver: Some(constructor.variable_name.clone()), // Use actual variable name from constructor }), }) @@ -939,16 +930,16 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// all-synthetic parameters since we don't know the receiver. fn find_unmatched_utility_and_collection_calls( &self, - ast: &ast_grep_core::AstGrep>, - matched_positions: &HashSet, + ast: &AstWithSourceFile, + matched_locations: &HashSet, ) -> Vec { let mut tier3_calls = Vec::new(); // Search for utility method calls across all services - tier3_calls.extend(self.find_unmatched_utility_method_calls(ast, matched_positions)); + tier3_calls.extend(self.find_unmatched_utility_method_calls(ast, matched_locations)); // Search for collection accesses across all services - tier3_calls.extend(self.find_unmatched_collection_accesses(ast, matched_positions)); + tier3_calls.extend(self.find_unmatched_collection_accesses(ast, matched_locations)); tier3_calls } @@ -956,10 +947,10 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// Find utility method calls with unknown receivers (Tier 3) fn find_unmatched_utility_method_calls( &self, - ast: &ast_grep_core::AstGrep>, - matched_positions: &HashSet, + ast: &AstWithSourceFile, + matched_locations: &HashSet, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut calls = Vec::new(); // Pattern for method calls @@ -981,16 +972,11 @@ impl<'a> ResourceDirectCallsExtractor<'a> { None => continue, }; - // Get position - let node = node_match.get_node(); - let start = node.start_pos(); - let position = MatchedPosition { - line: start.line() + 1, - column: start.column(node) + 1, - }; + let location = + Location::from_node(ast.source_file.path.clone(), node_match.get_node()); // Skip if already matched in Tier 1/2 - if matched_positions.contains(&position) { + if matched_locations.contains(&location) { continue; } @@ -1017,8 +1003,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &operation.operation, &arguments, &operation.required_params, - (start.line() + 1, start.column(node) + 1), - (node.end_pos().line() + 1, node.end_pos().column(node) + 1), + node_match.text().to_string(), + &location, &receiver_var, // Use actual receiver from code )); } @@ -1035,8 +1021,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &operation.operation, &arguments, &operation.required_params, - (start.line() + 1, start.column(node) + 1), - (node.end_pos().line() + 1, node.end_pos().column(node) + 1), + node_match.text().to_string(), + &location, &receiver_var, // Use actual receiver from code )); } @@ -1052,10 +1038,10 @@ impl<'a> ResourceDirectCallsExtractor<'a> { /// Find collection accesses with unknown receivers (Tier 3) fn find_unmatched_collection_accesses( &self, - ast: &ast_grep_core::AstGrep>, - matched_positions: &HashSet, + ast: &AstWithSourceFile, + matched_locations: &HashSet, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut calls = Vec::new(); // Patterns for attribute access (including chained method calls) @@ -1082,16 +1068,11 @@ impl<'a> ResourceDirectCallsExtractor<'a> { None => continue, }; - // Get position - let node = node_match.get_node(); - let start = node.start_pos(); - let position = MatchedPosition { - line: start.line() + 1, - column: start.column(node) + 1, - }; + let location = + Location::from_node(ast.source_file.path.clone(), node_match.get_node()); // Skip if already matched in Tier 1/2 - if matched_positions.contains(&position) { + if matched_locations.contains(&location) { continue; } @@ -1109,11 +1090,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &has_many_spec.identifier_params, ), return_type: None, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: ( - node.end_pos().line() + 1, - node.end_pos().column(node) + 1, - ), + expr: node_match.text().to_string(), + location: location.clone(), receiver: Some(receiver_var.clone()), // Use actual receiver from code }), }); @@ -1134,11 +1112,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { &service_has_many_spec.identifier_params, ), return_type: None, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: ( - node.end_pos().line() + 1, - node.end_pos().column(node) + 1, - ), + expr: node_match.text().to_string(), + location: location.clone(), receiver: Some(receiver_var.clone()), // Use actual receiver from code }), }); @@ -1158,8 +1133,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { operation: &str, arguments: &[Parameter], required_params: &[String], - start_position: (usize, usize), - end_position: (usize, usize), + expr: String, + location: &Location, receiver_marker: &str, ) -> SdkMethodCall { let mut parameters = Vec::new(); @@ -1194,8 +1169,8 @@ impl<'a> ResourceDirectCallsExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, - start_position, - end_position, + expr: expr.clone(), + location: location.clone(), receiver: Some(receiver_marker.to_string()), }), } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs index 3bd8fb8..2ed755d 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs @@ -4,9 +4,13 @@ //! two-phase operations: creating a waiter from a client, then calling wait() //! on the waiter with operation arguments. +use std::path::Path; + use crate::extraction::python::common::{ArgumentExtractor, ParameterFilter}; -use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; -use crate::ServiceModelIndex; +use crate::extraction::{ + AstWithSourceFile, Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata, +}; +use crate::{Location, ServiceModelIndex}; use ast_grep_language::Python; /// Information about a discovered get_waiter call @@ -18,8 +22,16 @@ pub(crate) struct WaiterInfo { pub waiter_name: String, /// Client receiver variable name (e.g., "client", "ec2_client") pub client_receiver: String, - /// Line number where get_waiter was called - pub get_waiter_line: usize, + /// Matched expression + pub expr: String, + /// Location where we found the waiter + pub location: Location, +} + +impl WaiterInfo { + pub(crate) fn start_line(&self) -> usize { + self.location.start_line() + } } // TODO: This should be refactored at a higher level, so this type can be removed. @@ -37,6 +49,22 @@ impl<'a> CallInfo<'a> { Self::Chained(waiter_call_info) => &waiter_call_info.waiter_name, } } + + fn expr(&self) -> &'a str { + match self { + CallInfo::None(waiter_info) => &waiter_info.expr, + CallInfo::Simple(_, wait_call_info) => &wait_call_info.expr, + CallInfo::Chained(chained_waiter_call_info) => &chained_waiter_call_info.expr, + } + } + + fn location(&self) -> &'a Location { + match self { + CallInfo::None(waiter_info) => &waiter_info.location, + CallInfo::Simple(_, wait_call_info) => &wait_call_info.location, + CallInfo::Chained(chained_waiter_call_info) => &chained_waiter_call_info.location, + } + } } /// Information about a wait method call @@ -46,12 +74,16 @@ pub(crate) struct WaitCallInfo { pub waiter_var: String, /// Extracted arguments (including WaiterConfig) pub arguments: Vec, - /// Line number where wait was called - pub wait_line: usize, - /// Start position of the wait call node - pub start_position: (usize, usize), - /// End position of the wait call node - pub end_position: (usize, usize), + /// Matched expression + pub expr: String, + /// Location where we found the waiter + pub location: Location, +} + +impl WaitCallInfo { + pub(crate) fn start_line(&self) -> usize { + self.location.start_line() + } } /// Information about a chained waiter call (client.get_waiter().wait()) @@ -63,13 +95,10 @@ pub(crate) struct ChainedWaiterCallInfo { pub waiter_name: String, /// Extracted arguments from wait call (including WaiterConfig) pub arguments: Vec, - /// Line number where chained call was made - #[allow(dead_code)] - pub line: usize, - /// Start position of the chained call node - pub start_position: (usize, usize), - /// End position of the chained call node - pub end_position: (usize, usize), + /// Matched expression + pub expr: String, + /// Location where we found the waiter call + pub location: Location, } /// Extractor for boto3 waiter patterns @@ -107,7 +136,7 @@ impl<'a> WaitersExtractor<'a> { /// 3. Unmatched wait: Ignored (no waiter context) pub(crate) fn extract_waiter_method_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { // Step 1: Find all get_waiter calls let waiters = self.find_get_waiter_calls(ast); @@ -150,18 +179,17 @@ impl<'a> WaitersExtractor<'a> { } /// Find all get_waiter calls in the AST - fn find_get_waiter_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_get_waiter_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut waiters = Vec::new(); // Pattern: $WAITER = $CLIENT.get_waiter($NAME $$$ARGS) let get_waiter_pattern = "$WAITER = $CLIENT.get_waiter($NAME $$$ARGS)"; for node_match in root.find_all(get_waiter_pattern) { - if let Some(waiter_info) = self.parse_get_waiter_call(&node_match) { + if let Some(waiter_info) = + self.parse_get_waiter_call(&node_match, &ast.source_file.path) + { waiters.push(waiter_info); } } @@ -170,11 +198,8 @@ impl<'a> WaitersExtractor<'a> { } /// Find all wait calls in the AST - fn find_wait_calls( - &self, - ast: &ast_grep_core::AstGrep>, - ) -> Vec { - let root = ast.root(); + fn find_wait_calls(&self, ast: &AstWithSourceFile) -> Vec { + let root = ast.ast.root(); let mut wait_calls = Vec::new(); // Pattern: $WAITER.wait($$$ARGS) @@ -182,7 +207,7 @@ impl<'a> WaitersExtractor<'a> { let wait_pattern = "$WAITER.wait($$$ARGS)"; for node_match in root.find_all(wait_pattern) { - if let Some(wait_info) = self.parse_wait_call(&node_match) { + if let Some(wait_info) = self.parse_wait_call(&node_match, &ast.source_file.path) { wait_calls.push(wait_info); } } @@ -193,16 +218,18 @@ impl<'a> WaitersExtractor<'a> { /// Find all chained waiter calls in the AST fn find_chained_waiter_calls( &self, - ast: &ast_grep_core::AstGrep>, + ast: &AstWithSourceFile, ) -> Vec { - let root = ast.root(); + let root = ast.ast.root(); let mut chained_calls = Vec::new(); // Pattern: $CLIENT.get_waiter($NAME $$$WAITER_ARGS).wait($$$WAIT_ARGS) let chained_pattern = "$CLIENT.get_waiter($NAME $$$WAITER_ARGS).wait($$$WAIT_ARGS)"; for node_match in root.find_all(chained_pattern) { - if let Some(chained_info) = self.parse_chained_waiter_call(&node_match) { + if let Some(chained_info) = + self.parse_chained_waiter_call(&node_match, &ast.source_file.path) + { chained_calls.push(chained_info); } } @@ -214,6 +241,7 @@ impl<'a> WaitersExtractor<'a> { fn parse_get_waiter_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -228,14 +256,12 @@ impl<'a> WaitersExtractor<'a> { let name_text = name_node.text(); let waiter_name = self.extract_quoted_string(&name_text)?; - // Get line number - let get_waiter_line = node_match.get_node().start_pos().line() + 1; - Some(WaiterInfo { variable_name, waiter_name, client_receiver, - get_waiter_line, + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -243,6 +269,7 @@ impl<'a> WaitersExtractor<'a> { fn parse_wait_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -253,17 +280,11 @@ impl<'a> WaitersExtractor<'a> { let args_nodes = env.get_multiple_matches("ARGS"); let arguments = ArgumentExtractor::extract_arguments(&args_nodes); - // Get position information from the wait call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(WaitCallInfo { waiter_var, arguments, - wait_line: start.line() + 1, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -271,6 +292,7 @@ impl<'a> WaitersExtractor<'a> { fn parse_chained_waiter_call( &self, node_match: &ast_grep_core::NodeMatch>, + file_path: &Path, ) -> Option { let env = node_match.get_env(); @@ -286,18 +308,12 @@ impl<'a> WaitersExtractor<'a> { let wait_args_nodes = env.get_multiple_matches("WAIT_ARGS"); let arguments = ArgumentExtractor::extract_arguments(&wait_args_nodes); - // Get position information from the chained call node - let node = node_match.get_node(); - let start = node.start_pos(); - let end = node.end_pos(); - Some(ChainedWaiterCallInfo { client_receiver, waiter_name, arguments, - line: start.line() + 1, - start_position: (start.line() + 1, start.column(node) + 1), - end_position: (end.line() + 1, end.column(node) + 1), + expr: node_match.text().to_string(), + location: Location::from_node(file_path.to_path_buf(), node_match.get_node()), }) } @@ -316,8 +332,8 @@ impl<'a> WaitersExtractor<'a> { for (idx, waiter) in waiters.iter().enumerate() { if waiter.variable_name == wait_call.waiter_var { // Only consider waiters that come before the wait call - if waiter.get_waiter_line < wait_call.wait_line { - let distance = wait_call.wait_line - waiter.get_waiter_line; + if waiter.start_line() < wait_call.start_line() { + let distance = wait_call.start_line() - waiter.start_line(); if distance < best_distance { best_distance = distance; best_match = Some(waiter); @@ -346,36 +362,22 @@ impl<'a> WaitersExtractor<'a> { .get(wait_call.waiter_name()) { for service_method in service_defs { - let (parameters, start_position, end_position) = match wait_call { + let parameters = match wait_call { CallInfo::Simple(_, wait_call) => { // Filter out WaiterConfig from arguments - it's waiter-specific, not operation-specific - ( - ParameterFilter::filter_waiter_parameters(wait_call.arguments.clone()), - wait_call.start_position, - wait_call.end_position, - ) + ParameterFilter::filter_waiter_parameters(wait_call.arguments.clone()) } CallInfo::Chained(chained_wait_call) => { // Filter out WaiterConfig from arguments - it's waiter-specific, not operation-specific - ( - ParameterFilter::filter_waiter_parameters( - chained_wait_call.arguments.clone(), - ), - // Use wait call position (most specific) - chained_wait_call.start_position, - chained_wait_call.end_position, + ParameterFilter::filter_waiter_parameters( + chained_wait_call.arguments.clone(), ) } - CallInfo::None(waiter_info) => { - let fallback_start_pos = (waiter_info.get_waiter_line, 1); - let fallback_end_pos = (waiter_info.get_waiter_line, 1); - let parameters = self.get_required_parameters( - &service_method.service_name, - &service_method.operation_name, - self.service_index, - ); - (parameters, fallback_start_pos, fallback_end_pos) - } + CallInfo::None(_) => self.get_required_parameters( + &service_method.service_name, + &service_method.operation_name, + self.service_index, + ), }; // Create synthetic call with filtered wait() arguments @@ -385,8 +387,8 @@ impl<'a> WaitersExtractor<'a> { metadata: Some(SdkMethodCallMetadata { parameters, return_type: None, - start_position, - end_position, + expr: wait_call.expr().to_string(), + location: wait_call.location().clone(), // Use client receiver from get_waiter call receiver: receiver.clone(), }), @@ -471,16 +473,22 @@ impl<'a> WaitersExtractor<'a> { #[cfg(test)] mod tests { use crate::extraction::sdk_model::ServiceMethodRef; + use crate::SourceFile; use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Python; use std::collections::HashMap; + use std::path::PathBuf; - fn create_test_ast( - source_code: &str, - ) -> ast_grep_core::AstGrep> { - Python.ast_grep(source_code) + fn create_test_ast(source_code: &str) -> AstWithSourceFile { + let source_file = SourceFile::with_language( + PathBuf::new(), + source_code.to_string(), + crate::Language::Python, + ); + let ast_grep = Python.ast_grep(&source_file.content); + AstWithSourceFile::new(ast_grep, source_file.clone()) } fn create_test_service_index() -> ServiceModelIndex { @@ -638,7 +646,7 @@ waiter = ec2_client.get_waiter('instance_terminated') assert_eq!(waiters[0].variable_name, "waiter"); assert_eq!(waiters[0].waiter_name, "instance_terminated"); assert_eq!(waiters[0].client_receiver, "ec2_client"); - assert_eq!(waiters[0].get_waiter_line, 4); + assert_eq!(waiters[0].start_line(), 4); } #[test] diff --git a/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs index 85d53f4..db27ac3 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs @@ -10,7 +10,8 @@ use std::collections::HashSet; use crate::extraction::extractor::{Extractor, ExtractorResult}; use crate::extraction::javascript::scanner::ASTScanner; use crate::extraction::javascript::shared::ExtractionUtils; -use crate::ServiceModelIndex; +use crate::extraction::AstWithSourceFile; +use crate::{ServiceModelIndex, SourceFile}; /// TypeScript extractor for AWS SDK method calls pub(crate) struct TypeScriptExtractor; @@ -30,9 +31,10 @@ impl Default for TypeScriptExtractor { #[async_trait] impl Extractor for TypeScriptExtractor { - async fn parse(&self, source_code: &str) -> ExtractorResult { + async fn parse(&self, source_file: &SourceFile) -> ExtractorResult { // Create AST once and reuse it - let ast = TypeScript.ast_grep(source_code); + let ast_grep = TypeScript.ast_grep(&source_file.content); + let ast = AstWithSourceFile::new(ast_grep, source_file.clone()); // Create scanner with the pre-built AST let mut scanner = ASTScanner::new(ast.clone(), TypeScript.into()); @@ -55,7 +57,7 @@ impl Extractor for TypeScriptExtractor { &scan_results, )); - // Return TypeScript variant with the same AST (no double construction) + // Return TypeScript variant with the same AST ExtractorResult::TypeScript(ast, method_calls) } @@ -207,6 +209,14 @@ mod tests { } } + fn create_source_file(source_code: &str) -> SourceFile { + SourceFile::with_language( + std::path::PathBuf::new(), + source_code.to_string(), + crate::Language::TypeScript, + ) + } + #[tokio::test] async fn test_parse_typescript_with_types() { let extractor = TypeScriptExtractor::new(); @@ -242,7 +252,7 @@ async function queryUsers(): Promise { } "#; - let result = extractor.parse(typescript_code).await; + let result = extractor.parse(&create_source_file(typescript_code)).await; // Verify TypeScript AST is returned match result { @@ -320,7 +330,7 @@ class MyS3Service implements S3Service { } "#; - let result = extractor.parse(typescript_code).await; + let result = extractor.parse(&create_source_file(typescript_code)).await; // Verify TypeScript extraction with generics match result { @@ -375,7 +385,7 @@ const command = new GetObjectCommand({ Bucket: 'test', Key: 'test.txt' }); "#; // Parse the code - let mut results = vec![extractor.parse(typescript_code).await]; + let mut results = vec![extractor.parse(&create_source_file(typescript_code)).await]; // Build service index with all services for testing let service_index = ServiceDiscovery::load_service_index(Language::TypeScript) @@ -447,7 +457,7 @@ const command = new GetObjectCommand({ Bucket: 'test', Key: 'test.txt' }); "#; // Parse the code - let mut results = vec![extractor.parse(typescript_code).await]; + let mut results = vec![extractor.parse(&create_source_file(typescript_code)).await]; // Load service index let service_index = ServiceDiscovery::load_service_index(Language::TypeScript) diff --git a/iam-policy-autopilot-policy-generation/src/lib.rs b/iam-policy-autopilot-policy-generation/src/lib.rs index 4a68c3f..e9fe721 100644 --- a/iam-policy-autopilot-policy-generation/src/lib.rs +++ b/iam-policy-autopilot-policy-generation/src/lib.rs @@ -32,12 +32,12 @@ pub mod policy_generation; pub mod api; use std::fmt::Display; +use std::path::PathBuf; -pub use enrichment::Engine as EnrichmentEngine; +pub use enrichment::{Engine as EnrichmentEngine, Explanation}; pub use extraction::{Engine as ExtractionEngine, ExtractedMethods, SdkMethodCall, SourceFile}; pub use policy_generation::{ - Effect, Engine as PolicyGenerationEngine, IamPolicy, MethodActionMapping, PolicyType, - PolicyWithMetadata, Statement, + Effect, Engine as PolicyGenerationEngine, IamPolicy, PolicyType, PolicyWithMetadata, Statement, }; // Re-export commonly used types for convenience @@ -45,6 +45,9 @@ pub(crate) use extraction::ServiceModelIndex; pub use providers::FileSystemProvider; pub use providers::JsonProvider; +use schemars::JsonSchema; +use serde::Deserialize; +use serde::Serialize; use crate::errors::ExtractorError; @@ -134,6 +137,155 @@ impl From for String { } } +/// Represents a location in a source file +/// +/// This struct stores file path and position information and serializes +/// to the GNU coding standard (https://www.gnu.org/prep/standards/html_node/Errors.html) +/// format: `filename:startLine.startCol-endLine.endCol` +#[derive(Debug, Clone, PartialEq, Eq, Hash, JsonSchema)] +#[schemars( + description = "File location in GNU coding standard format: filename:startLine.startCol-endLine.endCol" +)] +pub struct Location { + /// File path + pub file_path: PathBuf, + /// Starting position (line, column) - both 1-based + pub start_position: (usize, usize), + /// Ending position (line, column) - both 1-based + pub end_position: (usize, usize), +} + +impl Location { + /// Create a new Location + #[must_use] + pub fn new( + file_path: PathBuf, + start_position: (usize, usize), + end_position: (usize, usize), + ) -> Self { + Self { + file_path, + start_position, + end_position, + } + } + + /// Create a new Location from an AST node + #[must_use] + pub fn from_node( + file_path: PathBuf, + node: &ast_grep_core::Node>, + ) -> Self + where + T: ast_grep_language::LanguageExt, + { + let start = node.start_pos(); + let end = node.end_pos(); + Self { + file_path, + start_position: (start.line() + 1, start.column(node) + 1), + end_position: (end.line() + 1, end.column(node) + 1), + } + } + + /// Line where the finding starts + pub fn start_line(&self) -> usize { + self.start_position.0 + } + + /// Column where the finding starts + pub fn start_col(&self) -> usize { + self.start_position.1 + } + + /// Line where the finding ends + pub fn end_line(&self) -> usize { + self.end_position.0 + } + + /// Column where the finding ends + pub fn end_col(&self) -> usize { + self.end_position.1 + } + + /// Format as GNU coding standard: `filename:startLine.startCol-endLine.endCol` + #[must_use] + pub fn to_gnu_format(&self) -> String { + let path_str = self.file_path.display(); + let (start_line, start_col) = self.start_position; + let (end_line, end_col) = self.end_position; + + format!( + "{}:{}.{}-{}.{}", + path_str, start_line, start_col, end_line, end_col + ) + } +} + +impl Serialize for Location { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&self.to_gnu_format()) + } +} + +impl<'de> Deserialize<'de> for Location { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + let s = String::deserialize(deserializer)?; + Self::from_gnu_format(&s).map_err(serde::de::Error::custom) + } +} + +impl Location { + /// Parse a Location from GNU coding standard format: `filename:startLine.startCol-endLine.endCol` + pub fn from_gnu_format(s: &str) -> Result { + // Find the colon that separates filename from position + let colon_pos = s.rfind(':').ok_or("Missing colon separator")?; + let (file_path_str, position_str) = s.split_at(colon_pos); + let position_str = &position_str[1..]; // Remove the colon + + // Parse the position part: startLine.startCol-endLine.endCol + let dash_pos = position_str.find('-').ok_or("Missing dash separator")?; + let (start_str, end_str) = position_str.split_at(dash_pos); + let end_str = &end_str[1..]; // Remove the dash + + // Parse start position: startLine.startCol + let start_dot_pos = start_str.find('.').ok_or("Missing dot in start position")?; + let (start_line_str, start_col_str) = start_str.split_at(start_dot_pos); + let start_col_str = &start_col_str[1..]; // Remove the dot + + // Parse end position: endLine.endCol + let end_dot_pos = end_str.find('.').ok_or("Missing dot in end position")?; + let (end_line_str, end_col_str) = end_str.split_at(end_dot_pos); + let end_col_str = &end_col_str[1..]; // Remove the dot + + // Convert strings to numbers + let start_line = start_line_str + .parse::() + .map_err(|_| format!("Invalid start line: {}", start_line_str))?; + let start_col = start_col_str + .parse::() + .map_err(|_| format!("Invalid start column: {}", start_col_str))?; + let end_line = end_line_str + .parse::() + .map_err(|_| format!("Invalid end line: {}", end_line_str))?; + let end_col = end_col_str + .parse::() + .map_err(|_| format!("Invalid end column: {}", end_col_str))?; + + Ok(Location { + file_path: PathBuf::from(file_path_str), + start_position: (start_line, start_col), + end_position: (end_line, end_col), + }) + } +} + #[cfg(test)] mod tests { use super::*; @@ -173,4 +325,68 @@ mod tests { assert!(Language::try_from_str("java").is_err()); assert!(Language::try_from_str("").is_err()); } + + #[test] + fn test_location_gnu_format() { + let location = Location::new(PathBuf::from("src/main.rs"), (10, 5), (15, 20)); + + let gnu_format = location.to_gnu_format(); + assert_eq!(gnu_format, "src/main.rs:10.5-15.20"); + } + + #[test] + fn test_location_from_gnu_format() { + let gnu_str = "src/main.rs:10.5-15.20"; + let location = Location::from_gnu_format(gnu_str).unwrap(); + + assert_eq!(location.file_path, PathBuf::from("src/main.rs")); + assert_eq!(location.start_position, (10, 5)); + assert_eq!(location.end_position, (15, 20)); + } + + #[test] + fn test_location_from_gnu_format_with_complex_path() { + let gnu_str = "/home/user/project/src/lib.rs:1.1-100.50"; + let location = Location::from_gnu_format(gnu_str).unwrap(); + + assert_eq!( + location.file_path, + PathBuf::from("/home/user/project/src/lib.rs") + ); + assert_eq!(location.start_position, (1, 1)); + assert_eq!(location.end_position, (100, 50)); + } + + #[test] + fn test_location_serialize_deserialize_roundtrip() { + let original = Location::new(PathBuf::from("test/file.py"), (42, 13), (45, 7)); + + // Serialize to JSON + let json = serde_json::to_string(&original).unwrap(); + assert_eq!(json, "\"test/file.py:42.13-45.7\""); + + // Deserialize back + let deserialized: Location = serde_json::from_str(&json).unwrap(); + + assert_eq!(original, deserialized); + } + + #[test] + fn test_location_from_gnu_format_invalid_formats() { + // Missing colon + assert!(Location::from_gnu_format("src/main.rs10.5-15.20").is_err()); + + // Missing dash + assert!(Location::from_gnu_format("src/main.rs:10.515.20").is_err()); + + // Missing dots + assert!(Location::from_gnu_format("src/main.rs:105-1520").is_err()); + + // Invalid numbers + assert!(Location::from_gnu_format("src/main.rs:abc.5-15.20").is_err()); + assert!(Location::from_gnu_format("src/main.rs:10.xyz-15.20").is_err()); + + // Empty string + assert!(Location::from_gnu_format("").is_err()); + } } diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs index e19ce8b..c93c43f 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/engine.rs @@ -4,12 +4,16 @@ //! The engine processes EnrichedSdkMethodCall instances and creates corresponding IAM policies //! with proper ARN pattern replacement. +use std::collections::BTreeMap; + use super::merge::{PolicyMerger, PolicyMergerConfig}; use super::utils::{ArnParser, ConditionValueProcessor}; -use super::{ActionMapping, IamPolicy, MethodActionMapping, Statement}; -use crate::enrichment::{Action, Condition, EnrichedSdkMethodCall}; +use super::{IamPolicy, Statement}; +use crate::api::model::GeneratePoliciesResult; +use crate::enrichment::{Action, Condition, EnrichedSdkMethodCall, Explanations}; use crate::errors::{ExtractorError, Result}; use crate::policy_generation::{PolicyType, PolicyWithMetadata}; +use crate::Explanation; /// Policy generation engine that converts enriched method calls into IAM policies #[derive(Debug, Clone)] @@ -48,7 +52,7 @@ impl<'a> Engine<'a> { /// Creates one IAM policy per EnrichedSdkMethodCall, with each Action becoming /// a separate statement within the policy. ARN patterns are processed to replace /// placeholder variables with actual values or wildcards. - pub fn generate_policies( + fn generate_individual_policies( &self, enriched_calls: &[EnrichedSdkMethodCall], ) -> Result> { @@ -221,7 +225,7 @@ impl<'a> Engine<'a> { /// /// # Errors /// Returns an error if policy merging fails - pub fn merge_policies( + pub(crate) fn merge_policies( &self, policies: &[PolicyWithMetadata], ) -> Result> { @@ -247,60 +251,57 @@ impl<'a> Engine<'a> { } } - /// Extract method to action mappings from enriched method calls + /// Generate IAM policies with explanations from enriched method calls /// - /// This method processes enriched method calls to create detailed mappings - /// between SDK method calls and their required IAM actions with associated resources. - /// It provides granular visibility into which SDK method calls require which - /// specific IAM actions and their associated resources. + /// This method generates policies and collects explanations for why each action + /// was added. Explanations are deduplicated to avoid redundant entries. /// /// # Arguments /// * `enriched_calls` - Slice of enriched method calls to process /// /// # Returns - /// A vector of method action mappings showing the relationship between - /// method calls and their required IAM actions + /// A PolicyGenerationResult containing policies and deduplicated explanations /// /// # Errors - /// Returns an error if resource processing fails for any action - pub fn extract_action_mappings( + /// Returns an error if policy generation fails + pub fn generate_policies( &self, enriched_calls: &[EnrichedSdkMethodCall], - ) -> Result> { - let mut mappings = Vec::new(); - - for enriched_call in enriched_calls { - let mut action_mappings = Vec::new(); - - for action in &enriched_call.actions { - // Process resources to get ARN patterns using the existing method - let resources = self.process_action_resources(action)?; + ) -> Result { + let policies = self.generate_individual_policies(enriched_calls)?; - let action_mapping = ActionMapping { - action_name: action.name.clone(), - resources, - }; + // Collect explanations + let explanations = extract_explanations(enriched_calls); - action_mappings.push(action_mapping); - } - - let method_mapping = MethodActionMapping { - method_call: enriched_call.method_name.clone(), - service: enriched_call.service.clone(), - actions: action_mappings, - }; + Ok(GeneratePoliciesResult { + policies, + explanations: Some(explanations), + }) + } +} - mappings.push(method_mapping); +fn extract_explanations(enriched_calls: &[EnrichedSdkMethodCall<'_>]) -> Explanations { + let mut explanations: BTreeMap = BTreeMap::new(); + + // Collect and merge explanations for each action name + for call in enriched_calls { + for action in &call.actions { + explanations + .entry(action.name.clone()) + .and_modify(|existing_explanation| { + existing_explanation.merge(action.explanation.clone()); + }) + .or_insert_with(|| action.explanation.clone()); } - - Ok(mappings) } + + Explanations::new(explanations) } #[cfg(test)] mod tests { use super::*; - use crate::SdkMethodCall; + use crate::{Explanation, SdkMethodCall}; use super::super::Effect; use crate::enrichment::{Action, EnrichedSdkMethodCall, Resource}; @@ -335,14 +336,15 @@ mod tests { ]), )], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call, }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - assert_eq!(policies.len(), 1); + let result = engine.generate_policies(&[enriched_call]).unwrap(); + assert_eq!(result.policies.len(), 1); - let policy = &policies[0].policy; + let policy = &result.policies[0].policy; assert_eq!(policy.version, "2012-10-17"); assert_eq!(policy.statements.len(), 1); @@ -371,6 +373,7 @@ mod tests { ]), )], vec![], + Explanation::default(), ), Action::new( "s3:GetObjectVersion".to_string(), @@ -381,15 +384,16 @@ mod tests { ]), )], vec![], + Explanation::default(), ), ], sdk_method_call: &sdk_call, }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - assert_eq!(policies.len(), 1); + let result = engine.generate_policies(&[enriched_call]).unwrap(); + assert_eq!(result.policies.len(), 1); - let policy = &policies[0].policy; + let policy = &result.policies[0].policy; assert_eq!(policy.statements.len(), 2); // Check first statement @@ -415,14 +419,15 @@ mod tests { "s3:ListAllMyBuckets".to_string(), vec![Resource::new("*".to_string(), None)], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call, }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - assert_eq!(policies.len(), 1); + let result = engine.generate_policies(&[enriched_call]).unwrap(); + assert_eq!(result.policies.len(), 1); - let policy = &policies[0].policy; + let policy = &result.policies[0].policy; let statement = &policy.statements[0]; assert_eq!(statement.resource, vec!["*"]); } @@ -449,15 +454,16 @@ mod tests { ) ], vec![], + Explanation::default(), ) ], sdk_method_call: &sdk_call, }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - assert_eq!(policies.len(), 1); + let result = engine.generate_policies(&[enriched_call]).unwrap(); + assert_eq!(result.policies.len(), 1); - let policy = &policies[0].policy; + let policy = &result.policies[0].policy; let statement = &policy.statements[0]; assert_eq!( statement.resource, @@ -480,7 +486,12 @@ mod tests { sdk_method_call: &sdk_call, }; - let action = Action::new("s3:GetObject".to_string(), vec![], vec![]); + let action = Action::new( + "s3:GetObject".to_string(), + vec![], + vec![], + Explanation::default(), + ); // Test first action (index 0) let sid1 = engine.generate_statement_id(&enriched_call, &action, 0); @@ -494,8 +505,8 @@ mod tests { #[test] fn test_generate_policies_empty_input() { let engine = create_test_engine(); - let policies = engine.generate_policies(&[]).unwrap(); - assert!(policies.is_empty()); + let result = engine.generate_policies(&[]).unwrap(); + assert!(result.policies.is_empty()); } #[test] @@ -588,6 +599,7 @@ mod tests { ]), )], vec![], + Explanation::default(), ); let processed_resources = engine.process_action_resources(&action).unwrap(); @@ -621,6 +633,7 @@ mod tests { ), ], vec![], + Explanation::default(), ); let processed_resources = engine.process_action_resources(&action).unwrap(); @@ -652,6 +665,7 @@ mod tests { values: vec!["${region}".to_string(), "us-west-${unknown}".to_string()], }, ], + Explanation::default(), ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -682,6 +696,7 @@ mod tests { key: "s3:ExistingObjectTag/Environment".to_string(), values: vec!["production".to_string(), "staging".to_string()], }], + Explanation::default(), ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -698,7 +713,12 @@ mod tests { let engine = create_test_engine(); // Create an action with no conditions - let action = Action::new("s3:GetObject".to_string(), vec![], vec![]); + let action = Action::new( + "s3:GetObject".to_string(), + vec![], + vec![], + Explanation::default(), + ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -718,6 +738,7 @@ mod tests { key: "s3:ExistingObjectTag/Environment".to_string(), values: vec!["s3.${unknown}.amazonaws.com".to_string()], }], + Explanation::default(), ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -751,6 +772,7 @@ mod tests { values: vec!["s3.${unknown}.amazonaws.com".to_string()], // Unknown placeholder, introduces wildcards }, ], + Explanation::default(), ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -789,6 +811,7 @@ mod tests { "us-west-${unknown}".to_string(), // Unknown placeholder, introduces wildcards ], }], + Explanation::default(), ); let processed_conditions = engine.process_action_conditions(&action).unwrap(); @@ -815,6 +838,7 @@ mod tests { key: "s3:ExistingObjectTag/Environment".to_string(), values: vec!["arn:${partition}:s3:${region}:${account}:bucket/test".to_string()], }], + Explanation::default(), ); let processed_conditions = engine_wildcard_partition @@ -874,4 +898,277 @@ mod tests { crate::enrichment::Operator::StringEquals ); } + + #[test] + fn test_generate_policies_with_explanations() { + use crate::enrichment::{Explanation, Operation, OperationSource, Reason}; + use std::sync::Arc; + + let engine = create_test_engine(); + let sdk_call = create_test_sdk_call(); + + // Create enriched call with explanations + let enriched_call = EnrichedSdkMethodCall { + method_name: "get_object".to_string(), + service: "s3".to_string(), + actions: vec![Action::new( + "s3:GetObject".to_string(), + vec![Resource::new( + "object".to_string(), + Some(vec![ + "arn:${Partition}:s3:::${BucketName}/${ObjectName}".to_string() + ]), + )], + vec![], + Explanation { + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "s3".to_string(), + "get_object".to_string(), + OperationSource::Provided, + ))])], + }, + )], + sdk_method_call: &sdk_call, + }; + + let result = engine.generate_policies(&[enriched_call]).unwrap(); + + // Verify policies were generated + assert_eq!(result.policies.len(), 1); + + // Verify explanations were collected + if let Some(explanation) = result + .explanations + .as_ref() + .and_then(|explanations| explanations.explanation_for_action.get("s3:GetObject")) + { + assert_eq!(explanation.reasons.len(), 1); + assert_eq!(explanation.reasons[0].operations.len(), 1); + } else { + panic!("Must have an explanation for s3:GetObject"); + } + } + + #[test] + fn test_explanation_deduplication() { + use crate::enrichment::{Explanation, Operation, OperationSource, Reason}; + use std::sync::Arc; + + let engine = create_test_engine(); + let sdk_call1 = create_test_sdk_call(); + let sdk_call2 = SdkMethodCall { + name: "get_object".to_string(), + possible_services: vec!["s3".to_string()], + metadata: None, + }; + + // Create two enriched calls with duplicate explanations + let enriched_call1 = EnrichedSdkMethodCall { + method_name: "get_object".to_string(), + service: "s3".to_string(), + actions: vec![Action::new( + "s3:GetObject".to_string(), + vec![Resource::new( + "object".to_string(), + Some(vec![ + "arn:${Partition}:s3:::${BucketName}/${ObjectName}".to_string() + ]), + )], + vec![], + Explanation { + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "s3".to_string(), + "get_object".to_string(), + OperationSource::Provided, + ))])], + }, + )], + sdk_method_call: &sdk_call1, + }; + + let enriched_call2 = EnrichedSdkMethodCall { + method_name: "get_object".to_string(), + service: "s3".to_string(), + actions: vec![Action::new( + "s3:GetObject".to_string(), + vec![Resource::new( + "object".to_string(), + Some(vec![ + "arn:${Partition}:s3:::${BucketName}/${ObjectName}".to_string() + ]), + )], + vec![], + Explanation { + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "s3".to_string(), + "get_object".to_string(), + OperationSource::Provided, + ))])], + }, + )], + sdk_method_call: &sdk_call2, + }; + + let result = engine + .generate_policies(&[enriched_call1, enriched_call2]) + .unwrap(); + + // Verify explanations were grouped by action with deduplicated reasons + if let Some(explanation) = result + .explanations + .as_ref() + .and_then(|explanations| explanations.explanation_for_action.get("s3:GetObject")) + { + assert_eq!( + explanation.reasons.len(), + 1, + "Duplicate reasons should be deduplicated" + ); + } else { + panic!("Must have an explanation for s3:GetObject"); + } + } + + #[test] + fn test_explanation_with_fas_expansion() { + use crate::enrichment::{ + operation_fas_map::FasContext, Explanation, Operation, OperationSource, Reason, + }; + use std::sync::Arc; + + let engine = create_test_engine(); + let sdk_call = create_test_sdk_call(); + + // Create enriched call with FAS expansion + let enriched_call = EnrichedSdkMethodCall { + method_name: "get_object".to_string(), + service: "s3".to_string(), + actions: vec![ + Action::new( + "s3:GetObject".to_string(), + vec![Resource::new( + "object".to_string(), + Some(vec![ + "arn:${Partition}:s3:::${BucketName}/${ObjectName}".to_string() + ]), + )], + vec![], + Explanation { + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "s3".to_string(), + "get_object".to_string(), + OperationSource::Provided, + ))])], + }, + ), + Action::new( + "kms:Decrypt".to_string(), + vec![Resource::new( + "key".to_string(), + Some(vec![ + "arn:${Partition}:kms:${Region}:${Account}:key/${KeyId}".to_string(), + ]), + )], + vec![], + Explanation { + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "kms".to_string(), + "Decrypt".to_string(), + OperationSource::Fas(vec![FasContext::new( + "kms:ViaService".to_string(), + vec!["s3.us-east-1.amazonaws.com".to_string()], + )]), + ))])], + }, + ), + ], + sdk_method_call: &sdk_call, + }; + + let result = engine.generate_policies(&[enriched_call]).unwrap(); + + // Verify explanations include FAS expansion + assert_eq!( + result + .explanations + .as_ref() + .unwrap() + .explanation_for_action + .len(), + 2 + ); + + // Check the FAS-expanded action + let kms_explanation = result + .explanations + .as_ref() + .unwrap() + .explanation_for_action + .get("kms:Decrypt") + .expect("Should have kms:Decrypt explanation"); + assert_eq!(kms_explanation.reasons.len(), 1); + assert_eq!(kms_explanation.reasons[0].operations.len(), 1); + + // Check that the operation has FAS context + let operation = &kms_explanation.reasons[0].operations[0]; + assert_eq!(operation.name, "Decrypt"); + assert_eq!(operation.service, "kms"); + match &operation.source { + OperationSource::Fas(context) => { + assert_eq!(context.len(), 1); + assert_eq!(context[0].key, "kms:ViaService"); + assert_eq!(context[0].values, vec!["s3.us-east-1.amazonaws.com"]); + } + _ => panic!("Expected FAS operation source"), + } + } + + #[test] + fn test_explanation_with_possible_false_positive() { + use crate::enrichment::{Explanation, Operation, OperationSource, Reason}; + use std::sync::Arc; + + let engine = create_test_engine(); + let sdk_call = SdkMethodCall { + name: "get_object".to_string(), + possible_services: vec!["s3".to_string(), "s3-object-lambda".to_string()], + metadata: None, + }; + + // Create enriched call with multiple possible services (false positive flag) + let enriched_call = EnrichedSdkMethodCall { + method_name: "get_object".to_string(), + service: "s3".to_string(), + actions: vec![Action::new( + "s3:GetObject".to_string(), + vec![Resource::new( + "object".to_string(), + Some(vec![ + "arn:${Partition}:s3:::${BucketName}/${ObjectName}".to_string() + ]), + )], + vec![], + Explanation { + reasons: vec![Reason::new(vec![Arc::new(Operation::new( + "s3".to_string(), + "get_object".to_string(), + OperationSource::Provided, + ))])], + }, + )], + sdk_method_call: &sdk_call, + }; + + let result = engine.generate_policies(&[enriched_call]).unwrap(); + + assert_eq!( + result + .explanations + .as_ref() + .unwrap() + .explanation_for_action + .len(), + 1 + ); + } } diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs index db4214c..9c8ea2b 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/integration_tests.rs @@ -8,7 +8,7 @@ mod tests { use super::super::{Effect, Engine}; use crate::enrichment::{Action, EnrichedSdkMethodCall, Resource}; use crate::errors::ExtractorError; - use crate::SdkMethodCall; + use crate::{Explanation, SdkMethodCall}; fn create_test_sdk_call() -> SdkMethodCall { SdkMethodCall { @@ -38,6 +38,7 @@ mod tests { ]), )], vec![], + Explanation::default(), ), Action::new( "s3:GetObjectVersion".to_string(), @@ -48,17 +49,18 @@ mod tests { ]), )], vec![], + Explanation::default(), ), ], sdk_method_call: &sdk_call, }; // Generate policies - let policies = engine.generate_policies(&[enriched_call]).unwrap(); + let result = engine.generate_policies(&[enriched_call]).unwrap(); // Verify results - assert_eq!(policies.len(), 1); - let policy = &policies[0].policy; + assert_eq!(result.policies.len(), 1); + let policy = &result.policies[0].policy; // Check policy structure assert_eq!(policy.version, "2012-10-17"); @@ -108,6 +110,7 @@ mod tests { ]), )], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call1, }, @@ -123,24 +126,25 @@ mod tests { ]), )], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call2, }, ]; - let policies = engine.generate_policies(&enriched_calls).unwrap(); + let result = engine.generate_policies(&enriched_calls).unwrap(); // Should generate one policy per enriched call - assert_eq!(policies.len(), 2); + assert_eq!(result.policies.len(), 2); // Check first policy - let policy1 = &policies[0].policy; + let policy1 = &result.policies[0].policy; assert_eq!(policy1.statements.len(), 1); assert_eq!(policy1.statements[0].action, vec!["s3:GetObject"]); assert_eq!(policy1.statements[0].resource, vec!["arn:aws:s3:::*/*"]); // Check second policy - let policy2 = &policies[1].policy; + let policy2 = &result.policies[1].policy; assert_eq!(policy2.statements.len(), 1); assert_eq!(policy2.statements[0].action, vec!["s3:PutObject"]); assert_eq!(policy2.statements[0].resource, vec!["arn:aws:s3:::*/*"]); @@ -172,14 +176,15 @@ mod tests { ]) ) ], - vec![] + vec![], + Explanation::default(), ) ], sdk_method_call: &sdk_call, }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - let policy = &policies[0].policy; + let result = engine.generate_policies(&[enriched_call]).unwrap(); + let policy = &result.policies[0].policy; let statement = &policy.statements[0]; // Verify ARN patterns are correctly processed for China partition @@ -209,12 +214,13 @@ mod tests { ]), )], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call, }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - let policy = &policies[0]; + let result = engine.generate_policies(&[enriched_call]).unwrap(); + let policy = &result.policies[0]; // Test JSON serialization let json = serde_json::to_string_pretty(policy).unwrap(); @@ -244,6 +250,7 @@ mod tests { ]), // Invalid empty placeholder )], vec![], + Explanation::default(), )], sdk_method_call: &sdk_call, }; @@ -271,12 +278,13 @@ mod tests { "s3:ListAllMyBuckets".to_string(), vec![Resource::new("*".to_string(), None)], // No ARN patterns vec![], + Explanation::default(), )], sdk_method_call: &sdk_call, }; - let policies = engine.generate_policies(&[enriched_call]).unwrap(); - let policy = &policies[0].policy; + let result = engine.generate_policies(&[enriched_call]).unwrap(); + let policy = &result.policies[0].policy; let statement = &policy.statements[0]; // Should fallback to wildcard resource diff --git a/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs b/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs index 650b4a0..e8ae478 100644 --- a/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs +++ b/iam-policy-autopilot-policy-generation/src/policy_generation/mod.rs @@ -42,29 +42,6 @@ where condition_map.serialize(serializer) } -/// Represents a single IAM action with its associated resources -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "PascalCase")] -pub(crate) struct ActionMapping { - /// The IAM action name (e.g., "s3:GetObject") - pub(crate) action_name: String, - /// Resources this action applies to - pub(crate) resources: Vec, -} - -/// Represents the mapping between a method call and its required IAM actions -#[derive(Debug, Clone, Serialize, Deserialize)] -#[serde(rename_all = "PascalCase")] -#[non_exhaustive] -pub struct MethodActionMapping { - /// The method call name - pub(crate) method_call: String, - /// The AWS service this method belongs to - pub(crate) service: String, - /// List of IAM actions required by this method - pub(crate) actions: Vec, -} - /// Represents a complete IAM policy document #[derive(Debug, Clone, Serialize, PartialEq, Eq)] #[non_exhaustive] @@ -283,43 +260,6 @@ mod tests { assert!(json.contains("\"Condition\":{\"StringLike\":{\"s3:ExistingObjectTag/Environment\":[\"production-*\",\"staging-*\"]}}")); } - #[test] - fn test_method_action_mapping_serialization() { - let action_mapping = ActionMapping { - action_name: "s3:GetObject".to_string(), - resources: vec!["arn:aws:s3:::*/*".to_string()], - }; - - let method_mapping = MethodActionMapping { - method_call: "get_object".to_string(), - service: "s3".to_string(), - actions: vec![action_mapping], - }; - - let json = serde_json::to_string(&method_mapping).unwrap(); - - // Verify PascalCase field names - assert!(json.contains("\"MethodCall\":\"get_object\"")); - assert!(json.contains("\"Service\":\"s3\"")); - assert!(json.contains("\"Actions\":")); - assert!(json.contains("\"ActionName\":\"s3:GetObject\"")); - assert!(json.contains("\"Resources\":")); - } - - #[test] - fn test_action_mapping_serialization() { - let action_mapping = ActionMapping { - action_name: "events:PutRule".to_string(), - resources: vec!["arn:aws:events:us-east-1:123456789012:rule/*".to_string()], - }; - - let json = serde_json::to_string(&action_mapping).unwrap(); - - // Verify PascalCase field names - assert!(json.contains("\"ActionName\":\"events:PutRule\"")); - assert!(json.contains("\"Resources\":[\"arn:aws:events:us-east-1:123456789012:rule/*\"]")); - } - #[test] fn test_policy_with_metadata_serialization() { let mut policy = IamPolicy::new(); diff --git a/iam-policy-autopilot-policy-generation/src/service_configuration.rs b/iam-policy-autopilot-policy-generation/src/service_configuration.rs index d603244..2cc6337 100644 --- a/iam-policy-autopilot-policy-generation/src/service_configuration.rs +++ b/iam-policy-autopilot-policy-generation/src/service_configuration.rs @@ -14,6 +14,8 @@ use std::{ /// Operation rename configuration #[derive(Clone, Debug, Deserialize)] +// TODO: remove +#[allow(dead_code)] pub(crate) struct OperationRename { /// Target service name pub(crate) service: String, @@ -31,6 +33,8 @@ pub(crate) struct ServiceConfiguration { pub(crate) rename_services_service_reference: HashMap, /// Smithy to Botocore model: service renames pub(crate) smithy_botocore_service_name_mapping: HashMap, + // TODO: remove + #[allow(dead_code)] /// Operation renames pub(crate) rename_operations: HashMap, /// Resource overrides @@ -55,6 +59,8 @@ impl ServiceConfiguration { } } + // TODO: remove + #[allow(dead_code)] pub(crate) fn rename_operation<'a>(&self, service: &str, original: &'a str) -> Cow<'a, str> { let tmp = format!("{}:{}", service, original); match self.rename_operations.get(&tmp) { diff --git a/iam-policy-autopilot-policy-generation/tests/go_extraction_integration_test.rs b/iam-policy-autopilot-policy-generation/tests/go_extraction_integration_test.rs index bdd2ab4..3fb94d1 100644 --- a/iam-policy-autopilot-policy-generation/tests/go_extraction_integration_test.rs +++ b/iam-policy-autopilot-policy-generation/tests/go_extraction_integration_test.rs @@ -109,59 +109,15 @@ async fn test_go_extraction_to_policy_generation_integration() { PolicyGenerationEngine::new("aws", "us-east-1", "123456789012"); match policy_engine.generate_policies(&enriched_calls) { - Ok(policies) => { + Ok(result) => { + let policies = result.policies; println!("Generated {} IAM policies:", policies.len()); // Verify policy generation worked assert!(!policies.is_empty(), "Should generate at least one policy"); - // Step 5: Test policy merging - println!("\nStep 5: Testing policy merging..."); - - if policies.len() > 1 { - match policy_engine.merge_policies(&policies) { - Ok(merged_policy) => { - println!( - "Successfully merged {} policies into one", - policies.len() - ); - - // Test JSON serialization of merged policy - match serde_json::to_string_pretty(&merged_policy) { - Ok(json) => { - println!( - "Merged policy JSON ({} bytes)", - json.len() - ); - - // Verify JSON structure - assert!( - json.contains("\"Version\""), - "JSON should contain Version field" - ); - assert!( - json.contains("\"Statement\""), - "JSON should contain Statement field" - ); - } - Err(e) => { - panic!( - "Failed to serialize merged policy to JSON: {}", - e - ); - } - } - } - Err(e) => { - panic!("Policy merging failed: {}", e); - } - } - } else { - println!("Only one policy generated, skipping merge test"); - } - - // Step 6: Test JSON serialization of individual policies - println!("\nStep 6: Testing JSON serialization..."); + // Step 5: Test JSON serialization of individual policies + println!("\nStep 5: Testing JSON serialization..."); for (i, policy) in policies.iter().enumerate() { match serde_json::to_string_pretty(policy) { diff --git a/iam-policy-autopilot-policy-generation/tests/public_api_integration_test.rs b/iam-policy-autopilot-policy-generation/tests/public_api_integration_test.rs index 33195f6..c791ee8 100644 --- a/iam-policy-autopilot-policy-generation/tests/public_api_integration_test.rs +++ b/iam-policy-autopilot-policy-generation/tests/public_api_integration_test.rs @@ -102,10 +102,12 @@ def download_object(): // 4. Test Policy Generation Engine (Public API) let policy_engine = PolicyGenerationEngine::new("aws", "us-east-1", "123456789012"); - let policies = policy_engine + let result = policy_engine .generate_policies(&enriched_methods) .expect("Policy generation should succeed"); + let policies = result.policies; + // Verify policy generation results using serialization assert!(!policies.is_empty(), "Should generate policies"); println!("Generated {} policies", policies.len()); @@ -117,29 +119,6 @@ def download_object(): assert!(policy_json.contains("Statement")); println!(" Policy {}: serialized successfully", i + 1); } - - // 5. Test policy merging - let merged_policy = policy_engine - .merge_policies(&policies) - .expect("Policy merging should succeed"); - - let merged_json = - serde_json::to_string(&merged_policy).expect("Should serialize merged policy"); - assert!(merged_json.contains("2012-10-17")); - println!("Merged policy serialized successfully"); - - // 6. Test method action mapping extraction - let action_mappings = policy_engine - .extract_action_mappings(&enriched_methods) - .expect("Action mapping extraction should succeed"); - - assert!(!action_mappings.is_empty(), "Should have action mappings"); - println!("Generated {} action mappings", action_mappings.len()); - - // Test serialization of action mappings - let mappings_json = - serde_json::to_string(&action_mappings).expect("Should serialize action mappings"); - assert!(!mappings_json.is_empty()); } #[tokio::test] @@ -289,10 +268,12 @@ def multi_service_operations(): let policy_engine = PolicyGenerationEngine::new("aws", "us-west-2", "987654321098"); - let policies = policy_engine + let result = policy_engine .generate_policies(&enriched) .expect("Should generate policies"); + let policies = result.policies; + // Verify we got policies for multi-service operations println!( "Generated {} policies for multi-service operations", @@ -445,7 +426,8 @@ def start_policy_generation(): let policy_engine = PolicyGenerationEngine::new("aws", "us-east-1", "123456789012"); let policies = policy_engine .generate_policies(&enriched) - .expect("Policy generation should succeed"); + .expect("Policy generation should succeed") + .policies; assert!(!policies.is_empty(), "Should generate policies");