From 07ee5d64eb2a679570589be37b38f66d28b65b98 Mon Sep 17 00:00:00 2001 From: m5l14i11 Date: Thu, 2 May 2024 02:20:52 +0300 Subject: [PATCH] upd --- strategy/generator/bootstrap/_trend_follow.py | 3 +- ta_lib/benches/indicators.rs | 4 +- ta_lib/indicators/trend/src/ce.rs | 137 ++++++++++-------- ta_lib/strategies/signal/src/flip/ce_flip.rs | 2 - 4 files changed, 79 insertions(+), 67 deletions(-) diff --git a/strategy/generator/bootstrap/_trend_follow.py b/strategy/generator/bootstrap/_trend_follow.py index 9c9615b5..b259a423 100644 --- a/strategy/generator/bootstrap/_trend_follow.py +++ b/strategy/generator/bootstrap/_trend_follow.py @@ -40,6 +40,7 @@ from strategy.generator.signal.bb.macd_bb import MacdBbSignal from strategy.generator.signal.bb.vwap_bb import VwapBbSignal from strategy.generator.signal.breakout.dch_two_ma import DchMa2BreakoutSignal +from strategy.generator.signal.flip.ce_flip import CeFlipSignal from strategy.generator.signal.flip.supertrend_flip import SupertrendFlipSignal from strategy.generator.signal.ma.ma2_rsi import Ma2RsiSignal from strategy.generator.signal.ma.ma3_cross import Ma3CrossSignal @@ -365,7 +366,7 @@ def _generate_signal(self, signal: TrendSignalType): if signal == TrendSignalType.FLIP: return np.random.choice( [ - # CeFlipSignal(), + CeFlipSignal(), SupertrendFlipSignal(), ] ) diff --git a/ta_lib/benches/indicators.rs b/ta_lib/benches/indicators.rs index 44f84034..227c1893 100644 --- a/ta_lib/benches/indicators.rs +++ b/ta_lib/benches/indicators.rs @@ -473,9 +473,9 @@ fn trend(c: &mut Criterion) { let factor = 3.0; let period = 20; - (high, low, close, atr, period, factor) + (close, atr, period, factor) }, - |(high, low, close, atr, period, factor)| ce(high, low, close, atr, *period, *factor), + |(close, atr, period, factor)| ce(close, atr, *period, *factor), criterion::BatchSize::SmallInput, ) }); diff --git a/ta_lib/indicators/trend/src/ce.rs b/ta_lib/indicators/trend/src/ce.rs index 19dc5a87..83542095 100644 --- a/ta_lib/indicators/trend/src/ce.rs +++ b/ta_lib/indicators/trend/src/ce.rs @@ -1,66 +1,46 @@ use core::prelude::*; pub fn ce( - high: &Series, - low: &Series, close: &Series, atr: &Series, period: usize, factor: f32, ) -> (Series, Series) { - let atr_multi = atr * factor; - - let short_stop = low.lowest(period) + &atr_multi; - let long_stop = high.highest(period) - &atr_multi; - let len = close.len(); - let mut short = Series::empty(len); - let mut long = Series::empty(len); + let atr_mul = atr * factor; - for _ in 0..len { - let prev_short = short.shift(1); - short = iff!( - close.sgt(&prev_short), - short_stop, - short_stop.min(&prev_short) - ); - short = iff!(prev_short.na(), short_stop, short); - - let prev_long = long.shift(1); - long = iff!(close.slt(&prev_long), long_stop, long.max(&prev_long)); - long = iff!(prev_long.na(), long_stop, long); - } + let basic_up = close.highest(period) - &atr_mul; + let mut up = Series::empty(len); + + let basic_dn = close.lowest(period) + &atr_mul; + let mut dn = Series::empty(len); - let mut direction = Series::empty(len); - let trend_up = Series::one(len); - let trend_dn = trend_up.negate(); - let trend_middle = Series::zero(len); let prev_close = close.shift(1); + let mut direction = Series::empty(len); - let long_switch = iff!( - close.sgte(&short.shift(1)) & prev_close.slt(&short.shift(1)), - trend_up, - trend_middle - ); - let short_switch = iff!( - close.slte(&long.shift(1)) & prev_close.sgt(&long.shift(1)), - trend_up, - trend_middle - ); + let trend_bull = Series::one(len); + let trend_bear = trend_bull.negate(); for _ in 0..len { - let prev_direction = direction.shift(1); - let cond_one = prev_direction.slte(&trend_middle) & long_switch.clone().into(); - let cond_two = prev_direction.sgte(&trend_middle) & short_switch.clone().into(); - - direction = iff!( - prev_direction.na(), - trend_middle, - iff!(cond_one, trend_up, iff!(cond_two, trend_dn, prev_direction)) - ); + let prev_up = up.shift(1); + up = iff!(prev_close.sgt(&prev_up), basic_up.max(&prev_up), basic_up); + + let prev_dn = dn.shift(1); + dn = iff!(prev_close.slt(&prev_dn), basic_dn.min(&prev_dn), basic_dn); + + direction = nz!(direction.shift(1), direction); + direction = iff!(close.sgte(&prev_dn), trend_bull, direction); + direction = iff!(close.slt(&prev_up), trend_bear, direction); } - let trend = iff!(direction.sgt(&trend_middle), long, short); + let first_non_empty = direction + .iter() + .find(|&&el| el.is_some()) + .map(|&el| -el.unwrap()); + + direction = direction.nz(first_non_empty); + + let trend = iff!(direction.seq(&ONE), up, dn); (direction, trend) } @@ -71,35 +51,68 @@ mod tests { use volatility::atr; #[test] - fn test_ce() { + fn test_ce_dn_up() { let high = Series::from([ - 2.0859, 2.0881, 2.0889, 2.0896, 2.0896, 2.0907, 2.0919, 2.1004, 2.0936, 2.0939, 2.0972, - 2.0974, 2.0997, 2.0982, 2.0982, 2.0974, 2.0942, 2.0924, 2.0924, + 4.8217, 4.8285, 4.8225, 4.8146, 4.8019, 4.8115, 4.8160, 4.8142, 4.8179, 4.8506, 4.8819, + 4.8833, 4.8466, 4.8448, 4.8586, ]); let low = Series::from([ - 2.0846, 2.0846, 2.0881, 2.0886, 2.0865, 2.0875, 2.0886, 2.0909, 2.0899, 2.0912, 2.0934, - 2.0947, 2.0946, 2.0973, 2.0942, 2.0920, 2.0908, 2.0917, 2.0869, + 4.7982, 4.8073, 4.8003, 4.7904, 4.7774, 4.7706, 4.7923, 4.7884, 4.7968, 4.8152, 4.8396, + 4.8393, 4.8275, 4.8300, 4.8397, ]); let close = Series::from([ - 2.0846, 2.0881, 2.0889, 2.0896, 2.0875, 2.0904, 2.0909, 2.0936, 2.0912, 2.0939, 2.0949, - 2.0952, 2.0973, 2.0982, 2.0974, 2.0942, 2.0917, 2.0924, 2.0869, + 4.8122, 4.8112, 4.8122, 4.7973, 4.7800, 4.8039, 4.8047, 4.8115, 4.8152, 4.8414, 4.8814, + 4.8423, 4.8439, 4.8429, 4.8535, ]); - let atr_period = 2; - let atr = atr(&high, &low, &close, Smooth::SMMA, atr_period); + let period = 3; + let atr = atr(&high, &low, &close, Smooth::SMMA, period); + + let factor = 2.0; + let expected_trend = vec![ + 4.8592, 4.8566666, 4.8563113, 4.8435073, 4.8271384, 4.8271384, 4.8271384, 4.8271384, + 4.8271384, 4.784503, 4.815269, 4.815269, 4.81972, 4.81972, 4.81972, + ]; + let expected_direction = vec![ + -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + ]; - let factor = 3.0; + let (direction, trend) = ce(&close, &atr, period, factor); + let result_direction: Vec = direction.into(); + let result_trend: Vec = trend.into(); + + assert_eq!(high.len(), low.len()); + assert_eq!(high.len(), close.len()); + assert_eq!(result_trend, expected_trend); + assert_eq!(result_direction, expected_direction); + } + + #[test] + fn test_ce_up_dn() { + let high = Series::from([ + 4.8565, 4.8791, 4.9177, 4.9199, 4.9614, 4.9570, 4.9486, 4.9010, 4.9085, 4.8713, 4.8591, + 4.8660, + ]); + let low = Series::from([ + 4.8113, 4.8447, 4.8696, 4.8858, 4.9128, 4.8955, 4.9005, 4.8551, 4.8447, 4.8325, 4.8333, + 4.8244, + ]); + let close = Series::from([ + 4.8558, 4.8706, 4.9108, 4.9128, 4.9565, 4.9013, 4.9005, 4.8868, 4.8527, 4.8553, 4.8532, + 4.8305, + ]); let period = 3; + let atr = atr(&high, &low, &close, Smooth::SMMA, period); + + let factor = 2.0; let expected_trend = vec![ - 2.0885003, 2.0885003, 2.0840998, 2.0856998, 2.0856998, 2.0856998, 2.0856998, 2.0856998, - 2.0856998, 2.0888877, 2.0888877, 2.0888877, 2.0888877, 2.0920804, 2.0920804, 2.0920804, - 2.100978, 2.0976512, 2.0976512, + 4.946201, 4.9390006, 4.9390006, 4.9390006, 4.8700404, 4.8700404, 4.8700404, 4.8700404, + 4.959112, 4.949508, 4.9344387, 4.912726, ]; let expected_direction = vec![ - 0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, - -1.0, -1.0, + -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, ]; - let (direction, trend) = ce(&high, &low, &close, &atr, period, factor); + let (direction, trend) = ce(&close, &atr, period, factor); let result_direction: Vec = direction.into(); let result_trend: Vec = trend.into(); diff --git a/ta_lib/strategies/signal/src/flip/ce_flip.rs b/ta_lib/strategies/signal/src/flip/ce_flip.rs index dcba581f..f1cac03d 100644 --- a/ta_lib/strategies/signal/src/flip/ce_flip.rs +++ b/ta_lib/strategies/signal/src/flip/ce_flip.rs @@ -25,8 +25,6 @@ impl Signal for CeFlipSignal { fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { let (direction, _) = ce( - data.high(), - data.low(), data.close(), &data.atr(self.atr_period), self.period,