Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 90 additions & 39 deletions surfactant/cmd/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,47 +122,98 @@ def plugin_uninstall_cmd(plugin_name):
is_flag=True,
help="Update all plugins that implement the 'update_db' hook.",
)
def plugin_update_db_cmd(plugin_name, update_all):
@click.option(
"--allow-gpl",
"allow_gpl",
default=None,
is_flag=False,
flag_value="", # When used without =, it will be set to empty string
help="Allow GPL-licensed databases. Use '--allow-gpl' for one-time acceptance, or '--allow-gpl=always' or '--allow-gpl=never' to update the stored config setting.",
)
def plugin_update_db_cmd(plugin_name, update_all, allow_gpl):
"""Updates the database for a specified plugin or all plugins if --all is used."""
pm = get_plugin_manager()
call_init_hooks(pm, hook_filter=["update_db"], command_name="update-db")

if update_all:
# Update all plugins that implement the update_db hook
for plugin in pm.get_plugins():
if is_hook_implemented(pm, plugin, "update_db"):
plugin_name = pm.get_name(plugin) or pm.get_canonical_name(plugin)
click.echo(f"Updating {plugin_name} ...")
update_result = plugin.update_db()
if update_result:
click.echo(f"Update result for {plugin_name}: {update_result}")
else:
click.echo(f"No update operation performed for {plugin_name}.")
else:
if not plugin_name:
click.echo("Please specify a plugin name or use --all to update all plugins.", err=True)
return

plugin = find_plugin_by_name(pm, plugin_name) # Get an instance of the plugin

# Check if the plugin is registered
if not plugin:
click.echo(f"Plugin '{plugin_name}' not found.", err=True)
return

# Check if the plugin has implemented the update_db hook
has_update_db_hook = is_hook_implemented(pm, plugin, "update_db")

if not has_update_db_hook:
click.echo(f"Plugin '{plugin_name}' does not implement the 'update_db' hook.", err=True)
return
config_manager = ConfigManager()

# Call the update_db hook for the specified plugin
plugin_name = pm.get_name(plugin) or pm.get_canonical_name(plugin)
click.echo(f"Updating {plugin_name} ...")
update_result = plugin.update_db()
# Handle --allow-gpl flag
# When used as --allow-gpl (without =), Click sets it to empty string (from flag_value)
# When used as --allow-gpl=value, Click sets it to the value
# When not used, it's None
if allow_gpl is not None:
# If it's an empty string, treat as one-time acceptance
if allow_gpl == "":
# Set runtime override that won't be persisted to config file
config_manager.set_runtime_override("sources", "gpl_license_ok", "always")
allow_gpl = "once" # Track that this is a one-time override
else:
# Normalize the value
allow_gpl_lower = allow_gpl.lower()

if allow_gpl_lower in ("always", "a"):
# Permanently set to always accept GPL
config_manager.set("sources", "gpl_license_ok", "always")
click.echo("GPL license acceptance set to 'always'.")
allow_gpl = "always"
elif allow_gpl_lower in ("never", "n"):
# Permanently set to never accept GPL
config_manager.set("sources", "gpl_license_ok", "never")
click.echo("GPL license acceptance set to 'never'.")
allow_gpl = "never"
else:
# Unknown value, treat as error
click.echo(
f"Error: Invalid value for --allow-gpl: '{allow_gpl}'. Use 'always' or 'never', or use the flag without a value for one-time acceptance.",
err=True,
)
return

if update_result:
click.echo(f"Update result for {plugin_name}: {update_result}")
try:
call_init_hooks(pm, hook_filter=["update_db"], command_name="update-db")

if update_all:
# Update all plugins that implement the update_db hook
for plugin in pm.get_plugins():
if is_hook_implemented(pm, plugin, "update_db"):
plugin_name = pm.get_name(plugin) or pm.get_canonical_name(plugin)
click.echo(f"Updating {plugin_name} ...")
update_result = plugin.update_db()
if update_result:
click.echo(f"Update result for {plugin_name}: {update_result}")
else:
click.echo(f"No update operation performed for {plugin_name}.")
else:
click.echo(f"No update operation performed for {plugin_name}.")
if not plugin_name:
click.echo(
"Please specify a plugin name or use --all to update all plugins.", err=True
)
return

plugin = find_plugin_by_name(pm, plugin_name) # Get an instance of the plugin

# Check if the plugin is registered
if not plugin:
click.echo(f"Plugin '{plugin_name}' not found.", err=True)
return

# Check if the plugin has implemented the update_db hook
has_update_db_hook = is_hook_implemented(pm, plugin, "update_db")

if not has_update_db_hook:
click.echo(
f"Plugin '{plugin_name}' does not implement the 'update_db' hook.", err=True
)
return

# Call the update_db hook for the specified plugin
plugin_name = pm.get_name(plugin) or pm.get_canonical_name(plugin)
click.echo(f"Updating {plugin_name} ...")
update_result = plugin.update_db()

