Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
m5l14i11 committed May 1, 2024
1 parent 41c209b commit 07ee5d6
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 67 deletions.
3 changes: 2 additions & 1 deletion strategy/generator/bootstrap/_trend_follow.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from strategy.generator.signal.bb.macd_bb import MacdBbSignal
from strategy.generator.signal.bb.vwap_bb import VwapBbSignal
from strategy.generator.signal.breakout.dch_two_ma import DchMa2BreakoutSignal
from strategy.generator.signal.flip.ce_flip import CeFlipSignal
from strategy.generator.signal.flip.supertrend_flip import SupertrendFlipSignal
from strategy.generator.signal.ma.ma2_rsi import Ma2RsiSignal
from strategy.generator.signal.ma.ma3_cross import Ma3CrossSignal
Expand Down Expand Up @@ -365,7 +366,7 @@ def _generate_signal(self, signal: TrendSignalType):
if signal == TrendSignalType.FLIP:
return np.random.choice(
[
# CeFlipSignal(),
CeFlipSignal(),
SupertrendFlipSignal(),
]
)
Expand Down
4 changes: 2 additions & 2 deletions ta_lib/benches/indicators.rs
Original file line number Diff line number Diff line change
Expand Up @@ -473,9 +473,9 @@ fn trend(c: &mut Criterion) {
let factor = 3.0;
let period = 20;

(high, low, close, atr, period, factor)
(close, atr, period, factor)
},
|(high, low, close, atr, period, factor)| ce(high, low, close, atr, *period, *factor),
|(close, atr, period, factor)| ce(close, atr, *period, *factor),
criterion::BatchSize::SmallInput,
)
});
Expand Down
137 changes: 75 additions & 62 deletions ta_lib/indicators/trend/src/ce.rs
Original file line number Diff line number Diff line change
@@ -1,66 +1,46 @@
use core::prelude::*;

