diff --git a/ta_lib/core/src/bool.rs b/ta_lib/core/src/bool.rs index fd0124c3..93797039 100644 --- a/ta_lib/core/src/bool.rs +++ b/ta_lib/core/src/bool.rs @@ -1,86 +1,87 @@ use crate::series::Series; use std::ops::{BitAnd, BitOr}; +macro_rules! scalar_comparison { + ($($name:ident, $op:tt);* $(;)?) => { + $( + pub fn $name(&self, scalar: f64) -> Series { + self.compare_scalar(scalar, |a, b| a $op b) + } + )* + }; +} + +macro_rules! series_comparison { + ($($name:ident, $op:tt);* $(;)?) => { + $( + pub fn $name(&self, rhs: &Series) -> Series { + self.compare(rhs, |a, b| a $op b) + } + )* + }; +} + +macro_rules! logical_operation { + ($($name:ident, $op:tt);* $(;)?) => { + $( + pub fn $name(&self, rhs: &Series) -> Series { + self.logical_op(rhs, |a, b| a $op b) + } + )* + }; +} + impl Series { - fn compare_series(&self, rhs: &Series, f: F) -> Series + fn compare_scalar(&self, scalar: f64, comparator: F) -> Series where F: Fn(f64, f64) -> bool, { - self.zip_with(rhs, |a, b| match (a, b) { - (Some(a_val), Some(b_val)) => Some(f(*a_val, *b_val)), - _ => None, - }) + self.fmap(|x| x.map(|v| comparator(*v, scalar))) } - fn compare(&self, scalar: f64, f: F) -> Series + fn compare(&self, rhs: &Series, comparator: F) -> Series where F: Fn(f64, f64) -> bool, { - self.fmap(|x| x.map(|v| f(*v, scalar))) - } - - pub fn eq(&self, scalar: f64) -> Series { - self.compare(scalar, |a, b| a == b) - } - - pub fn ne(&self, scalar: f64) -> Series { - self.compare(scalar, |a, b| a != b) - } - - pub fn gt(&self, scalar: f64) -> Series { - self.compare(scalar, |a, b| a > b) - } - - pub fn gte(&self, scalar: f64) -> Series { - self.compare(scalar, |a, b| a >= b) - } - - pub fn lt(&self, scalar: f64) -> Series { - self.compare(scalar, |a, b| a < b) - } - - pub fn lte(&self, scalar: f64) -> Series { - self.compare(scalar, |a, b| a <= b) - } - - pub fn eq_series(&self, rhs: &Series) -> Series { - self.compare_series(rhs, |a, b| a == b) - } - - pub fn ne_series(&self, rhs: &Series) -> Series { - self.compare_series(rhs, |a, b| a != b) - } - - pub fn gt_series(&self, rhs: &Series) -> Series { - self.compare_series(rhs, |a, b| a > b) - } - - pub fn gte_series(&self, rhs: &Series) -> Series { - self.compare_series(rhs, |a, b| a >= b) + self.zip_with(rhs, |a, b| match (a, b) { + (Some(a_val), Some(b_val)) => Some(comparator(*a_val, *b_val)), + _ => None, + }) } - pub fn lt_series(&self, rhs: &Series) -> Series { - self.compare_series(rhs, |a, b| a < b) + scalar_comparison! { + seq, ==; + sne, !=; + sgt, >; + sgte, >=; + slt, <; + slte, <=; } - pub fn lte_series(&self, rhs: &Series) -> Series { - self.compare_series(rhs, |a, b| a <= b) + series_comparison! { + eq, ==; + ne, !=; + gt, >; + gte, >=; + lt, <; + lte, <=; } } impl Series { - pub fn and_series(&self, rhs: &Series) -> Series { + fn logical_op(&self, rhs: &Series, operation: F) -> Series + where + F: Fn(bool, bool) -> bool, + { self.zip_with(rhs, |a, b| match (a, b) { - (Some(a_val), Some(b_val)) => Some(*a_val & *b_val), + (Some(a_val), Some(b_val)) => Some(operation(*a_val, *b_val)), _ => None, }) } - pub fn or_series(&self, rhs: &Series) -> Series { - self.zip_with(rhs, |a, b| match (a, b) { - (Some(a_val), Some(b_val)) => Some(*a_val | *b_val), - _ => None, - }) + logical_operation! { + and, &; + or, |; } } @@ -88,7 +89,7 @@ impl BitAnd for Series { type Output = Self; fn bitand(self, rhs: Self) -> Self::Output { - self.and_series(&rhs) + self.and(&rhs) } } @@ -96,7 +97,7 @@ impl BitOr for Series { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { - self.or_series(&rhs) + self.or(&rhs) } } @@ -110,7 +111,7 @@ mod tests { let b = Series::from([1.0, 1.0, 6.0, 1.0, 1.0]); let expected: Series = Series::from([0.0, 0.0, 0.0, 0.0, 0.0]).into(); - let result = a.gt_series(&b) & a.lt_series(&b); + let result = a.gt(&b) & a.lt(&b); assert_eq!(result, expected); } @@ -121,7 +122,7 @@ mod tests { let b = Series::from([1.0, 1.0, 1.0, 1.0, 1.0]); let expected: Series = Series::from([0.0, 1.0, 1.0, 1.0, 1.0]).into(); - let result = a.gt_series(&b) | a.lt_series(&b); + let result = a.gt(&b) | a.lt(&b); assert_eq!(result, expected); } diff --git a/ta_lib/patterns/src/barrier.rs b/ta_lib/patterns/src/barrier.rs index 9b78bea7..f33a0902 100644 --- a/ta_lib/patterns/src/barrier.rs +++ b/ta_lib/patterns/src/barrier.rs @@ -5,11 +5,11 @@ pub fn bullish(open: &[f64], low: &[f64], close: &[f64]) -> Vec { let low = Series::from(low); let close = Series::from(close); - (close.shift(1).gt_series(&open.shift(1)) - & close.shift(2).lt_series(&open.shift(2)) - & close.shift(3).lt_series(&open.shift(3)) - & low.shift(1).eq_series(&low.shift(2)) - & low.shift(2).eq_series(&low.shift(3))) + (close.shift(1).gt(&open.shift(1)) + & close.shift(2).lt(&open.shift(2)) + & close.shift(3).lt(&open.shift(3)) + & low.shift(1).eq(&low.shift(2)) + & low.shift(2).eq(&low.shift(3))) .into() } @@ -18,11 +18,11 @@ pub fn bearish(open: &[f64], high: &[f64], close: &[f64]) -> Vec { let high = Series::from(high); let close = Series::from(close); - (close.shift(1).lt_series(&open.shift(1)) - & close.shift(2).gt_series(&open.shift(2)) - & close.shift(3).gt_series(&open.shift(3)) - & high.shift(1).eq_series(&high.shift(2)) - & high.shift(2).eq_series(&high.shift(3))) + (close.shift(1).lt(&open.shift(1)) + & close.shift(2).gt(&open.shift(2)) + & close.shift(3).gt(&open.shift(3)) + & high.shift(1).eq(&high.shift(2)) + & high.shift(2).eq(&high.shift(3))) .into() } diff --git a/ta_lib/patterns/src/blockade.rs b/ta_lib/patterns/src/blockade.rs index 3cbb6336..07d2d175 100644 --- a/ta_lib/patterns/src/blockade.rs +++ b/ta_lib/patterns/src/blockade.rs @@ -6,17 +6,17 @@ pub fn bullish(open: &[f64], high: &[f64], low: &[f64], close: &[f64]) -> Vec Vec Vec { let low = Series::from(low); let close = Series::from(close); - (close.shift(2).gt_series(&open.shift(2)) - & close.shift(1).gt_series(&open.shift(1)) - & open.shift(1).lt_series(&close.shift(2)) - & open.shift(1).eq_series(&low.shift(1)) - & close.shift(1).gt_series(&close.shift(2))) + (close.shift(2).gt(&open.shift(2)) + & close.shift(1).gt(&open.shift(1)) + & open.shift(1).lt(&close.shift(2)) + & open.shift(1).eq(&low.shift(1)) + & close.shift(1).gt(&close.shift(2))) .into() } @@ -18,11 +18,11 @@ pub fn bearish(open: &[f64], high: &[f64], close: &[f64]) -> Vec { let high = Series::from(high); let close = Series::from(close); - (close.shift(2).lt_series(&open.shift(2)) - & close.shift(1).lt_series(&open.shift(1)) - & open.shift(1).gt_series(&close.shift(2)) - & open.shift(1).eq_series(&high.shift(1)) - & close.shift(1).lt_series(&close.shift(2))) + (close.shift(2).lt(&open.shift(2)) + & close.shift(1).lt(&open.shift(1)) + & open.shift(1).gt(&close.shift(2)) + & open.shift(1).eq(&high.shift(1)) + & close.shift(1).lt(&close.shift(2))) .into() } diff --git a/ta_lib/patterns/src/breakaway.rs b/ta_lib/patterns/src/breakaway.rs index 0948bc6d..4c9490be 100644 --- a/ta_lib/patterns/src/breakaway.rs +++ b/ta_lib/patterns/src/breakaway.rs @@ -4,13 +4,13 @@ pub fn bullish(open: &[f64], close: &[f64]) -> Vec { let open = Series::from(open); let close = Series::from(close); - (close.gt_series(&open.shift(3)) - & close.gt_series(&open) - & close.shift(1).lt_series(&open.shift(1)) - & close.shift(2).lt_series(&open.shift(2)) - & close.shift(3).lt_series(&open.shift(3)) - & close.shift(4).lt_series(&open.shift(4)) - & open.shift(3).lt_series(&close.shift(4))) + (close.gt(&open.shift(3)) + & close.gt(&open) + & close.shift(1).lt(&open.shift(1)) + & close.shift(2).lt(&open.shift(2)) + & close.shift(3).lt(&open.shift(3)) + & close.shift(4).lt(&open.shift(4)) + & open.shift(3).lt(&close.shift(4))) .into() } @@ -18,13 +18,13 @@ pub fn bearish(open: &[f64], close: &[f64]) -> Vec { let open = Series::from(open); let close = Series::from(close); - (close.lt_series(&open.shift(3)) - & close.lt_series(&open) - & close.shift(1).gt_series(&open.shift(1)) - & close.shift(2).gt_series(&open.shift(2)) - & close.shift(3).gt_series(&open.shift(3)) - & close.shift(4).gt_series(&open.shift(4)) - & open.shift(3).gt_series(&close.shift(4))) + (close.lt(&open.shift(3)) + & close.lt(&open) + & close.shift(1).gt(&open.shift(1)) + & close.shift(2).gt(&open.shift(2)) + & close.shift(3).gt(&open.shift(3)) + & close.shift(4).gt(&open.shift(4)) + & open.shift(3).gt(&close.shift(4))) .into() } diff --git a/ta_lib/patterns/src/counterattack.rs b/ta_lib/patterns/src/counterattack.rs index 4086f986..e5ba03d6 100644 --- a/ta_lib/patterns/src/counterattack.rs +++ b/ta_lib/patterns/src/counterattack.rs @@ -4,10 +4,10 @@ pub fn bullish(open: &[f64], close: &[f64]) -> Vec { let open = Series::from(open); let close = Series::from(close); - (open.lt_series(&close.shift(1)) - & close.gt_series(&open) - & close.shift(1).lt_series(&open.shift(1)) - & close.eq_series(&close.shift(1))) + (open.lt(&close.shift(1)) + & close.gt(&open) + & close.shift(1).lt(&open.shift(1)) + & close.eq(&close.shift(1))) .into() } @@ -15,10 +15,10 @@ pub fn bearish(open: &[f64], close: &[f64]) -> Vec { let open = Series::from(open); let close = Series::from(close); - (open.gt_series(&close.shift(1)) - & close.lt_series(&open) - & close.shift(1).gt_series(&open.shift(1)) - & close.eq_series(&close.shift(1))) + (open.gt(&close.shift(1)) + & close.lt(&open) + & close.shift(1).gt(&open.shift(1)) + & close.eq(&close.shift(1))) .into() } diff --git a/ta_lib/patterns/src/doji.rs b/ta_lib/patterns/src/doji.rs index 2b1bb9df..6751ba2a 100644 --- a/ta_lib/patterns/src/doji.rs +++ b/ta_lib/patterns/src/doji.rs @@ -4,20 +4,14 @@ pub fn bullish(open: &[f64], close: &[f64]) -> Vec { let open = Series::from(open); let close = Series::from(close); - (close.gt_series(&open) - & close.shift(1).eq_series(&open.shift(1)) - & close.shift(2).lt_series(&open.shift(2))) - .into() + (close.gt(&open) & close.shift(1).eq(&open.shift(1)) & close.shift(2).lt(&open.shift(2))).into() } pub fn bearish(open: &[f64], close: &[f64]) -> Vec { let open = Series::from(open); let close = Series::from(close); - (close.lt_series(&open) - & close.shift(1).eq_series(&open.shift(1)) - & close.shift(2).gt_series(&open.shift(2))) - .into() + (close.lt(&open) & close.shift(1).eq(&open.shift(1)) & close.shift(2).gt(&open.shift(2))).into() } #[cfg(test)] diff --git a/ta_lib/patterns/src/double_doji.rs b/ta_lib/patterns/src/double_doji.rs index bf758be2..8afe3c02 100644 --- a/ta_lib/patterns/src/double_doji.rs +++ b/ta_lib/patterns/src/double_doji.rs @@ -4,9 +4,9 @@ pub fn bullish(open: &[f64], close: &[f64]) -> Vec { let open = Series::from(open); let close = Series::from(close); - (close.shift(1).eq_series(&open.shift(1)) - & close.shift(2).eq_series(&open.shift(2)) - & close.shift(3).lt_series(&open.shift(3))) + (close.shift(1).eq(&open.shift(1)) + & close.shift(2).eq(&open.shift(2)) + & close.shift(3).lt(&open.shift(3))) .into() } @@ -14,9 +14,9 @@ pub fn bearish(open: &[f64], close: &[f64]) -> Vec { let open = Series::from(open); let close = Series::from(close); - (close.shift(1).eq_series(&open.shift(1)) - & close.shift(2).eq_series(&open.shift(2)) - & close.shift(3).gt_series(&open.shift(3))) + (close.shift(1).eq(&open.shift(1)) + & close.shift(2).eq(&open.shift(2)) + & close.shift(3).gt(&open.shift(3))) .into() } diff --git a/ta_lib/patterns/src/marubozu.rs b/ta_lib/patterns/src/marubozu.rs index 295009f2..823a9dd2 100644 --- a/ta_lib/patterns/src/marubozu.rs +++ b/ta_lib/patterns/src/marubozu.rs @@ -6,9 +6,9 @@ pub fn bullish(open: &[f64], high: &[f64], low: &[f64], close: &[f64]) -> Vec Vec Vec { let open = Series::from(open); let close = Series::from(close); - (close.lt_series(&open) - & close.shift(1).gt_series(&open.shift(1)) - & close.shift(2).gt_series(&open.shift(2)) - & close.lt_series(&open.shift(1)) - & close.gt_series(&close.shift(2)) - & open.shift(1).gt_series(&close.shift(2))) + (close.lt(&open) + & close.shift(1).gt(&open.shift(1)) + & close.shift(2).gt(&open.shift(2)) + & close.lt(&open.shift(1)) + & close.gt(&close.shift(2)) + & open.shift(1).gt(&close.shift(2))) .into() } @@ -17,12 +17,12 @@ pub fn bearish(open: &[f64], close: &[f64]) -> Vec { let open = Series::from(open); let close = Series::from(close); - (close.gt_series(&open) - & close.shift(1).lt_series(&open.shift(1)) - & close.shift(2).lt_series(&open.shift(2)) - & close.gt_series(&open.shift(1)) - & close.lt_series(&close.shift(2)) - & open.shift(1).lt_series(&close.shift(2))) + (close.gt(&open) + & close.shift(1).lt(&open.shift(1)) + & close.shift(2).lt(&open.shift(2)) + & close.gt(&open.shift(1)) + & close.lt(&close.shift(2)) + & open.shift(1).lt(&close.shift(2))) .into() } diff --git a/ta_lib/volume/src/mfi.rs b/ta_lib/volume/src/mfi.rs index 9e7ba8a2..b22f264b 100644 --- a/ta_lib/volume/src/mfi.rs +++ b/ta_lib/volume/src/mfi.rs @@ -8,8 +8,8 @@ pub fn mfi(hlc3: &[f64], volume: &[f64], period: usize) -> Vec { let volume_hlc3 = volume * hlc3; - let positive_volume = changes.gt(0.0) * &volume_hlc3; - let negative_volume = changes.lt(0.0) * &volume_hlc3; + let positive_volume = changes.sgt(0.0) * &volume_hlc3; + let negative_volume = changes.slt(0.0) * &volume_hlc3; let upper = positive_volume.sum(period); let lower = negative_volume.sum(period);