Skip to content

Commit

Permalink
Merge pull request #2540 from dathere/luau_cum_helpers_name_optional
Browse files Browse the repository at this point in the history
refactor: `luau` all cumulative helper functions (cum_) now have name as an optional argument
  • Loading branch information
jqnatividad authored Feb 19, 2025
2 parents 02b70d3 + 07b49b3 commit ef8f0b1
Show file tree
Hide file tree
Showing 2 changed files with 133 additions and 69 deletions.
154 changes: 109 additions & 45 deletions src/cmd/luau.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2218,17 +2218,23 @@ fn setup_helpers(
// this is a helper function that calculates the cumulative sum of a numeric column.
// if the input cannot be converted to a number, it returns 0 for that row.
//
// qsv_cumsum(name, value)
// name: identifier for this cumulative sum (allows multiple sums to run in parallel)
// qsv_cumsum(value, name)
// value: the numeric value to add to the cumulative sum
// name: (optional) identifier for this cumulative sum
// (allows multiple sums to run in parallel)
// returns: the cumulative sum up to the current row for the named sum
//
let qsv_cumsum = luau.create_function(|_, (name, value): (String, mlua::Value)| {
let qsv_cumsum = luau.create_function(|_, (value, name): (mlua::Value, Option<String>)| {
// Get static cumulative sums using thread_local storage
thread_local! {
static CUMSUMS: RefCell<HashMap<String, f64>> = RefCell::new(HashMap::new());
}

let key = if let Some(name) = name {
format!("_qsv_cumsum_{name}")
} else {
"_qsv_cumsum".to_string()
};
// Convert input value to number, defaulting to 0.0 if conversion fails
let num = match value {
Value::Number(n) => n,
Expand All @@ -2240,7 +2246,7 @@ fn setup_helpers(
// Update cumulative sum for this name
CUMSUMS.with(|cs| {
let mut sums = cs.borrow_mut();
let sum = sums.entry(name).or_insert(0.0);
let sum = sums.entry(key).or_insert(0.0);
*sum += num;
Ok(*sum)
})
Expand All @@ -2250,17 +2256,23 @@ fn setup_helpers(
// this is a helper function that calculates the cumulative product of a numeric column.
// if the input cannot be converted to a number, it returns 1 for that row.
//
// qsv_cumprod(name, value)
// name: identifier for this cumulative product
// (allows multiple products to run in parallel)
// qsv_cumprod(value, name)
// value: the numeric value to multiply with the cumulative product
// name: (optional) identifier for this cumulative product
// (allows multiple products to run in parallel)
// returns: the cumulative product up to the current row for the named product
//
let qsv_cumprod = luau.create_function(|_, (name, value): (String, mlua::Value)| {
let qsv_cumprod = luau.create_function(|_, (value, name): (mlua::Value, Option<String>)| {
thread_local! {
static CUMPRODS: RefCell<HashMap<String, f64>> = RefCell::new(HashMap::new());
}

let key = if let Some(name) = name {
format!("_qsv_cumprod_{name}")
} else {
"_qsv_cumprod".to_string()
};

let num = match value {
Value::Number(n) => n,
Value::Integer(i) => i as f64,
Expand All @@ -2270,7 +2282,7 @@ fn setup_helpers(

CUMPRODS.with(|cp| {
let mut prods = cp.borrow_mut();
let prod = prods.entry(name).or_insert(1.0);
let prod = prods.entry(key).or_insert(1.0);
*prod *= num;
Ok(*prod)
})
Expand All @@ -2280,17 +2292,23 @@ fn setup_helpers(
// this is a helper function that calculates the cumulative maximum of a numeric column.
// if the input cannot be converted to a number, it returns negative infinity for that row.
//
// qsv_cummax(name, value)
// name: identifier for this cumulative maximum
// (allows multiple maximums to run in parallel)
// qsv_cummax(value, name)
// value: the numeric value to compare with the cumulative maximum
// name: (optional) identifier for this cumulative maximum
// (allows multiple maximums to run in parallel)
// returns: the cumulative maximum up to the current row for the named maximum
//
let qsv_cummax = luau.create_function(|_, (name, value): (String, mlua::Value)| {
let qsv_cummax = luau.create_function(|_, (value, name): (mlua::Value, Option<String>)| {
thread_local! {
static CUMMAXS: RefCell<HashMap<String, f64>> = RefCell::new(HashMap::new());
}

let key = if let Some(name) = name {
format!("_qsv_cummax_{name}")
} else {
"_qsv_cummax".to_string()
};

let num = match value {
Value::Number(n) => n,
Value::Integer(i) => i as f64,
Expand All @@ -2303,7 +2321,7 @@ fn setup_helpers(

CUMMAXS.with(|cm| {
let mut maxs = cm.borrow_mut();
let max = maxs.entry(name).or_insert(f64::NEG_INFINITY);
let max = maxs.entry(key).or_insert(f64::NEG_INFINITY);
*max = max.max(num);
Ok(*max)
})
Expand All @@ -2313,17 +2331,23 @@ fn setup_helpers(
// this is a helper function that calculates the cumulative minimum of a numeric column.
// if the input cannot be converted to a number, it returns positive infinity for that row.
//
// qsv_cummin(name, value)
// name: identifier for this cumulative minimum
// (allows multiple minimums to run in parallel)
// qsv_cummin(value, name)
// value: the numeric value to compare with the cumulative minimum
// name: (optional) identifier for this cumulative minimum
// (allows multiple minimums to run in parallel)
// returns: the cumulative minimum up to the current row for the named minimum
//
let qsv_cummin = luau.create_function(|_, (name, value): (String, mlua::Value)| {
let qsv_cummin = luau.create_function(|_, (value, name): (mlua::Value, Option<String>)| {
thread_local! {
static CUMMINS: RefCell<HashMap<String, f64>> = RefCell::new(HashMap::new());
}

let key = if let Some(name) = name {
format!("_qsv_cummin_{name}")
} else {
"_qsv_cummin".to_string()
};

let num = match value {
Value::Number(n) => n,
Value::Integer(i) => i as f64,
Expand All @@ -2333,7 +2357,7 @@ fn setup_helpers(

CUMMINS.with(|cm| {
let mut mins = cm.borrow_mut();
let min = mins.entry(name).or_insert(f64::INFINITY);
let min = mins.entry(key).or_insert(f64::INFINITY);
*min = min.min(num);
Ok(*min)
})
Expand All @@ -2342,19 +2366,43 @@ fn setup_helpers(

// qsv_lag - returns lagged value with optional default
//
// qsv_lag(name, value, lag, default)
// name: identifier for this lag (allows multiple lags to run in parallel)
// qsv_lag(value, name, lag, default)
// value: the value to lag
// lag: (optional) number of rows to lag by (default: 1)
// default: (optional) value to return for rows before lag is available (default: "0")
// name: (optional) identifier for this lag. Note that you need to specify lag and
// default if you want to use a named lag.
// (allows multiple lags to run in parallel)
// returns: the value from 'lag' rows ago, or default if not enough rows seen yet
let qsv_lag = luau.create_function(|luau, (name, value, lag, default): (String, mlua::Value, Option<i64>, Option<mlua::Value>)| {
let qsv_lag = luau.create_function(|luau, args: mlua::MultiValue| {
let args: Vec<mlua::Value> = args.into_iter().collect();

if args.is_empty() {
return helper_err!("qsv_lag", "requires at least 1 argument: value");
}

let value = args[0].clone();
let lag = if args.len() > 1 {
args[1].as_i64().unwrap_or(1)
} else {
1
};
let default = if args.len() > 2 {
args[2].clone()
} else {
mlua::Value::String(luau.create_string("0")?)
};
let name = if args.len() > 3 {
args[3].to_string()?
} else {
String::new()
};

thread_local! {
static LAGS: RefCell<HashMap<String, Vec<String>>> = RefCell::new(HashMap::new());
}

let lag = lag.unwrap_or(1);
let key = format!("{name}_{lag}");
let key = format!("_qsv_lag_{name}_{lag}");

// Convert the value to a string to store it
let value_str = match &value {
Expand All @@ -2373,9 +2421,7 @@ fn setup_helpers(

if values.len() as i64 <= lag {
// Return the default value when not enough history
Ok(default.unwrap_or_else(|| {
mlua::Value::String(luau.create_string("0").unwrap())
}))
Ok(default)
} else {
let lagged_value = &values[values.len() - 1 - lag as usize];
Ok(mlua::Value::String(luau.create_string(lagged_value)?))
Expand All @@ -2386,16 +2432,22 @@ fn setup_helpers(

// qsv_cumany - returns true if any value so far has been truthy
//
// qsv_cumany(name, value)
// name: identifier for this cumulative any
// (allows multiple cumany's to run in parallel)
// qsv_cumany(value, name)
// value: the value to check for truthiness
// name: (optional) identifier for this cumulative any
// (allows multiple cumany's to run in parallel)
// returns: true if any value seen so far has been truthy, false otherwise
let qsv_cumany = luau.create_function(|_, (name, value): (String, mlua::Value)| {
let qsv_cumany = luau.create_function(|_, (value, name): (mlua::Value, Option<String>)| {
thread_local! {
static CUMANYS: RefCell<HashMap<String, bool>> = RefCell::new(HashMap::new());
}

let key = if let Some(name) = name {
format!("_qsv_cumany_{name}")
} else {
"_qsv_cumany".to_string()
};

let is_truthy = match value {
Value::Boolean(b) => b,
Value::Number(n) => n != 0.0,
Expand All @@ -2407,7 +2459,7 @@ fn setup_helpers(

CUMANYS.with(|ca| {
let mut anys = ca.borrow_mut();
let any = anys.entry(name).or_insert(false);
let any = anys.entry(key).or_insert(false);
*any = *any || is_truthy;
Ok(*any)
})
Expand All @@ -2416,16 +2468,22 @@ fn setup_helpers(

// qsv_cumall - returns true if all values so far have been truthy
//
// qsv_cumall(name, value)
// name: identifier for this cumulative all
// (allows multiple cumall's to run in parallel)
// value: the value to check for truthiness
// returns: true if all values seen so far have been truthy, false otherwise
let qsv_cumall = luau.create_function(|_, (name, value): (String, mlua::Value)| {
// qsv_cumall(value, name)
// value: the value to check for truthiness
// name: (optional) identifier for this cumulative all
// (allows multiple cumall's to run in parallel)
// returns: true if all values seen so far have been truthy, false otherwise
let qsv_cumall = luau.create_function(|_, (value, name): (mlua::Value, Option<String>)| {
thread_local! {
static CUMALLS: RefCell<HashMap<String, bool>> = RefCell::new(HashMap::new());
}

let key = if let Some(name) = name {
format!("_qsv_cumall_{name}")
} else {
"_qsv_cumall".to_string()
};

let is_truthy = match value {
Value::Boolean(b) => b,
Value::Number(n) => n != 0.0,
Expand All @@ -2437,7 +2495,7 @@ fn setup_helpers(

CUMALLS.with(|ca| {
let mut all_vals = ca.borrow_mut();
let all = all_vals.entry(name).or_insert(true);
let all = all_vals.entry(key).or_insert(true);
*all = *all && is_truthy;
Ok(*all)
})
Expand Down Expand Up @@ -2487,7 +2545,7 @@ fn setup_helpers(
};

// Generate unique name for the accumulator state
let state_name = format!("_qsv_accumulate_state_{column_name}");
let state_name = format!("_qsv_accumulate_{column_name}");

// Get existing accumulator value or use initial value
let prev_acc = if let Ok(prev) = luau.globals().get::<f64>(&*state_name) {
Expand Down Expand Up @@ -2530,20 +2588,26 @@ fn setup_helpers(
// qsv_diff - returns difference between current and previous value
//
// qsv_diff(name, value[, periods])
// name: identifier for this diff (allows multiple diffs to run in parallel)
// value: the value to calculate difference for
// periods: optional number of periods to look back (default: 1)
// periods: (optional) number of periods to look back (default: 1)
// name: (optional) identifier for this diff
// (allows multiple diffs to run in parallel)
// Note that you need to specify periods if you want to use a named diff.
// returns: difference between current value and value 'periods' rows back
// returns 0 if not enough history available yet
let qsv_diff = luau.create_function(
|_, (name, value, periods): (String, mlua::Value, Option<i64>)| {
|_, (value, periods, name): (mlua::Value, Option<i64>, Option<String>)| {
thread_local! {
static DIFFS: RefCell<HashMap<String, Vec<f64>>> = RefCell::new(HashMap::new());
}

let periods = periods.unwrap_or(1);
// Create a unique key that includes both the name and periods
let key = format!("{name}_{periods}");
// Create a unique key that includes both periods and name
let key = if let Some(name) = name {
format!("_qsv_diff_{periods}_{name}")
} else {
format!("_qsv_diff_{periods}")
};

let num = match value {
Value::Number(n) => n,
Expand Down
Loading

0 comments on commit ef8f0b1

Please sign in to comment.