diff --git a/src/config.rs b/src/config.rs index 3b66f28b2..8e1b025ce 100644 --- a/src/config.rs +++ b/src/config.rs @@ -166,6 +166,18 @@ pub struct DefGuardConfig { #[command(subcommand)] #[serde(skip_serializing)] pub cmd: Option, + + #[arg(long, env = "DEFGUARD_CHECK_PERIOD", default_value = "12h")] + #[serde(skip_serializing)] + pub check_period: Duration, + + #[arg(long, env = "DEFGUARD_CHECK_PERIOD_NO_LICENSE", default_value = "24h")] + #[serde(skip_serializing)] + pub check_period_no_license: Duration, + + #[arg(long, env = "DEFGUARD_CHECK_RENEWAL_WINDOW", default_value = "1h")] + #[serde(skip_serializing)] + pub check_period_renewal_window: Duration, } #[derive(Clone, Debug, Subcommand)] diff --git a/src/enterprise/license.rs b/src/enterprise/license.rs index 74b23f28a..cd22a869a 100644 --- a/src/enterprise/license.rs +++ b/src/enterprise/license.rs @@ -13,7 +13,7 @@ use sqlx::{error::Error as SqlxError, PgPool}; use thiserror::Error; use tokio::time::sleep; -use crate::{db::Settings, VERSION}; +use crate::{db::Settings, server_config, VERSION}; const LICENSE_SERVER_URL: &str = "https://pkgs.defguard.net/api/license/renew"; @@ -532,24 +532,13 @@ pub fn update_cached_license(key: Option<&str>) -> Result<(), LicenseError> { Ok(()) } - /// Amount of time before the license expiry date we should start the renewal attempts. const RENEWAL_TIME: TimeDelta = TimeDelta::hours(24); - -/// Maximum amount of time a license can be over its expiry date. const MAX_OVERDUE_TIME: TimeDelta = TimeDelta::days(14); -/// Periodic license check task -const CHECK_PERIOD: Duration = Duration::from_secs(12 * 60 * 60); - -/// Periodic license check task for the case when no license is present -const CHECK_PERIOD_NO_LICENSE: Duration = Duration::from_secs(24 * 60 * 60); - -/// Periodic license check task for the case when the license is about to expire -const CHECK_PERIOD_RENEWAL_WINDOW: Duration = Duration::from_secs(60 * 60); - pub async fn run_periodic_license_check(pool: PgPool) -> Result<(), LicenseError> { - let mut check_period: Duration = CHECK_PERIOD; + let config = server_config(); + let mut check_period: Duration = *config.check_period; info!( "Starting periodic license renewal check every {}", format_duration(check_period) @@ -559,7 +548,7 @@ pub async fn run_periodic_license_check(pool: PgPool) -> Result<(), LicenseError // Check if the license is present in the mutex, if not skip the check if get_cached_license().is_none() { debug!("No license found, skipping license check"); - sleep(CHECK_PERIOD_NO_LICENSE).await; + sleep(*config.check_period_no_license).await; continue; } @@ -578,7 +567,7 @@ pub async fn run_periodic_license_check(pool: PgPool) -> Result<(), LicenseError // check if we are pass the maximum expiration date, after which we don't // want to try to renew the license anymore if license.is_max_overdue() { - check_period = CHECK_PERIOD; + check_period = *config.check_period; warn!("Your license has expired and reached its maximum overdue date, please contact sales at salesdefguard.net"); debug!("Changing check period to {}", format_duration(check_period)); false @@ -607,13 +596,13 @@ pub async fn run_periodic_license_check(pool: PgPool) -> Result<(), LicenseError if requires_renewal { info!("License requires renewal, renewing license..."); - check_period = CHECK_PERIOD_RENEWAL_WINDOW; + check_period = *config.check_period_renewal_window; debug!("Changing check period to {}", format_duration(check_period)); match renew_license(&pool).await { Ok(new_license_key) => match save_license_key(&pool, &new_license_key).await { Ok(()) => { update_cached_license(Some(&new_license_key))?; - check_period = CHECK_PERIOD; + check_period = *config.check_period; debug!("Changing check period to {}", format_duration(check_period)); info!("Successfully renewed the license"); }