diff --git a/ta_lib/core/src/math.rs b/ta_lib/core/src/math.rs index 8d568bde..091f30e1 100644 --- a/ta_lib/core/src/math.rs +++ b/ta_lib/core/src/math.rs @@ -1,113 +1,68 @@ use crate::series::Series; impl Series { - pub fn max(&self, scalar: f64) -> Self { - self.fmap(|val| match val { - Some(v) => Some(v.max(scalar)), - None => Some(scalar), - }) - } + pub fn window(&self, period: usize, f: F) -> Self + where + F: Fn(&[f64], usize, usize) -> f64, + { + let len = self.len(); + let mut result = Self::empty(len); + let mut window = vec![0.0; period]; + let mut pos = 0; - pub fn min(&self, scalar: f64) -> Self { - self.fmap(|val| match val { - Some(v) => Some(v.min(scalar)), - None => Some(scalar), - }) - } + for i in 0..len { + if let Some(value) = self[i] { + window[pos] = value; - pub fn cumsum(&self) -> Self { - let len = self.len(); - let mut cumsum = Self::empty(len); + let size = (i + 1).min(period); - let mut sum = 0.0; + result[i] = Some(f(&window[0..size], size, i)); - for i in 0..len { - if let Some(val) = self[i] { - sum += val; - cumsum[i] = Some(sum); + pos = (pos + 1) % period; } } - cumsum + result } - pub fn sum(&self, period: usize) -> Self { - let len = self.len(); - let mut sum = Self::empty(len); - let mut window_sum = 0.0; + pub fn max(&self, scalar: f64) -> Self { + self.fmap(|val| val.map(|v| v.max(scalar)).or(Some(scalar))) + } - for i in 0..len { - if let Some(value) = self[i] { - window_sum += value; + pub fn min(&self, scalar: f64) -> Self { + self.fmap(|val| val.map(|v| v.min(scalar)).or(Some(scalar))) + } - if i >= period { - if let Some(old_value) = self[i - period] { - window_sum -= old_value; - } - } + pub fn cumsum(&self) -> Self { + let mut sum = 0.0; - sum[i] = Some(window_sum); - } - } + self.fmap(|val| { + val.map(|v| { + sum += v; + sum + }) + }) + } - sum + pub fn sum(&self, period: usize) -> Self { + self.window(period, |window, _, _| window.iter().sum()) } pub fn mean(&self, period: usize) -> Self { - let len = self.len(); - let mut mean = Self::empty(len); - let mut window_sum = 0.0; - let mut count = 0; - - for i in 0..len { - if let Some(value) = self[i] { - window_sum += value; - count += 1; - - if i >= period { - if let Some(old_value) = self[i - period] { - window_sum -= old_value; - count -= 1; - } - } - - if count > 0 { - mean[i] = Some(window_sum / count as f64); - } - } - } - - mean + self.window(period, |window, size, _| { + window.iter().sum::() / size as f64 + }) } pub fn std(&self, period: usize) -> Self { - let len = self.len(); - let mut std = Self::empty(len); let mean = self.mean(period); - let mut window = Vec::with_capacity(period); - - for i in 0..len { - let value = self[i]; - - if let Some(v) = value { - window.push(v); - - if window.len() > period { - window.remove(0); - } - } - - let count = window.len(); - - if count > 0 { - let mean_val = mean[i].unwrap_or(0.0); - let variance = - window.iter().map(|&v| (v - mean_val).powi(2)).sum::() / count as f64; - std[i] = Some(variance.sqrt()); - } - } - std + self.window(period, |window, size, i| { + let mean_val = mean[i].unwrap_or(0.0); + let variance = + window.iter().map(|&v| (v - mean_val).powi(2)).sum::() / size as f64; + variance.sqrt() + }) } } @@ -181,6 +136,17 @@ mod tests { } } + #[test] + fn test_cumsum() { + let source = vec![1.0, 2.0, 3.0, 4.0, 5.0]; + let expected = vec![Some(1.0), Some(3.0), Some(6.0), Some(10.0), Some(15.0)]; + let series = Series::from(&source); + + let result = series.cumsum(); + + assert_eq!(result, expected); + } + #[test] fn test_sum() { let source = vec![1.0, 2.0, 3.0, 4.0, 5.0]; diff --git a/ta_lib/core/src/series.rs b/ta_lib/core/src/series.rs index 8c81dcaf..42a1c450 100644 --- a/ta_lib/core/src/series.rs +++ b/ta_lib/core/src/series.rs @@ -10,7 +10,6 @@ impl Series { pub fn fmap(&self, mut f: F) -> Series where F: FnMut(Option<&T>) -> Option, - T: Clone, { Series { data: self.data.iter().map(|x| f(x.as_ref())).collect(), @@ -20,7 +19,6 @@ impl Series { pub fn zip_with(self, other: &Series, mut f: F) -> Series where F: FnMut(Option, Option) -> Option, - T: Clone, U: Clone, { let data = self @@ -68,6 +66,30 @@ impl IndexMut for Series { } impl Series { + fn extreme_value(&self, period: usize, comparison: F) -> Self + where + F: Fn(&f64, &f64) -> bool, + { + let len = self.len(); + let mut extreme_values = Self::empty(len); + let mut indices: Vec = Vec::new(); + + for i in 0..len { + let start = if i >= period { i - period + 1 } else { 0 }; + + indices.retain(|&j| j >= start); + indices.retain(|&j| { + comparison(&self[j].unwrap_or(f64::NAN), &self[i].unwrap_or(f64::NAN)) + }); + + indices.push(i); + + extreme_values[i] = self[indices[0]].clone(); + } + + extreme_values + } + pub fn nz(&self, replacement: Option) -> Self { let replacement = replacement.unwrap_or(0.0); @@ -91,43 +113,11 @@ impl Series { } pub fn highest(&self, period: usize) -> Self { - let len = self.len(); - let mut highest_values = Self::empty(len); - let mut indices: Vec = Vec::new(); - - for i in 0..len { - let start = if i >= period { i - period + 1 } else { 0 }; - - indices.retain(|&j| j >= start); - - indices.retain(|&j| self[j] > self[i]); - - indices.push(i); - - highest_values[i] = self[indices[0]]; - } - - highest_values + self.extreme_value(period, |a, b| a >= b) } pub fn lowest(&self, period: usize) -> Self { - let len = self.len(); - let mut lowest_values = Self::empty(len); - let mut indices: Vec = Vec::new(); - - for i in 0..len { - let start = if i >= period { i - period + 1 } else { 0 }; - - indices.retain(|&j| j >= start); - - indices.retain(|&j| self[j] < self[i]); - - indices.push(i); - - lowest_values[i] = self[indices[0]]; - } - - lowest_values + self.extreme_value(period, |a, b| a <= b) } }