diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs index 91a5a38..ad93529 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/disambiguation.rs @@ -387,7 +387,6 @@ mod tests { }, operations: sqs_operations, shapes: sqs_shapes, - waiters: HashMap::new(), }, ); @@ -478,7 +477,6 @@ mod tests { }, operations: s3_operations, shapes: s3_shapes, - waiters: HashMap::new(), }; services.insert("s3".to_string(), s3_service_def); @@ -538,7 +536,6 @@ mod tests { }, operations: s3control_operations, shapes: s3control_shapes, - waiters: HashMap::new(), }; services.insert("s3control".to_string(), s3control_service_def); @@ -577,7 +574,7 @@ mod tests { ServiceModelIndex { services, method_lookup, - waiter_to_services: HashMap::new(), + waiter_lookup: HashMap::new(), } } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs index 2a85e69..70433fa 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/extractor.rs @@ -1051,7 +1051,6 @@ func main() { }, operations: s3_operations, shapes: s3_shapes, - waiters: HashMap::new(), }, ); @@ -1112,7 +1111,6 @@ func main() { }, operations: s3control_operations, shapes: s3control_shapes, - waiters: HashMap::new(), }, ); @@ -1141,7 +1139,7 @@ func main() { let service_index = ServiceModelIndex { services, method_lookup, - waiter_to_services: HashMap::new(), + waiter_lookup: HashMap::new(), }; // Test Go code with S3 import but GetObject call that exists in both s3 and s3control diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs index 04712ca..8380d79 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/paginator_extractor.rs @@ -362,7 +362,6 @@ mod tests { }, operations: s3_operations, shapes: s3_shapes, - waiters: HashMap::new(), }, ); @@ -386,7 +385,6 @@ mod tests { }, operations: ec2_operations, shapes: HashMap::new(), - waiters: HashMap::new(), }, ); @@ -409,7 +407,6 @@ mod tests { }, operations: dynamodb_operations, shapes: HashMap::new(), - waiters: HashMap::new(), }, ); @@ -433,7 +430,6 @@ mod tests { }, operations: gamelift_operations, shapes: HashMap::new(), - waiters: HashMap::new(), }, ); @@ -462,7 +458,6 @@ mod tests { }, operations, shapes: HashMap::new(), - waiters: HashMap::new(), }, ); } @@ -501,7 +496,7 @@ mod tests { ServiceModelIndex { services, method_lookup, - waiter_to_services: HashMap::new(), + waiter_lookup: HashMap::new(), } } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs index ceb7b7c..3a961bc 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/go/waiter_extractor.rs @@ -230,98 +230,71 @@ impl<'a> GoWaiterExtractor<'a> { best_match.map(|w| (w, best_idx)) } - /// Create synthetic SdkMethodCall objects from a matched waiter + wait - fn create_synthetic_call( + fn create_synthetic_call_internal( &self, - wait_call: &WaitCallInfo, + wait_call: Option<&WaitCallInfo>, waiter_info: &WaiterInfo, ) -> Vec { - let waiter_name = &waiter_info.waiter_type; - - // Look up all services that provide this waiter - let candidate_services = self - .service_index - .waiter_to_services - .get(waiter_name) - .cloned() - .unwrap_or_default(); - - if candidate_services.is_empty() { - return Vec::new(); - } - let mut synthetic_calls = Vec::new(); - // Create one call per service - for service_name in candidate_services { - if let Some(service_def) = self.service_index.services.get(&service_name) { - if let Some(operation) = service_def.waiters.get(waiter_name) { - // Filter out waiter-specific parameters - let filtered_params = - self.filter_waiter_parameters(wait_call.arguments.clone()); - - synthetic_calls.push(SdkMethodCall { - name: operation.name.clone(), - possible_services: vec![service_name.clone()], // Single service per call - metadata: Some(SdkMethodCallMetadata { - parameters: filtered_params, - return_type: None, - start_position: wait_call.start_position, - end_position: wait_call.end_position, - receiver: Some(waiter_info.client_receiver.clone()), - }), - }); - } + // waiter_type already contains the clean waiter name (e.g., "InstanceTerminated") + if let Some(service_defs) = self + .service_index + .waiter_lookup + .get(&waiter_info.waiter_type) + { + // Create one call per service + for service_def in service_defs { + let service_name = &service_def.service_name; + let operation_name = &service_def.operation_name; + + let (parameters, start_position, end_position) = match wait_call { + Some(wait_call) => ( + self.filter_waiter_parameters(wait_call.arguments.clone()), + wait_call.start_position, + wait_call.end_position, + ), + None => { + // Fallback: + ( + // Get required parameters for this operation + self.get_required_parameters(service_name, operation_name), + (waiter_info.creation_line, 1), + (waiter_info.creation_line, 1), + ) + } + }; + + synthetic_calls.push(SdkMethodCall { + name: operation_name.clone(), + possible_services: vec![service_name.clone()], // Single service per call + metadata: Some(SdkMethodCallMetadata { + parameters, + return_type: None, + start_position, + end_position, + receiver: Some(waiter_info.client_receiver.clone()), + }), + }); } } synthetic_calls } + /// Create synthetic SdkMethodCall objects from a matched waiter + wait + fn create_synthetic_call( + &self, + wait_call: &WaitCallInfo, + waiter_info: &WaiterInfo, + ) -> Vec { + self.create_synthetic_call_internal(Some(wait_call), waiter_info) + } + /// Create fallback synthetic calls for unmatched waiter creation /// Returns one call per service that has the waiter, matching Python behavior fn create_fallback_synthetic_call(&self, waiter_info: &WaiterInfo) -> Vec { - // waiter_type already contains the clean waiter name (e.g., "InstanceTerminated") - let waiter_name = &waiter_info.waiter_type; - - // Look up all services that provide this waiter - let candidate_services = self - .service_index - .waiter_to_services - .get(waiter_name) - .cloned() - .unwrap_or_default(); - - if candidate_services.is_empty() { - return Vec::new(); - } - - let mut synthetic_calls = Vec::new(); - - // Create one call per service - for service_name in candidate_services { - if let Some(service_def) = self.service_index.services.get(&service_name) { - if let Some(operation) = service_def.waiters.get(waiter_name) { - // Get required parameters for this operation - let required_params = - self.get_required_parameters(&service_name, &operation.name); - - synthetic_calls.push(SdkMethodCall { - name: operation.name.clone(), - possible_services: vec![service_name.clone()], // Single service per call - metadata: Some(SdkMethodCallMetadata { - parameters: required_params, - return_type: None, - start_position: (waiter_info.creation_line, 1), - end_position: (waiter_info.creation_line, 1), - receiver: Some(waiter_info.client_receiver.clone()), - }), - }); - } - } - } - - synthetic_calls + self.create_synthetic_call_internal(None, waiter_info) } /// Get required parameters for an operation from the service index @@ -362,6 +335,8 @@ impl<'a> GoWaiterExtractor<'a> { #[cfg(test)] mod tests { + use crate::extraction::sdk_model::ServiceMethodRef; + use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Go; @@ -379,7 +354,7 @@ mod tests { }; let mut services = HashMap::new(); - let mut waiter_to_services = HashMap::new(); + let mut waiter_lookup = HashMap::new(); // Create EC2 service with DescribeInstances operation let mut ec2_operations = HashMap::new(); @@ -431,18 +406,29 @@ mod tests { }, operations: ec2_operations, shapes: ec2_shapes, - waiters: ec2_waiters, }, ); // Use PascalCase for waiter_to_services index - waiter_to_services.insert("InstanceTerminated".to_string(), vec!["ec2".to_string()]); - waiter_to_services.insert("InstanceRunning".to_string(), vec!["ec2".to_string()]); + waiter_lookup.insert( + "InstanceTerminated".to_string(), + vec![ServiceMethodRef { + service_name: "ec2".to_string(), + operation_name: "DescribeInstances".to_string(), + }], + ); + waiter_lookup.insert( + "InstanceRunning".to_string(), + vec![ServiceMethodRef { + service_name: "ec2".to_string(), + operation_name: "DescribeInstances".to_string(), + }], + ); ServiceModelIndex { services, method_lookup: HashMap::new(), - waiter_to_services, + waiter_lookup, } } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs index 6e265d6..10bd3a3 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/javascript/extractor.rs @@ -79,15 +79,23 @@ impl Extractor for JavaScriptExtractor { // First: Resolve waiter names to actual operations // For each call, check if it's a waiter name and replace with the actual operation for call in method_calls.iter_mut() { - // Try to find this name in the waiter mappings for each possible service - for service_name in &call.possible_services.clone() { - if let Some(service_def) = service_index.services.get(service_name) { - // Check if this is a waiter name in the service's waiters map - if let Some(operation) = service_def.waiters.get(&call.name) { - // Replace waiter name with actual operation name - call.name = operation.name.clone(); - break; // Found the waiter, no need to check other services - } + if let Some(service_methods) = service_index.waiter_lookup.get(&call.name) { + let matching_method = service_methods + .iter() + .find(|sm| call.possible_services.contains(&sm.service_name)); + + if let Some(method) = matching_method { + call.name = method.operation_name.clone(); + } else { + log::warn!( + "Waiter '{}' found in services {:?} but imported from {:?}", + call.name, + service_methods + .iter() + .map(|sm| &sm.service_name) + .collect::>(), + call.possible_services + ); } } } @@ -118,6 +126,7 @@ impl Extractor for JavaScriptExtractor { true } else { // Method name doesn't exist in SDK - filter it out + log::warn!("Filtering out {}", call.name); false } }); @@ -438,4 +447,72 @@ const command = new GetObjectCommand({ Bucket: 'test', Key: 'test.txt' }); } } } + + #[tokio::test] + async fn test_waiter_disambiguation_with_multiple_services() { + use crate::extraction::sdk_model::ServiceDiscovery; + use crate::Language; + + let extractor = JavaScriptExtractor::new(); + + // Code that imports Neptune client and uses DBInstanceAvailable waiter + let code = r#" + import { NeptuneClient, waitUntilDBInstanceAvailable } from '@aws-sdk/client-neptune'; + + const client = new NeptuneClient({ region: 'us-east-1' }); + + async function waitForInstance() { + await waitUntilDBInstanceAvailable( + { client, maxWaitTime: 300 }, + { DBInstanceIdentifier: 'my-neptune-instance' } + ); + } + "#; + + // Parse the code + let mut results = vec![extractor.parse(code).await]; + + // Load service index + let service_index = ServiceDiscovery::load_service_index(Language::JavaScript) + .await + .expect("Failed to load service index"); + + // Apply filter_map which includes waiter resolution + extractor.filter_map(&mut results, &service_index); + + // Verify the results + match &results[0] { + ExtractorResult::JavaScript(_ast, method_calls) => { + // Find the DBInstanceAvailable call + let db_instance_call = method_calls + .iter() + .find(|call| call.name == "DescribeDBInstances") + .expect("Should find DescribeDBInstances operation after waiter resolution"); + + // CRITICAL: Should be associated with Neptune, not RDS or DocumentDB + assert!( + db_instance_call + .possible_services + .contains(&"neptune".to_string()), + "DBInstanceAvailable waiter should resolve to Neptune service, got: {:?}", + db_instance_call.possible_services + ); + + // Should NOT contain RDS or DocumentDB + assert!( + !db_instance_call + .possible_services + .contains(&"rds".to_string()), + "Should not incorrectly resolve to RDS" + ); + assert!( + !db_instance_call + .possible_services + .contains(&"docdb".to_string()), + "Should not incorrectly resolve to DocumentDB" + ); + } + _ => panic!("Should return JavaScript result"), + } + } } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs index 6214fdf..9019cbe 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation.rs @@ -268,7 +268,6 @@ mod tests { }, operations, shapes, - waiters: HashMap::new(), }; services.insert("apigatewayv2".to_string(), service_def); @@ -285,7 +284,7 @@ mod tests { ServiceModelIndex { services, method_lookup, - waiter_to_services: HashMap::new(), + waiter_lookup: HashMap::new(), } } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs index 1858dbd..e8ca62d 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/disambiguation_tests.rs @@ -86,7 +86,6 @@ mod tests { }, operations: apigateway_operations, shapes: apigateway_shapes, - waiters: HashMap::new(), }; services.insert("apigatewayv2".to_string(), apigateway_service); @@ -144,7 +143,6 @@ mod tests { }, operations: s3_operations, shapes: s3_shapes, - waiters: HashMap::new(), }; services.insert("s3".to_string(), s3_service); @@ -169,7 +167,7 @@ mod tests { ServiceModelIndex { services, method_lookup, - waiter_to_services: HashMap::new(), + waiter_lookup: HashMap::new(), } } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs index 4ea02f3..d413b15 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/extractor.rs @@ -108,8 +108,7 @@ impl Extractor for PythonExtractor { // Add waiters to extracted methods using the service model index directly let waiters_extractor = WaitersExtractor::new(service_index); - let waiter_calls = - waiters_extractor.extract_waiter_method_calls(ast, service_index); + let waiter_calls = waiters_extractor.extract_waiter_method_calls(ast); method_calls.extend(waiter_calls); // Add paginators to extracted methods using the service model index directly diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs index b3b7572..28b4b80 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/paginator_extractor.rs @@ -522,7 +522,6 @@ mod tests { }, operations: s3_operations, shapes: s3_shapes, - waiters: HashMap::new(), }, ); @@ -538,7 +537,7 @@ mod tests { ServiceModelIndex { services, method_lookup, - waiter_to_services: HashMap::new(), + waiter_lookup: HashMap::new(), } } diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs index 7d32128..161d576 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/resource_direct_calls_extractor.rs @@ -702,22 +702,25 @@ impl<'a> ResourceDirectCallsExtractor<'a> { let resolved_operation = match &action_mapping.operation { OperationType::Waiter { waiter_name } => { // Resolve actual operation via ServiceModelIndex - if let Some(service_def) = - self.service_index.services.get(&constructor.service_name) - { - if let Some(operation) = service_def.waiters.get(waiter_name) { - operation.name.to_case(Case::Snake) - } else { - log::debug!( - "Waiter '{}' not found in service '{}' waiters", - waiter_name, - constructor.service_name - ); - return None; + if let Some(service_methods) = self.service_index.waiter_lookup.get(waiter_name) { + let service_methods_filtered = service_methods + .iter() + .filter(|x| x.service_name == constructor.service_name) + .collect::>(); + match service_methods_filtered.first() { + None => { + log::debug!( + "Service '{}' not found in ServiceModelIndex", + constructor.service_name + ); + return None; + } + Some(service_method) => service_method.operation_name.to_case(Case::Snake), } } else { log::debug!( - "Service '{}' not found in ServiceModelIndex", + "Waiter '{}' not found in service '{}' waiters", + waiter_name, constructor.service_name ); return None; diff --git a/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs index 304b5fb..3bd8fb8 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/python/waiters_extractor.rs @@ -8,7 +8,6 @@ use crate::extraction::python::common::{ArgumentExtractor, ParameterFilter}; use crate::extraction::{Parameter, ParameterValue, SdkMethodCall, SdkMethodCallMetadata}; use crate::ServiceModelIndex; use ast_grep_language::Python; -use convert_case::{Case, Casing}; /// Information about a discovered get_waiter call #[derive(Debug, Clone)] @@ -23,6 +22,23 @@ pub(crate) struct WaiterInfo { pub get_waiter_line: usize, } +// TODO: This should be refactored at a higher level, so this type can be removed. +// See https://github.com/awslabs/iam-policy-autopilot/issues/88. +enum CallInfo<'a> { + None(&'a WaiterInfo), + Simple(&'a WaiterInfo, &'a WaitCallInfo), + Chained(&'a ChainedWaiterCallInfo), +} + +impl<'a> CallInfo<'a> { + fn waiter_name(&self) -> &'a str { + match self { + Self::None(waiter_info) | Self::Simple(waiter_info, ..) => &waiter_info.waiter_name, + Self::Chained(waiter_call_info) => &waiter_call_info.waiter_name, + } + } +} + /// Information about a wait method call #[derive(Debug, Clone)] pub(crate) struct WaitCallInfo { @@ -92,7 +108,6 @@ impl<'a> WaitersExtractor<'a> { pub(crate) fn extract_waiter_method_calls( &self, ast: &ast_grep_core::AstGrep>, - service_index: &ServiceModelIndex, ) -> Vec { // Step 1: Find all get_waiter calls let waiters = self.find_get_waiter_calls(ast); @@ -126,7 +141,7 @@ impl<'a> WaitersExtractor<'a> { for (idx, waiter) in waiters.iter().enumerate() { if !matched_waiter_indices.contains(&idx) { // Create synthetic calls with required params for all candidate services - let unmatched_calls = self.create_unmatched_synthetic_calls(waiter, service_index); + let unmatched_calls = self.create_unmatched_synthetic_calls(waiter); synthetic_calls.extend(unmatched_calls); } } @@ -317,114 +332,88 @@ impl<'a> WaitersExtractor<'a> { /// Create synthetic SdkMethodCalls for a matched waiter + wait /// Creates one call per candidate service with the actual operation name - fn create_matched_synthetic_calls( + fn create_synthetic_calls_internal( &self, - wait_call: &WaitCallInfo, - waiter_info: &WaiterInfo, + wait_call: CallInfo, + receiver: Option, ) -> Vec { - // Convert Python snake_case waiter name to PascalCase for lookup - let waiter_name_pascal = waiter_info.waiter_name.to_case(Case::Pascal); - - // Find all services that provide this waiter using reverse index - let candidate_services = self - .service_index - .waiter_to_services - .get(&waiter_name_pascal) - .cloned() - .unwrap_or_default(); - - if candidate_services.is_empty() { - return Vec::new(); - } - let mut synthetic_calls = Vec::new(); - for service_name in candidate_services { - // Get the operation for this service+waiter combination from service definition - if let Some(service_def) = self.service_index.services.get(&service_name) { - if let Some(operation) = service_def.waiters.get(&waiter_name_pascal) { - // Convert operation name to Python snake_case for method name - let operation_snake = operation.name.to_case(Case::Snake); - - // Filter out WaiterConfig from arguments - it's waiter-specific, not operation-specific - let filtered_params = - ParameterFilter::filter_waiter_parameters(wait_call.arguments.clone()); - - // Create synthetic call with filtered wait() arguments - synthetic_calls.push(SdkMethodCall { - name: operation_snake, - possible_services: vec![service_name.clone()], - metadata: Some(SdkMethodCallMetadata { - parameters: filtered_params, - return_type: None, + // Get the operation for this service+waiter combination from service definition + if let Some(service_defs) = self + .service_index + .waiter_lookup + .get(wait_call.waiter_name()) + { + for service_method in service_defs { + let (parameters, start_position, end_position) = match wait_call { + CallInfo::Simple(_, wait_call) => { + // Filter out WaiterConfig from arguments - it's waiter-specific, not operation-specific + ( + ParameterFilter::filter_waiter_parameters(wait_call.arguments.clone()), + wait_call.start_position, + wait_call.end_position, + ) + } + CallInfo::Chained(chained_wait_call) => { + // Filter out WaiterConfig from arguments - it's waiter-specific, not operation-specific + ( + ParameterFilter::filter_waiter_parameters( + chained_wait_call.arguments.clone(), + ), // Use wait call position (most specific) - start_position: wait_call.start_position, - end_position: wait_call.end_position, - // Use client receiver from get_waiter call - receiver: Some(waiter_info.client_receiver.clone()), - }), - }); - } + chained_wait_call.start_position, + chained_wait_call.end_position, + ) + } + CallInfo::None(waiter_info) => { + let fallback_start_pos = (waiter_info.get_waiter_line, 1); + let fallback_end_pos = (waiter_info.get_waiter_line, 1); + let parameters = self.get_required_parameters( + &service_method.service_name, + &service_method.operation_name, + self.service_index, + ); + (parameters, fallback_start_pos, fallback_end_pos) + } + }; + + // Create synthetic call with filtered wait() arguments + synthetic_calls.push(SdkMethodCall { + name: wait_call.waiter_name().to_string(), + possible_services: vec![service_method.service_name.clone()], + metadata: Some(SdkMethodCallMetadata { + parameters, + return_type: None, + start_position, + end_position, + // Use client receiver from get_waiter call + receiver: receiver.clone(), + }), + }); } } synthetic_calls } - /// Create synthetic SdkMethodCalls for an unmatched get_waiter - fn create_unmatched_synthetic_calls( + fn create_matched_synthetic_calls( &self, + wait_call: &WaitCallInfo, waiter_info: &WaiterInfo, - service_index: &ServiceModelIndex, ) -> Vec { - // Convert Python snake_case waiter name to PascalCase for lookup - let waiter_name_pascal = waiter_info.waiter_name.to_case(Case::Pascal); - - // Find all services that provide this waiter using reverse index - let candidate_services = self - .service_index - .waiter_to_services - .get(&waiter_name_pascal) - .cloned() - .unwrap_or_default(); - - if candidate_services.is_empty() { - return Vec::new(); - } - - let mut synthetic_calls = Vec::new(); - - for service_name in candidate_services { - // Get the operation for this service+waiter combination from service definition - if let Some(service_def) = self.service_index.services.get(&service_name) { - if let Some(operation) = service_def.waiters.get(&waiter_name_pascal) { - // Convert operation name to Python snake_case for method name - let operation_snake = operation.name.to_case(Case::Snake); - - // Get required parameters for this operation - let required_params = self.get_required_parameters( - &service_name, - &operation_snake, - service_index, - ); - - // Create synthetic call with required parameters set to None - synthetic_calls.push(SdkMethodCall { - name: operation_snake, - possible_services: vec![service_name.clone()], - metadata: Some(SdkMethodCallMetadata { - parameters: required_params, - return_type: None, - start_position: (waiter_info.get_waiter_line, 1), - end_position: (waiter_info.get_waiter_line, 1), - receiver: Some(waiter_info.client_receiver.clone()), - }), - }); - } - } - } + self.create_synthetic_calls_internal( + CallInfo::Simple(waiter_info, wait_call), + Some(waiter_info.client_receiver.clone()), + ) + } - synthetic_calls + /// Create synthetic SdkMethodCalls for an unmatched get_waiter + fn create_unmatched_synthetic_calls(&self, waiter_info: &WaiterInfo) -> Vec { + self.create_synthetic_calls_internal( + CallInfo::None(waiter_info), + Some(waiter_info.client_receiver.clone()), + ) } /// Create synthetic SdkMethodCalls for a chained waiter call @@ -433,53 +422,10 @@ impl<'a> WaitersExtractor<'a> { &self, chained_call: &ChainedWaiterCallInfo, ) -> Vec { - // Convert Python snake_case waiter name to PascalCase for lookup - let waiter_name_pascal = chained_call.waiter_name.to_case(Case::Pascal); - - // Find all services that provide this waiter using reverse index - let candidate_services = self - .service_index - .waiter_to_services - .get(&waiter_name_pascal) - .cloned() - .unwrap_or_default(); - - if candidate_services.is_empty() { - return Vec::new(); - } - - let mut synthetic_calls = Vec::new(); - - for service_name in candidate_services { - // Get the operation for this service+waiter combination from service definition - if let Some(service_def) = self.service_index.services.get(&service_name) { - if let Some(operation) = service_def.waiters.get(&waiter_name_pascal) { - // Convert operation name to Python snake_case for method name - let operation_snake = operation.name.to_case(Case::Snake); - - // Filter out WaiterConfig from arguments - it's waiter-specific, not operation-specific - let filtered_params = - ParameterFilter::filter_waiter_parameters(chained_call.arguments.clone()); - - // Create synthetic call with filtered wait() arguments - synthetic_calls.push(SdkMethodCall { - name: operation_snake, - possible_services: vec![service_name.clone()], - metadata: Some(SdkMethodCallMetadata { - parameters: filtered_params, - return_type: None, - // Use chained call position - start_position: chained_call.start_position, - end_position: chained_call.end_position, - // Use client receiver from chained call - receiver: Some(chained_call.client_receiver.clone()), - }), - }); - } - } - } - - synthetic_calls + self.create_synthetic_calls_internal( + CallInfo::Chained(chained_call), + Some(chained_call.client_receiver.clone()), + ) } /// Get required parameters for an operation from the service index @@ -493,10 +439,7 @@ impl<'a> WaitersExtractor<'a> { // Look up the service and operation in the service index if let Some(service_def) = service_index.services.get(service_name) { - // Convert snake_case operation name to PascalCase for lookup - let operation_name_pascal = operation_name.to_case(Case::Pascal); - - if let Some(operation) = service_def.operations.get(&operation_name_pascal) { + if let Some(operation) = service_def.operations.get(operation_name) { // Get the input shape if it exists if let Some(input_ref) = &operation.input { if let Some(input_shape) = service_def.shapes.get(&input_ref.shape) { @@ -527,6 +470,8 @@ impl<'a> WaitersExtractor<'a> { #[cfg(test)] mod tests { + use crate::extraction::sdk_model::ServiceMethodRef; + use super::*; use ast_grep_core::tree_sitter::LanguageExt; use ast_grep_language::Python; @@ -546,8 +491,7 @@ mod tests { let mut services = HashMap::new(); let mut operations = HashMap::new(); let mut shapes = HashMap::new(); - let mut waiters = HashMap::new(); - let mut waiter_to_services = HashMap::new(); + let mut waiter_lookup = HashMap::new(); // Create a mock DescribeInstances operation with required params let mut input_shape_members = HashMap::new(); @@ -579,12 +523,13 @@ mod tests { describe_instances_op.clone(), ); - // Add waiter entry for InstanceTerminated - waiters.insert( - "InstanceTerminated".to_string(), - describe_instances_op.clone(), + waiter_lookup.insert( + "instance_terminated".to_string(), + vec![ServiceMethodRef { + service_name: "ec2".to_string(), + operation_name: "InstanceTerminated".to_string(), + }], ); - waiter_to_services.insert("InstanceTerminated".to_string(), vec!["ec2".to_string()]); // Create DynamoDB DescribeTables operation for table_exists waiter let mut describe_tables_members = HashMap::new(); @@ -634,7 +579,13 @@ mod tests { // Add TableExists waiter for DynamoDB dynamodb_waiters.insert("TableExists".to_string(), describe_tables_op); - waiter_to_services.insert("TableExists".to_string(), vec!["dynamodb".to_string()]); + waiter_lookup.insert( + "table_exists".to_string(), + vec![ServiceMethodRef { + service_name: "dynamodb".to_string(), + operation_name: "TableExists".to_string(), + }], + ); services.insert( "ec2".to_string(), @@ -646,7 +597,6 @@ mod tests { }, operations, shapes, - waiters, }, ); @@ -660,14 +610,13 @@ mod tests { }, operations: dynamodb_operations, shapes: dynamodb_shapes, - waiters: dynamodb_waiters, }, ); ServiceModelIndex { services, method_lookup: HashMap::new(), - waiter_to_services, + waiter_lookup, } } @@ -722,7 +671,7 @@ waiter.wait(InstanceIds=['i-1234567890abcdef0']) let service_index = create_test_service_index(); let extractor = WaitersExtractor::new(&service_index); - let calls = extractor.extract_waiter_method_calls(&ast, &service_index); + let calls = extractor.extract_waiter_method_calls(&ast); // Should extract at least one call assert!(!calls.is_empty()); @@ -779,7 +728,7 @@ dynamodb_client.get_waiter('table_exists').wait(TableName='test-table') let service_index = create_test_service_index(); let extractor = WaitersExtractor::new(&service_index); - let calls = extractor.extract_waiter_method_calls(&ast, &service_index); + let calls = extractor.extract_waiter_method_calls(&ast); // Should extract at least one call for the chained waiter assert!(!calls.is_empty()); diff --git a/iam-policy-autopilot-policy-generation/src/extraction/sdk_model.rs b/iam-policy-autopilot-policy-generation/src/extraction/sdk_model.rs index 5089527..33a1055 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/sdk_model.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/sdk_model.rs @@ -60,11 +60,6 @@ pub(crate) struct SdkServiceDefinition { pub(crate) operations: HashMap, /// Map of shape name to shape definition pub(crate) shapes: HashMap, - /// Map of waiter name (PascalCase) to underlying operation definition - /// Example: "InstanceTerminated" -> Operation { name: "DescribeInstances", ... } - /// Provides direct access to operation details without secondary HashMap lookup - #[serde(default, skip_deserializing)] - pub(crate) waiters: HashMap, } /// Service metadata from AWS service definitions @@ -134,7 +129,7 @@ pub(crate) struct ServiceModelIndex { pub(crate) method_lookup: HashMap>, /// Reverse index: waiter name (PascalCase) to list of services that provide it /// Example: "InstanceTerminated" -> ["ec2"], "BucketExists" -> ["s3"] - pub(crate) waiter_to_services: HashMap>, + pub(crate) waiter_lookup: HashMap>, } /// Reference to a service method for lookup purposes @@ -229,6 +224,7 @@ impl ServiceDiscovery { let services = Self::discover_services()?; let mut service_models = HashMap::new(); let mut method_lookup = HashMap::new(); + let mut waiter_lookup = HashMap::new(); let mut load_errors = Vec::new(); log::debug!( @@ -240,6 +236,7 @@ impl ServiceDiscovery { services, &mut service_models, &mut method_lookup, + &mut waiter_lookup, &mut load_errors, language, ) @@ -265,21 +262,10 @@ impl ServiceDiscovery { ))); } - // Build waiter-to-services reverse index - let mut waiter_to_services = HashMap::new(); - for (service_name, service_def) in &service_models { - for waiter_name in service_def.waiters.keys() { - waiter_to_services - .entry(waiter_name.clone()) - .or_insert_with(Vec::new) - .push(service_name.clone()); - } - } - let index = Arc::new(ServiceModelIndex { services: service_models, method_lookup, - waiter_to_services, + waiter_lookup, }); // Cache the index for future use @@ -296,6 +282,7 @@ impl ServiceDiscovery { services: Vec, service_models: &mut HashMap, method_lookup: &mut HashMap>, + waiter_lookup: &mut HashMap>, load_errors: &mut Vec, language: Language, ) -> Result<()> { @@ -360,7 +347,7 @@ impl ServiceDiscovery { // Collect results from parallel tasks while let Some(result) = join_set.join_next().await { match result { - Ok(Ok((service_info, mut service_model, waiters))) => { + Ok(Ok((service_info, service_model, waiters))) => { let service_name = service_info.name; // Build method lookup index for this service for operation_name in service_model.operations.keys() { @@ -374,24 +361,17 @@ impl ServiceDiscovery { ); } - // Populate waiters directly in service definition in canonical PascalCase if let Some(waiters) = waiters { - // Add waiters to the service definition with full Operation structs - for (waiter_name_pascal, waiter_entry) in &waiters { - if let Some(operation) = - service_model.operations.get(&waiter_entry.operation) - { - service_model - .waiters - .insert(waiter_name_pascal.clone(), operation.clone()); - } else { - log::debug!( - "Waiter '{}' references unknown operation '{}' in service '{}'", - waiter_name_pascal, - waiter_entry.operation, - service_name - ); - } + for (waiter_name, waiter_entry) in waiters { + let method_name = + Self::operation_to_method_name(&waiter_name, language); + + waiter_lookup.entry(method_name.clone()).or_default().push( + ServiceMethodRef { + service_name: service_name.clone(), + operation_name: waiter_entry.operation.clone(), + }, + ); } } @@ -555,7 +535,6 @@ mod tests { }, operations: HashMap::new(), shapes: HashMap::new(), - waiters: HashMap::new(), }; assert_eq!(service.version, Some("2.0".to_string())); diff --git a/iam-policy-autopilot-policy-generation/src/extraction/service_hints.rs b/iam-policy-autopilot-policy-generation/src/extraction/service_hints.rs index 0679621..d4e7fb1 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/service_hints.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/service_hints.rs @@ -309,7 +309,6 @@ mod tests { }, operations: HashMap::new(), shapes: HashMap::new(), - waiters: HashMap::new(), }; // Create a minimal service index for testing @@ -324,7 +323,7 @@ mod tests { .cloned() .collect(), method_lookup: HashMap::new(), - waiter_to_services: HashMap::new(), + waiter_lookup: HashMap::new(), }); let hints = ServiceHints { @@ -379,7 +378,6 @@ mod tests { }, operations: HashMap::new(), shapes: HashMap::new(), - waiters: HashMap::new(), }; // Create a service index with logs service @@ -393,7 +391,7 @@ mod tests { .cloned() .collect(), method_lookup: HashMap::new(), - waiter_to_services: HashMap::new(), + waiter_lookup: HashMap::new(), }); let hints = ServiceHints { @@ -463,7 +461,6 @@ mod tests { }, operations: HashMap::new(), shapes: HashMap::new(), - waiters: HashMap::new(), }; // Create a service index with services A and B @@ -477,7 +474,7 @@ mod tests { .cloned() .collect(), method_lookup: HashMap::new(), - waiter_to_services: HashMap::new(), + waiter_lookup: HashMap::new(), }); // Service hints with both A and B diff --git a/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs b/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs index 71bd27e..85d53f4 100644 --- a/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs +++ b/iam-policy-autopilot-policy-generation/src/extraction/typescript/extractor.rs @@ -78,15 +78,23 @@ impl Extractor for TypeScriptExtractor { // For each call, check if it's a waiter name and replace with the actual operation for call in method_calls.iter_mut() { - // Try to find this name in the waiter mappings for each possible service - for service_name in &call.possible_services.clone() { - if let Some(service_def) = service_index.services.get(service_name) { - // Check if this is a waiter name in the service's waiters map - if let Some(operation) = service_def.waiters.get(&call.name) { - // Replace waiter name with actual operation name - call.name = operation.name.clone(); - break; // Found the waiter, no need to check other services - } + if let Some(service_methods) = service_index.waiter_lookup.get(&call.name) { + let matching_method = service_methods + .iter() + .find(|sm| call.possible_services.contains(&sm.service_name)); + + if let Some(method) = matching_method { + call.name = method.operation_name.clone(); + } else { + log::warn!( + "Waiter '{}' found in services {:?} but imported from {:?}", + call.name, + service_methods + .iter() + .map(|sm| &sm.service_name) + .collect::>(), + call.possible_services + ); } } } @@ -416,4 +424,72 @@ const command = new GetObjectCommand({ Bucket: 'test', Key: 'test.txt' }); } } } + + #[tokio::test] + async fn test_waiter_disambiguation_with_multiple_services() { + use crate::extraction::sdk_model::ServiceDiscovery; + use crate::Language; + + let extractor = TypeScriptExtractor::new(); + + // Code that imports Neptune client and uses DBInstanceAvailable waiter + let typescript_code = r#" + import { NeptuneClient, waitUntilDBInstanceAvailable } from '@aws-sdk/client-neptune'; + + const client = new NeptuneClient({ region: 'us-east-1' }); + + async function waitForInstance() { + await waitUntilDBInstanceAvailable( + { client, maxWaitTime: 300 }, + { DBInstanceIdentifier: 'my-neptune-instance' } + ); + } + "#; + + // Parse the code + let mut results = vec![extractor.parse(typescript_code).await]; + + // Load service index + let service_index = ServiceDiscovery::load_service_index(Language::TypeScript) + .await + .expect("Failed to load service index"); + + // Apply filter_map which includes waiter resolution + extractor.filter_map(&mut results, &service_index); + + // Verify the results + match &results[0] { + ExtractorResult::TypeScript(_ast, method_calls) => { + // Find the DBInstanceAvailable call + let db_instance_call = method_calls + .iter() + .find(|call| call.name == "DescribeDBInstances") + .expect("Should find DescribeDBInstances operation after waiter resolution"); + + // CRITICAL: Should be associated with Neptune, not RDS or DocumentDB + assert!( + db_instance_call + .possible_services + .contains(&"neptune".to_string()), + "DBInstanceAvailable waiter should resolve to Neptune service, got: {:?}", + db_instance_call.possible_services + ); + + // Should NOT contain RDS or DocumentDB + assert!( + !db_instance_call + .possible_services + .contains(&"rds".to_string()), + "Should not incorrectly resolve to RDS" + ); + assert!( + !db_instance_call + .possible_services + .contains(&"docdb".to_string()), + "Should not incorrectly resolve to DocumentDB" + ); + } + _ => panic!("Should return TypeScript result"), + } + } }