From 3fb8c2bea7d3b12275cef28b2d2c1ac7d6f23e3c Mon Sep 17 00:00:00 2001 From: daywalker90 <8257956+daywalker90@users.noreply.github.com> Date: Sat, 29 Jun 2024 15:36:44 +0200 Subject: [PATCH] make column names and states case insensitive --- CHANGELOG.md | 4 ++ src/config.rs | 144 ++++++++++++------------------------------ src/structs.rs | 3 +- src/tables.rs | 4 +- tests/test_summars.py | 111 +++++++++++++++++++++++++++----- 5 files changed, 146 insertions(+), 120 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ac0ec2..1a12ec3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -13,6 +13,10 @@ - sats values are rounded to the closest integer instead of rounded down +### Fixed + +- column names and states are now case insensitive + ## [3.3.0] 2024-06-05 ### Added diff --git a/src/config.rs b/src/config.rs index ca72d77..90029d9 100644 --- a/src/config.rs +++ b/src/config.rs @@ -109,36 +109,11 @@ fn parse_option(name: &str, value: &serde_json::Value) -> Result Result, Error> { - let cleaned_input: String = input - .chars() - .filter(|&c| !c.is_whitespace()) - .collect::() - .to_ascii_uppercase(); - let split_input: Vec<&str> = cleaned_input.split(',').collect(); - - let mut uniq = HashSet::new(); - for i in &split_input { - if !uniq.insert(i) { - return Err(anyhow!( - "Duplicate entry detected in {}: {}", - OPT_COLUMNS, - i - )); - } - } - - for i in &split_input { - if !Summary::FIELD_NAMES_AS_ARRAY.contains(i) { - return Err(anyhow!("`{}` not found in valid column names!", i)); - } - } - - let cleaned_strings: Vec = split_input.into_iter().map(String::from).collect(); - Ok(cleaned_strings) -} - -fn validate_forwards_columns_input(input: &str) -> Result, Error> { +fn validate_columns_input( + input: &str, + column_name: &str, + columns: &[&'static str], +) -> Result, Error> { let cleaned_input: String = input .chars() .filter(|&c| !c.is_whitespace()) @@ -151,73 +126,15 @@ fn validate_forwards_columns_input(input: &str) -> Result, Error> { if !uniq.insert(i) { return Err(anyhow!( "Duplicate entry detected in {}: {}", - OPT_FORWARDS_COLUMNS, + column_name, i )); } } for i in &split_input { - if !Forwards::FIELD_NAMES_AS_ARRAY.contains(i) { - return Err(anyhow!("`{}` not found in valid forwards column names!", i)); - } - } - - let cleaned_strings: Vec = split_input.into_iter().map(String::from).collect(); - Ok(cleaned_strings) -} - -fn validate_pays_columns_input(input: &str) -> Result, Error> { - let cleaned_input: String = input - .chars() - .filter(|&c| !c.is_whitespace()) - .collect::() - .to_ascii_lowercase(); - let split_input: Vec<&str> = cleaned_input.split(',').collect(); - - let mut uniq = HashSet::new(); - for i in &split_input { - if !uniq.insert(i) { - return Err(anyhow!( - "Duplicate entry detected in {}: {}", - OPT_PAYS_COLUMNS, - i - )); - } - } - - for i in &split_input { - if !Pays::FIELD_NAMES_AS_ARRAY.contains(i) { - return Err(anyhow!("`{}` not found in valid pays column names!", i)); - } - } - - let cleaned_strings: Vec = split_input.into_iter().map(String::from).collect(); - Ok(cleaned_strings) -} - -fn validate_invoices_columns_input(input: &str) -> Result, Error> { - let cleaned_input: String = input - .chars() - .filter(|&c| !c.is_whitespace()) - .collect::() - .to_ascii_lowercase(); - let split_input: Vec<&str> = cleaned_input.split(',').collect(); - - let mut uniq = HashSet::new(); - for i in &split_input { - if !uniq.insert(i) { - return Err(anyhow!( - "Duplicate entry detected in {}: {}", - OPT_INVOICES_COLUMNS, - i - )); - } - } - - for i in &split_input { - if !Invoices::FIELD_NAMES_AS_ARRAY.contains(i) { - return Err(anyhow!("`{}` not found in valid invoices column names!", i)); + if !columns.contains(i) { + return Err(anyhow!("`{}` not found in valid {} names!", i, column_name)); } } @@ -230,11 +147,16 @@ fn validate_sort_input(input: &str) -> Result { let sortable_columns = Summary::FIELD_NAMES_AS_ARRAY .into_iter() - .filter(|t| t != &"GRAPH_SATS") - .collect::>(); + .filter(|t| t != &"graph_sats") + .map(|s| s.to_string()) + .collect::>(); - if reverse && sortable_columns.contains(&&input[1..]) || sortable_columns.contains(&input) { - Ok(input.to_string()) + if reverse && sortable_columns.contains(&(input[1..].to_ascii_lowercase())) + || sortable_columns.contains(&input.to_ascii_lowercase()) + { + Ok(input.to_ascii_uppercase()) + } else if input.to_ascii_lowercase().contains("graph_sats") { + Err(anyhow!("Can not sort by `GRAPH_SATS`!")) } else { Err(anyhow!( "Not a valid column name: `{}`. Must be one of: {}", @@ -245,7 +167,11 @@ fn validate_sort_input(input: &str) -> Result { } fn validate_exclude_states_input(input: &str) -> Result { - let cleaned_input: String = input.chars().filter(|&c| !c.is_whitespace()).collect(); + let cleaned_input: String = input + .chars() + .filter(|&c| !c.is_whitespace()) + .collect::() + .to_ascii_uppercase(); let split_input: Vec<&str> = cleaned_input.split(',').collect(); if split_input.contains(&"PUBLIC") && split_input.contains(&"PRIVATE") { return Err(anyhow!("Can only filter `PUBLIC` OR `PRIVATE`, not both.")); @@ -539,7 +465,11 @@ pub fn get_startup_options( fn check_option(config: &mut Config, name: &str, value: &options::Value) -> Result<(), Error> { match name { n if n.eq(OPT_COLUMNS) => { - config.columns.value = validate_columns_input(value.as_str().unwrap())?; + config.columns.value = validate_columns_input( + value.as_str().unwrap(), + OPT_COLUMNS, + &Summary::FIELD_NAMES_AS_ARRAY, + )?; } n if n.eq(OPT_SORT_BY) => { config.sort_by.value = validate_sort_input(value.as_str().unwrap())? @@ -553,8 +483,11 @@ fn check_option(config: &mut Config, name: &str, value: &options::Value) -> Resu options_value_to_u64(OPT_FORWARDS, value.as_i64().unwrap(), 0, true)?; } n if n.eq(OPT_FORWARDS_COLUMNS) => { - config.forwards_columns.value = - validate_forwards_columns_input(value.as_str().unwrap())?; + config.forwards_columns.value = validate_columns_input( + value.as_str().unwrap(), + OPT_FORWARDS_COLUMNS, + &Forwards::FIELD_NAMES_AS_ARRAY, + )?; } n if n.eq(OPT_FORWARDS_FILTER_AMT) => { config.forwards_filter_amt_msat.value = @@ -571,7 +504,11 @@ fn check_option(config: &mut Config, name: &str, value: &options::Value) -> Resu config.pays.value = options_value_to_u64(OPT_PAYS, value.as_i64().unwrap(), 0, true)? } n if n.eq(OPT_PAYS_COLUMNS) => { - config.pays_columns.value = validate_pays_columns_input(value.as_str().unwrap())?; + config.pays_columns.value = validate_columns_input( + value.as_str().unwrap(), + OPT_PAYS_COLUMNS, + &Pays::FIELD_NAMES_AS_ARRAY, + )?; } n if n.eq(OPT_MAX_DESC_LENGTH) => { config.max_desc_length.value = @@ -582,8 +519,11 @@ fn check_option(config: &mut Config, name: &str, value: &options::Value) -> Resu options_value_to_u64(OPT_INVOICES, value.as_i64().unwrap(), 0, true)? } n if n.eq(OPT_INVOICES_COLUMNS) => { - config.invoices_columns.value = - validate_invoices_columns_input(value.as_str().unwrap())?; + config.invoices_columns.value = validate_columns_input( + value.as_str().unwrap(), + OPT_INVOICES_COLUMNS, + &Invoices::FIELD_NAMES_AS_ARRAY, + )?; } n if n.eq(OPT_MAX_LABEL_LENGTH) => { config.max_label_length.value = diff --git a/src/structs.rs b/src/structs.rs index f8ec7eb..40bb020 100644 --- a/src/structs.rs +++ b/src/structs.rs @@ -64,7 +64,7 @@ impl Config { value: { Summary::FIELD_NAMES_AS_ARRAY .into_iter() - .filter(|t| t != &"GRAPH_SATS") + .filter(|t| t != &"graph_sats") .map(ToString::to_string) .collect::>() }, @@ -267,7 +267,6 @@ pub struct PeerAvailability { } #[derive(Debug, Tabled, FieldNamesAsArray, Serialize)] -#[field_names_as_array(rename_all = "SCREAMING_SNAKE_CASE")] #[tabled(rename_all = "SCREAMING_SNAKE_CASE")] pub struct Summary { #[serde(skip_serializing)] diff --git a/src/tables.rs b/src/tables.rs index 43fecdf..a99830d 100644 --- a/src/tables.rs +++ b/src/tables.rs @@ -1167,7 +1167,9 @@ fn format_summary(config: &Config, sumtable: &mut Table) -> Result<(), Error> { config.style.value.apply(sumtable); for head in Summary::FIELD_NAMES_AS_ARRAY { if !config.columns.value.contains(&head.to_string()) { - sumtable.with(Disable::column(ByColumnName::new(head))); + sumtable.with(Disable::column(ByColumnName::new( + head.to_ascii_uppercase(), + ))); } } diff --git a/tests/test_summars.py b/tests/test_summars.py index 49e818d..083c063 100644 --- a/tests/test_summars.py +++ b/tests/test_summars.py @@ -81,7 +81,7 @@ def test_basic(node_factory, get_plugin): # noqa: F811 node.rpc.call("summars", {"summars-columns": 1}) with pytest.raises( - RpcError, match="`TEST` not found in " "valid column names" + RpcError, match="`test` not found in valid summars-columns names" ): node.rpc.call("summars", {"summars-columns": "TEST"}) @@ -118,6 +118,11 @@ def test_options(node_factory, get_plugin): # noqa: F811 for col2 in columns: if col != col2: assert col2 not in result["result"] + result = node.rpc.call("summars", {"summars-columns": col.lower()}) + assert col in result["result"] + for col2 in columns: + if col != col2: + assert col2 not in result["result"] result = node.rpc.call( "summars", {"summars-columns": "PPM,PEER_ID,IN_SATS,SCID"}, @@ -140,6 +145,13 @@ def test_options(node_factory, get_plugin): # noqa: F811 for col2 in pay_columns: if col != col2: assert col2 not in result["result"] + result = node.rpc.call( + "summars", {"summars-pays": 1, "summars-pays-columns": col.upper()} + ) + assert col in result["result"] + for col2 in pay_columns: + if col != col2: + assert col2 not in result["result"] result = node.rpc.call( "summars", { @@ -169,6 +181,14 @@ def test_options(node_factory, get_plugin): # noqa: F811 for col2 in invoice_columns: if col != col2: assert col2 not in result["result"] + result = node.rpc.call( + "summars", + {"summars-invoices": 1, "summars-invoices-columns": col.upper()}, + ) + assert col in result["result"] + for col2 in invoice_columns: + if col != col2: + assert col2 not in result["result"] result = node.rpc.call( "summars", @@ -197,6 +217,14 @@ def test_options(node_factory, get_plugin): # noqa: F811 for col2 in forwards_columns: if col != col2: assert col2 not in result["result"] + result = node.rpc.call( + "summars", + {"summars-forwards": 1, "summars-forwards-columns": col.upper()}, + ) + assert col in result["result"] + for col2 in forwards_columns: + if col != col2: + assert col2 not in result["result"] result = node.rpc.call( "summars", @@ -224,14 +252,41 @@ def test_options(node_factory, get_plugin): # noqa: F811 for col in columns: if col == "GRAPH_SATS": - continue - result = node.rpc.call( - "summars", - {"summars-columns": ",".join(columns), "summars-sort-by": col}, - ) - assert col in result["result"] + with pytest.raises(RpcError, match="Can not sort by `GRAPH_SATS`!"): + node.rpc.call( + "summars", + { + "summars-columns": ",".join(columns), + "summars-sort-by": col, + }, + ) + else: + result = node.rpc.call( + "summars", + {"summars-columns": ",".join(columns), "summars-sort-by": col}, + ) + assert col in result["result"] + if col == "GRAPH_SATS": + with pytest.raises(RpcError, match="Can not sort by `GRAPH_SATS`!"): + node.rpc.call( + "summars", + { + "summars-columns": ",".join(columns), + "summars-sort-by": col.lower(), + }, + ) + else: + result = node.rpc.call( + "summars", + { + "summars-columns": ",".join(columns), + "summars-sort-by": col.lower(), + }, + ) + assert col in result["result"] result = node.rpc.call("summars", {"summars-exclude-states": "OK"}) + assert "OK" not in result["result"] result = node.rpc.call("summars", {"summars-forwards": 1}) assert "forwards" in result["result"] @@ -308,17 +363,23 @@ def test_options(node_factory, get_plugin): # noqa: F811 def test_option_errors(node_factory, get_plugin): # noqa: F811 node = node_factory.get_node(options={"plugin": get_plugin}) - with pytest.raises(RpcError, match="not found in valid column names"): + with pytest.raises( + RpcError, match="not found in valid summars-columns names" + ): node.rpc.call("summars", {"summars-columns": "test"}) with pytest.raises(RpcError, match="Duplicate entry"): node.rpc.call("summars", {"summars-columns": "IN_SATS,IN_SATS"}) - with pytest.raises(RpcError, match="not found in valid column names"): + with pytest.raises( + RpcError, match="not found in valid summars-columns names" + ): node.rpc.call("summars", {"summars-columns": "PRIVATE"}) - with pytest.raises(RpcError, match="not found in valid column names"): + with pytest.raises( + RpcError, match="not found in valid summars-columns names" + ): node.rpc.call("summars", {"summars-columns": "OFFLINE"}) with pytest.raises( - RpcError, match="not found in valid forwards column names" + RpcError, match="not found in valid summars-forwards-columns names" ): node.rpc.call("summars", {"summars-forwards-columns": "test"}) with pytest.raises(RpcError, match="Duplicate entry"): @@ -326,7 +387,9 @@ def test_option_errors(node_factory, get_plugin): # noqa: F811 "summars", {"summars-forwards-columns": "in_channel,in_channel"} ) - with pytest.raises(RpcError, match="not found in valid pays column names"): + with pytest.raises( + RpcError, match="not found in valid summars-pays-columns names" + ): node.rpc.call("summars", {"summars-pays-columns": "test"}) with pytest.raises(RpcError, match="Duplicate entry"): node.rpc.call( @@ -334,7 +397,7 @@ def test_option_errors(node_factory, get_plugin): # noqa: F811 ) with pytest.raises( - RpcError, match="not found in valid invoices column names" + RpcError, match="not found in valid summars-invoices-columns names" ): node.rpc.call("summars", {"summars-invoices-columns": "test"}) with pytest.raises(RpcError, match="Duplicate entry"): @@ -355,8 +418,6 @@ def test_option_errors(node_factory, get_plugin): # noqa: F811 node.rpc.call("summars", {"summars-sort-by": 1}) with pytest.raises(RpcError, match="Not a valid column name"): node.rpc.call("summars", {"summars-sort-by": "TEST"}) - with pytest.raises(RpcError, match="Not a valid column name"): - node.rpc.call("summars", {"summars-sort-by": "GRAPH_SATS"}) with pytest.raises(RpcError, match="not a valid integer"): node.rpc.call("summars", {"summars-forwards": "TEST"}) @@ -541,6 +602,22 @@ def test_chanstates(node_factory, bitcoind, get_plugin): # noqa: F811 assert "_]" not in result["result"] assert "2 channels filtered" in result["result"] + result = l1.rpc.call("summars", {"summars-exclude-states": "ok"}) + assert "OK" not in result["result"] + assert "2 channels filtered" in result["result"] + + result = l1.rpc.call("summars", {"summars-exclude-states": "private"}) + assert "[P" not in result["result"] + assert "1 channel filtered" in result["result"] + + result = l1.rpc.call("summars", {"summars-exclude-states": "public"}) + assert "[_" not in result["result"] + assert "1 channel filtered" in result["result"] + + result = l1.rpc.call("summars", {"summars-exclude-states": "online"}) + assert "_]" not in result["result"] + assert "2 channels filtered" in result["result"] + l3.stop() wait_for( @@ -552,6 +629,10 @@ def test_chanstates(node_factory, bitcoind, get_plugin): # noqa: F811 assert "O]" not in result["result"] assert "1 channel filtered" in result["result"] + result = l1.rpc.call("summars", {"summars-exclude-states": "offline"}) + assert "O]" not in result["result"] + assert "1 channel filtered" in result["result"] + l1.rpc.close(chans[0]["short_channel_id"]) wait_for(