diff --git a/ta_lib/indicators/trend/src/supertrend.rs b/ta_lib/indicators/trend/src/supertrend.rs index 110dfd8b..d6456953 100644 --- a/ta_lib/indicators/trend/src/supertrend.rs +++ b/ta_lib/indicators/trend/src/supertrend.rs @@ -10,52 +10,36 @@ pub fn supertrend( let prev_close = close.shift(1); let len = close.len(); - let basic_up = source - &atr_mul; - let mut up = Series::zero(len); + let mut up = source - &atr_mul; + let mut dn = source + &atr_mul; + let mut direction = Series::empty(len); + let mut supertrend = Series::empty(len); - for _ in 0..len { - let prev_up = nz!(up.shift(1), basic_up); - up = iff!( - basic_up.sgt(&prev_up) | prev_close.slt(&prev_up), - basic_up, - prev_up - ); - } - - let basic_dn = source + &atr_mul; - let mut dn = Series::zero(len); - - for _ in 0..len { - let prev_dn = nz!(dn.shift(1), basic_dn); - dn = iff!( - basic_dn.slt(&prev_dn) | prev_close.sgt(&prev_dn), - basic_dn, - prev_dn - ); - } - - let mut direction = Series::one(len); let trend_bull = Series::one(len); let trend_bear = trend_bull.negate(); + direction = iff!(&atr.shift(1).na(), trend_bull, direction); for _ in 0..len { - direction = nz!(direction.shift(1), direction); + let prev_up = nz!(up.shift(1), up); + up = iff!(up.sgt(&prev_up) | prev_close.slt(&prev_up), up, prev_up); - let cond_bull = direction.seq(&MINUS_ONE) & close.sgt(&dn.shift(1)); - let cond_bear = direction.seq(&ONE) & close.slt(&up.shift(1)); + let prev_dn = nz!(dn.shift(1), dn); + dn = iff!(dn.slt(&prev_dn) | prev_close.sgt(&prev_dn), dn, prev_dn); + + let prev_supertrend = supertrend.shift(1); direction = iff!( - cond_bull, + &atr.shift(1).na(), trend_bull, - iff!(cond_bear, trend_bear, direction) + iff!( + prev_supertrend.seq(&prev_dn), + iff!(close.sgt(&dn), trend_bear, trend_bull), + iff!(close.slt(&up), trend_bull, trend_bear) + ) ); - } - let supertrend = iff!( - direction.seq(&MINUS_ONE), - dn, - iff!(direction.seq(&ONE), up, Series::zero(len)) - ); + supertrend = iff!(direction.seq(&MINUS_ONE), up, dn); + } (direction, supertrend) } @@ -69,16 +53,19 @@ mod tests { #[test] fn test_supertrend() { let high = Series::from([ - 6.5425, 6.5527, 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, - 6.5480, 6.5325, 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, 6.5543, 6.5563, + 6.622, 6.650, 6.664, 6.687, 6.695, 6.647, 6.624, 6.607, 6.609, 6.614, 6.609, 6.590, + 6.580, 6.580, 6.586, 6.587, 6.586, 6.574, 6.584, 6.577, 6.578, 6.583, 6.575, 6.577, + 6.578, 6.567, 6.575, 6.588, 6.596, 6.600, 6.587, 6.573, 6.566, 6.586, ]); let low = Series::from([ - 6.5156, 6.5195, 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, - 6.5229, 6.4982, 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, 6.5231, 6.5222, + 6.582, 6.614, 6.636, 6.637, 6.602, 6.606, 6.576, 6.579, 6.579, 6.587, 6.562, 6.566, + 6.559, 6.551, 6.567, 6.556, 6.560, 6.541, 6.543, 6.564, 6.560, 6.557, 6.557, 6.565, + 6.559, 6.552, 6.563, 6.567, 6.575, 6.570, 6.570, 6.541, 6.552, 6.555, ]); let close = Series::from([ - 6.5232, 6.5474, 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, - 6.5254, 6.5038, 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, 6.5527, 6.5316, + 6.617, 6.645, 6.641, 6.679, 6.627, 6.624, 6.593, 6.607, 6.588, 6.608, 6.581, 6.579, + 6.569, 6.574, 6.574, 6.578, 6.568, 6.543, 6.577, 6.571, 6.563, 6.575, 6.565, 6.573, + 6.567, 6.563, 6.571, 6.578, 6.596, 6.574, 6.572, 6.556, 6.555, 6.579, ]); let hl2 = median_price(&high, &low); let atr_period = 3; @@ -86,13 +73,15 @@ mod tests { let factor = 3.0; let expected_supertrend = vec![ - 6.4483504, 6.4491, 6.4747, 6.4747, 6.4747, 6.4747, 6.4747, 6.4747, 6.4747, 6.4747, - 6.4747, 6.4747, 6.4747, 6.5986347, 6.5774565, 6.5774565, 6.5774565, 6.5774565, - 6.5774565, 6.5774565, + 6.7220016, 6.7220016, 6.7220016, 6.7220016, 6.7220016, 6.7220016, 6.7220016, 6.7220016, + 6.71035, 6.7050667, 6.7022114, 6.679808, 6.658372, 6.653748, 6.653748, 6.653748, + 6.653748, 6.644672, 6.644672, 6.644672, 6.639718, 6.639718, 6.632763, 6.627509, + 6.6251726, 6.6122813, 6.6122813, 6.6122813, 6.6122813, 6.6122813, 6.6122813, 6.6122813, + 6.6122813, 6.6122813, ]; let expected_direction = vec![ - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, - -1.0, -1.0, -1.0, -1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ]; let (direction, supertrend) = supertrend(&hl2, &close, &atr, factor); diff --git a/ta_lib/strategies/signal/src/flip/supertrend_flip.rs b/ta_lib/strategies/signal/src/flip/supertrend_flip.rs index dc5686d1..7877fe8b 100644 --- a/ta_lib/strategies/signal/src/flip/supertrend_flip.rs +++ b/ta_lib/strategies/signal/src/flip/supertrend_flip.rs @@ -29,6 +29,291 @@ impl Signal for SupertrendFlipSignal { self.factor, ); - (direction.cross_over(&ZERO), direction.cross_under(&ZERO)) + (direction.cross_under(&ZERO), direction.cross_over(&ZERO)) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::collections::VecDeque; + + #[test] + fn test_supertrend_flip_signal() { + let signal = SupertrendFlipSignal::new(3.0, 3.0); + let data = VecDeque::from([ + OHLCV { + open: 6.161, + high: 6.161, + low: 6.136, + close: 6.146, + volume: 100.0, + }, + OHLCV { + open: 6.146, + high: 6.150, + low: 6.135, + close: 6.148, + volume: 100.0, + }, + OHLCV { + open: 6.148, + high: 6.157, + low: 6.143, + close: 6.155, + volume: 100.0, + }, + OHLCV { + open: 6.155, + high: 6.174, + low: 6.155, + close: 6.174, + volume: 100.0, + }, + OHLCV { + open: 6.174, + high: 6.179, + low: 6.163, + close: 6.173, + volume: 100.0, + }, + OHLCV { + open: 6.173, + high: 6.192, + low: 6.170, + close: 6.172, + volume: 100.0, + }, + OHLCV { + open: 6.172, + high: 6.184, + low: 6.167, + close: 6.182, + volume: 100.0, + }, + OHLCV { + open: 6.182, + high: 6.183, + low: 6.170, + close: 6.176, + volume: 100.0, + }, + OHLCV { + open: 6.176, + high: 6.185, + low: 6.161, + close: 6.167, + volume: 100.0, + }, + OHLCV { + open: 6.167, + high: 6.193, + low: 6.165, + close: 6.193, + volume: 100.0, + }, + OHLCV { + open: 6.193, + high: 6.213, + low: 6.188, + close: 6.201, + volume: 100.0, + }, + OHLCV { + open: 6.201, + high: 6.201, + low: 6.183, + close: 6.198, + volume: 100.0, + }, + OHLCV { + open: 6.198, + high: 6.205, + low: 6.186, + close: 6.188, + volume: 100.0, + }, + OHLCV { + open: 6.188, + high: 6.188, + low: 6.168, + close: 6.174, + volume: 100.0, + }, + OHLCV { + open: 6.174, + high: 6.180, + low: 6.164, + close: 6.176, + volume: 100.0, + }, + OHLCV { + open: 6.176, + high: 6.194, + low: 6.176, + close: 6.191, + volume: 100.0, + }, + OHLCV { + open: 6.191, + high: 6.191, + low: 6.169, + close: 6.175, + volume: 100.0, + }, + OHLCV { + open: 6.175, + high: 6.184, + low: 6.175, + close: 6.184, + volume: 100.0, + }, + OHLCV { + open: 6.184, + high: 6.194, + low: 6.176, + close: 6.188, + volume: 100.0, + }, + OHLCV { + open: 6.188, + high: 6.188, + low: 6.171, + close: 6.179, + volume: 100.0, + }, + OHLCV { + open: 6.179, + high: 6.188, + low: 6.171, + close: 6.184, + volume: 100.0, + }, + OHLCV { + open: 6.184, + high: 6.195, + low: 6.182, + close: 6.195, + volume: 100.0, + }, + OHLCV { + open: 6.195, + high: 6.212, + low: 6.193, + close: 6.210, + volume: 100.0, + }, + OHLCV { + open: 6.210, + high: 6.210, + low: 6.180, + close: 6.192, + volume: 100.0, + }, + OHLCV { + open: 6.192, + high: 6.193, + low: 6.152, + close: 6.173, + volume: 100.0, + }, + OHLCV { + open: 6.173, + high: 6.178, + low: 6.161, + close: 6.174, + volume: 100.0, + }, + OHLCV { + open: 6.174, + high: 6.189, + low: 6.161, + close: 6.189, + volume: 100.0, + }, + OHLCV { + open: 6.189, + high: 6.197, + low: 6.183, + close: 6.194, + volume: 100.0, + }, + OHLCV { + open: 6.194, + high: 6.205, + low: 6.189, + close: 6.202, + volume: 100.0, + }, + OHLCV { + open: 6.202, + high: 6.232, + low: 6.193, + close: 6.231, + volume: 100.0, + }, + OHLCV { + open: 6.231, + high: 6.236, + low: 6.215, + close: 6.218, + volume: 100.0, + }, + OHLCV { + open: 6.218, + high: 6.222, + low: 6.205, + close: 6.208, + volume: 100.0, + }, + OHLCV { + open: 6.208, + high: 6.233, + low: 6.208, + close: 6.224, + volume: 100.0, + }, + OHLCV { + open: 6.224, + high: 6.231, + low: 6.213, + close: 6.220, + volume: 100.0, + }, + OHLCV { + open: 6.220, + high: 6.224, + low: 6.196, + close: 6.208, + volume: 100.0, + }, + OHLCV { + open: 6.208, + high: 6.219, + low: 6.202, + close: 6.204, + volume: 100.0, + }, + ]); + let series = OHLCVSeries::from_data(&data); + + let (long_signal, short_signal) = signal.generate(&series); + + let expected_long_signal = vec![ + false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, true, false, + false, false, false, false, false, false, false, false, false, false, false, false, + ]; + let expected_short_signal = vec![ + false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, false, false, false, false, false, false, false, false, + ]; + + let result_long_signal: Vec = long_signal.into(); + let result_short_signal: Vec = short_signal.into(); + + assert_eq!(result_long_signal, expected_long_signal); + assert_eq!(result_short_signal, expected_short_signal); } }