Skip to content

Commit

Permalink
fix: luau qsv_accumulate arg processing
Browse files Browse the repository at this point in the history
actually take the value, not the column name
  • Loading branch information
jqnatividad committed Feb 20, 2025
1 parent def3f1d commit 4db04ac
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 38 deletions.
57 changes: 23 additions & 34 deletions src/cmd/luau.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2509,47 +2509,37 @@ fn setup_helpers(

// qsv_accumulate - accumulates values using a custom function
//
// qsv_accumulate(column: string, func: function, init: number, name: string)
// column: the name of the NUMERIC column to accumulate over
// IMPORTANT: Be sure to enclose the name in double quotes,
// Otherwise, the current value of the column is passed, and not the
// name, resulting in a runtime error.
// func: function that takes two arguments (prev_acc, curr_val) and returns
// the new accumulated value. prev_acc is the previously accumulated value,
// curr_val is the current value from the column.
// init: (optional) initial value. If not provided, defaults to 0.0
// If the column value is not a number, 0.0 is used as the initial value.
// name: (optional) identifier for this accumulator.
// Note that you need to specify name if you want to use a named accumulator.
// (allows multiple accumulators to run in parallel)
// returns: the accumulated value for the current row
// or Luau runtime error if invalid arguments
// qsv_accumulate(value: number, func: function, init: number, name: string)
// value: the numeric value to accumulate over.
// If the value is not a number, 0.0 is used.
// func: function that takes two arguments (prev_acc, curr_val) and returns
// the new accumulated value. prev_acc is the previously accumulated value,
// curr_val is the current value.
// init: (optional) initial value. If not provided, defaults to the first column value.
// If the column value is not a number, 0.0 is used as the initial value.
// name: (optional) identifier for this accumulator.
// Note that you need to specify name if you want to use a named accumulator.
// (allows multiple accumulators to run in parallel)
// returns: the accumulated value for the current row
// or Luau runtime error if invalid arguments
let qsv_accumulate = luau.create_function(
|luau,
(column_name, func, init, name): (
String,
(value, func, init, name): (
mlua::Value,
mlua::Function,
Option<mlua::Value>,
Option<String>,
)| {
// Get the current value from the column
let Ok(column) = luau.globals().raw_get::<String>(&*column_name) else {
return helper_err!(
"qsv_accumulate",
"'{column_name}' not found. Be sure to enclose the column name in double \
quotes."
);
let curr_value = match value {
Value::Number(n) => n,
Value::Integer(i) => i as f64,
Value::String(s) => fast_float2::parse(s.as_bytes()).unwrap_or(0.0),
_ => 0.0,
};
let curr_value = column.parse::<f64>().unwrap_or(0.0);

// Generate unique name for the accumulator state
let state_name = if let Some(name) = name {
// if a name is provided, use it as part of the state name
format!("_qsv_accumulate_{column_name}_{name}")
} else {
// otherwise, just use the column name
format!("_qsv_accumulate_{column_name}")
};
// Generate name for the accumulator state
let state_name = format!("_qsv_accumulate_{nm}", nm = name.unwrap_or_default());

// Get existing accumulator value or use initial value
let prev_acc = if let Ok(prev) = luau.globals().raw_get::<f64>(&*state_name) {
Expand All @@ -2566,7 +2556,6 @@ fn setup_helpers(
} else {
// By default, the first column value is used as the initial value
// unless the optional initial value is provided.
// If the first column value is not a number, 0.0 is used as the initial value.
curr_value
};
luau.globals().raw_set(&*state_name, init_value)?;
Expand All @@ -2583,7 +2572,7 @@ fn setup_helpers(
};

// Store the new accumulated value
luau.globals().raw_set(&*state_name, result)?;
luau.globals().raw_set(state_name, result)?;

Ok(result)
},
Expand Down
8 changes: 4 additions & 4 deletions tests/test_luau.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3086,7 +3086,7 @@ function weighted_sum(acc, x)
return acc + x * _IDX
end
return qsv_accumulate("value", weighted_sum, 0)
return qsv_accumulate(value, weighted_sum, 0)
"#,
)
.arg("data.csv");
Expand Down Expand Up @@ -3131,7 +3131,7 @@ function weighted_sum(acc, x)
return acc + x * _IDX
end
return qsv_accumulate("value", weighted_sum)
return qsv_accumulate(value, weighted_sum)
"#,
)
.arg("data.csv");
Expand Down Expand Up @@ -3177,7 +3177,7 @@ BEGIN {
}!
-- This is the MAIN LOOP
accumulated = qsv_accumulate("value", udf_sum, 100)
accumulated = qsv_accumulate(value, udf_sum, 100)
-- return the accumulated value for the current row
return accumulated
Expand Down Expand Up @@ -3231,7 +3231,7 @@ BEGIN {
}!
-- This is the MAIN LOOP
return qsv_accumulate("value", func_with_reset, 0)
return qsv_accumulate(value, func_with_reset, 0)
"#,
)
.arg("data.csv");
Expand Down

0 comments on commit 4db04ac

Please sign in to comment.