Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
m5l14i11 committed Apr 19, 2024
1 parent 0dc0433 commit 68663de
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 315 deletions.
344 changes: 122 additions & 222 deletions ta_lib/Cargo.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion ta_lib/core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,10 @@ mod traits;

pub mod prelude {
pub use crate::constants::*;
pub use crate::iff;
pub use crate::series::Series;
pub use crate::smoothing::Smooth;
pub use crate::traits::*;
pub use crate::{iff, nz};
}

pub use prelude::*;
26 changes: 26 additions & 0 deletions ta_lib/core/src/macros.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,20 @@ macro_rules! iff {
}};
}

#[macro_export]
macro_rules! nz {
($source:expr, $fill:expr) => {{
$source
.iter()
.zip($fill.iter())
.map(|(source, fill)| match source {
Some(val) => Some(*val),
_ => *fill,
})
.collect::<Series<_>>()
}};
}

#[cfg(test)]
mod tests {
use crate::{Comparator, Series};
Expand All @@ -31,4 +45,16 @@ mod tests {

assert_eq!(result, expected);
}

#[test]
fn test_nz() {
let source = Series::from([f32::NAN, 5.0, 4.0, 3.0, 5.0]);
let fill = Series::from([1.0, 0.5, 5.0, 2.0, 8.0]);

let expected = Series::from([1.0, 5.0, 4.0, 3.0, 5.0]);

let result = nz!(source, fill);

assert_eq!(result, expected);
}
}
9 changes: 5 additions & 4 deletions ta_lib/core/src/smoothing.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use crate::iff;
use crate::series::Series;
use crate::{iff, nz};

