diff --git a/ta_lib/strategy/src/base.rs b/ta_lib/strategy/src/base.rs index 371b6da5..1bb5f7b6 100644 --- a/ta_lib/strategy/src/base.rs +++ b/ta_lib/strategy/src/base.rs @@ -57,19 +57,25 @@ impl Price for OHLCVSeries { } } -#[repr(u32)] -pub enum Action { - GoLong = 1, - GoShort = 2, - ExitLong = 3, - ExitShort = 4, - DoNothing = 0, +#[derive(Debug, Clone, Copy)] +pub enum TradeAction { + GoLong(f64), + GoShort(f64), + ExitLong, + ExitShort, + DoNothing, +} + +impl Default for TradeAction { + fn default() -> Self { + TradeAction::DoNothing + } } pub trait Strategy { const DEFAULT_LOOKBACK: usize = 55; - fn next(&mut self, data: OHLCV) -> Action; + fn next(&mut self, data: OHLCV) -> TradeAction; fn can_process(&self) -> bool; fn params(&self) -> HashMap; fn entry(&self, data: &OHLCVSeries) -> (bool, bool); @@ -90,44 +96,42 @@ impl BaseStrategy { lookback_period, } } -} -impl Strategy for BaseStrategy { - fn next(&mut self, data: OHLCV) -> Action { + fn add_data(&mut self, data: OHLCV) { self.data.push_back(data); if self.data.len() > self.lookback_period { self.data.pop_front(); } + } +} - if self.can_process() { - let series = OHLCVSeries::new(&self.data); +impl Strategy for BaseStrategy { + fn next(&mut self, data: OHLCV) -> TradeAction { + self.add_data(data); - let (go_long, go_short) = self.entry(&series); - let (exit_long, exit_short) = self.exit(&series); + if !self.can_process() { + return TradeAction::default(); + } - if go_long { - return Action::GoLong; - } + let series = OHLCVSeries::new(&self.data); - if go_short { - return Action::GoShort; - } + let (go_long, go_short) = self.entry(&series); + let (exit_long, exit_short) = self.exit(&series); - if exit_long { - return Action::ExitLong; - } + let suggested_entry = series.hlc3().last().unwrap_or(&std::f64::NAN).clone(); - if exit_short { - return Action::ExitShort; - } + match (go_long, go_short, exit_long, exit_short) { + (true, _, _, _) => TradeAction::GoLong(suggested_entry), + (_, true, _, _) => TradeAction::GoShort(suggested_entry), + (_, _, true, _) => TradeAction::ExitLong, + (_, _, _, true) => TradeAction::ExitShort, + _ => TradeAction::default(), } - - Action::DoNothing } fn can_process(&self) -> bool { - self.data.len() == self.lookback_period + self.data.len() >= self.lookback_period } fn params(&self) -> HashMap {