|  | 
|  | 1 | +#!/usr/bin/env python3 | 
|  | 2 | +""" | 
|  | 3 | +MSAL Feature Test Runner | 
|  | 4 | +Interprets testcase file(s) to create and execute test cases using MSAL. | 
|  | 5 | +
 | 
|  | 6 | +Initially created by the following prompt: | 
|  | 7 | +Write a python implementation that can read content from feature.yml, create variables whose names are defined in the "arrange" mapping's keys, and the variables' value are derived from the "arrange" mapping's value; interpret those value as if they are python snippet using MSAL library. | 
|  | 8 | +""" | 
|  | 9 | +import os | 
|  | 10 | +import sys | 
|  | 11 | +import logging | 
|  | 12 | +from contextlib import contextmanager | 
|  | 13 | +from typing import Dict, Any, List, Optional | 
|  | 14 | + | 
|  | 15 | +import yaml | 
|  | 16 | +import msal | 
|  | 17 | +import requests | 
|  | 18 | + | 
|  | 19 | + | 
|  | 20 | +logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | 
|  | 21 | +logger = logging.getLogger(__name__) | 
|  | 22 | + | 
|  | 23 | +class SmileTestRunner: | 
|  | 24 | + | 
|  | 25 | +    def __init__(self, testcase_url: str): | 
|  | 26 | +        self.testcase_url = testcase_url | 
|  | 27 | +        self.test_spec = None | 
|  | 28 | +        self.variables = {} | 
|  | 29 | + | 
|  | 30 | +    def load_feature(self) -> Dict[str, Any]: | 
|  | 31 | +        """Load and validate the feature file.""" | 
|  | 32 | +        try: | 
|  | 33 | +            with requests.get(self.testcase_url) as response: | 
|  | 34 | +                response.raise_for_status() | 
|  | 35 | +                self.test_spec = yaml.safe_load(response.text) | 
|  | 36 | + | 
|  | 37 | +            # Basic validation | 
|  | 38 | +            if not isinstance(self.test_spec, dict): | 
|  | 39 | +                raise ValueError("Feature file must contain a valid YAML dictionary") | 
|  | 40 | + | 
|  | 41 | +            if self.test_spec.get('type') != 'MSAL Test': | 
|  | 42 | +                raise ValueError("Feature file must have type 'MSAL Test'") | 
|  | 43 | + | 
|  | 44 | +            return self.test_spec | 
|  | 45 | +        except Exception as e: | 
|  | 46 | +            logger.error(f"Error loading feature file: {str(e)}") | 
|  | 47 | +            sys.exit(1) | 
|  | 48 | + | 
|  | 49 | +    @contextmanager | 
|  | 50 | +    def setup_environment(self): | 
|  | 51 | +        """Set up the environment variables specified in the feature file.""" | 
|  | 52 | +        original_env = os.environ.copy() | 
|  | 53 | + | 
|  | 54 | +        try: | 
|  | 55 | +            # Set environment variables | 
|  | 56 | +            if 'env' in self.test_spec and isinstance(self.test_spec['env'], dict): | 
|  | 57 | +                for key, value in self.test_spec['env'].items(): | 
|  | 58 | +                    os.environ[key] = str(value) | 
|  | 59 | +                    logger.debug(f"Set environment variable {key}={value}") | 
|  | 60 | +            yield | 
|  | 61 | +        finally: | 
|  | 62 | +            # Restore original environment | 
|  | 63 | +            os.environ.clear() | 
|  | 64 | +            os.environ.update(original_env) | 
|  | 65 | + | 
|  | 66 | +    def arrange(self): | 
|  | 67 | +        """Create variables based on the arrange section.""" | 
|  | 68 | +        arrange_spec = self.test_spec.get('arrange', {}) | 
|  | 69 | +        if not isinstance(arrange_spec, dict): | 
|  | 70 | +            raise ValueError("Arrange section must be a dictionary") | 
|  | 71 | +        for var_name, value_spec in arrange_spec.items(): | 
|  | 72 | +            logger.debug(f"Creating variable '{var_name}' with {value_spec}") | 
|  | 73 | +            self.variables[var_name] = self._create_instance(value_spec) | 
|  | 74 | + | 
|  | 75 | +    def _create_instance(self, spec: Dict[str, Any]) -> Any: | 
|  | 76 | +        """Create an instance based on the specification.""" | 
|  | 77 | +        if not isinstance(spec, dict) or len(spec) != 1: | 
|  | 78 | +            raise ValueError(f"Invalid specification format: {spec}") | 
|  | 79 | + | 
|  | 80 | +        class_name, params = next(iter(spec.items())) | 
|  | 81 | + | 
|  | 82 | +        # Handle different MSAL classes | 
|  | 83 | +        if class_name == "ManagedIdentityClient": | 
|  | 84 | +            return msal.ManagedIdentityClient(http_client=requests.Session(), **params) | 
|  | 85 | +        elif class_name == "PublicClientApplication": | 
|  | 86 | +            return self._create_public_client_app(params) | 
|  | 87 | +        elif class_name == "ConfidentialClientApplication": | 
|  | 88 | +            return self._create_confidential_client_app(params) | 
|  | 89 | +        else: | 
|  | 90 | +            raise ValueError(f"Unsupported class: {class_name}") | 
|  | 91 | + | 
|  | 92 | +    def _create_public_client_app(self, params: Dict[str, Any]) -> Any: | 
|  | 93 | +        """Create a PublicClientApplication instance.""" | 
|  | 94 | +        if not params or 'client_id' not in params: | 
|  | 95 | +            raise ValueError("PublicClientApplication requires client_id") | 
|  | 96 | + | 
|  | 97 | +        client_id = params.get('client_id') | 
|  | 98 | +        authority = params.get('authority') | 
|  | 99 | +        logger.debug(f"Creating PublicClientApplication with client_id: {client_id}, authority: {authority}") | 
|  | 100 | + | 
|  | 101 | +        kwargs = {'client_id': client_id} | 
|  | 102 | +        if authority: | 
|  | 103 | +            kwargs['authority'] = authority | 
|  | 104 | + | 
|  | 105 | +        return msal.PublicClientApplication(**kwargs) | 
|  | 106 | + | 
|  | 107 | +    def _create_confidential_client_app(self, params: Dict[str, Any]) -> Any: | 
|  | 108 | +        """Create a ConfidentialClientApplication instance.""" | 
|  | 109 | +        if not params or 'client_id' not in params or 'client_credential' not in params: | 
|  | 110 | +            raise ValueError("ConfidentialClientApplication requires client_id and client_credential") | 
|  | 111 | + | 
|  | 112 | +        client_id = params.get('client_id') | 
|  | 113 | +        client_credential = params.get('client_credential') | 
|  | 114 | +        authority = params.get('authority') | 
|  | 115 | +        logger.debug(f"Creating ConfidentialClientApplication with client_id: {client_id}, authority: {authority}") | 
|  | 116 | + | 
|  | 117 | +        kwargs = {'client_id': client_id, 'client_credential': client_credential} | 
|  | 118 | +        if authority: | 
|  | 119 | +            kwargs['authority'] = authority | 
|  | 120 | + | 
|  | 121 | +        return msal.ConfidentialClientApplication(**kwargs) | 
|  | 122 | + | 
|  | 123 | +    def execute_steps(self) -> bool: | 
|  | 124 | +        """Execute the test steps, returns whether all steps passed.""" | 
|  | 125 | +        steps = self.test_spec.get('steps', []) | 
|  | 126 | +        passed = 0 | 
|  | 127 | +        for i, step in enumerate(steps): | 
|  | 128 | +            logger.debug(f"Executing step {i+1}/{len(steps)}") | 
|  | 129 | +            if 'act' in step: | 
|  | 130 | +                result = self._execute_action(step['act']) | 
|  | 131 | +                if 'assert' in step: | 
|  | 132 | +                    if self._validate_assertions(result, step['assert']): | 
|  | 133 | +                        passed += 1 | 
|  | 134 | +        logger.info(f"{passed} of {len(steps)} step(s) passed") | 
|  | 135 | +        return passed == len(steps) | 
|  | 136 | + | 
|  | 137 | +    def _execute_action(self, act_spec: Dict[str, Any]) -> Any: | 
|  | 138 | +        """Execute an action based on the specification.""" | 
|  | 139 | +        if not isinstance(act_spec, dict) or len(act_spec) != 1: | 
|  | 140 | +            raise ValueError(f"Invalid action specification: {act_spec}") | 
|  | 141 | + | 
|  | 142 | +        action_str, params = next(iter(act_spec.items())) | 
|  | 143 | + | 
|  | 144 | +        # Parse the action string (e.g., "app1.AcquireToken") | 
|  | 145 | +        parts = action_str.split('.') | 
|  | 146 | +        if len(parts) != 2: | 
|  | 147 | +            raise ValueError(f"Invalid action format: {action_str}") | 
|  | 148 | + | 
|  | 149 | +        var_name = parts[0] | 
|  | 150 | +        method_name = {  # Map the method names in yml to actual method names | 
|  | 151 | +            "AcquireTokenForManagedIdentity": "acquire_token_for_client", | 
|  | 152 | +            }.get(parts[1]) | 
|  | 153 | + | 
|  | 154 | +        if method_name is None: | 
|  | 155 | +            raise ValueError(f"Unsupported method: {parts[1]}") | 
|  | 156 | + | 
|  | 157 | +        if var_name not in self.variables: | 
|  | 158 | +            raise ValueError(f"Variable '{var_name}' not found") | 
|  | 159 | + | 
|  | 160 | +        instance = self.variables[var_name] | 
|  | 161 | +        if not hasattr(instance, method_name): | 
|  | 162 | +            raise ValueError(f"Method '{method_name}' not found on {var_name}") | 
|  | 163 | + | 
|  | 164 | +        method = getattr(instance, method_name) | 
|  | 165 | + | 
|  | 166 | +        # Convert parameters to kwargs | 
|  | 167 | +        kwargs = params if params else {} | 
|  | 168 | + | 
|  | 169 | +        # Execute the method with parameters | 
|  | 170 | +        logger.info(f"Calling {var_name}.{method_name} with {kwargs}") | 
|  | 171 | +        return method(**kwargs) | 
|  | 172 | + | 
|  | 173 | +    def _validate_assertions(self, result: Any, assertions: Dict[str, Any]) -> bool: | 
|  | 174 | +        """Validate the assertions against the result.""" | 
|  | 175 | +        logger.info(f"Validating assertions: {assertions}") | 
|  | 176 | +        for key, expected_value in assertions.items(): | 
|  | 177 | +            if key not in result: | 
|  | 178 | +                logger.error(f"Assertion failed: '{key}' not found in result {result}") | 
|  | 179 | +                return False  # Failed | 
|  | 180 | +            actual_value = result[key] | 
|  | 181 | +            if actual_value != expected_value: | 
|  | 182 | +                logger.error(f"Assertion failed: expected {key}='{expected_value}', got '{actual_value}'") | 
|  | 183 | +                return False  # Failed | 
|  | 184 | +            else: | 
|  | 185 | +                logger.debug(f"Assertion passed: {key}='{actual_value}'") | 
|  | 186 | +        return True  # Passed | 
|  | 187 | + | 
|  | 188 | +    def run(self) -> bool: | 
|  | 189 | +        """Run the entire test, returns whether it passed.""" | 
|  | 190 | +        self.load_feature() | 
|  | 191 | + | 
|  | 192 | +        with self.setup_environment(): | 
|  | 193 | +            self.arrange() | 
|  | 194 | +            result = self.execute_steps() | 
|  | 195 | +            if result: | 
|  | 196 | +                logger.info(f"Test case {self.testcase_url} passed") | 
|  | 197 | +            else: | 
|  | 198 | +                logger.error(f"Test case {self.testcase_url} failed") | 
|  | 199 | +        return result | 
|  | 200 | + | 
|  | 201 | + | 
|  | 202 | +def run_testcases(testcases_url: str) -> bool: | 
|  | 203 | +    try: | 
|  | 204 | +        response = requests.get(testcases_url) | 
|  | 205 | +        response.raise_for_status() | 
|  | 206 | +        passed = 0 | 
|  | 207 | +        testcases = response.json().get("testcases", []) | 
|  | 208 | +        for url in testcases: | 
|  | 209 | +            try: | 
|  | 210 | +                if SmileTestRunner(url).run(): | 
|  | 211 | +                    passed += 1 | 
|  | 212 | +            except Exception as e: | 
|  | 213 | +                logger.error(f"Test case {url} failed: {e}") | 
|  | 214 | +        (logger.info if passed == len(testcases) else logger.error)( | 
|  | 215 | +            f"Passed {passed} of {len(testcases)} test cases" | 
|  | 216 | +        ) | 
|  | 217 | +        return passed == len(testcases) | 
|  | 218 | +    except requests.RequestException as e: | 
|  | 219 | +        logger.error(f"Failed to fetch test cases from {testcases_url}: {e}") | 
|  | 220 | +        raise | 
|  | 221 | + | 
|  | 222 | + | 
|  | 223 | +def main(): | 
|  | 224 | +    import argparse | 
|  | 225 | +    parser = argparse.ArgumentParser(description="MSAL Feature Test Runner") | 
|  | 226 | +    group = parser.add_mutually_exclusive_group(required=True) | 
|  | 227 | +    group.add_argument("--testcase", help="URL for a single test case") | 
|  | 228 | +    group.add_argument("--batch", help="URL for a batch of test cases in JSON format") | 
|  | 229 | +    args = parser.parse_args() | 
|  | 230 | + | 
|  | 231 | +    if args.testcase: | 
|  | 232 | +        logger.setLevel(logging.DEBUG) | 
|  | 233 | +        success = SmileTestRunner(args.testcase).run() | 
|  | 234 | +    elif args.batch: | 
|  | 235 | +        logger.setLevel(logging.INFO) | 
|  | 236 | +        success = run_testcases(args.batch) | 
|  | 237 | + | 
|  | 238 | +    sys.exit(0 if success else 1) | 
|  | 239 | + | 
|  | 240 | +if __name__ == "__main__": | 
|  | 241 | +    main() | 
0 commit comments