diff --git a/surfactant/cmd/plugin.py b/surfactant/cmd/plugin.py index 6be94e75..0acae294 100644 --- a/surfactant/cmd/plugin.py +++ b/surfactant/cmd/plugin.py @@ -122,47 +122,98 @@ def plugin_uninstall_cmd(plugin_name): is_flag=True, help="Update all plugins that implement the 'update_db' hook.", ) -def plugin_update_db_cmd(plugin_name, update_all): +@click.option( + "--allow-gpl", + "allow_gpl", + default=None, + is_flag=False, + flag_value="", # When used without =, it will be set to empty string + help="Allow GPL-licensed databases. Use '--allow-gpl' for one-time acceptance, or '--allow-gpl=always' or '--allow-gpl=never' to update the stored config setting.", +) +def plugin_update_db_cmd(plugin_name, update_all, allow_gpl): """Updates the database for a specified plugin or all plugins if --all is used.""" pm = get_plugin_manager() - call_init_hooks(pm, hook_filter=["update_db"], command_name="update-db") - - if update_all: - # Update all plugins that implement the update_db hook - for plugin in pm.get_plugins(): - if is_hook_implemented(pm, plugin, "update_db"): - plugin_name = pm.get_name(plugin) or pm.get_canonical_name(plugin) - click.echo(f"Updating {plugin_name} ...") - update_result = plugin.update_db() - if update_result: - click.echo(f"Update result for {plugin_name}: {update_result}") - else: - click.echo(f"No update operation performed for {plugin_name}.") - else: - if not plugin_name: - click.echo("Please specify a plugin name or use --all to update all plugins.", err=True) - return - - plugin = find_plugin_by_name(pm, plugin_name) # Get an instance of the plugin - - # Check if the plugin is registered - if not plugin: - click.echo(f"Plugin '{plugin_name}' not found.", err=True) - return - - # Check if the plugin has implemented the update_db hook - has_update_db_hook = is_hook_implemented(pm, plugin, "update_db") - - if not has_update_db_hook: - click.echo(f"Plugin '{plugin_name}' does not implement the 'update_db' hook.", err=True) - return + config_manager = ConfigManager() - # Call the update_db hook for the specified plugin - plugin_name = pm.get_name(plugin) or pm.get_canonical_name(plugin) - click.echo(f"Updating {plugin_name} ...") - update_result = plugin.update_db() + # Handle --allow-gpl flag + # When used as --allow-gpl (without =), Click sets it to empty string (from flag_value) + # When used as --allow-gpl=value, Click sets it to the value + # When not used, it's None + if allow_gpl is not None: + # If it's an empty string, treat as one-time acceptance + if allow_gpl == "": + # Set runtime override that won't be persisted to config file + config_manager.set_runtime_override("sources", "gpl_license_ok", "always") + allow_gpl = "once" # Track that this is a one-time override + else: + # Normalize the value + allow_gpl_lower = allow_gpl.lower() + + if allow_gpl_lower in ("always", "a"): + # Permanently set to always accept GPL + config_manager.set("sources", "gpl_license_ok", "always") + click.echo("GPL license acceptance set to 'always'.") + allow_gpl = "always" + elif allow_gpl_lower in ("never", "n"): + # Permanently set to never accept GPL + config_manager.set("sources", "gpl_license_ok", "never") + click.echo("GPL license acceptance set to 'never'.") + allow_gpl = "never" + else: + # Unknown value, treat as error + click.echo( + f"Error: Invalid value for --allow-gpl: '{allow_gpl}'. Use 'always' or 'never', or use the flag without a value for one-time acceptance.", + err=True, + ) + return - if update_result: - click.echo(f"Update result for {plugin_name}: {update_result}") + try: + call_init_hooks(pm, hook_filter=["update_db"], command_name="update-db") + + if update_all: + # Update all plugins that implement the update_db hook + for plugin in pm.get_plugins(): + if is_hook_implemented(pm, plugin, "update_db"): + plugin_name = pm.get_name(plugin) or pm.get_canonical_name(plugin) + click.echo(f"Updating {plugin_name} ...") + update_result = plugin.update_db() + if update_result: + click.echo(f"Update result for {plugin_name}: {update_result}") + else: + click.echo(f"No update operation performed for {plugin_name}.") else: - click.echo(f"No update operation performed for {plugin_name}.") + if not plugin_name: + click.echo( + "Please specify a plugin name or use --all to update all plugins.", err=True + ) + return + + plugin = find_plugin_by_name(pm, plugin_name) # Get an instance of the plugin + + # Check if the plugin is registered + if not plugin: + click.echo(f"Plugin '{plugin_name}' not found.", err=True) + return + + # Check if the plugin has implemented the update_db hook + has_update_db_hook = is_hook_implemented(pm, plugin, "update_db") + + if not has_update_db_hook: + click.echo( + f"Plugin '{plugin_name}' does not implement the 'update_db' hook.", err=True + ) + return + + # Call the update_db hook for the specified plugin + plugin_name = pm.get_name(plugin) or pm.get_canonical_name(plugin) + click.echo(f"Updating {plugin_name} ...") + update_result = plugin.update_db() + + if update_result: + click.echo(f"Update result for {plugin_name}: {update_result}") + else: + click.echo(f"No update operation performed for {plugin_name}.") + finally: + # Clean up runtime override if it was a one-time acceptance + if allow_gpl == "once": + config_manager.clear_runtime_override("sources", "gpl_license_ok") diff --git a/surfactant/configmanager.py b/surfactant/configmanager.py index 0c62ab22..806b945a 100644 --- a/surfactant/configmanager.py +++ b/surfactant/configmanager.py @@ -59,6 +59,9 @@ def __init__( self.config_dir = Path(config_dir) / app_name if config_dir else None self.config = tomlkit.document() self.config_file_path = self._get_config_file_path() + self._runtime_overrides: Dict[ + str, Dict[str, Any] + ] = {} # Runtime overlay for temporary values self._load_config() def _get_config_file_path(self) -> Path: @@ -94,6 +97,10 @@ def get(self, section: str, option: str, fallback: Optional[Any] = None) -> Any: Returns: Any: The configuration value or the fallback value. """ + # Check runtime overrides first (they take precedence) + if section in self._runtime_overrides and option in self._runtime_overrides[section]: + return self._runtime_overrides[section][option] + return self.config.get(section, {}).get(option, fallback) def set(self, section: str, option: str, value: Any) -> None: @@ -142,6 +149,45 @@ def delete_instance(cls, app_name: str) -> None: if app_name in cls._instances: del cls._instances[app_name] + def set_runtime_override(self, section: str, option: str, value: Any) -> None: + """Sets a runtime override value that takes precedence over config file values. + + Runtime overrides are not persisted to the config file and only exist in memory. + They take precedence over values loaded from the config file. + + Args: + section (str): The section within the configuration. + option (str): The option within the section. + value (Any): The value to set as a runtime override. + """ + if section not in self._runtime_overrides: + self._runtime_overrides[section] = {} + self._runtime_overrides[section][option] = value + + def clear_runtime_override(self, section: str, option: str) -> None: + """Clears a runtime override value. + + Args: + section (str): The section within the configuration. + option (str): The option within the section. + """ + if section in self._runtime_overrides and option in self._runtime_overrides[section]: + del self._runtime_overrides[section][option] + if not self._runtime_overrides[section]: + del self._runtime_overrides[section] + + def clear_all_runtime_overrides(self) -> None: + """Clears all runtime override values.""" + self._runtime_overrides.clear() + + def has_runtime_overrides(self) -> bool: + """Check if any runtime overrides are set. + + Returns: + bool: True if any runtime overrides exist, False otherwise. + """ + return bool(self._runtime_overrides) + def get_data_dir_path(self) -> Path: """Determines the path to the data directory, for storing things such as databases. diff --git a/surfactant/database_manager/database_utils.py b/surfactant/database_manager/database_utils.py index f7172fd2..05f220f7 100755 --- a/surfactant/database_manager/database_utils.py +++ b/surfactant/database_manager/database_utils.py @@ -19,11 +19,11 @@ from surfactant.configmanager import ConfigManager from surfactant.database_manager.utils import ( calculate_hash, + check_gpl_acceptance, download_content, get_source_for, load_db_version_metadata, save_db_version_metadata, - check_gpl_acceptance, ) @@ -82,9 +82,7 @@ def __init__(self, config: DatabaseConfig) -> None: self.config.source = url self.config.gpl = gpl self._overridden = overridden - logger.debug( - "Using external URL override for {}: {}", self.config.database_key, url - ) + logger.debug("Using external URL override for {}: {}", self.config.database_key, url) else: self._overridden = False logger.debug("Using hard-coded URL for {}", self.config.database_key) @@ -173,7 +171,12 @@ def parse_raw_data(self, raw_data: str) -> Dict[str, Any]: def download_and_update_database(self) -> str: # Check GPL acceptance before download if self.config.gpl: - if not check_gpl_acceptance(self.config.database_dir, self.config.database_key, self.config.gpl, getattr(self, '_overridden', False)): + if not check_gpl_acceptance( + self.config.database_dir, + self.config.database_key, + self.config.gpl, + getattr(self, "_overridden", False), + ): return f"Download aborted: '{self.config.database_key}' is GPL-licensed and user did not accept." raw_data = download_content(self.config.source) diff --git a/surfactant/database_manager/utils.py b/surfactant/database_manager/utils.py index c8755f66..e37a89ac 100644 --- a/surfactant/database_manager/utils.py +++ b/surfactant/database_manager/utils.py @@ -222,30 +222,47 @@ def check_gpl_acceptance(database_category: str, key: str, gpl: bool, overridden """ if not gpl or overridden: return True + config_manager = ConfigManager() + + # Check GPL setting (includes runtime overrides which take precedence) gpl_setting = config_manager.get("sources", "gpl_license_ok") if gpl_setting in ("always", "a", True): return True if gpl_setting in ("never", "n", False): return False - # Prompt user + + # Prompt user if no setting is configured + return _prompt_user_for_gpl_acceptance(config_manager, database_category, key) + + +def _prompt_user_for_gpl_acceptance( + config_manager: ConfigManager, database_category: str, key: str +) -> bool: + """ + Prompt the user for GPL acceptance and optionally save their preference. + + Returns: + bool: True if user accepts, False otherwise. + """ prompt = ( f"The pattern database '{key}' in category '{database_category}' is GPL-licensed. " "Do you want to download it? [y]es/[n]o/[a]lways/[N]ever: " ) try: user_input = input(prompt).strip() - user_input_lower = user_input.lower() - except Exception: + except (EOFError, KeyboardInterrupt): return False + + user_input_lower = user_input.lower() + + # Handle always/never options that update config if user_input_lower in ("a", "always"): config_manager.set("Settings", "gpl_license_ok", "always") return True - if user_input_lower in ("never") or user_input in ("N"): + if user_input_lower in ("never") or user_input == "N": config_manager.set("Settings", "gpl_license_ok", "never") return False - if user_input_lower in ("no") or user_input in ("n"): - return False - if user_input_lower in ("y", "yes"): - return True - return False + + # Handle yes/no for this time only + return user_input_lower in ("y", "yes") diff --git a/tests/cmd/test_plugin.py b/tests/cmd/test_plugin.py new file mode 100644 index 00000000..756857d6 --- /dev/null +++ b/tests/cmd/test_plugin.py @@ -0,0 +1,190 @@ +# Copyright 2025 Lawrence Livermore National Security, LLC +# See the top-level LICENSE file for details. +# +# SPDX-License-Identifier: MIT + +import tempfile +from unittest.mock import MagicMock, patch + +import pytest +from click.testing import CliRunner + +from surfactant.cmd.plugin import plugin_update_db_cmd +from surfactant.configmanager import ConfigManager +from surfactant.database_manager.utils import check_gpl_acceptance + + +@pytest.fixture +def temp_config_dir(tmp_path): # pylint: disable=redefined-outer-name + """Create a temporary config directory for testing.""" + config_dir = tmp_path / "test_config" + config_dir.mkdir() + return config_dir + + +@pytest.fixture +def isolated_config(temp_config_dir): # pylint: disable=redefined-outer-name + """Create an isolated ConfigManager instance for testing.""" + # Delete any existing instance + ConfigManager.delete_instance("surfactant") + # Create a new instance with temporary config directory + config_manager = ConfigManager(config_dir=str(temp_config_dir.parent)) + yield config_manager + # Clean up + ConfigManager.delete_instance("surfactant") + + +def test_allow_gpl_flag_once(isolated_config): # pylint: disable=redefined-outer-name + """Test --allow-gpl flag without a value (one-time acceptance).""" + runner = CliRunner() + + with patch("surfactant.cmd.plugin.get_plugin_manager") as mock_pm_getter: + # Mock plugin manager + mock_pm = MagicMock() + mock_pm_getter.return_value = mock_pm + mock_pm.get_plugins.return_value = [] + + with patch("surfactant.cmd.plugin.call_init_hooks"): + # Run command with --allow-gpl flag (no value) + result = runner.invoke(plugin_update_db_cmd, ["--allow-gpl", "--all"]) + + # Check that command succeeded + assert result.exit_code == 0 + + # Verify runtime override is NOT persisted (should be cleaned up) + gpl_setting = isolated_config.get("sources", "gpl_license_ok") + assert gpl_setting is None + + # Verify runtime override was cleared + assert not isolated_config.has_runtime_overrides() + + +def test_allow_gpl_flag_always(isolated_config): # pylint: disable=redefined-outer-name + """Test --allow-gpl=always to permanently set GPL acceptance.""" + runner = CliRunner() + + with patch("surfactant.cmd.plugin.get_plugin_manager") as mock_pm_getter: + # Mock plugin manager + mock_pm = MagicMock() + mock_pm_getter.return_value = mock_pm + mock_pm.get_plugins.return_value = [] + + with patch("surfactant.cmd.plugin.call_init_hooks"): + # Run command with --allow-gpl=always + result = runner.invoke(plugin_update_db_cmd, ["--allow-gpl=always", "--all"]) + + # Check that command succeeded + assert result.exit_code == 0 + assert "GPL license acceptance set to 'always'" in result.output + + # Verify permanent setting is stored + gpl_setting = isolated_config.get("sources", "gpl_license_ok") + assert gpl_setting == "always" + + +def test_allow_gpl_flag_never(isolated_config): # pylint: disable=redefined-outer-name + """Test --allow-gpl=never to permanently disable GPL acceptance.""" + runner = CliRunner() + + with patch("surfactant.cmd.plugin.get_plugin_manager") as mock_pm_getter: + # Mock plugin manager + mock_pm = MagicMock() + mock_pm_getter.return_value = mock_pm + mock_pm.get_plugins.return_value = [] + + with patch("surfactant.cmd.plugin.call_init_hooks"): + # Run command with --allow-gpl=never + result = runner.invoke(plugin_update_db_cmd, ["--allow-gpl=never", "--all"]) + + # Check that command succeeded + assert result.exit_code == 0 + assert "GPL license acceptance set to 'never'" in result.output + + # Verify permanent setting is stored + gpl_setting = isolated_config.get("sources", "gpl_license_ok") + assert gpl_setting == "never" + + +def test_check_gpl_acceptance_with_runtime_flag(): + """Test that check_gpl_acceptance respects the runtime allow_gpl flag.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create isolated config + ConfigManager.delete_instance("surfactant") + config_manager = ConfigManager(config_dir=tmpdir) + + try: + # Set runtime override + config_manager.set_runtime_override("sources", "gpl_license_ok", "always") + + # Test that GPL is accepted due to runtime override + result = check_gpl_acceptance( + database_category="test_category", key="test_key", gpl=True, overridden=False + ) + assert result is True + finally: + ConfigManager.delete_instance("surfactant") + + +def test_check_gpl_acceptance_with_permanent_always(): + """Test that check_gpl_acceptance respects permanent 'always' setting.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create isolated config + ConfigManager.delete_instance("surfactant") + config_manager = ConfigManager(config_dir=tmpdir) + + try: + # Set permanent always flag + config_manager.set("sources", "gpl_license_ok", "always") + + # Test that GPL is accepted + result = check_gpl_acceptance( + database_category="test_category", key="test_key", gpl=True, overridden=False + ) + assert result is True + finally: + ConfigManager.delete_instance("surfactant") + + +def test_check_gpl_acceptance_with_permanent_never(): + """Test that check_gpl_acceptance respects permanent 'never' setting.""" + with tempfile.TemporaryDirectory() as tmpdir: + # Create isolated config + ConfigManager.delete_instance("surfactant") + config_manager = ConfigManager(config_dir=tmpdir) + + try: + # Set permanent never flag + config_manager.set("sources", "gpl_license_ok", "never") + + # Test that GPL is rejected + result = check_gpl_acceptance( + database_category="test_category", key="test_key", gpl=True, overridden=False + ) + assert result is False + finally: + ConfigManager.delete_instance("surfactant") + + +def test_no_allow_gpl_flag(isolated_config): # pylint: disable=redefined-outer-name + """Test command without --allow-gpl flag (default behavior).""" + runner = CliRunner() + + with patch("surfactant.cmd.plugin.get_plugin_manager") as mock_pm_getter: + # Mock plugin manager + mock_pm = MagicMock() + mock_pm_getter.return_value = mock_pm + mock_pm.get_plugins.return_value = [] + + with patch("surfactant.cmd.plugin.call_init_hooks"): + # Run command without --allow-gpl flag + result = runner.invoke(plugin_update_db_cmd, ["--all"]) + + # Check that command succeeded + assert result.exit_code == 0 + + # Verify no GPL settings were changed + gpl_setting = isolated_config.get("sources", "gpl_license_ok") + assert gpl_setting is None + + # Verify no runtime overrides were set + assert not isolated_config.has_runtime_overrides()