diff --git a/src/cmd/joinp.rs b/src/cmd/joinp.rs index 0ee0eb97a..fa4de1b2b 100644 --- a/src/cmd/joinp.rs +++ b/src/cmd/joinp.rs @@ -143,6 +143,11 @@ joinp options: (e.g. 2022-02-29 -> 2022-02-28) instead of erroring. OUTPUT FORMAT OPTIONS: + --sql-filter The SQL expression to apply against the join result. + Ordinarily used to select columns and filter rows from + the join result. Be sure to select from the "join_result" + table when formulating the SQL expression. + (e.g. "select c1, c2 as colname from join_result where c2 > 20") --datetime-format The datetime format to use writing datetimes. See https://docs.rs/chrono/latest/chrono/format/strftime/index.html for the list of valid format specifiers. @@ -151,7 +156,7 @@ joinp options: --float-precision The number of digits of precision to use when writing floats. (default: 6) --null-value The string to use when writing null values. - (default: ) + (default: ) Common options: -h, --help Display this message @@ -172,9 +177,10 @@ use polars::{ datatypes::AnyValue, frame::hash_join::{JoinType, JoinValidation}, prelude::{ - AsOfOptions, AsofStrategy, CsvWriter, LazyCsvReader, LazyFileListReader, LazyFrame, - SerWriter, SortOptions, + AsOfOptions, AsofStrategy, CsvWriter, IntoLazy, LazyCsvReader, LazyFileListReader, + LazyFrame, SerWriter, SortOptions, }, + sql::SQLContext, }; use serde::Deserialize; use smartstring; @@ -207,6 +213,7 @@ struct Args { flag_right_by: Option, flag_strategy: Option, flag_tolerance: Option, + flag_sql_filter: Option, flag_datetime_format: Option, flag_date_format: Option, flag_time_format: Option, @@ -328,6 +335,7 @@ struct JoinStruct { delim: u8, streaming: bool, no_optimizations: bool, + sql_filter: Option, datetime_format: Option, date_format: Option, time_format: Option, @@ -384,7 +392,7 @@ impl JoinStruct { }; log::debug!("Optimization state: {optimization_state:?}"); - let mut join_results = if jointype == JoinType::Cross { + let join_results = if jointype == JoinType::Cross { self.left_lf .with_optimizations(optimization_state) .join_builder() @@ -414,6 +422,15 @@ impl JoinStruct { .collect()? }; + let mut results_df = if let Some(sql_filter) = &self.sql_filter { + let mut ctx = SQLContext::new(); + ctx.register("join_result", join_results.lazy()); + ctx.execute(sql_filter) + .and_then(polars::prelude::LazyFrame::collect)? + } else { + join_results + }; + // no need to use buffered writer here, as CsvWriter already does that let mut out_writer = match self.output { Some(output_file) => { @@ -424,7 +441,7 @@ impl JoinStruct { }; // shape is the number of rows and columns - let join_shape = join_results.shape(); + let join_shape = results_df.shape(); CsvWriter::new(&mut out_writer) .has_header(true) @@ -434,7 +451,7 @@ impl JoinStruct { .with_time_format(self.time_format) .with_float_precision(self.float_precision) .with_null_value(self.null_value) - .finish(&mut join_results)?; + .finish(&mut results_df)?; Ok(join_shape) } @@ -499,6 +516,7 @@ impl Args { delim, streaming: self.flag_streaming, no_optimizations: self.flag_no_optimizations, + sql_filter: self.flag_sql_filter.clone(), datetime_format: self.flag_datetime_format.clone(), date_format: self.flag_date_format.clone(), time_format: self.flag_time_format.clone(), diff --git a/tests/test_joinp.rs b/tests/test_joinp.rs index 036959a4f..eb68a4f64 100644 --- a/tests/test_joinp.rs +++ b/tests/test_joinp.rs @@ -567,3 +567,44 @@ fn joinp_asof_date_diffcolnames() { ]; assert_eq!(got, expected); } + +#[test] +fn joinp_asof_date_diffcolnames_sqlfilter() { + let wrk = Workdir::new("join_asof_date_diffcolnames_sqlfilter"); + wrk.create( + "gdp.csv", + vec![ + svec!["gdp_date", "gdp"], + svec!["2016-01-01", "4164"], + svec!["2017-01-01", "4411"], + svec!["2018-01-01", "4566"], + svec!["2019-01-01", "4696"], + ], + ); + wrk.create( + "population.csv", + vec![ + svec!["pop_date", "population"], + svec!["2016-05-12", "82.19"], + svec!["2017-05-12", "82.66"], + svec!["2018-05-12", "83.12"], + svec!["2019-05-12", "83.52"], + ], + ); + + let mut cmd = wrk.command("joinp"); + cmd.arg("--asof") + .args(["pop_date", "population.csv", "gdp_date", "gdp.csv"]) + .args([ + "--sql-filter", + "select pop_date, gdp from join_result where gdp > 4500", + ]); + + let got: Vec> = wrk.read_stdout(&mut cmd); + let expected = vec![ + svec!["pop_date", "gdp"], + svec!["2018-05-12", "4566"], + svec!["2019-05-12", "4696"], + ]; + assert_eq!(got, expected); +}