pub fn ce(
high: &Series<f32>,
low: &Series<f32>,
close: &Series<f32>,
atr: &Series<f32>,
period: usize,
factor: f32,
) -> (Series<f32>, Series<f32>) {
let atr_multi = atr * factor;

let short_stop = low.lowest(period) + &atr_multi;
let long_stop = high.highest(period) - &atr_multi;

let len = close.len();
let mut short = Series::empty(len);
let mut long = Series::empty(len);
let atr_mul = atr * factor;

for _ in 0..len {
let prev_short = short.shift(1);
short = iff!(
close.sgt(&prev_short),
short_stop,
short_stop.min(&prev_short)
);
short = iff!(prev_short.na(), short_stop, short);

let prev_long = long.shift(1);
long = iff!(close.slt(&prev_long), long_stop, long.max(&prev_long));
long = iff!(prev_long.na(), long_stop, long);
}
let basic_up = close.highest(period) - &atr_mul;
let mut up = Series::empty(len);

let basic_dn = close.lowest(period) + &atr_mul;
let mut dn = Series::empty(len);

let mut direction = Series::empty(len);
let trend_up = Series::one(len);
let trend_dn = trend_up.negate();
let trend_middle = Series::zero(len);
let prev_close = close.shift(1);
let mut direction = Series::empty(len);

let long_switch = iff!(
close.sgte(&short.shift(1)) & prev_close.slt(&short.shift(1)),
trend_up,
trend_middle
);
let short_switch = iff!(
close.slte(&long.shift(1)) & prev_close.sgt(&long.shift(1)),
trend_up,
trend_middle
);
let trend_bull = Series::one(len);
let trend_bear = trend_bull.negate();

for _ in 0..len {
let prev_direction = direction.shift(1);
let cond_one = prev_direction.slte(&trend_middle) & long_switch.clone().into();
let cond_two = prev_direction.sgte(&trend_middle) & short_switch.clone().into();

direction = iff!(
prev_direction.na(),
trend_middle,
iff!(cond_one, trend_up, iff!(cond_two, trend_dn, prev_direction))
);
let prev_up = up.shift(1);
up = iff!(prev_close.sgt(&prev_up), basic_up.max(&prev_up), basic_up);

let prev_dn = dn.shift(1);
dn = iff!(prev_close.slt(&prev_dn), basic_dn.min(&prev_dn), basic_dn);

direction = nz!(direction.shift(1), direction);
direction = iff!(close.sgte(&prev_dn), trend_bull, direction);
direction = iff!(close.slt(&prev_up), trend_bear, direction);
}

let trend = iff!(direction.sgt(&trend_middle), long, short);
let first_non_empty = direction
.iter()
.find(|&&el| el.is_some())
.map(|&el| -el.unwrap());

direction = direction.nz(first_non_empty);

let trend = iff!(direction.seq(&ONE), up, dn);

(direction, trend)
}
Expand All @@ -71,35 +51,68 @@ mod tests {
use volatility::atr;

#[test]
fn test_ce() {
fn test_ce_dn_up() {
let high = Series::from([
2.0859, 2.0881, 2.0889, 2.0896, 2.0896, 2.0907, 2.0919, 2.1004, 2.0936, 2.0939, 2.0972,
2.0974, 2.0997, 2.0982, 2.0982, 2.0974, 2.0942, 2.0924, 2.0924,
4.8217, 4.8285, 4.8225, 4.8146, 4.8019, 4.8115, 4.8160, 4.8142, 4.8179, 4.8506, 4.8819,
4.8833, 4.8466, 4.8448, 4.8586,
]);
let low = Series::from([
2.0846, 2.0846, 2.0881, 2.0886, 2.0865, 2.0875, 2.0886, 2.0909, 2.0899, 2.0912, 2.0934,
2.0947, 2.0946, 2.0973, 2.0942, 2.0920, 2.0908, 2.0917, 2.0869,
4.7982, 4.8073, 4.8003, 4.7904, 4.7774, 4.7706, 4.7923, 4.7884, 4.7968, 4.8152, 4.8396,
4.8393, 4.8275, 4.8300, 4.8397,
]);
let close = Series::from([
2.0846, 2.0881, 2.0889, 2.0896, 2.0875, 2.0904, 2.0909, 2.0936, 2.0912, 2.0939, 2.0949,
2.0952, 2.0973, 2.0982, 2.0974, 2.0942, 2.0917, 2.0924, 2.0869,
4.8122, 4.8112, 4.8122, 4.7973, 4.7800, 4.8039, 4.8047, 4.8115, 4.8152, 4.8414, 4.8814,
4.8423, 4.8439, 4.8429, 4.8535,
]);
let atr_period = 2;
let atr = atr(&high, &low, &close, Smooth::SMMA, atr_period);
let period = 3;
let atr = atr(&high, &low, &close, Smooth::SMMA, period);

let factor = 2.0;
let expected_trend = vec![
4.8592, 4.8566666, 4.8563113, 4.8435073, 4.8271384, 4.8271384, 4.8271384, 4.8271384,
4.8271384, 4.784503, 4.815269, 4.815269, 4.81972, 4.81972, 4.81972,
];
let expected_direction = vec![
-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
];

let factor = 3.0;
let (direction, trend) = ce(&close, &atr, period, factor);
let result_direction: Vec<f32> = direction.into();
let result_trend: Vec<f32> = trend.into();

assert_eq!(high.len(), low.len());
assert_eq!(high.len(), close.len());
assert_eq!(result_trend, expected_trend);
assert_eq!(result_direction, expected_direction);
}

#[test]
fn test_ce_up_dn() {
let high = Series::from([
4.8565, 4.8791, 4.9177, 4.9199, 4.9614, 4.9570, 4.9486, 4.9010, 4.9085, 4.8713, 4.8591,
4.8660,
]);
let low = Series::from([
4.8113, 4.8447, 4.8696, 4.8858, 4.9128, 4.8955, 4.9005, 4.8551, 4.8447, 4.8325, 4.8333,
4.8244,
]);
let close = Series::from([
4.8558, 4.8706, 4.9108, 4.9128, 4.9565, 4.9013, 4.9005, 4.8868, 4.8527, 4.8553, 4.8532,
4.8305,
]);
let period = 3;
let atr = atr(&high, &low, &close, Smooth::SMMA, period);

let factor = 2.0;
let expected_trend = vec![
2.0885003, 2.0885003, 2.0840998, 2.0856998, 2.0856998, 2.0856998, 2.0856998, 2.0856998,
2.0856998, 2.0888877, 2.0888877, 2.0888877, 2.0888877, 2.0920804, 2.0920804, 2.0920804,
2.100978, 2.0976512, 2.0976512,
4.946201, 4.9390006, 4.9390006, 4.9390006, 4.8700404, 4.8700404, 4.8700404, 4.8700404,
4.959112, 4.949508, 4.9344387, 4.912726,
];
let expected_direction = vec![
0.0, 0.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0,
-1.0, -1.0,
-1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0,
];

let (direction, trend) = ce(&high, &low, &close, &atr, period, factor);
let (direction, trend) = ce(&close, &atr, period, factor);
let result_direction: Vec<f32> = direction.into();
let result_trend: Vec<f32> = trend.into();

Expand Down
2 changes: 0 additions & 2 deletions ta_lib/strategies/signal/src/flip/ce_flip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,6 @@ impl Signal for CeFlipSignal {

fn generate(&self, data: &OHLCVSeries) -> (Series<bool>, Series<bool>) {
let (direction, _) = ce(
data.high(),
data.low(),
data.close(),
&data.atr(self.atr_period),
self.period,
Expand Down

0 comments on commit 07ee5d6

Please sign in to comment.