if update_result:
click.echo(f"Update result for {plugin_name}: {update_result}")
else:
click.echo(f"No update operation performed for {plugin_name}.")
finally:
# Clean up runtime override if it was a one-time acceptance
if allow_gpl == "once":
config_manager.clear_runtime_override("sources", "gpl_license_ok")
46 changes: 46 additions & 0 deletions surfactant/configmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def __init__(
self.config_dir = Path(config_dir) / app_name if config_dir else None
self.config = tomlkit.document()
self.config_file_path = self._get_config_file_path()
self._runtime_overrides: Dict[
str, Dict[str, Any]
] = {} # Runtime overlay for temporary values
self._load_config()

def _get_config_file_path(self) -> Path:
Expand Down Expand Up @@ -94,6 +97,10 @@ def get(self, section: str, option: str, fallback: Optional[Any] = None) -> Any:
Returns:
Any: The configuration value or the fallback value.
"""
# Check runtime overrides first (they take precedence)
if section in self._runtime_overrides and option in self._runtime_overrides[section]:
return self._runtime_overrides[section][option]

return self.config.get(section, {}).get(option, fallback)

def set(self, section: str, option: str, value: Any) -> None:
Expand Down Expand Up @@ -142,6 +149,45 @@ def delete_instance(cls, app_name: str) -> None:
if app_name in cls._instances:
del cls._instances[app_name]

def set_runtime_override(self, section: str, option: str, value: Any) -> None:
"""Sets a runtime override value that takes precedence over config file values.

Runtime overrides are not persisted to the config file and only exist in memory.
They take precedence over values loaded from the config file.

Args:
section (str): The section within the configuration.
option (str): The option within the section.
value (Any): The value to set as a runtime override.
"""
if section not in self._runtime_overrides:
self._runtime_overrides[section] = {}
self._runtime_overrides[section][option] = value

def clear_runtime_override(self, section: str, option: str) -> None:
"""Clears a runtime override value.

Args:
section (str): The section within the configuration.
option (str): The option within the section.
"""
if section in self._runtime_overrides and option in self._runtime_overrides[section]:
del self._runtime_overrides[section][option]
if not self._runtime_overrides[section]:
del self._runtime_overrides[section]

def clear_all_runtime_overrides(self) -> None:
"""Clears all runtime override values."""
self._runtime_overrides.clear()

def has_runtime_overrides(self) -> bool:
"""Check if any runtime overrides are set.

Returns:
bool: True if any runtime overrides exist, False otherwise.
"""
return bool(self._runtime_overrides)

def get_data_dir_path(self) -> Path:
"""Determines the path to the data directory, for storing things such as databases.

Expand Down
13 changes: 8 additions & 5 deletions surfactant/database_manager/database_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,11 @@
from surfactant.configmanager import ConfigManager
from surfactant.database_manager.utils import (
calculate_hash,
check_gpl_acceptance,
download_content,
get_source_for,
load_db_version_metadata,
save_db_version_metadata,
check_gpl_acceptance,
)


Expand Down Expand Up @@ -82,9 +82,7 @@ def __init__(self, config: DatabaseConfig) -> None:
self.config.source = url
self.config.gpl = gpl
self._overridden = overridden
logger.debug(
"Using external URL override for {}: {}", self.config.database_key, url
)
logger.debug("Using external URL override for {}: {}", self.config.database_key, url)
else:
self._overridden = False
logger.debug("Using hard-coded URL for {}", self.config.database_key)
Expand Down Expand Up @@ -173,7 +171,12 @@ def parse_raw_data(self, raw_data: str) -> Dict[str, Any]:
def download_and_update_database(self) -> str:
# Check GPL acceptance before download
if self.config.gpl:
if not check_gpl_acceptance(self.config.database_dir, self.config.database_key, self.config.gpl, getattr(self, '_overridden', False)):
if not check_gpl_acceptance(
self.config.database_dir,
self.config.database_key,
self.config.gpl,
getattr(self, "_overridden", False),
):
return f"Download aborted: '{self.config.database_key}' is GPL-licensed and user did not accept."

raw_data = download_content(self.config.source)
Expand Down
35 changes: 26 additions & 9 deletions surfactant/database_manager/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,30 +222,47 @@ def check_gpl_acceptance(database_category: str, key: str, gpl: bool, overridden
"""
if not gpl or overridden:
return True

config_manager = ConfigManager()

# Check GPL setting (includes runtime overrides which take precedence)
gpl_setting = config_manager.get("sources", "gpl_license_ok")
if gpl_setting in ("always", "a", True):
return True
if gpl_setting in ("never", "n", False):
return False
# Prompt user

# Prompt user if no setting is configured
return _prompt_user_for_gpl_acceptance(config_manager, database_category, key)


def _prompt_user_for_gpl_acceptance(
config_manager: ConfigManager, database_category: str, key: str
) -> bool:
"""
Prompt the user for GPL acceptance and optionally save their preference.

Returns:
bool: True if user accepts, False otherwise.
"""
prompt = (
f"The pattern database '{key}' in category '{database_category}' is GPL-licensed. "
"Do you want to download it? [y]es/[n]o/[a]lways/[N]ever: "
)
try:
user_input = input(prompt).strip()
user_input_lower = user_input.lower()
except Exception:
except (EOFError, KeyboardInterrupt):
return False

user_input_lower = user_input.lower()

# Handle always/never options that update config
if user_input_lower in ("a", "always"):
config_manager.set("Settings", "gpl_license_ok", "always")
return True
if user_input_lower in ("never") or user_input in ("N"):
if user_input_lower in ("never") or user_input == "N":
config_manager.set("Settings", "gpl_license_ok", "never")
return False
if user_input_lower in ("no") or user_input in ("n"):
return False
if user_input_lower in ("y", "yes"):
return True
return False

# Handle yes/no for this time only
return user_input_lower in ("y", "yes")
Loading