Skip to content

Commit

Permalink
cross_ma
Browse files Browse the repository at this point in the history
  • Loading branch information
m5l14i11 committed Aug 8, 2023
1 parent 4ee320d commit 243a2d1
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 20 deletions.
8 changes: 8 additions & 0 deletions ta_lib/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions ta_lib/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ members = [
"patterns/candlestick",
"price",
"strategies/base",
"strategies/trend_follow",
"utils"
]

Expand Down
48 changes: 28 additions & 20 deletions ta_lib/strategies/base/src/base.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use price::{average::average_price, median::median_price, typical::typical_price, wcl::wcl};
use std::{
cmp::min,
cmp::max,
collections::{HashMap, VecDeque},
};

Expand Down Expand Up @@ -90,7 +90,7 @@ pub struct BaseStrategy {

impl BaseStrategy {
pub fn new(lookback_period: usize) -> BaseStrategy {
let lookback_period = min(lookback_period, Self::DEFAULT_LOOKBACK);
let lookback_period = max(lookback_period, Self::DEFAULT_LOOKBACK);

BaseStrategy {
data: VecDeque::with_capacity(lookback_period),
Expand Down Expand Up @@ -158,7 +158,7 @@ mod tests {
#[test]
fn test_base_strategy_creation() {
let strategy = BaseStrategy::new(20);
assert_eq!(strategy.lookback_period, 20);
assert_eq!(strategy.lookback_period, 55);
}

#[test]
Expand All @@ -175,21 +175,24 @@ mod tests {
});
assert_eq!(strategy.can_process(), false);

strategy.next(OHLCV {
open: 2.0,
high: 3.0,
low: 2.0,
close: 3.0,
volume: 2000.0,
});
for _i in 0..54 {
strategy.next(OHLCV {
open: 2.0,
high: 3.0,
low: 2.0,
close: 3.0,
volume: 2000.0,
});
}

assert_eq!(strategy.can_process(), true);
}

#[test]
fn test_base_strategy_params() {
let strategy = BaseStrategy::new(20);
let params = strategy.params();
assert_eq!(params.get("lookback_period"), Some(&20));
assert_eq!(params.get("lookback_period"), Some(&55));
}

#[test]
Expand Down Expand Up @@ -232,18 +235,23 @@ mod tests {

let series = OHLCVSeries::new(&strategy.data);

assert_eq!(series.open, vec![2.0, 3.0, 4.0]);
assert_eq!(series.high, vec![3.0, 4.0, 5.0]);
assert_eq!(series.low, vec![1.5, 2.5, 3.5]);
assert_eq!(series.close, vec![2.5, 3.5, 4.5]);
assert_eq!(series.volume, vec![200.0, 300.0, 400.0]);
assert_eq!(series.open, vec![1.0, 2.0, 3.0, 4.0]);
assert_eq!(series.high, vec![2.0, 3.0, 4.0, 5.0]);
assert_eq!(series.low, vec![0.5, 1.5, 2.5, 3.5]);
assert_eq!(series.close, vec![1.5, 2.5, 3.5, 4.5]);
assert_eq!(series.volume, vec![100.0, 200.0, 300.0, 400.0]);

assert_eq!(series.hl2(), vec![2.25, 3.25, 4.25]);
assert_eq!(series.hl2(), vec![1.25, 2.25, 3.25, 4.25]);
assert_eq!(
series.hlc3(),
vec![2.3333333333333335, 3.3333333333333335, 4.333333333333333]
vec![
1.3333333333333333,
2.3333333333333335,
3.3333333333333335,
4.333333333333333
]
);
assert_eq!(series.hlcc4(), vec![2.375, 3.375, 4.375]);
assert_eq!(series.ohlc4(), vec![2.25, 3.25, 4.25]);
assert_eq!(series.hlcc4(), vec![1.375, 2.375, 3.375, 4.375]);
assert_eq!(series.ohlc4(), vec![1.25, 2.25, 3.25, 4.25]);
}
}
8 changes: 8 additions & 0 deletions ta_lib/strategies/trend_follow/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[package]
name = "trend_follow"
version = "0.1.0"
edition = "2021"

[dependencies]
base = { path = "../base" }
trend = { path = "../../indicators/trend" }
77 changes: 77 additions & 0 deletions ta_lib/strategies/trend_follow/src/cross_ma.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
use base::base::{BaseStrategy, OHLCVSeries, Strategy, TradeAction, OHLCV};
use std::cmp::max;
use std::collections::HashMap;
use trend::sma::sma;

pub struct MACrossStrategy {
base: BaseStrategy,
short_period: usize,
long_period: usize,
}

impl MACrossStrategy {
pub fn new(short_period: usize, long_period: usize) -> Self {
let lookback_period = max(short_period, long_period);

MACrossStrategy {
base: BaseStrategy::new(lookback_period),
short_period,
long_period,
}
}
}

impl Strategy for MACrossStrategy {
fn next(&mut self, data: OHLCV) -> TradeAction {
self.base.next(data)
}

fn can_process(&self) -> bool {
self.base.can_process()
}

fn params(&self) -> HashMap<String, usize> {
let mut map = self.base.params();
map.insert(String::from("short_period"), self.short_period);
map.insert(String::from("long_period"), self.long_period);
map
}

fn entry(&self, data: &OHLCVSeries) -> (bool, bool) {
let short_ma = sma(&data.close, self.short_period);
let long_ma = sma(&data.close, self.long_period);

let long_signal: Vec<bool> = short_ma.cross_over(&long_ma).into();
let short_signal: Vec<bool> = short_ma.cross_under(&long_ma).into();

(
long_signal.last().cloned().unwrap_or_default(),
short_signal.last().cloned().unwrap_or_default(),
)
}

fn exit(&self, data: &OHLCVSeries) -> (bool, bool) {
self.base.exit(data)
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_macrossstrategy_new() {
let strat = MACrossStrategy::new(50, 100);
assert_eq!(strat.short_period, 50);
assert_eq!(strat.long_period, 100);
}

#[test]
fn test_macrossstrategy_params() {
let strat = MACrossStrategy::new(50, 100);
let params = strat.params();
assert_eq!(params.get("lookback_period"), Some(&100));
assert_eq!(params.get("short_period"), Some(&50));
assert_eq!(params.get("long_period"), Some(&100));
}
}
1 change: 1 addition & 0 deletions ta_lib/strategies/trend_follow/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
mod cross_ma;

0 comments on commit 243a2d1

Please sign in to comment.