#[derive(Copy, Clone)]
pub enum Smooth {
Expand All @@ -16,15 +16,15 @@ pub enum Smooth {
impl Series<f32> {
pub fn ew(&self, alpha: &Series<f32>, seed: &Series<f32>) -> Self {
let len = self.len();
let mut sum = Series::empty(len);
let mut sum = Series::zero(len);

for _ in 0..len {
let prev = sum.shift(1);

sum = iff!(
prev.na(),
seed,
alpha * self + (1. - alpha) * prev.nz(Some(0.))
alpha * self + (1. - alpha) * nz!(prev, sum)
)
}

Expand Down Expand Up @@ -56,8 +56,9 @@ impl Series<f32> {

fn smma(&self, period: usize) -> Self {
let alpha = Series::fill(1. / (period as f32), self.len());
let seed = self.ma(period);

self.ew(&alpha, &self.ma(period))
self.ew(&alpha, &seed)
}

fn wma(&self, period: usize) -> Self {
Expand Down
9 changes: 4 additions & 5 deletions ta_lib/indicators/trend/src/ema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,12 @@ mod tests {
#[test]
fn test_ema() {
let source = Series::from([
6.8575, 6.855, 6.858, 6.86, 6.8480, 6.8575, 6.864, 6.8565, 6.8455, 6.8450, 6.8365,
6.8310, 6.8355, 6.8360, 6.8345, 6.8285, 6.8395,
6.5232, 6.5474, 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452,
6.5254, 6.5038, 6.4614, 6.4854, 6.4966,
]);
let expected = vec![
6.8575, 6.85625, 6.857125, 6.8585625, 6.853281, 6.8553905, 6.8596954, 6.858098,
6.851799, 6.848399, 6.8424497, 6.8367248, 6.836112, 6.8360558, 6.8352776, 6.8318887,
6.8356943,
6.5232, 6.5353003, 6.5447, 6.5694504, 6.552125, 6.5235624, 6.526681, 6.544141,
6.5332203, 6.5316105, 6.5384054, 6.531903, 6.5178514, 6.489626, 6.487513, 6.492057,
];

let result: Vec<f32> = ema(&source, 3).into();
Expand Down
80 changes: 50 additions & 30 deletions ta_lib/indicators/trend/src/supertrend.rs
Original file line number Diff line number Diff line change
@@ -1,53 +1,61 @@
use core::prelude::*;

pub fn supertrend(
hl2: &Series<f32>,
source: &Series<f32>,
close: &Series<f32>,
atr: &Series<f32>,
factor: f32,
) -> (Series<f32>, Series<f32>) {
let atr_mul = atr * factor;
let prev_close = close.shift(1);
let len = close.len();

let mut up = hl2 - &atr_mul;
let mut dn = hl2 + &atr_mul;
let basic_up = source - &atr_mul;
let mut up = Series::zero(len);

let len = hl2.len();
for _ in 0..len {
let prev_up = nz!(up.shift(1), basic_up);
up = iff!(
basic_up.sgt(&prev_up) | prev_close.slt(&prev_up),
basic_up,
prev_up
);
}

let mut prev_up = up.shift(1);
let mut prev_dn = dn.shift(1);

prev_up = iff!(prev_up.na(), up, prev_up);
prev_dn = iff!(prev_dn.na(), dn, prev_dn);
let basic_dn = source + &atr_mul;
let mut dn = Series::zero(len);

for _ in 0..len {
let prev_close = close.shift(1);
up = iff!(prev_close.sgt(&prev_up), up.max(&prev_up), up);
dn = iff!(prev_close.slt(&prev_dn), dn.min(&prev_dn), dn);
let prev_dn = nz!(dn.shift(1), basic_dn);
dn = iff!(
basic_dn.slt(&prev_dn) | prev_close.sgt(&prev_dn),
basic_dn,
prev_dn
);
}

let mut direction = Series::one(len);
let trend_bull = Series::one(len);
let trend_bear = trend_bull.negate();

for _ in 0..len {
let prev_direction = direction.shift(1);
direction = nz!(direction.shift(1), direction);

direction = iff!(prev_direction.na(), direction, prev_direction);
let cond_bull = direction.seq(&MINUS_ONE) & close.sgt(&dn.shift(1));
let cond_bear = direction.seq(&ONE) & close.slt(&up.shift(1));

direction = iff!(
direction.seq(&ONE) & close.slt(&prev_up),
trend_bear,
direction
);
direction = iff!(
direction.seq(&MINUS_ONE) & close.sgt(&prev_dn),
cond_bull,
trend_bull,
direction
iff!(cond_bear, trend_bear, direction)
);
}

let mut supertrend = iff!(direction.seq(&MINUS_ONE), dn, Series::zero(len));
supertrend = iff!(direction.seq(&ONE), up, supertrend);
let supertrend = iff!(
direction.seq(&MINUS_ONE),
dn,
iff!(direction.seq(&ONE), up, Series::zero(len))
);

(direction, supertrend)
}
Expand All @@ -60,19 +68,31 @@ mod tests {

#[test]
fn test_supertrend() {
let high = Series::from([6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630]);
let low = Series::from([6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126]);
let close = Series::from([6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223]);
let high = Series::from([
6.5425, 6.5527, 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497,
6.5480, 6.5325, 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, 6.5543, 6.5563,
]);
let low = Series::from([
6.5156, 6.5195, 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206,
6.5229, 6.4982, 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, 6.5231, 6.5222,
]);
let close = Series::from([
6.5232, 6.5474, 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452,
6.5254, 6.5038, 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, 6.5527, 6.5316,
]);
let hl2 = median_price(&high, &low);
let atr_period = 2;
let atr_period = 3;
let atr = atr(&high, &low, &close, Smooth::SMMA, atr_period);

let factor = 3.0;
let expected_supertrend = vec![
6.4963, 6.4963, 6.446601, 6.403225, 6.3497434, 6.376522, 6.3796854,
6.4483504, 6.4491, 6.4747, 6.4747, 6.4747, 6.4747, 6.4747, 6.4747, 6.4747, 6.4747,
6.4747, 6.4747, 6.4747, 6.5986347, 6.5774565, 6.5774565, 6.5774565, 6.5774565,
6.5774565, 6.5774565,
];
let expected_direction = vec![
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0,
-1.0, -1.0, -1.0, -1.0,
];

let (direction, supertrend) = supertrend(&hl2, &close, &atr, factor);
Expand All @@ -81,7 +101,7 @@ mod tests {

assert_eq!(high.len(), low.len());
assert_eq!(high.len(), close.len());
assert_eq!(result_supertrend, expected_supertrend);
assert_eq!(result_direction, expected_direction);
assert_eq!(result_supertrend, expected_supertrend);
}
}
22 changes: 10 additions & 12 deletions ta_lib/indicators/trend/src/wma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,21 +10,19 @@ mod tests {

#[test]
fn test_wma() {
let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]);
let source = Series::from([
6.5232, 6.5474, 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452,
6.5254, 6.5038, 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, 6.5527, 6.5316,
]);
let period = 3;
let epsilon = 0.001;
let expected = [0.0, 0.0, 2.333333, 3.333333, 4.333333];
let expected = [
0.0, 0.0, 6.5467167, 6.573034, 6.557817, 6.5248, 6.5190334, 6.5399, 6.5366497,
6.5326996, 6.5363164, 6.5327663, 6.5179005, 6.4862, 6.4804664, 6.487, 6.5022836,
6.516834, 6.5373, 6.5378666,
];

let result: Vec<f32> = wma(&source, period).into();

for i in 0..source.len() {
assert!(
(result[i] - expected[i]).abs() < epsilon,
"at position {}: {} != {}",
i,
result[i],
expected[i]
)
}
assert_eq!(result, expected);
}
}
57 changes: 22 additions & 35 deletions ta_lib/indicators/volatility/src/atr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,48 +18,35 @@ mod tests {
#[test]
fn test_atr_smma() {
let high = Series::from([
6.8430, 6.8660, 6.8685, 6.8690, 6.865, 6.8595, 6.8565, 6.862, 6.859, 6.86, 6.8580,
6.8605, 6.8620, 6.86, 6.859, 6.8670, 6.8640, 6.8575, 6.8485, 6.8450, 6.8365, 6.84,
6.8385, 6.8365, 6.8345, 6.8395,
6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325,
6.5065, 6.4866, 6.5536, 6.5142, 6.5294,
]);
let low = Series::from([
6.8380, 6.8430, 6.8595, 6.8640, 6.8435, 6.8445, 6.8510, 6.8560, 6.8520, 6.8530, 6.8550,
6.8550, 6.8565, 6.8475, 6.8480, 6.8535, 6.8565, 6.8455, 6.8445, 6.8365, 6.8310, 6.8310,
6.8345, 6.8325, 6.8275, 6.8285,
6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982,
6.4560, 6.4614, 6.4798, 6.4903, 6.5066,
]);
let close = Series::from([
6.6430, 6.8595, 6.8680, 6.8650, 6.8445, 6.8560, 6.8565, 6.8590, 6.8530, 6.8575, 6.855,
6.858, 6.86, 6.8480, 6.8575, 6.864, 6.8565, 6.8455, 6.8450, 6.8365, 6.8310, 6.8355,
6.8360, 6.8345, 6.8285, 6.8395,
6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038,
6.4614, 6.4854, 6.4966, 6.5117, 6.5270,
]);
let period = 3;
let expected = [
0.0050001144,
0.07766677,
0.054777943,
0.038185332,
0.032623433,
0.02674891,
0.019666044,
0.01511071,
0.01240713,
0.010604743,
0.0080697555,
0.0072131166,
0.006642024,
0.008594777,
0.00939657,
0.010764452,
0.009676199,
0.010450827,
0.008300454,
0.008366844,
0.007411334,
0.00794099,
0.0066273883,
0.0057516545,
0.006167759,
0.007778558,
0.01819992,
0.03396654,
0.044011116,
0.05464077,
0.05036052,
0.05254035,
0.051826984,
0.040484603,
0.036689714,
0.032826394,
0.033317544,
0.039045,
0.03442996,
0.047553174,
0.03966879,
0.03404585,
];

let result: Vec<f32> = atr(&high, &low, &close, Smooth::SMMA, period).into();
Expand Down
36 changes: 31 additions & 5 deletions ta_lib/indicators/volatility/src/tr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ pub fn tr(high: &Series<f32>, low: &Series<f32>, close: &Series<f32>) -> Series<
high.shift(1).na(),
diff,
diff.max(&(high - &prev_close).abs())
.max(&(low.negate() + &prev_close).abs())
.max(&(low - &prev_close).abs())
)
}

Expand All @@ -18,10 +18,36 @@ mod tests {

#[test]
fn test_true_range() {
let high = Series::from([50.0, 60.0, 55.0, 70.0]);
let low = Series::from([40.0, 50.0, 45.0, 60.0]);
let close = Series::from([45.0, 55.0, 50.0, 65.0]);
let expected = vec![10.0, 15.0, 10.0, 20.0];
let high = Series::from([
6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325,
6.5065, 6.4866, 6.5536, 6.5142, 6.5294,
]);
let low = Series::from([
6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982,
6.4560, 6.4614, 6.4798, 6.4903, 6.5066,
]);
let close = Series::from([
6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038,
6.4614, 6.4854, 6.4966, 6.5117, 6.5270,
]);
let expected = vec![
0.01819992,
0.06549978,
0.064100266,
0.07590008,
0.041800022,
0.056900024,
0.050400257,
0.017799854,
0.029099941,
0.025099754,
0.03429985,
0.050499916,
0.02519989,
0.07379961,
0.023900032,
0.022799969,
];

let result: Vec<f32> = tr(&high, &low, &close).into();

Expand Down
2 changes: 1 addition & 1 deletion ta_lib/strategies/signal/src/flip/supertrend_flip.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ impl Signal for SupertrendFlipSignal {
self.factor,
);

(direction.cross_under(&ZERO), direction.cross_over(&ZERO))
(direction.cross_over(&ZERO), direction.cross_under(&ZERO))
}
}

0 comments on commit 68663de

Please sign in to comment.