Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
m5l14i11 committed Sep 22, 2023
1 parent 5256d07 commit bfec04b
Show file tree
Hide file tree
Showing 8 changed files with 156 additions and 12 deletions.
38 changes: 29 additions & 9 deletions strategy/generator/trend_follow.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from enum import Enum, auto
import numpy as np
from random import shuffle

Expand All @@ -7,19 +8,29 @@
from core.models.strategy import Strategy
from core.models.candle import TrendCandleType

from ..signal.cross_three_ma import Cross3xMovingAverageSignal
from ..filter.dumb import DumbFilter
from ..filter.ma import MovingAverageFilter
from ..signal.candle import TrendCandleSignal
from ..signal.cross_ma import CrossMovingAverageSignal
from ..signal.cross_two_ma import Cross2xMovingAverageSignal
from ..signal.snatr import SNATRSignal
from ..signal.testing_ground import TestingGroundSignal
from ..signal.rsi_ma import RSIMovingAverageSignal
from ..signal.rsi_two_ma import RSI2xMovingAverageSignal
from ..stop_loss.atr import ATRStopLoss

class StrategyTypes(Enum):
Cross2xMa = auto()
Cross3xMa = auto()
Ground = auto()
SnAtr = auto()
Candle = auto()
RsiMa = auto()
Rsi2xMa = auto()


class TrendFollowStrategyGenerator(AbstractStrategyGenerator):
STRATEGY_TYPES = ['crossma', 'ground', 'snatr', 'candle', 'rsima', 'rsi2xma']
STRATEGY_TYPES = list(StrategyTypes)

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -59,6 +70,7 @@ def _generate_strategy(self, strategy_type):
moving_avg_type = np.random.choice(list(MovingAverageType))
trend_candle_type = np.random.choice(list(TrendCandleType))
short_period = RandomParameter(20.0, 50.0, 5.0)
medium_period = RandomParameter(50.0, 100.0, 5.0)
long_period = RandomParameter(50.0, 200.0, 10.0)
rsi_lower_barrier = RandomParameter(5.0, 15.0, 1.0)
rsi_upper_barrier = RandomParameter(75.0, 95, 1.0)
Expand All @@ -67,38 +79,46 @@ def _generate_strategy(self, strategy_type):

filter = np.random.choice([MovingAverageFilter(moving_avg_type, long_period)])

if strategy_type == 'crossma':
if strategy_type == StrategyTypes.Cross2xMa:
return Strategy(
'cross2xma',
Cross2xMovingAverageSignal(moving_avg_type, short_period, long_period),
DumbFilter(),
ATRStopLoss(multi=atr_multi)
)

if strategy_type == StrategyTypes.Cross3xMa:
return Strategy(
'crossma',
CrossMovingAverageSignal(moving_avg_type, short_period, long_period),
'cross3xma',
Cross3xMovingAverageSignal(moving_avg_type, short_period, medium_period, long_period),
DumbFilter(),
ATRStopLoss(multi=atr_multi)
)
elif strategy_type == 'candle':
elif strategy_type == StrategyTypes.Candle:
return Strategy(
'candle',
TrendCandleSignal(trend_candle_type),
filter,
ATRStopLoss(multi=atr_multi)
)

elif strategy_type == 'snatr':
elif strategy_type == StrategyTypes.SnAtr:
return Strategy(
'snatr',
SNATRSignal(),
MovingAverageFilter(moving_avg_type, long_period),
ATRStopLoss(multi=atr_multi)
)

elif strategy_type == 'rsima':
elif strategy_type == StrategyTypes.RsiMa:
return Strategy(
'rsima',
RSIMovingAverageSignal(ma=moving_avg_type, period=short_period),
DumbFilter(),
ATRStopLoss(multi=atr_multi)
)

