Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
m5l14i11 committed Aug 1, 2023
1 parent c5719da commit e06228d
Show file tree
Hide file tree
Showing 11 changed files with 155 additions and 160 deletions.
125 changes: 63 additions & 62 deletions ta_lib/core/src/bool.rs
Original file line number Diff line number Diff line change
@@ -1,102 +1,103 @@
use crate::series::Series;
use std::ops::{BitAnd, BitOr};

macro_rules! scalar_comparison {
($($name:ident, $op:tt);* $(;)?) => {
$(
pub fn $name(&self, scalar: f64) -> Series<bool> {
self.compare_scalar(scalar, |a, b| a $op b)
}
)*
};
}

macro_rules! series_comparison {
($($name:ident, $op:tt);* $(;)?) => {
$(
pub fn $name(&self, rhs: &Series<f64>) -> Series<bool> {
self.compare(rhs, |a, b| a $op b)
}
)*
};
}

macro_rules! logical_operation {
($($name:ident, $op:tt);* $(;)?) => {
$(
pub fn $name(&self, rhs: &Series<bool>) -> Series<bool> {
self.logical_op(rhs, |a, b| a $op b)
}
)*
};
}

impl Series<f64> {
fn compare_series<F>(&self, rhs: &Series<f64>, f: F) -> Series<bool>
fn compare_scalar<F>(&self, scalar: f64, comparator: F) -> Series<bool>
where
F: Fn(f64, f64) -> bool,
{
self.zip_with(rhs, |a, b| match (a, b) {
(Some(a_val), Some(b_val)) => Some(f(*a_val, *b_val)),
_ => None,
})
self.fmap(|x| x.map(|v| comparator(*v, scalar)))
}

fn compare<F>(&self, scalar: f64, f: F) -> Series<bool>
fn compare<F>(&self, rhs: &Series<f64>, comparator: F) -> Series<bool>
where
F: Fn(f64, f64) -> bool,
{
self.fmap(|x| x.map(|v| f(*v, scalar)))
}

pub fn eq(&self, scalar: f64) -> Series<bool> {
self.compare(scalar, |a, b| a == b)
}

pub fn ne(&self, scalar: f64) -> Series<bool> {
self.compare(scalar, |a, b| a != b)
}

pub fn gt(&self, scalar: f64) -> Series<bool> {
self.compare(scalar, |a, b| a > b)
}

pub fn gte(&self, scalar: f64) -> Series<bool> {
self.compare(scalar, |a, b| a >= b)
}

pub fn lt(&self, scalar: f64) -> Series<bool> {
self.compare(scalar, |a, b| a < b)
}

pub fn lte(&self, scalar: f64) -> Series<bool> {
self.compare(scalar, |a, b| a <= b)
}

pub fn eq_series(&self, rhs: &Series<f64>) -> Series<bool> {
self.compare_series(rhs, |a, b| a == b)
}

pub fn ne_series(&self, rhs: &Series<f64>) -> Series<bool> {
self.compare_series(rhs, |a, b| a != b)
}

pub fn gt_series(&self, rhs: &Series<f64>) -> Series<bool> {
self.compare_series(rhs, |a, b| a > b)
}

pub fn gte_series(&self, rhs: &Series<f64>) -> Series<bool> {
self.compare_series(rhs, |a, b| a >= b)
self.zip_with(rhs, |a, b| match (a, b) {
(Some(a_val), Some(b_val)) => Some(comparator(*a_val, *b_val)),
_ => None,
})
}

pub fn lt_series(&self, rhs: &Series<f64>) -> Series<bool> {
self.compare_series(rhs, |a, b| a < b)
scalar_comparison! {
seq, ==;
sne, !=;
sgt, >;
sgte, >=;
slt, <;
slte, <=;
}

pub fn lte_series(&self, rhs: &Series<f64>) -> Series<bool> {
self.compare_series(rhs, |a, b| a <= b)
series_comparison! {
eq, ==;
ne, !=;
gt, >;
gte, >=;
lt, <;
lte, <=;
}
}

impl Series<bool> {
pub fn and_series(&self, rhs: &Series<bool>) -> Series<bool> {
fn logical_op<F>(&self, rhs: &Series<bool>, operation: F) -> Series<bool>
where
F: Fn(bool, bool) -> bool,
{
self.zip_with(rhs, |a, b| match (a, b) {
(Some(a_val), Some(b_val)) => Some(*a_val & *b_val),
(Some(a_val), Some(b_val)) => Some(operation(*a_val, *b_val)),
_ => None,
})
}

pub fn or_series(&self, rhs: &Series<bool>) -> Series<bool> {
self.zip_with(rhs, |a, b| match (a, b) {
(Some(a_val), Some(b_val)) => Some(*a_val | *b_val),
_ => None,
})
logical_operation! {
and, &;
or, |;
}
}

impl BitAnd for Series<bool> {
type Output = Self;

fn bitand(self, rhs: Self) -> Self::Output {
self.and_series(&rhs)
self.and(&rhs)
}
}

impl BitOr for Series<bool> {
type Output = Self;

fn bitor(self, rhs: Self) -> Self::Output {
self.or_series(&rhs)
self.or(&rhs)
}
}

Expand All @@ -110,7 +111,7 @@ mod tests {
let b = Series::from([1.0, 1.0, 6.0, 1.0, 1.0]);
let expected: Series<bool> = Series::from([0.0, 0.0, 0.0, 0.0, 0.0]).into();

let result = a.gt_series(&b) & a.lt_series(&b);
let result = a.gt(&b) & a.lt(&b);

assert_eq!(result, expected);
}
Expand All @@ -121,7 +122,7 @@ mod tests {
let b = Series::from([1.0, 1.0, 1.0, 1.0, 1.0]);
let expected: Series<bool> = Series::from([0.0, 1.0, 1.0, 1.0, 1.0]).into();

let result = a.gt_series(&b) | a.lt_series(&b);
let result = a.gt(&b) | a.lt(&b);

assert_eq!(result, expected);
}
Expand Down
20 changes: 10 additions & 10 deletions ta_lib/patterns/src/barrier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ pub fn bullish(open: &[f64], low: &[f64], close: &[f64]) -> Vec<bool> {
let low = Series::from(low);
let close = Series::from(close);

(close.shift(1).gt_series(&open.shift(1))
& close.shift(2).lt_series(&open.shift(2))
& close.shift(3).lt_series(&open.shift(3))
& low.shift(1).eq_series(&low.shift(2))
& low.shift(2).eq_series(&low.shift(3)))
(close.shift(1).gt(&open.shift(1))
& close.shift(2).lt(&open.shift(2))
& close.shift(3).lt(&open.shift(3))
& low.shift(1).eq(&low.shift(2))
& low.shift(2).eq(&low.shift(3)))
.into()
}

Expand All @@ -18,11 +18,11 @@ pub fn bearish(open: &[f64], high: &[f64], close: &[f64]) -> Vec<bool> {
let high = Series::from(high);
let close = Series::from(close);

(close.shift(1).lt_series(&open.shift(1))
& close.shift(2).gt_series(&open.shift(2))
& close.shift(3).gt_series(&open.shift(3))
& high.shift(1).eq_series(&high.shift(2))
& high.shift(2).eq_series(&high.shift(3)))
(close.shift(1).lt(&open.shift(1))
& close.shift(2).gt(&open.shift(2))
& close.shift(3).gt(&open.shift(3))
& high.shift(1).eq(&high.shift(2))
& high.shift(2).eq(&high.shift(3)))
.into()
}

Expand Down
44 changes: 22 additions & 22 deletions ta_lib/patterns/src/blockade.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@ pub fn bullish(open: &[f64], high: &[f64], low: &[f64], close: &[f64]) -> Vec<bo
let low = Series::from(low);
let close = Series::from(close);

(close.shift(1).gt_series(&open.shift(1))
& close.shift(4).lt_series(&open.shift(4))
& low.shift(1).gte_series(&low.shift(4))
& low.shift(1).lte_series(&close.shift(4))
& close.shift(1).gt_series(&high.shift(4))
& low.shift(2).gte_series(&low.shift(4))
& low.shift(2).lte_series(&close.shift(4))
& low.shift(3).gte_series(&low.shift(4))
& low.shift(3).lte_series(&close.shift(4))
& high.shift(2).lt_series(&high.shift(4))
& high.shift(3).lt_series(&high.shift(4)))
(close.shift(1).gt(&open.shift(1))
& close.shift(4).lt(&open.shift(4))
& low.shift(1).gte(&low.shift(4))
& low.shift(1).lte(&close.shift(4))
& close.shift(1).gt(&high.shift(4))
& low.shift(2).gte(&low.shift(4))
& low.shift(2).lte(&close.shift(4))
& low.shift(3).gte(&low.shift(4))
& low.shift(3).lte(&close.shift(4))
& high.shift(2).lt(&high.shift(4))
& high.shift(3).lt(&high.shift(4)))
.into()
}

Expand All @@ -26,17 +26,17 @@ pub fn bearish(open: &[f64], high: &[f64], low: &[f64], close: &[f64]) -> Vec<bo
let low = Series::from(low);
let close = Series::from(close);

(close.shift(1).lt_series(&open.shift(1))
& close.shift(4).gt_series(&open.shift(4))
& high.shift(1).lte_series(&high.shift(4))
& high.shift(1).gte_series(&close.shift(4))
& close.shift(1).lt_series(&low.shift(4))
& high.shift(2).lte_series(&high.shift(4))
& high.shift(2).gte_series(&close.shift(4))
& high.shift(3).lte_series(&high.shift(4))
& high.shift(3).gte_series(&close.shift(4))
& low.shift(2).gt_series(&low.shift(4))
& low.shift(3).gt_series(&low.shift(4)))
(close.shift(1).lt(&open.shift(1))
& close.shift(4).gt(&open.shift(4))
& high.shift(1).lte(&high.shift(4))
& high.shift(1).gte(&close.shift(4))
& close.shift(1).lt(&low.shift(4))
& high.shift(2).lte(&high.shift(4))
& high.shift(2).gte(&close.shift(4))
& high.shift(3).lte(&high.shift(4))
& high.shift(3).gte(&close.shift(4))
& low.shift(2).gt(&low.shift(4))
& low.shift(3).gt(&low.shift(4)))
.into()
}

Expand Down
20 changes: 10 additions & 10 deletions ta_lib/patterns/src/bottle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,11 @@ pub fn bullish(open: &[f64], low: &[f64], close: &[f64]) -> Vec<bool> {
let low = Series::from(low);
let close = Series::from(close);

(close.shift(2).gt_series(&open.shift(2))
& close.shift(1).gt_series(&open.shift(1))
& open.shift(1).lt_series(&close.shift(2))
& open.shift(1).eq_series(&low.shift(1))
& close.shift(1).gt_series(&close.shift(2)))
(close.shift(2).gt(&open.shift(2))
& close.shift(1).gt(&open.shift(1))
& open.shift(1).lt(&close.shift(2))
& open.shift(1).eq(&low.shift(1))
& close.shift(1).gt(&close.shift(2)))
.into()
}

Expand All @@ -18,11 +18,11 @@ pub fn bearish(open: &[f64], high: &[f64], close: &[f64]) -> Vec<bool> {
let high = Series::from(high);
let close = Series::from(close);

(close.shift(2).lt_series(&open.shift(2))
& close.shift(1).lt_series(&open.shift(1))
& open.shift(1).gt_series(&close.shift(2))
& open.shift(1).eq_series(&high.shift(1))
& close.shift(1).lt_series(&close.shift(2)))
(close.shift(2).lt(&open.shift(2))
& close.shift(1).lt(&open.shift(1))
& open.shift(1).gt(&close.shift(2))
& open.shift(1).eq(&high.shift(1))
& close.shift(1).lt(&close.shift(2)))
.into()
}

Expand Down
28 changes: 14 additions & 14 deletions ta_lib/patterns/src/breakaway.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,27 @@ pub fn bullish(open: &[f64], close: &[f64]) -> Vec<bool> {
let open = Series::from(open);
let close = Series::from(close);

(close.gt_series(&open.shift(3))
& close.gt_series(&open)
& close.shift(1).lt_series(&open.shift(1))
& close.shift(2).lt_series(&open.shift(2))
& close.shift(3).lt_series(&open.shift(3))
& close.shift(4).lt_series(&open.shift(4))
& open.shift(3).lt_series(&close.shift(4)))
(close.gt(&open.shift(3))
& close.gt(&open)
& close.shift(1).lt(&open.shift(1))
& close.shift(2).lt(&open.shift(2))
& close.shift(3).lt(&open.shift(3))
& close.shift(4).lt(&open.shift(4))
& open.shift(3).lt(&close.shift(4)))
.into()
}

pub fn bearish(open: &[f64], close: &[f64]) -> Vec<bool> {
let open = Series::from(open);
let close = Series::from(close);

(close.lt_series(&open.shift(3))
& close.lt_series(&open)
& close.shift(1).gt_series(&open.shift(1))
& close.shift(2).gt_series(&open.shift(2))
& close.shift(3).gt_series(&open.shift(3))
& close.shift(4).gt_series(&open.shift(4))
& open.shift(3).gt_series(&close.shift(4)))
(close.lt(&open.shift(3))
& close.lt(&open)
& close.shift(1).gt(&open.shift(1))
& close.shift(2).gt(&open.shift(2))
& close.shift(3).gt(&open.shift(3))
& close.shift(4).gt(&open.shift(4))
& open.shift(3).gt(&close.shift(4)))
.into()
}

Expand Down
16 changes: 8 additions & 8 deletions ta_lib/patterns/src/counterattack.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,21 +4,21 @@ pub fn bullish(open: &[f64], close: &[f64]) -> Vec<bool> {
let open = Series::from(open);
let close = Series::from(close);

(open.lt_series(&close.shift(1))
& close.gt_series(&open)
& close.shift(1).lt_series(&open.shift(1))
& close.eq_series(&close.shift(1)))
(open.lt(&close.shift(1))
& close.gt(&open)
& close.shift(1).lt(&open.shift(1))
& close.eq(&close.shift(1)))
.into()
}

pub fn bearish(open: &[f64], close: &[f64]) -> Vec<bool> {
let open = Series::from(open);
let close = Series::from(close);

(open.gt_series(&close.shift(1))
& close.lt_series(&open)
& close.shift(1).gt_series(&open.shift(1))
& close.eq_series(&close.shift(1)))
(open.gt(&close.shift(1))
& close.lt(&open)
& close.shift(1).gt(&open.shift(1))
& close.eq(&close.shift(1)))
.into()
}

Expand Down
Loading

0 comments on commit e06228d

Please sign in to comment.