elif strategy_type == 'rsi2xma':
elif strategy_type == StrategyTypes.Rsi2xMa:
return Strategy(
'rsi2xma',
RSI2xMovingAverageSignal(ma=moving_avg_type, lower_barrier=rsi_lower_barrier, upper_barrier=rsi_upper_barrier),
Expand Down
23 changes: 23 additions & 0 deletions strategy/signal/cross_three_ma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
from dataclasses import dataclass

from core.models.indicator import Indicator
from core.models.moving_average import MovingAverageType
from core.models.parameter import Parameter, RandomParameter


@dataclass(frozen=True)
class Cross3xMovingAverageSignal(Indicator):
ma: MovingAverageType = MovingAverageType.SMA
short_period: Parameter = RandomParameter(5.0, 50.0, 5.0)
medium_period: Parameter = RandomParameter(50.0, 100.0, 5.0)
long_period: Parameter = RandomParameter(100.0, 200.0, 10.0)

@property
def parameters(self):
return [
self.ma,
self.short_period,
self.medium_period,
self.long_period,
]

Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


@dataclass(frozen=True)
class CrossMovingAverageSignal(Indicator):
class Cross2xMovingAverageSignal(Indicator):
ma: MovingAverageType = MovingAverageType.SMA
short_period: Parameter = RandomParameter(5.0, 50.0, 5.0)
long_period: Parameter = RandomParameter(50.0, 200.0, 10.0)
Expand Down
55 changes: 55 additions & 0 deletions ta_lib/strategies/signal/src/cross_three_ma.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use base::{OHLCVSeries, Signal};
use core::Series;
use shared::{ma_indicator, MovingAverageType};

pub struct Cross3xMASignal {
smoothing: MovingAverageType,
short_period: usize,
medium_period: usize,
long_period: usize,
}

impl Cross3xMASignal {
pub fn new(
smoothing: MovingAverageType,
short_period: f32,
medium_period: f32,
long_period: f32,
) -> Self {
Self {
smoothing,
short_period: short_period as usize,
medium_period: medium_period as usize,
long_period: long_period as usize,
}
}
}

impl Signal for Cross3xMASignal {
fn id(&self) -> String {
format!(
"CROSS3xMA_{}:{}:{}:{}",
self.smoothing, self.short_period, self.medium_period, self.long_period
)
}

fn lookback(&self) -> usize {
let adjusted_lookback = std::cmp::max(self.short_period, self.long_period);
std::cmp::max(adjusted_lookback, self.medium_period)
}

fn entry(&self, data: &OHLCVSeries) -> (Series<bool>, Series<bool>) {
let short_ma = ma_indicator(&self.smoothing, data, self.short_period);
let medium_ma = ma_indicator(&self.smoothing, data, self.medium_period);
let long_ma = ma_indicator(&self.smoothing, data, self.long_period);

let long_signal = short_ma.cross_over(&medium_ma) & medium_ma.cross_over(&long_ma);
let short_signal = short_ma.cross_under(&medium_ma) & medium_ma.cross_under(&long_ma);

(long_signal, short_signal)
}

fn exit(&self, _data: &OHLCVSeries) -> (Series<bool>, Series<bool>) {
(Series::empty(1), Series::empty(1))
}
}
2 changes: 2 additions & 0 deletions ta_lib/strategies/signal/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
mod cross_three_ma;
mod cross_two_ma;
mod rsi_ma;
mod rsi_two_ma;
mod snatr;
mod testing_ground;
mod trend_candle;

pub use cross_three_ma::Cross3xMASignal;
pub use cross_two_ma::Cross2xMASignal;
pub use rsi_ma::RSIMASignal;
pub use rsi_two_ma::RSI2xMASignal;
Expand Down
29 changes: 28 additions & 1 deletion ta_lib/strategies/trend_follow/src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ fn map_to_candle(candle: usize) -> TrendCandleType {
}

#[no_mangle]
pub fn register_crossma(
pub fn register_cross2xma(
smoothing: f32,
short_period: f32,
long_period: f32,
Expand All @@ -71,6 +71,33 @@ pub fn register_crossma(
register_strategy(signal, filter, stoploss)
}

#[no_mangle]
pub fn register_cross3xma(
smoothing: f32,
short_period: f32,
medium_period: f32,
long_period: f32,
atr_period: f32,
atr_factor: f32,
) -> i32 {
let smoothing = map_to_ma(smoothing as usize);
let signal = map_to_signal(SignalConfig::Cross3xMa {
smoothing,
short_period,
medium_period,
long_period,
});
let filter = map_to_filter(FilterConfig::Dumb {
period: long_period,
});
let stoploss = map_to_stoploss(StopLossConfig::Atr {
period: atr_period,
multi: atr_factor,
});

register_strategy(signal, filter, stoploss)
}

#[no_mangle]
pub fn register_rsima(
rsi_period: f32,
Expand Down
19 changes: 18 additions & 1 deletion ta_lib/strategies/trend_follow/src/signal_mapper.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use base::Signal;
use shared::{MovingAverageType, RSIType, TrendCandleType};
use signal::{
Cross2xMASignal, RSI2xMASignal, RSIMASignal, SNATRSignal, TestingGroundSignal,
Cross2xMASignal, Cross3xMASignal, RSI2xMASignal, RSIMASignal, SNATRSignal, TestingGroundSignal,
TrendCandleSignal,
};

Expand All @@ -11,6 +11,12 @@ pub enum SignalConfig {
short_period: f32,
long_period: f32,
},
Cross3xMa {
smoothing: MovingAverageType,
short_period: f32,
medium_period: f32,
long_period: f32,
},
RsiMa {
rsi_type: RSIType,
rsi_period: f32,
Expand Down Expand Up @@ -50,6 +56,17 @@ pub fn map_to_signal(config: SignalConfig) -> Box<dyn Signal> {
short_period,
long_period,
} => Box::new(Cross2xMASignal::new(smoothing, short_period, long_period)),
SignalConfig::Cross3xMa {
smoothing,
short_period,
medium_period,
long_period,
} => Box::new(Cross3xMASignal::new(
smoothing,
short_period,
medium_period,
long_period,
)),
SignalConfig::RsiMa {
rsi_type,
rsi_period,
Expand Down
Binary file modified wasm/trend_follow.wasm
Binary file not shown.

0 comments on commit bfec04b

Please sign in to comment.