From ef9a9d36b8d2c04642a1ae82844af049cd43f6b3 Mon Sep 17 00:00:00 2001 From: Siarhei Melnik <57074509+m5l14i11@users.noreply.github.com> Date: Fri, 6 Sep 2024 19:29:27 +0300 Subject: [PATCH] Tp (#47) * upd --- .env.example | 4 +- Makefile | 14 +- Pipfile | 2 + config.default.ini | 48 +- copilot/__init__.py | 3 + copilot/_actor.py | 358 +++++++++++ copilot/_prompt.py | 166 +++++ core/actors/__init__.py | 5 +- core/actors/{_actor.py => _base_actor.py} | 40 +- core/actors/_strategy_actor.py | 30 + .../risk => core/actors/policy}/__init__.py | 0 core/actors/policy/event.py | 10 + core/actors/policy/signal.py | 13 + core/actors/policy/strategy.py | 28 + core/commands/base.py | 12 +- core/commands/broker.py | 7 - core/events/base.py | 22 +- core/events/ohlcv.py | 12 + core/events/position.py | 2 +- core/events/risk.py | 4 +- core/events/signal.py | 34 +- core/interfaces/abstract_actor.py | 12 - .../abstract_executor_actor_factory.py | 7 +- core/interfaces/abstract_llm_service.py | 7 + core/interfaces/abstract_market_repository.py | 20 + ...egy.py => abstract_order_size_strategy.py} | 5 +- core/interfaces/abstract_position_factory.py | 11 +- .../abstract_strategy_generator_factory.py | 5 +- core/interfaces/abstract_timeseries.py | 34 ++ core/interfaces/abstract_wasm_manager.py | 12 + core/interfaces/abstract_wasm_service.py | 9 - core/mixins/__init__.py | 3 + core/mixins/_event_handler.py | 15 + core/models/candle.py | 16 + core/models/moving_average.py | 29 +- core/models/ohlcv.py | 155 ++++- core/models/order.py | 2 +- core/models/portfolio.py | 567 ++++++++---------- core/models/position.py | 471 ++++++++++++--- core/models/position_risk.py | 445 ++++++++++++++ core/models/profit_target.py | 136 +++++ core/models/risk_type.py | 37 +- core/models/side.py | 14 +- core/models/signal.py | 22 +- core/models/signal_risk.py | 24 + core/models/smooth.py | 11 + core/models/strategy.py | 36 +- core/models/strategy_ref.py | 86 ++- core/models/strategy_type.py | 11 + core/models/symbol.py | 9 +- core/models/ta.py | 147 +++++ core/models/timeframe.py | 22 +- core/models/timeseries_ref.py | 117 ++++ core/models/wasm_type.py | 6 + core/queries/base.py | 15 +- core/queries/copilot.py | 34 ++ core/queries/ohlcv.py | 55 ++ exchange/_bybit.py | 7 +- exchange/_bybit_ws.py | 27 +- executor/_factory.py | 4 +- executor/_market_actor.py | 69 +++ executor/_market_order_actor.py | 95 --- executor/_paper_actor.py | 162 +++++ executor/_paper_order_actor.py | 213 ------- feed/_factory.py | 5 + feed/_historical.py | 75 ++- feed/_realtime.py | 22 +- .../event_dispatcher/event_dedup.py | 25 + .../event_dispatcher/event_dispatcher.py | 21 +- .../event_dispatcher/event_handler.py | 22 +- .../event_dispatcher/event_worker.py | 17 +- .../event_dispatcher/load_balancer.py | 86 ++- .../event_dispatcher/worker_pool.py | 17 +- infrastructure/event_store/event_encoder.py | 12 +- market/__init__.py | 3 + market/_actor.py | 43 ++ optimization/_genetic.py | 10 +- portfolio/_portfolio.py | 85 +-- portfolio/_service.py | 30 +- portfolio/_strategy.py | 90 ++- position/_actor.py | 79 ++- position/_position_factory.py | 121 ++-- position/_sm.py | 15 +- position/risk/break_even.py | 103 ---- position/risk/simple.py | 20 - .../{take_profit/__init__.py => size/base.py} | 0 position/size/fixed.py | 14 +- position/size/kelly.py | 14 +- position/size/optimal_f.py | 16 +- position/take_profit/risk_reward.py | 23 - quant.py | 56 +- risk/_actor.py | 505 ++++++---------- service/__init__.py | 12 +- service/_env_secret.py | 29 +- service/_llm.py | 47 ++ service/_signal.py | 56 +- service/_timeseries.py | 58 ++ service/_wasm.py | 58 ++ service/_wasm_file.py | 20 - sor/_router.py | 146 +---- sor/_twap.py | 41 +- strategy/_actor.py | 26 +- strategy/generator/_factory.py | 16 +- strategy/generator/baseline/ma.py | 5 +- strategy/generator/bootstrap/_trend_follow.py | 401 ++++++++----- strategy/generator/confirm/base.py | 8 +- strategy/generator/confirm/bb.py | 16 + strategy/generator/confirm/braid.py | 22 + strategy/generator/confirm/cc.py | 19 + strategy/generator/confirm/cci.py | 8 +- strategy/generator/confirm/didi.py | 18 + strategy/generator/confirm/dpo.py | 2 +- strategy/generator/confirm/dso.py | 17 - strategy/generator/confirm/eom.py | 3 +- strategy/generator/confirm/roc.py | 13 - strategy/generator/confirm/rsi_neutrality.py | 2 +- strategy/generator/confirm/rsi_signalline.py | 6 +- strategy/generator/confirm/vi.py | 11 - strategy/generator/confirm/wpr.py | 16 + strategy/generator/exit/ast.py | 5 +- strategy/generator/exit/base.py | 3 +- strategy/generator/exit/cci.py | 17 - strategy/generator/exit/mad.py | 17 + strategy/generator/exit/rex.py | 20 + strategy/generator/pulse/adx.py | 6 +- strategy/generator/pulse/base.py | 3 +- strategy/generator/pulse/braid.py | 20 - strategy/generator/pulse/chop.py | 13 +- strategy/generator/pulse/nvol.py | 4 +- strategy/generator/pulse/sqz.py | 19 + strategy/generator/pulse/tdfi.py | 8 +- strategy/generator/pulse/vo.py | 8 +- strategy/generator/pulse/wae.py | 20 +- strategy/generator/pulse/yz.py | 14 + strategy/generator/signal/base.py | 23 +- .../signal/bb/{macd_bb.py => macd.py} | 2 +- .../signal/bb/{vwap_bb.py => vwap.py} | 0 .../{reversal => colorswitch}/__init__.py | 0 .../macd.py} | 5 +- .../generator/signal/contrarian/__init__.py | 0 strategy/generator/signal/contrarian/kch_a.py | 17 + strategy/generator/signal/contrarian/kch_c.py | 17 + strategy/generator/signal/contrarian/rsi_c.py | 19 + strategy/generator/signal/contrarian/rsi_d.py | 19 + .../generator/signal/contrarian/rsi_nt.py | 19 + strategy/generator/signal/contrarian/rsi_u.py | 19 + .../signal/{pattern => contrarian}/rsi_v.py | 6 +- .../snatr_reversal.py => contrarian/snatr.py} | 4 +- .../generator/signal/contrarian/stoch_e.py | 21 + .../signal/{pattern => contrarian}/tii_v.py | 4 +- .../signal/flip/{ce_flip.py => ce.py} | 6 +- .../{supertrend_flip.py => supertrend.py} | 6 +- strategy/generator/signal/ma/ma_cross.py | 2 +- strategy/generator/signal/ma/ma_surpass.py | 2 +- .../generator/signal/ma/ma_testing_ground.py | 9 +- .../{dso_neutrality_cross.py => dso_cross.py} | 0 .../{rsi_neutrality_cross.py => rsi_cross.py} | 0 ...neutrality_pullback.py => rsi_pullback.py} | 0 ...utrality_rejection.py => rsi_rejection.py} | 0 .../{tii_neutrality_cross.py => tii_cross.py} | 0 .../signal/pattern/candle_reversal.py | 11 + strategy/generator/signal/pattern/spread.py | 18 + .../generator/signal/pullback/__init__.py | 0 .../generator/signal/pullback/supertrend.py | 15 + .../generator/signal/reversal/vi_reversal.py | 11 - .../signalline/{di_signalline.py => di.py} | 0 .../signalline/{dso_signalline.py => dso.py} | 0 .../signalline/{kst_signalline.py => kst.py} | 0 .../{macd_signalline.py => macd.py} | 0 .../{qstick_signalline.py => qstick.py} | 0 .../signalline/{rsi_signalline.py => rsi.py} | 0 .../{stoch_signalline.py => stoch.py} | 0 .../{trix_signalline.py => trix.py} | 0 .../signalline/{tsi_signalline.py => tsi.py} | 0 .../signal/twolinescross/__init__.py | 0 .../dmi_reversal.py => twolinescross/dmi.py} | 6 +- strategy/generator/signal/twolinescross/vi.py | 13 + .../zerocross/{ao_zerocross.py => ao.py} | 0 .../zerocross/{bop_zerocross.py => bop.py} | 0 .../zerocross/{cc_zerocross.py => cc.py} | 0 .../zerocross/{cfo_zerocross.py => cfo.py} | 0 .../zerocross/{di_zerocross.py => di.py} | 0 .../zerocross/{macd_zerocross.py => macd.py} | 0 strategy/generator/signal/zerocross/mad.py | 18 + .../{qstick_zerocross.py => qstick.py} | 0 .../zerocross/{roc_zerocross.py => roc.py} | 0 .../zerocross/{trix_zerocross.py => trix.py} | 0 .../zerocross/{tsi_zerocross.py => tsi.py} | 0 strategy/generator/stop_loss/atr.py | 6 +- strategy/generator/stop_loss/dch.py | 2 +- system/backtest.py | 22 +- system/context.py | 2 - system/trading.py | 4 +- ta_lib/Cargo.lock | 265 ++++---- ta_lib/Cargo.toml | 10 +- ta_lib/benches/Cargo.toml | 5 - ta_lib/benches/indicators.rs | 58 +- ta_lib/benches/strategy.rs | 96 --- ta_lib/core/src/bitwise.rs | 25 +- ta_lib/core/src/cmp.rs | 102 ++-- ta_lib/core/src/constants.rs | 14 +- ta_lib/core/src/cross.rs | 29 +- ta_lib/core/src/extremum.rs | 121 ++-- ta_lib/core/src/fmt.rs | 18 + ta_lib/core/src/from.rs | 29 +- ta_lib/core/src/lib.rs | 3 + ta_lib/core/src/math.rs | 274 +++++++-- ta_lib/core/src/ops.rs | 179 +++--- ta_lib/core/src/series.rs | 156 ++--- ta_lib/core/src/smoothing.rs | 224 +++++-- ta_lib/core/src/traits.rs | 11 +- ta_lib/core/src/types.rs | 6 + ta_lib/ffi/Cargo.toml | 17 + ta_lib/ffi/src/lib.rs | 3 + ta_lib/ffi/src/timeseries.rs | 203 +++++++ ta_lib/indicators/momentum/src/ao.rs | 30 - ta_lib/indicators/momentum/src/bop.rs | 18 +- ta_lib/indicators/momentum/src/cc.rs | 18 +- ta_lib/indicators/momentum/src/cci.rs | 6 +- ta_lib/indicators/momentum/src/cfo.rs | 4 +- ta_lib/indicators/momentum/src/cmo.rs | 4 +- ta_lib/indicators/momentum/src/di.rs | 10 +- ta_lib/indicators/momentum/src/dmi.rs | 30 +- ta_lib/indicators/momentum/src/dso.rs | 46 -- ta_lib/indicators/momentum/src/kst.rs | 32 +- ta_lib/indicators/momentum/src/lib.rs | 20 +- ta_lib/indicators/momentum/src/macd.rs | 23 +- .../{trend => momentum}/src/qstick.rs | 11 +- ta_lib/indicators/momentum/src/rex.rs | 39 ++ ta_lib/indicators/momentum/src/roc.rs | 4 +- ta_lib/indicators/momentum/src/rsi.rs | 22 +- ta_lib/indicators/momentum/src/sso.rs | 45 -- ta_lib/indicators/momentum/src/stc.rs | 35 +- ta_lib/indicators/momentum/src/stoch.rs | 32 - ta_lib/indicators/momentum/src/stochosc.rs | 121 +++- ta_lib/indicators/momentum/src/tdfi.rs | 10 +- ta_lib/indicators/momentum/src/tii.rs | 18 +- ta_lib/indicators/momentum/src/trix.rs | 10 +- ta_lib/indicators/momentum/src/tsi.rs | 17 +- ta_lib/indicators/momentum/src/uo.rs | 69 +++ .../indicators/momentum/src/{pr.rs => wpr.rs} | 13 +- ta_lib/indicators/trend/src/alma.rs | 38 +- ta_lib/indicators/trend/src/ast.rs | 6 +- ta_lib/indicators/trend/src/cama.rs | 14 +- ta_lib/indicators/trend/src/ce.rs | 61 +- ta_lib/indicators/trend/src/chop.rs | 11 +- ta_lib/indicators/trend/src/dema.rs | 6 +- ta_lib/indicators/trend/src/dpo.rs | 6 +- ta_lib/indicators/trend/src/ema.rs | 4 +- ta_lib/indicators/trend/src/frama.rs | 21 +- ta_lib/indicators/trend/src/gma.rs | 2 +- ta_lib/indicators/trend/src/hema.rs | 8 +- ta_lib/indicators/trend/src/hma.rs | 6 +- ta_lib/indicators/trend/src/kama.rs | 8 +- ta_lib/indicators/trend/src/lib.rs | 16 +- ta_lib/indicators/trend/src/lsma.rs | 13 +- ta_lib/indicators/trend/src/md.rs | 6 +- .../trend/src/{kjs.rs => midpoint.rs} | 8 +- ta_lib/indicators/trend/src/pp.rs | 270 +++++++++ ta_lib/indicators/trend/src/rmsma.rs | 4 +- ta_lib/indicators/trend/src/sinwma.rs | 24 +- ta_lib/indicators/trend/src/slsma.rs | 27 + ta_lib/indicators/trend/src/sma.rs | 4 +- ta_lib/indicators/trend/src/smma.rs | 4 +- ta_lib/indicators/trend/src/supertrend.rs | 59 +- ta_lib/indicators/trend/src/t3.rs | 4 +- ta_lib/indicators/trend/src/tema.rs | 4 +- ta_lib/indicators/trend/src/tma.rs | 23 - ta_lib/indicators/trend/src/trima.rs | 25 + ta_lib/indicators/trend/src/ults.rs | 21 + ta_lib/indicators/trend/src/vi.rs | 15 +- ta_lib/indicators/trend/src/vidya.rs | 8 +- ta_lib/indicators/trend/src/vwema.rs | 4 +- ta_lib/indicators/trend/src/vwma.rs | 2 +- ta_lib/indicators/trend/src/wma.rs | 10 +- ta_lib/indicators/trend/src/zlema.rs | 6 +- ta_lib/indicators/trend/src/zlhma.rs | 20 +- ta_lib/indicators/trend/src/zlsma.rs | 21 +- ta_lib/indicators/trend/src/zltema.rs | 4 +- ta_lib/indicators/volatility/src/atr.rs | 56 -- ta_lib/indicators/volatility/src/bb.rs | 67 ++- ta_lib/indicators/volatility/src/bbw.rs | 30 - ta_lib/indicators/volatility/src/dch.rs | 39 +- ta_lib/indicators/volatility/src/gkyz.rs | 34 ++ ta_lib/indicators/volatility/src/kb.rs | 68 ++- ta_lib/indicators/volatility/src/kch.rs | 98 ++- ta_lib/indicators/volatility/src/lib.rs | 26 +- ta_lib/indicators/volatility/src/pk.rs | 29 + ta_lib/indicators/volatility/src/ppb.rs | 204 +++++-- ta_lib/indicators/volatility/src/rs.rs | 32 + ta_lib/indicators/volatility/src/snatr.rs | 45 -- ta_lib/indicators/volatility/src/tr.rs | 128 +++- ta_lib/indicators/volatility/src/yz.rs | 40 ++ ta_lib/indicators/volume/src/cmf.rs | 10 +- ta_lib/indicators/volume/src/eom.rs | 22 +- ta_lib/indicators/volume/src/lib.rs | 2 - ta_lib/indicators/volume/src/mfi.rs | 6 +- ta_lib/indicators/volume/src/nvol.rs | 6 +- ta_lib/indicators/volume/src/obv.rs | 4 +- ta_lib/indicators/volume/src/vo.rs | 37 -- ta_lib/indicators/volume/src/vwap.rs | 4 +- ta_lib/patterns/bands/Cargo.toml | 12 + ta_lib/patterns/bands/src/lib.rs | 1 + ta_lib/patterns/bands/src/macros.rs | 53 ++ ta_lib/patterns/candlestick/src/barrier.rs | 4 +- ta_lib/patterns/candlestick/src/blockade.rs | 14 +- ta_lib/patterns/candlestick/src/bottle.rs | 4 +- ta_lib/patterns/candlestick/src/breakaway.rs | 4 +- .../patterns/candlestick/src/counterattack.rs | 4 +- ta_lib/patterns/candlestick/src/doji.rs | 4 +- .../patterns/candlestick/src/doji_double.rs | 4 +- .../patterns/candlestick/src/doppelganger.rs | 14 +- .../candlestick/src/double_trouble.rs | 14 +- ta_lib/patterns/candlestick/src/engulfing.rs | 14 +- ta_lib/patterns/candlestick/src/euphoria.rs | 4 +- .../candlestick/src/euphoria_extreme.rs | 4 +- ta_lib/patterns/candlestick/src/golden.rs | 18 +- ta_lib/patterns/candlestick/src/h.rs | 14 +- ta_lib/patterns/candlestick/src/hammer.rs | 4 +- .../candlestick/src/harami_flexible.rs | 14 +- .../patterns/candlestick/src/harami_strict.rs | 14 +- ta_lib/patterns/candlestick/src/hexad.rs | 4 +- ta_lib/patterns/candlestick/src/hikkake.rs | 14 +- .../patterns/candlestick/src/kangaroo_tail.rs | 16 +- ta_lib/patterns/candlestick/src/lib.rs | 2 + ta_lib/patterns/candlestick/src/marubozu.rs | 14 +- .../patterns/candlestick/src/master_candle.rs | 14 +- ta_lib/patterns/candlestick/src/on_neck.rs | 4 +- ta_lib/patterns/candlestick/src/piercing.rs | 4 +- .../patterns/candlestick/src/quintuplets.rs | 4 +- ta_lib/patterns/candlestick/src/r.rs | 56 ++ ta_lib/patterns/candlestick/src/shrinking.rs | 14 +- ta_lib/patterns/candlestick/src/slingshot.rs | 14 +- ta_lib/patterns/candlestick/src/split.rs | 14 +- ta_lib/patterns/candlestick/src/tasuki.rs | 6 +- .../patterns/candlestick/src/three_candles.rs | 4 +- .../patterns/candlestick/src/three_methods.rs | 14 +- .../patterns/candlestick/src/three_one_two.rs | 14 +- ta_lib/patterns/candlestick/src/tweezers.rs | 60 ++ ta_lib/patterns/channel/Cargo.toml | 12 + ta_lib/patterns/channel/src/lib.rs | 1 + ta_lib/patterns/channel/src/macros.rs | 24 + ta_lib/patterns/osc/Cargo.toml | 12 + ta_lib/patterns/osc/src/lib.rs | 1 + ta_lib/patterns/osc/src/macros.rs | 68 +++ ta_lib/patterns/trail/Cargo.toml | 13 + ta_lib/patterns/trail/src/lib.rs | 1 + ta_lib/patterns/trail/src/macros.rs | 18 + ta_lib/price/src/average.rs | 9 +- ta_lib/price/src/median.rs | 6 +- ta_lib/price/src/typical.rs | 4 +- ta_lib/price/src/wcl.rs | 4 +- ta_lib/strategies/base/Cargo.toml | 1 + ta_lib/strategies/base/src/constants.rs | 2 - ta_lib/strategies/base/src/ffi.rs | 400 +++++++++++- ta_lib/strategies/base/src/lib.rs | 4 - ta_lib/strategies/base/src/model.rs | 88 --- ta_lib/strategies/base/src/source.rs | 6 +- ta_lib/strategies/base/src/strategy.rs | 198 +++--- ta_lib/strategies/base/src/traits.rs | 21 +- ta_lib/strategies/base/src/volatility.rs | 20 +- ta_lib/strategies/baseline/Cargo.toml | 3 +- ta_lib/strategies/baseline/src/ma.rs | 55 +- ta_lib/strategies/confirm/Cargo.toml | 4 +- ta_lib/strategies/confirm/src/bb.rs | 37 ++ .../{pulse => confirm}/src/braid.rs | 47 +- ta_lib/strategies/confirm/src/cc.rs | 62 ++ ta_lib/strategies/confirm/src/cci.rs | 42 +- ta_lib/strategies/confirm/src/didi.rs | 54 ++ ta_lib/strategies/confirm/src/dpo.rs | 24 +- ta_lib/strategies/confirm/src/dso.rs | 199 ------ ta_lib/strategies/confirm/src/dumb.rs | 3 +- ta_lib/strategies/confirm/src/eom.rs | 22 +- ta_lib/strategies/confirm/src/lib.rs | 16 +- ta_lib/strategies/confirm/src/roc.rs | 29 - .../strategies/confirm/src/rsi_neutrality.rs | 11 +- .../strategies/confirm/src/rsi_signalline.rs | 7 +- ta_lib/strategies/confirm/src/stc.rs | 9 +- ta_lib/strategies/confirm/src/vi.rs | 143 ----- ta_lib/strategies/confirm/src/wpr.rs | 41 ++ ta_lib/strategies/exit/Cargo.toml | 3 +- ta_lib/strategies/exit/src/ast.rs | 15 +- ta_lib/strategies/exit/src/cci.rs | 51 -- ta_lib/strategies/exit/src/dumb.rs | 3 +- ta_lib/strategies/exit/src/highlow.rs | 3 +- ta_lib/strategies/exit/src/lib.rs | 6 +- ta_lib/strategies/exit/src/ma.rs | 3 +- ta_lib/strategies/exit/src/mad.rs | 33 + ta_lib/strategies/exit/src/mfi.rs | 3 +- ta_lib/strategies/exit/src/rex.rs | 51 ++ ta_lib/strategies/exit/src/rsi.rs | 3 +- ta_lib/strategies/exit/src/trix.rs | 3 +- ta_lib/strategies/indicator/Cargo.toml | 3 +- ta_lib/strategies/indicator/src/candle.rs | 74 ++- ta_lib/strategies/indicator/src/ma.rs | 81 +-- ta_lib/strategies/pulse/Cargo.toml | 1 + ta_lib/strategies/pulse/src/adx.rs | 35 +- ta_lib/strategies/pulse/src/chop.rs | 28 +- ta_lib/strategies/pulse/src/dumb.rs | 1 + ta_lib/strategies/pulse/src/lib.rs | 6 +- ta_lib/strategies/pulse/src/nvol.rs | 9 +- ta_lib/strategies/pulse/src/sqz.rs | 55 ++ ta_lib/strategies/pulse/src/tdfi.rs | 25 +- ta_lib/strategies/pulse/src/vo.rs | 29 +- ta_lib/strategies/pulse/src/wae.rs | 61 +- ta_lib/strategies/pulse/src/yz.rs | 39 ++ ta_lib/strategies/signal/Cargo.toml | 7 +- .../signal/src/bb/{macd_bb.rs => macd.rs} | 3 +- ta_lib/strategies/signal/src/bb/mod.rs | 8 +- .../signal/src/bb/{vwap_bb.rs => vwap.rs} | 3 +- .../{dch_ma2_breakout.rs => dch_ma2.rs} | 18 +- ta_lib/strategies/signal/src/breakout/mod.rs | 4 +- .../macd.rs} | 7 +- .../strategies/signal/src/colorswitch/mod.rs | 3 + .../strategies/signal/src/contrarian/kch_a.rs | 207 +++++++ .../strategies/signal/src/contrarian/kch_c.rs | 48 ++ .../strategies/signal/src/contrarian/mod.rs | 21 + .../strategies/signal/src/contrarian/rsi_c.rs | 40 ++ .../strategies/signal/src/contrarian/rsi_d.rs | 55 ++ .../signal/src/contrarian/rsi_nt.rs | 53 ++ .../strategies/signal/src/contrarian/rsi_u.rs | 46 ++ .../strategies/signal/src/contrarian/rsi_v.rs | 40 ++ .../snatr_reversal.rs => contrarian/snatr.rs} | 15 +- .../signal/src/contrarian/stoch_e.rs | 64 ++ .../src/{pattern => contrarian}/tii_v.rs | 22 +- ta_lib/strategies/signal/src/flip/ce.rs | 48 ++ ta_lib/strategies/signal/src/flip/ce_flip.rs | 36 -- ta_lib/strategies/signal/src/flip/mod.rs | 8 +- .../strategies/signal/src/flip/supertrend.rs | 40 ++ .../signal/src/flip/supertrend_flip.rs | 357 ----------- ta_lib/strategies/signal/src/lib.rs | 10 +- ta_lib/strategies/signal/src/ma/ma2_rsi.rs | 3 +- ta_lib/strategies/signal/src/ma/ma3_cross.rs | 3 +- ta_lib/strategies/signal/src/ma/ma_cross.rs | 3 +- .../strategies/signal/src/ma/ma_quadruple.rs | 3 +- ta_lib/strategies/signal/src/ma/ma_surpass.rs | 3 +- .../signal/src/ma/ma_testing_ground.rs | 3 +- ta_lib/strategies/signal/src/ma/vwap_cross.rs | 3 +- .../{dso_neutrality_cross.rs => dso_cross.rs} | 8 +- .../strategies/signal/src/neutrality/mod.rs | 20 +- .../{rsi_neutrality_cross.rs => rsi_cross.rs} | 23 +- ...neutrality_pullback.rs => rsi_pullback.rs} | 19 +- ...utrality_rejection.rs => rsi_rejection.rs} | 19 +- .../{tii_neutrality_cross.rs => tii_cross.rs} | 8 +- .../signal/src/pattern/ao_saucer.rs | 54 +- .../src/pattern/candlestick_reversal.rs | 26 + .../signal/src/pattern/candlestick_trend.rs | 3 +- ta_lib/strategies/signal/src/pattern/hl.rs | 3 +- ta_lib/strategies/signal/src/pattern/mod.rs | 10 +- ta_lib/strategies/signal/src/pattern/rsi_v.rs | 57 -- .../strategies/signal/src/pattern/spread.rs | 40 ++ ta_lib/strategies/signal/src/pullback/mod.rs | 3 + .../signal/src/pullback/supertrend.rs | 39 ++ ta_lib/strategies/signal/src/reversal/mod.rs | 7 - .../signalline/{di_signalline.rs => di.rs} | 3 +- .../signalline/{dso_signalline.rs => dso.rs} | 3 +- .../signalline/{kst_signalline.rs => kst.rs} | 3 +- .../{macd_signalline.rs => macd.rs} | 3 +- .../strategies/signal/src/signalline/mod.rs | 36 +- .../{qstick_signalline.rs => qstick.rs} | 5 +- .../signalline/{rsi_signalline.rs => rsi.rs} | 7 +- .../{stoch_signalline.rs => stoch.rs} | 3 +- .../{trix_signalline.rs => trix.rs} | 3 +- .../signalline/{tsi_signalline.rs => tsi.rs} | 3 +- .../dmi_reversal.rs => twolinescross/dmi.rs} | 54 +- .../signal/src/twolinescross/mod.rs | 5 + .../vi_reversal.rs => twolinescross/vi.rs} | 34 +- ta_lib/strategies/signal/src/zerocross/ao.rs | 35 ++ .../signal/src/zerocross/ao_zerocross.rs | 43 -- .../zerocross/{bop_zerocross.rs => bop.rs} | 5 +- .../src/zerocross/{cc_zerocross.rs => cc.rs} | 5 +- .../zerocross/{cfo_zerocross.rs => cfo.rs} | 5 +- .../src/zerocross/{di_zerocross.rs => di.rs} | 5 +- .../zerocross/{macd_zerocross.rs => macd.rs} | 8 +- ta_lib/strategies/signal/src/zerocross/mad.rs | 35 ++ ta_lib/strategies/signal/src/zerocross/mod.rs | 42 +- .../{qstick_zerocross.rs => qstick.rs} | 10 +- ta_lib/strategies/signal/src/zerocross/roc.rs | 30 + .../signal/src/zerocross/roc_zerocross.rs | 29 - .../strategies/signal/src/zerocross/trix.rs | 32 + .../signal/src/zerocross/trix_zerocross.rs | 35 -- .../zerocross/{tsi_zerocross.rs => tsi.rs} | 5 +- ta_lib/strategies/stop_loss/Cargo.toml | 1 + ta_lib/strategies/stop_loss/src/atr.rs | 9 +- ta_lib/strategies/stop_loss/src/dch.rs | 3 +- ta_lib/strategies/trend_follow/Cargo.toml | 1 + .../src/config/baseline_config.rs | 6 +- .../trend_follow/src/config/confirm_config.rs | 52 +- .../trend_follow/src/config/exit_config.rs | 22 +- .../trend_follow/src/config/pulse_config.rs | 54 +- .../trend_follow/src/config/signal_config.rs | 110 +++- .../src/config/stoploss_config.rs | 11 +- .../src/deserialize/candle_deserialize.rs | 19 +- .../src/deserialize/ma_deserialize.rs | 29 +- .../trend_follow/src/deserialize/mod.rs | 2 +- .../src/deserialize/smooth_deserialize.rs | 2 + ta_lib/strategies/trend_follow/src/ffi.rs | 2 + .../src/mapper/baseline_mapper.rs | 8 +- .../trend_follow/src/mapper/confirm_mapper.rs | 102 +++- .../trend_follow/src/mapper/exit_mapper.rs | 41 +- .../trend_follow/src/mapper/pulse_mapper.rs | 113 ++-- .../trend_follow/src/mapper/signal_mapper.rs | 217 +++++-- .../src/mapper/stoploss_mapper.rs | 11 +- ta_lib/timeseries/Cargo.toml | 18 + ta_lib/timeseries/src/lib.rs | 13 + ta_lib/timeseries/src/model.rs | 386 ++++++++++++ ta_lib/timeseries/src/ohlcv.rs | 141 +++++ ta_lib/timeseries/src/ta.rs | 36 ++ ta_lib/timeseries/src/traits.rs | 11 + tr.rs | 191 ++++++ 510 files changed, 11595 insertions(+), 6283 deletions(-) create mode 100644 copilot/__init__.py create mode 100644 copilot/_actor.py create mode 100644 copilot/_prompt.py rename core/actors/{_actor.py => _base_actor.py} (74%) create mode 100644 core/actors/_strategy_actor.py rename {position/risk => core/actors/policy}/__init__.py (100%) create mode 100644 core/actors/policy/event.py create mode 100644 core/actors/policy/signal.py create mode 100644 core/actors/policy/strategy.py create mode 100644 core/interfaces/abstract_llm_service.py create mode 100644 core/interfaces/abstract_market_repository.py rename core/interfaces/{abstract_position_size_strategy.py => abstract_order_size_strategy.py} (60%) create mode 100644 core/interfaces/abstract_timeseries.py create mode 100644 core/interfaces/abstract_wasm_manager.py delete mode 100644 core/interfaces/abstract_wasm_service.py create mode 100644 core/mixins/__init__.py create mode 100644 core/mixins/_event_handler.py create mode 100644 core/models/position_risk.py create mode 100644 core/models/profit_target.py create mode 100644 core/models/signal_risk.py create mode 100644 core/models/strategy_type.py create mode 100644 core/models/ta.py create mode 100644 core/models/timeseries_ref.py create mode 100644 core/models/wasm_type.py create mode 100644 core/queries/copilot.py create mode 100644 core/queries/ohlcv.py create mode 100644 executor/_market_actor.py delete mode 100644 executor/_market_order_actor.py create mode 100644 executor/_paper_actor.py delete mode 100644 executor/_paper_order_actor.py create mode 100644 infrastructure/event_dispatcher/event_dedup.py create mode 100644 market/__init__.py create mode 100644 market/_actor.py delete mode 100644 position/risk/break_even.py delete mode 100644 position/risk/simple.py rename position/{take_profit/__init__.py => size/base.py} (100%) delete mode 100644 position/take_profit/risk_reward.py create mode 100644 service/_llm.py create mode 100644 service/_timeseries.py create mode 100644 service/_wasm.py delete mode 100644 service/_wasm_file.py create mode 100644 strategy/generator/confirm/bb.py create mode 100644 strategy/generator/confirm/braid.py create mode 100644 strategy/generator/confirm/cc.py create mode 100644 strategy/generator/confirm/didi.py delete mode 100644 strategy/generator/confirm/dso.py delete mode 100644 strategy/generator/confirm/roc.py delete mode 100644 strategy/generator/confirm/vi.py create mode 100644 strategy/generator/confirm/wpr.py delete mode 100644 strategy/generator/exit/cci.py create mode 100644 strategy/generator/exit/mad.py create mode 100644 strategy/generator/exit/rex.py delete mode 100644 strategy/generator/pulse/braid.py create mode 100644 strategy/generator/pulse/sqz.py create mode 100644 strategy/generator/pulse/yz.py rename strategy/generator/signal/bb/{macd_bb.py => macd.py} (94%) rename strategy/generator/signal/bb/{vwap_bb.py => vwap.py} (100%) rename strategy/generator/signal/{reversal => colorswitch}/__init__.py (100%) rename strategy/generator/signal/{pattern/macd_colorswitch.py => colorswitch/macd.py} (78%) rename ta_lib/core/src/distance.rs => strategy/generator/signal/contrarian/__init__.py (100%) create mode 100644 strategy/generator/signal/contrarian/kch_a.py create mode 100644 strategy/generator/signal/contrarian/kch_c.py create mode 100644 strategy/generator/signal/contrarian/rsi_c.py create mode 100644 strategy/generator/signal/contrarian/rsi_d.py create mode 100644 strategy/generator/signal/contrarian/rsi_nt.py create mode 100644 strategy/generator/signal/contrarian/rsi_u.py rename strategy/generator/signal/{pattern => contrarian}/rsi_v.py (71%) rename strategy/generator/signal/{reversal/snatr_reversal.py => contrarian/snatr.py} (84%) create mode 100644 strategy/generator/signal/contrarian/stoch_e.py rename strategy/generator/signal/{pattern => contrarian}/tii_v.py (78%) rename strategy/generator/signal/flip/{ce_flip.py => ce.py} (58%) rename strategy/generator/signal/flip/{supertrend_flip.py => supertrend.py} (65%) rename strategy/generator/signal/neutrality/{dso_neutrality_cross.py => dso_cross.py} (100%) rename strategy/generator/signal/neutrality/{rsi_neutrality_cross.py => rsi_cross.py} (100%) rename strategy/generator/signal/neutrality/{rsi_neutrality_pullback.py => rsi_pullback.py} (100%) rename strategy/generator/signal/neutrality/{rsi_neutrality_rejection.py => rsi_rejection.py} (100%) rename strategy/generator/signal/neutrality/{tii_neutrality_cross.py => tii_cross.py} (100%) create mode 100644 strategy/generator/signal/pattern/candle_reversal.py create mode 100644 strategy/generator/signal/pattern/spread.py create mode 100644 strategy/generator/signal/pullback/__init__.py create mode 100644 strategy/generator/signal/pullback/supertrend.py delete mode 100644 strategy/generator/signal/reversal/vi_reversal.py rename strategy/generator/signal/signalline/{di_signalline.py => di.py} (100%) rename strategy/generator/signal/signalline/{dso_signalline.py => dso.py} (100%) rename strategy/generator/signal/signalline/{kst_signalline.py => kst.py} (100%) rename strategy/generator/signal/signalline/{macd_signalline.py => macd.py} (100%) rename strategy/generator/signal/signalline/{qstick_signalline.py => qstick.py} (100%) rename strategy/generator/signal/signalline/{rsi_signalline.py => rsi.py} (100%) rename strategy/generator/signal/signalline/{stoch_signalline.py => stoch.py} (100%) rename strategy/generator/signal/signalline/{trix_signalline.py => trix.py} (100%) rename strategy/generator/signal/signalline/{tsi_signalline.py => tsi.py} (100%) create mode 100644 strategy/generator/signal/twolinescross/__init__.py rename strategy/generator/signal/{reversal/dmi_reversal.py => twolinescross/dmi.py} (72%) create mode 100644 strategy/generator/signal/twolinescross/vi.py rename strategy/generator/signal/zerocross/{ao_zerocross.py => ao.py} (100%) rename strategy/generator/signal/zerocross/{bop_zerocross.py => bop.py} (100%) rename strategy/generator/signal/zerocross/{cc_zerocross.py => cc.py} (100%) rename strategy/generator/signal/zerocross/{cfo_zerocross.py => cfo.py} (100%) rename strategy/generator/signal/zerocross/{di_zerocross.py => di.py} (100%) rename strategy/generator/signal/zerocross/{macd_zerocross.py => macd.py} (100%) create mode 100644 strategy/generator/signal/zerocross/mad.py rename strategy/generator/signal/zerocross/{qstick_zerocross.py => qstick.py} (100%) rename strategy/generator/signal/zerocross/{roc_zerocross.py => roc.py} (100%) rename strategy/generator/signal/zerocross/{trix_zerocross.py => trix.py} (100%) rename strategy/generator/signal/zerocross/{tsi_zerocross.py => tsi.py} (100%) delete mode 100644 ta_lib/benches/strategy.rs create mode 100644 ta_lib/core/src/fmt.rs create mode 100644 ta_lib/core/src/types.rs create mode 100644 ta_lib/ffi/Cargo.toml create mode 100644 ta_lib/ffi/src/lib.rs create mode 100644 ta_lib/ffi/src/timeseries.rs delete mode 100644 ta_lib/indicators/momentum/src/ao.rs delete mode 100644 ta_lib/indicators/momentum/src/dso.rs rename ta_lib/indicators/{trend => momentum}/src/qstick.rs (61%) create mode 100644 ta_lib/indicators/momentum/src/rex.rs delete mode 100644 ta_lib/indicators/momentum/src/sso.rs delete mode 100644 ta_lib/indicators/momentum/src/stoch.rs create mode 100644 ta_lib/indicators/momentum/src/uo.rs rename ta_lib/indicators/momentum/src/{pr.rs => wpr.rs} (67%) rename ta_lib/indicators/trend/src/{kjs.rs => midpoint.rs} (63%) create mode 100644 ta_lib/indicators/trend/src/pp.rs create mode 100644 ta_lib/indicators/trend/src/slsma.rs delete mode 100644 ta_lib/indicators/trend/src/tma.rs create mode 100644 ta_lib/indicators/trend/src/trima.rs create mode 100644 ta_lib/indicators/trend/src/ults.rs delete mode 100644 ta_lib/indicators/volatility/src/atr.rs delete mode 100644 ta_lib/indicators/volatility/src/bbw.rs create mode 100644 ta_lib/indicators/volatility/src/gkyz.rs create mode 100644 ta_lib/indicators/volatility/src/pk.rs create mode 100644 ta_lib/indicators/volatility/src/rs.rs delete mode 100644 ta_lib/indicators/volatility/src/snatr.rs create mode 100644 ta_lib/indicators/volatility/src/yz.rs delete mode 100644 ta_lib/indicators/volume/src/vo.rs create mode 100644 ta_lib/patterns/bands/Cargo.toml create mode 100644 ta_lib/patterns/bands/src/lib.rs create mode 100644 ta_lib/patterns/bands/src/macros.rs create mode 100644 ta_lib/patterns/candlestick/src/r.rs create mode 100644 ta_lib/patterns/candlestick/src/tweezers.rs create mode 100644 ta_lib/patterns/channel/Cargo.toml create mode 100644 ta_lib/patterns/channel/src/lib.rs create mode 100644 ta_lib/patterns/channel/src/macros.rs create mode 100644 ta_lib/patterns/osc/Cargo.toml create mode 100644 ta_lib/patterns/osc/src/lib.rs create mode 100644 ta_lib/patterns/osc/src/macros.rs create mode 100644 ta_lib/patterns/trail/Cargo.toml create mode 100644 ta_lib/patterns/trail/src/lib.rs create mode 100644 ta_lib/patterns/trail/src/macros.rs delete mode 100644 ta_lib/strategies/base/src/constants.rs delete mode 100644 ta_lib/strategies/base/src/model.rs create mode 100644 ta_lib/strategies/confirm/src/bb.rs rename ta_lib/strategies/{pulse => confirm}/src/braid.rs (80%) create mode 100644 ta_lib/strategies/confirm/src/cc.rs create mode 100644 ta_lib/strategies/confirm/src/didi.rs delete mode 100644 ta_lib/strategies/confirm/src/dso.rs delete mode 100644 ta_lib/strategies/confirm/src/roc.rs delete mode 100644 ta_lib/strategies/confirm/src/vi.rs create mode 100644 ta_lib/strategies/confirm/src/wpr.rs delete mode 100644 ta_lib/strategies/exit/src/cci.rs create mode 100644 ta_lib/strategies/exit/src/mad.rs create mode 100644 ta_lib/strategies/exit/src/rex.rs create mode 100644 ta_lib/strategies/pulse/src/sqz.rs create mode 100644 ta_lib/strategies/pulse/src/yz.rs rename ta_lib/strategies/signal/src/bb/{macd_bb.rs => macd.rs} (93%) rename ta_lib/strategies/signal/src/bb/{vwap_bb.rs => vwap.rs} (90%) rename ta_lib/strategies/signal/src/breakout/{dch_ma2_breakout.rs => dch_ma2.rs} (64%) rename ta_lib/strategies/signal/src/{pattern/macd_colorswitch.rs => colorswitch/macd.rs} (90%) create mode 100644 ta_lib/strategies/signal/src/colorswitch/mod.rs create mode 100644 ta_lib/strategies/signal/src/contrarian/kch_a.rs create mode 100644 ta_lib/strategies/signal/src/contrarian/kch_c.rs create mode 100644 ta_lib/strategies/signal/src/contrarian/mod.rs create mode 100644 ta_lib/strategies/signal/src/contrarian/rsi_c.rs create mode 100644 ta_lib/strategies/signal/src/contrarian/rsi_d.rs create mode 100644 ta_lib/strategies/signal/src/contrarian/rsi_nt.rs create mode 100644 ta_lib/strategies/signal/src/contrarian/rsi_u.rs create mode 100644 ta_lib/strategies/signal/src/contrarian/rsi_v.rs rename ta_lib/strategies/signal/src/{reversal/snatr_reversal.rs => contrarian/snatr.rs} (76%) create mode 100644 ta_lib/strategies/signal/src/contrarian/stoch_e.rs rename ta_lib/strategies/signal/src/{pattern => contrarian}/tii_v.rs (71%) create mode 100644 ta_lib/strategies/signal/src/flip/ce.rs delete mode 100644 ta_lib/strategies/signal/src/flip/ce_flip.rs create mode 100644 ta_lib/strategies/signal/src/flip/supertrend.rs delete mode 100644 ta_lib/strategies/signal/src/flip/supertrend_flip.rs rename ta_lib/strategies/signal/src/neutrality/{dso_neutrality_cross.rs => dso_cross.rs} (84%) rename ta_lib/strategies/signal/src/neutrality/{rsi_neutrality_cross.rs => rsi_cross.rs} (62%) rename ta_lib/strategies/signal/src/neutrality/{rsi_neutrality_pullback.rs => rsi_pullback.rs} (69%) rename ta_lib/strategies/signal/src/neutrality/{rsi_neutrality_rejection.rs => rsi_rejection.rs} (66%) rename ta_lib/strategies/signal/src/neutrality/{tii_neutrality_cross.rs => tii_cross.rs} (82%) create mode 100644 ta_lib/strategies/signal/src/pattern/candlestick_reversal.rs delete mode 100644 ta_lib/strategies/signal/src/pattern/rsi_v.rs create mode 100644 ta_lib/strategies/signal/src/pattern/spread.rs create mode 100644 ta_lib/strategies/signal/src/pullback/mod.rs create mode 100644 ta_lib/strategies/signal/src/pullback/supertrend.rs delete mode 100644 ta_lib/strategies/signal/src/reversal/mod.rs rename ta_lib/strategies/signal/src/signalline/{di_signalline.rs => di.rs} (90%) rename ta_lib/strategies/signal/src/signalline/{dso_signalline.rs => dso.rs} (91%) rename ta_lib/strategies/signal/src/signalline/{kst_signalline.rs => kst.rs} (96%) rename ta_lib/strategies/signal/src/signalline/{macd_signalline.rs => macd.rs} (92%) rename ta_lib/strategies/signal/src/signalline/{qstick_signalline.rs => qstick.rs} (86%) rename ta_lib/strategies/signal/src/signalline/{rsi_signalline.rs => rsi.rs} (88%) rename ta_lib/strategies/signal/src/signalline/{stoch_signalline.rs => stoch.rs} (90%) rename ta_lib/strategies/signal/src/signalline/{trix_signalline.rs => trix.rs} (90%) rename ta_lib/strategies/signal/src/signalline/{tsi_signalline.rs => tsi.rs} (91%) rename ta_lib/strategies/signal/src/{reversal/dmi_reversal.rs => twolinescross/dmi.rs} (70%) create mode 100644 ta_lib/strategies/signal/src/twolinescross/mod.rs rename ta_lib/strategies/signal/src/{reversal/vi_reversal.rs => twolinescross/vi.rs} (85%) create mode 100644 ta_lib/strategies/signal/src/zerocross/ao.rs delete mode 100644 ta_lib/strategies/signal/src/zerocross/ao_zerocross.rs rename ta_lib/strategies/signal/src/zerocross/{bop_zerocross.rs => bop.rs} (80%) rename ta_lib/strategies/signal/src/zerocross/{cc_zerocross.rs => cc.rs} (87%) rename ta_lib/strategies/signal/src/zerocross/{cfo_zerocross.rs => cfo.rs} (76%) rename ta_lib/strategies/signal/src/zerocross/{di_zerocross.rs => di.rs} (80%) rename ta_lib/strategies/signal/src/zerocross/{macd_zerocross.rs => macd.rs} (85%) create mode 100644 ta_lib/strategies/signal/src/zerocross/mad.rs rename ta_lib/strategies/signal/src/zerocross/{qstick_zerocross.rs => qstick.rs} (71%) create mode 100644 ta_lib/strategies/signal/src/zerocross/roc.rs delete mode 100644 ta_lib/strategies/signal/src/zerocross/roc_zerocross.rs create mode 100644 ta_lib/strategies/signal/src/zerocross/trix.rs delete mode 100644 ta_lib/strategies/signal/src/zerocross/trix_zerocross.rs rename ta_lib/strategies/signal/src/zerocross/{tsi_zerocross.rs => tsi.rs} (84%) create mode 100644 ta_lib/timeseries/Cargo.toml create mode 100644 ta_lib/timeseries/src/lib.rs create mode 100644 ta_lib/timeseries/src/model.rs create mode 100644 ta_lib/timeseries/src/ohlcv.rs create mode 100644 ta_lib/timeseries/src/ta.rs create mode 100644 ta_lib/timeseries/src/traits.rs create mode 100644 tr.rs diff --git a/.env.example b/.env.example index 240a8158..bf665853 100644 --- a/.env.example +++ b/.env.example @@ -6,4 +6,6 @@ LOG_DIR= LOG_LEVEL=INFO WASM_FOLDER=wasm -REGIME=default \ No newline at end of file +REGIME=default + +COPILOT_MODEL_PATH= \ No newline at end of file diff --git a/Makefile b/Makefile index 40780cdc..45cd2097 100644 --- a/Makefile +++ b/Makefile @@ -14,14 +14,24 @@ check: cargo clippy --all-features --all-targets --workspace --manifest-path=$(TA_LIB_PATH) cargo fmt --all --check --manifest-path=$(TA_LIB_PATH) -build: +build: build-timeseries build-strategy + +build-strategy: RUSTFLAGS="-C target-feature=+multivalue,+simd128 -C link-arg=-s" cargo build --release --manifest-path=$(TA_LIB_PATH) --package trend_follow --target wasm32-wasi cp $(TA_LIB_DIR)/target/wasm32-wasi/release/trend_follow.wasm $(WASM_DIR)/trend_follow.wasm +build-timeseries: + RUSTFLAGS="-C target-feature=+multivalue,+simd128 -C link-arg=-s" cargo build --release --manifest-path=$(TA_LIB_PATH) --package ffi --target wasm32-wasi + cp $(TA_LIB_DIR)/target/wasm32-wasi/release/ffi.wasm $(WASM_DIR)/timeseries.wasm + run: pipenv run python3 quant.py format: cargo fmt --all --manifest-path=$(TA_LIB_PATH) pipenv run black . - pipenv run ruff . --fix \ No newline at end of file + pipenv run ruff . --fix + +update: + cargo update --manifest-path=$(TA_LIB_PATH) + pipenv update \ No newline at end of file diff --git a/Pipfile b/Pipfile index a8be478a..0301b5a8 100644 --- a/Pipfile +++ b/Pipfile @@ -22,6 +22,8 @@ cachetools = "*" orjson = "*" scipy = "*" cbor2 = "*" +llama-cpp-python = "0.2.82" +umap-learn = "*" [dev-packages] mypy = "*" diff --git a/config.default.ini b/config.default.ini index 845ff0d1..bd9d9f71 100644 --- a/config.default.ini +++ b/config.default.ini @@ -1,53 +1,59 @@ [store] -buf_size = 100 +buf_size = 50 base_dir = tmp [bus] -piority_groups = 168 -num_workers = 1 +piority_groups = 8 +num_workers = 5 [backtest] -batch_size = 1597 -window_size = 2 +batch_size = 350 +buff_size = 8 +window_size = 1 [position] -risk_reward_ratio = 1.618 -trade_duration = 900 -twap_duration = 180 -max_order_slice = 13 +trade_duration = 26800 +twap_duration = 80 +max_order_slice = 8 order_expiration_time = 8 max_order_breach = 3 stop_loss_threshold = 0.5 -risk_factor = 0.618 -tp_factor = 2.236 -sl_factor = 2.236 -trl_factor = 1.236 -depth = 80 +dom = 15 +max_scale_in = 0 [portfolio] -risk_per_trade = 0.005 +risk_per_trade = 0.0004 account_size = 1000 cagr_threshold = 0.01 -sharpe_ratio_threshold = 0.236 -total_trades_threshold = 8 +sharpe_ratio_threshold = 0.23 +total_trades_threshold = 6 [system] active_strategy_num = 3 verify_strategy_num = 21 mode = 1 -leverage = 1 +leverage = 75 reevaluate_timeout = 3600 [generator] -n_samples = 21 +n_samples = 8 blacklist = [USDCUSDT] timeframes = [5m] [optimization] -max_generations = 5 +max_generations = 3 elite_count = 3 mutation_rate = 0.0236 crossover_rate = 0.8 tournament_size = 5 reset_percentage = 0.236 -stability_percentage = 0.382 \ No newline at end of file +stability_percentage = 0.382 + +[copilot] +model_path = '' +n_ctx = 4096 +n_threads = 7 +n_gpu_layers = 2 +n_batch = 256 +max_tokens = 66 +temperature = 0.52 \ No newline at end of file diff --git a/copilot/__init__.py b/copilot/__init__.py new file mode 100644 index 00000000..9463a94b --- /dev/null +++ b/copilot/__init__.py @@ -0,0 +1,3 @@ +from ._actor import CopilotActor + +__all__ = [CopilotActor] diff --git a/copilot/_actor.py b/copilot/_actor.py new file mode 100644 index 00000000..90f7667a --- /dev/null +++ b/copilot/_actor.py @@ -0,0 +1,358 @@ +import asyncio +import logging +from typing import Union + +import numpy as np +from core.models.strategy_type import StrategyType +from scipy.spatial.distance import cdist +from sklearn.cluster import KMeans +from sklearn.decomposition import PCA, KernelPCA +from sklearn.ensemble import IsolationForest +from sklearn.metrics import ( + calinski_harabasz_score, + davies_bouldin_score, + silhouette_score, +) +from sklearn.mixture import BayesianGaussianMixture +from sklearn.neighbors import LocalOutlierFactor +from sklearn.preprocessing import MinMaxScaler, StandardScaler +from sklearn.svm import OneClassSVM +from sklearn.utils import check_random_state + +from core.actors import BaseActor +from core.interfaces.abstract_llm_service import AbstractLLMService +from core.mixins import EventHandlerMixin +from core.models.risk_type import SessionRiskType, SignalRiskType +from core.models.side import PositionSide, SignalSide +from core.models.signal_risk import SignalRisk +from core.queries.copilot import EvaluateSession, EvaluateSignal + +from ._prompt import ( + signal_contrarian_risk_prompt, + signal_trend_risk_prompt, +) + +CopilotEvent = Union[EvaluateSignal, EvaluateSession] + +logger = logging.getLogger(__name__) +LOOKBACK = 8 +N_CLUSTERS = 5 + + +def binary_strings(n): + def backtrack(current): + if len(current) == n: + results.append(current) + return + + backtrack(current + "0") + backtrack(current + "1") + + results = [] + backtrack("") + return results + + +def pad_bars(bars, length): + if len(bars) < length: + padding = [None] * (length - len(bars)) + return padding + bars + else: + return bars[-length:] + + +def lorentzian_distance(u, v): + return np.log(1 + np.sum(np.abs(u - v))) + + +class CustomKMeans(KMeans): + def __init__(self, n_clusters=3, max_iter=300, tol=1e-4, random_state=None): + super().__init__( + n_clusters=n_clusters, max_iter=max_iter, tol=tol, random_state=random_state + ) + + def _e_step(self, X): + distances = cdist(X, self.cluster_centers_, metric=lorentzian_distance) + labels = distances.argmin(axis=1) + return labels, distances + + def fit(self, X, y=None): + random_state = check_random_state(self.random_state) + X = self._validate_data(X, accept_sparse="csr", reset=True) + + self.cluster_centers_ = self._init_centroids(X, random_state) + + for i in range(self.max_iter): + self.labels_, distances = self._e_step(X) + + new_centers = np.array( + [ + X[self.labels_ == j].mean(axis=0) + if len(X[self.labels_ == j]) > 0 + else self.cluster_centers_[j] + for j in range(self.n_clusters) + ] + ) + + if np.all(np.abs(new_centers - self.cluster_centers_) <= self.tol): + break + + self.cluster_centers_ = new_centers + + self.inertia_ = np.sum((distances.min(axis=1)) ** 2) + self.n_iter_ = i + 1 + return self + + def _init_centroids(self, X, random_state): + n_samples, n_features = X.shape + centers = np.empty((self.n_clusters, n_features), dtype=X.dtype) + + center_id = random_state.randint(n_samples) + centers[0] = X[center_id] + + closest_dist_sq = np.full(n_samples, np.inf) + closest_dist_sq = np.minimum( + closest_dist_sq, np.sum((X - centers[0]) ** 2, axis=1) + ) + + for c in range(1, self.n_clusters): + probabilities = closest_dist_sq / closest_dist_sq.sum() + new_center_id = random_state.choice(n_samples, p=probabilities) + centers[c] = X[new_center_id] + + new_dist_sq = np.sum((X - centers[c]) ** 2, axis=1) + closest_dist_sq = np.minimum(closest_dist_sq, new_dist_sq) + + return centers + + +class CopilotActor(BaseActor, EventHandlerMixin): + _EVENTS = [EvaluateSignal, EvaluateSession] + + def __init__(self, llm: AbstractLLMService): + super().__init__() + EventHandlerMixin.__init__(self) + self._register_event_handlers() + + self.llm = llm + self.prev_txn = (None, None) + self._lock = asyncio.Lock() + self.anomaly = set(binary_strings(8)) + self.bars_n = 3 + self.horizon = 3 + + async def on_receive(self, event: CopilotEvent): + return await self.handle_event(event) + + def _register_event_handlers(self): + self.register_handler(EvaluateSignal, self._evaluate_signal) + self.register_handler(EvaluateSession, self._evaluate_session) + + async def _evaluate_signal(self, msg: EvaluateSignal) -> SignalRisk: + signal = msg.signal + curr_bar = signal.ohlcv + + prev_bar = msg.prev_bar + ta = msg.ta + + trend = ta.trend + volume = ta.volume + osc = ta.oscillator + momentum = ta.momentum + volatility = ta.volatility + + side = ( + PositionSide.LONG if signal.side == SignalSide.BUY else PositionSide.SHORT + ) + risk_type = SignalRiskType.NONE + + risk = SignalRisk( + type=risk_type, + ) + + bar = sorted(prev_bar + [curr_bar], key=lambda x: x.timestamp) + strategy_type = StrategyType.CONTRARIAN if "SUP" not in str(signal.strategy) else StrategyType.TREND_FOLLOW + + template = ( + signal_contrarian_risk_prompt + if strategy_type == StrategyType.CONTRARIAN + else signal_trend_risk_prompt + ) + + prompt = template.format( + side=side, + strategy_type=strategy_type, + entry=curr_bar.close, + horizon=self.horizon, + timeframe=signal.timeframe, + bar=bar[-self.bars_n :], + trend=trend.sma[-self.bars_n :], + macd=trend.macd[-self.bars_n :], + rsi=osc.srsi[-self.bars_n :], + cci=momentum.cci[-self.bars_n :], + roc=momentum.sroc[-self.bars_n :], + nvol=volume.nvol[-self.bars_n :], + support=trend.support[-self.bars_n :] + if side == PositionSide.SHORT + else trend.resistance[-self.bars_n :], + resistance=trend.resistance[-self.bars_n :] + if side == PositionSide.SHORT + else trend.support[-self.bars_n :], + vwap=volume.vwap[-self.bars_n :], + upper_bb=volatility.upb[-self.bars_n :], + lower_bb=volatility.lwb[-self.bars_n :], + true_range=volatility.tr[-self.bars_n :], + ) + + # logger.info(f"Signal Prompt: {prompt}") + + # answer = await self.llm.call(system_prompt, prompt) + + # logger.info(f"LLM Answer: {answer}") + + # match = re.search(signal_risk_pattern, answer) + match = None + + if not match: + risk = SignalRisk(type=risk_type) + else: + risk_type = SignalRiskType.from_string(match.group(1)) + _tp, _sl = match.group(2).split("."), match.group(3).split(".") + + tp, sl = float(f"{_tp[0]}.{_tp[1]}"), float(f"{_sl[0]}.{_sl[1]}") + + unknow_risk = (tp > curr_bar.close and side == PositionSide.SHORT) or ( + tp < curr_bar.close and side == PositionSide.LONG + ) + + if unknow_risk: + logger.warn("Risk with unknown position management") + + risk = SignalRisk(type=risk_type, tp=tp, sl=sl) + + logger.info(f"Entry: {curr_bar.close}, Signal Risk: {risk}") + + return risk + + async def _evaluate_session(self, msg: EvaluateSession) -> SessionRiskType: + async with self._lock: + ta = msg.ta + + ema = np.array(ta.trend.sma[-LOOKBACK:]) + support = np.array(ta.trend.support[-LOOKBACK:]) + resistance = np.array(ta.trend.resistance[-LOOKBACK:]) + dmi = np.array(ta.trend.dmi[-LOOKBACK:]) + macd = np.array(ta.trend.macd[-LOOKBACK:]) + hlcc4 = np.array(ta.trend.hlcc4[-LOOKBACK:]) + cci = np.array(ta.momentum.cci[-LOOKBACK:]) + roc = np.array(ta.momentum.sroc[-LOOKBACK:]) + ebb = np.array(ta.volatility.ebb[-LOOKBACK:]) + ekch = np.array(ta.volatility.ekch[-LOOKBACK:]) + rsi = np.array(ta.oscillator.srsi[-LOOKBACK:]) + stoch_k = np.array(ta.oscillator.k[-LOOKBACK:]) + mfi = np.array(ta.volume.mfi[-LOOKBACK:]) + vwap = np.array(ta.volume.vwap[-LOOKBACK:]) + nvol = np.array(ta.volume.nvol[-LOOKBACK:]) + yz = np.array(ta.volatility.yz[-LOOKBACK:]) + tr = np.array(ta.volatility.tr[-LOOKBACK:]) + + features = np.column_stack( + ( + ema, + support, + resistance, + dmi, + macd, + cci, + rsi, + stoch_k, + mfi, + ebb, + ekch, + yz, + roc, + vwap, + tr, + nvol, + hlcc4, + ) + ) + + features = StandardScaler().fit_transform(features) + features = MinMaxScaler(feature_range=(-1, 1)).fit_transform(features) + + features = PCA(n_components=5).fit_transform(features) + features = KernelPCA(n_components=2, kernel="rbf").fit_transform(features) + + n_neighbors = len(features) - 1 + max_clusters = min(n_neighbors, 10) + min_clusters = min(2, max_clusters) + k_best_score = float("-inf") + k_best_labels = None + + for k in range(min_clusters, max_clusters + 1): + kmeans = CustomKMeans(n_clusters=k, random_state=None).fit(features) + + if len(np.unique(kmeans.labels_)) < k: + continue + + score = calinski_harabasz_score(features, kmeans.labels_) + sil_score = silhouette_score(features, kmeans.labels_) + db_score = davies_bouldin_score(features, kmeans.labels_) + + combined_score = (score + sil_score - db_score) / 3 + + if combined_score > k_best_score: + k_best_score = combined_score + k_best_labels = kmeans.labels_ + + k_cluster_labels = k_best_labels.reshape(-1, 1) + + features_with_clusters = np.hstack((features, k_cluster_labels)) + + iso_forest = IsolationForest(contamination=0.01, random_state=1337).fit( + features_with_clusters + ) + iso_anomaly = iso_forest.predict(features_with_clusters) == -1 + + lof = LocalOutlierFactor(n_neighbors=n_neighbors, contamination=0.01) + lof_anomaly = lof.fit_predict(features_with_clusters) == -1 + + one_class_svm = OneClassSVM(kernel="rbf", gamma="scale", nu=0.01) + svm_anomaly = one_class_svm.fit_predict(features_with_clusters) == -1 + + iso_scores = iso_forest.decision_function(features_with_clusters) + lof_scores = -lof.negative_outlier_factor_ + svm_score = one_class_svm.decision_function(features_with_clusters) + + anomaly_scores = 0.3 * iso_scores + 0.2 * lof_scores + 0.5 * svm_score + + bgmm = BayesianGaussianMixture( + n_components=2, covariance_type="full", random_state=1337 + ) + bgmm.fit(anomaly_scores.reshape(-1, 1)) + + dynamic_threshold = np.percentile(bgmm.means_, 5) + knn_transaction = "".join(map(str, kmeans.labels_)) + + should_exit = False + + confidence_scores = { + "knn_transaction": 0.4 if knn_transaction in self.anomaly else 0, + "anomaly_score": 0.2 if anomaly_scores[-1] < dynamic_threshold else 0, + "iso_anomaly": 0.3 if iso_anomaly[-1] else 0, + "lof_anomaly": 0.05 if lof_anomaly[-1] else 0, + "svm_anomaly": 0.05 if svm_anomaly[-1] else 0, + } + + if sum(confidence_scores.values()) > 0.5: + should_exit = True + + logger.info( + f"SIDE: {msg.side}, " f"HLCC4: {hlcc4[-1]}, " f"Exit: {should_exit}" + ) + + if should_exit: + return SessionRiskType.EXIT + + return SessionRiskType.CONTINUE diff --git a/copilot/_prompt.py b/copilot/_prompt.py new file mode 100644 index 00000000..a02db404 --- /dev/null +++ b/copilot/_prompt.py @@ -0,0 +1,166 @@ +system_prompt = """ +You are act as an effective quantitative analysis assistant. Your job is to help interpret data, perform statistical analyses, technical analyses, forecast trend and provide insights based on numerical information. +""" +risk_intro = """ +[Position Risk Evaluation Framework] + +[Position Details] +- Side: {side} +- Timeframe: {timeframe} +- Horizon: Next {horizon} Candlesticks +- Entry Price: {entry} +- Strategy Type: {strategy_type} +""" +risk_outro = """ +[Final Output] +- RL: [Risk Level Value:ENUM] +- TP: [Take Profit Value:.6f] +- SL: [Stop Loss Value:.6f] + +[Example] +RL: MODERATE, TP: 7.4499, SL: 8.444 + +Return the result as raw string only. +""" +risk_data = """ +[Input Data] +- Candlestick Data: {bar} +- EMA (Exponential Moving Average): {trend} +- MACD (Moving Average Convergence/Divergence) Histogram: {macd} +- RSI (Relative Strength Index): {rsi} +- CCI (Commodity Channel Index): {cci} +- ROC (Rate of Change): {roc} +- Normalized Volume: {nvol} +- VWAP (Volume Weighted Average Price): {vwap} +- Support/Resistance Levels: + - Support: {support} + - Resistance: {resistance} +- Bollinger Bands: + - Upper: {upper_bb} + - Lower: {lower_bb} +- Volatility (True Range): {true_range} +""" +trend_risk_framework = """ +[Input Data Analysis] + +[Step 1: Candlestick Data Analysis] +- Price Movement: + - Upward: Higher risk for SHORT, lower risk for LONG. + - Downward: Lower risk for SHORT, higher risk for LONG. +- Price Range: + - Wide: Higher risk due to potential price swings. + - Narrow: Lower risk but may suggest a potential breakout. +- Real Body Normalization: + - High: Strong movement, higher risk if against the position. + - Low: Weak movement, lower risk if against the position. +- Body Range Ratio: + - High: Higher risk if against the position. + - Low: Lower risk if against the position. +- Body Shadow Ratio: + - High: Higher risk if against the position. + - Low: Lower risk if against the position. + +[Step 2: Technical Analysis] +- EMA: + - Upward: Lower risk for LONG, higher risk for SHORT. + - Downward: Higher risk for LONG, lower risk for SHORT. +- MACD Histogram: + - Positive: Bullish momentum, lower risk for LONG, higher risk for SHORT. + - Negative: Bearish momentum, higher risk for LONG, lower risk for SHORT. +- RSI: + - Above 70: Overbought, higher risk for LONG. + - Below 30: Oversold, higher risk for SHORT. + - Between 30 and 70: Neutral, moderate risk for LONG and SHORT. +- CCI: + - Above 100: Overbought, higher risk for LONG. + - Below -100: Oversold, higher risk for SHORT. + - Between -100 and 100: Neutral, moderate risk for LONG and SHORT. +- ROC: + - Positive: Lower risk for LONG, higher risk for SHORT. + - Negative: Higher risk for LONG, lower risk for SHORT. +- Normalized Volume: + - High: Higher risk if against the position. + - Low: Lower risk. +- VWAP: + - Above: Favorable for LONG, increased risk for SHORT. + - Below: Favorable for SHORT, increased risk for LONG. +- Support/Resistance Levels: + - Near Resistance: Increased risk for LONG, favorable for SHORT. + - Near Support: Favorable for LONG, increased risk for SHORT. +- Bollinger Bands: + - Above Upper: Higher risk for LONG, potential reversal or correction. + - Below Lower: Higher risk for SHORT, potential reversal or bounce. +- Volatility (True Range): + - High: Higher risk due to potential price swings, tighter stops recommended for both LONG and SHORT. + - Low: Suggests consolidation, with increased breakout potential, adjust risk management accordingly. +""" +contrarian_risk_framework = """ +[Input Data Analysis] + +[Step 1: Candlestick Data Analysis] +- Price Movement: + - Upward: Reversal risk higher for LONG, lower for SHORT. + - Downward: Rebound risk higher for SHORT, lower for LONG. +- Price Range: + - Wide: Higher reversal risk if aligned with the trend. + - Narrow: Consolidation with potential reversal or breakout. +- Real Body Normalization: + - High: Strong trend, potential exhaustion. + - Low: Weak trend, potential reversal. +- Body Range Ratio: + - High: Significant movement, potential exhaustion. + - Low: Insignificant movement, potential reversal. +- Body Shadow Ratio: + - High: Pressure with potential reversal. + - Low: Weak pressure, lower reversal risk. + +[Step 2: Technical Analysis] +- EMA: + - Upward: Indicates an overall upward trend, reducing risk for LONG and increasing risk for SHORT. + - Downward: Suggests trend reversal, favorable for SHORT and riskier for LONG. +- MACD Histogram: + - Positive: Potential exhaustion. + - Negative: Potential rebound. +- RSI: + - Above 70: Overbought, increased risk for LONG, favorable for SHORT. + - Below 30: Oversold, favorable for LONG, increased risk for SHORT. +- CCI: + - Above 100: Overbought, increased risk for LONG, favorable for SHORT. + - Below -100: Oversold, favorable for LONG, increased risk for SHORT. +- ROC: + - Positive: Potential exhaustion if combined with overbought signals, higher reversal risk. + - Negative: Potential rebound if combined with oversold signals, higher rebound risk. +- Normalized Volume: + - High: Strong sentiment, potential exhaustion. + - Low: Weak sentiment, potential reversal. +- VWAP: + - Above: Favorable for LONG, increased risk for SHORT. + - Below: Favorable for SHORT, increased risk for LONG. +- Support/Resistance Levels: + - Near Resistance: Increased risk for LONG, favorable for SHORT. + - Near Support: Favorable for LONG, increased risk for SHORT. +- Bollinger Bands: + - Above Upper: Overbought, reversal risk for LONG. + - Below Lower: Oversold, rebound risk for SHORT. + - Wide (Significant gap between Upper and Lower): Indicates high volatility, potential breakout risk. + - Narrow (Small gap between Upper and Lower): Indicates low volatility, potential for explosive movement if bands expand. +- Volatility (True Range): + - High: Indicates increased risk of sharp movements; tighter stops recommended for both LONG and SHORT. + - Low: Suggests consolidation, with increased breakout potential, adjust risk management accordingly. +""" +risk_eval = """ +[Step 3: Risk Level Management] +- NONE: No significant risk factors. +- VERY_LOW: Minor risk factors, generally favorable. +- LOW: Some risk factors, not significant enough to deter. +- MODERATE: Noticeable risk factors, caution advised. +- HIGH: Significant risk factors, high caution or avoidance advised. +- VERY_HIGH: Major risk factors, generally unfavorable. +""" +signal_trend_risk_prompt = ( + f"{risk_intro}{risk_data}{trend_risk_framework}{risk_eval}{risk_outro}" +) +signal_contrarian_risk_prompt = ( + f"{risk_intro}{risk_data}{contrarian_risk_framework}{risk_eval}{risk_outro}" +) +signal_risk_pattern = r"RL:\s*(NONE|VERY_LOW|LOW|MODERATE|HIGH|VERY_HIGH)\s*,\s*TP:\s*([\d.]+)\s*,\s*SL:\s*([\d.]+)\s*\.*" diff --git a/core/actors/__init__.py b/core/actors/__init__.py index f3c8e5f4..83e534dd 100644 --- a/core/actors/__init__.py +++ b/core/actors/__init__.py @@ -1,3 +1,4 @@ -from ._actor import Actor +from ._base_actor import BaseActor +from ._strategy_actor import StrategyActor -__all__ = [Actor] +__all__ = [BaseActor, StrategyActor] diff --git a/core/actors/_actor.py b/core/actors/_base_actor.py similarity index 74% rename from core/actors/_actor.py rename to core/actors/_base_actor.py index 0a28fef8..641170ab 100644 --- a/core/actors/_actor.py +++ b/core/actors/_base_actor.py @@ -1,32 +1,23 @@ +import uuid + from core.commands.base import Command from core.interfaces.abstract_actor import AbstractActor, Ask, Message -from core.models.symbol import Symbol -from core.models.timeframe import Timeframe from core.queries.base import Query from infrastructure.event_dispatcher.event_dispatcher import EventDispatcher -class Actor(AbstractActor): +class BaseActor(AbstractActor): _EVENTS = [] - def __init__(self, symbol: Symbol, timeframe: Timeframe): + def __init__(self): super().__init__() - self._symbol = symbol - self._timeframe = timeframe self._running = False self._mailbox = EventDispatcher() + self._id = str(uuid.uuid4()) @property def id(self): - return f"{self._symbol}_{self._timeframe}" - - @property - def symbol(self): - return self._symbol - - @property - def timeframe(self): - return self._timeframe + return self._id @property def running(self): @@ -46,11 +37,9 @@ def on_receive(self, _msg: Message): def start(self): if self.running: - raise RuntimeError(f"Start: {self.__class__.__name__} is running") - - for event in self._EVENTS: - self._mailbox.register(event, self.on_receive, self._pre_receive) + raise RuntimeError(f"Start: {self.__class__.__name__} is already running") + self._register_events() self.on_start() self._running = True @@ -58,9 +47,7 @@ def stop(self): if not self.running: raise RuntimeError(f"Stop: {self.__class__.__name__} is not started") - for event in self._EVENTS: - self._mailbox.unregister(event, self.on_receive) - + self._unregister_events() self.on_stop() self._running = False @@ -73,5 +60,10 @@ async def ask(self, msg: Ask, *args, **kwrgs): if isinstance(msg, Command): await self._mailbox.execute(msg, *args, **kwrgs) - def _pre_receive(self, _msg: Message): - return self.pre_receive(_msg) + def _register_events(self): + for event in self._EVENTS: + self._mailbox.register(event, self.on_receive, self.pre_receive) + + def _unregister_events(self): + for event in self._EVENTS: + self._mailbox.unregister(event, self.on_receive) diff --git a/core/actors/_strategy_actor.py b/core/actors/_strategy_actor.py new file mode 100644 index 00000000..9847f9e4 --- /dev/null +++ b/core/actors/_strategy_actor.py @@ -0,0 +1,30 @@ +from core.models.symbol import Symbol +from core.models.timeframe import Timeframe + +from ._base_actor import BaseActor +from .policy.strategy import StrategyPolicy + + +class StrategyActor(BaseActor): + _EVENTS = [] + + def __init__(self, symbol: Symbol, timeframe: Timeframe): + super().__init__() + self._symbol = symbol + self._timeframe = timeframe + self._id = f"{self.symbol}_{self.timeframe}" + + @property + def id(self) -> str: + return self._id + + @property + def symbol(self) -> "Symbol": + return self._symbol + + @property + def timeframe(self) -> "Timeframe": + return self._timeframe + + def pre_receive(self, msg) -> bool: + return StrategyPolicy.should_process(self, msg) diff --git a/position/risk/__init__.py b/core/actors/policy/__init__.py similarity index 100% rename from position/risk/__init__.py rename to core/actors/policy/__init__.py diff --git a/core/actors/policy/event.py b/core/actors/policy/event.py new file mode 100644 index 00000000..536781fc --- /dev/null +++ b/core/actors/policy/event.py @@ -0,0 +1,10 @@ +from abc import ABC, abstractmethod + +from core.events.base import Event +from core.interfaces.abstract_actor import AbstractActor + + +class EventPolicy(ABC): + @abstractmethod + def should_process(actor: AbstractActor, event: Event) -> bool: + pass diff --git a/core/actors/policy/signal.py b/core/actors/policy/signal.py new file mode 100644 index 00000000..133cf6d2 --- /dev/null +++ b/core/actors/policy/signal.py @@ -0,0 +1,13 @@ +from core.events.ohlcv import NewMarketDataReceived + +from .event import EventPolicy + + +class SignalPolicy(EventPolicy): + @classmethod + def should_process(cls, actor, event: NewMarketDataReceived) -> bool: + return ( + event.symbol == actor.symbol + and event.timeframe == actor.timeframe + and event.closed + ) diff --git a/core/actors/policy/strategy.py b/core/actors/policy/strategy.py new file mode 100644 index 00000000..ae0194b5 --- /dev/null +++ b/core/actors/policy/strategy.py @@ -0,0 +1,28 @@ +from typing import Any + +from .event import EventPolicy + + +class StrategyPolicy(EventPolicy): + @classmethod + def should_process(cls, actor, event) -> bool: + symbol, timeframe = cls._get_event_key(event) + return actor.symbol == symbol and actor.timeframe == timeframe + + @classmethod + def _get_event_key(cls, event: Any): + key = cls._extract_key(event) + + if not all(hasattr(key, attr) for attr in ("symbol", "timeframe")): + raise AttributeError("Key does not have 'symbol' or 'timeframe' attributes") + + return key.symbol, key.timeframe + + @staticmethod + def _extract_key(event: Any): + if hasattr(event, "signal"): + return event.signal + elif hasattr(event, "position") and hasattr(event.position, "signal"): + return event.position.signal + + return event diff --git a/core/commands/base.py b/core/commands/base.py index 732fabe3..aeb972f8 100644 --- a/core/commands/base.py +++ b/core/commands/base.py @@ -2,19 +2,19 @@ import hashlib from dataclasses import dataclass, field, fields from datetime import datetime, timedelta -from enum import Enum +from enum import Enum, auto from core.events.base import Event, EventMeta class CommandGroup(Enum): - account = "account" - broker = "broker" - portfolio = "portfolio" - feed = "feed" + account = auto() + broker = auto() + portfolio = auto() + feed = auto() def __str__(self): - return self.value + return self.name @dataclass(frozen=True) diff --git a/core/commands/broker.py b/core/commands/broker.py index fc36aafb..a4655f68 100644 --- a/core/commands/broker.py +++ b/core/commands/broker.py @@ -32,10 +32,3 @@ class OpenPosition(BrokerCommand): @dataclass(frozen=True) class ClosePosition(BrokerCommand): position: Position - exit_price: float - - -@dataclass(frozen=True) -class AdjustPosition(BrokerCommand): - position: Position - adjust_price: float diff --git a/core/events/base.py b/core/events/base.py index b0801f0e..4d1e64bc 100644 --- a/core/events/base.py +++ b/core/events/base.py @@ -1,22 +1,22 @@ import uuid from dataclasses import asdict, dataclass, field from datetime import datetime -from enum import Enum +from enum import Enum, auto class EventGroup(Enum): - account = "account" - backtest = "backtest" - market = "market" - portfolio = "portfolio" - position = "position" - risk = "risk" - service = "service" - signal = "signal" - system = "system" + account = auto() + backtest = auto() + market = auto() + portfolio = auto() + position = auto() + risk = auto() + service = auto() + signal = auto() + system = auto() def __str__(self): - return self.value + return self.name @dataclass diff --git a/core/events/ohlcv.py b/core/events/ohlcv.py index d3e49ddf..1f91b7c1 100644 --- a/core/events/ohlcv.py +++ b/core/events/ohlcv.py @@ -21,3 +21,15 @@ class NewMarketDataReceived(MarketEvent): timeframe: Timeframe ohlcv: OHLCV closed: bool + + def to_dict(self): + parent_dict = super().to_dict() + + current_dict = { + "symbol": str(self.symbol), + "timeframe": str(self.timeframe), + "ohlcv": self.ohlcv.to_dict(), + "closed": self.closed, + } + + return {**parent_dict, **current_dict} diff --git a/core/events/position.py b/core/events/position.py index da15a826..c2813f8f 100644 --- a/core/events/position.py +++ b/core/events/position.py @@ -40,7 +40,7 @@ class PositionOpened(PositionEvent): @dataclass(frozen=True) class PositionCloseRequested(PositionEvent): - exit_price: float + pass @dataclass(frozen=True) diff --git a/core/events/risk.py b/core/events/risk.py index 02a93886..6ca91c53 100644 --- a/core/events/risk.py +++ b/core/events/risk.py @@ -1,7 +1,6 @@ from dataclasses import dataclass, field from core.models.position import Position -from core.models.risk_type import RiskType from .base import Event, EventGroup, EventMeta @@ -16,8 +15,7 @@ class RiskEvent(Event): @dataclass(frozen=True) class RiskThresholdBreached(RiskEvent): - exit_price: float - reason: RiskType + pass @dataclass(frozen=True) diff --git a/core/events/signal.py b/core/events/signal.py index 44ea0ca0..e9653ab4 100644 --- a/core/events/signal.py +++ b/core/events/signal.py @@ -1,6 +1,5 @@ -from dataclasses import asdict, dataclass, field +from dataclasses import dataclass, field -from core.models.ohlcv import OHLCV from core.models.signal import Signal from .base import Event, EventGroup, EventMeta @@ -9,59 +8,36 @@ @dataclass(frozen=True) class SignalEvent(Event): signal: Signal - ohlcv: OHLCV meta: EventMeta = field( default_factory=lambda: EventMeta(priority=5, group=EventGroup.signal), init=False, ) - -@dataclass(frozen=True) -class SignalEntryEvent(SignalEvent): - entry_price: float - stop_loss: float - def to_dict(self): parent_dict = super().to_dict() current_dict = { "signal": self.signal.to_dict(), - "entry_price": self.entry_price, - "stop_loss": self.stop_loss, - "ohlcv": asdict(self.ohlcv), } return {**parent_dict, **current_dict} @dataclass(frozen=True) -class SignalExitEvent(SignalEvent): - exit_price: float - - def to_dict(self): - return { - "signal": self.signal.to_dict(), - "exit_price": self.exit_price, - "ohlcv": asdict(self.ohlcv), - "meta": asdict(self.meta), - } - - -@dataclass(frozen=True) -class GoLongSignalReceived(SignalEntryEvent): +class GoLongSignalReceived(SignalEvent): pass @dataclass(frozen=True) -class GoShortSignalReceived(SignalEntryEvent): +class GoShortSignalReceived(SignalEvent): pass @dataclass(frozen=True) -class ExitLongSignalReceived(SignalExitEvent): +class ExitLongSignalReceived(SignalEvent): pass @dataclass(frozen=True) -class ExitShortSignalReceived(SignalExitEvent): +class ExitShortSignalReceived(SignalEvent): pass diff --git a/core/interfaces/abstract_actor.py b/core/interfaces/abstract_actor.py index 648c1757..f0042bef 100644 --- a/core/interfaces/abstract_actor.py +++ b/core/interfaces/abstract_actor.py @@ -19,8 +19,6 @@ GoLongSignalReceived, GoShortSignalReceived, ) -from core.models.symbol import Symbol -from core.models.timeframe import Timeframe from core.queries.base import Query Message = Union[ @@ -48,16 +46,6 @@ class AbstractActor(ABC): def id(self) -> str: pass - @property - @abstractmethod - def symbol(self) -> Symbol: - pass - - @property - @abstractmethod - def timeframe(self) -> Timeframe: - pass - @property @abstractmethod def running(self) -> bool: diff --git a/core/interfaces/abstract_executor_actor_factory.py b/core/interfaces/abstract_executor_actor_factory.py index ef6c0344..14243813 100644 --- a/core/interfaces/abstract_executor_actor_factory.py +++ b/core/interfaces/abstract_executor_actor_factory.py @@ -1,5 +1,6 @@ from abc import ABC, abstractmethod +from core.interfaces.abstract_market_repository import AbstractMarketRepository from core.models.order import OrderType from core.models.symbol import Symbol from core.models.timeframe import Timeframe @@ -10,6 +11,10 @@ class AbstractExecutorActorFactory(ABC): @abstractmethod def create_actor( - self, type: OrderType, symbol: Symbol, timeframe: Timeframe + self, + type: OrderType, + symbol: Symbol, + timeframe: Timeframe, + repository: AbstractMarketRepository, ) -> AbstractActor: pass diff --git a/core/interfaces/abstract_llm_service.py b/core/interfaces/abstract_llm_service.py new file mode 100644 index 00000000..dc51e20b --- /dev/null +++ b/core/interfaces/abstract_llm_service.py @@ -0,0 +1,7 @@ +from abc import ABC, abstractmethod + + +class AbstractLLMService(ABC): + @abstractmethod + def call(self, system_prompt: str, user_prompt: str) -> str: + pass diff --git a/core/interfaces/abstract_market_repository.py b/core/interfaces/abstract_market_repository.py new file mode 100644 index 00000000..7124d19b --- /dev/null +++ b/core/interfaces/abstract_market_repository.py @@ -0,0 +1,20 @@ +from abc import abstractmethod + +from core.interfaces.abstract_event_manager import AbstractEventManager +from core.models.ohlcv import OHLCV +from core.models.symbol import Symbol +from core.models.timeframe import Timeframe + + +class AbstractMarketRepository(AbstractEventManager): + @abstractmethod + def upsert(self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV): + pass + + @abstractmethod + def find_next_bar(self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV): + pass + + @abstractmethod + def ta(self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV): + pass diff --git a/core/interfaces/abstract_position_size_strategy.py b/core/interfaces/abstract_order_size_strategy.py similarity index 60% rename from core/interfaces/abstract_position_size_strategy.py rename to core/interfaces/abstract_order_size_strategy.py index 75b7dd88..d87f29bf 100644 --- a/core/interfaces/abstract_position_size_strategy.py +++ b/core/interfaces/abstract_order_size_strategy.py @@ -1,16 +1,13 @@ from abc import abstractmethod -from typing import Optional from core.interfaces.abstract_event_manager import AbstractEventManager from core.models.signal import Signal -class AbstractPositionSizeStrategy(AbstractEventManager): +class AbstractOrderSizeStrategy(AbstractEventManager): @abstractmethod def calculate( self, signal: Signal, - entry_price: float, - stop_loss_price: Optional[float] = None, ) -> float: pass diff --git a/core/interfaces/abstract_position_factory.py b/core/interfaces/abstract_position_factory.py index 39524b3c..96302a40 100644 --- a/core/interfaces/abstract_position_factory.py +++ b/core/interfaces/abstract_position_factory.py @@ -1,17 +1,14 @@ from abc import ABC, abstractmethod -from core.models.ohlcv import OHLCV from core.models.position import Position +from core.models.risk_type import SignalRiskType from core.models.signal import Signal +from core.models.ta import TechAnalysis class AbstractPositionFactory(ABC): @abstractmethod - def create_position( - self, - signal: Signal, - ohlcv: OHLCV, - entry_price: float, - stop_loss: float, + def create( + self, signal: Signal, signal_risk_type: SignalRiskType, ta: TechAnalysis ) -> Position: pass diff --git a/core/interfaces/abstract_strategy_generator_factory.py b/core/interfaces/abstract_strategy_generator_factory.py index 9ee4062d..5b493303 100644 --- a/core/interfaces/abstract_strategy_generator_factory.py +++ b/core/interfaces/abstract_strategy_generator_factory.py @@ -1,13 +1,10 @@ from abc import ABC, abstractmethod from core.interfaces.abstract_strategy_generator import AbstractStrategyGenerator -from core.models.strategy import StrategyType from core.models.symbol import Symbol class AbstractStrategyGeneratorFactory(ABC): @abstractmethod - def create( - self, type: StrategyType, symbols: list[Symbol] - ) -> AbstractStrategyGenerator: + def create(self, symbols: list[Symbol]) -> AbstractStrategyGenerator: pass diff --git a/core/interfaces/abstract_timeseries.py b/core/interfaces/abstract_timeseries.py new file mode 100644 index 00000000..56f1b382 --- /dev/null +++ b/core/interfaces/abstract_timeseries.py @@ -0,0 +1,34 @@ +from abc import ABC, abstractmethod +from typing import Optional + +from core.models.ohlcv import OHLCV +from core.models.symbol import Symbol +from core.models.timeframe import Timeframe + + +class AbstractTimeSeriesService(ABC): + @abstractmethod + async def upsert(self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV): + pass + + @abstractmethod + async def next_bar( + self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV + ) -> Optional[OHLCV]: + pass + + @abstractmethod + async def prev_bar( + self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV + ) -> Optional[OHLCV]: + pass + + @abstractmethod + async def back_n_bars( + self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV, n: int + ) -> Optional[OHLCV]: + pass + + @abstractmethod + async def ta(self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV): + pass diff --git a/core/interfaces/abstract_wasm_manager.py b/core/interfaces/abstract_wasm_manager.py new file mode 100644 index 00000000..0c5cb620 --- /dev/null +++ b/core/interfaces/abstract_wasm_manager.py @@ -0,0 +1,12 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +from wasmtime import Instance, Store + +from core.models.wasm_type import WasmType + + +class AbstractWasmManager(ABC): + @abstractmethod + def get_instance(self, wasm_type: WasmType) -> Tuple[Instance, Store]: + pass diff --git a/core/interfaces/abstract_wasm_service.py b/core/interfaces/abstract_wasm_service.py deleted file mode 100644 index 0b31f82f..00000000 --- a/core/interfaces/abstract_wasm_service.py +++ /dev/null @@ -1,9 +0,0 @@ -from abc import ABC, abstractmethod - -from wasmtime import Engine, Module - - -class AbstractWasmService(ABC): - @abstractmethod - def get_module(self, identifier: str, engine: Engine) -> Module: - pass diff --git a/core/mixins/__init__.py b/core/mixins/__init__.py new file mode 100644 index 00000000..9ce3597f --- /dev/null +++ b/core/mixins/__init__.py @@ -0,0 +1,3 @@ +from ._event_handler import EventHandlerMixin + +__all__ = [EventHandlerMixin] diff --git a/core/mixins/_event_handler.py b/core/mixins/_event_handler.py new file mode 100644 index 00000000..cc0e7e95 --- /dev/null +++ b/core/mixins/_event_handler.py @@ -0,0 +1,15 @@ +from typing import Any, Callable, Dict, Type + + +class EventHandlerMixin: + def __init__(self): + self._handlers: Dict[Type[Any], Callable] = {} + + def register_handler(self, event_type: Type[Any], handler: Callable): + self._handlers[event_type] = handler + + async def handle_event(self, event: Any) -> Any: + handler = self._handlers.get(type(event)) + if handler: + return await handler(event) + return None diff --git a/core/models/candle.py b/core/models/candle.py index 48e634da..a2634292 100644 --- a/core/models/candle.py +++ b/core/models/candle.py @@ -19,3 +19,19 @@ class CandleTrendType(Enum): def __str__(self): return self.name.upper() + + +class CandleReversalType(Enum): + DOJI = 1 + ENGULFING = 2 + EUPHORIA = 3 + HAMMER = 4 + HARAMIF = 5 + HARAMIS = 6 + KANGAROO = 7 + R = 8 + SPLIT = 9 + TWEEZERS = 10 + + def __str__(self): + return self.name.upper() diff --git a/core/models/moving_average.py b/core/models/moving_average.py index 4d595bfb..1e43be64 100644 --- a/core/models/moving_average.py +++ b/core/models/moving_average.py @@ -16,19 +16,22 @@ class MovingAverageType(Enum): MD = 12 RMSMA = 13 SINWMA = 14 - SMA = 15 - SMMA = 16 - T3 = 17 - TEMA = 18 - TMA = 19 - VIDYA = 20 - VWMA = 21 - VWEMA = 22 - WMA = 23 - ZLEMA = 24 - ZLSMA = 25 - ZLTEMA = 26 - ZLHMA = 27 + SLSMA = 15 + SMA = 16 + SMMA = 17 + T3 = 18 + TEMA = 19 + TL = 20 + TRIMA = 21 + ULTS = 22 + VIDYA = 23 + VWMA = 24 + VWEMA = 25 + WMA = 26 + ZLEMA = 27 + ZLSMA = 28 + ZLTEMA = 29 + ZLHMA = 30 def __str__(self): return self.name.upper() diff --git a/core/models/ohlcv.py b/core/models/ohlcv.py index 96cd012f..b9a61416 100644 --- a/core/models/ohlcv.py +++ b/core/models/ohlcv.py @@ -1,7 +1,17 @@ -from dataclasses import asdict, dataclass +from dataclasses import dataclass +from enum import Enum, auto from typing import Any, Dict, List +class CandleType(Enum): + BULLISH = auto() + BEARISH = auto() + NEUTRAL = auto() + + def __str__(self): + return self.name.upper() + + @dataclass(frozen=True) class OHLCV: timestamp: int @@ -11,6 +21,16 @@ class OHLCV: close: float volume: float + def __post_init__(self): + if not ( + self.low <= self.open <= self.high and self.low <= self.close <= self.high + ): + raise ValueError( + "Open and Close prices must be between Low and High prices" + ) + if self.low > self.high: + raise ValueError("Low price cannot be higher than High price") + @classmethod def from_list(cls, data: List[Any]) -> "OHLCV": timestamp, open, high, low, close, volume = data @@ -26,12 +46,133 @@ def from_list(cls, data: List[Any]) -> "OHLCV": @classmethod def from_dict(cls, data: Dict) -> "OHLCV": - return cls.from_list( - [ - data[key] - for key in ["timestamp", "open", "high", "low", "close", "volume"] - ] + keys = [ + "start", + "timestamp", + "open", + "high", + "low", + "close", + "volume", + "confirm", + ] + + if any(key not in data for key in keys): + raise ValueError(f"Data dictionary must contain the keys: {keys}") + + confirmed = ["start", "open", "high", "low", "close", "volume"] + not_confirmed = ["timestamp", "open", "high", "low", "close", "volume"] + + ohlcv_keys = not_confirmed if not data["confirm"] else confirmed + + return cls.from_list([data[key] for key in ohlcv_keys]) + + @property + def real_body(self) -> float: + return abs(self.open - self.close) + + @property + def upper_shadow(self) -> float: + return self.high - max(self.open, self.close) + + @property + def lower_shadow(self) -> float: + return min(self.open, self.close) - self.low + + @property + def price_range(self) -> float: + return self.high - self.low + + @property + def price_movement(self) -> float: + return self.close - self.open + + @property + def body_range_ratio(self) -> float: + return self.real_body / self.price_range if self.price_range != 0 else 0 + + @property + def body_shadow_ratio(self) -> float: + total_shadow = self.upper_shadow + self.lower_shadow + return self.real_body / total_shadow if total_shadow != 0 else 0 + + @property + def shadow_range_ratio(self) -> float: + total_shadow = self.upper_shadow + self.lower_shadow + return total_shadow / self.price_range if self.price_range != 0 else 0 + + @property + def real_body_normalized(self) -> float: + return self.real_body / self.price_range if self.price_range != 0 else 0 + + @property + def type(self) -> CandleType: + if self.price_movement > 0: + return CandleType.BULLISH + + if self.price_movement < 0: + return CandleType.BEARISH + + return CandleType.NEUTRAL + + def __lt__(self, other: object): + if not isinstance(other, OHLCV): + return NotImplemented + + return self.timestamp < other.timestamp + + def __eq__(self, other): + if not isinstance(other, OHLCV): + return False + + return ( + self.timestamp == other.timestamp + and self.open == other.open + and self.high == other.high + and self.low == other.low + and self.close == other.close + and self.volume == other.volume ) def to_dict(self): - return asdict(self) + return { + "timestamp": self.timestamp, + "open": self.open, + "high": self.high, + "low": self.low, + "close": self.close, + "volume": self.volume, + "real_body": self.real_body, + "upper_shadow": self.upper_shadow, + "lower_shadow": self.lower_shadow, + "price_range": self.price_range, + "price_movement": self.price_movement, + "body_range_ratio": self.body_range_ratio, + "body_shadow_ratio": self.body_shadow_ratio, + "shadow_range_ratio": self.shadow_range_ratio, + "real_body_normalized": self.real_body_normalized, + "type": str(self.type), + } + + def __str__(self) -> str: + return ( + f"timestamp={self.timestamp}, " + f"open={self.open}, " + f"high={self.high}, " + f"low={self.low}, " + f"close={self.close}, " + f"volume={self.volume}, " + f"real_body={self.real_body:.8f}, " + f"upper_shadow={self.upper_shadow:.8f}, " + f"lower_shadow={self.lower_shadow:.8f}, " + f"price_range={self.price_range:.8f}, " + f"price_movement={self.price_movement:.8f}, " + f"body_range_ratio={self.body_range_ratio:.8f}, " + f"body_shadow_ratio={self.body_shadow_ratio:.8f}, " + f"shadow_range_ratio={self.shadow_range_ratio:.8f}, " + f"real_body_normalized={self.real_body_normalized:.8f}, " + f"type={self.type}" + ) + + def __repr__(self) -> str: + return f"OHLCV({self})" diff --git a/core/models/order.py b/core/models/order.py index 1efd740b..6de0ef99 100644 --- a/core/models/order.py +++ b/core/models/order.py @@ -21,10 +21,10 @@ class Order: status: OrderStatus price: float size: float - fee: float = field(default_factory=lambda: 0.0) type: OrderType = field(default=OrderType.MARKET) id: str = field(default_factory=lambda: str(uuid.uuid4())) timestamp: float = field(default_factory=lambda: int(datetime.now().timestamp())) + fee: float = field(default_factory=lambda: 0.0) def to_dict(self): return asdict(self) diff --git a/core/models/portfolio.py b/core/models/portfolio.py index c10f6b0c..8039b579 100644 --- a/core/models/portfolio.py +++ b/core/models/portfolio.py @@ -1,10 +1,12 @@ from dataclasses import dataclass, field, replace from datetime import datetime +from functools import cached_property import numpy as np from scipy.stats import kurtosis, norm, skew TOTAL_TRADES_THRESHOLD = 3 +GAMMA = 0.57721566 @dataclass(frozen=True) @@ -17,115 +19,133 @@ class Performance: updated_at: float = field(default_factory=lambda: datetime.now().timestamp()) @property + def equity(self): + return np.array([self._account_size]) + np.cumsum(self._pnl) + + @cached_property def total_trades(self) -> int: - return self._pnl.size + return len(self._pnl) - @property + @cached_property def total_pnl(self) -> float: return np.sum(self._pnl) - @property - def total_fee(self) -> float: - return np.sum(self._fee) - - @property + @cached_property def average_pnl(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 + + if len(self._pnl) < 2: + return 0.0 return np.mean(self._pnl) @property - def average_win(self) -> float: - win = self._pnl[self._pnl > 0] - - if len(win) < 2: - return 0 + def profit(self): + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return np.array([0.0]) - return np.mean(win) + return self._pnl[self._pnl > 0] @property - def average_loss(self) -> float: - loss = self._pnl[self._pnl < 0] + def loss(self): + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return np.array([0.0]) - if len(loss) < 2: - return 0 + return self._pnl[self._pnl < 0] - return np.mean(loss) + @cached_property + def total_profit(self): + return np.sum(self.profit) - @property - def max_consecutive_wins(self) -> int: - return self._max_streak(self._pnl, True) + @cached_property + def total_loss(self): + return np.sum(self.loss) - @property - def max_consecutive_losses(self) -> int: - return self._max_streak(self._pnl, False) + @cached_property + def total_fee(self) -> float: + return np.sum(self._fee) - @property + @cached_property def hit_ratio(self) -> float: - total_trades = self.total_trades + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return 0.0 - if total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return np.sum(self._pnl > 0) / self.total_trades - return np.divide(np.sum(self._pnl > 0), total_trades) + @cached_property + def average_profit(self) -> float: + if len(self.profit) < 2: + return 0.0 - @property - def equity(self): - return [self._account_size] + self._pnl.cumsum() + return np.mean(self.profit) - @property - def drawdown(self): - equity_curve = self.equity + @cached_property + def average_loss(self) -> float: + if len(self.loss) < 2: + return 0.0 - if len(equity_curve) < 2: - return np.array([0, 0]) + return np.mean(self.loss) - peak = np.maximum.accumulate(equity_curve) + @cached_property + def max_consecutive_wins(self) -> int: + return self._max_streak(self._pnl, True) - return np.divide(peak - equity_curve, peak) + @cached_property + def max_consecutive_losses(self) -> int: + return self._max_streak(self._pnl, False) @property - def runup(self) -> float: - equity_curve = self.equity + def drawdown(self): + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return np.array([0.0]) - if len(equity_curve) < 2: - return np.array([0, 0]) + peak = np.maximum.accumulate(self.equity) + return (peak - self.equity) / peak - trough = np.minimum.accumulate(equity_curve) + @property + def runup(self): + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return np.array([0.0]) - return np.divide(equity_curve - trough, trough) + trough = np.minimum.accumulate(self.equity) + return (self.equity - trough) / trough - @property + @cached_property def max_runup(self) -> float: + if len(self.runup) == 0: + return 0.0 + return np.max(self.runup) - @property + @cached_property def max_drawdown(self) -> float: + if len(self.drawdown) == 0: + return 0.0 + return np.max(self.drawdown) - @property + @cached_property def calmar_ratio(self) -> float: - max_drawdown = self.max_drawdown + denom = abs(self.max_drawdown) - if max_drawdown == 0: - return 0 + if denom == 0: + return 0.0 - return np.divide(self.cagr, np.abs(max_drawdown)) + return self.cagr / denom - @property + @cached_property def sharpe_ratio(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 std_return = np.std(self._pnl, ddof=1) - if std_return == 0: - return 0 + return 0.0 - return np.divide(self.average_pnl, std_return) + return self.average_pnl / std_return - @property + @cached_property def smart_sharpe_ratio(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: return 0 @@ -133,180 +153,144 @@ def smart_sharpe_ratio(self) -> float: std_return = np.std(self._pnl, ddof=1) penalty = self._penalty(self._pnl) - if std_return == 0 or penalty is None or np.isnan(penalty): - return 0 + denom = std_return * penalty - return np.divide(self.average_pnl, std_return * penalty) + if denom == 0: + return 0.0 - @property - def deflated_sharpe_ratio(self) -> float: - total_trades = self.total_trades + return self.average_pnl / denom - if total_trades < TOTAL_TRADES_THRESHOLD: + @cached_property + def deflated_sharpe_ratio(self) -> float: + if self.total_trades < TOTAL_TRADES_THRESHOLD: return 0 - sharpe_ratio = self.sharpe_ratio - skewness = self.skew - kurtosis = self.kurtosis - - gamma = 0.57721566 e = np.exp(1) sharpe_ratio_star = np.sqrt(0.5 / self._periods_per_year) * ( - (1 - gamma) * norm.ppf(1 - 1 / total_trades) - + gamma * norm.ppf(1 - 1 / (total_trades * e)) + (1 - GAMMA) * norm.ppf(1 - 1 / self.total_trades) + + GAMMA * norm.ppf(1 - 1 / (self.total_trades * e)) ) - denom = 1 - skewness * sharpe_ratio + ((kurtosis - 1) / 4) * sharpe_ratio**2 + denom = ( + 1 + - self.skew * self.sharpe_ratio + + ((self.kurtosis - 1) / 4) * self.sharpe_ratio**2 + ) if denom <= 0: return 0 return norm.cdf( - np.divide( - (sharpe_ratio - sharpe_ratio_star) * np.sqrt(total_trades - 1), - np.sqrt(denom), - ) + (self.sharpe_ratio - sharpe_ratio_star) + * np.sqrt(self.total_trades - 1) + / np.sqrt(denom) ) - @property + @cached_property def sortino_ratio(self) -> float: - total_trades = self.total_trades - - if total_trades < TOTAL_TRADES_THRESHOLD: - return 0 - - downside_returns = self._pnl[self._pnl < 0] - - if len(downside_returns) < 2: - return 0 + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return 0.0 - downside = np.sqrt(np.sum(downside_returns**2) / total_trades) + downside = np.sqrt(np.sum(self.loss**2) / self.total_trades) if downside == 0: - return 0 + return 0.0 - return np.divide(self.average_pnl, downside) + return self.average_pnl / downside - @property + @cached_property def smart_sortino_ratio(self) -> float: - total_trades = self.total_trades - - if total_trades < TOTAL_TRADES_THRESHOLD: - return 0 - - downside_returns = self._pnl[self._pnl < 0] - - if len(downside_returns) < 2: - return 0 + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return 0.0 - downside = np.sqrt(np.sum(downside_returns**2) / total_trades) + downside = np.sqrt(np.sum(self.loss**2) / self.total_trades) penalty = self._penalty(self._pnl) - if downside == 0 or penalty is None or np.isnan(penalty): - return 0 + denom = downside * penalty - return np.divide(self.average_pnl, downside * penalty) + if denom == 0: + return 0.0 - @property + return self.average_pnl / denom + + @cached_property def payoff_ratio(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 - denom = np.abs(self.average_loss) + denom = abs(self.average_loss) if denom == 0: - return 0 + return 0.0 - return np.divide(self.average_win, denom) + return self.average_profit / denom - @property + @cached_property def cagr(self) -> float: - periods = self.total_trades - - if periods < TOTAL_TRADES_THRESHOLD: - return 0 - - equity = self.equity - - if len(equity) < 2: - return 0 + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return 0.0 - final_value = equity[-1] + final_value = self.equity[-1] initial_value = self._account_size if initial_value == 0: - return 0 - + return 0.0 if final_value < initial_value: - return -1 + return -1.0 compound_factor = final_value / initial_value - time_factor = 1 / (periods / self._periods_per_year) + time_factor = 1 / (self.total_trades / self._periods_per_year) return np.power(compound_factor, time_factor) - 1 - @property - def optimal_f(self) -> float: - total_trades = self.total_trades - - if total_trades < TOTAL_TRADES_THRESHOLD: - return self._risk_per_trade - - equity = self.equity - - max_loss = np.abs(np.min(self._pnl)) - initial_value = self._account_size - growth_factor = equity[-1] / initial_value - - if growth_factor <= 0: - return self._risk_per_trade - - return np.divide(max_loss, np.abs(initial_value)) * np.sqrt(growth_factor) - - @property + @cached_property def kelly(self) -> float: - total_trades = self.total_trades - - if total_trades < TOTAL_TRADES_THRESHOLD: + if self.total_trades < TOTAL_TRADES_THRESHOLD: return self._risk_per_trade - wl_ratio = self.payoff_ratio - - if wl_ratio == 0: + if self.payoff_ratio == 0: return self._risk_per_trade - win_prob = self.hit_ratio - - return win_prob - np.divide(1 - win_prob, wl_ratio) + return self.hit_ratio - (1 - self.hit_ratio) / self.payoff_ratio - @property + @cached_property def ann_sharpe_ratio(self) -> float: return self.sharpe_ratio * np.sqrt(self._periods_per_year) - @property - def expected_return(self) -> float: - total_trades = self.total_trades + @cached_property + def time_weighted_return(self) -> float: + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return 0.0 - if total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + if len(self.equity) < 2: + return 0.0 - pnl_positive = self._pnl[self._pnl > 0] + return (self.equity[-1] / self.equity[0]) - 1 - if len(pnl_positive) == 0: - return 0 + @cached_property + def geometric_holding_period_return(self) -> float: + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return 0.0 - log_prod = np.sum(np.log(1.0 + pnl_positive)) + return (1 + self.time_weighted_return) ** (1 / self.total_trades) - 1 + + @cached_property + def expected_return(self) -> float: + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return 0.0 + + log_prod = np.sum(np.log(1.0 + self.profit)) if log_prod <= 0: - return 0 + return 0.0 - return np.exp(log_prod / total_trades) - 1.0 + return np.exp(log_prod / self.total_trades) - 1.0 - @property + @cached_property def ann_volatility(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 daily_returns = self._pnl / self._account_size @@ -314,74 +298,60 @@ def ann_volatility(self) -> float: return volatility * np.sqrt(self._periods_per_year) - @property + @cached_property def recovery_factor(self) -> float: - max_drawdown = self.max_drawdown - - if max_drawdown == 0: - return 0 - - total_profit = np.sum(self._pnl[self._pnl > 0]) + if self.max_drawdown == 0: + return 0.0 - return np.divide(total_profit, max_drawdown) + return self.total_profit / self.max_drawdown - @property + @cached_property def profit_factor(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 - pnl_positive = self._pnl > 0 - profit, loss = self._pnl[pnl_positive], self._pnl[~pnl_positive] - - if len(profit) < 2 or len(loss) < 2: - return 0 - - gross_profit, gross_loss = np.sum(profit), np.abs(np.sum(loss)) + gross_loss = abs(self.total_loss) if gross_loss == 0: - return 0 + return 0.0 - return np.divide(gross_profit, gross_loss) + return self.total_profit / gross_loss - @property + @cached_property def risk_of_ruin(self) -> float: - total_trades = self.total_trades - - if total_trades < TOTAL_TRADES_THRESHOLD: - return 0 - - win_rate = self.hit_ratio + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return 0.0 - return np.divide(1 - win_rate, 1 + win_rate) ** total_trades + return ((1 - self.hit_ratio) / (1 + self.hit_ratio)) ** self.total_trades - @property + @cached_property def skew(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 return skew(self._pnl, bias=False) - @property + @cached_property def kurtosis(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 return kurtosis(self._pnl, bias=False) - @property + @cached_property def var(self, confidence_level=0.95) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 mu = self.average_pnl sigma = np.std(self._pnl, ddof=1) return norm.ppf(1.0 - confidence_level, mu, sigma) - @property + @cached_property def cvar(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 var = self.var pnl = self._pnl[self._pnl < var] @@ -391,162 +361,143 @@ def cvar(self) -> float: return np.mean(pnl) - @property + @cached_property def ulcer_index(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 - - drawdown = self.drawdown + return 0.0 - if len(drawdown) < 2: - return 0 + if len(self.drawdown) < 2: + return 0.0 - return np.sqrt(np.mean(drawdown**2)) + return np.sqrt(np.mean(self.drawdown**2)) - @property + @cached_property def upi(self) -> float: - ulcer_index = self.ulcer_index - - if ulcer_index == 0: - return 0 + if self.ulcer_index == 0: + return 0.0 - return np.divide(self.expected_return, ulcer_index) + return self.expected_return / self.ulcer_index - @property + @cached_property def common_sense_ratio(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 return self.profit_factor * self.tail_ratio - @property + @cached_property def cpc_ratio(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 return self.profit_factor * self.hit_ratio * self.payoff_ratio - @property + @cached_property def lake_ratio(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 - - equity = self.equity - - if len(equity) < 2: - return 0 + return 0.0 - peaks = np.maximum.accumulate(equity) - drawdowns = (peaks - equity) / peaks - underwater_time = np.divide(np.sum(drawdowns < 0), self._periods_per_year) + underwater_time = np.sum(self.drawdown < 0) / self._periods_per_year return 1 - underwater_time - @property + @cached_property def burke_ratio(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 downside_deviation = np.std(np.minimum(self._pnl, 0), ddof=1) if downside_deviation == 0: - return 0 + return 0.0 - return np.divide(self.cagr, downside_deviation) + return self.cagr / downside_deviation - @property + @cached_property def rachev_ratio(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 pnl_sorted = np.sort(self._pnl)[::-1] - var_95 = np.percentile(pnl_sorted, 5) shortfall = pnl_sorted[pnl_sorted <= var_95] - if len(shortfall) < 2: - return 0 + return 0.0 expected_shortfall = np.abs(np.mean(shortfall)) - if expected_shortfall == 0: - return 0 + return 0.0 - return np.divide(np.abs(self.average_pnl), expected_shortfall) + return abs(self.average_pnl) / expected_shortfall - @property + @cached_property def sterling_ratio(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 - - gains, losses = self._pnl[self._pnl > 0], self._pnl[self._pnl < 0] + return 0.0 - if len(losses) < 2 or len(gains) < 2: - return 0 + if len(self.loss) < 2: + return 0.0 - upside_potential = np.mean(gains) - downside_risk = np.sqrt(np.mean(losses**2)) + downside_risk = np.sqrt(np.mean(self.loss**2)) if downside_risk == 0: - return 0 + return 0.0 - return np.divide(upside_potential, downside_risk) + return self.average_profit / downside_risk - @property + @cached_property def tail_ratio(self, cutoff=95) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 denom = np.percentile(self._pnl, 100 - cutoff) - if denom == 0: - return 0 + return 0.0 - return np.abs(np.divide(np.percentile(self._pnl, cutoff), denom)) + return abs(np.percentile(self._pnl, cutoff) / denom) - @property + @cached_property def omega_ratio(self) -> float: if self.total_trades < TOTAL_TRADES_THRESHOLD: - return 0 + return 0.0 - gains, losses = self._pnl[self._pnl > 0], self._pnl[self._pnl < 0] + gross_loss = abs(self.total_loss) - if len(losses) < 2 or len(gains) < 2: - return 0 + if gross_loss == 0: + return 0.0 - sum_losses = np.sum(np.abs(losses)) + return self.total_profit / gross_loss - if sum_losses == 0: - return 0 + @cached_property + def martin_ratio(self) -> float: + if self.ulcer_index == 0: + return 0.0 - return np.divide(np.sum(gains), sum_losses) + return self.average_pnl / self.ulcer_index - @property + @cached_property def kappa_three_ratio(self) -> float: - total_trades = self.total_trades - - if total_trades < TOTAL_TRADES_THRESHOLD: - return 0 - - gains, losses = self._pnl[self._pnl > 0], self._pnl[self._pnl < 0] + if self.total_trades < TOTAL_TRADES_THRESHOLD: + return 0.0 - if len(losses) < 2 or len(gains) < 2: - return 0 + threshold = self.average_profit - self.average_loss - avg_gain, avg_loss = np.mean(gains), np.mean(losses) + if threshold == 0: + return 0.0 - threshold = avg_gain - avg_loss + up_proportion = np.sum(self.profit > threshold) / self.total_trades + down_proportion = np.sum(self.loss < threshold) / self.total_trades - up_proportion = np.sum(gains > threshold) / total_trades - down_proportion = np.sum(losses < threshold) / total_trades + if len(self._pnl) < 2: + return 0.0 denom = np.sqrt(np.mean(self._pnl**2)) if denom == 0: - return 0 + return 0.0 - return np.divide((up_proportion**3 - down_proportion), denom) + return (up_proportion**3 - down_proportion) / denom def next(self, pnl: float, fee: float) -> "Performance": _pnl, _fee = np.append(self._pnl, pnl), np.append(self._fee, fee) @@ -583,20 +534,6 @@ def _penalty(pnl) -> float: return np.sqrt(1 + 2 * np.sum(corr)) - def __repr__(self): - return ( - f"Performance(total_trades={self.total_trades}, hit_ratio={self.hit_ratio}, profit_factor={self.profit_factor}, " - + f"max_runup={self.max_runup}, max_drawdown={self.max_drawdown}, sortino_ratio={self.sortino_ratio}, smart_sortino_ratio={self.smart_sortino_ratio}, calmar_ratio={self.calmar_ratio}, " - + f"risk_of_ruin={self.risk_of_ruin}, recovery_factor={self.recovery_factor}, optimal_f={self.optimal_f}, " - + f"total_pnl={self.total_pnl}, average_pnl={self.average_pnl}, total_fee={self.total_fee}, sharpe_ratio={self.sharpe_ratio}, smart_sharpe_ratio={self.smart_sharpe_ratio}, deflated_sharpe_ratio={self.deflated_sharpe_ratio}, " - + f"max_consecutive_wins={self.max_consecutive_wins}, max_consecutive_losses={self.max_consecutive_losses}, average_win={self.average_win}, average_loss={self.average_loss}, " - + f"cagr={self.cagr}, expected_return={self.expected_return}, annualized_volatility={self.ann_volatility}, annualized_sharpe_ratio={self.ann_sharpe_ratio}, payoff_ratio={self.payoff_ratio}, " - + f"var={self.var}, cvar={self.cvar}, ulcer_index={self.ulcer_index}, upi={self.upi}, kelly={self.kelly}, " - + f"lake_ratio={self.lake_ratio}, burke_ratio={self.burke_ratio}, rachev_ratio={self.rachev_ratio}, kappa_three_ratio={self.kappa_three_ratio}, " - + f"sterling_ratio={self.sterling_ratio}, tail_ratio={self.tail_ratio}, omega_ratio={self.omega_ratio}, cpc_ratio={self.cpc_ratio}, common_sense_ratio={self.common_sense_ratio}, " - + f"skew={self.skew}, kurtosis={self.kurtosis})" - ) - def to_dict(self): return { "account_size": self._account_size, @@ -604,11 +541,12 @@ def to_dict(self): "total_pnl": self.total_pnl, "total_fee": self.total_fee, "average_pnl": self.average_pnl, - "average_win": self.average_win, + "average_profit": self.average_profit, "average_loss": self.average_loss, - "max_consecutive_wins": self.max_consecutive_wins, - "max_consecutive_losses": self.max_consecutive_losses, - "hit_ratio": self.hit_ratio, + "profit": self.profit, + "loss": self.loss, + "total_profit": self.total_profit, + "total_loss": self.total_loss, "equity": self.equity, "runup": self.runup, "max_runup": self.max_runup, @@ -617,16 +555,15 @@ def to_dict(self): "sharpe_ratio": self.sharpe_ratio, "smart_sharpe_ratio": self.smart_sharpe_ratio, "deflated_sharpe_ratio": self.deflated_sharpe_ratio, - "calmar_ratio": self.calmar_ratio, - "cpc_ratio": self.cpc_ratio, - "common_sense_ratio": self.common_sense_ratio, "sortino_ratio": self.sortino_ratio, "smart_sortino_ratio": self.smart_sortino_ratio, + "calmar_ratio": self.calmar_ratio, "payoff_ratio": self.payoff_ratio, "cagr": self.cagr, - "optimal_f": self.optimal_f, "kelly": self.kelly, "expected_return": self.expected_return, + "time_weighted_return": self.time_weighted_return, + "geometric_holding_period_return": self.geometric_holding_period_return, "annualized_volatility": self.ann_volatility, "annualized_sharpe_ratio": self.ann_sharpe_ratio, "recovery_factor": self.recovery_factor, @@ -644,5 +581,31 @@ def to_dict(self): "sterling_ratio": self.sterling_ratio, "tail_ratio": self.tail_ratio, "omega_ratio": self.omega_ratio, + "martin_ratio": self.martin_ratio, "kappa_three_ratio": self.kappa_three_ratio, + "max_consecutive_wins": self.max_consecutive_wins, + "max_consecutive_losses": self.max_consecutive_losses, + "hit_ratio": self.hit_ratio, + "cpc_ratio": self.cpc_ratio, + "common_sense_ratio": self.common_sense_ratio, } + + def __str__(self): + return ( + f"total_trades={self.total_trades}, hit_ratio={self.hit_ratio}, profit_factor={self.profit_factor}, profit={self.profit}, loss={self.loss}, " + f"total_profit={self.total_profit}, total_loss={self.total_loss}, total_pnl={self.total_pnl}, average_pnl={self.average_pnl}, average_profit={self.average_profit}, average_loss={self.average_loss}, " + f"total_fee={self.total_fee}, max_consecutive_wins={self.max_consecutive_wins}, max_consecutive_losses={self.max_consecutive_losses}, " + f"equity={self.equity}, max_runup={self.max_runup}, max_drawdown={self.max_drawdown}, " + f"sharpe_ratio={self.sharpe_ratio}, smart_sharpe_ratio={self.smart_sharpe_ratio}, deflated_sharpe_ratio={self.deflated_sharpe_ratio}, sortino_ratio={self.sortino_ratio}, smart_sortino_ratio={self.smart_sortino_ratio}, calmar_ratio={self.calmar_ratio}, " + f"expected_return={self.expected_return}, cagr={self.cagr}, time_weighted_return={self.time_weighted_return}, geometric_holding_period_return={self.geometric_holding_period_return}, " + f"annualized_volatility={self.ann_volatility}, annualized_sharpe_ratio={self.ann_sharpe_ratio}, " + f"recovery_factor={self.recovery_factor}, risk_of_ruin={self.risk_of_ruin}, " + f"skew={self.skew}, kurtosis={self.kurtosis}, var={self.var}, cvar={self.cvar}, " + f"ulcer_index={self.ulcer_index}, upi={self.upi}, " + f"kelly={self.kelly}, lake_ratio={self.lake_ratio}, burke_ratio={self.burke_ratio}, rachev_ratio={self.rachev_ratio}, kappa_three_ratio={self.kappa_three_ratio}, " + f"payoff_ratio={self.payoff_ratio}, sterling_ratio={self.sterling_ratio}, tail_ratio={self.tail_ratio}, omega_ratio={self.omega_ratio}, " + f"cpc_ratio={self.cpc_ratio}, common_sense_ratio={self.common_sense_ratio}" + ) + + def __repr__(self): + return f"Performance({self})" diff --git a/core/models/position.py b/core/models/position.py index 8ddf4d6b..6932db4c 100644 --- a/core/models/position.py +++ b/core/models/position.py @@ -1,90 +1,141 @@ +import logging +import uuid from dataclasses import dataclass, field, replace from datetime import datetime -from typing import List, Tuple +from typing import List, Optional, Tuple -from core.interfaces.abstract_position_risk_strategy import AbstractPositionRiskStrategy -from core.interfaces.abstract_position_take_profit_strategy import ( - AbstractPositionTakeProfitStrategy, -) +import numpy as np from .ohlcv import OHLCV from .order import Order, OrderStatus -from .side import PositionSide +from .position_risk import PositionRisk +from .profit_target import ProfitTarget +from .risk_type import PositionRiskType, SessionRiskType, SignalRiskType +from .side import PositionSide, SignalSide from .signal import Signal +from .signal_risk import SignalRisk +from .ta import TechAnalysis + +logger = logging.getLogger(__name__) + +DEFAULT_TARGET_IDX = 2 +LATENCY_GAP_THRESHOLD = 1.8 @dataclass(frozen=True) class Position: + initial_size: float signal: Signal - side: PositionSide - risk_strategy: AbstractPositionRiskStrategy - take_profit_strategy: AbstractPositionTakeProfitStrategy + signal_risk: SignalRisk + position_risk: PositionRisk + profit_target: ProfitTarget orders: Tuple[Order] = () - stop_loss_price: float = field(default_factory=lambda: 0.0000001) - take_profit_price: float = field(default_factory=lambda: 0.0000001) - open_timestamp: float = field(default_factory=lambda: 0) - closed_timestamp: float = field(default_factory=lambda: 0) + expiration: int = field(default_factory=lambda: 900000) # 15min last_modified: float = field(default_factory=lambda: datetime.now().timestamp()) + id: str = field(default_factory=lambda: str(uuid.uuid4())) + _tp: Optional[float] = None + _sl: Optional[float] = None + + @property + def side(self) -> PositionSide: + if self.signal.side == SignalSide.BUY: + return PositionSide.LONG + + if self.signal.side == SignalSide.SELL: + return PositionSide.SHORT + + @property + def take_profit(self) -> float: + if self._tp: + return self._tp + + return self.profit_target.last + + @property + def stop_loss(self) -> float: + if self._sl: + return self._sl + + return self.signal.stop_loss + + @property + def open_timestamp(self) -> int: + return self.signal.ohlcv.timestamp + + @property + def close_timestamp(self) -> int: + return self.position_risk.curr_bar.timestamp + + @property + def signal_bar(self) -> OHLCV: + return self.signal.ohlcv + + @property + def risk_bar(self) -> OHLCV: + return self.position_risk.curr_bar @property def trade_time(self) -> int: - return abs(int(self.closed_timestamp - self.open_timestamp)) + return abs(self.close_timestamp - self.open_timestamp) @property def closed(self) -> bool: - closed_orders = [ - order.size for order in self.orders if order.status == OrderStatus.CLOSED - ] - closed_size = sum(closed_orders) - - failed_orders = [ - order for order in self.orders if order.status == OrderStatus.FAILED - ] + if not self.orders: + return False - pending_orders = [ - order for order in self.orders if order.status == OrderStatus.PENDING - ] + if self.rejected_orders: + return True - if not closed_orders: + if not self.closed_orders: return False - return closed_size >= self.filled_size or len(failed_orders) == len( - pending_orders + order_diff = self._average_size(self.open_orders) - self._average_size( + self.closed_orders ) + return order_diff <= 0 + + @property + def has_break_even(self) -> bool: + if self.side == PositionSide.LONG: + return self.stop_loss > self.entry_price + if self.side == PositionSide.SHORT: + return self.stop_loss < self.entry_price + + return False + + @property + def has_risk(self) -> bool: + return self.position_risk.type != PositionRiskType.NONE + @property def adj_count(self) -> int: - executed_orders = [ - order for order in self.orders if order.status == OrderStatus.EXECUTED - ] return max( 0, - len(executed_orders) - 1, + len(self.open_orders) - 1, ) @property - def pending_size(self) -> int: - pending_orders = [ - order.size for order in self.orders if order.status == OrderStatus.PENDING - ] + def size(self) -> float: + if self.closed_orders: + return self._average_size(self.closed_orders) - return sum(pending_orders) + if self.open_orders: + return self._average_size(self.open_orders) - @property - def pending_price(self) -> int: - pending_orders = [ - order.price for order in self.orders if order.status == OrderStatus.PENDING - ] + return 0.0 - return sum(pending_orders) / len(pending_orders) if pending_orders else 0.0 + @property + def open_orders(self) -> List[Order]: + return [order for order in self.orders if order.status == OrderStatus.EXECUTED] @property - def filled_size(self) -> int: - executed_orders = [ - order.size for order in self.orders if order.status == OrderStatus.EXECUTED - ] + def closed_orders(self) -> List[Order]: + return [order for order in self.orders if order.status == OrderStatus.CLOSED] - return sum(executed_orders) + @property + def rejected_orders(self) -> List[Order]: + return [order for order in self.orders if order.status == OrderStatus.FAILED] @property def pnl(self) -> float: @@ -93,96 +144,328 @@ def pnl(self) -> float: if not self.closed: return pnl - factor = -1 if self.side == PositionSide.SHORT else 1 + factor = -1.0 if self.side == PositionSide.SHORT else 1 + pnl = factor * (self.exit_price - self.entry_price) * self.size - return factor * (self.exit_price - self.entry_price) * self.filled_size + return pnl @property - def fee(self) -> float: - executed_orders = [ - order.fee for order in self.orders if order.status == OrderStatus.EXECUTED - ] - open_fee = sum(executed_orders) + def curr_pnl(self) -> float: + factor = -1.0 if self.side == PositionSide.SHORT else 1 + pnl = factor * (self.curr_price - self.entry_price) * self.size - closed_orders = [ - order.fee for order in self.orders if order.status == OrderStatus.CLOSED - ] - closed_fee = sum(closed_orders) + return pnl - return open_fee + closed_fee + @property + def fee(self) -> float: + return sum([order.fee for order in self.open_orders]) + sum( + [order.fee for order in self.closed_orders] + ) @property def entry_price(self) -> float: - executed_orders = [ - order.price for order in self.orders if order.status == OrderStatus.EXECUTED - ] - return sum(executed_orders) / len(executed_orders) if executed_orders else 0.0 + return self._average_price(self.open_orders) @property def exit_price(self) -> float: - closed_orders = [ - order.price for order in self.orders if order.status == OrderStatus.CLOSED - ] + return self._average_price(self.closed_orders) - return sum(closed_orders) / len(closed_orders) if closed_orders else 0.0 + @property + def curr_price(self) -> float: + last_bar = self.risk_bar - def add_order(self, order: Order) -> "Position": + return (2 * last_bar.close + last_bar.high + last_bar.low) / 4.0 + + @property + def curr_target(self) -> float: + targets = self.profit_target.targets + curr_price = self.curr_price + idx = 0 + + if self.side == PositionSide.LONG: + for i, target in enumerate(targets): + if curr_price > target: + idx = i + break + + elif self.side == PositionSide.SHORT: + for i, target in enumerate(targets): + if curr_price < target: + idx = i + break + + return targets[idx] + + @property + def is_valid(self) -> bool: if self.closed: + return self.size != 0 and self.open_timestamp < self.close_timestamp + + if self.side == PositionSide.LONG: + return self.take_profit > self.stop_loss + + if self.side == PositionSide.SHORT: + return self.take_profit < self.stop_loss + + return False + + def entry_order(self) -> Order: + price = round(self.signal.entry, self.signal.symbol.price_precision) + size = round( + max(self.initial_size, self.signal.symbol.min_position_size), + self.signal.symbol.position_precision, + ) + + return Order( + status=OrderStatus.PENDING, + price=price, + size=size, + ) + + def exit_order(self) -> Order: + size = self._average_size(self.open_orders) - self._average_size( + self.closed_orders + ) + price = self.position_risk.exit_price( + self.side, self.take_profit, self.stop_loss + ) + + return Order( + status=OrderStatus.PENDING, + price=price, + size=size, + ) + + def fill_order(self, order: Order) -> "Position": + if self.closed: + return self + + if order.status == OrderStatus.PENDING: return self - last_modified = datetime.now().timestamp() + execution_time = datetime.now().timestamp() + orders = (*self.orders, order) - if order.status == OrderStatus.PENDING or order.status == OrderStatus.EXECUTED: - take_profit_price = self.take_profit_strategy.next( - self.side, order.price, self.stop_loss_price + if order.status == OrderStatus.EXECUTED: + return replace( + self, + orders=orders, + profit_target=replace(self.profit_target, entry=order.price), + last_modified=execution_time, ) + if order.status == OrderStatus.CLOSED: return replace( self, orders=orders, - last_modified=last_modified, - take_profit_price=take_profit_price, + last_modified=execution_time, ) - if order.status == OrderStatus.CLOSED or order.status == OrderStatus.FAILED: + if order.status == OrderStatus.FAILED: return replace( self, orders=orders, - closed_timestamp=last_modified, - last_modified=last_modified, + last_modified=execution_time, ) - def next(self, ohlcvs: List[Tuple[OHLCV]]) -> "Position": - next_stop_loss_price, next_take_profit_price = self.risk_strategy.next( - self.side, - self.entry_price, - self.take_profit_price, - self.stop_loss_price, - ohlcvs, + def next( + self, ohlcv: OHLCV, ta: TechAnalysis, session_risk: SessionRiskType + ) -> "Position": + if self.closed or ohlcv.timestamp <= self.risk_bar.timestamp: + logger.warning("Position update ignored due to stale data.") + return self + + gap = ohlcv.timestamp - self.risk_bar.timestamp + + if gap > LATENCY_GAP_THRESHOLD * self.signal.timeframe.to_milliseconds(): + logger.warning(f"Position update ignored due to large latency gap: {gap}") + return self + + next_risk = self.position_risk.next(ohlcv) + next_position = replace(self, position_risk=next_risk) + + curr_price = next_position.curr_price + entry_price = next_position.entry_price + curr_pnl = next_position.curr_pnl + targets = next_position.profit_target.targets[1:] + raw_forecast = next_position.position_risk.forecast(steps=3) + + forecast = None + long = next_position.side == PositionSide.LONG + + if raw_forecast: + forecast = raw_forecast[-1] + + stp = ( + next_position.signal_risk.tp + if next_position.signal_risk.type != SignalRiskType.NONE + else targets[DEFAULT_TARGET_IDX + 1] ) + stp = np.clip(stp, targets[0], targets[-1]) + ftp = forecast if forecast else targets[DEFAULT_TARGET_IDX + 1] + ftp = np.clip(ftp, targets[0], targets[-1]) + sstp = ta.trend.resistance[-1] if long else ta.trend.support[-1] + sstp = np.clip(sstp, targets[0], targets[-1]) + + w_stp, w_sstp, w_ftp = 0.6, 0.1, 0.3 + + ttp = (w_stp * stp + w_ftp * ftp + w_sstp * sstp) / (w_stp + w_sstp + w_ftp) + + def target_filter(target, tp): + sl = next_position.stop_loss + curr_price = next_position.curr_price + + return ( + target > tp and target > sl and target > curr_price + if long + else target < tp and target < sl and target < curr_price + ) + + idx_rr = 0 + risk = abs(entry_price - next_position.stop_loss) + rr_factor = 1.5 + rr = rr_factor * risk + + for i, target in enumerate(targets): + reward = abs(target - entry_price) + + if reward > rr if long else reward < rr: + idx_rr = i + break + + idx_tg = 0 + for i, target in enumerate(targets): + if target_filter(target, ttp): + idx_tg = i + break + + tidx = max(DEFAULT_TARGET_IDX, idx_tg) + idx = idx_rr if tidx > len(targets) - 1 else tidx + trail_target = targets[max(0, idx - 1)] + exit_target = targets[max(0, idx - 2)] + next_tp = targets[max(0, idx)] + + pnl_perc = (curr_pnl / curr_price) * 100 + trl_dist = abs(curr_price - trail_target) + exit_dist = abs(curr_price - exit_target) + dist = abs(curr_price - entry_price) + + trl_ratio = trl_dist / entry_price + exit_ratio = exit_dist / entry_price + dist_ratio = dist / entry_price + is_exit = session_risk == SessionRiskType.EXIT + + trail_threshold = 0.0008 + + if dist > trl_dist and trl_ratio > trail_threshold: + logger.info("Activating trailing stop mechanism") + next_position = next_position.trail(ta) + + if is_exit: + exit_ratio = exit_dist / entry_price + dist_ratio = dist / entry_price + + if exit_ratio > 0.005: + next_position = next_position.trail(ta) + + logger.info( + f"TRAIL NEXT SL: {next_position.stop_loss:.6f}, " + f"CURR PRICE: {next_position.risk_bar.close:.6f}" + ) + else: + logger.info( + f"Exit condition not met: " + f"CURR_DIST: {dist:.6f} ({dist_ratio:.2%}), " + f"EXIT_DIST: {exit_dist:.6f} ({exit_ratio:.2%})" + ) + + next_sl = next_position.stop_loss + next_risk = next_position.position_risk + + if dist_ratio > 0.007: + half = 0.382 * dist + sl = curr_price + half if long else curr_price - half + next_sl = max(sl, next_sl) if long else min(sl, next_sl) + + next_risk = next_risk.assess( + next_position.side, + next_tp, + next_sl, + next_position.open_timestamp, + next_position.expiration, + ) + + if next_risk.type == PositionRiskType.TP: + next_risk = next_risk.reset() + index = 0 + + for i, target in enumerate(targets): + if target_filter(target, next_tp): + index = i + break - return replace( + idx = min(max(DEFAULT_TARGET_IDX, index + 1), len(targets) - 1) + next_tp = targets[idx] + + next_position = replace( self, - stop_loss_price=next_stop_loss_price, - take_profit_price=next_take_profit_price, + position_risk=next_risk, + _tp=next_tp, + _sl=next_sl, + last_modified=datetime.now().timestamp(), + ) + + logger.info( + f"SYMBOL: {next_position.signal.symbol.name}, SIDE: {next_position.side}, SIGNAL_RISK: {next_position.signal_risk.type}, TS: {ohlcv.timestamp}, GAP: {gap}ms, ENTRY: {next_position.entry_price}, CURR: {next_position.curr_price}, HIGH: {next_position.risk_bar.high}, LOW: {next_position.risk_bar.low}, CLOSE: {next_position.risk_bar.close}, PT: {next_position.curr_target}, SL: {next_position.stop_loss}, TP: {next_position.take_profit}, LLM_TP: {next_position.signal_risk.tp}, PnL%: {pnl_perc}, BREAK EVEN: {next_position.has_break_even}, RISK: {next_position.has_risk}" ) + return next_position + + def trail(self, ta: TechAnalysis) -> "Position": + prev_sl = self.stop_loss + next_sl = self.position_risk.sl_ats(self.side, ta, prev_sl) + + return replace(self, _sl=next_sl, last_modified=datetime.now().timestamp()) + + def theo_taker_fee(self, size: float, price: float) -> float: + return size * price * self.signal.symbol.taker_fee + + def theo_maker_fee(self, size: float, price: float) -> float: + return size * price * self.signal.symbol.maker_fee + + @staticmethod + def _average_size(orders: List[Order]) -> float: + total_size = sum(order.size for order in orders) + return total_size / len(orders) if orders else 0.0 + + @staticmethod + def _average_price(orders: List[Order]) -> float: + total_price = sum(order.price for order in orders) + return total_price / len(orders) if orders else 0.0 + def to_dict(self): return { "signal": self.signal.to_dict(), + "signal_risk": self.signal_risk.to_dict(), + "position_risk": self.position_risk.to_dict(), + "curr_target": self.curr_target, "side": str(self.side), - "pending_size": self.pending_size, - "filled_size": self.filled_size, + "size": self.size, "entry_price": self.entry_price, "exit_price": self.exit_price, "closed": self.closed, - "stop_loss_price": self.stop_loss_price, - "take_profit_price": self.take_profit_price, + "valid": self.is_valid, "pnl": self.pnl, - "open_timestamp": self.open_timestamp, + "fee": self.fee, + "take_profit": self.take_profit, + "stop_loss": self.stop_loss, "trade_time": self.trade_time, + "break_even": self.has_break_even, } def __str__(self): - return f"Position(signal={self.signal}, side={self.side}, pending_size={self.pending_size}, filled_size={self.filled_size}, entry_price={self.entry_price}, exit_price={self.exit_price}, take_profit_price={self.take_profit_price}, stop_loss_price={self.stop_loss_price}, trade_time={self.trade_time}, closed={self.closed})" + return f"signal={self.signal}, signal_risk={self.signal_risk.type}, position_risk={self.position_risk.type}, open_ohlcv={self.signal_bar}, close_ohlcv={self.risk_bar}, side={self.side}, size={self.size}, entry_price={self.entry_price}, exit_price={self.exit_price}, tp={self.take_profit}, sl={self.stop_loss}, pnl={self.pnl}, trade_time={self.trade_time}, closed={self.closed}, valid={self.is_valid}, break_even={self.has_break_even}" + + def __repr__(self): + return f"Position({self})" diff --git a/core/models/position_risk.py b/core/models/position_risk.py new file mode 100644 index 00000000..a0139e3c --- /dev/null +++ b/core/models/position_risk.py @@ -0,0 +1,445 @@ +from dataclasses import dataclass, field, replace +from typing import List, Tuple + +import numpy as np +from scipy.interpolate import UnivariateSpline +from scipy.signal import savgol_filter +from sklearn.cluster import KMeans +from sklearn.linear_model import SGDRegressor +from sklearn.metrics import ( + calinski_harabasz_score, + davies_bouldin_score, + silhouette_score, +) +from sklearn.preprocessing import MinMaxScaler, StandardScaler + +from .ohlcv import OHLCV +from .risk_type import PositionRiskType +from .side import PositionSide +from .ta import TechAnalysis + +TIME_THRESHOLD = 15000 +LOOKBACK = 6 + + +def optimize_params( + data: np.ndarray, n_clusters_range: Tuple[int, int] = (2, 10) +) -> int: + if data.ndim == 1: + data = data.reshape(-1, 1) + + scaler = MinMaxScaler(feature_range=(0, 1)) + X = scaler.fit_transform(data) + + best_score = float("-inf") + best_centroids = [] + + for n_clusters in range(*n_clusters_range): + kmeans = KMeans(n_clusters=n_clusters, n_init="auto", random_state=None) + kmeans.fit(X) + + if len(np.unique(kmeans.labels_)) < n_clusters: + continue + + sscore = silhouette_score(X, kmeans.labels_) + cscore = calinski_harabasz_score(X, kmeans.labels_) + db_score = davies_bouldin_score(X, kmeans.labels_) + + score = (sscore + cscore - db_score) / 3 + + if score > best_score: + best_score = score + best_centroids = scaler.inverse_transform(kmeans.cluster_centers_).flatten() + + return int(round(np.mean(best_centroids))) if len(best_centroids) > 1 else 2 + + +def optimize_window_polyorder(data: np.ndarray) -> Tuple[int, int]: + window_length = optimize_params(data) + + if window_length % 2 == 0: + window_length += 1 + + window_length = min(window_length, len(data)) + + polyorder_range = (2, window_length - 1 if window_length > 2 else 2) + polyorder = optimize_params(data, n_clusters_range=polyorder_range) + + polyorder = min(polyorder, window_length - 1) + + return window_length, polyorder + + +def smooth_savgol(*arrays: np.ndarray) -> List[np.ndarray]: + all_data = np.concatenate(arrays) + window_length, polyorder = optimize_window_polyorder(all_data) + return [ + savgol_filter( + array, + min(window_length, len(array)), + min(polyorder, min(window_length, len(array)) - 1), + ) + for array in arrays + ] + + +def smooth_spline(*arrays: np.ndarray, s: float = 1.0, k: int = 3) -> List[np.ndarray]: + return [ + UnivariateSpline(np.arange(len(array)), array, s=s, k=min(k, len(array) - 1))( + np.arange(len(array)) + ) + if len(array) > k + else array + for array in arrays + ] + + +class TaMixin: + @staticmethod + def _ats(closes: List[float], atr: List[float]) -> List[float]: + period = min(len(closes), len(atr)) + stop_prices = np.zeros(period) + + stop_prices[0] = closes[0] - atr[0] + + for i in range(1, period): + stop = atr[i] + + long_stop = closes[i] - stop + short_stop = closes[i] + stop + + prev_stop = stop_prices[i - 1] + prev_close = closes[i - 1] + + if closes[i] > prev_stop and prev_close > prev_stop: + stop_prices[i] = max(prev_stop, long_stop) + elif closes[i] < prev_stop and prev_close < prev_stop: + stop_prices[i] = min(prev_stop, short_stop) + elif closes[i] > prev_stop: + stop_prices[i] = long_stop + else: + stop_prices[i] = short_stop + + return stop_prices + + +@dataclass(frozen=True) +class PositionRisk(TaMixin): + model: SGDRegressor + scaler: StandardScaler + ohlcv: List[OHLCV] = field(default_factory=list) + type: PositionRiskType = PositionRiskType.NONE + trail_factor: float = field(default_factory=lambda: np.random.uniform(1.8, 2.2)) + + @property + def curr_bar(self): + return self.ohlcv[-1] + + def update_model(self): + if len(self.ohlcv) < 3: + return + + last_ohlcv = self.ohlcv[-3:] + + close = [ohlcv.close for ohlcv in last_ohlcv] + high = [ohlcv.high for ohlcv in last_ohlcv] + low = [ohlcv.low for ohlcv in last_ohlcv] + + hlcc4 = [(high[i] + low[i] + 2 * close[i]) / 4.0 for i in range(3)] + hlcc4_lagged_1 = [hlcc4[0]] + hlcc4[:-1] + hlcc4_lagged_2 = [hlcc4[0], hlcc4[1]] + hlcc4[:-2] + + true_range = [ + max( + high[i] - low[i], + abs(high[i] - close[i - 1]), + abs(low[i] - close[i - 1]), + ) + for i in range(1, 3) + ] + true_range.insert(0, true_range[0]) + + true_range_lagged_1 = [true_range[0]] + true_range[:-1] + true_range_lagged_2 = [true_range[0], true_range[1]] + true_range[:-2] + + hlcc4_diff = hlcc4[-1] - hlcc4_lagged_1[-1] + true_range_diff = true_range[-1] - true_range_lagged_1[-1] + + features = np.array( + [ + [ + hlcc4[-1], + hlcc4_lagged_1[-1], + hlcc4_lagged_2[-1], + true_range[-1], + true_range_lagged_1[-1], + true_range_lagged_2[-1], + hlcc4_diff, + true_range_diff, + ] + ] + ) + target = np.array([close[-1]]) + + features_scaled = self.scaler.transform(features) + + self.model.partial_fit(features_scaled, target) + + def forecast(self, steps: int = 3): + if len(self.ohlcv) < 1: + return [] + + self.update_model() + + last_ohlcv = self.ohlcv[-1] + + last_hlcc4 = (last_ohlcv.high + last_ohlcv.low + 2 * last_ohlcv.close) / 4.0 + last_true_range = max( + last_ohlcv.high - last_ohlcv.low, + abs(last_ohlcv.high - self.ohlcv[-2].close), + abs(last_ohlcv.low - self.ohlcv[-2].close), + ) + + last_hlcc4_lagged_1 = last_hlcc4 + last_hlcc4_lagged_2 = last_hlcc4 + last_true_range_lagged_1 = last_true_range + last_true_range_lagged_2 = last_true_range + + hlcc4_diff = last_hlcc4 - last_hlcc4_lagged_1 + true_range_diff = last_true_range - last_true_range_lagged_1 + + predictions = [] + + for _ in range(steps): + X = np.array( + [ + [ + last_hlcc4, + last_hlcc4_lagged_1, + last_hlcc4_lagged_2, + last_true_range, + last_true_range_lagged_1, + last_true_range_lagged_2, + hlcc4_diff, + true_range_diff, + ] + ] + ) + + X_scaled = self.scaler.transform(X) + + forecast = self.model.predict(X_scaled)[0] + predictions.append(forecast) + + last_hlcc4_lagged_2 = last_hlcc4_lagged_1 + last_hlcc4_lagged_1 = last_hlcc4 + last_hlcc4 = forecast + + last_true_range_lagged_2 = last_true_range_lagged_1 + last_true_range_lagged_1 = last_true_range + last_true_range = max( + forecast - last_ohlcv.close, + abs(forecast - self.ohlcv[-2].close), + abs(last_ohlcv.low - self.ohlcv[-2].close), + ) + + return predictions + + def next(self, bar: OHLCV): + ohlcv = self.ohlcv + [bar] + ohlcv.sort(key=lambda x: x.timestamp) + return replace(self, ohlcv=ohlcv) + + def reset(self): + return replace(self, type=PositionRiskType.NONE) + + def assess( + self, + side: PositionSide, + tp: float, + sl: float, + open_timestamp: float, + expiration: float, + ) -> "PositionRisk": + high, low = self.curr_bar.high, self.curr_bar.low + expiration = self.curr_bar.timestamp - open_timestamp - expiration + + if expiration >= 0: + return replace(self, type=PositionRiskType.TIME) + + if side == PositionSide.LONG: + if low < sl: + return replace(self, type=PositionRiskType.SL) + if high > tp: + return replace(self, type=PositionRiskType.TP) + + if side == PositionSide.SHORT: + if high > sl: + return replace(self, type=PositionRiskType.SL) + if low < tp: + return replace(self, type=PositionRiskType.TP) + + return replace(self, type=PositionRiskType.NONE) + + def exit_price(self, side: PositionSide, tp: float, sl: float) -> "float": + high, low, close = self.curr_bar.high, self.curr_bar.low, self.curr_bar.close + + if self.type == PositionRiskType.TP: + return min(tp, high) if side == PositionSide.LONG else max(tp, low) + + elif self.type == PositionRiskType.SL: + return ( + max(min(sl, high), low) + if side == PositionSide.LONG + else min(max(sl, low), high) + ) + + return close + + def sl_low(self, side: PositionSide, ta: TechAnalysis, sl: float) -> "float": + timestamps = np.array([candle.timestamp for candle in self.ohlcv]) + ts_diff = np.diff(timestamps) + + if ts_diff.sum() < TIME_THRESHOLD: + return sl + + max_lookback = max(len(timestamps), LOOKBACK) + + trend = ta.trend + + ll = np.array(trend.ll)[-max_lookback:] + hh = np.array(trend.hh)[-max_lookback:] + volatility = np.array(ta.volatility.yz)[-max_lookback:] + res = np.array(trend.resistance)[-max_lookback:] + sup = np.array(trend.support)[-max_lookback:] + + min_length = min(len(ll), len(hh), len(volatility), len(timestamps)) + + if min_length < 1: + return sl + + ll_smooth, hh_smooth, volatility_smooth = smooth_savgol(ll, hh, volatility) + + ll_smooth = ll_smooth[-min_length:] + hh_smooth = hh_smooth[-min_length:] + volatility_smooth = self.trail_factor * volatility_smooth[-min_length:] + + ll_atr = ll_smooth - volatility_smooth + hh_atr = hh_smooth + volatility_smooth + + if side == PositionSide.LONG: + return max(sl, np.max(sup), np.max(ll_atr)) + + if side == PositionSide.SHORT: + return min(sl, np.min(res), np.min(hh_atr)) + + return sl + + def tp_low(self, side: PositionSide, ta: TechAnalysis, tp: float) -> "float": + timestamps = np.array([candle.timestamp for candle in self.ohlcv]) + ts_diff = np.diff(timestamps) + + if ts_diff.sum() < TIME_THRESHOLD: + return tp + + max_lookback = max(len(timestamps), LOOKBACK) + + trend = ta.trend + + ll = np.array(trend.ll)[-max_lookback:] + hh = np.array(trend.hh)[-max_lookback:] + volatility = np.array(ta.volatility.yz)[-max_lookback:] + res = np.array(trend.resistance)[-max_lookback:] + sup = np.array(trend.support)[-max_lookback:] + + min_length = min(len(ll), len(hh), len(volatility), len(timestamps)) + + if min_length < 1: + return tp + + ll_smooth, hh_smooth, volatility_smooth = smooth_savgol(ll, hh, volatility) + + ll_smooth = ll_smooth[-min_length:] + hh_smooth = hh_smooth[-min_length:] + volatility_smooth = self.trail_factor * volatility_smooth[-min_length:] + + ll_atr = ll_smooth - volatility_smooth + hh_atr = hh_smooth + volatility_smooth + + if side == PositionSide.LONG: + return min(np.min(res), np.min(hh_atr)) + elif side == PositionSide.SHORT: + return max(np.max(sup), np.max(ll_atr)) + + return tp + + def sl_ats(self, side: PositionSide, ta: TechAnalysis, sl: float) -> "float": + timestamps = np.array([candle.timestamp for candle in self.ohlcv]) + ts_diff = np.diff(timestamps) + + if ts_diff.sum() < TIME_THRESHOLD: + return sl + + max_lookback = max(len(timestamps), LOOKBACK) + + close = np.array([candle.close for candle in self.ohlcv]) + low = np.array([candle.low for candle in self.ohlcv]) + high = np.array([candle.high for candle in self.ohlcv]) + volatility = np.array(ta.volatility.gkyz)[-max_lookback:] + + min_length = min(len(close), len(volatility), len(high), len(low)) + + if min_length < 3: + return sl + + close_smooth, volatility_smooth, high_smooth, low_smooth = smooth_savgol( + close, volatility, high, low + ) + + volatility_smooth = self.trail_factor * volatility_smooth[-min_length:] + close_smooth = close_smooth[-min_length:] + + ats = self._ats(close_smooth, volatility_smooth) + + rising_low = low_smooth[-1] > low_smooth[-2] and low_smooth[-2] > low_smooth[-3] + failing_high = ( + high_smooth[-1] < high_smooth[-2] and high_smooth[-2] < high_smooth[-3] + ) + + bullish = rising_low and (ta.trend.dmi[-1] > 0.0 or ta.momentum.cci[-1] > 100.0) + bearish = failing_high and ( + ta.trend.dmi[-1] <= 0.0 or ta.momentum.cci[-1] < -100.0 + ) + + if side == PositionSide.LONG: + adjusted_sl = ( + min(low_smooth[-1], np.max(ats)) + if bullish + else min(low_smooth[-1], ats[-1]) + ) + + return max(sl, adjusted_sl) + + if side == PositionSide.SHORT: + adjusted_sl = ( + max(high_smooth[-1], np.min(ats)) + if bearish + else max(high_smooth[-1], ats[-1]) + ) + + return min(sl, adjusted_sl) + + return sl + + def to_dict(self): + return { + "type": self.type, + "trail_factor": self.trail_factor, + "ohlcv": self.curr_bar.to_dict(), + } + + def __str__(self): + return f"type={self.type}, ohlcv={self.curr_bar}" + + def __repr__(self): + return f"PositionRisk({self})" diff --git a/core/models/profit_target.py b/core/models/profit_target.py new file mode 100644 index 00000000..72c5fe04 --- /dev/null +++ b/core/models/profit_target.py @@ -0,0 +1,136 @@ +from dataclasses import dataclass +from functools import cached_property + +import numpy as np + +from core.models.side import SignalSide + + +@dataclass(frozen=True) +class ProfitTarget: + side: SignalSide + entry: float + volatility: float + noise_sigma: float = 0.001 + + @cached_property + def context_factor(self): + return 1.0 if self.side == SignalSide.BUY else -1.0 + + @cached_property + def targets(self): + ratios = [ + 0.236, + 0.272, + 0.382, + 0.414, + 0.618, + 0.786, + 0.886, + 1.0, + 1.118, + 1.236, + 1.382, + 1.618, + 1.786, + 1.886, + 2.0, + 2.236, + 2.382, + 2.618, + 2.786, + 3.0, + 3.236, + 3.382, + 3.618, + 4.0, + 4.236, + 4.382, + 4.618, + 5.0, + 5.236, + 5.382, + 5.618, + 6.0, + 6.236, + 6.382, + 6.618, + 7.0, + 7.236, + 7.382, + 7.618, + 8.0, + 8.236, + 8.382, + 8.618, + 9.0, + 9.236, + 9.382, + 9.618, + 10.0, + 10.236, + 10.382, + 10.618, + 11.0, + 11.236, + 11.382, + 11.618, + 12.0, + 12.236, + 12.382, + 12.618, + 13.0, + 13.236, + 13.382, + 13.618, + 14.0, + 14.236, + 14.382, + 14.618, + 15.0, + 15.236, + 15.382, + 15.618, + 16.0, + 16.236, + 16.382, + 16.618, + 17.0, + 17.236, + 17.382, + 17.618, + 18.0, + ] + + levels = sorted( + { + self._pt(ratios[i], ratios[j]) + for i in range(len(ratios)) + for j in range(i, len(ratios)) + } + ) + + reverse = self.context_factor == -1 + return levels if not reverse else levels[::-1] + + @cached_property + def last(self): + return self.targets[-1] + + def _pt(self, min_scale: float, max_scale: float) -> float: + scale = np.random.uniform(min_scale, max_scale) + noise = np.random.lognormal(mean=0, sigma=self.noise_sigma) - 1 + target_price = self.entry * ( + 1 + self.volatility * self.context_factor * scale + noise + ) + + return ( + max(target_price, self.entry) + if self.context_factor == 1 + else min(target_price, self.entry) + ) + + def to_dict(self): + return { + "targets": self.targets, + } diff --git a/core/models/risk_type.py b/core/models/risk_type.py index a546275c..14c4c4f3 100644 --- a/core/models/risk_type.py +++ b/core/models/risk_type.py @@ -1,12 +1,43 @@ -from enum import Enum +import re +from enum import Enum, auto -class RiskType(Enum): +class PositionRiskType(Enum): + NONE = 0 TIME = 1 SIGNAL = 2 SL = 3 TP = 4 - REVERSE = 5 + + def __str__(self): + return self.name.upper() + + +class SignalRiskType(Enum): + NONE = auto() + LOW = auto() + VERY_LOW = auto() + MODERATE = auto() + HIGH = auto() + VERY_HIGH = auto() + UNKNOWN = auto() + + @classmethod + def from_string(cls, risk_string): + match = re.search( + r"\b(NONE|LOW|VERY_LOW|MODERATE|HIGH|VERY_HIGH)\b(?![a-zA-Z0-9])", + risk_string, + ) + + return cls[match.group()] if match else cls.NONE + + def __str__(self): + return self.name.upper() + + +class SessionRiskType(Enum): + CONTINUE = auto() + EXIT = auto() def __str__(self): return self.name.upper() diff --git a/core/models/side.py b/core/models/side.py index b2c4b911..435ee40f 100644 --- a/core/models/side.py +++ b/core/models/side.py @@ -1,17 +1,17 @@ -from enum import Enum +from enum import Enum, auto class PositionSide(Enum): - LONG = "long" - SHORT = "short" + LONG = auto() + SHORT = auto() def __str__(self): - return self.value + return self.name.upper() class SignalSide(Enum): - BUY = "buy" - SELL = "sell" + BUY = auto() + SELL = auto() def __str__(self): - return self.value.upper() + return self.name.upper() diff --git a/core/models/signal.py b/core/models/signal.py index 73084185..d77f5082 100644 --- a/core/models/signal.py +++ b/core/models/signal.py @@ -1,5 +1,6 @@ -from dataclasses import dataclass +from dataclasses import dataclass, field +from .ohlcv import OHLCV from .side import SignalSide from .strategy import Strategy from .symbol import Symbol @@ -12,12 +13,10 @@ class Signal: timeframe: Timeframe strategy: Strategy side: SignalSide - - def __str__(self) -> str: - return f"{self.symbol.name}_{self.timeframe}_{self.side}{self.strategy}" - - def __repr__(self) -> str: - return f"Signal(symbol={self.symbol}, timeframe={self.timeframe}, side={self.side}, strategy={self.strategy})" + ohlcv: OHLCV + entry: float = field(default_factory=lambda: 0.0) + exit: float = field(default_factory=lambda: 0.0) + stop_loss: float = field(default_factory=lambda: 0.0) def __hash__(self) -> int: return hash((self.symbol, self.timeframe, self.strategy, self.side)) @@ -39,4 +38,13 @@ def to_dict(self): "timeframe": str(self.timeframe), "strategy": str(self.strategy), "side": str(self.side), + "ohlcv": self.ohlcv.to_dict(), + "entry": self.entry, + "stop_loss": self.stop_loss, } + + def __str__(self) -> str: + return f"{self.symbol.name}_{self.timeframe}_{self.side}{self.strategy}" + + def __repr__(self) -> str: + return f"Signal({self})" diff --git a/core/models/signal_risk.py b/core/models/signal_risk.py new file mode 100644 index 00000000..ceb04f7b --- /dev/null +++ b/core/models/signal_risk.py @@ -0,0 +1,24 @@ +from dataclasses import dataclass +from typing import Optional + +from core.models.risk_type import SignalRiskType + + +@dataclass(frozen=True) +class SignalRisk: + type: SignalRiskType = SignalRiskType.NONE + tp: Optional[float] = None + sl: Optional[float] = None + + def to_dict(self): + return { + "type": self.type, + "tp": self.tp, + "sl": self.sl, + } + + def __str__(self): + return f"type={self.type}, tp={self.tp}, sl={self.sl}" + + def __repr__(self): + return f"SignalRisk({self})" diff --git a/core/models/smooth.py b/core/models/smooth.py index eab0de98..8ff00afc 100644 --- a/core/models/smooth.py +++ b/core/models/smooth.py @@ -11,6 +11,17 @@ class Smooth(Enum): ZLEMA = 7 LSMA = 8 TEMA = 9 + DEMA = 10 + UTLS = 11 + + def __str__(self): + return self.name.upper() + + +class SmoothATR(Enum): + EMA = 1 + SMMA = 3 + UTLS = 11 def __str__(self): return self.name.upper() diff --git a/core/models/strategy.py b/core/models/strategy.py index 44294643..213d0958 100644 --- a/core/models/strategy.py +++ b/core/models/strategy.py @@ -1,24 +1,17 @@ from dataclasses import dataclass from enum import Enum, auto -import orjson as json - from .indicator import Indicator from .parameter import Parameter -class StrategyType(Enum): - TREND = auto() - - class StrategyOptimizationType(Enum): GENETIC = auto() @dataclass(frozen=True) class Strategy: - type: StrategyType - entry: Indicator + signal: Indicator confirm: Indicator pulse: Indicator baseline: Indicator @@ -27,20 +20,13 @@ class Strategy: @property def parameters(self): - signal_data = json.dumps(self.entry.to_dict()) - confirmation_data = json.dumps(self.confirm.to_dict()) - pulse_data = json.dumps(self.pulse.to_dict()) - baseline_data = json.dumps(self.baseline.to_dict()) - stoploss_data = json.dumps(self.stop_loss.to_dict()) - exit_data = json.dumps(self.exit.to_dict()) - return ( - signal_data, - confirmation_data, - pulse_data, - baseline_data, - stoploss_data, - exit_data, + self.signal.to_dict(), + self.confirm.to_dict(), + self.pulse.to_dict(), + self.baseline.to_dict(), + self.stop_loss.to_dict(), + self.exit.to_dict(), ) def _format_parameters(self, indicator): @@ -59,10 +45,8 @@ def _format_parameters(self, indicator): return parameters if parameters else "NONE" def __str__(self) -> str: - entry_ = f"_SGNL{self.entry.type}:{self._format_parameters(self.entry)}" - confirmation_ = ( - f"_CNFRM{self.confirm.type}:{self._format_parameters(self.confirm)}" - ) + signal_ = f"_SGNL{self.signal.type}:{self._format_parameters(self.signal)}" + confirm_ = f"_CNFRM{self.confirm.type}:{self._format_parameters(self.confirm)}" pulse_ = f"_PLS{self.pulse.type}:{self._format_parameters(self.pulse)}" baseline_ = ( f"_BSLN{self.baseline.type}:{self._format_parameters(self.baseline)}" @@ -72,7 +56,7 @@ def __str__(self) -> str: ) exit_ = f"_EXT{self.exit.type}:{self._format_parameters(self.exit)}" - return entry_ + confirmation_ + pulse_ + baseline_ + stop_loss + exit_ + return f"{signal_}{confirm_}{pulse_}{baseline_}{stop_loss}{exit_}" def __hash__(self) -> int: return hash(str(self)) diff --git a/core/models/strategy_ref.py b/core/models/strategy_ref.py index 87af78ca..9dae9a0a 100644 --- a/core/models/strategy_ref.py +++ b/core/models/strategy_ref.py @@ -1,8 +1,11 @@ import logging +import typing from dataclasses import dataclass +from functools import cached_property from typing import Optional, Union -from wasmtime import Instance, Store +if typing.TYPE_CHECKING: + from wasmtime import Instance, Store from core.events.signal import ( ExitLongSignalReceived, @@ -12,7 +15,8 @@ ) from core.models.action import Action from core.models.ohlcv import OHLCV -from core.models.signal import Signal, SignalSide +from core.models.side import SignalSide +from core.models.signal import Signal from core.models.strategy import Strategy from core.models.symbol import Symbol from core.models.timeframe import Timeframe @@ -30,20 +34,21 @@ @dataclass(frozen=True) class StrategyRef: id: int - instance_ref: Instance - store_ref: Store + instance_ref: "Instance" + store_ref: "Store" + + @cached_property + def exports(self): + return self.instance_ref.exports(self.store_ref) def unregister(self): - exports = self.instance_ref.exports(self.store_ref) - exports["unregister_strategy"](self.store_ref, self.id) + self.exports["strategy_unregister"](self.store_ref, self.id) self.store_ref.gc() def next( self, symbol: Symbol, timeframe: Timeframe, strategy: Strategy, ohlcv: OHLCV ) -> Optional[SignalEvent]: - exports = self.instance_ref.exports(self.store_ref) - - [raw_action, price] = exports["strategy_next"]( + strategy_args = [ self.store_ref, self.id, ohlcv.timestamp, @@ -52,48 +57,67 @@ def next( ohlcv.low, ohlcv.close, ohlcv.volume, - ) + ] + + raw_action, price = self.exports["strategy_next"](*strategy_args) action = Action.from_raw(raw_action) long_stop_loss, short_stop_loss = 0.0, 0.0 if action in (Action.GO_LONG, Action.GO_SHORT): - [long_stop_loss, short_stop_loss] = exports["strategy_stop_loss"]( - self.store_ref, self.id + long_stop_loss, short_stop_loss = self.exports["strategy_stop_loss"]( + *strategy_args ) - signal = Signal( - symbol, - timeframe, - strategy, + side = ( SignalSide.BUY if action in (Action.GO_LONG, Action.EXIT_SHORT) - else SignalSide.SELL, + else SignalSide.SELL ) action_event_map = { Action.GO_LONG: GoLongSignalReceived( - signal=signal, - ohlcv=ohlcv, - entry_price=price, - stop_loss=long_stop_loss, + signal=Signal( + symbol, + timeframe, + strategy, + side, + ohlcv, + entry=price, + stop_loss=long_stop_loss, + ), ), Action.GO_SHORT: GoShortSignalReceived( - signal=signal, - ohlcv=ohlcv, - entry_price=price, - stop_loss=short_stop_loss, + signal=Signal( + symbol, + timeframe, + strategy, + side, + ohlcv, + entry=price, + stop_loss=short_stop_loss, + ), ), Action.EXIT_LONG: ExitLongSignalReceived( - signal=signal, - ohlcv=ohlcv, - exit_price=price, + signal=Signal( + symbol, + timeframe, + strategy, + side, + ohlcv, + exit=price, + ), ), Action.EXIT_SHORT: ExitShortSignalReceived( - signal=signal, - ohlcv=ohlcv, - exit_price=price, + signal=Signal( + symbol, + timeframe, + strategy, + side, + ohlcv, + exit=price, + ), ), } diff --git a/core/models/strategy_type.py b/core/models/strategy_type.py new file mode 100644 index 00000000..d0f1947c --- /dev/null +++ b/core/models/strategy_type.py @@ -0,0 +1,11 @@ + + +from enum import Enum, auto + + +class StrategyType(Enum): + TREND_FOLLOW = "Trend Follow" + CONTRARIAN = "Contrarian" + + def __str__(self): + return self.value \ No newline at end of file diff --git a/core/models/symbol.py b/core/models/symbol.py index 3a43a1c7..ac31de69 100644 --- a/core/models/symbol.py +++ b/core/models/symbol.py @@ -10,12 +10,13 @@ class Symbol: min_price_size: float position_precision: int price_precision: int + max_leverage: float + + def __hash__(self) -> int: + return hash(self.name) def __str__(self): return self.name def __repr__(self) -> str: - return f"Symbol({self.name})" - - def __hash__(self) -> int: - return hash(self.name) + return f"Symbol({self})" diff --git a/core/models/ta.py b/core/models/ta.py new file mode 100644 index 00000000..2b2c7cad --- /dev/null +++ b/core/models/ta.py @@ -0,0 +1,147 @@ +from dataclasses import dataclass +from typing import Any, List + + +@dataclass(frozen=True) +class VolumeAnalysis: + obv: List[float] + vo: List[float] + nvol: List[float] + mfi: List[float] + vwap: List[float] + + def __str__(self) -> str: + return f"obv={self.obv}, vo={self.vo}, nvol={self.nvol}, mfi={self.mfi}, vwap={self.vwap}" + + def __repr__(self) -> str: + return f"VolumeAnalysis({self})" + + +@dataclass(frozen=True) +class VolatilityAnalysis: + tr: List[float] + gkyz: List[float] + yz: List[float] + upb: List[float] + lwb: List[float] + ebb: List[float] + ekch: List[float] + + def __str__(self) -> str: + return f"tr={self.tr}, yz={self.yz}, upb={self.upb}, lwb={self.lwb}, ebb={self.ebb}, ekch={self.ekch}" + + def __repr__(self) -> str: + return f"VolatilityAnalysis({self})" + + +@dataclass(frozen=True) +class TrendAnalysis: + fma: List[float] + sma: List[float] + macd: List[float] + ppo: List[float] + hh: List[float] + ll: List[float] + support: List[float] + resistance: List[float] + dmi: List[float] + close: List[float] + hlc3: List[float] + hlcc4: List[float] + + def __str__(self) -> str: + return f"fma={self.fma}, sma={self.sma}, macd={self.macd}, ppo={self.ppo}, hh={self.hh}, ll={self.ll}, support={self.support}, resistance={self.resistance}, dmi={self.dmi}, close={self.close}, hlc3={self.hlc3}, hlcc4={self.hlcc4}" + + def __repr__(self) -> str: + return f"TrendAnalysis({self})" + + +@dataclass(frozen=True) +class MomentumAnalysis: + froc: List[float] + sroc: List[float] + cci: List[float] + + def __str__(self) -> str: + return f"froc={self.froc}, sroc={self.sroc}, cci={self.cci}" + + def __repr__(self) -> str: + return f"MomentumAnalysis({self})" + + +@dataclass(frozen=True) +class OscillatorAnalysis: + frsi: List[float] + srsi: List[float] + k: List[float] + d: List[float] + + def __str__(self) -> str: + return f"frsi={self.frsi}, srsi={self.srsi}, k={self.k}, d={self.d}" + + def __repr__(self) -> str: + return f"OscillatorAnalysis({self})" + + +@dataclass(frozen=True) +class TechAnalysis: + trend: TrendAnalysis + momentum: MomentumAnalysis + oscillator: OscillatorAnalysis + volume: VolumeAnalysis + volatility: VolatilityAnalysis + + @classmethod + def from_list(cls, data: List[Any]) -> "TechAnalysis": + ( + frsi, + srsi, + fma, + sma, + froc, + sroc, + macd, + ppo, + cci, + obv, + vo, + nvol, + mfi, + tr, + gkyz, + yz, + upb, + lwb, + ebb, + ekch, + k, + d, + hh, + ll, + support, + resistance, + dmi, + vwap, + close, + hlc3, + hlcc4, + ) = data + + trend = TrendAnalysis( + fma, sma, macd, ppo, hh, ll, support, resistance, dmi, close, hlc3, hlcc4 + ) + momentum = MomentumAnalysis(froc, sroc, cci) + oscillator = OscillatorAnalysis(frsi, srsi, k, d) + volume = VolumeAnalysis(obv, vo, nvol, mfi, vwap) + volatility = VolatilityAnalysis(tr, gkyz, yz, upb, lwb, ebb, ekch) + + return cls(trend, momentum, oscillator, volume, volatility) + + def __str__(self) -> str: + return ( + f"trend={self.trend}, momentum={self.momentum}, oscillator={self.oscillator}, " + f"volume={self.volume}, volatility={self.volatility}" + ) + + def __repr__(self) -> str: + return f"TechAnalysis({self})" diff --git a/core/models/timeframe.py b/core/models/timeframe.py index 0981e395..d6c032b4 100644 --- a/core/models/timeframe.py +++ b/core/models/timeframe.py @@ -17,11 +17,17 @@ def from_raw(cls, value: str) -> "Timeframe": raise ValueError(f"No matching Timeframe for value: {value}") - def __str__(self): - return self.value - - def __repr__(self) -> str: - return f"Timeframe({self.value})" + def to_milliseconds(self) -> int: + value = self.value + + if value.endswith("m"): + minutes = int(value[:-1]) + return minutes * 60 * 1000 + elif value.endswith("h"): + hours = int(value[:-1]) + return hours * 60 * 60 * 1000 + else: + raise ValueError(f"Unsupported timeframe value: {value}") def __lt__(self, other): if not isinstance(other, Timeframe): @@ -33,3 +39,9 @@ def __lt__(self, other): def __hash__(self) -> int: return hash(self.value) + + def __str__(self): + return self.value + + def __repr__(self) -> str: + return f"Timeframe({self})" diff --git a/core/models/timeseries_ref.py b/core/models/timeseries_ref.py new file mode 100644 index 00000000..f95eed30 --- /dev/null +++ b/core/models/timeseries_ref.py @@ -0,0 +1,117 @@ +import typing +from dataclasses import dataclass +from functools import cached_property +from typing import Any, List, Optional, Type + +import orjson as json + +if typing.TYPE_CHECKING: + from wasmtime import Instance, Store + +from .ohlcv import OHLCV +from .ta import TechAnalysis + + +@dataclass(frozen=True) +class TimeSeriesRef: + id: int + instance_ref: "Instance" + store_ref: "Store" + + @cached_property + def exports(self): + return self.instance_ref.exports(self.store_ref) + + def unregister(self): + self.exports["timeseries_unregister"](self.store_ref, self.id) + self.store_ref.gc() + + def add(self, bar: OHLCV): + res, _ = self.exports["timeseries_add"]( + self.store_ref, + self.id, + bar.timestamp, + bar.open, + bar.high, + bar.low, + bar.close, + bar.volume, + ) + + if res == -1: + raise ValueError("Can't add new market bar") + + def next_bar(self, bar: OHLCV) -> Optional[OHLCV]: + return self._get_bar("next_bar", bar) + + def prev_bar(self, bar: OHLCV) -> Optional[OHLCV]: + return self._get_bar("prev_bar", bar) + + def back_n_bars(self, bar: OHLCV, n: int) -> List[OHLCV]: + ptr, length = self.exports["timeseries_back_n_bars"]( + self.store_ref, + self.id, + bar.timestamp, + bar.open, + bar.high, + bar.low, + bar.close, + bar.volume, + n, + ) + + buff = self._read_from_memory(ptr, length) + + return self._deserialize(buff, OHLCV) or [] + + def ta(self, bar: OHLCV) -> Optional[TechAnalysis]: + ptr, length = self.exports["timeseries_ta"]( + self.store_ref, + self.id, + bar.timestamp, + bar.open, + bar.high, + bar.low, + bar.close, + bar.volume, + ) + + buff = self._read_from_memory(ptr, length) + + return self._deserialize(buff, TechAnalysis) + + def _get_bar(self, method: str, bar: OHLCV) -> Optional[OHLCV]: + ptr, length = self.exports[f"timeseries_{method}"]( + self.store_ref, + self.id, + bar.timestamp, + bar.open, + bar.high, + bar.low, + bar.close, + bar.volume, + ) + + buff = self._read_from_memory(ptr, length) + + return self._deserialize(buff, OHLCV) + + def _read_from_memory(self, ptr: int, length: int) -> bytes: + if ptr == -1 and length == 0: + return None + + return self.exports["memory"].data_ptr(self.store_ref)[ptr : ptr + length] + + def _deserialize(self, buff: bytes, data_class: Type[Any]) -> Optional[Any]: + try: + raw_data = json.loads("".join(chr(val) for val in buff)) + + if isinstance(raw_data, dict): + return data_class.from_list(raw_data.values()) + elif isinstance(raw_data, list): + return [data_class.from_list(d.values()) for d in raw_data] + else: + raise ValueError("Unexpected data format") + + except Exception: + return None diff --git a/core/models/wasm_type.py b/core/models/wasm_type.py new file mode 100644 index 00000000..fd390c4a --- /dev/null +++ b/core/models/wasm_type.py @@ -0,0 +1,6 @@ +from enum import Enum, auto + + +class WasmType(Enum): + TREND = auto() + TIMESERIES = auto() diff --git a/core/queries/base.py b/core/queries/base.py index cd8141a5..ffd9dbad 100644 --- a/core/queries/base.py +++ b/core/queries/base.py @@ -1,6 +1,6 @@ import asyncio from dataclasses import dataclass, field -from enum import Enum +from enum import Enum, auto from typing import Generic, TypeVar from core.events.base import Event, EventMeta @@ -9,13 +9,16 @@ class QueryGroup(Enum): - account = "account" - broker = "broker" - position = "position" - portfolio = "portfolio" + account = auto() + broker = auto() + position = auto() + portfolio = auto() + copilot = auto() + market = auto() + ta = auto() def __str__(self): - return self.value + return self.name @dataclass(frozen=True) diff --git a/core/queries/copilot.py b/core/queries/copilot.py new file mode 100644 index 00000000..fc846b80 --- /dev/null +++ b/core/queries/copilot.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass, field +from typing import List + +from core.events.base import EventMeta +from core.models.ohlcv import OHLCV +from core.models.risk_type import SessionRiskType +from core.models.side import PositionSide +from core.models.signal import Signal +from core.models.signal_risk import SignalRisk +from core.models.ta import TechAnalysis + +from .base import Query, QueryGroup + + +@dataclass(frozen=True) +class EvaluateSignal(Query[SignalRisk]): + signal: Signal + prev_bar: List[OHLCV] + ta: TechAnalysis + meta: EventMeta = field( + default_factory=lambda: EventMeta(priority=5, group=QueryGroup.copilot), + init=False, + ) + + +@dataclass(frozen=True) +class EvaluateSession(Query[SessionRiskType]): + side: PositionSide + session: List[OHLCV] + ta: TechAnalysis + meta: EventMeta = field( + default_factory=lambda: EventMeta(priority=1, group=QueryGroup.copilot), + init=False, + ) diff --git a/core/queries/ohlcv.py b/core/queries/ohlcv.py new file mode 100644 index 00000000..0c05945b --- /dev/null +++ b/core/queries/ohlcv.py @@ -0,0 +1,55 @@ +from dataclasses import dataclass, field +from typing import List + +from core.events.base import EventMeta +from core.models.ohlcv import OHLCV +from core.models.symbol import Symbol +from core.models.ta import TechAnalysis +from core.models.timeframe import Timeframe + +from .base import Query, QueryGroup + + +@dataclass(frozen=True) +class NextBar(Query[OHLCV]): + symbol: Symbol + timeframe: Timeframe + ohlcv: OHLCV + meta: EventMeta = field( + default_factory=lambda: EventMeta(priority=3, group=QueryGroup.market), + init=False, + ) + + +@dataclass(frozen=True) +class PrevBar(Query[OHLCV]): + symbol: Symbol + timeframe: Timeframe + ohlcv: OHLCV + meta: EventMeta = field( + default_factory=lambda: EventMeta(priority=4, group=QueryGroup.market), + init=False, + ) + + +@dataclass(frozen=True) +class BackNBars(Query[List[OHLCV]]): + symbol: Symbol + timeframe: Timeframe + ohlcv: OHLCV + n: int + meta: EventMeta = field( + default_factory=lambda: EventMeta(priority=4, group=QueryGroup.market), + init=False, + ) + + +@dataclass(frozen=True) +class TA(Query[TechAnalysis]): + symbol: Symbol + timeframe: Timeframe + ohlcv: OHLCV + meta: EventMeta = field( + default_factory=lambda: EventMeta(priority=2, group=QueryGroup.ta), + init=False, + ) diff --git a/exchange/_bybit.py b/exchange/_bybit.py index fe0f55ae..51e0f3b9 100644 --- a/exchange/_bybit.py +++ b/exchange/_bybit.py @@ -221,8 +221,9 @@ def close_half_position(self, symbol: Symbol, side: PositionSide): @retry(max_retries=MAX_RETRIES, handled_exceptions=EXCEPTIONS) def fetch_position(self, symbol: Symbol, side: PositionSide): positions = self.connector.fetch_positions([symbol.name]) + side = str(side).lower() position = next( - iter([position for position in positions if position["side"] == str(side)]), + iter([position for position in positions if position["side"] == side]), None, ) @@ -324,6 +325,7 @@ def _create_symbol(self, market): min_price_size, position_precision, price_precision, + max_leverage, ) = self._get_symbol_meta(market) return Symbol( @@ -334,6 +336,7 @@ def _create_symbol(self, market): min_price_size, position_precision, price_precision, + max_leverage, ) def _get_symbol_meta(self, market): @@ -343,6 +346,7 @@ def _get_symbol_meta(self, market): limits = market.get("limits", {}) min_position_size = limits.get("amount", {}).get("min", 0) min_position_price = limits.get("price", {}).get("min", 0) + max_leverage = limits.get("leverage", {}).get("max", 1) precision = market.get("precision", {}) position_precision = precision.get("amount", 0) @@ -355,6 +359,7 @@ def _get_symbol_meta(self, market): min_position_price, int(abs(math.log10(position_precision))), int(abs(math.log10(price_precision))), + max_leverage, ) def _create_order( diff --git a/exchange/_bybit_ws.py b/exchange/_bybit_ws.py index 79d858f2..4cd19efe 100644 --- a/exchange/_bybit_ws.py +++ b/exchange/_bybit_ws.py @@ -52,7 +52,6 @@ def __new__(cls, wss: str): cls._instance._channels = set() cls._instance._receive_semaphore = asyncio.Semaphore(1) cls._instance._lock = asyncio.Lock() - cls.ping_task = None return cls._instance @@ -64,7 +63,6 @@ async def _connect_to_websocket(self): ping_timeout=15, close_timeout=None, ) - await self._resubscribe() @retry( @@ -82,10 +80,8 @@ async def run(self): await self._connect_to_websocket() async def close(self): - if not self.ws or not self.ws.open: - return - - await self.ws.close() + if self.ws and self.ws.open: + await self.ws.close() @retry( max_retries=13, @@ -101,14 +97,17 @@ async def receive(self, symbol, timeframe): data = json.loads(message) if self.TOPIC_KEY not in data: - return - - topic = data["topic"].split(".") + continue - if symbol.name == topic[2] and timeframe == self.TIMEFRAMES[topic[1]]: - ohlcv = data[self.DATA_KEY][0] + topic = data[self.TOPIC_KEY].split(".") - return Bar(OHLCV.from_dict(ohlcv), ohlcv[self.CONFIRM_KEY]) + if symbol.name == topic[2] and timeframe == self.TIMEFRAMES.get( + topic[1] + ): + return [ + Bar(OHLCV.from_dict(ohlcv), ohlcv.get(self.CONFIRM_KEY)) + for ohlcv in data.get(self.DATA_KEY, {}) + ] async def subscribe(self, symbol, timeframe): async with self._lock: @@ -133,7 +132,7 @@ async def _subscribe(self, symbol, timeframe): logger.info(f"Subscribe to: {subscribe_message}") await self.ws.send(json.dumps(subscribe_message)) except Exception as e: - logger.error(e) + logger.error(f"Failed to send subscribe message: {e}") async def _unsubscribe(self, symbol, timeframe): if not self.ws or not self.ws.open: @@ -146,7 +145,7 @@ async def _unsubscribe(self, symbol, timeframe): logger.info(f"Unsubscribe from: {unsubscribe_message}") await self.ws.send(json.dumps(unsubscribe_message)) except Exception as e: - logger.error(e) + logger.error(f"Failed to send unsubscribe message: {e}") async def _resubscribe(self): async with self._lock: diff --git a/executor/_factory.py b/executor/_factory.py index 8c66493e..d5ebeb96 100644 --- a/executor/_factory.py +++ b/executor/_factory.py @@ -4,8 +4,8 @@ from core.models.symbol import Symbol from core.models.timeframe import Timeframe -from ._market_order_actor import MarketOrderActor -from ._paper_order_actor import PaperOrderActor +from ._market_actor import MarketOrderActor +from ._paper_actor import PaperOrderActor class OrderExecutorActorFactory(AbstractExecutorActorFactory): diff --git a/executor/_market_actor.py b/executor/_market_actor.py new file mode 100644 index 00000000..808a7104 --- /dev/null +++ b/executor/_market_actor.py @@ -0,0 +1,69 @@ +import logging +from typing import Union + +from core.actors import StrategyActor +from core.commands.broker import ClosePosition, OpenPosition +from core.events.position import ( + BrokerPositionClosed, + BrokerPositionOpened, + PositionCloseRequested, + PositionInitialized, +) +from core.mixins import EventHandlerMixin +from core.models.symbol import Symbol +from core.models.timeframe import Timeframe +from core.queries.position import GetClosePosition, GetOpenPosition + +logger = logging.getLogger(__name__) + + +PositionEventType = Union[PositionInitialized, PositionCloseRequested] + + +class MarketOrderActor(StrategyActor, EventHandlerMixin): + _EVENTS = [PositionInitialized, PositionCloseRequested] + + def __init__(self, symbol: Symbol, timeframe: Timeframe): + super().__init__(symbol, timeframe) + EventHandlerMixin.__init__(self) + self._register_event_handlers() + + async def on_receive(self, event: PositionEventType): + return await self.handle_event(event) + + def _register_event_handlers(self): + self.register_handler(PositionInitialized, self._execute_order) + self.register_handler(PositionCloseRequested, self._close_position) + + async def _execute_order(self, event: PositionInitialized): + current_position = event.position + + logger.debug(f"New Position: {current_position}") + + await self.ask(OpenPosition(current_position)) + + filled_order = await self.ask(GetOpenPosition(current_position)) + + current_position = current_position.fill_order(filled_order) + + logger.info(f"Position to Open: {current_position}") + + if current_position.closed: + await self.tell(BrokerPositionClosed(current_position)) + else: + await self.tell(BrokerPositionOpened(current_position)) + + async def _close_position(self, event: PositionCloseRequested): + current_position = event.position + + logger.debug(f"To Close Position: {current_position}") + + await self.ask(ClosePosition(current_position)) + + order = await self.ask(GetClosePosition(current_position)) + + current_position = current_position.fill_order(order) + + logger.info(f"Closed Position: {current_position}") + + await self.tell(BrokerPositionClosed(current_position)) diff --git a/executor/_market_order_actor.py b/executor/_market_order_actor.py deleted file mode 100644 index f18addd0..00000000 --- a/executor/_market_order_actor.py +++ /dev/null @@ -1,95 +0,0 @@ -import logging -from typing import Union - -from core.actors import Actor -from core.commands.broker import AdjustPosition, ClosePosition, OpenPosition -from core.events.position import ( - BrokerPositionAdjusted, - BrokerPositionClosed, - BrokerPositionOpened, - PositionCloseRequested, - PositionInitialized, -) -from core.events.risk import RiskAdjustRequested -from core.models.symbol import Symbol -from core.models.timeframe import Timeframe -from core.queries.position import GetClosePosition, GetOpenPosition - -logger = logging.getLogger(__name__) - - -PositionEventType = Union[PositionInitialized, PositionCloseRequested] - - -class MarketOrderActor(Actor): - _EVENTS = [PositionInitialized, RiskAdjustRequested, PositionCloseRequested] - - def __init__(self, symbol: Symbol, timeframe: Timeframe): - super().__init__(symbol, timeframe) - - def pre_receive(self, event: PositionEventType): - event = event.position.signal if hasattr(event, "position") else event - return event.symbol == self._symbol and event.timeframe == self._timeframe - - async def on_receive(self, event: PositionEventType): - handlers = { - PositionInitialized: self._execute_order, - RiskAdjustRequested: self._adjust_position, - PositionCloseRequested: self._close_position, - } - - handler = handlers.get(type(event)) - - if handler: - await handler(event) - - async def _execute_order(self, event: PositionInitialized): - current_position = event.position - - logger.debug(f"New Position: {current_position}") - - await self.ask(OpenPosition(current_position)) - - order = await self.ask(GetOpenPosition(current_position)) - - current_position = current_position.add_order(order) - - logger.info(f"Position to Open: {current_position}") - - if current_position.closed: - await self.tell(BrokerPositionClosed(current_position)) - else: - await self.tell(BrokerPositionOpened(current_position)) - - async def _adjust_position(self, event: RiskAdjustRequested): - current_position, entry_price = event.position, event.adjust_price - - logger.debug(f"To Adjust Position: {current_position}, adjust: {entry_price}") - - await self.ask(AdjustPosition(current_position, entry_price)) - - order = await self.ask(GetOpenPosition(current_position)) - - current_position = current_position.add_order(order) - - logger.info(f"Adjusted Position: {current_position}") - - if current_position.closed: - await self.tell(BrokerPositionClosed(current_position)) - else: - await self.tell(BrokerPositionAdjusted(current_position)) - - async def _close_position(self, event: PositionCloseRequested): - current_position = event.position - - logger.debug(f"To Close Position: {current_position}") - - await self.ask(ClosePosition(current_position, event.exit_price)) - - order = await self.ask(GetClosePosition(current_position)) - - current_position = current_position.add_order(order) - - logger.info(f"Closed Position: {current_position}") - - await self.tell(BrokerPositionClosed(current_position)) diff --git a/executor/_paper_actor.py b/executor/_paper_actor.py new file mode 100644 index 00000000..4db9a55e --- /dev/null +++ b/executor/_paper_actor.py @@ -0,0 +1,162 @@ +import logging +from enum import Enum, auto +from typing import Optional, Union + +from core.actors import StrategyActor +from core.events.position import ( + BrokerPositionClosed, + BrokerPositionOpened, + PositionCloseRequested, + PositionInitialized, +) +from core.mixins import EventHandlerMixin +from core.models.ohlcv import OHLCV +from core.models.order import Order, OrderStatus, OrderType +from core.models.position import Position +from core.models.side import PositionSide +from core.models.symbol import Symbol +from core.models.timeframe import Timeframe +from core.queries.ohlcv import NextBar + +OrderEventType = Union[PositionInitialized, PositionCloseRequested] + +logger = logging.getLogger(__name__) + + +class PriceDirection(Enum): + OHLC = auto() + OLHC = auto() + + +class PaperOrderActor(StrategyActor, EventHandlerMixin): + _EVENTS = [ + PositionInitialized, + PositionCloseRequested, + ] + + def __init__(self, symbol: Symbol, timeframe: Timeframe): + super().__init__(symbol, timeframe) + EventHandlerMixin.__init__(self) + self._register_event_handlers() + + async def on_receive(self, event: OrderEventType): + return await self.handle_event(event) + + def _register_event_handlers(self): + self.register_handler(PositionInitialized, self._execute_order) + self.register_handler(PositionCloseRequested, self._close_position) + + async def _execute_order(self, event: PositionInitialized): + current_position = event.position + + logger.debug(f"New Position: {current_position}") + + entry_order = current_position.entry_order() + + price = self._find_open_price(current_position, entry_order) + size = entry_order.size + fee = current_position.theo_taker_fee(size, price) + + order = Order( + status=OrderStatus.EXECUTED, + type=OrderType.PAPER, + price=price, + size=size, + fee=fee, + ) + + current_position = current_position.fill_order(order) + + if not current_position.is_valid: + order = Order( + status=OrderStatus.FAILED, type=OrderType.PAPER, price=0, size=0 + ) + + current_position = current_position.fill_order(order) + + logger.debug(f"Position to Open: {current_position}") + + if current_position.closed: + await self.tell(BrokerPositionClosed(current_position)) + else: + await self.tell(BrokerPositionOpened(current_position)) + + async def _close_position(self, event: PositionCloseRequested): + current_position = event.position + + logger.debug(f"To Close Position: {current_position}") + + exit_order = current_position.exit_order() + + next_bar = await self.ask( + NextBar(self.symbol, self.timeframe, current_position.risk_bar) + ) + + price = self._find_close_price(current_position, exit_order, next_bar) + size = exit_order.size + fee = current_position.theo_taker_fee(size, price) + + order = Order( + status=OrderStatus.CLOSED, + type=OrderType.PAPER, + price=price, + size=size, + fee=fee, + ) + + next_position = current_position.fill_order(order) + + logger.debug(f"Closed Position: {next_position}") + + await self.tell(BrokerPositionClosed(next_position)) + + def _find_fill_price(self, side: PositionSide, bar: OHLCV, price: float) -> float: + direction = self._intrabar_price_movement(bar) + + high, low = bar.high, bar.low + in_bar = low <= price <= high + + if direction == PriceDirection.OHLC: + if side == PositionSide.LONG: + return price if in_bar else high + elif direction == PriceDirection.OLHC: + if side == PositionSide.SHORT: + return price if in_bar else low + + return bar.close + + def _find_open_price( + self, position: Position, order: Order, bar: Optional[OHLCV] = None + ) -> float: + if bar is None: + bar = position.signal_bar + + diff = bar.timestamp - position.signal_bar.timestamp + + if diff > position.signal.timeframe.to_milliseconds(): + logger.warn("The open of the next bar is too far from the previous one.") + bar = position.signal_bar + + return self._find_fill_price(position.side, bar, order.price) + + def _find_close_price( + self, position: Position, order: Order, bar: Optional[OHLCV] = None + ) -> float: + if bar is None: + bar = position.risk_bar + + diff = bar.timestamp - position.risk_bar.timestamp + + if diff > position.signal.timeframe.to_milliseconds(): + logger.warn("The close of the next bar is too far from the previous one.") + bar = position.risk_bar + + return self._find_fill_price(position.side, bar, order.price) + + @staticmethod + def _intrabar_price_movement(bar: OHLCV) -> PriceDirection: + return ( + PriceDirection.OHLC + if abs(bar.open - bar.high) < abs(bar.open - bar.low) + else PriceDirection.OLHC + ) diff --git a/executor/_paper_order_actor.py b/executor/_paper_order_actor.py deleted file mode 100644 index a48c9353..00000000 --- a/executor/_paper_order_actor.py +++ /dev/null @@ -1,213 +0,0 @@ -import asyncio -import logging -from collections import deque -from enum import Enum, auto -from typing import Union - -from core.actors import Actor -from core.events.ohlcv import NewMarketDataReceived -from core.events.position import ( - BrokerPositionAdjusted, - BrokerPositionClosed, - BrokerPositionOpened, - PositionCloseRequested, - PositionInitialized, -) -from core.events.risk import RiskAdjustRequested -from core.models.ohlcv import OHLCV -from core.models.order import Order, OrderStatus, OrderType -from core.models.position import Position -from core.models.side import PositionSide -from core.models.symbol import Symbol -from core.models.timeframe import Timeframe - -OrderEventType = Union[ - NewMarketDataReceived, PositionInitialized, PositionCloseRequested -] - -logger = logging.getLogger(__name__) - - -class PriceDirection(Enum): - OHLC = auto() - OLHC = auto() - - -class PaperOrderActor(Actor): - _EVENTS = [ - NewMarketDataReceived, - PositionInitialized, - RiskAdjustRequested, - PositionCloseRequested, - ] - - def __init__(self, symbol: Symbol, timeframe: Timeframe): - super().__init__(symbol, timeframe) - self.lock = asyncio.Lock() - self.tick_buffer = deque(maxlen=13) - - def pre_receive(self, event: OrderEventType): - event = event.position.signal if hasattr(event, "position") else event - return event.symbol == self._symbol and event.timeframe == self._timeframe - - async def on_receive(self, event: OrderEventType): - handlers = { - PositionInitialized: self._execute_order, - RiskAdjustRequested: self._adjust_position, - PositionCloseRequested: self._close_position, - NewMarketDataReceived: self._update_tick, - } - - handler = handlers.get(type(event)) - - if handler: - await handler(event) - - async def _execute_order(self, event: PositionInitialized): - current_position = event.position - - logger.debug(f"New Position: {current_position}") - - size = current_position.pending_size - side = current_position.side - price = current_position.pending_price - fill_price = await self._determine_fill_price(side, price) - - if ( - side == PositionSide.LONG and current_position.stop_loss_price > fill_price - ) or ( - side == PositionSide.SHORT and current_position.stop_loss_price < fill_price - ): - order = Order( - status=OrderStatus.FAILED, - type=OrderType.PAPER, - price=0, - size=0, - ) - else: - order = Order( - status=OrderStatus.EXECUTED, - type=OrderType.PAPER, - fee=fill_price * size * current_position.signal.symbol.taker_fee, - price=fill_price, - size=size, - ) - - current_position = current_position.add_order(order) - - logger.debug(f"Position to Open: {current_position}") - - if current_position.closed: - await self.tell(BrokerPositionClosed(current_position)) - else: - await self.tell(BrokerPositionOpened(current_position)) - - async def _adjust_position(self, event: RiskAdjustRequested): - current_position = event.position - - logger.debug(f"To Adjust Position: {current_position}") - - total_value = (current_position.filled_size * current_position.entry_price) + ( - current_position.filled_size * event.adjust_price - ) - - size = round( - 1.3 * current_position.filled_size, - current_position.signal.symbol.position_precision, - ) - fill_price = round( - total_value / size, current_position.signal.symbol.price_precision - ) - - if ( - current_position.side == PositionSide.LONG - and current_position.stop_loss_price > current_position.take_profit_price - ) or ( - current_position.side == PositionSide.SHORT - and current_position.stop_loss_price < current_position.take_profit_price - ): - logger.error(f"Wrong Adjust: {current_position}") - return - else: - order = Order( - status=OrderStatus.EXECUTED, - type=OrderType.PAPER, - fee=fill_price * size * current_position.signal.symbol.taker_fee, - price=fill_price, - size=size, - ) - - current_position = current_position.add_order(order) - - logger.debug(f"Adjusted Position: {current_position}") - - await self.tell(BrokerPositionAdjusted(current_position)) - - async def _close_position(self, event: PositionCloseRequested): - current_position = event.position - - logger.debug(f"To Close Position: {current_position}") - - fill_price = await self._determine_fill_price( - current_position.side, event.exit_price - ) - price = self._calculate_closing_price(current_position, fill_price) - size = current_position.filled_size - - order = Order( - status=OrderStatus.CLOSED, - type=OrderType.PAPER, - fee=price * size * current_position.signal.symbol.taker_fee, - price=price, - size=size, - ) - - next_position = current_position.add_order(order) - - logger.debug(f"Closed Position: {next_position}") - - await self.tell(BrokerPositionClosed(next_position)) - - async def _update_tick(self, event: NewMarketDataReceived): - async with self.lock: - self.tick_buffer.append(event.ohlcv) - - async def _determine_fill_price( - self, side: PositionSide, event_price: float - ) -> float: - async with self.lock: - tick_buffer = sorted(self.tick_buffer, key=lambda x: x.timestamp) - last_tick = tick_buffer[-1] - - direction = self._intrabar_price_movement(last_tick) - high, low = last_tick.high, last_tick.low - - in_bar = low <= event_price <= high - - if side == PositionSide.LONG and direction == PriceDirection.OHLC: - return event_price if in_bar else high - elif side == PositionSide.SHORT and direction == PriceDirection.OLHC: - return event_price if in_bar else low - else: - return last_tick.close - - @staticmethod - def _intrabar_price_movement(tick: OHLCV) -> PriceDirection: - return ( - PriceDirection.OHLC - if abs(tick.open - tick.high) < abs(tick.open - tick.low) - else PriceDirection.OLHC - ) - - @staticmethod - def _calculate_closing_price(position: Position, fill_price: float) -> float: - if position.side == PositionSide.LONG: - return max( - min(fill_price, position.take_profit_price), - position.stop_loss_price, - ) - else: - return min( - max(fill_price, position.take_profit_price), - position.stop_loss_price, - ) diff --git a/feed/_factory.py b/feed/_factory.py index e5debf79..e9631325 100644 --- a/feed/_factory.py +++ b/feed/_factory.py @@ -1,6 +1,7 @@ from core.interfaces.abstract_config import AbstractConfig from core.interfaces.abstract_exhange_factory import AbstractExchangeFactory from core.interfaces.abstract_feed_actor_factory import AbstractFeedActorFactory +from core.interfaces.abstract_timeseries import AbstractTimeSeriesService from core.models.exchange import ExchangeType from core.models.feed import FeedType from core.models.symbol import Symbol @@ -15,11 +16,13 @@ def __init__( self, exchange_factory: AbstractExchangeFactory, ws_factory: AbstractExchangeFactory, + ts_service: AbstractTimeSeriesService, config_service: AbstractConfig, ): self.config_service = config_service self.exchange_factory = exchange_factory self.ws_factory = ws_factory + self.ts_service = ts_service def create_actor( self, @@ -33,6 +36,7 @@ def create_actor( symbol, timeframe, self.exchange_factory.create(exchange_type), + self.ts_service, self.config_service, ) if feed_type == FeedType.HISTORICAL @@ -40,6 +44,7 @@ def create_actor( symbol, timeframe, self.ws_factory.create(exchange_type), + self.ts_service, ) ) actor.start() diff --git a/feed/_historical.py b/feed/_historical.py index b5cdc32b..25d1cd39 100644 --- a/feed/_historical.py +++ b/feed/_historical.py @@ -1,10 +1,13 @@ import asyncio +import bisect +from typing import AsyncIterator, List -from core.actors import Actor +from core.actors import StrategyActor from core.commands.feed import StartHistoricalFeed from core.events.ohlcv import NewMarketDataReceived from core.interfaces.abstract_config import AbstractConfig from core.interfaces.abstract_exchange import AbstractExchange +from core.interfaces.abstract_timeseries import AbstractTimeSeriesService from core.models.bar import Bar from core.models.lookback import Lookback from core.models.ohlcv import OHLCV @@ -44,7 +47,6 @@ async def __aenter__(self): async def __aexit__(self, exc_type, exc_value, traceback): self.iterator = None - return self def __aiter__(self): return self @@ -67,11 +69,8 @@ def _next_item_or_end(self): except StopIteration: return self.sentinel - def get_last_bar(self): - return self.last_row - -class HistoricalActor(Actor): +class HistoricalActor(StrategyActor): _EVENTS = [StartHistoricalFeed] def __init__( @@ -79,15 +78,14 @@ def __init__( symbol: Symbol, timeframe: Timeframe, exchange: AbstractExchange, + ts: AbstractTimeSeriesService, config_service: AbstractConfig, ): super().__init__(symbol, timeframe) self.exchange = exchange + self.ts = ts self.config_service = config_service.get("backtest") - self.last_bar = None - - def pre_receive(self, msg: StartHistoricalFeed): - return self._symbol == msg.symbol and self._timeframe == msg.timeframe + self.buffer: List[Bar] = [] async def on_receive(self, msg: StartHistoricalFeed): symbol, timeframe = msg.symbol, msg.timeframe @@ -100,9 +98,56 @@ async def on_receive(self, msg: StartHistoricalFeed): msg.out_sample, self.config_service["batch_size"], ) as stream: - async for bar in stream: - await self.tell( - NewMarketDataReceived(symbol, timeframe, bar.ohlcv, bar.closed) - ) + async for bars in self.batched(stream, self.config_service["buff_size"]): + self._update_buffer(bars) + await self._process_buffer() + + await self._process_remaining_buffer() + + def _update_buffer(self, batch: List[Bar]): + for bar in batch: + bisect.insort(self.buffer, bar, key=lambda x: x.ohlcv.timestamp) - self.last_bar = stream.get_last_bar() + async def _process_buffer(self): + buff_size = self.config_service["buff_size"] + + while len(self.buffer) >= buff_size: + bars = [self.buffer.pop(0) for _ in range(buff_size)] + await self._outbox(bars) + await self._handle_market(bars) + + async def _process_remaining_buffer(self): + buff_size = self.config_service["buff_size"] + + while self.buffer: + bars = [self.buffer.pop(0) for _ in range(min(len(self.buffer), buff_size))] + await self._outbox(bars) + await self._handle_market(bars) + + async def _handle_market(self, bars: List[Bar]) -> None: + for bar in bars: + await self.tell( + NewMarketDataReceived( + self.symbol, self.timeframe, bar.ohlcv, bar.closed + ) + ) + await asyncio.sleep(0.0001) + + async def _outbox(self, bars: List[Bar]) -> None: + ts = [] + for bar in bars: + if bar.closed: + ts.append(self.ts.upsert(self.symbol, self.timeframe, bar.ohlcv)) + + await asyncio.gather(*ts) + + @staticmethod + async def batched(stream: AsyncIterator[Bar], batch_size: int): + batch = [] + async for bar in stream: + batch.append(bar) + if len(batch) >= batch_size: + yield batch + batch = [] + if batch: + yield batch diff --git a/feed/_realtime.py b/feed/_realtime.py index 337c53be..0f7274e0 100644 --- a/feed/_realtime.py +++ b/feed/_realtime.py @@ -1,9 +1,10 @@ import asyncio import logging -from core.actors import Actor +from core.actors import StrategyActor from core.commands.feed import StartRealtimeFeed from core.events.ohlcv import NewMarketDataReceived +from core.interfaces.abstract_timeseries import AbstractTimeSeriesService from core.interfaces.abstract_ws import AbstractWS from core.models.symbol import Symbol from core.models.timeframe import Timeframe @@ -44,7 +45,7 @@ async def __anext__(self): raise -class RealtimeActor(Actor): +class RealtimeActor(StrategyActor): _EVENTS = [StartRealtimeFeed] def __init__( @@ -52,13 +53,12 @@ def __init__( symbol: Symbol, timeframe: Timeframe, ws: AbstractWS, + ts_service: AbstractTimeSeriesService, ): super().__init__(symbol, timeframe) self.ws = ws self.task = None - - def pre_receive(self, msg: StartRealtimeFeed): - return self._symbol == msg.symbol and self._timeframe == msg.timeframe + self.ts_service = ts_service def on_stop(self): if self.task: @@ -73,11 +73,13 @@ async def _run_realtime_feed(self, msg: StartRealtimeFeed): symbol, timeframe = msg.symbol, msg.timeframe async with AsyncRealTimeData(self.ws, symbol, timeframe) as stream: - async for bar in stream: - if bar: + async for bars in stream: + for bar in bars: + if bar.closed: + logger.info(f"{symbol}_{timeframe}:{bar}") + + await self.ts_service.upsert(symbol, timeframe, bar.ohlcv) + await self.tell( NewMarketDataReceived(symbol, timeframe, bar.ohlcv, bar.closed) ) - - if bar and bar.closed: - logger.info(f"Tick: {symbol}_{timeframe}:{bar}") diff --git a/infrastructure/event_dispatcher/event_dedup.py b/infrastructure/event_dispatcher/event_dedup.py new file mode 100644 index 00000000..f1c6f12e --- /dev/null +++ b/infrastructure/event_dispatcher/event_dedup.py @@ -0,0 +1,25 @@ +import asyncio +from typing import Set + +from core.events.base import Event + + +class EventDedup: + def __init__(self): + self._events_in_queue: Set[int] = set() + self._lock = asyncio.Lock() + + async def add_event(self, event: Event) -> bool: + async with self._lock: + key = event.meta.key + + if key in self._events_in_queue: + return False + + self._events_in_queue.add(key) + + return True + + async def remove_event(self, event: Event) -> None: + async with self._lock: + self._events_in_queue.discard(event.meta.key) diff --git a/infrastructure/event_dispatcher/event_dispatcher.py b/infrastructure/event_dispatcher/event_dispatcher.py index be1f103b..515e9629 100644 --- a/infrastructure/event_dispatcher/event_dispatcher.py +++ b/infrastructure/event_dispatcher/event_dispatcher.py @@ -26,8 +26,8 @@ def __init__(self, config_service: AbstractConfig): self.cancel_event = asyncio.Event() self.config = config_service.get("bus") - self._store = EventStore(config_service) + self._store = EventStore(config_service) self._command_worker_pool = None self._query_worker_pool = None self._event_worker_pool = None @@ -62,18 +62,12 @@ def unregister(self, event_class: Type[Event], handler: Callable) -> None: self.event_handler.unregister(event_class, handler) async def execute(self, command: Command, *args, **kwargs) -> None: - await asyncio.gather( - self._dispatch_to_poll(command, self.command_worker_pool, *args, **kwargs), - command.wait_for_execution(), - ) + await self._dispatch_to_poll(command, self.command_worker_pool, *args, **kwargs) + await command.wait_for_execution() async def query(self, query: Query, *args, **kwargs) -> Any: - _, result = await asyncio.gather( - self._dispatch_to_poll(query, self.query_worker_pool, *args, **kwargs), - query.wait_for_response(), - ) - - return result + await self._dispatch_to_poll(query, self.query_worker_pool, *args, **kwargs) + return await query.wait_for_response() async def dispatch(self, event: Event, *args, **kwargs) -> None: await self._dispatch_to_poll(event, self.event_worker_pool, *args, **kwargs) @@ -103,9 +97,8 @@ async def _dispatch_to_poll( ) -> None: if isinstance(event, EventEnded): self.cancel_event.set() - return - - await worker_pool.dispatch_to_worker(event, *args, **kwargs) + else: + await worker_pool.dispatch_to_worker(event, *args, **kwargs) def _create_worker_pool(self) -> WorkerPool: return WorkerPool( diff --git a/infrastructure/event_dispatcher/event_handler.py b/infrastructure/event_dispatcher/event_handler.py index a17e3995..f9ffcb74 100644 --- a/infrastructure/event_dispatcher/event_handler.py +++ b/infrastructure/event_dispatcher/event_handler.py @@ -17,11 +17,11 @@ class EventHandler: def __init__(self): self._event_handlers: Dict[Type[Event], List[HandlerType]] = defaultdict(list) - self._dead_letter_queue: Deque[Tuple[Event, Exception]] = deque(maxlen=100) + self._dlq: Deque[Tuple[Event, Exception]] = deque(maxlen=100) @property def dlq(self): - return self._dead_letter_queue + return self._dlq def register( self, @@ -39,23 +39,17 @@ def unregister(self, event_class: Type[Event], handler: HandlerType) -> None: ] async def handle_event(self, event: Event, *args, **kwargs) -> None: - event_type = type(event) - handlers = self._event_handlers.get(event_type, []) + handlers = self._event_handlers.get(type(event), []) for handler, filter_fn in handlers: if not filter_fn or filter_fn(event): - await self._call_handler(handler, event, *args, **kwargs) + try: + await self._call_handler(handler, event, *args, **kwargs) + except Exception as e: + self._handle_exception(handler, event, e) async def _call_handler( self, handler: HandlerType, event: Event, *args, **kwargs - ) -> None: - try: - await self._execute_handler(handler, event, *args, **kwargs) - except Exception as e: - self._handle_exception(handler, event, e) - - async def _execute_handler( - self, handler: HandlerType, event: Event, *args, **kwargs ) -> None: if asyncio.iscoroutinefunction(handler): response = await handler(event, *args, **kwargs) @@ -79,4 +73,4 @@ def _handle_exception( elif isinstance(event, Query): event.set_response(None) - self._dead_letter_queue.append((event, exception)) + self._dlq.append((event, exception)) diff --git a/infrastructure/event_dispatcher/event_worker.py b/infrastructure/event_dispatcher/event_worker.py index 9920f088..1bada320 100644 --- a/infrastructure/event_dispatcher/event_worker.py +++ b/infrastructure/event_dispatcher/event_worker.py @@ -2,6 +2,7 @@ from typing import Any, AsyncIterable, Dict, Tuple from core.events.base import Event +from infrastructure.event_dispatcher.event_dedup import EventDedup from .event_handler import EventHandler @@ -11,11 +12,11 @@ def __init__( self, event_handler: EventHandler, cancel_event: asyncio.Event, - events_in_queue: set, + dedup: EventDedup, ): self.event_handler = event_handler self.cancel_event = cancel_event - self.events_in_queue = events_in_queue + self.dedup = dedup self.queue = asyncio.Queue() self.tasks = asyncio.create_task(self._process_events()) @@ -32,17 +33,13 @@ async def _get_event_stream( yield event, args, kwargs - self.events_in_queue.remove(event.meta.key) + await self.dedup.remove_event(event) + self.queue.task_done() async def dispatch(self, event: Event, *args, **kwargs) -> None: - event_key = event.meta.key - - if event_key in self.events_in_queue: - return - - self.events_in_queue.add(event_key) - await self.queue.put((event, args, kwargs)) + if await self.dedup.add_event(event): + await self.queue.put((event, args, kwargs)) async def wait(self) -> None: await self.queue.join() diff --git a/infrastructure/event_dispatcher/load_balancer.py b/infrastructure/event_dispatcher/load_balancer.py index fcfd4afa..9129603f 100644 --- a/infrastructure/event_dispatcher/load_balancer.py +++ b/infrastructure/event_dispatcher/load_balancer.py @@ -7,29 +7,50 @@ def softmax(x): class LoadBalancer: - def __init__(self, priority_groups: int, learning_rate: float = 0.001): + def __init__( + self, + priority_groups: int, + initial_kp: float = 0.3, + initial_ki: float = 0.6, + initial_kd: float = 0.1, + learning_rate: float = 0.001, + decay_rate: float = 0.99, + ): self._group_event_counts = np.zeros(priority_groups) - self._initialize_load_balancer(priority_groups) + self._initialize_load_balancer( + priority_groups, initial_kp, initial_ki, initial_kd + ) self._group_event_counts_threshold = 1e4 self._learning_rate = learning_rate - - def _initialize_load_balancer(self, priority_groups: int): - self._kp = np.ones(priority_groups) * 0.3 - self._ki = np.ones(priority_groups) * 0.6 - self._kd = np.ones(priority_groups) * 0.1 + self._decay_rate = decay_rate + + def _initialize_load_balancer( + self, + priority_groups: int, + initial_kp: float, + initial_ki: float, + initial_kd: float, + ): + self._kp = np.ones(priority_groups) * initial_kp + self._ki = np.ones(priority_groups) * initial_ki + self._kd = np.ones(priority_groups) * initial_kd self._integral_errors = np.zeros(priority_groups) self._previous_errors = np.zeros(priority_groups) self._target_ratios = 1 / (np.arange(priority_groups) + 1) def register_event(self, priority_group: int): - if 0 <= priority_group < len(self._group_event_counts): - self._group_event_counts[priority_group] += 1 + if not 0 <= priority_group < len(self._group_event_counts): + raise ValueError(f"Invalid priority group: {priority_group}") + + self._group_event_counts[priority_group] += 1 - if self._group_event_counts.max() > self._group_event_counts_threshold: - self._group_event_counts *= 0.5 - else: - raise ValueError("Invalid priority group!") + if self._group_event_counts.max() > self._group_event_counts_threshold: + self._group_event_counts *= 0.5 + + self._group_event_counts_threshold = max( + self._group_event_counts_threshold * 1.1, 1e4 + ) def determine_priority_group(self, priority: int) -> int: total_group = self._group_event_counts.sum() @@ -38,32 +59,31 @@ def determine_priority_group(self, priority: int) -> int: return np.clip(priority - 1, 0, len(self._group_event_counts) - 1) processed_ratios = self._group_event_counts / total_group - errors = self._target_ratios - processed_ratios - - for i, error in enumerate(errors): - self._integral_errors[i] += error - self._update_pid_parameters(i, error) - - derivative_errors = errors - self._previous_errors - - self._previous_errors = errors.copy() + self._update_pid(errors) control_outputs = ( self._kp * errors + self._ki * self._integral_errors - + self._kd * derivative_errors + + self._kd * (errors - self._previous_errors) ) - weights = softmax(control_outputs) - - return np.random.choice(np.arange(len(control_outputs)), p=weights) + self._previous_errors = errors.copy() + self._learning_rate *= self._decay_rate - def _update_pid_parameters(self, priority_group: int, error: float): - self._kp[priority_group] += self._learning_rate * error - self._ki[priority_group] += ( - self._learning_rate * self._integral_errors[priority_group] - ) - self._kd[priority_group] += self._learning_rate * ( - error - self._previous_errors[priority_group] + return np.random.choice( + np.arange(len(control_outputs)), p=softmax(control_outputs) ) + + def _update_pid(self, errors: np.ndarray): + for i, error in enumerate(errors): + self._integral_errors[i] += error + self._kp[i] = np.clip(self._kp[i] + self._learning_rate * error, 0, 1) + self._ki[i] = np.clip( + self._ki[i] + self._learning_rate * self._integral_errors[i], 0, 1 + ) + self._kd[i] = np.clip( + self._kd[i] + self._learning_rate * (error - self._previous_errors[i]), + 0, + 1, + ) diff --git a/infrastructure/event_dispatcher/worker_pool.py b/infrastructure/event_dispatcher/worker_pool.py index 312a2ac5..8f8f7b9e 100644 --- a/infrastructure/event_dispatcher/worker_pool.py +++ b/infrastructure/event_dispatcher/worker_pool.py @@ -2,6 +2,7 @@ from core.events.base import Event +from .event_dedup import EventDedup from .event_handler import EventHandler from .event_worker import EventWorker from .load_balancer import LoadBalancer @@ -11,26 +12,30 @@ class WorkerPool: def __init__( self, num_workers: int, - num_priority_groups: int, + num_piority_groups: int, event_handler: EventHandler, cancel_event: asyncio.Event, ): - self.events_in_queue = set() + self.workers = [] + self.load_balancer = LoadBalancer(num_piority_groups) + self.dedup = EventDedup() + self.event_handler = event_handler + self.cancel_event = cancel_event + self._initialize_workers(num_workers) + + def _initialize_workers(self, num_workers): self.workers = [ - EventWorker(event_handler, cancel_event, self.events_in_queue) + EventWorker(self.event_handler, self.cancel_event, self.dedup) for _ in range(num_workers) ] - self.load_balancer = LoadBalancer(num_priority_groups) async def dispatch_to_worker(self, event: Event, *args, **kwargs) -> None: priority_group = self.load_balancer.determine_priority_group( event.meta.priority ) - worker = self.workers[priority_group % len(self.workers)] await worker.dispatch(event, *args, **kwargs) - self.load_balancer.register_event(priority_group) async def wait(self) -> None: diff --git a/infrastructure/event_store/event_encoder.py b/infrastructure/event_store/event_encoder.py index f510e123..9b7385dc 100644 --- a/infrastructure/event_store/event_encoder.py +++ b/infrastructure/event_store/event_encoder.py @@ -1,14 +1,12 @@ +import asyncio import json +from abc import ABC from enum import Enum from typing import Any import numpy as np from core.events.base import Event -from core.interfaces.abstract_position_risk_strategy import AbstractPositionRiskStrategy -from core.interfaces.abstract_position_take_profit_strategy import ( - AbstractPositionTakeProfitStrategy, -) from core.models.indicator import Indicator @@ -16,9 +14,7 @@ class Encoder(json.JSONEncoder): def default(self, obj): if isinstance(obj, Enum): return obj.value - if isinstance(obj, AbstractPositionRiskStrategy): - return obj.__class__.__name__ - if isinstance(obj, AbstractPositionTakeProfitStrategy): + if isinstance(obj, ABC): return obj.__class__.__name__ if isinstance(obj, np.ndarray): return obj.tolist() @@ -28,5 +24,7 @@ def default(self, obj): return obj.to_dict() if isinstance(obj, type(Any)): return "Any" + if isinstance(obj, asyncio.Future): + return None return str(obj) diff --git a/market/__init__.py b/market/__init__.py new file mode 100644 index 00000000..28eaa73e --- /dev/null +++ b/market/__init__.py @@ -0,0 +1,3 @@ +from ._actor import MarketActor + +__all__ = [MarketActor] diff --git a/market/_actor.py b/market/_actor.py new file mode 100644 index 00000000..1b3960e2 --- /dev/null +++ b/market/_actor.py @@ -0,0 +1,43 @@ +from typing import Union + +from core.actors import BaseActor +from core.interfaces.abstract_timeseries import AbstractTimeSeriesService +from core.mixins import EventHandlerMixin +from core.models.ohlcv import OHLCV +from core.models.ta import TechAnalysis +from core.queries.ohlcv import TA, BackNBars, NextBar, PrevBar + +MarketEvent = Union[NextBar, PrevBar, TA, BackNBars] + + +class MarketActor(BaseActor, EventHandlerMixin): + _EVENTS = [NextBar, PrevBar, TA, BackNBars] + + def __init__(self, ts: AbstractTimeSeriesService): + super().__init__() + EventHandlerMixin.__init__(self) + self._register_event_handlers() + self.ts = ts + + async def on_receive(self, event: MarketEvent): + return await self.handle_event(event) + + def _register_event_handlers(self): + self.register_handler(NextBar, self._handle_next_bar) + self.register_handler(PrevBar, self._handle_prev_bar) + self.register_handler(BackNBars, self._handle_back_n_bars) + self.register_handler(TA, self._handle_ta) + + async def _handle_next_bar(self, event: NextBar) -> OHLCV: + return await self.ts.next_bar(event.symbol, event.timeframe, event.ohlcv) + + async def _handle_prev_bar(self, event: PrevBar) -> OHLCV: + return await self.ts.prev_bar(event.symbol, event.timeframe, event.ohlcv) + + async def _handle_back_n_bars(self, event: BackNBars) -> OHLCV: + return await self.ts.back_n_bars( + event.symbol, event.timeframe, event.ohlcv, event.n + ) + + async def _handle_ta(self, event: TA) -> TechAnalysis: + return await self.ts.ta(event.symbol, event.timeframe, event.ohlcv) diff --git a/optimization/_genetic.py b/optimization/_genetic.py index 38e3d5e1..db7648e1 100644 --- a/optimization/_genetic.py +++ b/optimization/_genetic.py @@ -153,10 +153,9 @@ def _crossover(self, parent1, parent2): GeneticAttributes.EXIT, ]: child1_strategy = Strategy( - parent1.strategy.type, - parent2.strategy.entry + parent2.strategy.signal if chosen_attr == GeneticAttributes.SIGNAL - else parent1.strategy.entry, + else parent1.strategy.signal, parent1.strategy.confirm if chosen_attr == GeneticAttributes.CONFIRM else parent2.strategy.confirm, @@ -174,10 +173,9 @@ def _crossover(self, parent1, parent2): else parent2.strategy.exit, ) child2_strategy = Strategy( - parent2.strategy.type, - parent1.strategy.entry + parent1.strategy.signal if chosen_attr == GeneticAttributes.SIGNAL - else parent2.strategy.entry, + else parent2.strategy.signal, parent2.strategy.confirm if chosen_attr == GeneticAttributes.CONFIRM else parent1.strategy.confirm, diff --git a/portfolio/_portfolio.py b/portfolio/_portfolio.py index 80fa4149..cfcd9605 100644 --- a/portfolio/_portfolio.py +++ b/portfolio/_portfolio.py @@ -1,4 +1,5 @@ import asyncio +from contextlib import asynccontextmanager from typing import Dict, Tuple from core.models.portfolio import Performance @@ -10,84 +11,56 @@ class PortfolioStorage: def __init__(self): - self.data: Dict[Tuple[Symbol, Timeframe, Strategy], Performance] = {} + self._data: Dict[Tuple[Symbol, Timeframe, Strategy], Performance] = {} self._lock = asyncio.Lock() async def next(self, position: Position, account_size: int, risk_per_trade: float): - async with self._lock: - key = self._get_key( - position.signal.symbol, - position.signal.timeframe, - position.signal.strategy, - ) - - performance = self.data.get(key, None) - - if not performance: - performance = Performance(account_size, risk_per_trade) - - if position.pnl != 0: - performance = performance.next(position.pnl, position.fee) - - self.data[key] = performance - + key = self._get_key( + position.signal.symbol, position.signal.timeframe, position.signal.strategy + ) + async with self._state() as state: + performance = state[key] or Performance(account_size, risk_per_trade) + performance = performance.next(position.pnl, position.fee) + state[key] = performance return performance async def get(self, symbol: Symbol, timeframe: Timeframe, strategy: Strategy): - async with self._lock: - key = self._get_key(symbol, timeframe, strategy) - return self.data.get(key, None) + key = self._get_key(symbol, timeframe, strategy) + async with self._state() as state: + return state.get(key) async def reset( self, symbol, timeframe, strategy, account_size: int, risk_per_trade: float ): - async with self._lock: - key = self._get_key(symbol, timeframe, strategy) - self.data[key] = Performance(account_size, risk_per_trade) + key = self._get_key(symbol, timeframe, strategy) + async with self._state() as state: + state[key] = Performance(account_size, risk_per_trade) async def reset_all(self): - async with self._lock: - self.data = {} + async with self._state() as state: + state.clear() async def get_equity( self, symbol: Symbol, timeframe: Timeframe, strategy: Strategy ): - async with self._lock: - key = self._get_key(symbol, timeframe, strategy) - performance = self.data.get(key) - - if performance and len(performance.equity) > 2: - return performance.equity[-1] - - return 0 + performance = await self.get(symbol, timeframe, strategy) + return performance.equity[-1] if performance and len(performance.equity) else 1 async def get_kelly(self, symbol: Symbol, timeframe: Timeframe, strategy: Strategy): - async with self._lock: - key = self._get_key(symbol, timeframe, strategy) - performance = self.data.get(key) - - return performance.kelly if performance else 0 - - async def get_optimalf( - self, symbol: Symbol, timeframe: Timeframe, strategy: Strategy - ): - async with self._lock: - key = self._get_key(symbol, timeframe, strategy) - performance = self.data.get(key) - - return performance.optimal_f if performance else 0 + performance = await self.get(symbol, timeframe, strategy) + return performance.kelly if performance else 0 async def get_fitness( self, symbol: Symbol, timeframe: Timeframe, strategy: Strategy ): - async with self._lock: - key = self._get_key(symbol, timeframe, strategy) - performance = self.data.get(key) + performance = await self.get(symbol, timeframe, strategy) + return performance.deflated_sharpe_ratio if performance else 0 - if not performance: - return 0 - - return performance.deflated_sharpe_ratio + @asynccontextmanager + async def _state(self): + async with self._lock: + yield self._data - def _get_key(self, symbol, timeframe, strategy): + @staticmethod + def _get_key(symbol, timeframe, strategy): return (symbol, timeframe, strategy) diff --git a/portfolio/_service.py b/portfolio/_service.py index 4e9a9283..3aaaae7a 100644 --- a/portfolio/_service.py +++ b/portfolio/_service.py @@ -44,7 +44,7 @@ async def handle_backtest_started(self, event: BacktestStarted): ) @event_handler(TradeStarted) - async def trade_started(self, event: TradeStarted): + async def handle_trade_started(self, event: TradeStarted): await asyncio.gather( *[ self.state.reset( @@ -61,6 +61,11 @@ async def trade_started(self, event: TradeStarted): @event_handler(PositionClosed) async def handle_close_positon(self, event: PositionClosed): position = event.position + + if not position.is_valid: + logger.warn(f"Wrong position: {position}") + return + signal = position.signal symbol = signal.symbol timeframe = signal.timeframe @@ -75,11 +80,13 @@ async def handle_close_positon(self, event: PositionClosed): logger.info( f"Performance: strategy={symbol}_{timeframe}{strategy}, side={position.side}, " - + f"trades={performance.total_trades}, hit_ratio={round(performance.hit_ratio * 100)}%, " - + f"cagr={round(performance.cagr * 100, 2)}%, return={round(performance.expected_return * 100, 2)}%, volatility={round(performance.ann_volatility * 100, 2)}%, " - + f"smart_sharpe={round(performance.smart_sharpe_ratio, 4)}, smart_sortino={round(performance.smart_sortino_ratio, 4)}, " - + f"skew={round(performance.skew, 2)}, kurtosis={round(performance.kurtosis, 2)}, omega={round(performance.omega_ratio, 2)}, upi={round(performance.upi, 2)}, " - + f"max_dd={round(performance.max_drawdown * 100, 2)}%, pnl={round(performance.total_pnl, 4)}, fee={round(performance.total_fee, 4)}" + f"trades={performance.total_trades}, hit_ratio={performance.hit_ratio * 100:.0f}%, " + f"cagr={performance.cagr * 100:.2f}%, return={performance.expected_return * 100:.2f}%, " + f"volatility={performance.ann_volatility * 100:.2f}%, smart_sharpe={performance.smart_sharpe_ratio:.4f}, " + f"smart_sortino={performance.smart_sortino_ratio:.4f}, skew={performance.skew:.2f}, " + f"kurtosis={performance.kurtosis:.2f}, omega={performance.omega_ratio:.2f}, " + f"upi={performance.upi:.2f}, mdd={performance.max_drawdown * 100:.4f}%, " + f"pnl={performance.total_pnl:.4f}, fee={performance.total_fee:.4f}" ) await self.dispatch( @@ -139,19 +146,16 @@ async def position_risk(self, query: GetPositionRisk): risk_per_trade = self.config["risk_per_trade"] equity = await self.state.get_equity(symbol, timeframe, strategy) + fixed_size = equity * risk_per_trade - if equity == 0: + if query.type == PositionSizeType.Fixed: return risk_per_trade if query.type == PositionSizeType.Kelly: kelly = await self.state.get_kelly(symbol, timeframe, strategy) - return equity * kelly if kelly > 0 else equity * risk_per_trade - - if query.type == PositionSizeType.Optimalf: - optimalf = await self.state.get_optimalf(symbol, timeframe, strategy) - return optimalf * equity if optimalf > 0 else equity * risk_per_trade + return equity * kelly if kelly > 0 else fixed_size - return equity * risk_per_trade + return fixed_size @query_handler(GetFitness) async def fitness(self, query: GetFitness): diff --git a/portfolio/_strategy.py b/portfolio/_strategy.py index d6c5f682..64928beb 100644 --- a/portfolio/_strategy.py +++ b/portfolio/_strategy.py @@ -1,9 +1,11 @@ import asyncio -from typing import Dict, Tuple +from contextlib import asynccontextmanager +from typing import Dict, List, Tuple import numpy as np from sklearn.cluster import KMeans from sklearn.impute import KNNImputer +from sklearn.metrics import silhouette_score from sklearn.preprocessing import MinMaxScaler from core.models.strategy import Strategy @@ -12,12 +14,12 @@ class StrategyStorage: - def __init__(self, n_clusters=3): - self.kmeans = KMeans(n_clusters=n_clusters, n_init="auto") + def __init__(self, n_neighbors=3, max_data_size=1000): self.scaler = MinMaxScaler() - self.imputer = KNNImputer(n_neighbors=2) - self.data: Dict[Tuple[Symbol, Timeframe, Strategy], Tuple[np.array, int]] = {} - self.lock = asyncio.Lock() + self.imputer = KNNImputer(n_neighbors=n_neighbors) + self._data: Dict[Tuple[Symbol, Timeframe, Strategy], Tuple[np.array, int]] = {} + self._lock = asyncio.Lock() + self.max_data_size = max_data_size async def next( self, @@ -26,41 +28,77 @@ async def next( strategy: Strategy, metrics: np.array, ): - async with self.lock: + async with self._state() as state: key = (symbol, timeframe, strategy) - self.data[key] = (metrics, -1) + state[key] = (metrics, -1) - if len(self.data.keys()) >= self.kmeans.n_clusters: - self._update_clusters() + if len(state) > self.max_data_size: + oldest_key = next(iter(state)) + state.pop(oldest_key) + + if len(state) >= 2: + self._update_clusters(state) async def reset(self, symbol: Symbol, timeframe: Timeframe, strategy: Strategy): - async with self.lock: - self.data.pop((symbol, timeframe, strategy), None) + async with self._state() as state: + state.pop((symbol, timeframe, strategy), None) async def reset_all(self): - async with self.lock: - self.data = {} + async with self._state() as state: + state.clear() - async def get_top(self, num: int = 10): - async with self.lock: + async def get_top(self, num: int = 10) -> List[Tuple[Symbol, Timeframe, Strategy]]: + async with self._state() as state: sorted_strategies = sorted( - self.data.keys(), key=self._sorting_key, reverse=True + state.keys(), key=self._sorting_key, reverse=True ) + return sorted_strategies[:num] - def _update_clusters(self): - data_matrix = np.array([item[0] for item in self.data.values()]) + def _update_clusters(self, state): + if len(state) < 3: + return + + data_keys = list(state.keys()) + data_matrix = np.array([state[key][0] for key in data_keys]) + imputed_data = self.imputer.fit_transform(data_matrix) normalized_data = self.scaler.fit_transform(imputed_data) - cluster_indices = self.kmeans.fit_predict(normalized_data) + optimal_clusters = self._determine_optimal_clusters(normalized_data) + kmeans = KMeans(n_clusters=optimal_clusters, n_init="auto", random_state=1337) + cluster_indices = kmeans.fit_predict(normalized_data) - for (symbol, timeframe, strategy), idx in zip( - self.data.keys(), cluster_indices - ): - self.data[(symbol, timeframe, strategy)] = ( - self.data[(symbol, timeframe, strategy)][0], + for key, idx in zip(data_keys, cluster_indices): + state[key] = ( + state[key][0], idx, ) + def _determine_optimal_clusters(self, data: np.array) -> int: + max_clusters = min(len(data) - 1, 10) + min_clusters = min(2, max_clusters) + best_score = float("-inf") + optimal_clusters = min_clusters + + for k in range(min_clusters, max_clusters + 1): + kmeans = KMeans(n_clusters=k, n_init="auto", random_state=1337) + cluster_labels = kmeans.fit_predict(data) + + if len(np.unique(cluster_labels)) < k: + continue + + score = silhouette_score(data, cluster_labels) + + if score > best_score: + best_score = score + optimal_clusters = k + + return optimal_clusters + + @asynccontextmanager + async def _state(self): + async with self._lock: + yield self._data + def _sorting_key(self, key): - return self.data[key][1], self.data[key][0][0] + return self._data[key][1], self._data[key][0][0] diff --git a/position/_actor.py b/position/_actor.py index 4f5f8ea3..c42be477 100644 --- a/position/_actor.py +++ b/position/_actor.py @@ -3,13 +3,12 @@ import time from typing import Union -from core.actors import Actor +from core.actors import StrategyActor from core.events.backtest import BacktestEnded +from core.events.base import EventMeta from core.events.position import ( - BrokerPositionAdjusted, BrokerPositionClosed, BrokerPositionOpened, - PositionAdjusted, PositionClosed, PositionCloseRequested, PositionInitialized, @@ -22,17 +21,18 @@ ) from core.interfaces.abstract_config import AbstractConfig from core.interfaces.abstract_position_factory import AbstractPositionFactory +from core.models.risk_type import SignalRiskType from core.models.side import PositionSide from core.models.symbol import Symbol from core.models.timeframe import Timeframe +from core.queries.copilot import EvaluateSignal +from core.queries.ohlcv import TA, BackNBars from ._sm import LONG_TRANSITIONS, SHORT_TRANSITIONS, PositionStateMachine from ._state import PositionStorage SignalEvent = Union[GoLongSignalReceived, GoShortSignalReceived] -BrokerPositionEvent = Union[ - BrokerPositionOpened, BrokerPositionAdjusted, BrokerPositionClosed -] +BrokerPositionEvent = Union[BrokerPositionOpened, BrokerPositionClosed] ExitSignal = RiskThresholdBreached BacktestSignal = BacktestEnded @@ -41,14 +41,14 @@ logger = logging.getLogger(__name__) TIME_BUFF = 3 +N_BACK_BARS = 4 -class PositionActor(Actor): +class PositionActor(StrategyActor): _EVENTS = [ GoLongSignalReceived, GoShortSignalReceived, BrokerPositionOpened, - BrokerPositionAdjusted, BrokerPositionClosed, RiskThresholdBreached, BacktestEnded, @@ -69,10 +69,6 @@ def __init__( self.state = PositionStorage() self.config = config_service.get("position") - def pre_receive(self, event: PositionEvent) -> bool: - symbol, timeframe = self._get_event_key(event) - return self._symbol == symbol and self._timeframe == timeframe - async def on_receive(self, event): symbol, _ = self._get_event_key(event) @@ -90,13 +86,28 @@ async def on_receive(self, event): ) async def handle_signal_received(self, event: SignalEvent) -> bool: - if int(event.meta.timestamp) < int(time.time()) - TIME_BUFF: + if self._is_stale_signal(event.meta): logger.warn(f"Stale Signal: {event}, {time.time()}") return False async def create_and_store_position(event: SignalEvent): - position = await self.position_factory.create_position( - event.signal, event.ohlcv, event.entry_price, event.stop_loss + symbol, timeframe, ohlcv = ( + event.signal.symbol, + event.signal.timeframe, + event.signal.ohlcv, + ) + + back_bars = await self.ask(BackNBars(symbol, timeframe, ohlcv, N_BACK_BARS)) + ta = await self.ask(TA(symbol, timeframe, ohlcv)) + signal_risk_level = await self.ask( + EvaluateSignal(event.signal, back_bars, ta) + ) + + if signal_risk_level.type in {SignalRiskType.VERY_HIGH}: + return False + + position = await self.position_factory.create( + event.signal, signal_risk_level, ta ) await self.state.store_position(position) @@ -137,27 +148,6 @@ async def handle_position_opened(self, event: BrokerPositionOpened) -> bool: return False - async def handle_position_adjusted(self, event: BrokerPositionAdjusted) -> bool: - symbol, timeframe = self._get_event_key(event) - long_position, short_position = await self.state.retrieve_position( - symbol, timeframe - ) - - if ( - event.position.side == PositionSide.LONG - and long_position - and long_position.last_modified < event.meta.timestamp - ) or ( - event.position.side == PositionSide.SHORT - and short_position - and short_position.last_modified < event.meta.timestamp - ): - next_position = await self.state.update_stored_position(event.position) - await self.tell(PositionAdjusted(next_position)) - return True - - return False - async def handle_position_closed(self, event: BrokerPositionClosed) -> bool: symbol, timeframe = self._get_event_key(event) long_position, short_position = await self.state.retrieve_position( @@ -174,6 +164,9 @@ async def handle_position_closed(self, event: BrokerPositionClosed) -> bool: return False async def handle_exit_received(self, event: ExitSignal) -> bool: + if not event.position.has_risk: + logger.warn(f"Attempt to close not risky position: {event.position}") + symbol, timeframe = self._get_event_key(event) long_position, short_position = await self.state.retrieve_position( symbol, timeframe @@ -189,7 +182,7 @@ async def handle_exit_received(self, event: ExitSignal) -> bool: and short_position.last_modified < event.meta.timestamp ): next_position = await self.state.update_stored_position(event.position) - await self.tell(PositionCloseRequested(next_position, event.exit_price)) + await self.tell(PositionCloseRequested(next_position)) return True return False @@ -201,17 +194,17 @@ async def handle_backtest(self, event: BacktestSignal) -> bool: ) if long_position: - await self.tell( - PositionCloseRequested(long_position, long_position.entry_price) - ) + await self.tell(PositionCloseRequested(long_position)) if short_position: - await self.tell( - PositionCloseRequested(short_position, short_position.entry_price) - ) + await self.tell(PositionCloseRequested(short_position)) return True + @staticmethod + def _is_stale_signal(meta: EventMeta) -> bool: + return int(meta.timestamp) < int(time.time()) - TIME_BUFF + @staticmethod def _get_event_key(event: PositionEvent): signal = ( diff --git a/position/_position_factory.py b/position/_position_factory.py index 8348dc09..3903948f 100644 --- a/position/_position_factory.py +++ b/position/_position_factory.py @@ -1,60 +1,109 @@ +import numpy as np +from sklearn.linear_model import SGDRegressor +from sklearn.preprocessing import StandardScaler + +from core.interfaces.abstract_config import AbstractConfig +from core.interfaces.abstract_order_size_strategy import AbstractOrderSizeStrategy from core.interfaces.abstract_position_factory import AbstractPositionFactory -from core.interfaces.abstract_position_risk_strategy import AbstractPositionRiskStrategy -from core.interfaces.abstract_position_size_strategy import AbstractPositionSizeStrategy -from core.interfaces.abstract_position_take_profit_strategy import ( - AbstractPositionTakeProfitStrategy, -) from core.models.ohlcv import OHLCV -from core.models.order import Order, OrderStatus from core.models.position import Position -from core.models.side import PositionSide -from core.models.signal import Signal, SignalSide +from core.models.position_risk import PositionRisk +from core.models.profit_target import ProfitTarget +from core.models.signal import Signal +from core.models.signal_risk import SignalRisk +from core.models.ta import TechAnalysis class PositionFactory(AbstractPositionFactory): def __init__( self, - position_size_strategy: AbstractPositionSizeStrategy, - risk_strategy: AbstractPositionRiskStrategy, - take_profit_strategy: AbstractPositionTakeProfitStrategy, + config_service: AbstractConfig, + size_strategy: AbstractOrderSizeStrategy, ): super().__init__() - self.position_size_strategy = position_size_strategy - self.risk_strategy = risk_strategy - self.take_profit_strategy = take_profit_strategy + self.size_strategy = size_strategy + self.config = config_service.get("position") - async def create_position( + async def create( self, signal: Signal, - ohlcv: OHLCV, - entry_price: float, - stop_loss_price: float, + signal_risk: SignalRisk, + ta: TechAnalysis, ) -> Position: - symbol = signal.symbol - entry_price = round(entry_price, symbol.price_precision) + size = await self.size_strategy.calculate(signal) + + model, scaler = self._create_model(ta, signal.ohlcv) + + position_risk = PositionRisk(model=model, scaler=scaler).next(signal.ohlcv) + profit_target = ProfitTarget( + signal.side, signal.ohlcv.close, ta.volatility.yz[-1] + ) - position_side = ( - PositionSide.LONG if signal.side == SignalSide.BUY else PositionSide.SHORT + return Position( + signal=signal, + signal_risk=signal_risk, + position_risk=position_risk, + initial_size=size, + profit_target=profit_target, + expiration=self.config["trade_duration"] * 1000, ) - order_size = await self.position_size_strategy.calculate( - signal, entry_price, stop_loss_price + @staticmethod + def _create_model(ta: TechAnalysis, ohlcv: OHLCV): + model = SGDRegressor( + max_iter=1984, + tol=None, + warm_start=True, + alpha=0.001, + penalty="elasticnet", + l1_ratio=0.69, ) - adjusted_order_size = max(order_size, symbol.min_position_size) - rounded_order_size = round(adjusted_order_size, symbol.position_precision) + scaler = StandardScaler() - order = Order( - status=OrderStatus.PENDING, price=entry_price, size=rounded_order_size + hlcc4 = np.array( + ta.trend.hlcc4 + [(ohlcv.high + ohlcv.low + 2 * ohlcv.close) / 4.0] ) - position = Position( - signal, - position_side, - self.risk_strategy, - self.take_profit_strategy, - open_timestamp=ohlcv.timestamp, - stop_loss_price=stop_loss_price, + hlcc4_lagged_1 = np.roll(hlcc4, 1) + hlcc4_lagged_1[0] = hlcc4[0] + + hlcc4_lagged_2 = np.roll(hlcc4, 2) + hlcc4_lagged_2[:2] = hlcc4[:2] + + close = np.array(ta.trend.close + [ohlcv.close]) + + current_tr = max( + ohlcv.high - ohlcv.low, + abs(ohlcv.high - ta.trend.close[-1]), + abs(ohlcv.low - ta.trend.close[-1]), ) - return position.add_order(order) + true_range = np.array(ta.volatility.tr + [current_tr]) + + true_range_lagged_1 = np.roll(true_range, 1) + true_range_lagged_1[0] = true_range[0] + + true_range_lagged_2 = np.roll(true_range, 2) + true_range_lagged_2[:2] = true_range[:2] + + features = np.column_stack( + ( + hlcc4[:-2], + hlcc4_lagged_1[:-2], + hlcc4_lagged_2[:-2], + true_range[:-2], + true_range_lagged_1[:-2], + true_range_lagged_2[:-2], + hlcc4[2:] - hlcc4_lagged_1[2:], + true_range[2:] - true_range_lagged_1[2:], + ) + ) + + target = close[2:] + + features_scaled = scaler.fit_transform(features) + + model.fit(features_scaled, target) + + return model, scaler diff --git a/position/_sm.py b/position/_sm.py index f8ed230b..e5d793ce 100644 --- a/position/_sm.py +++ b/position/_sm.py @@ -118,20 +118,20 @@ def __init__( self._state: Dict[str, PositionState] = {} self._position_manager = position_manager self._transitions = transitions - self._state_lock = asyncio.Lock() + self._lock = asyncio.Lock() async def _get_state(self, symbol: Symbol) -> PositionState: - async with self._state_lock: + async with self._lock: return self._state.get(symbol, PositionState.IDLE) async def _set_state(self, symbol: Symbol, state: PositionState) -> None: - async with self._state_lock: + async with self._lock: self._state[symbol] = state async def process_event(self, symbol: Symbol, event: PositionEvent): current_state = await self._get_state(symbol) - if not self._is_valid_state(current_state, event): + if not self._is_valid_state(self._transitions, current_state, event): return next_state, handler_name = self._transitions[(current_state, type(event))] @@ -147,5 +147,8 @@ async def process_event(self, symbol: Symbol, event: PositionEvent): f"SM: symbol={symbol}, event={event}, curr_state={current_state}, next_state={next_state}" ) - def _is_valid_state(self, state: PositionState, event: PositionEvent) -> bool: - return (state, type(event)) in self._transitions + @staticmethod + def _is_valid_state( + transitions: Transitions, state: PositionState, event: PositionEvent + ) -> bool: + return (state, type(event)) in transitions diff --git a/position/risk/break_even.py b/position/risk/break_even.py deleted file mode 100644 index bd9c1ce0..00000000 --- a/position/risk/break_even.py +++ /dev/null @@ -1,103 +0,0 @@ -from typing import List, Tuple - -import numpy as np - -from core.interfaces.abstract_config import AbstractConfig -from core.interfaces.abstract_position_risk_strategy import AbstractPositionRiskStrategy -from core.models.ohlcv import OHLCV -from core.models.side import PositionSide - - -class PositionRiskBreakEvenStrategy(AbstractPositionRiskStrategy): - def __init__(self, config_service: AbstractConfig): - super().__init__() - self.config = config_service.get("position") - - def next( - self, - side: PositionSide, - entry_price: float, - take_profit_price: float, - stop_loss_price: float, - ohlcvs: List[Tuple[OHLCV]], - ) -> float: - ohlcvs = ohlcvs[:] - lookback = 14 - factor = 2.0 - - if len(ohlcvs) < lookback: - return stop_loss_price, take_profit_price - - atr = self._atr(ohlcvs, lookback) - price = self._price(ohlcvs) - - risk_value = atr[-1] * self.config["risk_factor"] - tp_threshold = atr[-1] * self.config["tp_factor"] - sl_threshold = atr[-1] * self.config["sl_factor"] - - dist = abs(entry_price - take_profit_price) * self.config["trl_factor"] - curr_dist = abs(entry_price - price) - - high = min(ohlcvs[-lookback:], key=lambda x: abs(x.high - price)).high - low = min(ohlcvs[-lookback:], key=lambda x: abs(x.low - price)).low - - upper_bb, lower_bb, middle_bb = self._bb(ohlcvs, lookback, factor) - bbw = (upper_bb - lower_bb) / middle_bb - - squeeze = bbw <= np.min(bbw[-lookback:]) - - next_stop_loss = stop_loss_price - - if side == PositionSide.LONG: - if squeeze[-1]: - next_take_profit = max(entry_price + risk_value, upper_bb[-1]) - else: - next_take_profit = max(entry_price + risk_value, high + tp_threshold) - - if curr_dist > dist and price > entry_price: - next_stop_loss = max(entry_price - risk_value, low - sl_threshold) - - elif side == PositionSide.SHORT: - if squeeze[-1]: - next_take_profit = min(entry_price - risk_value, lower_bb[-1]) - else: - next_take_profit = min(entry_price - risk_value, low - tp_threshold) - - if curr_dist > dist and price < entry_price: - next_stop_loss = min(entry_price + risk_value, high + sl_threshold) - - return next_stop_loss, next_take_profit - - @staticmethod - def _atr(ohlcvs: List[OHLCV], period: int) -> List[float]: - highs, lows, closes = ( - np.array([ohlcv.high for ohlcv in ohlcvs]), - np.array([ohlcv.low for ohlcv in ohlcvs]), - np.array([ohlcv.close for ohlcv in ohlcvs]), - ) - - prev_closes = np.roll(closes, 1) - - true_ranges = np.maximum( - highs - lows, np.abs(highs - prev_closes), np.abs(lows - prev_closes) - ) - - atr = np.zeros_like(true_ranges, dtype=float) - atr[period - 1] = np.mean(true_ranges[:period]) - - for i in range(period, len(true_ranges)): - atr[i] = np.divide((atr[i - 1] * (period - 1) + true_ranges[i]), period) - - return atr - - @staticmethod - def _price(ohlcvs: List[OHLCV]) -> float: - return (ohlcvs[-1].high + ohlcvs[-1].low + ohlcvs[-1].close) / 3.0 - - @staticmethod - def _bb(ohlcvs: List[OHLCV], period: int, factor: float) -> Tuple[List[float], List[float], List[float]]: - closes = np.array([ohlcv.close for ohlcv in ohlcvs]) - rolling_mean = np.convolve(closes, np.ones(period) / period, mode='valid') - rolling_std = factor * np.std([closes[i:i+period] for i in range(len(closes) - period + 1)], axis=1) - - return rolling_mean + rolling_std, rolling_mean - rolling_std, rolling_mean \ No newline at end of file diff --git a/position/risk/simple.py b/position/risk/simple.py deleted file mode 100644 index 902adb3e..00000000 --- a/position/risk/simple.py +++ /dev/null @@ -1,20 +0,0 @@ -from typing import List - -from core.interfaces.abstract_position_risk_strategy import AbstractPositionRiskStrategy -from core.models.ohlcv import OHLCV -from core.models.position import PositionSide - - -class PositionRiskSimpleStrategy(AbstractPositionRiskStrategy): - def __init__(self): - super().__init__() - - def next( - self, - _side: PositionSide, - _entry_price: float, - take_profit_price: float, - stop_loss_price: float, - _ohlcvs: List[OHLCV], - ) -> float: - return stop_loss_price, take_profit_price diff --git a/position/take_profit/__init__.py b/position/size/base.py similarity index 100% rename from position/take_profit/__init__.py rename to position/size/base.py diff --git a/position/size/fixed.py b/position/size/fixed.py index 0fe917b3..f35c2839 100644 --- a/position/size/fixed.py +++ b/position/size/fixed.py @@ -1,31 +1,27 @@ -from typing import Optional - -from core.interfaces.abstract_position_size_strategy import AbstractPositionSizeStrategy +from core.interfaces.abstract_order_size_strategy import AbstractOrderSizeStrategy from core.models.signal import Signal from core.models.size import PositionSizeType from core.queries.portfolio import GetPositionRisk -class PositionFixedSizeStrategy(AbstractPositionSizeStrategy): +class PositionFixedSizeStrategy(AbstractOrderSizeStrategy): def __init__(self): super().__init__() async def calculate( self, signal: Signal, - entry_price: float, - stop_loss_price: Optional[float] = None, ) -> float: risk_amount = await self.query(GetPositionRisk(signal, PositionSizeType.Fixed)) - if stop_loss_price is not None and entry_price is not None: - price_difference = abs(entry_price - stop_loss_price) + if signal.stop_loss is not None and signal.entry is not None: + price_difference = abs(signal.entry - signal.stop_loss) else: raise ValueError("Both entry_price and stop_loss_price must be provided.") if price_difference == 0: raise ValueError( - f"Price difference cannot be zero. For entry price {entry_price} and for stoploss {stop_loss_price}" + f"Price difference cannot be zero. For entry price {signal.entry} and for stoploss {signal.stop_loss}" ) position_size = risk_amount / price_difference diff --git a/position/size/kelly.py b/position/size/kelly.py index 41b56b50..f8260db6 100644 --- a/position/size/kelly.py +++ b/position/size/kelly.py @@ -1,12 +1,10 @@ -from typing import Optional - -from core.interfaces.abstract_position_size_strategy import AbstractPositionSizeStrategy +from core.interfaces.abstract_order_size_strategy import AbstractOrderSizeStrategy from core.models.signal import Signal from core.models.size import PositionSizeType from core.queries.portfolio import GetPositionRisk -class PositionKellySizeStrategy(AbstractPositionSizeStrategy): +class PositionKellySizeStrategy(AbstractOrderSizeStrategy): def __init__(self, kelly_factor: float = 0.033): super().__init__() self.kelly_factor = kelly_factor @@ -14,21 +12,19 @@ def __init__(self, kelly_factor: float = 0.033): async def calculate( self, signal: Signal, - entry_price: float, - stop_loss_price: Optional[float] = None, ) -> float: risk_amount = ( await self.query(GetPositionRisk(signal, PositionSizeType.Kelly)) ) * self.kelly_factor - if stop_loss_price is not None and entry_price is not None: - price_difference = abs(entry_price - stop_loss_price) + if signal.stop_loss is not None and signal.entry is not None: + price_difference = abs(signal.entry - signal.stop_loss) else: raise ValueError("Both entry_price and stop_loss_price must be provided.") if price_difference == 0: raise ValueError( - f"Price difference cannot be zero. For entry price {entry_price} and for stoploss {stop_loss_price}" + f"Price difference cannot be zero. For entry price {signal.entry} and for stoploss {signal.stop_loss}" ) position_size = risk_amount / price_difference diff --git a/position/size/optimal_f.py b/position/size/optimal_f.py index d8f8f990..0da960ff 100644 --- a/position/size/optimal_f.py +++ b/position/size/optimal_f.py @@ -1,37 +1,31 @@ -from typing import Optional - -from core.interfaces.abstract_position_size_strategy import AbstractPositionSizeStrategy +from core.interfaces.abstract_order_size_strategy import AbstractOrderSizeStrategy from core.models.signal import Signal from core.models.size import PositionSizeType from core.queries.portfolio import GetPositionRisk -class PositionOptimalFSizeStrategy(AbstractPositionSizeStrategy): +class PositionOptimalFSizeStrategy(AbstractOrderSizeStrategy): def __init__(self): super().__init__() async def calculate( self, signal: Signal, - entry_price: float, - stop_loss_price: Optional[float] = None, ) -> float: risk_amount = await self.query( GetPositionRisk(signal, PositionSizeType.Optimalf) ) - if stop_loss_price is not None and entry_price is not None: - price_difference = abs(entry_price - stop_loss_price) + if signal.stop_loss is not None and signal.entry is not None: + price_difference = abs(signal.entry - signal.stop_loss) else: raise ValueError("Both entry_price and stop_loss_price must be provided.") if price_difference == 0: raise ValueError( - f"Price difference cannot be zero. For entry price {entry_price} and for stoploss {stop_loss_price}" + f"Price difference cannot be zero. For entry price {signal.entry} and for stoploss {signal.stop_loss}" ) position_size = risk_amount / price_difference - print(f"Risk {risk_amount}, Size {position_size}") - return position_size diff --git a/position/take_profit/risk_reward.py b/position/take_profit/risk_reward.py deleted file mode 100644 index b92ac01b..00000000 --- a/position/take_profit/risk_reward.py +++ /dev/null @@ -1,23 +0,0 @@ -from core.interfaces.abstract_config import AbstractConfig -from core.interfaces.abstract_position_take_profit_strategy import ( - AbstractPositionTakeProfitStrategy, -) -from core.models.side import PositionSide - - -class PositionRiskRewardTakeProfitStrategy(AbstractPositionTakeProfitStrategy): - def __init__(self, config_service: AbstractConfig): - super().__init__() - self.config = config_service.get("position") - - def next(self, side: PositionSide, entry_price: float, stop_loss_price: float): - if side == PositionSide.LONG: - return ( - entry_price - + (entry_price - stop_loss_price) * self.config["risk_reward_ratio"] - ) - - return ( - entry_price - - (stop_loss_price - entry_price) * self.config["risk_reward_ratio"] - ) diff --git a/quant.py b/quant.py index f904b934..79088940 100644 --- a/quant.py +++ b/quant.py @@ -6,8 +6,8 @@ import uvloop from dotenv import load_dotenv +from copilot import CopilotActor from core.models.exchange import ExchangeType -from core.models.strategy import StrategyType from exchange import ExchangeFactory, WSFactory from executor import OrderExecutorActorFactory from feed import FeedActorFactory @@ -15,14 +15,19 @@ from infrastructure.event_dispatcher.event_dispatcher import EventDispatcher from infrastructure.logger import configure_logging from infrastructure.shutdown import GracefulShutdown +from market import MarketActor from optimization import StrategyOptimizerFactory from portfolio import Portfolio from position import PositionActorFactory, PositionFactory -from position.risk.break_even import PositionRiskBreakEvenStrategy from position.size.fixed import PositionFixedSizeStrategy -from position.take_profit.risk_reward import PositionRiskRewardTakeProfitStrategy from risk import RiskActorFactory -from service import EnvironmentSecretService, SignalService, WasmFileService +from service import ( + EnvironmentSecretService, + LLMService, + SignalService, + TimeSeriesService, + WasmManager, +) from sor import SmartRouter from strategy import SignalActorFactory from strategy.generator import StrategyGeneratorFactory @@ -39,6 +44,7 @@ LOG_DIR = os.getenv("LOG_DIR") LOG_LEVEL = os.getenv("LOG_LEVEL", "INFO").upper() WASM_FOLDER = os.getenv("WASM_FOLDER") +COPILOT_MODEL_PATH = os.getenv("COPILOT_MODEL_PATH") configure_logging(LOG_LEVEL) @@ -55,7 +61,11 @@ async def main(): config_service = ConfigService() config_service.load(config_path=f"config.{REGIME}.ini") - config = {"bus": {"num_workers": os.cpu_count()}, "store": {"base_dir": LOG_DIR}} + config = { + "bus": {"num_workers": os.cpu_count()}, + "store": {"base_dir": LOG_DIR}, + "copilot": {"model_path": COPILOT_MODEL_PATH}, + } config_service.update(config) @@ -63,23 +73,24 @@ async def main(): exchange_factory = ExchangeFactory(EnvironmentSecretService()) ws_factory = WSFactory(EnvironmentSecretService()) - - Portfolio(config_service) - SmartRouter(exchange_factory, config_service) - + wasm = WasmManager(WASM_FOLDER) position_factory = PositionFactory( + config_service, PositionFixedSizeStrategy(), - PositionRiskBreakEvenStrategy(config_service), - PositionRiskRewardTakeProfitStrategy(config_service), ) + ts_service = TimeSeriesService(wasm) + MarketActor(ts_service).start() + CopilotActor(LLMService(config_service)).start() + Portfolio(config_service) + SmartRouter(exchange_factory, config_service) - signal_actor_factory = SignalActorFactory( - SignalService(WasmFileService(WASM_FOLDER)) - ) + signal_actor_factory = SignalActorFactory(SignalService(wasm)) position_actor_factory = PositionActorFactory(position_factory, config_service) risk_actor_factory = RiskActorFactory(config_service) executor_actor_factory = OrderExecutorActorFactory() - feed_actor_factory = FeedActorFactory(exchange_factory, ws_factory, config_service) + feed_actor_factory = FeedActorFactory( + exchange_factory, ws_factory, ts_service, config_service + ) trend_context = SystemContext( signal_actor_factory, @@ -89,12 +100,11 @@ async def main(): feed_actor_factory, StrategyGeneratorFactory(config_service), StrategyOptimizerFactory(config_service), - strategy_type=StrategyType.TREND, exchange_type=ExchangeType.BYBIT, config_service=config_service, ) - trend_system_a = BacktestSystem(trend_context) + backtest_system = BacktestSystem(trend_context) trading_system = TradingSystem( signal_actor_factory, @@ -106,21 +116,23 @@ async def main(): exchange_type=ExchangeType.BYBIT, ) - trend_system_a_task = asyncio.create_task(trend_system_a.start()) + backtest_system_task = asyncio.create_task(backtest_system.start()) trading_system_task = asyncio.create_task(trading_system.start()) shutdown_task = asyncio.create_task(graceful_shutdown.wait_for_exit_signal()) try: logging.info("Started") - await asyncio.gather(*[trend_system_a_task, shutdown_task]) + await asyncio.gather( + *[backtest_system_task, trading_system_task, shutdown_task] + ) finally: logging.info("Closing...") shutdown_task.cancel() - trend_system_a_task.cancel() - + backtest_system_task.cancel() trading_system_task.cancel() trading_system.stop() + backtest_system.stop() await event_bus.stop() await event_bus.wait() @@ -128,5 +140,5 @@ async def main(): logging.info("Finished.") -with asyncio.Runner() as runner: +with asyncio.Runner(loop_factory=uvloop.new_event_loop) as runner: runner.run(main()) diff --git a/risk/_actor.py b/risk/_actor.py index ef4054bf..308b9681 100644 --- a/risk/_actor.py +++ b/risk/_actor.py @@ -1,15 +1,19 @@ import asyncio -from collections import deque -from typing import List, Optional, Union +import logging +import random +from typing import Optional, Union -from core.actors import Actor +import numpy as np + +from core.actors import StrategyActor +from core.events.base import EventMeta from core.events.ohlcv import NewMarketDataReceived from core.events.position import ( PositionAdjusted, PositionClosed, PositionOpened, ) -from core.events.risk import RiskAdjustRequested, RiskThresholdBreached, RiskType +from core.events.risk import RiskThresholdBreached from core.events.signal import ( ExitLongSignalReceived, ExitShortSignalReceived, @@ -17,29 +21,51 @@ GoShortSignalReceived, ) from core.interfaces.abstract_config import AbstractConfig +from core.mixins import EventHandlerMixin from core.models.ohlcv import OHLCV from core.models.position import Position from core.models.side import PositionSide from core.models.symbol import Symbol from core.models.timeframe import Timeframe +from core.queries.copilot import EvaluateSession +from core.queries.ohlcv import TA, NextBar, PrevBar + +TrailEvent = Union[ + GoLongSignalReceived, + GoShortSignalReceived, + ExitLongSignalReceived, + ExitShortSignalReceived, +] RiskEvent = Union[ NewMarketDataReceived, PositionOpened, PositionAdjusted, PositionClosed, - ExitLongSignalReceived, - ExitShortSignalReceived, - GoLongSignalReceived, - GoShortSignalReceived, + TrailEvent, ] +MAX_ATTEMPTS = 16 +DEFAULT_MAX_BARS = 16 +MAX_CONSECUTIVE_ANOMALIES = 3 +DYNAMIC_THRESHOLD_MULTIPLIER = 8.0 +DEFAULT_ANOMALY_THRESHOLD = 6.0 -class RiskActor(Actor): + +def _ema(values, alpha=0.1): + ema = [values[0]] + for value in values[1:]: + ema.append(ema[-1] * (1 - alpha) + value * alpha) + return np.array(ema) + + +logger = logging.getLogger(__name__) + + +class RiskActor(StrategyActor, EventHandlerMixin): _EVENTS = [ NewMarketDataReceived, PositionOpened, - PositionAdjusted, PositionClosed, ExitLongSignalReceived, ExitShortSignalReceived, @@ -48,350 +74,177 @@ class RiskActor(Actor): ] def __init__( - self, - symbol: Symbol, - timeframe: Timeframe, - config_service: AbstractConfig, + self, symbol: Symbol, timeframe: Timeframe, config_service: AbstractConfig ): super().__init__(symbol, timeframe) - self.lock = asyncio.Lock() + EventHandlerMixin.__init__(self) + self._register_event_handlers() + self._lock = asyncio.Lock() self._position = (None, None) - self._ohlcv = deque(maxlen=120) self.config = config_service.get("position") - - def pre_receive(self, event: RiskEvent): - symbol, timeframe = self._get_event_key(event) - return self._symbol == symbol and self._timeframe == timeframe + self.max_bars = DEFAULT_MAX_BARS + self.anomaly_threshold = DEFAULT_ANOMALY_THRESHOLD + self.consc_anomaly_counter = 1 async def on_receive(self, event: RiskEvent): - handlers = { - NewMarketDataReceived: self._handle_market_risk, - PositionOpened: self._open_position, - PositionClosed: self._close_position, - PositionAdjusted: self._adjust_position, - GoLongSignalReceived: [self._handle_reverse, self._handle_scale_in], - GoShortSignalReceived: [self._handle_reverse, self._handle_scale_in], - ExitLongSignalReceived: self._handle_signal_exit, - ExitShortSignalReceived: self._handle_signal_exit, - } - - handler = handlers.get(type(event)) - - if handler: - if isinstance(handler, list): - for h in handler: - await h(event) - else: - await handler(event) + return await self.handle_event(event) - async def _open_position(self, event: PositionOpened): - async with self.lock: - long_position, short_position = self._position + def _register_event_handlers(self): + self.register_handler(NewMarketDataReceived, self._handle_position_risk) + self.register_handler(PositionOpened, self._open_position) + self.register_handler(PositionClosed, self._close_position) + self.register_handler(ExitLongSignalReceived, self._trail_position) + self.register_handler(ExitShortSignalReceived, self._trail_position) + self.register_handler(GoLongSignalReceived, self._trail_position) + self.register_handler(GoShortSignalReceived, self._trail_position) - self._position = ( - event.position - if event.position.side == PositionSide.LONG - else long_position, - event.position - if event.position.side == PositionSide.SHORT - else short_position, - ) - - async def _adjust_position(self, event: PositionAdjusted): - async with self.lock: - long_position, short_position = self._position - - self._position = ( - event.position - if event.position.side == PositionSide.LONG - else long_position, - event.position - if event.position.side == PositionSide.SHORT - else short_position, - ) + async def _open_position(self, event: PositionOpened): + async with self._lock: + match event.position.side: + case PositionSide.LONG: + self._position = (event.position, self._position[1]) + case PositionSide.SHORT: + self._position = (self._position[0], event.position) async def _close_position(self, event: PositionClosed): - async with self.lock: - long_position, short_position = self._position - - self._position = ( - None if event.position.side == PositionSide.LONG else long_position, - None if event.position.side == PositionSide.SHORT else short_position, - ) + async with self._lock: + match event.position.side: + case PositionSide.LONG: + self._position = (None, self._position[1]) + case PositionSide.SHORT: + self._position = (self._position[0], None) + + async def _handle_position_risk(self, event: NewMarketDataReceived): + async with self._lock: + processed_positions = list(self._position) + num_positions = len(self._position) + + indexes = list(range(num_positions)) + random.shuffle(indexes) + + current_index = 0 + + for _ in range(num_positions): + shuffled_index = indexes[current_index] + processed_positions[shuffled_index] = await self._process_market( + event, self._position[shuffled_index] + ) - async def _handle_market_risk(self, event: NewMarketDataReceived): - async with self.lock: - self._ohlcv.append(event.ohlcv) - visited = set() - ohlcvs = [] + current_index = (current_index + 1) % num_positions - for i in range(len(self._ohlcv)): - if self._ohlcv[i].timestamp not in visited: - ohlcvs.append(self._ohlcv[i]) - visited.add(self._ohlcv[i].timestamp) + self._position = tuple(processed_positions) - ohlcvs = sorted(ohlcvs, key=lambda x: x.timestamp) + async def _trail_position(self, event: TrailEvent): + async with self._lock: - long_position, short_position = self._position + async def handle_trail(position: Position, risk_bar: OHLCV): + logger.info("Trail event") - if long_position or short_position: - long_position, short_position = await asyncio.gather( - *[ - self._process_position(long_position, ohlcvs), - self._process_position(short_position, ohlcvs), - ] - ) + ta = await self.ask(TA(self.symbol, self.timeframe, risk_bar)) + return position.trail(ta) - self._position = (long_position, short_position) + async def process_trail(position: Position, event_meta: EventMeta): + if ( + position + and not position.has_risk + and position.last_modified < event_meta.timestamp + ): + return await handle_trail(position, position.risk_bar) + return position - async def _handle_reverse( - self, event: Union[GoLongSignalReceived, GoShortSignalReceived] - ): - async with self.lock: long_position, short_position = self._position - if ( - isinstance(event, GoShortSignalReceived) - and long_position - and not short_position - ): - await self._process_reverse_exit(long_position, event.entry_price) - if ( - isinstance(event, GoLongSignalReceived) - and short_position - and not long_position - ): - await self._process_reverse_exit(short_position, event.entry_price) - - async def _handle_scale_in( - self, event: Union[GoLongSignalReceived, GoShortSignalReceived] - ): - async with self.lock: - long_position, short_position = self._position + if isinstance(event, (ExitLongSignalReceived, GoShortSignalReceived)): + long_position = await process_trail(long_position, event.meta) - if ( - isinstance(event, GoLongSignalReceived) - and long_position - and long_position.adj_count < self.config["max_scale_in"] - and long_position.entry_price < event.entry_price - ): - await self.tell(RiskAdjustRequested(long_position, event.entry_price)) - if ( - isinstance(event, GoShortSignalReceived) - and short_position - and short_position.adj_count < self.config["max_scale_in"] - and short_position.entry_price > event.entry_price - ): - await self.tell(RiskAdjustRequested(short_position, event.entry_price)) - - async def _handle_signal_exit( - self, event: Union[ExitLongSignalReceived, ExitShortSignalReceived] - ): - async with self.lock: - long_position, short_position = self._position + elif isinstance(event, (ExitShortSignalReceived, GoLongSignalReceived)): + short_position = await process_trail(short_position, event.meta) - if isinstance(event, ExitLongSignalReceived) and long_position: - await self._process_signal_exit( - long_position, - event.exit_price, - ) - if isinstance(event, ExitShortSignalReceived) and short_position: - await self._process_signal_exit( - short_position, - event.exit_price, - ) + self._position = (long_position, short_position) - async def _process_position( - self, position: Optional[Position], ohlcvs: List[OHLCV] + async def _process_market( + self, event: NewMarketDataReceived, position: Optional[Position] ): next_position = position - if position and len(ohlcvs) > 1: - next_position = position.next(ohlcvs) - exit_event = self._create_exit_event(next_position, ohlcvs[-1]) + if position and not position.has_risk: + prev_bar = next_position.risk_bar + next_bar = await self.ask(NextBar(self.symbol, self.timeframe, prev_bar)) - if exit_event: - await self.tell(exit_event) + if not next_bar: + next_bar = event.ohlcv - return next_position + diff = event.ohlcv.timestamp - next_bar.timestamp + attempts = 0 - def _create_exit_event(self, position: Position, ohlcv: OHLCV): - expiration = ( - position.open_timestamp + self.config["trade_duration"] * 1000 - ) - ohlcv.timestamp - - risk_type = None - - if position.side == PositionSide.LONG: - if self._is_long_expires(position, expiration, ohlcv): - risk_type = RiskType.TIME - elif self._is_long_meets_tp(position, ohlcv): - risk_type = RiskType.TP - elif self._is_long_meets_sl(position, ohlcv) and not self._position[1]: - risk_type = RiskType.SL - elif position.side == PositionSide.SHORT: - if self._is_short_expires(position, expiration, ohlcv): - risk_type = RiskType.TIME - elif self._is_short_meets_tp(position, ohlcv): - risk_type = RiskType.TP - elif self._is_short_meets_sl(position, ohlcv) and not self._position[0]: - risk_type = RiskType.SL - - if risk_type: - exit_price = ( - self._long_exit_price - if position.side == PositionSide.LONG - else self._short_exit_price - )(position, ohlcv) - - return RiskThresholdBreached(position, exit_price, risk_type) - - return None - - async def _process_reverse_exit( - self, - position: Position, - price: float, - ): - take_profit_price = position.take_profit_price - stop_loss_price = position.stop_loss_price + while diff < 0 and attempts < MAX_ATTEMPTS: + new_prev_bar = await self.ask( + PrevBar(self.symbol, self.timeframe, prev_bar) + ) + attempts += 1 + + if new_prev_bar: + diff = event.ohlcv.timestamp - new_prev_bar.timestamp + prev_bar = new_prev_bar + + bars = [next_bar] + + if diff > 0: + for _ in range(int(self.max_bars)): + next_bar = await self.ask( + NextBar(self.symbol, self.timeframe, prev_bar) + ) + + if not next_bar: + break + + bars.append(next_bar) + prev_bar = next_bar + + for bar in sorted(bars, key=lambda x: x.timestamp): + ohlcv = next_position.position_risk.ohlcv + ts = np.array([o.timestamp for o in ohlcv]) + + if len(ts) > 2: + ts_diff = _ema(np.diff(ts)) + mean, std = np.mean(ts_diff), max( + np.std(ts_diff), np.finfo(float).eps + ) + + current_diff = abs(bar.timestamp - ts[-1]) + anomaly = (current_diff - mean) / std + anomaly = np.clip( + anomaly, + -9.0 * DEFAULT_ANOMALY_THRESHOLD, + 9.0 * DEFAULT_ANOMALY_THRESHOLD, + ) + + if abs(anomaly) > self.anomaly_threshold: + self.consc_anomaly_counter += 1 + + if self.consc_anomaly_counter > MAX_CONSECUTIVE_ANOMALIES: + logger.warn( + "Too many consecutive anomalies, increasing threshold temporarily" + ) + self.anomaly_threshold *= DYNAMIC_THRESHOLD_MULTIPLIER + self.max_bars *= DYNAMIC_THRESHOLD_MULTIPLIER + self.consc_anomaly_counter = 1 + await asyncio.sleep(0.00001) + continue + else: + self.anomaly_threshold = DEFAULT_ANOMALY_THRESHOLD + self.max_bars = DEFAULT_MAX_BARS + self.consc_anomaly_counter = 1 + + ta = await self.ask(TA(self.symbol, self.timeframe, bar)) + session_risk = await self.ask( + EvaluateSession(next_position.side, ohlcv, ta) + ) - distance_to_take_profit = abs(price - take_profit_price) - distance_to_stop_loss = abs(price - stop_loss_price) + next_position = next_position.next(bar, ta, session_risk) - if distance_to_take_profit < distance_to_stop_loss: - await self.tell(RiskThresholdBreached(position, price, RiskType.REVERSE)) + if next_position.has_risk: + await self.tell(RiskThresholdBreached(next_position)) + break - async def _process_signal_exit( - self, - position: Position, - price: float, - ): - side = position.side - take_profit_price = position.take_profit_price - stop_loss_price = position.stop_loss_price - entry_price = position.entry_price - - price_exceeds_take_profit = ( - side == PositionSide.LONG and price > take_profit_price - ) or (side == PositionSide.SHORT and price < take_profit_price) - - price_exceeds_stop_loss = ( - side == PositionSide.LONG and price < stop_loss_price - ) or (side == PositionSide.SHORT and price > stop_loss_price) - - if price_exceeds_take_profit or price_exceeds_stop_loss: - return - - distance_to_take_profit = abs(price - take_profit_price) - distance_to_stop_loss = abs(price - stop_loss_price) - trailing_dist = abs(price - entry_price) - - ttp = distance_to_take_profit * self.config["trl_factor"] - - if distance_to_take_profit < distance_to_stop_loss and trailing_dist > ttp: - await self.tell(RiskThresholdBreached(position, price, RiskType.SIGNAL)) - return position - - @staticmethod - def _long_exit_price(position: Position, ohlcv: OHLCV): - if ( - position.stop_loss_price is not None - and ohlcv.low <= position.stop_loss_price - ): - return ohlcv.low - if ( - position.take_profit_price is not None - and ohlcv.high >= position.take_profit_price - ): - return ohlcv.high - - return ohlcv.close - - @staticmethod - def _is_long_expires( - position: Position, - expiration: int, - ohlcv: OHLCV, - ) -> bool: - return ( - expiration <= 0 - and position.entry_price > max(ohlcv.close, ohlcv.high) * 1.1 - ) - - @staticmethod - def _is_long_meets_tp( - position: Position, - ohlcv: OHLCV, - ): - return ( - position.take_profit_price is not None - and ohlcv.high > position.take_profit_price - ) - - @staticmethod - def _is_long_meets_sl( - position: Position, - ohlcv: OHLCV, - ): - return ( - position.stop_loss_price is not None - and ohlcv.low < position.stop_loss_price - ) - - @staticmethod - def _short_exit_price(position: Position, ohlcv: OHLCV): - if ( - position.stop_loss_price is not None - and ohlcv.high >= position.stop_loss_price - ): - return ohlcv.high - if ( - position.take_profit_price is not None - and ohlcv.low <= position.take_profit_price - ): - return ohlcv.low - - return ohlcv.close - - @staticmethod - def _is_short_expires( - position: Position, - expiration: int, - ohlcv: OHLCV, - ) -> bool: - return expiration <= 0 and position.entry_price * 1.1 < min( - ohlcv.close, ohlcv.low - ) - - @staticmethod - def _is_short_meets_tp( - position: Position, - ohlcv: OHLCV, - ): - return ( - position.take_profit_price is not None - and ohlcv.low < position.take_profit_price - ) - - @staticmethod - def _is_short_meets_sl( - position: Position, - ohlcv: OHLCV, - ): - return ( - position.stop_loss_price is not None - and ohlcv.high > position.stop_loss_price - ) - - @staticmethod - def _get_event_key(event: RiskEvent): - signal = ( - event.signal - if hasattr(event, "signal") - else event.position.signal - if hasattr(event, "position") - else event - ) - - return (signal.symbol, signal.timeframe) + return next_position diff --git a/service/__init__.py b/service/__init__.py index 5f6ffd16..7dc0c43f 100644 --- a/service/__init__.py +++ b/service/__init__.py @@ -1,5 +1,13 @@ from ._env_secret import EnvironmentSecretService +from ._llm import LLMService from ._signal import SignalService -from ._wasm_file import WasmFileService +from ._timeseries import TimeSeriesService +from ._wasm import WasmManager -__all__ = [EnvironmentSecretService, SignalService, WasmFileService] +__all__ = [ + EnvironmentSecretService, + SignalService, + LLMService, + WasmManager, + TimeSeriesService, +] diff --git a/service/_env_secret.py b/service/_env_secret.py index afe6e8f5..dc04f908 100644 --- a/service/_env_secret.py +++ b/service/_env_secret.py @@ -1,17 +1,34 @@ +import logging import os +from typing import Optional from core.interfaces.abstract_secret_service import AbstractSecretService +logger = logging.getLogger(__name__) + class EnvironmentSecretService(AbstractSecretService): def __init__(self): super().__init__() - def get_api_key(self, identifier: str) -> str: - return os.environ.get(identifier + "_API_KEY") + def get_api_key(self, identifier: str) -> Optional[str]: + return self._get_env_variable(self._format_key(identifier, "API_KEY")) + + def get_secret(self, identifier: str) -> Optional[str]: + return self._get_env_variable(self._format_key(identifier, "API_SECRET")) + + def get_wss(self, identifier: str) -> Optional[str]: + return self._get_env_variable(self._format_key(identifier, "WSS")) + + @staticmethod + def _format_key(identifier: str, key_type: str) -> str: + return f"{identifier.upper()}_{key_type}" + + @staticmethod + def _get_env_variable(key: str) -> Optional[str]: + value = os.environ.get(key) - def get_secret(self, identifier: str) -> str: - return os.environ.get(identifier + "_API_SECRET") + if value is None: + logger.warning(f"Environment variable '{key}' is not set.") - def get_wss(self, identifier: str) -> str: - return os.environ.get(identifier + "_WSS") + return value diff --git a/service/_llm.py b/service/_llm.py new file mode 100644 index 00000000..263f995b --- /dev/null +++ b/service/_llm.py @@ -0,0 +1,47 @@ +import asyncio +from typing import Any, Dict + +from llama_cpp import Llama + +from core.interfaces.abstract_config import AbstractConfig +from core.interfaces.abstract_llm_service import AbstractLLMService + + +class LLMService(AbstractLLMService): + def __init__(self, config_service: AbstractConfig): + super().__init__() + self.config = config_service.get("copilot") + self._llm = self._initialize_llm(self.config) + self._lock = asyncio.Semaphore(3) + + async def call( + self, system_prompt: str, user_prompt: str, stop_words: tuple[str] = ("<|end|>") + ) -> str: + async with self._lock: + llama_input = { + "prompt": f"<|system|>{system_prompt}<|end|><|user|>\n{user_prompt}<|end|>\n<|assistant|>", + "max_tokens": self.config["max_tokens"], + "temperature": self.config["temperature"], + "stop": list(stop_words), + "stream": True, + "echo": False, + } + + answer = "" + + for output in self._llm(**llama_input): + answer += output["choices"][0]["text"] + + return answer + + @staticmethod + def _initialize_llm(config: Dict[str, Any]) -> Llama: + return Llama( + model_path=config["model_path"], + n_ctx=config["n_ctx"], + n_threads=config["n_threads"], + n_gpu_layers=config["n_gpu_layers"], + n_batch=config["n_batch"], + seed=1337, + verbose=False, + ) diff --git a/service/_signal.py b/service/_signal.py index 48fbd739..f81b2b32 100644 --- a/service/_signal.py +++ b/service/_signal.py @@ -1,57 +1,37 @@ from ctypes import addressof, c_ubyte -from typing import Optional +from typing import Tuple -from wasmtime import Instance, Linker, Store, WasiConfig +import orjson as json from core.interfaces.abstract_signal_service import AbstractSignalService -from core.interfaces.abstract_wasm_service import AbstractWasmService -from core.models.strategy import Strategy, StrategyType +from core.interfaces.abstract_wasm_manager import AbstractWasmManager +from core.models.strategy import Strategy from core.models.strategy_ref import StrategyRef +from core.models.wasm_type import WasmType class SignalService(AbstractSignalService): - def __init__(self, wasm_service: AbstractWasmService): - self.wasm_service = wasm_service - self.store = Store() - wasi_config = WasiConfig() - wasi_config.wasm_multi_value = True - self.store.set_wasi(wasi_config) - self.linker = Linker(self.store.engine) - self.linker.define_wasi() - self.instance: Optional[Instance] = None - - def _load(self, type: StrategyType): - module = self.wasm_service.get_module(type, self.store.engine) - self.instance = self.linker.instantiate(self.store, module) + def __init__(self, wasm_manager: AbstractWasmManager): + super().__init__() + self._wasm_manager = wasm_manager + self._wasm = WasmType.TREND def register(self, strategy: Strategy) -> StrategyRef: - if not self.instance: - self._load(strategy.type) - - exports = self.instance.exports(self.store) - - data = { - "signal": strategy.parameters[0], - "filter": strategy.parameters[1], - "pulse": strategy.parameters[2], - "baseline": strategy.parameters[3], - "stoploss": strategy.parameters[4], - "exit": strategy.parameters[5], - } - - allocation_data = { - key: self._allocate_and_write(self.store, exports, data) - for key, data in data.items() - } + instance, store = self._wasm_manager.get_instance(self._wasm) + exports = instance.exports(store) + allocation_data = [ + self._write(store, exports, json.dumps(param)) + for param in strategy.parameters + ] id = exports["register"]( - self.store, *[item for pair in allocation_data.values() for item in pair] + store, *[item for pair in allocation_data for item in pair] ) - return StrategyRef(id=id, instance_ref=self.instance, store_ref=self.store) + return StrategyRef(id=id, instance_ref=instance, store_ref=store) @staticmethod - def _allocate_and_write(store, exports, data: bytes) -> (int, int): + def _write(store, exports, data: bytes) -> Tuple[int]: ptr = exports["allocate"](store, len(data)) memory = exports["memory"] diff --git a/service/_timeseries.py b/service/_timeseries.py new file mode 100644 index 00000000..aab477e1 --- /dev/null +++ b/service/_timeseries.py @@ -0,0 +1,58 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import Optional + +from core.interfaces.abstract_timeseries import AbstractTimeSeriesService +from core.interfaces.abstract_wasm_manager import AbstractWasmManager +from core.models.ohlcv import OHLCV +from core.models.symbol import Symbol +from core.models.timeframe import Timeframe +from core.models.timeseries_ref import TimeSeriesRef +from core.models.wasm_type import WasmType + + +class TimeSeriesService(AbstractTimeSeriesService): + def __init__(self, wasm_manager: AbstractWasmManager): + self._bucket = {} + self._lock = asyncio.Lock() + self._wasm_manager = wasm_manager + self._wasm = WasmType.TIMESERIES + + async def upsert(self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV): + async with self._get_timeseries(symbol, timeframe) as timeseries: + timeseries.add(bar) + + async def next_bar( + self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV + ) -> Optional[OHLCV]: + async with self._get_timeseries(symbol, timeframe) as timeseries: + return timeseries.next_bar(bar) + + async def prev_bar( + self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV + ) -> Optional[OHLCV]: + async with self._get_timeseries(symbol, timeframe) as timeseries: + return timeseries.prev_bar(bar) + + async def back_n_bars( + self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV, n: int + ) -> Optional[OHLCV]: + async with self._get_timeseries(symbol, timeframe) as timeseries: + return timeseries.back_n_bars(bar, n) + + async def ta(self, symbol: Symbol, timeframe: Timeframe, bar: OHLCV): + async with self._get_timeseries(symbol, timeframe) as timeseries: + return timeseries.ta(bar) + + @asynccontextmanager + async def _get_timeseries(self, symbol: Symbol, timeframe: Timeframe): + async with self._lock: + key = (symbol, timeframe) + if key not in self._bucket: + instance, store = self._wasm_manager.get_instance(self._wasm) + exports = instance.exports(store) + id = exports["timeseries_register"](store) + self._bucket[key] = TimeSeriesRef( + id=id, instance_ref=instance, store_ref=store + ) + yield self._bucket[key] diff --git a/service/_wasm.py b/service/_wasm.py new file mode 100644 index 00000000..5c9d8341 --- /dev/null +++ b/service/_wasm.py @@ -0,0 +1,58 @@ +import os +from functools import lru_cache + +from wasmtime import Engine, Linker, Module, Store, WasiConfig + +from core.interfaces.abstract_wasm_manager import AbstractWasmManager +from core.models.wasm_type import WasmType + + +class WasmManager(AbstractWasmManager): + _type = { + WasmType.TREND: "trend_follow.wasm", + WasmType.TIMESERIES: "timeseries.wasm", + } + + def __init__(self, dir="wasm"): + super().__init__() + self.dir = dir + + @lru_cache(maxsize=None) + def get_instance(self, wasm_type: WasmType): + return self._load_instance(wasm_type) + + def _load_instance(self, wasm_type: WasmType): + store = Store() + + wasi = self._configure_wasi() + + store.set_wasi(wasi) + + module = self._get_module(wasm_type, store.engine) + linker = self._configure_linker(store.engine) + + instance = linker.instantiate(store, module) + + return (instance, store) + + def _configure_wasi(self) -> WasiConfig: + wasi_config = WasiConfig() + wasi_config.wasm_multi_value = True + wasi_config.inherit_stdout() + return wasi_config + + def _configure_linker(self, engine: Engine) -> Linker: + linker = Linker(engine) + linker.define_wasi() + return linker + + def _get_module(self, type: WasmType, engine: Engine) -> Module: + if type not in WasmType: + raise ValueError(f"Unknown Strategy: {type}") + + wasm_path = f"./{self.dir}/{self._type.get(type)}" + + if not os.path.exists(wasm_path): + raise FileNotFoundError(f"WASM file not found: {wasm_path}") + + return Module.from_file(engine, wasm_path) diff --git a/service/_wasm_file.py b/service/_wasm_file.py deleted file mode 100644 index eaef1805..00000000 --- a/service/_wasm_file.py +++ /dev/null @@ -1,20 +0,0 @@ -from wasmtime import Engine, Module - -from core.interfaces.abstract_wasm_service import AbstractWasmService -from core.models.strategy import StrategyType - - -class WasmFileService(AbstractWasmService): - _type = {StrategyType.TREND: "trend_follow.wasm"} - - def __init__(self, dir="wasm"): - super().__init__() - self.dir = dir - - def get_module(self, type: StrategyType, engine: Engine) -> Module: - if type not in StrategyType: - raise ValueError(f"Unknown Strategy: {type}") - - wasm_path = f"./{self.dir}/{self._type.get(type)}" - - return Module.from_file(engine, wasm_path) diff --git a/sor/_router.py b/sor/_router.py index caefd077..583aef9f 100644 --- a/sor/_router.py +++ b/sor/_router.py @@ -2,7 +2,6 @@ import time from core.commands.broker import ( - AdjustPosition, ClosePosition, OpenPosition, UpdateSettings, @@ -98,9 +97,11 @@ async def open_position(self, command: OpenPosition): logger.info(f"Try to open position: {position}") symbol = position.signal.symbol - position_size = position.pending_size - stop_loss = position.stop_loss_price - entry_price = position.pending_price + stop_loss = position.stop_loss + pending_order = position.entry_order() + + entry_price = pending_order.price + size = pending_order.size if self.exchange.fetch_position(symbol, position.side): logging.info("Position already exists") @@ -108,103 +109,17 @@ async def open_position(self, command: OpenPosition): distance_to_stop_loss = abs(entry_price - stop_loss) - min_size = symbol.min_position_size num_orders = min( - max(1, int(position_size / min_size)), self.config["max_order_slice"] + max(1, int(size / symbol.min_position_size)), self.config["max_order_slice"] ) - size = round(position_size / num_orders, symbol.position_precision) + size = round(size / num_orders, symbol.position_precision) order_counter = 0 num_order_breach = 0 order_timestamps = {} - for price in self.algo_price.calculate(symbol, self.exchange): - current_distance_to_stop_loss = abs(stop_loss - price) - - threshold_breach = ( - self.config["stop_loss_threshold"] * distance_to_stop_loss - > current_distance_to_stop_loss - ) - - if threshold_breach: - logging.info( - f"Order risk breached: ENTR: {entry_price}, STPLS: {stop_loss}, THEO_DSTNC: {distance_to_stop_loss}, ALG_DSTNC: {current_distance_to_stop_loss}" - ) - - num_order_breach += 1 - - if num_order_breach >= self.config["max_order_breach"]: - break - - spread = ( - price - entry_price - if position.side == PositionSide.LONG - else entry_price - price - ) - - spread_percentage = (spread / entry_price) * 100 - - logging.info( - f"Trying to open order -> algo price: {price}, theo price: {entry_price}, spread: {spread_percentage}%" - ) - - if spread_percentage > 1.5: - break - - curr_time = time.time() - expired_orders = [ - order_id - for order_id, timestamp in order_timestamps.items() - if curr_time - timestamp > self.config["order_expiration_time"] - ] - - for order_id in expired_orders: - self.exchange.cancel_order(order_id, symbol) - order_timestamps.pop(order_id) - - for order_id in list(order_timestamps.keys()): - if self.exchange.has_filled_order(order_id, symbol): - order_timestamps.pop(order_id) - order_counter += 1 - - if order_counter >= num_orders: - logging.info(f"All orders are filled: {order_counter}") - break - - if not self.exchange.has_open_orders(symbol, position.side) and not len( - order_timestamps.keys() - ): - order_id = self.exchange.create_limit_order( - symbol, position.side, size, price - ) - if order_id: - order_timestamps[order_id] = time.time() - - for order_id in list(order_timestamps.keys()): - self.exchange.cancel_order(order_id, symbol) - - @command_handler(AdjustPosition) - async def adjust_position(self, command: AdjustPosition): - position = command.position - - logger.info(f"Try to adjust position: {position}") - - symbol = position.signal.symbol - position_size = position.filled_size - stop_loss = position.stop_loss_price - entry_price = command.adjust_price - - distance_to_stop_loss = abs(entry_price - stop_loss) - - min_size = symbol.min_position_size - num_orders = min( - max(1, int(position_size / min_size)), self.config["max_order_slice"] - ) - size = round(position_size / num_orders, symbol.position_precision) - order_counter = 0 - num_order_breach = 0 - order_timestamps = {} + async for bid, ask in self.algo_price.next_value(symbol, self.exchange): + price = ask if position.side == PositionSide.LONG else bid - for price in self.algo_price.calculate(symbol, self.exchange): current_distance_to_stop_loss = abs(stop_loss - price) threshold_breach = ( @@ -228,13 +143,13 @@ async def adjust_position(self, command: AdjustPosition): else entry_price - price ) - spread_percentage = (spread / entry_price) * 100 + spread_percentage = 100 * (spread / entry_price) logging.info( f"Trying to open order -> algo price: {price}, theo price: {entry_price}, spread: {spread_percentage}%" ) - if spread_percentage > 1.5: + if spread_percentage > 0.35: break curr_time = time.time() @@ -278,38 +193,35 @@ async def close_position(self, command: ClosePosition): logging.info("Position is not existed") return - position_size = position.filled_size + exit_order = position.exit_order() position_side = position.side - exit_price = command.exit_price - min_size = symbol.min_position_size num_orders = min( - max(1, int(position_size / min_size)), self.config["max_order_slice"] + max(1, int(exit_order.size / symbol.min_position_size)), + self.config["max_order_slice"], ) - size = round(position_size / num_orders, symbol.position_precision) + size = round(exit_order.size / num_orders, symbol.position_precision) order_counter = 0 order_timestamps = {} max_spread = float("-inf") - for price in self.algo_price.calculate(symbol, self.exchange): + async for bid, ask in self.algo_price.next_value(symbol, self.exchange): if not self.exchange.fetch_position(symbol, position_side): break + price = bid if position.side == PositionSide.LONG else ask + spread = ( - price - exit_price + price - exit_order.price if position_side == PositionSide.LONG - else exit_price - price + else exit_order.price - price ) - max_spread = max(spread, max_spread) + max_spread = max(max_spread, spread) logging.info( - f"Trying to reduce order -> algo price: {price}, theo price: {exit_price}, spread: {spread}, max spread: {max_spread}" + f"Trying to reduce order -> algo price: {price}, theo price: {exit_order.price}, spread: {spread}, max spread: {max_spread}" ) - if max_spread < 0: - self.exchange.close_full_position(symbol, position_side) - break - curr_time = time.time() expired_orders = [ order_id @@ -330,16 +242,22 @@ async def close_position(self, command: ClosePosition): logging.info(f"All orders are filled: {order_counter}") break - if ( - not self.exchange.has_open_orders(symbol, position_side, True) - or not len(order_timestamps.keys()) - ) and (spread < max_spread or spread < 0): + if not ( + self.exchange.has_open_orders(symbol, position_side, True) + or len(order_timestamps.keys()) + ): order_id = self.exchange.create_reduce_order( symbol, position_side, size, price ) if order_id: order_timestamps[order_id] = time.time() + if spread < max_spread and not len(order_timestamps.keys()): + if num_orders > 2: + self.exchange.close_half_position(symbol, position_side) + else: + self.exchange.close_full_position(symbol, position_side) + for order_id in list(order_timestamps.keys()): self.exchange.cancel_order(order_id, symbol) diff --git a/sor/_twap.py b/sor/_twap.py index e3514e95..d4f8c452 100644 --- a/sor/_twap.py +++ b/sor/_twap.py @@ -1,4 +1,4 @@ -import time +from asyncio import sleep import numpy as np @@ -11,7 +11,7 @@ class TWAP: def __init__(self, config_service: AbstractConfig): self.config = config_service.get("position") - def calculate(self, symbol: Symbol, exchange: AbstractExchange): + async def next_value(self, symbol: Symbol, exchange: AbstractExchange): current_time = 0 timepoints = [] twap_duration = self.config["twap_duration"] @@ -24,28 +24,37 @@ def calculate(self, symbol: Symbol, exchange: AbstractExchange): time_interval = self._volatility_time_interval(timepoints) current_time += time_interval - twap_value = self._twap(timepoints) + yield self._twap(timepoints) - yield twap_value - - time.sleep(time_interval) + await sleep(time_interval) def _fetch_book(self, symbol: Symbol, exchange: AbstractExchange): - bids, asks = exchange.fetch_order_book(symbol, depth=self.config["depth"]) + bids, asks = exchange.fetch_order_book(symbol, depth=self.config["dom"]) return np.array(bids), np.array(asks) @staticmethod def _twap(order_book): bid_prices, ask_prices, bid_volume, ask_volume = zip(*order_book) - bid_weighted_average = np.sum(np.multiply(bid_prices, bid_volume)) / np.sum( - bid_volume - ) - ask_weighted_average = np.sum(np.multiply(ask_prices, ask_volume)) / np.sum( - ask_volume - ) + bid_prices, ask_prices = np.array(bid_prices), np.array(ask_prices) + bid_volume, ask_volume = np.array(bid_volume), np.array(ask_volume) + + total_bid_volume, total_ask_volume = np.sum(bid_volume), np.sum(ask_volume) + + bid_weighted_average = np.sum(bid_prices * bid_volume) / total_bid_volume + ask_weighted_average = np.sum(ask_prices * ask_volume) / total_ask_volume + + diff = ask_prices - bid_prices + + mid_price = (bid_weighted_average + ask_weighted_average) / 2.0 + spread, volatility = np.mean(diff), np.std(diff) + + adj_spread = spread * volatility + + bid_price = mid_price - adj_spread / 2.0 + ask_price = mid_price + adj_spread / 2.0 - return (bid_weighted_average + ask_weighted_average) / 2 + return bid_price, ask_price @staticmethod def _volatility_time_interval(timepoints): @@ -55,7 +64,7 @@ def _volatility_time_interval(timepoints): high_low = np.log(high_prices / low_prices) volatility = np.sqrt((1 / (4 * np.log(2))) * np.mean(high_low**2)) - base_interval = 1.0 + base_interval = 1.236 volatility_factor = 30.0 - return base_interval + volatility_factor * volatility + return base_interval + np.tanh(volatility_factor * volatility) diff --git a/strategy/_actor.py b/strategy/_actor.py index 01c8a1d0..674b9b33 100644 --- a/strategy/_actor.py +++ b/strategy/_actor.py @@ -1,7 +1,8 @@ import logging from typing import TYPE_CHECKING, Optional -from core.actors import Actor +from core.actors import StrategyActor +from core.actors.policy.signal import SignalPolicy from core.events.ohlcv import NewMarketDataReceived from core.interfaces.abstract_signal_service import AbstractSignalService from core.models.strategy import Strategy @@ -14,7 +15,7 @@ logger = logging.getLogger(__name__) -class SignalActor(Actor): +class SignalActor(StrategyActor): _EVENTS = [NewMarketDataReceived] def __init__( @@ -25,29 +26,30 @@ def __init__( service: AbstractSignalService, ): super().__init__(symbol, timeframe) - - self.strategy_ref: Optional[StrategyRef] = None - self.service = service self._strategy = strategy + self.service = service + self.strategy_ref: Optional[StrategyRef] = None + + @property + def strategy(self): + return self._strategy def on_start(self): - self.strategy_ref = self.service.register(self._strategy) + self.strategy_ref = self.service.register(self.strategy) def on_stop(self): self.strategy_ref.unregister() self.strategy_ref = None def pre_receive(self, event: NewMarketDataReceived): - return ( - event.symbol == self._symbol - and event.timeframe == self._timeframe - and event.closed - ) + return SignalPolicy.should_process(self, event) async def on_receive(self, event: NewMarketDataReceived): signal_event = self.strategy_ref.next( - self._symbol, self._timeframe, self._strategy, event.ohlcv + self.symbol, self.timeframe, self.strategy, event.ohlcv ) if signal_event: + logger.debug(signal_event) + await self.tell(signal_event) diff --git a/strategy/generator/_factory.py b/strategy/generator/_factory.py index e5ec018e..ca856e55 100644 --- a/strategy/generator/_factory.py +++ b/strategy/generator/_factory.py @@ -3,7 +3,6 @@ from core.interfaces.abstract_strategy_generator_factory import ( AbstractStrategyGeneratorFactory, ) -from core.models.strategy import StrategyType from core.models.symbol import Symbol from core.models.timeframe import Timeframe @@ -11,8 +10,6 @@ class StrategyGeneratorFactory(AbstractStrategyGeneratorFactory): - _type = {StrategyType.TREND: TrendFollowStrategyGenerator} - def __init__( self, config_service: AbstractConfig, @@ -20,14 +17,7 @@ def __init__( super().__init__() self.config = config_service.get("generator") - def create( - self, type: StrategyType, symbols: list[Symbol] - ) -> AbstractStrategyGenerator: - if type not in self._type: - raise ValueError(f"Unknown StrategyType: {type}") - - generator_class = self._type.get(type) - + def create(self, symbols: list[Symbol]) -> AbstractStrategyGenerator: _symbols = [ symbol for symbol in symbols if symbol.name not in self.config["blacklist"] ] @@ -36,4 +26,6 @@ def create( Timeframe.from_raw(timeframe) for timeframe in self.config["timeframes"] ] - return generator_class(self.config["n_samples"], _symbols, _timeframes) + return TrendFollowStrategyGenerator( + self.config["n_samples"], _symbols, _timeframes + ) diff --git a/strategy/generator/baseline/ma.py b/strategy/generator/baseline/ma.py index 161e66f4..908c011b 100644 --- a/strategy/generator/baseline/ma.py +++ b/strategy/generator/baseline/ma.py @@ -4,6 +4,7 @@ from core.models.parameter import ( CategoricalParameter, Parameter, + RandomParameter, StaticParameter, ) from core.models.source import SourceType @@ -14,6 +15,6 @@ @dataclass(frozen=True) class MaBaseLine(BaseLine): type: BaseLineType = BaseLineType.Ma - source_type: Parameter = StaticParameter(SourceType.CLOSE) + source: Parameter = StaticParameter(SourceType.CLOSE) ma: Parameter = CategoricalParameter(MovingAverageType) - period: Parameter = StaticParameter(14.0) + period: Parameter = RandomParameter(12.0, 16.0) diff --git a/strategy/generator/bootstrap/_trend_follow.py b/strategy/generator/bootstrap/_trend_follow.py index b259a423..003b838d 100644 --- a/strategy/generator/bootstrap/_trend_follow.py +++ b/strategy/generator/bootstrap/_trend_follow.py @@ -1,4 +1,3 @@ -from dataclasses import replace from enum import Enum, auto from itertools import product from random import shuffle @@ -7,103 +6,105 @@ import numpy as np from core.interfaces.abstract_strategy_generator import AbstractStrategyGenerator -from core.models.candle import CandleTrendType -from core.models.moving_average import MovingAverageType -from core.models.parameter import CategoricalParameter, RandomParameter -from core.models.smooth import Smooth -from core.models.strategy import Strategy, StrategyType +from core.models.strategy import Strategy from core.models.symbol import Symbol from core.models.timeframe import Timeframe from strategy.generator.baseline.ma import MaBaseLine +from strategy.generator.confirm.bb import BbConfirm +from strategy.generator.confirm.braid import BraidConfirm +from strategy.generator.confirm.cc import CcConfirm from strategy.generator.confirm.cci import CciConfirm +from strategy.generator.confirm.didi import DidiConfirm from strategy.generator.confirm.dpo import DpoConfirm -from strategy.generator.confirm.dso import DsoConfirm +from strategy.generator.confirm.dumb import DumbConfirm from strategy.generator.confirm.eom import EomConfirm -from strategy.generator.confirm.roc import RocConfirm -from strategy.generator.confirm.rsi_neutrality import RsiNeutralityConfirm from strategy.generator.confirm.rsi_signalline import RsiSignalLineConfirm from strategy.generator.confirm.stc import StcConfirm -from strategy.generator.confirm.vi import ViConfirm -from strategy.generator.exit.cci import CciExit +from strategy.generator.confirm.wpr import WprConfirm from strategy.generator.exit.highlow import HighLowExit -from strategy.generator.exit.ma import MaExit -from strategy.generator.exit.mfi import MfiExit -from strategy.generator.exit.rsi import RsiExit +from strategy.generator.exit.mad import MadExit +from strategy.generator.exit.rex import RexExit from strategy.generator.exit.trix import TrixExit from strategy.generator.pulse.adx import AdxPulse -from strategy.generator.pulse.braid import BraidPulse from strategy.generator.pulse.chop import ChopPulse +from strategy.generator.pulse.dumb import DumbPulse from strategy.generator.pulse.nvol import NvolPulse +from strategy.generator.pulse.sqz import SqzPulse from strategy.generator.pulse.tdfi import TdfiPulse from strategy.generator.pulse.vo import VoPulse from strategy.generator.pulse.wae import WaePulse -from strategy.generator.signal.bb.macd_bb import MacdBbSignal -from strategy.generator.signal.bb.vwap_bb import VwapBbSignal +from strategy.generator.pulse.yz import YzPulse +from strategy.generator.signal.bb.macd import MacdBbSignal +from strategy.generator.signal.bb.vwap import VwapBbSignal from strategy.generator.signal.breakout.dch_two_ma import DchMa2BreakoutSignal -from strategy.generator.signal.flip.ce_flip import CeFlipSignal -from strategy.generator.signal.flip.supertrend_flip import SupertrendFlipSignal +from strategy.generator.signal.colorswitch.macd import MacdColorSwitchSignal +from strategy.generator.signal.contrarian.kch_a import KchASignal +from strategy.generator.signal.contrarian.kch_c import KchCSignal +from strategy.generator.signal.contrarian.rsi_c import RsiCSignal +from strategy.generator.signal.contrarian.rsi_d import RsiDSignal +from strategy.generator.signal.contrarian.rsi_nt import RsiNtSignal +from strategy.generator.signal.contrarian.rsi_u import RsiUSignal +from strategy.generator.signal.contrarian.rsi_v import RsiVSignal +from strategy.generator.signal.contrarian.snatr import SnatrSignal +from strategy.generator.signal.contrarian.stoch_e import StochESignal +from strategy.generator.signal.contrarian.tii_v import TiiVSignal +from strategy.generator.signal.flip.ce import CeFlipSignal +from strategy.generator.signal.flip.supertrend import SupertrendFlipSignal from strategy.generator.signal.ma.ma2_rsi import Ma2RsiSignal from strategy.generator.signal.ma.ma3_cross import Ma3CrossSignal from strategy.generator.signal.ma.ma_cross import MaCrossSignal +from strategy.generator.signal.ma.ma_quadruple import MaQuadrupleSignal +from strategy.generator.signal.ma.ma_surpass import MaSurpassSignal from strategy.generator.signal.ma.ma_testing_ground import MaTestingGroundSignal from strategy.generator.signal.ma.vwap_cross import VwapCrossSignal -from strategy.generator.signal.neutrality.dso_neutrality_cross import ( - DsoNeutralityCrossSignal, -) -from strategy.generator.signal.neutrality.rsi_neutrality_cross import ( - RsiNautralityCrossSignal, -) -from strategy.generator.signal.neutrality.rsi_neutrality_pullback import ( +from strategy.generator.signal.neutrality.dso_cross import DsoNeutralityCrossSignal +from strategy.generator.signal.neutrality.rsi_cross import RsiNautralityCrossSignal +from strategy.generator.signal.neutrality.rsi_pullback import ( RsiNautralityPullbackSignal, ) -from strategy.generator.signal.neutrality.rsi_neutrality_rejection import ( +from strategy.generator.signal.neutrality.rsi_rejection import ( RsiNautralityRejectionSignal, ) -from strategy.generator.signal.neutrality.tii_neutrality_cross import ( - TiiNeutralityCrossSignal, -) +from strategy.generator.signal.neutrality.tii_cross import TiiNeutralityCrossSignal from strategy.generator.signal.pattern.ao_saucer import AoSaucerSignal +from strategy.generator.signal.pattern.candle_reversal import CandlestickReversalSignal from strategy.generator.signal.pattern.candle_trend import CandlestickTrendSignal from strategy.generator.signal.pattern.hl import HighLowSignal -from strategy.generator.signal.pattern.macd_colorswitch import MacdColorSwitchSignal -from strategy.generator.signal.pattern.rsi_v import RsiVSignal -from strategy.generator.signal.pattern.tii_v import TiiVSignal -from strategy.generator.signal.reversal.dmi_reversal import DmiReversalSignal -from strategy.generator.signal.reversal.snatr_reversal import SnatrReversalSignal -from strategy.generator.signal.reversal.vi_reversal import ViReversalSignal -from strategy.generator.signal.signalline.di_signalline import DiSignalLineSignal -from strategy.generator.signal.signalline.dso_signalline import DsoSignalLineSignal -from strategy.generator.signal.signalline.kst_signalline import KstSignalLineSignal -from strategy.generator.signal.signalline.macd_signalline import MacdSignalLineSignal -from strategy.generator.signal.signalline.qstick_signalline import ( - QstickSignalLineSignal, -) -from strategy.generator.signal.signalline.rsi_signalline import RsiSignalLineSignal -from strategy.generator.signal.signalline.stoch_signalline import StochSignalLineSignal -from strategy.generator.signal.signalline.trix_signalline import TrixSignalLineSignal -from strategy.generator.signal.signalline.tsi_signalline import TsiSignalLineSignal -from strategy.generator.signal.zerocross.ao_zerocross import AoZeroCrossSignal -from strategy.generator.signal.zerocross.bop_zerocross import BopZeroCrossSignal -from strategy.generator.signal.zerocross.cc_zerocross import CcZeroCrossSignal -from strategy.generator.signal.zerocross.cfo_zerocross import CfoZeroCrossSignal -from strategy.generator.signal.zerocross.macd_zerocross import MacdZeroCrossSignal -from strategy.generator.signal.zerocross.qstick_zerocross import QstickZeroCrossSignal -from strategy.generator.signal.zerocross.roc_zerocross import RocZeroCrossSignal -from strategy.generator.signal.zerocross.trix_zerocross import TrixZeroCrossSignal -from strategy.generator.signal.zerocross.tsi_zerocross import TsiZeroCrossSignal +from strategy.generator.signal.signalline.di import DiSignalLineSignal +from strategy.generator.signal.signalline.dso import DsoSignalLineSignal +from strategy.generator.signal.signalline.kst import KstSignalLineSignal +from strategy.generator.signal.signalline.macd import MacdSignalLineSignal +from strategy.generator.signal.signalline.qstick import QstickSignalLineSignal +from strategy.generator.signal.signalline.rsi import RsiSignalLineSignal +from strategy.generator.signal.signalline.stoch import StochSignalLineSignal +from strategy.generator.signal.signalline.trix import TrixSignalLineSignal +from strategy.generator.signal.signalline.tsi import TsiSignalLineSignal +from strategy.generator.signal.twolinescross.dmi import Dmi2LinesCrossSignal +from strategy.generator.signal.twolinescross.vi import Vi2LinesCrossSignal +from strategy.generator.signal.zerocross.ao import AoZeroCrossSignal +from strategy.generator.signal.zerocross.bop import BopZeroCrossSignal +from strategy.generator.signal.zerocross.cc import CcZeroCrossSignal +from strategy.generator.signal.zerocross.cfo import CfoZeroCrossSignal +from strategy.generator.signal.zerocross.macd import MacdZeroCrossSignal +from strategy.generator.signal.zerocross.mad import MadZeroCrossSignal +from strategy.generator.signal.zerocross.qstick import QstickZeroCrossSignal +from strategy.generator.signal.zerocross.roc import RocZeroCrossSignal +from strategy.generator.signal.zerocross.trix import TrixZeroCrossSignal +from strategy.generator.signal.zerocross.tsi import TsiZeroCrossSignal from strategy.generator.stop_loss.atr import AtrStopLoss -from strategy.generator.stop_loss.dch import DchStopLoss class TrendSignalType(Enum): ZERO_CROSS = auto() SIGNAL_LINE = auto() + LINES_TWO_CROSS = auto() + CONTRARIAN = auto() BB = auto() PATTERN = auto() + COLOR_SWITCH = auto() FLIP = auto() MA = auto() BREAKOUT = auto() - REVERSAL = auto() NEUTRALITY = auto() @@ -175,47 +176,51 @@ def add_strategy(): def _generate_strategy(self): signal_groups = list(TrendSignalType) entry_signal = self._generate_signal(np.random.choice(signal_groups)) - baseline = np.random.choice([MaBaseLine()]) + baseline = np.random.choice( + [ + MaBaseLine(), + ] + ) confirm = np.random.choice( [ DpoConfirm(), EomConfirm(), - RocConfirm(), + WprConfirm(), + CciConfirm(), + BraidConfirm(), RsiSignalLineConfirm(), - RsiNeutralityConfirm(), + CcConfirm(), + DumbConfirm(), + BbConfirm(), + DidiConfirm(), StcConfirm(), - DsoConfirm(), - CciConfirm(), - ViConfirm(), ] ) pulse = np.random.choice( [ AdxPulse(), ChopPulse(), - BraidPulse(), VoPulse(), NvolPulse(), TdfiPulse(), WaePulse(), + YzPulse(), + SqzPulse(), + DumbPulse(), ] ) - stop_loss = np.random.choice([AtrStopLoss(), DchStopLoss()]) + stop_loss = np.random.choice([AtrStopLoss()]) exit_signal = np.random.choice( [ - # AstExit(), HighLowExit(), - MaExit(), - RsiExit(), - MfiExit(), - CciExit(), TrixExit(), + RexExit(), + MadExit(), ] ) return Strategy( *( - StrategyType.TREND, entry_signal, confirm, pulse, @@ -227,91 +232,139 @@ def _generate_strategy(self): def _generate_invariants(self, base_strategy: Strategy) -> List[Strategy]: result = [base_strategy] - attributes = [] - - def smooth_invariants(strategy_part): - if not hasattr(strategy_part, "smooth_type") or not hasattr( - strategy_part, "smooth_signal" - ): - return [] - - return [ - replace(strategy_part, smooth_type=CategoricalParameter(Smooth)) - for _ in range(5) - ] + [ - replace(strategy_part, smooth_signal=CategoricalParameter(Smooth)) - for _ in range(5) - ] - - def candle_invariants(strategy_part): - if not hasattr(strategy_part, "candle"): - return [] - - return [ - replace(strategy_part, candle=CategoricalParameter(CandleTrendType)) - for _ in range(5) - ] - - def period_invariants(strategy_part): - if not hasattr(strategy_part, "period"): - return [] - - return ( - [ - replace(strategy_part, period=RandomParameter(8.0, 20.0, 5.0)) - for _ in range(2) - ] - + [ - replace(strategy_part, period=RandomParameter(25.0, 50.0, 8.0)) - for _ in range(3) - ] - + [ - replace(strategy_part, period=RandomParameter(58.0, 100.0, 10.0)) - for _ in range(2) - ] - ) - - def ma_invariants(strategy_part): - if not hasattr(strategy_part, "ma"): - return [] - - return [ - replace(strategy_part, ma=CategoricalParameter(MovingAverageType)) - for _ in range(3) - ] - - def factor_invariants(strategy_part): - if not hasattr(strategy_part, "factor"): - return [] - - return [ - replace(strategy_part, factor=RandomParameter(1.0, 8.0, 0.5)) - for _ in range(3) - ] - - for attr in attributes: - for strategy in result[:]: - strategy_attr = getattr(strategy, attr) - - smoothed_parts = smooth_invariants(strategy_attr) - for part in smoothed_parts: - result.append(replace(strategy, **{attr: part})) - - ma_parts = ma_invariants(strategy_attr) - for part in ma_parts: - result.append(replace(strategy, **{attr: part})) - - candle_parts = candle_invariants(strategy_attr) - for part in candle_parts: - result.append(replace(strategy, **{attr: part})) - - factor_parts = factor_invariants(strategy_attr) - for part in factor_parts: - result.append(replace(strategy, **{attr: part})) - - period_parts = period_invariants(strategy_attr) - for part in period_parts: - result.append(replace(strategy, **{attr: part})) + # strategy_attributes = [] + + # def smooth_invariants(strategy_part, nums=8): + # smooth_attr = ["smooth_type", "smooth_signal", "smooth_bb"] + # replacements = [] + + # for attr in smooth_attr: + # if hasattr(strategy_part, attr): + # replacements.extend( + # [ + # replace( + # strategy_part, **{attr: CategoricalParameter(Smooth)} + # ) + # for _ in range(nums) + # ] + # ) + + # return replacements + + # def candle_invariants(strategy_part, nums=3): + # smooth_attr = ["candle"] + # replacements = [] + + # for attr in smooth_attr: + # if hasattr(strategy_part, attr): + # replacements.extend( + # [ + # replace( + # strategy_part, + # **{attr: CategoricalParameter(CandleTrendType)} + # ) + # for _ in range(nums) + # ] + # ) + + # return replacements + + # def period_invariants(strategy_part): + # replacements = [] + # period_replacement_ranges = [ + # ( + # "period", + # [ + # (RandomParameter(6.0, 20.0, 5.0), 8), + # (RandomParameter(25.0, 50.0, 8.0), 6), + # (RandomParameter(58.0, 100.0, 10.0), 3), + # ], + # ), + # ("atr_period", [(RandomParameter(0.2, 10.0, 0.1), 5)]), + # ] + + # for attr, replacement_ranges in period_replacement_ranges: + # if hasattr(strategy_part, attr): + # for range_params, num_replacements in replacement_ranges: + # replacements.extend( + # [ + # replace(strategy_part, **{attr: range_params}) + # for _ in range(num_replacements) + # ] + # ) + + # return replacements + + # def ma_invariants(strategy_part, nums=3): + # replacements = [] + + # if hasattr(strategy_part, "ma"): + # replacements.extend( + # [ + # replace( + # strategy_part, ma=CategoricalParameter(MovingAverageType) + # ) + # for _ in range(nums) + # ] + # ) + + # return replacements + + # def factor_invariants(strategy_part, nums=3): + # replacements = [] + + # if hasattr(strategy_part, "factor"): + # replacements.extend( + # [ + # replace(strategy_part, factor=RandomParameter(1.0, 8.0, 0.5)) + # for _ in range(nums) + # ] + # ) + + # return replacements + + # def source_invariants(strategy_part, nums=3): + # replacements = [] + + # if hasattr(strategy_part, "source_type"): + # replacements.extend( + # [ + # replace( + # strategy_part, source_type=CategoricalParameter(SourceType) + # ) + # for _ in range(nums) + # ] + # ) + + # return replacements + + # for attr in strategy_attributes: + # for strategy in result[:]: + # strategy_attr = getattr(strategy, attr) + + # source_parts = source_invariants(strategy_attr) + # for part in source_parts: + # result.append(replace(strategy, **{attr: part})) + + # smoothed_parts = smooth_invariants(strategy_attr) + # for part in smoothed_parts: + # result.append(replace(strategy, **{attr: part})) + + # ma_parts = ma_invariants(strategy_attr) + # for part in ma_parts: + # result.append(replace(strategy, **{attr: part})) + + # candle_parts = candle_invariants(strategy_attr) + # for part in candle_parts: + # result.append(replace(strategy, **{attr: part})) + + # period_parts = period_invariants(strategy_attr) + # for part in period_parts: + # result.append(replace(strategy, **{attr: part})) + + # factor_parts = factor_invariants(strategy_attr) + # for part in factor_parts: + # result.append(replace(strategy, **{attr: part})) return result @@ -320,7 +373,6 @@ def _generate_signal(self, signal: TrendSignalType): return np.random.choice( [ AoZeroCrossSignal(), - BopZeroCrossSignal(), MacdZeroCrossSignal(), RocZeroCrossSignal(), TsiZeroCrossSignal(), @@ -329,6 +381,7 @@ def _generate_signal(self, signal: TrendSignalType): CcZeroCrossSignal(), BopZeroCrossSignal(), CfoZeroCrossSignal(), + MadZeroCrossSignal(), ] ) if signal == TrendSignalType.SIGNAL_LINE: @@ -351,9 +404,29 @@ def _generate_signal(self, signal: TrendSignalType): AoSaucerSignal(), CandlestickTrendSignal(), HighLowSignal(), + ] + ) + if signal == TrendSignalType.COLOR_SWITCH: + return np.random.choice( + [ MacdColorSwitchSignal(), + ] + ) + + if signal == TrendSignalType.CONTRARIAN: + return np.random.choice( + [ TiiVSignal(), RsiVSignal(), + StochESignal(), + RsiDSignal(), + RsiCSignal(), + RsiNtSignal(), + RsiUSignal(), + SnatrSignal(), + CandlestickReversalSignal(), + KchCSignal(), + KchASignal(), ] ) if signal == TrendSignalType.BB: @@ -377,8 +450,8 @@ def _generate_signal(self, signal: TrendSignalType): VwapCrossSignal(), Ma2RsiSignal(), MaTestingGroundSignal(), - # MaQuadrupleSignal(), - # MaSurpassSignal(), + MaQuadrupleSignal(), + MaSurpassSignal(), MaCrossSignal(), ] ) @@ -388,12 +461,11 @@ def _generate_signal(self, signal: TrendSignalType): DchMa2BreakoutSignal(), ] ) - if signal == TrendSignalType.REVERSAL: + if signal == TrendSignalType.LINES_TWO_CROSS: return np.random.choice( [ - DmiReversalSignal(), - SnatrReversalSignal(), - ViReversalSignal(), + Dmi2LinesCrossSignal(), + Vi2LinesCrossSignal(), ] ) if signal == TrendSignalType.NEUTRALITY: @@ -409,6 +481,5 @@ def _generate_signal(self, signal: TrendSignalType): return np.random.choice( [ RsiNautralityRejectionSignal(), - MacdZeroCrossSignal(), ] ) diff --git a/strategy/generator/confirm/base.py b/strategy/generator/confirm/base.py index bb76f271..4728e816 100644 --- a/strategy/generator/confirm/base.py +++ b/strategy/generator/confirm/base.py @@ -5,16 +5,18 @@ class ConfirmType(Enum): + BbC = "BbC" + Braid = "Braid" Dumb = "Dumb" Dpo = "Dpo" - Dso = "Dso" Cci = "Cci" + Cc = "Cc" Eom = "Eom" - Roc = "Roc" RsiSignalLine = "RsiSignalLine" RsiNeutrality = "RsiNeutrality" Stc = "Stc" - Vi = "Vi" + Wpr = "Wpr" + Didi = "Didi" def __str__(self): return self.value.upper() diff --git a/strategy/generator/confirm/bb.py b/strategy/generator/confirm/bb.py new file mode 100644 index 00000000..80cff0c6 --- /dev/null +++ b/strategy/generator/confirm/bb.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + +from core.models.parameter import Parameter, StaticParameter +from core.models.smooth import Smooth +from core.models.source import SourceType + +from .base import Confirm, ConfirmType + + +@dataclass(frozen=True) +class BbConfirm(Confirm): + type: Confirm = ConfirmType.BbC + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = StaticParameter(Smooth.SMA) + period: Parameter = StaticParameter(20.0) + factor: Parameter = StaticParameter(2.0) diff --git a/strategy/generator/confirm/braid.py b/strategy/generator/confirm/braid.py new file mode 100644 index 00000000..d908c9d0 --- /dev/null +++ b/strategy/generator/confirm/braid.py @@ -0,0 +1,22 @@ +from dataclasses import dataclass + +from core.models.parameter import ( + CategoricalParameter, + Parameter, + StaticParameter, +) +from core.models.smooth import Smooth, SmoothATR + +from .base import Confirm, ConfirmType + + +@dataclass(frozen=True) +class BraidConfirm(Confirm): + type: Confirm = ConfirmType.Braid + smooth_type: Parameter = StaticParameter(Smooth.DEMA) + fast_period: Parameter = StaticParameter(3.0) + slow_period: Parameter = StaticParameter(14.0) + open_period: Parameter = StaticParameter(7.0) + strength: Parameter = StaticParameter(40.0) + smooth_atr: Parameter = CategoricalParameter(SmoothATR) + period_atr: Parameter = StaticParameter(14.0) diff --git a/strategy/generator/confirm/cc.py b/strategy/generator/confirm/cc.py new file mode 100644 index 00000000..3f87d255 --- /dev/null +++ b/strategy/generator/confirm/cc.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +from core.models.parameter import Parameter, StaticParameter +from core.models.smooth import Smooth +from core.models.source import SourceType + +from .base import Confirm, ConfirmType + + +@dataclass(frozen=True) +class CcConfirm(Confirm): + type: Confirm = ConfirmType.Cc + source: Parameter = StaticParameter(SourceType.CLOSE) + period_fast: Parameter = StaticParameter(14.0) + period_slow: Parameter = StaticParameter(28.0) + smooth: Parameter = StaticParameter(Smooth.WMA) + period_smooth: Parameter = StaticParameter(14.0) + smooth_signal: Parameter = StaticParameter(Smooth.SMA) + period_signal: Parameter = StaticParameter(14.0) diff --git a/strategy/generator/confirm/cci.py b/strategy/generator/confirm/cci.py index dbec2a28..ea54c599 100644 --- a/strategy/generator/confirm/cci.py +++ b/strategy/generator/confirm/cci.py @@ -3,13 +3,15 @@ from core.models.parameter import Parameter, StaticParameter from core.models.smooth import Smooth from core.models.source import SourceType -from strategy.generator.confirm.base import Confirm, ConfirmType + +from .base import Confirm, ConfirmType @dataclass(frozen=True) class CciConfirm(Confirm): type: Confirm = ConfirmType.Cci - source_type: Parameter = StaticParameter(SourceType.HLC3) - smooth_type: Parameter = StaticParameter(Smooth.SMA) + source: Parameter = StaticParameter(SourceType.HLC3) period: Parameter = StaticParameter(100.0) factor: Parameter = StaticParameter(0.015) + smooth: Parameter = StaticParameter(Smooth.EMA) + period_smooth: Parameter = StaticParameter(14.0) diff --git a/strategy/generator/confirm/didi.py b/strategy/generator/confirm/didi.py new file mode 100644 index 00000000..2ae5f229 --- /dev/null +++ b/strategy/generator/confirm/didi.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + +from core.models.parameter import Parameter, StaticParameter +from core.models.smooth import Smooth +from core.models.source import SourceType + +from .base import Confirm, ConfirmType + + +@dataclass(frozen=True) +class DidiConfirm(Confirm): + type: Confirm = ConfirmType.Didi + source: Parameter = StaticParameter(SourceType.HL2) + smooth: Parameter = StaticParameter(Smooth.SMA) + period_medium: Parameter = StaticParameter(8.0) + period_slow: Parameter = StaticParameter(40.0) + smooth_signal: Parameter = StaticParameter(Smooth.EMA) + period_signal: Parameter = StaticParameter(3.0) diff --git a/strategy/generator/confirm/dpo.py b/strategy/generator/confirm/dpo.py index f05196ea..e7012f5d 100644 --- a/strategy/generator/confirm/dpo.py +++ b/strategy/generator/confirm/dpo.py @@ -12,4 +12,4 @@ class DpoConfirm(Confirm): type: Confirm = ConfirmType.Dpo source_type: Parameter = StaticParameter(SourceType.CLOSE) smooth_type: Parameter = StaticParameter(Smooth.SMA) - period: Parameter = StaticParameter(18.0) + period: Parameter = StaticParameter(27.0) diff --git a/strategy/generator/confirm/dso.py b/strategy/generator/confirm/dso.py deleted file mode 100644 index de20cdd9..00000000 --- a/strategy/generator/confirm/dso.py +++ /dev/null @@ -1,17 +0,0 @@ -from dataclasses import dataclass - -from core.models.parameter import Parameter, StaticParameter -from core.models.smooth import Smooth -from core.models.source import SourceType - -from .base import Confirm, ConfirmType - - -@dataclass(frozen=True) -class DsoConfirm(Confirm): - type: Confirm = ConfirmType.Dso - source_type: Parameter = StaticParameter(SourceType.CLOSE) - smooth_type: Parameter = StaticParameter(Smooth.EMA) - smooth_period: Parameter = StaticParameter(10.0) - k_period: Parameter = StaticParameter(5.0) - d_period: Parameter = StaticParameter(7.0) diff --git a/strategy/generator/confirm/eom.py b/strategy/generator/confirm/eom.py index 35d705c9..bc578efb 100644 --- a/strategy/generator/confirm/eom.py +++ b/strategy/generator/confirm/eom.py @@ -12,5 +12,4 @@ class EomConfirm(Confirm): type: ConfirmType = ConfirmType.Eom source_type: Parameter = StaticParameter(SourceType.HL2) smooth_type: Parameter = StaticParameter(Smooth.SMA) - period: Parameter = StaticParameter(14.0) - divisor: Parameter = StaticParameter(10000.0) + period: Parameter = StaticParameter(16.0) diff --git a/strategy/generator/confirm/roc.py b/strategy/generator/confirm/roc.py deleted file mode 100644 index ff0aa988..00000000 --- a/strategy/generator/confirm/roc.py +++ /dev/null @@ -1,13 +0,0 @@ -from dataclasses import dataclass - -from core.models.parameter import Parameter, StaticParameter -from core.models.source import SourceType - -from .base import Confirm, ConfirmType - - -@dataclass(frozen=True) -class RocConfirm(Confirm): - type: Confirm = ConfirmType.Roc - source_type: Parameter = StaticParameter(SourceType.CLOSE) - period: Parameter = StaticParameter(21.0) diff --git a/strategy/generator/confirm/rsi_neutrality.py b/strategy/generator/confirm/rsi_neutrality.py index d04e4306..7a7cc3d7 100644 --- a/strategy/generator/confirm/rsi_neutrality.py +++ b/strategy/generator/confirm/rsi_neutrality.py @@ -12,5 +12,5 @@ class RsiNeutralityConfirm(Confirm): type: Confirm = ConfirmType.RsiNeutrality source_type: Parameter = StaticParameter(SourceType.CLOSE) smooth_type: Parameter = StaticParameter(Smooth.SMMA) - period: Parameter = StaticParameter(14.0) + period: Parameter = StaticParameter(28.0) threshold: Parameter = RandomParameter(0.0, 3.0, 1.0) diff --git a/strategy/generator/confirm/rsi_signalline.py b/strategy/generator/confirm/rsi_signalline.py index 8624c72d..b9315c10 100644 --- a/strategy/generator/confirm/rsi_signalline.py +++ b/strategy/generator/confirm/rsi_signalline.py @@ -12,7 +12,7 @@ class RsiSignalLineConfirm(Confirm): type: Confirm = ConfirmType.RsiSignalLine source_type: Parameter = StaticParameter(SourceType.CLOSE) smooth_type: Parameter = StaticParameter(Smooth.SMMA) - period: Parameter = StaticParameter(14.0) - smooth_signal: Parameter = StaticParameter(Smooth.WMA) - smooth_period: Parameter = RandomParameter(7.0, 10.0, 1.0) + period: Parameter = StaticParameter(18.0) + smooth_signal: Parameter = StaticParameter(Smooth.HMA) + smooth_period: Parameter = StaticParameter(10.0) threshold: Parameter = RandomParameter(0.0, 3.0, 1.0) diff --git a/strategy/generator/confirm/vi.py b/strategy/generator/confirm/vi.py deleted file mode 100644 index e77f7bf8..00000000 --- a/strategy/generator/confirm/vi.py +++ /dev/null @@ -1,11 +0,0 @@ -from dataclasses import dataclass - -from core.models.parameter import Parameter, StaticParameter -from strategy.generator.confirm.base import Confirm, ConfirmType - - -@dataclass(frozen=True) -class ViConfirm(Confirm): - type: ConfirmType = ConfirmType.Vi - atr_period: Parameter = StaticParameter(1.0) - period: Parameter = StaticParameter(21.0) diff --git a/strategy/generator/confirm/wpr.py b/strategy/generator/confirm/wpr.py new file mode 100644 index 00000000..fc5506d1 --- /dev/null +++ b/strategy/generator/confirm/wpr.py @@ -0,0 +1,16 @@ +from dataclasses import dataclass + +from core.models.parameter import Parameter, StaticParameter +from core.models.smooth import Smooth +from core.models.source import SourceType + +from .base import Confirm, ConfirmType + + +@dataclass(frozen=True) +class WprConfirm(Confirm): + type: Confirm = ConfirmType.Wpr + source: Parameter = StaticParameter(SourceType.CLOSE) + period: Parameter = StaticParameter(28.0) + smooth_signal: Parameter = StaticParameter(Smooth.SMA) + period_signal: Parameter = StaticParameter(14.0) diff --git a/strategy/generator/exit/ast.py b/strategy/generator/exit/ast.py index 1d64fa7e..43d21078 100644 --- a/strategy/generator/exit/ast.py +++ b/strategy/generator/exit/ast.py @@ -1,9 +1,11 @@ from dataclasses import dataclass from core.models.parameter import ( + CategoricalParameter, Parameter, StaticParameter, ) +from core.models.smooth import SmoothATR from core.models.source import SourceType from .base import Exit, ExitType @@ -13,5 +15,6 @@ class AstExit(Exit): type: ExitType = ExitType.Ast source_type: Parameter = StaticParameter(SourceType.CLOSE) - atr_period: Parameter = StaticParameter(12.0) + smooth_atr: Parameter = CategoricalParameter(SmoothATR) + period_atr: Parameter = StaticParameter(12.0) factor: Parameter = StaticParameter(3.0) diff --git a/strategy/generator/exit/base.py b/strategy/generator/exit/base.py index 4688696b..3ead4471 100644 --- a/strategy/generator/exit/base.py +++ b/strategy/generator/exit/base.py @@ -6,13 +6,14 @@ class ExitType(Enum): Ast = "Ast" + Mad = "Mad" Dumb = "Dumb" HighLow = "HighLow" Rsi = "Rsi" Ma = "Ma" Mfi = "Mfi" - Cci = "Cci" Trix = "Trix" + Rex = "Rex" def __str__(self): return self.value.upper() diff --git a/strategy/generator/exit/cci.py b/strategy/generator/exit/cci.py deleted file mode 100644 index 78197bf6..00000000 --- a/strategy/generator/exit/cci.py +++ /dev/null @@ -1,17 +0,0 @@ -from dataclasses import dataclass - -from core.models.parameter import Parameter, RandomParameter, StaticParameter -from core.models.smooth import Smooth -from core.models.source import SourceType - -from .base import Exit, ExitType - - -@dataclass(frozen=True) -class CciExit(Exit): - type: ExitType = ExitType.Cci - source_type: Parameter = StaticParameter(SourceType.HLC3) - smooth_type: Parameter = StaticParameter(Smooth.SMA) - period: Parameter = StaticParameter(8.0) - factor: Parameter = StaticParameter(0.015) - threshold: Parameter = RandomParameter(0.0, 5.0, 1.0) diff --git a/strategy/generator/exit/mad.py b/strategy/generator/exit/mad.py new file mode 100644 index 00000000..a1f95a0b --- /dev/null +++ b/strategy/generator/exit/mad.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + +from core.models.parameter import ( + Parameter, + StaticParameter, +) +from core.models.source import SourceType + +from .base import Exit, ExitType + + +@dataclass(frozen=True) +class MadExit(Exit): + type: ExitType = ExitType.Mad + source: Parameter = StaticParameter(SourceType.CLOSE) + period_fast: Parameter = StaticParameter(8.0) + period_slow: Parameter = StaticParameter(23.0) diff --git a/strategy/generator/exit/rex.py b/strategy/generator/exit/rex.py new file mode 100644 index 00000000..68cb939c --- /dev/null +++ b/strategy/generator/exit/rex.py @@ -0,0 +1,20 @@ +from dataclasses import dataclass + +from core.models.parameter import ( + Parameter, + StaticParameter, +) +from core.models.smooth import Smooth +from core.models.source import SourceType + +from .base import Exit, ExitType + + +@dataclass(frozen=True) +class RexExit(Exit): + type: ExitType = ExitType.Rex + source: Parameter = StaticParameter(SourceType.HL2) + smooth: Parameter = StaticParameter(Smooth.LSMA) + period: Parameter = StaticParameter(10.0) + smooth_signal: Parameter = StaticParameter(Smooth.TEMA) + period_signal: Parameter = StaticParameter(5.0) diff --git a/strategy/generator/pulse/adx.py b/strategy/generator/pulse/adx.py index 2f57a9c7..20f89ef8 100644 --- a/strategy/generator/pulse/adx.py +++ b/strategy/generator/pulse/adx.py @@ -9,7 +9,7 @@ @dataclass(frozen=True) class AdxPulse(Pulse): type: PulseType = PulseType.Adx - smooth_type: Parameter = StaticParameter(Smooth.SMMA) - adx_period: Parameter = StaticParameter(15.0) - di_period: Parameter = StaticParameter(15.0) + smooth: Parameter = StaticParameter(Smooth.SMMA) + period: Parameter = StaticParameter(15.0) + period_di: Parameter = StaticParameter(15.0) threshold: Parameter = RandomParameter(0.0, 5.0, 1.0) diff --git a/strategy/generator/pulse/base.py b/strategy/generator/pulse/base.py index 3ee61397..251441c7 100644 --- a/strategy/generator/pulse/base.py +++ b/strategy/generator/pulse/base.py @@ -6,13 +6,14 @@ class PulseType(Enum): Adx = "Adx" - Braid = "Braid" Dumb = "Dumb" Chop = "Chop" Nvol = "Nvol" Vo = "Vo" Tdfi = "Tdfi" Wae = "Wae" + Yz = "Yz" + Sqz = "Sqz" def __str__(self): return self.value.upper() diff --git a/strategy/generator/pulse/braid.py b/strategy/generator/pulse/braid.py deleted file mode 100644 index cc81bb26..00000000 --- a/strategy/generator/pulse/braid.py +++ /dev/null @@ -1,20 +0,0 @@ -from dataclasses import dataclass - -from core.models.parameter import ( - Parameter, - StaticParameter, -) -from core.models.smooth import Smooth - -from .base import Pulse, PulseType - - -@dataclass(frozen=True) -class BraidPulse(Pulse): - type: PulseType = PulseType.Braid - smooth_type: Parameter = StaticParameter(Smooth.WMA) - fast_period: Parameter = StaticParameter(3.0) - slow_period: Parameter = StaticParameter(14.0) - open_period: Parameter = StaticParameter(7.0) - strength: Parameter = StaticParameter(40.0) - atr_period: Parameter = StaticParameter(14.0) diff --git a/strategy/generator/pulse/chop.py b/strategy/generator/pulse/chop.py index 64537028..3725698b 100644 --- a/strategy/generator/pulse/chop.py +++ b/strategy/generator/pulse/chop.py @@ -1,6 +1,12 @@ from dataclasses import dataclass -from core.models.parameter import Parameter, RandomParameter, StaticParameter +from core.models.parameter import ( + CategoricalParameter, + Parameter, + RandomParameter, + StaticParameter, +) +from core.models.smooth import SmoothATR from .base import Pulse, PulseType @@ -8,6 +14,7 @@ @dataclass(frozen=True) class ChopPulse(Pulse): type: PulseType = PulseType.Chop - atr_period: Parameter = StaticParameter(1.0) - period: Parameter = StaticParameter(14.0) + period: Parameter = StaticParameter(9.0) + smooth_atr: Parameter = CategoricalParameter(SmoothATR) + period_atr: Parameter = StaticParameter(1.0) threshold: Parameter = RandomParameter(0.0, 5.0, 1.0) diff --git a/strategy/generator/pulse/nvol.py b/strategy/generator/pulse/nvol.py index eff1659d..0701c9a8 100644 --- a/strategy/generator/pulse/nvol.py +++ b/strategy/generator/pulse/nvol.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from core.models.parameter import Parameter, StaticParameter +from core.models.parameter import CategoricalParameter, Parameter, StaticParameter from core.models.smooth import Smooth from .base import Pulse, PulseType @@ -9,5 +9,5 @@ @dataclass(frozen=True) class NvolPulse(Pulse): type: PulseType = PulseType.Nvol - smooth_type: Parameter = StaticParameter(Smooth.SMA) + smooth: Parameter = CategoricalParameter(Smooth) period: Parameter = StaticParameter(14.0) diff --git a/strategy/generator/pulse/sqz.py b/strategy/generator/pulse/sqz.py new file mode 100644 index 00000000..d404cd2f --- /dev/null +++ b/strategy/generator/pulse/sqz.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +from core.models.parameter import CategoricalParameter, Parameter, StaticParameter +from core.models.smooth import Smooth, SmoothATR +from core.models.source import SourceType + +from .base import Pulse, PulseType + + +@dataclass(frozen=True) +class SqzPulse(Pulse): + type: PulseType = PulseType.Sqz + source: Parameter = CategoricalParameter(SourceType) + smooth: Parameter = StaticParameter(Smooth.SMA) + period: Parameter = StaticParameter(20.0) + smooth_atr: Parameter = CategoricalParameter(SmoothATR) + period_atr: Parameter = StaticParameter(20.0) + factor_bb: Parameter = StaticParameter(2.0) + factor_kch: Parameter = StaticParameter(1.2) diff --git a/strategy/generator/pulse/tdfi.py b/strategy/generator/pulse/tdfi.py index ec6b7bf7..19223675 100644 --- a/strategy/generator/pulse/tdfi.py +++ b/strategy/generator/pulse/tdfi.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from core.models.parameter import Parameter, StaticParameter +from core.models.parameter import CategoricalParameter, Parameter, StaticParameter from core.models.smooth import Smooth from core.models.source import SourceType @@ -10,7 +10,7 @@ @dataclass(frozen=True) class TdfiPulse(Pulse): type: PulseType = PulseType.Tdfi - source_type: Parameter = StaticParameter(SourceType.CLOSE) - smooth_type: Parameter = StaticParameter(Smooth.EMA) - period: Parameter = StaticParameter(14.0) + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = CategoricalParameter(Smooth) + period: Parameter = StaticParameter(8.0) n: Parameter = StaticParameter(3.0) diff --git a/strategy/generator/pulse/vo.py b/strategy/generator/pulse/vo.py index 1d0ff209..4cd56432 100644 --- a/strategy/generator/pulse/vo.py +++ b/strategy/generator/pulse/vo.py @@ -1,6 +1,6 @@ from dataclasses import dataclass -from core.models.parameter import Parameter, StaticParameter +from core.models.parameter import CategoricalParameter, Parameter, StaticParameter from core.models.smooth import Smooth from .base import Pulse, PulseType @@ -9,6 +9,6 @@ @dataclass(frozen=True) class VoPulse(Pulse): type: PulseType = PulseType.Vo - smooth_type: Parameter = StaticParameter(Smooth.EMA) - fast_period: Parameter = StaticParameter(7.0) - slow_period: Parameter = StaticParameter(13.0) + smooth: Parameter = CategoricalParameter(Smooth) + period_fast: Parameter = StaticParameter(7.0) + period_slow: Parameter = StaticParameter(13.0) diff --git a/strategy/generator/pulse/wae.py b/strategy/generator/pulse/wae.py index 3638ef2a..c015ca5f 100644 --- a/strategy/generator/pulse/wae.py +++ b/strategy/generator/pulse/wae.py @@ -1,7 +1,8 @@ from dataclasses import dataclass -from core.models.parameter import Parameter, StaticParameter +from core.models.parameter import CategoricalParameter, Parameter, StaticParameter from core.models.smooth import Smooth +from core.models.source import SourceType from .base import Pulse, PulseType @@ -9,12 +10,11 @@ @dataclass(frozen=True) class WaePulse(Pulse): type: PulseType = PulseType.Wae - smooth_type: Parameter = StaticParameter(Smooth.EMA) - fast_period: Parameter = StaticParameter(15.0) - slow_period: Parameter = StaticParameter(30.0) - smooth_bb: Parameter = StaticParameter(Smooth.SMA) - bb_period: Parameter = StaticParameter(15.0) - factor: Parameter = StaticParameter(2.0) - strength: Parameter = StaticParameter(150.0) - atr_period: Parameter = StaticParameter(100.0) - dz_factor: Parameter = StaticParameter(3.7) + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = CategoricalParameter(Smooth) + period_fast: Parameter = StaticParameter(10.0) + period_slow: Parameter = StaticParameter(29.0) + smooth_bb: Parameter = CategoricalParameter(Smooth) + period_bb: Parameter = StaticParameter(13.0) + factor: Parameter = StaticParameter(1.2) + strength: Parameter = StaticParameter(69.0) diff --git a/strategy/generator/pulse/yz.py b/strategy/generator/pulse/yz.py new file mode 100644 index 00000000..549ac379 --- /dev/null +++ b/strategy/generator/pulse/yz.py @@ -0,0 +1,14 @@ +from dataclasses import dataclass + +from core.models.parameter import Parameter, StaticParameter +from core.models.smooth import Smooth + +from .base import Pulse, PulseType + + +@dataclass(frozen=True) +class YzPulse(Pulse): + type: PulseType = PulseType.Yz + period: Parameter = StaticParameter(60.0) + smooth_signal: Parameter = StaticParameter(Smooth.HMA) + period_signal: Parameter = StaticParameter(12.0) diff --git a/strategy/generator/signal/base.py b/strategy/generator/signal/base.py index 09d4969c..332f228a 100644 --- a/strategy/generator/signal/base.py +++ b/strategy/generator/signal/base.py @@ -13,6 +13,7 @@ class SignalType(Enum): CfoZeroCross = "CfoZeroCross" DiZeroCross = "DiZeroCross" MacdZeroCross = "MacdZeroCross" + MadZeroCross = "MadZeroCross" QstickZeroCross = "QstickZeroCross" RocZeroCross = "RocZeroCross" TrixZeroCross = "TrixZeroCross" @@ -32,18 +33,30 @@ class SignalType(Enum): VwapBb = "VwapBb" # Pattern AoSaucer = "AoSaucer" + Spread = "Spread" HighLow = "HighLow" MacdColorSwitch = "MacdColorSwitch" + CandlestickTrend = "CandlestickTrend" + CandlestickReversal = "CandlestickReversal" + # Contrarian + Snatr = "Snatr" + RsiC = "RsiC" + RsiD = "RsiD" RsiV = "RsiV" + RsiU = "RsiU" + RsiNt = "RsiNt" TiiV = "TiiV" - CandlestickTrend = "CandlestickTrend" + StochE = "StochE" + KchA = "KchA" + KchC = "KchC" # Flip CeFlip = "CeFlip" SupFlip = "SupFlip" - # Reversal - DmiReversal = "DmiReversal" - SnatrReversal = "SnatrReversal" - ViReversal = "ViReversal" + # Pullback + SupPullback = "SupPullback" + # Two Lines Cross + Dmi2LinesCross = "Dmi2LinesCross" + Vi2LinesCross = "Vi2LinesCross" # Ma Ma3Cross = "Ma3Cross" MaTestingGround = "MaTestingGround" diff --git a/strategy/generator/signal/bb/macd_bb.py b/strategy/generator/signal/bb/macd.py similarity index 94% rename from strategy/generator/signal/bb/macd_bb.py rename to strategy/generator/signal/bb/macd.py index 95ff8302..0f3e4080 100644 --- a/strategy/generator/signal/bb/macd_bb.py +++ b/strategy/generator/signal/bb/macd.py @@ -19,4 +19,4 @@ class MacdBbSignal(Signal): signal_period: Parameter = StaticParameter(9.0) bb_smooth: Parameter = StaticParameter(Smooth.SMA) bb_period: Parameter = StaticParameter(9.0) - factor: Parameter = StaticParameter(0.8) + factor: Parameter = StaticParameter(0.6) diff --git a/strategy/generator/signal/bb/vwap_bb.py b/strategy/generator/signal/bb/vwap.py similarity index 100% rename from strategy/generator/signal/bb/vwap_bb.py rename to strategy/generator/signal/bb/vwap.py diff --git a/strategy/generator/signal/reversal/__init__.py b/strategy/generator/signal/colorswitch/__init__.py similarity index 100% rename from strategy/generator/signal/reversal/__init__.py rename to strategy/generator/signal/colorswitch/__init__.py diff --git a/strategy/generator/signal/pattern/macd_colorswitch.py b/strategy/generator/signal/colorswitch/macd.py similarity index 78% rename from strategy/generator/signal/pattern/macd_colorswitch.py rename to strategy/generator/signal/colorswitch/macd.py index 423f458c..768e9794 100644 --- a/strategy/generator/signal/pattern/macd_colorswitch.py +++ b/strategy/generator/signal/colorswitch/macd.py @@ -1,7 +1,6 @@ from dataclasses import dataclass from core.models.parameter import ( - CategoricalParameter, Parameter, StaticParameter, ) @@ -13,8 +12,8 @@ @dataclass(frozen=True) class MacdColorSwitchSignal(Signal): type: SignalType = SignalType.MacdColorSwitch - source_type: Parameter = CategoricalParameter(SourceType) - smooth_type: Parameter = CategoricalParameter(Smooth) + source_type: Parameter = StaticParameter(SourceType.HLC3) + smooth_type: Parameter = StaticParameter(Smooth.SMA) fast_period: Parameter = StaticParameter(12.0) slow_period: Parameter = StaticParameter(26.0) signal_period: Parameter = StaticParameter(9.0) diff --git a/ta_lib/core/src/distance.rs b/strategy/generator/signal/contrarian/__init__.py similarity index 100% rename from ta_lib/core/src/distance.rs rename to strategy/generator/signal/contrarian/__init__.py diff --git a/strategy/generator/signal/contrarian/kch_a.py b/strategy/generator/signal/contrarian/kch_a.py new file mode 100644 index 00000000..b2fc3360 --- /dev/null +++ b/strategy/generator/signal/contrarian/kch_a.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + +from core.models.parameter import Parameter, RandomParameter, StaticParameter +from core.models.smooth import Smooth, SmoothATR +from core.models.source import SourceType +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class KchASignal(Signal): + type: SignalType = SignalType.KchA + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = StaticParameter(Smooth.UTLS) + period: Parameter = RandomParameter(20.0, 60.0, 10.0) + smooth_atr: Parameter = StaticParameter(SmoothATR.UTLS) + period_atr: Parameter = RandomParameter(20.0, 80.0, 10.0) + factor: Parameter = RandomParameter(0.3, 2.0, 0.1) diff --git a/strategy/generator/signal/contrarian/kch_c.py b/strategy/generator/signal/contrarian/kch_c.py new file mode 100644 index 00000000..110216fa --- /dev/null +++ b/strategy/generator/signal/contrarian/kch_c.py @@ -0,0 +1,17 @@ +from dataclasses import dataclass + +from core.models.parameter import Parameter, StaticParameter +from core.models.smooth import Smooth, SmoothATR +from core.models.source import SourceType +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class KchCSignal(Signal): + type: SignalType = SignalType.KchC + source: Parameter = StaticParameter(SourceType.HLC3) + smooth: Parameter = StaticParameter(Smooth.EMA) + period: Parameter = StaticParameter(20.0) + smooth_atr: Parameter = StaticParameter(SmoothATR.SMMA) + period_atr: Parameter = StaticParameter(20.0) + factor: Parameter = StaticParameter(1.0) diff --git a/strategy/generator/signal/contrarian/rsi_c.py b/strategy/generator/signal/contrarian/rsi_c.py new file mode 100644 index 00000000..95ab0ca0 --- /dev/null +++ b/strategy/generator/signal/contrarian/rsi_c.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +from core.models.parameter import ( + Parameter, + RandomParameter, + StaticParameter, +) +from core.models.smooth import Smooth +from core.models.source import SourceType +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class RsiCSignal(Signal): + type: SignalType = SignalType.RsiC + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = StaticParameter(Smooth.SMMA) + period: Parameter = StaticParameter(14.0) + threshold: Parameter = RandomParameter(0.0, 3.0, 1.0) diff --git a/strategy/generator/signal/contrarian/rsi_d.py b/strategy/generator/signal/contrarian/rsi_d.py new file mode 100644 index 00000000..9cdc22b3 --- /dev/null +++ b/strategy/generator/signal/contrarian/rsi_d.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +from core.models.parameter import ( + Parameter, + RandomParameter, + StaticParameter, +) +from core.models.smooth import Smooth +from core.models.source import SourceType +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class RsiDSignal(Signal): + type: SignalType = SignalType.RsiD + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = StaticParameter(Smooth.SMMA) + period: Parameter = StaticParameter(8.0) + threshold: Parameter = RandomParameter(0.0, 3.0, 1.0) diff --git a/strategy/generator/signal/contrarian/rsi_nt.py b/strategy/generator/signal/contrarian/rsi_nt.py new file mode 100644 index 00000000..5111af3f --- /dev/null +++ b/strategy/generator/signal/contrarian/rsi_nt.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +from core.models.parameter import ( + Parameter, + RandomParameter, + StaticParameter, +) +from core.models.smooth import Smooth +from core.models.source import SourceType +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class RsiNtSignal(Signal): + type: SignalType = SignalType.RsiNt + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = StaticParameter(Smooth.SMMA) + period: Parameter = StaticParameter(8.0) + threshold: Parameter = RandomParameter(0.0, 1.0, 1.0) diff --git a/strategy/generator/signal/contrarian/rsi_u.py b/strategy/generator/signal/contrarian/rsi_u.py new file mode 100644 index 00000000..ad61dfcf --- /dev/null +++ b/strategy/generator/signal/contrarian/rsi_u.py @@ -0,0 +1,19 @@ +from dataclasses import dataclass + +from core.models.parameter import ( + Parameter, + RandomParameter, + StaticParameter, +) +from core.models.smooth import Smooth +from core.models.source import SourceType +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class RsiUSignal(Signal): + type: SignalType = SignalType.RsiU + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = StaticParameter(Smooth.SMMA) + period: Parameter = StaticParameter(8.0) + threshold: Parameter = RandomParameter(0.0, 3.0, 1.0) diff --git a/strategy/generator/signal/pattern/rsi_v.py b/strategy/generator/signal/contrarian/rsi_v.py similarity index 71% rename from strategy/generator/signal/pattern/rsi_v.py rename to strategy/generator/signal/contrarian/rsi_v.py index 96e7fb5f..7dd8d6e5 100644 --- a/strategy/generator/signal/pattern/rsi_v.py +++ b/strategy/generator/signal/contrarian/rsi_v.py @@ -13,7 +13,7 @@ @dataclass(frozen=True) class RsiVSignal(Signal): type: SignalType = SignalType.RsiV - source_type: Parameter = StaticParameter(SourceType.CLOSE) - smooth_type: Parameter = StaticParameter(Smooth.SMMA) - rsi_period: Parameter = StaticParameter(8.0) + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = StaticParameter(Smooth.SMMA) + period: Parameter = StaticParameter(8.0) threshold: Parameter = RandomParameter(0.0, 3.0, 1.0) diff --git a/strategy/generator/signal/reversal/snatr_reversal.py b/strategy/generator/signal/contrarian/snatr.py similarity index 84% rename from strategy/generator/signal/reversal/snatr_reversal.py rename to strategy/generator/signal/contrarian/snatr.py index 37ec036b..64514ac8 100644 --- a/strategy/generator/signal/reversal/snatr_reversal.py +++ b/strategy/generator/signal/contrarian/snatr.py @@ -6,8 +6,8 @@ @dataclass(frozen=True) -class SnatrReversalSignal(Signal): - type: SignalType = SignalType.SnatrReversal +class SnatrSignal(Signal): + type: SignalType = SignalType.Snatr smooth_type: Parameter = StaticParameter(Smooth.WMA) atr_period: Parameter = StaticParameter(60.0) atr_smooth_period: Parameter = StaticParameter(13.0) diff --git a/strategy/generator/signal/contrarian/stoch_e.py b/strategy/generator/signal/contrarian/stoch_e.py new file mode 100644 index 00000000..e813fd60 --- /dev/null +++ b/strategy/generator/signal/contrarian/stoch_e.py @@ -0,0 +1,21 @@ +from dataclasses import dataclass + +from core.models.parameter import ( + Parameter, + RandomParameter, + StaticParameter, +) +from core.models.smooth import Smooth +from core.models.source import SourceType +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class StochESignal(Signal): + type: SignalType = SignalType.StochE + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = StaticParameter(Smooth.SMA) + period: Parameter = StaticParameter(34.0) + period_k: Parameter = StaticParameter(5.0) + period_d: Parameter = StaticParameter(3.0) + threshold: Parameter = RandomParameter(0.0, 1.0, 1.0) diff --git a/strategy/generator/signal/pattern/tii_v.py b/strategy/generator/signal/contrarian/tii_v.py similarity index 78% rename from strategy/generator/signal/pattern/tii_v.py rename to strategy/generator/signal/contrarian/tii_v.py index da83a759..e7fdcd50 100644 --- a/strategy/generator/signal/pattern/tii_v.py +++ b/strategy/generator/signal/contrarian/tii_v.py @@ -9,7 +9,7 @@ @dataclass(frozen=True) class TiiVSignal(Signal): type: SignalType = SignalType.TiiV - source_type: Parameter = StaticParameter(SourceType.CLOSE) - smooth_type: Parameter = StaticParameter(Smooth.SMA) + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = StaticParameter(Smooth.SMA) major_period: Parameter = StaticParameter(8.0) minor_period: Parameter = StaticParameter(2.0) diff --git a/strategy/generator/signal/flip/ce_flip.py b/strategy/generator/signal/flip/ce.py similarity index 58% rename from strategy/generator/signal/flip/ce_flip.py rename to strategy/generator/signal/flip/ce.py index 5eddb204..c9bd786c 100644 --- a/strategy/generator/signal/flip/ce_flip.py +++ b/strategy/generator/signal/flip/ce.py @@ -4,12 +4,16 @@ Parameter, StaticParameter, ) +from core.models.smooth import SmoothATR +from core.models.source import SourceType from strategy.generator.signal.base import Signal, SignalType @dataclass(frozen=True) class CeFlipSignal(Signal): type: SignalType = SignalType.CeFlip + source_type: Parameter = StaticParameter(SourceType.CLOSE) period: Parameter = StaticParameter(22.0) - atr_period: Parameter = StaticParameter(22.0) + smooth_atr: Parameter = StaticParameter(SmoothATR.SMMA) + period_atr: Parameter = StaticParameter(22.0) factor: Parameter = StaticParameter(3.0) diff --git a/strategy/generator/signal/flip/supertrend_flip.py b/strategy/generator/signal/flip/supertrend.py similarity index 65% rename from strategy/generator/signal/flip/supertrend_flip.py rename to strategy/generator/signal/flip/supertrend.py index 950d4903..a3b7ebab 100644 --- a/strategy/generator/signal/flip/supertrend_flip.py +++ b/strategy/generator/signal/flip/supertrend.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from core.models.parameter import Parameter, StaticParameter +from core.models.smooth import SmoothATR from core.models.source import SourceType from strategy.generator.signal.base import Signal, SignalType @@ -9,5 +10,6 @@ class SupertrendFlipSignal(Signal): type: SignalType = SignalType.SupFlip source_type: Parameter = StaticParameter(SourceType.HL2) - atr_period: Parameter = StaticParameter(10.0) - factor: Parameter = StaticParameter(2.0) + smooth_atr: Parameter = StaticParameter(SmoothATR.SMMA) + period_atr: Parameter = StaticParameter(8.0) + factor: Parameter = StaticParameter(0.86) diff --git a/strategy/generator/signal/ma/ma_cross.py b/strategy/generator/signal/ma/ma_cross.py index f49c7c01..2f80f9bc 100644 --- a/strategy/generator/signal/ma/ma_cross.py +++ b/strategy/generator/signal/ma/ma_cross.py @@ -12,4 +12,4 @@ class MaCrossSignal(Signal): type: SignalType = SignalType.MaCross source_type: Parameter = StaticParameter(SourceType.CLOSE) ma: Parameter = CategoricalParameter(MovingAverageType) - period: Parameter = RandomParameter(100.0, 150.0, 10.0) + period: Parameter = RandomParameter(10.0, 20.0, 5.0) diff --git a/strategy/generator/signal/ma/ma_surpass.py b/strategy/generator/signal/ma/ma_surpass.py index 3dc98d1e..9b199983 100644 --- a/strategy/generator/signal/ma/ma_surpass.py +++ b/strategy/generator/signal/ma/ma_surpass.py @@ -12,4 +12,4 @@ class MaSurpassSignal(Signal): type: SignalType = SignalType.MaSurpass source_type: Parameter = StaticParameter(SourceType.CLOSE) ma: Parameter = CategoricalParameter(MovingAverageType) - period: Parameter = RandomParameter(150.0, 200.0, 10.0) + period: Parameter = RandomParameter(10.0, 60.0, 5.0) diff --git a/strategy/generator/signal/ma/ma_testing_ground.py b/strategy/generator/signal/ma/ma_testing_ground.py index e1f08513..18384caf 100644 --- a/strategy/generator/signal/ma/ma_testing_ground.py +++ b/strategy/generator/signal/ma/ma_testing_ground.py @@ -1,7 +1,12 @@ from dataclasses import dataclass from core.models.moving_average import MovingAverageType -from core.models.parameter import CategoricalParameter, Parameter, StaticParameter +from core.models.parameter import ( + CategoricalParameter, + Parameter, + RandomParameter, + StaticParameter, +) from core.models.source import SourceType from strategy.generator.signal.base import Signal, SignalType @@ -11,4 +16,4 @@ class MaTestingGroundSignal(Signal): type: SignalType = SignalType.MaTestingGround source_type: Parameter = StaticParameter(SourceType.CLOSE) ma: Parameter = CategoricalParameter(MovingAverageType) - period: Parameter = StaticParameter(100.0) + period: Parameter = RandomParameter(10.0, 60.0, 5.0) diff --git a/strategy/generator/signal/neutrality/dso_neutrality_cross.py b/strategy/generator/signal/neutrality/dso_cross.py similarity index 100% rename from strategy/generator/signal/neutrality/dso_neutrality_cross.py rename to strategy/generator/signal/neutrality/dso_cross.py diff --git a/strategy/generator/signal/neutrality/rsi_neutrality_cross.py b/strategy/generator/signal/neutrality/rsi_cross.py similarity index 100% rename from strategy/generator/signal/neutrality/rsi_neutrality_cross.py rename to strategy/generator/signal/neutrality/rsi_cross.py diff --git a/strategy/generator/signal/neutrality/rsi_neutrality_pullback.py b/strategy/generator/signal/neutrality/rsi_pullback.py similarity index 100% rename from strategy/generator/signal/neutrality/rsi_neutrality_pullback.py rename to strategy/generator/signal/neutrality/rsi_pullback.py diff --git a/strategy/generator/signal/neutrality/rsi_neutrality_rejection.py b/strategy/generator/signal/neutrality/rsi_rejection.py similarity index 100% rename from strategy/generator/signal/neutrality/rsi_neutrality_rejection.py rename to strategy/generator/signal/neutrality/rsi_rejection.py diff --git a/strategy/generator/signal/neutrality/tii_neutrality_cross.py b/strategy/generator/signal/neutrality/tii_cross.py similarity index 100% rename from strategy/generator/signal/neutrality/tii_neutrality_cross.py rename to strategy/generator/signal/neutrality/tii_cross.py diff --git a/strategy/generator/signal/pattern/candle_reversal.py b/strategy/generator/signal/pattern/candle_reversal.py new file mode 100644 index 00000000..0508da7c --- /dev/null +++ b/strategy/generator/signal/pattern/candle_reversal.py @@ -0,0 +1,11 @@ +from dataclasses import dataclass + +from core.models.candle import CandleReversalType +from core.models.parameter import CategoricalParameter, Parameter +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class CandlestickReversalSignal(Signal): + type: SignalType = SignalType.CandlestickReversal + candle: Parameter = CategoricalParameter(CandleReversalType) diff --git a/strategy/generator/signal/pattern/spread.py b/strategy/generator/signal/pattern/spread.py new file mode 100644 index 00000000..73b02951 --- /dev/null +++ b/strategy/generator/signal/pattern/spread.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + +from core.models.parameter import ( + Parameter, + StaticParameter, +) +from core.models.smooth import Smooth +from core.models.source import SourceType +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class SpreadSignal(Signal): + type: SignalType = SignalType.Spread + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = StaticParameter(Smooth.EMA) + period_fast: Parameter = StaticParameter(6.0) + period_slow: Parameter = StaticParameter(8.0) diff --git a/strategy/generator/signal/pullback/__init__.py b/strategy/generator/signal/pullback/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/strategy/generator/signal/pullback/supertrend.py b/strategy/generator/signal/pullback/supertrend.py new file mode 100644 index 00000000..a5d5eb60 --- /dev/null +++ b/strategy/generator/signal/pullback/supertrend.py @@ -0,0 +1,15 @@ +from dataclasses import dataclass + +from core.models.parameter import Parameter, StaticParameter +from core.models.smooth import SmoothATR +from core.models.source import SourceType +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class SupertrendPullbackSignal(Signal): + type: SignalType = SignalType.SupPullback + source: Parameter = StaticParameter(SourceType.HL2) + smooth_atr: Parameter = StaticParameter(SmoothATR.SMMA) + period_atr: Parameter = StaticParameter(8.0) + factor: Parameter = StaticParameter(0.86) diff --git a/strategy/generator/signal/reversal/vi_reversal.py b/strategy/generator/signal/reversal/vi_reversal.py deleted file mode 100644 index 5e12c20c..00000000 --- a/strategy/generator/signal/reversal/vi_reversal.py +++ /dev/null @@ -1,11 +0,0 @@ -from dataclasses import dataclass - -from core.models.parameter import Parameter, StaticParameter -from strategy.generator.signal.base import Signal, SignalType - - -@dataclass(frozen=True) -class ViReversalSignal(Signal): - type: SignalType = SignalType.ViReversal - atr_period: Parameter = StaticParameter(1.0) - period: Parameter = StaticParameter(4.0) diff --git a/strategy/generator/signal/signalline/di_signalline.py b/strategy/generator/signal/signalline/di.py similarity index 100% rename from strategy/generator/signal/signalline/di_signalline.py rename to strategy/generator/signal/signalline/di.py diff --git a/strategy/generator/signal/signalline/dso_signalline.py b/strategy/generator/signal/signalline/dso.py similarity index 100% rename from strategy/generator/signal/signalline/dso_signalline.py rename to strategy/generator/signal/signalline/dso.py diff --git a/strategy/generator/signal/signalline/kst_signalline.py b/strategy/generator/signal/signalline/kst.py similarity index 100% rename from strategy/generator/signal/signalline/kst_signalline.py rename to strategy/generator/signal/signalline/kst.py diff --git a/strategy/generator/signal/signalline/macd_signalline.py b/strategy/generator/signal/signalline/macd.py similarity index 100% rename from strategy/generator/signal/signalline/macd_signalline.py rename to strategy/generator/signal/signalline/macd.py diff --git a/strategy/generator/signal/signalline/qstick_signalline.py b/strategy/generator/signal/signalline/qstick.py similarity index 100% rename from strategy/generator/signal/signalline/qstick_signalline.py rename to strategy/generator/signal/signalline/qstick.py diff --git a/strategy/generator/signal/signalline/rsi_signalline.py b/strategy/generator/signal/signalline/rsi.py similarity index 100% rename from strategy/generator/signal/signalline/rsi_signalline.py rename to strategy/generator/signal/signalline/rsi.py diff --git a/strategy/generator/signal/signalline/stoch_signalline.py b/strategy/generator/signal/signalline/stoch.py similarity index 100% rename from strategy/generator/signal/signalline/stoch_signalline.py rename to strategy/generator/signal/signalline/stoch.py diff --git a/strategy/generator/signal/signalline/trix_signalline.py b/strategy/generator/signal/signalline/trix.py similarity index 100% rename from strategy/generator/signal/signalline/trix_signalline.py rename to strategy/generator/signal/signalline/trix.py diff --git a/strategy/generator/signal/signalline/tsi_signalline.py b/strategy/generator/signal/signalline/tsi.py similarity index 100% rename from strategy/generator/signal/signalline/tsi_signalline.py rename to strategy/generator/signal/signalline/tsi.py diff --git a/strategy/generator/signal/twolinescross/__init__.py b/strategy/generator/signal/twolinescross/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/strategy/generator/signal/reversal/dmi_reversal.py b/strategy/generator/signal/twolinescross/dmi.py similarity index 72% rename from strategy/generator/signal/reversal/dmi_reversal.py rename to strategy/generator/signal/twolinescross/dmi.py index 2a06a615..990b691a 100644 --- a/strategy/generator/signal/reversal/dmi_reversal.py +++ b/strategy/generator/signal/twolinescross/dmi.py @@ -9,8 +9,8 @@ @dataclass(frozen=True) -class DmiReversalSignal(Signal): - type: SignalType = SignalType.DmiReversal +class Dmi2LinesCrossSignal(Signal): + type: SignalType = SignalType.Dmi2LinesCross smooth_type: Parameter = StaticParameter(Smooth.SMMA) adx_period: Parameter = StaticParameter(8.0) - di_period: Parameter = StaticParameter(4.0) + di_period: Parameter = StaticParameter(8.0) diff --git a/strategy/generator/signal/twolinescross/vi.py b/strategy/generator/signal/twolinescross/vi.py new file mode 100644 index 00000000..a9dbc83b --- /dev/null +++ b/strategy/generator/signal/twolinescross/vi.py @@ -0,0 +1,13 @@ +from dataclasses import dataclass + +from core.models.parameter import Parameter, StaticParameter +from core.models.smooth import SmoothATR +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class Vi2LinesCrossSignal(Signal): + type: SignalType = SignalType.Vi2LinesCross + period: Parameter = StaticParameter(6.0) + smooth_atr: Parameter = StaticParameter(SmoothATR.UTLS) + period_atr: Parameter = StaticParameter(1.0) diff --git a/strategy/generator/signal/zerocross/ao_zerocross.py b/strategy/generator/signal/zerocross/ao.py similarity index 100% rename from strategy/generator/signal/zerocross/ao_zerocross.py rename to strategy/generator/signal/zerocross/ao.py diff --git a/strategy/generator/signal/zerocross/bop_zerocross.py b/strategy/generator/signal/zerocross/bop.py similarity index 100% rename from strategy/generator/signal/zerocross/bop_zerocross.py rename to strategy/generator/signal/zerocross/bop.py diff --git a/strategy/generator/signal/zerocross/cc_zerocross.py b/strategy/generator/signal/zerocross/cc.py similarity index 100% rename from strategy/generator/signal/zerocross/cc_zerocross.py rename to strategy/generator/signal/zerocross/cc.py diff --git a/strategy/generator/signal/zerocross/cfo_zerocross.py b/strategy/generator/signal/zerocross/cfo.py similarity index 100% rename from strategy/generator/signal/zerocross/cfo_zerocross.py rename to strategy/generator/signal/zerocross/cfo.py diff --git a/strategy/generator/signal/zerocross/di_zerocross.py b/strategy/generator/signal/zerocross/di.py similarity index 100% rename from strategy/generator/signal/zerocross/di_zerocross.py rename to strategy/generator/signal/zerocross/di.py diff --git a/strategy/generator/signal/zerocross/macd_zerocross.py b/strategy/generator/signal/zerocross/macd.py similarity index 100% rename from strategy/generator/signal/zerocross/macd_zerocross.py rename to strategy/generator/signal/zerocross/macd.py diff --git a/strategy/generator/signal/zerocross/mad.py b/strategy/generator/signal/zerocross/mad.py new file mode 100644 index 00000000..9eb41b91 --- /dev/null +++ b/strategy/generator/signal/zerocross/mad.py @@ -0,0 +1,18 @@ +from dataclasses import dataclass + +from core.models.parameter import ( + Parameter, + StaticParameter, +) +from core.models.smooth import Smooth +from core.models.source import SourceType +from strategy.generator.signal.base import Signal, SignalType + + +@dataclass(frozen=True) +class MadZeroCrossSignal(Signal): + type: SignalType = SignalType.MadZeroCross + source: Parameter = StaticParameter(SourceType.CLOSE) + smooth: Parameter = StaticParameter(Smooth.SMA) + period_fast: Parameter = StaticParameter(8.0) + period_slow: Parameter = StaticParameter(23.0) diff --git a/strategy/generator/signal/zerocross/qstick_zerocross.py b/strategy/generator/signal/zerocross/qstick.py similarity index 100% rename from strategy/generator/signal/zerocross/qstick_zerocross.py rename to strategy/generator/signal/zerocross/qstick.py diff --git a/strategy/generator/signal/zerocross/roc_zerocross.py b/strategy/generator/signal/zerocross/roc.py similarity index 100% rename from strategy/generator/signal/zerocross/roc_zerocross.py rename to strategy/generator/signal/zerocross/roc.py diff --git a/strategy/generator/signal/zerocross/trix_zerocross.py b/strategy/generator/signal/zerocross/trix.py similarity index 100% rename from strategy/generator/signal/zerocross/trix_zerocross.py rename to strategy/generator/signal/zerocross/trix.py diff --git a/strategy/generator/signal/zerocross/tsi_zerocross.py b/strategy/generator/signal/zerocross/tsi.py similarity index 100% rename from strategy/generator/signal/zerocross/tsi_zerocross.py rename to strategy/generator/signal/zerocross/tsi.py diff --git a/strategy/generator/stop_loss/atr.py b/strategy/generator/stop_loss/atr.py index 31ed603f..756207d3 100644 --- a/strategy/generator/stop_loss/atr.py +++ b/strategy/generator/stop_loss/atr.py @@ -1,6 +1,7 @@ from dataclasses import dataclass from core.models.parameter import Parameter, StaticParameter +from core.models.smooth import SmoothATR from .base import StopLoss, StopLossType @@ -8,5 +9,6 @@ @dataclass(frozen=True) class AtrStopLoss(StopLoss): type: StopLossType = StopLossType.Atr - period: Parameter = StaticParameter(14.0) - factor: Parameter = StaticParameter(1.618) + smooth: Parameter = StaticParameter(SmoothATR.SMMA) + period: Parameter = StaticParameter(6.0) + factor: Parameter = StaticParameter(0.68) diff --git a/strategy/generator/stop_loss/dch.py b/strategy/generator/stop_loss/dch.py index 0d69f161..28adb312 100644 --- a/strategy/generator/stop_loss/dch.py +++ b/strategy/generator/stop_loss/dch.py @@ -8,5 +8,5 @@ @dataclass(frozen=True) class DchStopLoss(StopLoss): type: StopLossType = StopLossType.Dch - period: Parameter = StaticParameter(21.0) + period: Parameter = StaticParameter(8.0) factor: Parameter = StaticParameter(0.2) diff --git a/system/backtest.py b/system/backtest.py index 9d6ea450..a8f4d8bf 100644 --- a/system/backtest.py +++ b/system/backtest.py @@ -116,28 +116,8 @@ async def _generate(self): futures_symbols = await self.query(GetSymbols()) - scalp = [ - # "LUNA2USDT", - # "ALGOUSDT", - # "WAVESUSDT", - # "NEARUSDT", - "SOLUSDT", - # "FILUSDT" - # "TONUSDT" - # "ATOMUSDT", - # "SCUSDT" - # "FTMUSDT", - # "BOMEUSDT" - # "APEUSDT" - # "BTCUSDT" - # "ADAUSDT" - ] - - futures_symbols = [symbol for symbol in futures_symbols if symbol.name in scalp] + generator = self.context.strategy_generator_factory.create(futures_symbols) - generator = self.context.strategy_generator_factory.create( - self.context.strategy_type, futures_symbols - ) self.optimizer = self.context.strategy_optimizer_factory.create( Optimizer.GENETIC, generator, diff --git a/system/context.py b/system/context.py index 5e064b43..21e2a17e 100644 --- a/system/context.py +++ b/system/context.py @@ -10,7 +10,6 @@ AbstractStrategyGeneratorFactory, ) from core.models.exchange import ExchangeType -from core.models.strategy import StrategyType from infrastructure.config import ConfigService @@ -24,5 +23,4 @@ class SystemContext: strategy_generator_factory: AbstractStrategyGeneratorFactory strategy_optimizer_factory: AbstractStrategyOptimizerFactory exchange_type: ExchangeType - strategy_type: StrategyType config_service: ConfigService diff --git a/system/trading.py b/system/trading.py index 00fd8e0a..a026e0fd 100644 --- a/system/trading.py +++ b/system/trading.py @@ -137,13 +137,15 @@ async def _run_pretrading(self): feed_actor.stop() + await asyncio.sleep(1.0) + signal_actors[(symbol, timeframe)].append(signal_actor) for (symbol, timeframe), _ in self.next_strategy.items(): await self.execute( UpdateSettings( symbol, - self.config["leverage"], + min(symbol.max_leverage, self.config["leverage"]), PositionMode.HEDGED, MarginMode.CROSS, ) diff --git a/ta_lib/Cargo.lock b/ta_lib/Cargo.lock index e6add65b..f8ecebc7 100644 --- a/ta_lib/Cargo.lock +++ b/ta_lib/Cargo.lock @@ -19,15 +19,19 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstyle" -version = "1.0.6" +version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8901269c6307e8d93993578286ac0edf7f195079ffff5ebdeea6a59ffb7e36bc" +checksum = "1bec1de6f59aedf83baf9ff929c98f2ad654b97c9510f4e70cf6f661d49fd5b1" [[package]] name = "autocfg" -version = "1.2.0" +version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" +checksum = "0c4b4d0bd25bd0b74681c0ad21497610ce1b7c91b1022cd21c80c6fbdd9476b0" + +[[package]] +name = "bands" +version = "0.1.0" [[package]] name = "base" @@ -36,6 +40,7 @@ dependencies = [ "core", "once_cell", "price", + "timeseries", "volatility", ] @@ -47,6 +52,7 @@ dependencies = [ "core", "indicator", "signal", + "timeseries", "trend", ] @@ -91,6 +97,10 @@ version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +[[package]] +name = "channel" +version = "0.1.0" + [[package]] name = "ciborium" version = "0.2.2" @@ -120,18 +130,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.4" +version = "4.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" +checksum = "3e5a21b8495e732f1b3c364c9949b201ca7bae518c502c80256c96ad79eaf6ac" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.2" +version = "4.5.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +checksum = "8cf2dd12af7a047ad9d6da2b6b249759a22a7abc0f474c1dae1777afa4b21a73" dependencies = [ "anstyle", "clap_lex", @@ -139,9 +149,9 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.0" +version = "0.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +checksum = "1462739cb27611015575c0c11df5df7601141071f07518d56fcc1be504cbec97" [[package]] name = "confirm" @@ -150,7 +160,9 @@ dependencies = [ "base", "core", "momentum", + "timeseries", "trend", + "volatility", "volume", ] @@ -215,9 +227,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.19" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "crunchy" @@ -227,9 +239,9 @@ checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" [[package]] name = "either" -version = "1.11.0" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] name = "exit" @@ -240,10 +252,21 @@ dependencies = [ "core", "indicator", "momentum", + "timeseries", "trend", "volume", ] +[[package]] +name = "ffi" +version = "0.1.0" +dependencies = [ + "once_cell", + "serde", + "serde_json", + "timeseries", +] + [[package]] name = "half" version = "2.4.1" @@ -256,9 +279,9 @@ dependencies = [ [[package]] name = "hermit-abi" -version = "0.3.9" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" +checksum = "fbf6a919d6cf397374f7dfeeea91d974c7c0a7221d0d0f4f20d859d329e53fcc" [[package]] name = "indicator" @@ -267,18 +290,19 @@ dependencies = [ "base", "candlestick", "core", + "timeseries", "trend", ] [[package]] name = "is-terminal" -version = "0.4.12" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f23ff5ef2b80d608d61efee834934d862cd92461afc0560dedf493e4c033738b" +checksum = "261f68e344040fbd0edea105bef17c66edf46f984ddb1115b775ce31be948f4b" dependencies = [ "hermit-abi", "libc", - "windows-sys", + "windows-sys 0.52.0", ] [[package]] @@ -298,30 +322,30 @@ checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" [[package]] name = "js-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "29c15563dc2726973df627357ce0c9ddddbea194836909d655df6a75d2cf296d" +checksum = "1868808506b929d7b0cfa8f75951347aa71bb21144b7791bae35d9bccfcfe37a" dependencies = [ "wasm-bindgen", ] [[package]] name = "libc" -version = "0.2.153" +version = "0.2.158" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +checksum = "d8adc4bb1803a324070e64a98ae98f38934d91957a99cfb3a43dcbc01bc56439" [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "memchr" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "momentum" @@ -334,9 +358,9 @@ dependencies = [ [[package]] name = "num-traits" -version = "0.2.18" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", ] @@ -349,15 +373,19 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "oorandom" -version = "11.1.3" +version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" + +[[package]] +name = "osc" +version = "0.1.0" [[package]] name = "plotters" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" dependencies = [ "num-traits", "plotters-backend", @@ -368,15 +396,15 @@ dependencies = [ [[package]] name = "plotters-backend" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" [[package]] name = "plotters-svg" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" dependencies = [ "plotters-backend", ] @@ -390,9 +418,9 @@ dependencies = [ [[package]] name = "proc-macro2" -version = "1.0.81" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] @@ -404,6 +432,7 @@ dependencies = [ "base", "core", "momentum", + "timeseries", "trend", "volatility", "volume", @@ -411,9 +440,9 @@ dependencies = [ [[package]] name = "quote" -version = "1.0.36" +version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] @@ -440,9 +469,9 @@ dependencies = [ [[package]] name = "regex" -version = "1.10.4" +version = "1.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "4219d74c6b67a3654a9fbebc4b419e22126d13d2f3c4a07ee0cb61ff79a79619" dependencies = [ "aho-corasick", "memchr", @@ -452,9 +481,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" dependencies = [ "aho-corasick", "memchr", @@ -463,15 +492,15 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" [[package]] name = "ryu" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "same-file" @@ -484,18 +513,18 @@ dependencies = [ [[package]] name = "serde" -version = "1.0.198" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" +checksum = "99fce0ffe7310761ca6bf9faf5115afbc19688edd00171d81b1bb1b116c63e09" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.198" +version = "1.0.209" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" +checksum = "a5831b979fd7b5439637af1752d535ff49f4860c0f341d1baeb6faf0f4242170" dependencies = [ "proc-macro2", "quote", @@ -504,11 +533,12 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.116" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -517,11 +547,16 @@ dependencies = [ name = "signal" version = "0.1.0" dependencies = [ + "bands", "base", "candlestick", + "channel", "core", "indicator", "momentum", + "osc", + "timeseries", + "trail", "trend", "volatility", "volume", @@ -533,20 +568,34 @@ version = "0.1.0" dependencies = [ "base", "core", + "timeseries", "volatility", ] [[package]] name = "syn" -version = "2.0.60" +version = "2.0.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] +[[package]] +name = "timeseries" +version = "0.1.0" +dependencies = [ + "core", + "momentum", + "price", + "serde", + "trend", + "volatility", + "volume", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -557,6 +606,13 @@ dependencies = [ "serde_json", ] +[[package]] +name = "trail" +version = "0.1.0" +dependencies = [ + "core", +] + [[package]] name = "trend" version = "0.1.0" @@ -581,6 +637,7 @@ dependencies = [ "serde_json", "signal", "stop_loss", + "timeseries", ] [[package]] @@ -617,19 +674,20 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" +checksum = "a82edfc16a6c469f5f44dc7b571814045d60404b55a0ee849f9bcfa2e63dd9b5" dependencies = [ "cfg-if", + "once_cell", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" +checksum = "9de396da306523044d3302746f1208fa71d7532227f15e347e2d93e4145dd77b" dependencies = [ "bumpalo", "log", @@ -642,9 +700,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" +checksum = "585c4c91a46b072c92e908d99cb1dcdf95c5218eeb6f3bf1efa991ee7a68cccf" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -652,9 +710,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" +checksum = "afc340c74d9005395cf9dd098506f7f44e38f2b4a21c6aaacf9a105ea5e1e836" dependencies = [ "proc-macro2", "quote", @@ -665,65 +723,52 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.92" +version = "0.2.93" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +checksum = "c62a0a307cb4a311d3a07867860911ca130c3494e8c2719593806c08bc5d0484" [[package]] name = "web-sys" -version = "0.3.69" +version = "0.3.70" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77afa9a11836342370f4817622a2f0f418b134426d91a82dfb48f532d2ec13ef" +checksum = "26fdeaafd9bd129f65e7c031593c24d62186301e0c72c8978fa1678be7d532c0" dependencies = [ "js-sys", "wasm-bindgen", ] [[package]] -name = "winapi" -version = "0.3.9" +name = "winapi-util" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +checksum = "cf221c93e13a30d793f7645a0e7762c55d169dbb0a49671918a2319d289b10bb" dependencies = [ - "winapi-i686-pc-windows-gnu", - "winapi-x86_64-pc-windows-gnu", + "windows-sys 0.59.0", ] [[package]] -name = "winapi-i686-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" - -[[package]] -name = "winapi-util" -version = "0.1.6" +name = "windows-sys" +version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "winapi", + "windows-targets", ] -[[package]] -name = "winapi-x86_64-pc-windows-gnu" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" - [[package]] name = "windows-sys" -version = "0.52.0" +version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ "windows-targets", ] [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -737,48 +782,48 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/ta_lib/Cargo.toml b/ta_lib/Cargo.toml index 41f345b8..6993405b 100644 --- a/ta_lib/Cargo.toml +++ b/ta_lib/Cargo.toml @@ -4,11 +4,17 @@ resolver = "2" members = [ "benches", "core", +"timeseries", +"ffi", "indicators/momentum", "indicators/trend", "indicators/volatility", "indicators/volume", "patterns/candlestick", +"patterns/osc", +"patterns/bands", +"patterns/channel", +"patterns/trail", "price", "strategies/base", "strategies/trend_follow", @@ -17,7 +23,7 @@ members = [ "strategies/confirm", "strategies/pulse", "strategies/baseline", -"strategies/indicator", +"strategies/indicator", "timeseries", "ffi", "patterns/osc", ] [workspace.package] @@ -29,6 +35,6 @@ repository = "https://github.com/YieldLabs/quant" version = "0.1.0" [profile.release] -opt-level = 'z' +opt-level = 3 lto = true panic = "abort" \ No newline at end of file diff --git a/ta_lib/benches/Cargo.toml b/ta_lib/benches/Cargo.toml index 39e4ab65..7eed6486 100644 --- a/ta_lib/benches/Cargo.toml +++ b/ta_lib/benches/Cargo.toml @@ -24,11 +24,6 @@ name = "patterns" harness = false path = "patterns.rs" -[[bench]] -name = "strategy" -harness = false -path = "strategy.rs" - [dependencies] core = { path = "../core" } candlestick = { path = "../patterns/candlestick" } diff --git a/ta_lib/benches/indicators.rs b/ta_lib/benches/indicators.rs index 227c1893..61951799 100644 --- a/ta_lib/benches/indicators.rs +++ b/ta_lib/benches/indicators.rs @@ -40,25 +40,6 @@ fn momentum(c: &mut Criterion) { 7.1230, 7.1225, 7.1180, 7.1250, ]; - group.bench_function("ao", |b| { - b.iter_batched_ref( - || { - let high = Series::from(&high); - let low = Series::from(&low); - let source = median_price(&high, &low); - let smooth_type = Smooth::SMA; - let fast_period = 5; - let slow_period = 34; - - (source, smooth_type, fast_period, slow_period) - }, - |(source, smooth_type, fast_period, slow_period)| { - ao(source, *smooth_type, *fast_period, *slow_period) - }, - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("bop", |b| { b.iter_batched_ref( || { @@ -113,12 +94,11 @@ fn momentum(c: &mut Criterion) { let low = Series::from(&low); let close = Series::from(&close); let hlc3 = typical_price(&high, &low, &close); - let smooth_type = Smooth::SMA; let period = 14; let factor = 0.015; - (hlc3, smooth_type, period, factor) + (hlc3, period, factor) }, - |(hlc3, smooth_type, period, factor)| cci(hlc3, *smooth_type, *period, *factor), + |(hlc3, period, factor)| cci(hlc3, *period, *factor), criterion::BatchSize::SmallInput, ) }); @@ -591,7 +571,7 @@ fn trend(c: &mut Criterion) { ) }); - group.bench_function("kjs", |b| { + group.bench_function("midpoint", |b| { b.iter_batched_ref( || { let high = Series::from(&high); @@ -600,7 +580,7 @@ fn trend(c: &mut Criterion) { (high, low, period) }, - |(high, low, period)| kjs(high, low, *period), + |(high, low, period)| midpoint(high, low, *period), criterion::BatchSize::SmallInput, ) }); @@ -742,7 +722,7 @@ fn trend(c: &mut Criterion) { ) }); - group.bench_function("tma", |b| { + group.bench_function("trima", |b| { b.iter_batched_ref( || { let source = Series::from(&close); @@ -750,7 +730,7 @@ fn trend(c: &mut Criterion) { (source, period) }, - |(source, period)| tma(source, *period), + |(source, period)| trima(source, *period), criterion::BatchSize::SmallInput, ) }); @@ -913,7 +893,7 @@ fn volatility(c: &mut Criterion) { (hlc3, atr, smooth_type, period, factor) }, |(hlc3, atr, smooth_type, period, factor)| { - kch(hlc3, atr, *smooth_type, *period, *factor) + kch(hlc3, *smooth_type, atr, *period, *factor) }, criterion::BatchSize::SmallInput, ) @@ -1036,12 +1016,11 @@ fn volume(c: &mut Criterion) { let hl2 = median_price(&high, &low); let smooth_type = Smooth::SMA; let period = 14; - let divisor = 10000.0; - (hl2, high, low, volume, smooth_type, period, divisor) + (hl2, high, low, volume, smooth_type, period) }, - |(hl2, high, low, volume, smooth_type, period, divisor)| { - eom(hl2, high, low, volume, *smooth_type, *period, *divisor) + |(hl2, high, low, volume, smooth_type, period)| { + eom(hl2, high, low, volume, *smooth_type, *period) }, criterion::BatchSize::SmallInput, ) @@ -1077,23 +1056,6 @@ fn volume(c: &mut Criterion) { ) }); - group.bench_function("vo", |b| { - b.iter_batched_ref( - || { - let volume = Series::from(&volume); - let fast_period = 5; - let slow_period = 10; - let smooth_type = Smooth::EMA; - - (volume, smooth_type, fast_period, slow_period) - }, - |(volume, smooth_type, fast_period, slow_period)| { - vo(volume, *smooth_type, *fast_period, *slow_period) - }, - criterion::BatchSize::SmallInput, - ) - }); - group.bench_function("vwap", |b| { b.iter_batched_ref( || { diff --git a/ta_lib/benches/strategy.rs b/ta_lib/benches/strategy.rs deleted file mode 100644 index af766597..00000000 --- a/ta_lib/benches/strategy.rs +++ /dev/null @@ -1,96 +0,0 @@ -use base::*; -use criterion::{criterion_group, criterion_main, Criterion}; -use std::collections::VecDeque; - -fn model(c: &mut Criterion) { - let mut group = c.benchmark_group("model"); - let ts: Vec = vec![ - 1679827200, 1679827500, 1679827800, 1679828100, 1679828400, 1679828700, 1679829000, - 1679829300, 1679829600, 1679829900, 1679830200, 1679830500, 1679830800, 1679831100, - 1679831400, 1679831700, 1679832000, 1679832300, 1679832600, 1679832900, 1679833200, - 1679833500, 1679833800, 1679834100, 1679834400, 1679834700, 1679835000, 1679835300, - 1679835600, 1679835900, 1679836200, 1679836500, 1679836800, 1679837100, 1679837400, - 1679837700, 1679838000, 1679838300, 1679838600, 1679838900, 1679839200, 1679839500, - 1679839800, 1679840100, 1679840400, 1679840700, 1679841000, 1679841300, 1679841600, - 1679841900, 1679842200, - ]; - - let open: Vec = vec![ - 6.8430, 6.8660, 6.8635, 6.8610, 6.865, 6.8595, 6.8565, 6.852, 6.859, 6.86, 6.8580, 6.8605, - 6.8620, 6.867, 6.859, 6.8670, 6.8640, 6.8575, 6.8485, 6.8350, 7.1195, 7.136, 7.1405, 7.112, - 7.1095, 7.1520, 7.1310, 7.1550, 7.1480, 7.1435, 7.1405, 7.1440, 7.1495, 7.1515, 7.1415, - 7.1445, 7.1525, 7.1440, 7.1370, 7.1305, 7.1375, 7.1250, 7.1190, 7.1135, 7.1280, 7.1220, - 7.1330, 7.1225, 7.1180, 7.1250, - ]; - - let high: Vec = vec![ - 6.8430, 6.8660, 6.8685, 6.8690, 6.865, 6.8595, 6.8565, 6.862, 6.859, 6.86, 6.8580, 6.8605, - 6.8620, 6.86, 6.859, 6.8670, 6.8640, 6.8575, 6.8485, 6.8450, 7.1195, 7.136, 7.1405, 7.112, - 7.1095, 7.1220, 7.1310, 7.1550, 7.1480, 7.1435, 7.1405, 7.1440, 7.1495, 7.1515, 7.1415, - 7.1445, 7.1525, 7.1440, 7.1370, 7.1305, 7.1375, 7.1250, 7.1190, 7.1135, 7.1280, 7.1220, - 7.1230, 7.1225, 7.1180, 7.1250, - ]; - - let low: Vec = vec![ - 6.8380, 6.8430, 6.8595, 6.8640, 6.8435, 6.8445, 6.8510, 6.8560, 6.8520, 6.8530, 6.8550, - 6.8550, 6.8565, 6.8475, 6.8480, 6.8535, 6.8565, 6.8455, 6.8445, 6.8365, 7.1195, 7.136, - 7.1405, 7.112, 7.1095, 7.1220, 7.1310, 7.1550, 7.1480, 7.1435, 7.1405, 7.1440, 7.1495, - 7.1515, 7.1415, 7.1445, 7.1525, 7.1440, 7.1370, 7.1305, 7.1375, 7.1250, 7.1190, 7.1135, - 7.1280, 7.1220, 7.1230, 7.1225, 7.1180, 7.1250, - ]; - - let close: Vec = vec![ - 6.855, 6.858, 6.86, 6.8480, 6.8575, 6.864, 6.8565, 6.8455, 6.8450, 6.8365, 6.8310, 6.8355, - 6.8360, 6.8345, 6.8285, 6.8395, 7.1135, 7.088, 7.112, 7.1205, 7.1195, 7.136, 7.1405, 7.112, - 7.1095, 7.1220, 7.1310, 7.1550, 7.1480, 7.1435, 7.1405, 7.1440, 7.1495, 7.1515, 7.1415, - 7.1445, 7.1525, 7.1440, 7.1370, 7.1305, 7.1375, 7.1250, 7.1190, 7.1135, 7.1280, 7.1220, - 7.1230, 7.1225, 7.1180, 7.1250, - ]; - - let volume: Vec = vec![ - 60.855, 600.858, 60.86, 600.848, 60.8575, 60.864, 600.8565, 60.8455, 600.845, 600.8365, - 60.8310, 60.8355, 600.836, 60.8345, 600.8285, 60.8395, 700.1135, 70.088, 700.112, 70.1205, - 700.1195, 70.136, 70.1405, 70.112, 700.1095, 70.1220, 70.1310, 700.155, 70.1480, 70.1435, - 700.1405, 70.1440, 70.1495, 70.1515, 70.1415, 700.1445, 70.1525, 700.144, 70.1370, - 700.1305, 70.1375, 700.125, 700.119, 70.1135, 70.128, 700.122, 70.123, 700.1225, 70.118, - 70.125, - ]; - - group.bench_function("from_data", |b| { - b.iter_batched_ref( - || { - let mut data: VecDeque = VecDeque::with_capacity(200); - - let ohlcvs = open - .iter() - .zip(high.iter()) - .zip(low.iter()) - .zip(close.iter()) - .zip(volume.iter()) - .zip(ts.iter()) - .map(|(((((&o, &h), &l), &c), &v), &t)| OHLCV { - ts: t, - open: o, - high: h, - low: l, - close: c, - volume: v, - }) - .collect::>(); - - for ohlcv in ohlcvs { - data.push_back(ohlcv) - } - - data - }, - |data| OHLCVSeries::from_data(data), - criterion::BatchSize::SmallInput, - ) - }); - - group.finish(); -} - -criterion_group!(strategy, model); -criterion_main!(strategy); diff --git a/ta_lib/core/src/bitwise.rs b/ta_lib/core/src/bitwise.rs index 6c1b6221..5baeb0ac 100644 --- a/ta_lib/core/src/bitwise.rs +++ b/ta_lib/core/src/bitwise.rs @@ -1,30 +1,30 @@ -use crate::series::Series; use crate::traits::Bitwise; +use crate::types::Rule; use std::ops::{BitAnd, BitOr}; -impl Bitwise> for Series { - type Output = Series; +impl Bitwise for Rule { + type Output = Rule; - fn op(&self, rhs: &Series, operation: F) -> Self::Output + fn op(&self, rhs: &Rule, operation: F) -> Self::Output where - F: Fn(&bool, &bool) -> bool, + F: Fn(bool, bool) -> bool, { self.zip_with(rhs, |a, b| match (a, b) { - (Some(a_val), Some(b_val)) => Some(operation(a_val, b_val)), + (Some(a_val), Some(b_val)) => Some(operation(*a_val, *b_val)), (Some(_), None) | (None, Some(_)) | (None, None) => Some(false), }) } - fn sand(&self, rhs: &Series) -> Self::Output { + fn sand(&self, rhs: &Rule) -> Self::Output { self.op(rhs, |a, b| a & b) } - fn sor(&self, rhs: &Series) -> Self::Output { + fn sor(&self, rhs: &Rule) -> Self::Output { self.op(rhs, |a, b| a | b) } } -impl BitAnd for Series { +impl BitAnd for Rule { type Output = Self; fn bitand(self, rhs: Self) -> Self::Output { @@ -32,7 +32,7 @@ impl BitAnd for Series { } } -impl BitOr for Series { +impl BitOr for Rule { type Output = Self; fn bitor(self, rhs: Self) -> Self::Output { @@ -43,13 +43,14 @@ impl BitOr for Series { #[cfg(test)] mod tests { use super::*; + use crate::series::Series; use crate::traits::Comparator; #[test] fn test_bitand() { let a = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); let b = Series::from([1.0, 1.0, 6.0, 1.0, 1.0]); - let expected: Series = Series::from([0.0, 0.0, 0.0, 0.0, 0.0]).into(); + let expected: Rule = Series::from([0.0, 0.0, 0.0, 0.0, 0.0]).into(); let result = a.sgt(&b) & a.slt(&b); @@ -60,7 +61,7 @@ mod tests { fn test_bitor() { let a = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); let b = Series::from([1.0, 1.0, 1.0, 1.0, 1.0]); - let expected: Series = Series::from([0.0, 1.0, 1.0, 1.0, 1.0]).into(); + let expected: Rule = Series::from([0.0, 1.0, 1.0, 1.0, 1.0]).into(); let result = a.sgt(&b) | a.slt(&b); diff --git a/ta_lib/core/src/cmp.rs b/ta_lib/core/src/cmp.rs index 5ecd73d6..fdd84240 100644 --- a/ta_lib/core/src/cmp.rs +++ b/ta_lib/core/src/cmp.rs @@ -1,81 +1,82 @@ -use crate::series::Series; +use crate::constants::NAN; use crate::traits::Comparator; +use crate::types::{Price, Rule, Scalar}; -impl Comparator for Series { - type Output = Series; +impl Comparator for Price { + type Output = Rule; - fn compare(&self, scalar: &f32, comparator: F) -> Self::Output + fn compare(&self, scalar: &Scalar, comparator: F) -> Self::Output where - F: Fn(&f32, &f32) -> bool, + F: Fn(Scalar, Scalar) -> bool, { self.fmap(|x| { - x.map_or(Some(comparator(&f32::NAN, scalar)), |val| { - Some(comparator(val, scalar)) + x.map_or(Some(comparator(NAN, *scalar)), |val| { + Some(comparator(*val, *scalar)) }) }) } - fn seq(&self, rhs: &f32) -> Self::Output { + fn seq(&self, rhs: &Scalar) -> Self::Output { self.compare(rhs, |a, b| a == b) } - fn sne(&self, rhs: &f32) -> Self::Output { + fn sne(&self, rhs: &Scalar) -> Self::Output { self.compare(rhs, |a, b| a != b) } - fn sgt(&self, rhs: &f32) -> Self::Output { + fn sgt(&self, rhs: &Scalar) -> Self::Output { self.compare(rhs, |a, b| a > b) } - fn sgte(&self, rhs: &f32) -> Self::Output { + fn sgte(&self, rhs: &Scalar) -> Self::Output { self.compare(rhs, |a, b| a >= b) } - fn slt(&self, rhs: &f32) -> Self::Output { + fn slt(&self, rhs: &Scalar) -> Self::Output { self.compare(rhs, |a, b| a < b) } - fn slte(&self, rhs: &f32) -> Self::Output { + fn slte(&self, rhs: &Scalar) -> Self::Output { self.compare(rhs, |a, b| a <= b) } } -impl Comparator> for Series { - type Output = Series; +impl Comparator for Price { + type Output = Rule; - fn compare(&self, rhs: &Series, comparator: F) -> Self::Output + fn compare(&self, rhs: &Price, comparator: F) -> Self::Output where - F: Fn(&f32, &f32) -> bool, + F: Fn(Scalar, Scalar) -> bool, { self.zip_with(rhs, |a, b| match (a, b) { - (Some(a_val), Some(b_val)) => Some(comparator(a_val, b_val)), - (None, Some(b_val)) => Some(comparator(&f32::NAN, b_val)), - (Some(a_val), None) => Some(comparator(a_val, &f32::NAN)), + (Some(a_val), Some(b_val)) => Some(comparator(*a_val, *b_val)), + (None, Some(b_val)) => Some(comparator(NAN, *b_val)), + (Some(a_val), None) => Some(comparator(*a_val, NAN)), _ => None, }) } - fn seq(&self, rhs: &Series) -> Self::Output { + fn seq(&self, rhs: &Price) -> Self::Output { self.compare(rhs, |a, b| a == b) } - fn sne(&self, rhs: &Series) -> Self::Output { + fn sne(&self, rhs: &Price) -> Self::Output { self.compare(rhs, |a, b| a != b) } - fn sgt(&self, rhs: &Series) -> Self::Output { + fn sgt(&self, rhs: &Price) -> Self::Output { self.compare(rhs, |a, b| a > b) } - fn sgte(&self, rhs: &Series) -> Self::Output { + fn sgte(&self, rhs: &Price) -> Self::Output { self.compare(rhs, |a, b| a >= b) } - fn slt(&self, rhs: &Series) -> Self::Output { + fn slt(&self, rhs: &Price) -> Self::Output { self.compare(rhs, |a, b| a < b) } - fn slte(&self, rhs: &Series) -> Self::Output { + fn slte(&self, rhs: &Price) -> Self::Output { self.compare(rhs, |a, b| a <= b) } } @@ -83,12 +84,13 @@ impl Comparator> for Series { #[cfg(test)] mod tests { use super::*; + use crate::series::Series; #[test] fn test_scalar_eq() { - let a = Series::from([f32::NAN, 2.0, 3.0, 1.0, 5.0]); + let a = Series::from([NAN, 2.0, 3.0, 1.0, 5.0]); let b = 1.0; - let expected: Series = Series::from([0.0, 0.0, 0.0, 1.0, 0.0]).into(); + let expected: Rule = Series::from([0.0, 0.0, 0.0, 1.0, 0.0]).into(); let result = a.seq(&b); @@ -97,9 +99,9 @@ mod tests { #[test] fn test_scalar_ne() { - let a = Series::from([f32::NAN, 2.0, 3.0, 1.0, 5.0]); + let a = Series::from([NAN, 2.0, 3.0, 1.0, 5.0]); let b = 1.0; - let expected: Series = Series::from([1.0, 1.0, 1.0, 0.0, 1.0]).into(); + let expected: Rule = Series::from([1.0, 1.0, 1.0, 0.0, 1.0]).into(); let result = a.sne(&b); @@ -108,9 +110,9 @@ mod tests { #[test] fn test_scalar_gt() { - let a = Series::from([f32::NAN, 2.0, 3.0, 4.0, 5.0]); + let a = Series::from([NAN, 2.0, 3.0, 4.0, 5.0]); let b = 1.0; - let expected: Series = Series::from([0.0, 1.0, 1.0, 1.0, 1.0]).into(); + let expected: Rule = Series::from([0.0, 1.0, 1.0, 1.0, 1.0]).into(); let result = a.sgt(&b); @@ -119,9 +121,9 @@ mod tests { #[test] fn test_scalar_gte() { - let a = Series::from([f32::NAN, 2.0, 1.0, 1.0, 5.0]); + let a = Series::from([NAN, 2.0, 1.0, 1.0, 5.0]); let b = 1.0; - let expected: Series = Series::from([0.0, 1.0, 1.0, 1.0, 1.0]).into(); + let expected: Rule = Series::from([0.0, 1.0, 1.0, 1.0, 1.0]).into(); let result = a.sgte(&b); @@ -130,9 +132,9 @@ mod tests { #[test] fn test_scalar_lt() { - let a = Series::from([f32::NAN, 2.0, 3.0, 4.0, 5.0]); + let a = Series::from([NAN, 2.0, 3.0, 4.0, 5.0]); let b = 1.0; - let expected: Series = Series::from([0.0, 0.0, 0.0, 0.0, 0.0]).into(); + let expected: Rule = Series::from([0.0, 0.0, 0.0, 0.0, 0.0]).into(); let result = a.slt(&b); @@ -141,9 +143,9 @@ mod tests { #[test] fn test_scalar_lte() { - let a = Series::from([f32::NAN, 2.0, 3.0, 1.0, 5.0]); + let a = Series::from([NAN, 2.0, 3.0, 1.0, 5.0]); let b = 1.0; - let expected: Series = Series::from([0.0, 0.0, 0.0, 1.0, 0.0]).into(); + let expected: Rule = Series::from([0.0, 0.0, 0.0, 1.0, 0.0]).into(); let result = a.slte(&b); @@ -152,9 +154,9 @@ mod tests { #[test] fn test_series_eq() { - let a = Series::from([f32::NAN, 2.0, 3.0, 1.0, 5.0]); + let a = Series::from([NAN, 2.0, 3.0, 1.0, 5.0]); let b = Series::from([1.0, 1.0, 6.0, 1.0, 1.0]); - let expected: Series = Series::from([0.0, 0.0, 0.0, 1.0, 0.0]).into(); + let expected: Rule = Series::from([0.0, 0.0, 0.0, 1.0, 0.0]).into(); let result = a.seq(&b); @@ -163,9 +165,9 @@ mod tests { #[test] fn test_series_ne() { - let a = Series::from([f32::NAN, 2.0, 3.0, 1.0, 5.0]); + let a = Series::from([NAN, 2.0, 3.0, 1.0, 5.0]); let b = Series::from([1.0, 1.0, 6.0, 1.0, 1.0]); - let expected: Series = Series::from([1.0, 1.0, 1.0, 0.0, 1.0]).into(); + let expected: Rule = Series::from([1.0, 1.0, 1.0, 0.0, 1.0]).into(); let result = a.sne(&b); @@ -174,9 +176,9 @@ mod tests { #[test] fn test_series_gt() { - let a = Series::from([f32::NAN, 2.0, 3.0, 4.0, 5.0]); + let a = Series::from([NAN, 2.0, 3.0, 4.0, 5.0]); let b = Series::from([1.0, 1.0, 6.0, 1.0, 1.0]); - let expected: Series = Series::from([0.0, 1.0, 0.0, 1.0, 1.0]).into(); + let expected: Rule = Series::from([0.0, 1.0, 0.0, 1.0, 1.0]).into(); let result = a.sgt(&b); @@ -185,9 +187,9 @@ mod tests { #[test] fn test_series_gte() { - let a = Series::from([f32::NAN, 2.0, 1.0, 1.0, 5.0]); + let a = Series::from([NAN, 2.0, 1.0, 1.0, 5.0]); let b = Series::from([1.0, 1.0, 6.0, 1.0, 1.0]); - let expected: Series = Series::from([0.0, 1.0, 0.0, 1.0, 1.0]).into(); + let expected: Rule = Series::from([0.0, 1.0, 0.0, 1.0, 1.0]).into(); let result = a.sgte(&b); @@ -196,9 +198,9 @@ mod tests { #[test] fn test_series_lt() { - let a = Series::from([f32::NAN, 2.0, 3.0, 4.0, 5.0]); + let a = Series::from([NAN, 2.0, 3.0, 4.0, 5.0]); let b = Series::from([1.0, 1.0, 6.0, 1.0, 1.0]); - let expected: Series = Series::from([0.0, 0.0, 1.0, 0.0, 0.0]).into(); + let expected: Rule = Series::from([0.0, 0.0, 1.0, 0.0, 0.0]).into(); let result = a.slt(&b); @@ -207,9 +209,9 @@ mod tests { #[test] fn test_series_lte() { - let a = Series::from([f32::NAN, 2.0, 3.0, 1.0, 5.0]); + let a = Series::from([NAN, 2.0, 3.0, 1.0, 5.0]); let b = Series::from([1.0, 1.0, 6.0, 1.0, 1.0]); - let expected: Series = Series::from([0.0, 0.0, 1.0, 1.0, 0.0]).into(); + let expected: Rule = Series::from([0.0, 0.0, 1.0, 1.0, 0.0]).into(); let result = a.slte(&b); diff --git a/ta_lib/core/src/constants.rs b/ta_lib/core/src/constants.rs index 0e51375c..6791dd7c 100644 --- a/ta_lib/core/src/constants.rs +++ b/ta_lib/core/src/constants.rs @@ -1,4 +1,10 @@ -pub const SCALE: f32 = 100.; -pub const ZERO: f32 = 0.; -pub const ONE: f32 = 1.; -pub const MINUS_ONE: f32 = -1.; +use crate::types::Scalar; + +pub const SCALE: Scalar = 100.; +pub const NEUTRALITY: Scalar = 50.; +pub const ZERO: Scalar = 0.; +pub const ONE: Scalar = 1.; +pub const HALF: Scalar = 0.5; +pub const MINUS_ONE: Scalar = -1.; +pub const NAN: Scalar = Scalar::NAN; +pub const PI: Scalar = std::f32::consts::PI; diff --git a/ta_lib/core/src/cross.rs b/ta_lib/core/src/cross.rs index dda5eef6..e96bafc6 100644 --- a/ta_lib/core/src/cross.rs +++ b/ta_lib/core/src/cross.rs @@ -1,34 +1,34 @@ -use crate::series::Series; use crate::traits::{Comparator, Cross}; +use crate::types::{Price, Rule, Scalar}; -impl Cross for Series { - type Output = Series; +impl Cross for Price { + type Output = Rule; - fn cross_over(&self, line: &f32) -> Self::Output { + fn cross_over(&self, line: &Scalar) -> Self::Output { self.sgt(line) & self.shift(1).slt(line) } - fn cross_under(&self, line: &f32) -> Self::Output { + fn cross_under(&self, line: &Scalar) -> Self::Output { self.slt(line) & self.shift(1).sgt(line) } - fn cross(&self, line: &f32) -> Self::Output { + fn cross(&self, line: &Scalar) -> Self::Output { self.cross_over(line) | self.cross_under(line) } } -impl Cross> for Series { - type Output = Series; +impl Cross for Price { + type Output = Rule; - fn cross_over(&self, rhs: &Series) -> Self::Output { + fn cross_over(&self, rhs: &Price) -> Self::Output { self.sgt(rhs) & self.shift(1).slt(&rhs.shift(1)) } - fn cross_under(&self, rhs: &Series) -> Self::Output { + fn cross_under(&self, rhs: &Price) -> Self::Output { self.slt(rhs) & self.shift(1).sgt(&rhs.shift(1)) } - fn cross(&self, rhs: &Series) -> Self::Output { + fn cross(&self, rhs: &Price) -> Self::Output { self.cross_over(rhs) | self.cross_under(rhs) } } @@ -36,12 +36,13 @@ impl Cross> for Series { #[cfg(test)] mod tests { use super::*; + use crate::series::Series; #[test] fn test_cross_over() { let a = Series::from([5.5, 5.0, 4.5, 3.0, 2.5]); let b = Series::from([4.5, 2.0, 3.0, 3.5, 2.0]); - let expected: Series = Series::from([0.0, 0.0, 0.0, 0.0, 1.0]).into(); + let expected: Rule = Series::from([0.0, 0.0, 0.0, 0.0, 1.0]).into(); let result = a.cross_over(&b); @@ -52,7 +53,7 @@ mod tests { fn test_cross_under() { let a = Series::from([5.5, 5.0, 4.5, 3.0, 2.5]); let b = Series::from([4.5, 2.0, 3.0, 3.5, 2.0]); - let expected: Series = Series::from([0.0, 0.0, 0.0, 1.0, 0.0]).into(); + let expected: Rule = Series::from([0.0, 0.0, 0.0, 1.0, 0.0]).into(); let result = a.cross_under(&b); @@ -63,7 +64,7 @@ mod tests { fn test_cross() { let a = Series::from([5.5, 5.0, 4.5, 3.0, 2.5]); let b = Series::from([4.5, 2.0, 3.0, 3.5, 2.0]); - let expected: Series = Series::from([0.0, 0.0, 0.0, 1.0, 1.0]).into(); + let expected: Rule = Series::from([0.0, 0.0, 0.0, 1.0, 1.0]).into(); let result = a.cross(&b); diff --git a/ta_lib/core/src/extremum.rs b/ta_lib/core/src/extremum.rs index dbc0b8ff..7427086c 100644 --- a/ta_lib/core/src/extremum.rs +++ b/ta_lib/core/src/extremum.rs @@ -1,31 +1,35 @@ -use crate::series::Series; use crate::traits::Extremum; +use crate::types::{Price, Scalar}; -impl Extremum for Series { - type Output = Series; +impl Extremum for Price { + type Output = Price; - fn extremum(&self, scalar: &f32, f: F) -> Self::Output + fn extremum(&self, scalar: &Scalar, f: F) -> Self::Output where - F: Fn(f32, f32) -> f32, + F: Fn(Scalar, Scalar) -> Scalar, { self.fmap(|val| val.map(|v| f(*v, *scalar)).or(Some(*scalar))) } - fn max(&self, scalar: &f32) -> Self::Output { - self.extremum(scalar, f32::max) + fn max(&self, scalar: &Scalar) -> Self::Output { + self.extremum(scalar, Scalar::max) } - fn min(&self, scalar: &f32) -> Self::Output { - self.extremum(scalar, f32::min) + fn min(&self, scalar: &Scalar) -> Self::Output { + self.extremum(scalar, Scalar::min) + } + + fn clip(&self, lhs: &Scalar, rhs: &Scalar) -> Self::Output { + self.min(rhs).max(lhs) } } -impl Extremum> for Series { - type Output = Series; +impl Extremum for Price { + type Output = Price; - fn extremum(&self, rhs: &Series, f: F) -> Self::Output + fn extremum(&self, rhs: &Price, f: F) -> Self::Output where - F: Fn(f32, f32) -> f32, + F: Fn(Scalar, Scalar) -> Scalar, { self.zip_with(rhs, |a, b| match (a, b) { (Some(a_val), Some(b_val)) => Some(f(*a_val, *b_val)), @@ -34,50 +38,45 @@ impl Extremum> for Series { }) } - fn max(&self, rhs: &Series) -> Self::Output { - self.extremum(rhs, f32::max) + fn max(&self, rhs: &Price) -> Self::Output { + self.extremum(rhs, Scalar::max) + } + + fn min(&self, rhs: &Price) -> Self::Output { + self.extremum(rhs, Scalar::min) } - fn min(&self, rhs: &Series) -> Self::Output { - self.extremum(rhs, f32::min) + fn clip(&self, lhs: &Price, rhs: &Price) -> Self::Output { + self.min(rhs).max(lhs) } } #[cfg(test)] mod tests { use super::*; + use crate::series::Series; #[test] fn test_smax() { - let source = vec![ + let series = Series::from([ 44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, - ]; + ]); let length = 1; - let epsilon = 0.001; - let expected = [ - Some(0.0), - Some(0.0), - Some(0.0599), - Some(0.0), - Some(0.7199), - Some(0.5), - Some(0.2700), - Some(0.3200), - Some(0.4200), - ]; - let series = Series::from(&source); + let expected = Series::from([ + 0., + 0., + 0.060001373, + 0., + 0.7200012, + 0.5, + 0.26999664, + 0.3199997, + 0.42000198, + ]); let result = series.change(length).max(&0.0); - for i in 0..result.len() { - match (result[i], expected[i]) { - (Some(a), Some(b)) => { - assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b) - } - (None, None) => {} - _ => panic!("at position {}: {:?} != {:?}", i, result[i], expected[i]), - } - } + assert_eq!(result, expected); } #[test] @@ -120,35 +119,15 @@ mod tests { #[test] fn test_smin() { - let source = vec![ + let series = Series::from([ 44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, - ]; + ]); let length = 1; - let epsilon = 0.001; - let expected = [ - Some(0.0), - Some(-0.25), - Some(0.0), - Some(-0.5399), - Some(0.0), - Some(0.0), - Some(0.0), - Some(0.0), - Some(0.0), - ]; - let series = Series::from(&source); + let expected = Series::from([0., -0.25, 0., -0.5400009, 0., 0., 0., 0., 0.]); let result = series.change(length).min(&0.0); - for i in 0..result.len() { - match (result[i], expected[i]) { - (Some(a), Some(b)) => { - assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b) - } - (None, None) => {} - _ => panic!("at position {}: {:?} != {:?}", i, result[i], expected[i]), - } - } + assert_eq!(result, expected); } #[test] @@ -186,4 +165,16 @@ mod tests { assert_eq!(result, expected); } + + #[test] + fn test_clip() { + let source = Series::from([-1.0, 0.0, 1.0, 3.0, 5.0]); + let expected = Series::from([0.0, 0.0, 1.0, 3.0, 3.0]); + let min = 0.; + let max = 3.; + + let result = source.clip(&min, &max); + + assert_eq!(result, expected); + } } diff --git a/ta_lib/core/src/fmt.rs b/ta_lib/core/src/fmt.rs new file mode 100644 index 00000000..819cdaf6 --- /dev/null +++ b/ta_lib/core/src/fmt.rs @@ -0,0 +1,18 @@ +use crate::series::Series; +use std::fmt; + +impl fmt::Display for Series { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "[")?; + for (i, item) in self.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + match item { + Some(value) => write!(f, "{}", value)?, + None => write!(f, "None")?, + } + } + write!(f, "]") + } +} diff --git a/ta_lib/core/src/from.rs b/ta_lib/core/src/from.rs index ca665167..273d40ac 100644 --- a/ta_lib/core/src/from.rs +++ b/ta_lib/core/src/from.rs @@ -1,6 +1,7 @@ -use crate::series::Series; +use crate::constants::ZERO; +use crate::types::{Price, Rule, Scalar}; -impl> From for Series { +impl> From for Price { fn from(item: T) -> Self { item.as_ref() .iter() @@ -9,34 +10,34 @@ impl> From for Series { } } -impl<'a> FromIterator> for Series { - fn from_iter>>(iter: I) -> Self { +impl<'a> FromIterator> for Price { + fn from_iter>>(iter: I) -> Self { iter.into_iter().map(|opt| opt.copied()).collect() } } -impl FromIterator for Series { - fn from_iter>(iter: I) -> Self { +impl FromIterator for Price { + fn from_iter>(iter: I) -> Self { iter.into_iter() .map(|x| if x.is_nan() { None } else { Some(x) }) .collect() } } -impl From> for Vec { - fn from(val: Series) -> Self { - val.into_iter().map(|x| x.unwrap_or(0.0)).collect() +impl From for Vec { + fn from(val: Price) -> Self { + val.into_iter().map(|x| x.unwrap_or(ZERO)).collect() } } -impl From> for Vec { - fn from(val: Series) -> Self { +impl From for Vec { + fn from(val: Rule) -> Self { val.into_iter().map(|x| x.unwrap_or(false)).collect() } } -impl From> for Series { - fn from(val: Series) -> Self { - val.fmap(|opt| opt.map(|f| f.is_finite() && *f != 0.)) +impl From for Rule { + fn from(val: Price) -> Self { + val.fmap(|opt| opt.map(|f| f.is_finite() && *f != ZERO)) } } diff --git a/ta_lib/core/src/lib.rs b/ta_lib/core/src/lib.rs index 8dd187ff..51d3a618 100644 --- a/ta_lib/core/src/lib.rs +++ b/ta_lib/core/src/lib.rs @@ -3,6 +3,7 @@ mod cmp; mod constants; mod cross; mod extremum; +mod fmt; mod from; mod macros; mod math; @@ -10,12 +11,14 @@ mod ops; mod series; mod smoothing; mod traits; +mod types; pub mod prelude { pub use crate::constants::*; pub use crate::series::Series; pub use crate::smoothing::Smooth; pub use crate::traits::*; + pub use crate::types::*; pub use crate::{iff, nz}; } diff --git a/ta_lib/core/src/math.rs b/ta_lib/core/src/math.rs index 82a62d12..a0a7a63d 100644 --- a/ta_lib/core/src/math.rs +++ b/ta_lib/core/src/math.rs @@ -1,9 +1,7 @@ -use crate::series::Series; -use crate::smoothing::Smooth; -use crate::ZERO; -use std::ops::Neg; +use crate::types::{Period, Price, Scalar}; +use crate::{NEUTRALITY, SCALE, ZERO}; -impl Series { +impl Price { pub fn abs(&self) -> Self { self.fmap(|val| val.map(|v| v.abs())) } @@ -20,7 +18,7 @@ impl Series { self.fmap(|val| val.map(|v| v.exp())) } - pub fn pow(&self, period: usize) -> Self { + pub fn pow(&self, period: Period) -> Self { self.fmap(|val| val.map(|v| v.powi(period as i32))) } @@ -29,7 +27,7 @@ impl Series { } pub fn negate(&self) -> Self { - self.fmap(|val| val.map(|v| v.neg())) + self.fmap(|val| val.map(|v| -v)) } pub fn sqrt(&self) -> Self { @@ -53,53 +51,139 @@ impl Series { } } -impl Series { - pub fn sum(&self, period: usize) -> Self { +impl Price { + fn all_none(window: &[Option]) -> bool { + window.iter().all(|&x| x.is_none()) + } + + fn wsum(&self, window: &[Option]) -> Option { + if Self::all_none(window) { + return None; + } + + Some(window.iter().flatten().sum()) + } + + fn wmean(&self, window: &[Option]) -> Option { + self.wsum(window).map(|sum| sum / window.len() as Scalar) + } + + fn wpercentile(&self, window: &[Option], percentile: Scalar) -> Option { + if Self::all_none(window) { + return None; + } + + let mut values: Vec = window.iter().flatten().copied().collect(); + values.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)); + + let len = values.len(); + let idx = (percentile / SCALE) * (len - 1) as Scalar; + let idx_lower = idx.floor() as usize; + let idx_upper = idx.ceil() as usize; + + if idx_upper >= len { + Some(values[len - 1]); + } + + let value_lower = values[idx_lower]; + let value_upper = values[idx_upper]; + + if idx_lower == idx_upper { + Some(values[idx_lower]); + } + + let fraction = idx.fract(); + + Some(value_lower + fraction * (value_upper - value_lower)) + } + + pub fn sum(&self, period: Period) -> Self { + self.window(period).map(|w| self.wsum(w)).collect() + } + + pub fn ma(&self, period: Period) -> Self { + self.window(period).map(|w| self.wmean(w)).collect() + } + + pub fn percentile(&self, period: Period, percentage: Scalar) -> Self { + self.window(period) + .map(|w| self.wpercentile(w, percentage)) + .collect() + } + + pub fn median(&self, period: Period) -> Self { + self.percentile(period, NEUTRALITY) + } + + pub fn mad(&self, period: Period) -> Self { self.window(period) .map(|w| { - if w.iter().all(|&x| x.is_none()) { - None - } else { - Some(w.iter().flatten().sum::()) - } + self.wmean(w).map(|mean| { + w.iter() + .flatten() + .map(|value| (value - mean).abs()) + .sum::() + / w.len() as Scalar + }) }) .collect() } - pub fn var(&self, period: usize) -> Self { - self.pow(2).smooth(Smooth::SMA, period) - self.smooth(Smooth::SMA, period).pow(2) + pub fn var(&self, period: Period) -> Self { + self.pow(2).ma(period) - self.ma(period).pow(2) } - pub fn std(&self, period: usize) -> Self { + pub fn std(&self, period: Period) -> Self { self.var(period).sqrt() } - pub fn mad(&self, period: usize) -> Self { + pub fn zscore(&self, period: Period) -> Self { + (self - self.ma(period)) / self.std(period) + } + + pub fn slope(&self, period: Period) -> Self { + self.change(period) / period as Scalar + } + + pub fn change(&self, period: Period) -> Self { + self - self.shift(period) + } + + pub fn highest(&self, period: Period) -> Self { self.window(period) .map(|w| { - if w.iter().all(|&x| x.is_none()) { - None - } else { - let len = w.len() as f32; - let mean = w.iter().flatten().sum::() / len; - - let mad = w - .iter() - .flatten() - .map(|value| (value - mean).abs()) - .sum::() - / len; + w.iter() + .flatten() + .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) + }) + .collect() + } - Some(mad) - } + pub fn lowest(&self, period: Period) -> Self { + self.window(period) + .map(|w| { + w.iter() + .flatten() + .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) }) .collect() } + + pub fn range(&self, period: Period) -> Self { + self.highest(period) - self.lowest(period) + } + + pub fn normalize(&self, period: Period, scale: Scalar) -> Self { + let l = self.lowest(period); + let h = self.highest(period); + + scale * (self - &l) / (h - l) + } } #[cfg(test)] mod tests { - use super::*; + use crate::series::Series; #[test] fn test_abs() { @@ -226,26 +310,38 @@ mod tests { assert_eq!(result, expected); } + #[test] + fn test_ma() { + let source = Series::from([f32::NAN, 2.0, 3.0, 4.0, 5.0]); + let expected = Series::from([f32::NAN, 1.0, 1.6666666, 3.0, 4.0]); + + let result = source.ma(3); + + assert_eq!(result, expected); + } + + #[test] + fn test_median() { + let source = Series::from([3.0, 2.0, 3.0, 4.0, 5.0]); + let expected = Series::from([3.0, 2.5, 3.0, 3.0, 4.0]); + + let result = source.median(3); + + assert_eq!(result, expected); + } + #[test] fn test_std() { let source = Series::from([2.0, 4.0, 6.0, 8.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0]); let expected = Series::from([ - 0.0, 1.0, 1.6329, 1.6329, 1.6329, 0.8164, 0.8164, 0.8164, 0.8164, 0.8164, 0.8164, + 0.0, 1.0, 1.632993, 1.6329936, 1.6329924, 0.816495, 0.816495, 0.816495, 0.8164974, + 0.8164974, ]); let period = 3; - let epsilon = 0.001; let result = source.std(period); - for i in 0..result.len() { - match (result[i], expected[i]) { - (Some(a), Some(b)) => { - assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b) - } - (None, None) => {} - _ => panic!("at position {}: {:?} != {:?}", i, result[i], expected[i]), - } - } + assert_eq!(result, expected); } #[test] @@ -258,4 +354,94 @@ mod tests { assert_eq!(result, expected); } + + #[test] + fn test_zscore() { + let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let expected = Series::from([0.0, 1.0, 1.224745, 1.2247446, 1.2247455]); + let n = 3; + + let result = source.zscore(n); + + assert_eq!(result, expected); + } + + #[test] + fn test_slope() { + let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let expected = Series::from([f32::NAN, f32::NAN, f32::NAN, 1.0, 1.0]); + let n = 3; + + let result = source.slope(n); + + assert_eq!(result, expected); + } + + #[test] + fn test_range() { + let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let expected = Series::from([0.0, 1.0, 2.0, 2.0, 2.0]); + let n = 3; + + let result = source.range(n); + + assert_eq!(result, expected); + } + + #[test] + fn test_change() { + let source = Series::from([ + 44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, + ]); + let length = 1; + let expected = Series::from([ + f32::NAN, + -0.25, + 0.060001373, + -0.5400009, + 0.7200012, + 0.5, + 0.26999664, + 0.3199997, + 0.42000198, + ]); + + let result = source.change(length); + + assert_eq!(result, expected); + } + + #[test] + fn test_highest() { + let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let expected = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let period = 3; + + let result = source.highest(period); + + assert_eq!(result, expected); + } + + #[test] + fn test_lowest() { + let source = Series::from([f32::NAN, 2.0, 3.0, 1.0, 5.0]); + let expected = Series::from([f32::NAN, 2.0, 2.0, 1.0, 1.0]); + let period = 3; + + let result = source.lowest(period); + + assert_eq!(result, expected); + } + + #[test] + fn test_normalize() { + let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let expected = Series::from([0.0, 1.0, 1.0, 1.0, 1.0]); + let n = 3; + let scale = 1.; + + let result = source.normalize(n, scale); + + assert_eq!(result, expected); + } } diff --git a/ta_lib/core/src/ops.rs b/ta_lib/core/src/ops.rs index 02c8368a..f8c319b8 100644 --- a/ta_lib/core/src/ops.rs +++ b/ta_lib/core/src/ops.rs @@ -1,124 +1,129 @@ use crate::constants::ZERO; use crate::series::Series; use crate::traits::Operation; +use crate::types::{Price, Rule, Scalar}; use std::ops::{Add, Div, Mul, Sub}; -impl Operation for Series { - type Output = Series; +impl Operation for Price { + type Output = Price; - fn ops(&self, scalar: &f32, op: F) -> Self::Output + fn ops(&self, scalar: &Scalar, op: F) -> Self::Output where - F: Fn(&f32, &f32) -> f32, + F: Fn(Scalar, Scalar) -> Scalar, { - self.fmap(|val| val.map(|v| op(v, scalar))) + self.fmap(|val| val.map(|v| op(*v, *scalar))) } - fn sadd(&self, scalar: &f32) -> Series { + fn sadd(&self, scalar: &Scalar) -> Self::Output { self.ops(scalar, |v, s| v + s) } - fn smul(&self, scalar: &f32) -> Series { + fn smul(&self, scalar: &Scalar) -> Self::Output { self.ops(scalar, |v, s| v * s) } - fn sdiv(&self, scalar: &f32) -> Series { - self.ops(scalar, |v, s| if *s != ZERO { v / s } else { ZERO }) + fn sdiv(&self, scalar: &Scalar) -> Self::Output { + self.ops(scalar, |v, s| if s != ZERO { v / s } else { ZERO }) } - fn ssub(&self, scalar: &f32) -> Series { + fn ssub(&self, scalar: &Scalar) -> Self::Output { self.ops(scalar, |v, s| v - s) } } -impl Operation, f32, f32> for Series { - type Output = Series; +impl Operation for Price { + type Output = Price; - fn ops(&self, rhs: &Series, op: F) -> Self::Output + fn ops(&self, rhs: &Price, op: F) -> Self::Output where - F: Fn(&f32, &f32) -> f32, + F: Fn(Scalar, Scalar) -> Scalar, { self.zip_with(rhs, |a, b| { - a.and_then(|a_val| b.map(|b_val| op(a_val, b_val))) + a.and_then(|a_val| b.map(|b_val| op(*a_val, *b_val))) }) } - fn sadd(&self, rhs: &Series) -> Series { + fn sadd(&self, rhs: &Price) -> Self::Output { self.ops(rhs, |v, s| v + s) } - fn smul(&self, rhs: &Series) -> Series { + fn smul(&self, rhs: &Price) -> Self::Output { self.ops(rhs, |v, s| v * s) } - fn sdiv(&self, rhs: &Series) -> Series { + fn sdiv(&self, rhs: &Price) -> Self::Output { self.zip_with(rhs, |a, b| match (a, b) { - (Some(a_val), Some(b_val)) if *b_val == ZERO => Some(ZERO), + (Some(_a_val), Some(b_val)) if *b_val == ZERO => Some(ZERO), (Some(a_val), Some(b_val)) if *b_val != ZERO => Some(a_val / b_val), _ => None, }) } - fn ssub(&self, rhs: &Series) -> Series { + fn ssub(&self, rhs: &Price) -> Self::Output { self.ops(rhs, |v, s| v - s) } } -impl Operation, bool, f32> for Series { - type Output = Series; +impl Operation for Rule { + type Output = Price; - fn ops(&self, rhs: &Series, op: F) -> Series + fn ops(&self, rhs: &Price, op: F) -> Self::Output where - F: Fn(&bool, &f32) -> f32, + F: Fn(bool, Scalar) -> Scalar, { self.zip_with(rhs, |b, a| match (b, a) { - (Some(b_val), Some(a_val)) => Some(op(b_val, a_val)), + (Some(b_val), Some(a_val)) => Some(op(*b_val, *a_val)), _ => None, }) } - fn sadd(&self, _rhs: &Series) -> Series { + fn sadd(&self, _rhs: &Price) -> Self::Output { unimplemented!() } - fn smul(&self, rhs: &Series) -> Series { - self.ops(rhs, |b, a| if *b { *a } else { 0.0 }) + fn smul(&self, rhs: &Price) -> Self::Output { + self.ops(rhs, |b, a| if b { a } else { ZERO }) } - fn sdiv(&self, _rhs: &Series) -> Series { + fn sdiv(&self, _rhs: &Price) -> Self::Output { unimplemented!() } - fn ssub(&self, _rhs: &Series) -> Series { + fn ssub(&self, _rhs: &Price) -> Self::Output { unimplemented!() } } macro_rules! impl_series_ops { ($trait_name:ident, $trait_method:ident, $method:ident) => { - impl $trait_name> for &Series { - type Output = Series; - fn $trait_method(self, rhs: Series) -> Self::Output { + impl $trait_name for &Price { + type Output = Price; + + fn $trait_method(self, rhs: Price) -> Self::Output { self.$method(&rhs) } } - impl $trait_name<&Series> for Series { - type Output = Series; - fn $trait_method(self, rhs: &Series) -> Self::Output { + impl $trait_name<&Price> for Price { + type Output = Price; + + fn $trait_method(self, rhs: &Price) -> Self::Output { self.$method(rhs) } } - impl $trait_name<&Series> for &Series { - type Output = Series; - fn $trait_method(self, rhs: &Series) -> Self::Output { + impl $trait_name<&Price> for &Price { + type Output = Price; + + fn $trait_method(self, rhs: &Price) -> Self::Output { self.$method(rhs) } } - impl $trait_name> for Series { - type Output = Series; - fn $trait_method(self, rhs: Series) -> Self::Output { + impl $trait_name for Price { + type Output = Price; + + fn $trait_method(self, rhs: Price) -> Self::Output { self.$method(&rhs) } } @@ -132,30 +137,34 @@ impl_series_ops!(Sub, sub, ssub); macro_rules! impl_scalar_ops { ($trait_name:ident, $trait_method:ident, $method:ident) => { - impl $trait_name<&Series> for f32 { - type Output = Series; - fn $trait_method(self, rhs: &Series) -> Self::Output { + impl $trait_name<&Price> for Scalar { + type Output = Price; + + fn $trait_method(self, rhs: &Price) -> Self::Output { rhs.$method(&self) } } - impl $trait_name> for f32 { - type Output = Series; - fn $trait_method(self, rhs: Series) -> Self::Output { + impl $trait_name for Scalar { + type Output = Price; + + fn $trait_method(self, rhs: Price) -> Self::Output { rhs.$method(&self) } } - impl $trait_name for &Series { - type Output = Series; - fn $trait_method(self, scalar: f32) -> Self::Output { + impl $trait_name for &Price { + type Output = Price; + + fn $trait_method(self, scalar: Scalar) -> Self::Output { self.$method(&scalar) } } - impl $trait_name for Series { - type Output = Series; - fn $trait_method(self, scalar: f32) -> Self::Output { + impl $trait_name for Price { + type Output = Price; + + fn $trait_method(self, scalar: Scalar) -> Self::Output { self.$method(&scalar) } } @@ -165,82 +174,84 @@ macro_rules! impl_scalar_ops { impl_scalar_ops!(Add, add, sadd); impl_scalar_ops!(Mul, mul, smul); -impl Div for &Series { - type Output = Series; +impl Div for &Price { + type Output = Price; - fn div(self, scalar: f32) -> Self::Output { + fn div(self, scalar: Scalar) -> Self::Output { self.sdiv(&scalar) } } -impl Div for Series { - type Output = Series; +impl Div for Price { + type Output = Price; - fn div(self, scalar: f32) -> Self::Output { + fn div(self, scalar: Scalar) -> Self::Output { self.sdiv(&scalar) } } -impl Div<&Series> for f32 { - type Output = Series; +impl Div<&Price> for Scalar { + type Output = Price; - fn div(self, rhs: &Series) -> Self::Output { + fn div(self, rhs: &Price) -> Self::Output { Series::fill(self, rhs.len()).sdiv(rhs) } } -impl Div> for f32 { - type Output = Series; +impl Div for Scalar { + type Output = Price; - fn div(self, rhs: Series) -> Self::Output { + fn div(self, rhs: Price) -> Self::Output { Series::fill(self, rhs.len()).sdiv(&rhs) } } -impl Sub for &Series { - type Output = Series; +impl Sub for &Price { + type Output = Price; - fn sub(self, scalar: f32) -> Self::Output { + fn sub(self, scalar: Scalar) -> Self::Output { self.ssub(&scalar) } } -impl Sub for Series { - type Output = Series; +impl Sub for Price { + type Output = Price; - fn sub(self, scalar: f32) -> Self::Output { + fn sub(self, scalar: Scalar) -> Self::Output { self.ssub(&scalar) } } -impl Sub<&Series> for f32 { - type Output = Series; +impl Sub<&Price> for Scalar { + type Output = Price; - fn sub(self, rhs: &Series) -> Self::Output { + fn sub(self, rhs: &Price) -> Self::Output { rhs.negate().ssub(&-self) } } -impl Sub> for f32 { - type Output = Series; +impl Sub for Scalar { + type Output = Price; - fn sub(self, rhs: Series) -> Self::Output { + fn sub(self, rhs: Price) -> Self::Output { rhs.negate().ssub(&-self) } } macro_rules! impl_bool_ops { ($trait_name:ident, $trait_method:ident, $method:ident) => { - impl $trait_name<&Series> for &Series { - type Output = Series; - fn $trait_method(self, rhs: &Series) -> Self::Output { + impl $trait_name<&Price> for &Rule { + type Output = Price; + + fn $trait_method(self, rhs: &Price) -> Self::Output { self.$method(rhs) } } - impl $trait_name<&Series> for Series { - type Output = Series; - fn $trait_method(self, rhs: &Series) -> Self::Output { + impl $trait_name<&Price> for Rule { + type Output = Price; + + fn $trait_method(self, rhs: &Price) -> Self::Output { self.$method(rhs) } } diff --git a/ta_lib/core/src/series.rs b/ta_lib/core/src/series.rs index 0d78d39f..a27de3c0 100644 --- a/ta_lib/core/src/series.rs +++ b/ta_lib/core/src/series.rs @@ -1,4 +1,5 @@ -use std::ops::{Index, IndexMut}; +use crate::types::{Period, Price, Rule, Scalar}; +use crate::{ONE, ZERO}; #[derive(Debug, Clone, PartialEq)] pub struct Series { @@ -22,26 +23,12 @@ impl FromIterator> for Series { } } -impl Index for Series { - type Output = Option; - - fn index(&self, idx: usize) -> &Self::Output { - &self.data[idx] - } -} - -impl IndexMut for Series { - fn index_mut(&mut self, index: usize) -> &mut Self::Output { - &mut self.data[index] - } -} - -impl Series { +impl Series { pub fn iter(&self) -> impl Iterator> { self.data.iter() } - pub fn window(&self, period: usize) -> impl Iterator]> + '_ { + pub fn window(&self, period: Period) -> impl Iterator]> + '_ { (0..self.len()).map(move |i| &self.data[i.saturating_sub(period - 1)..=i]) } @@ -63,81 +50,67 @@ impl Series { .collect() } - pub fn empty(length: usize) -> Self { - std::iter::repeat(None).take(length).collect() - } - - #[inline] + #[inline(always)] pub fn len(&self) -> usize { self.data.len() } +} - pub fn shift(&self, n: usize) -> Self { - let shifted_len = self.len() - n; +impl Series { + pub fn shift(&self, period: Period) -> Self { + let shifted_len = self.len().saturating_sub(period); - std::iter::repeat(None) - .take(n) + core::iter::repeat(None) + .take(period) .chain(self.iter().take(shifted_len).cloned()) .collect() } + pub fn empty(length: usize) -> Self { + core::iter::repeat(None).take(length).collect() + } + pub fn last(&self) -> Option { self.iter().last().cloned().flatten() } + + pub fn get(&self, index: usize) -> Option { + if index < self.len() { + self.data[index].clone() + } else { + None + } + } } -impl Series { - pub fn nz(&self, replacement: Option) -> Self { +impl Price { + pub fn nz(&self, replacement: Option) -> Self { self.fmap(|opt| match opt { Some(v) => Some(*v), - None => Some(replacement.unwrap_or(0.0)), + None => Some(replacement.unwrap_or(ZERO)), }) } - pub fn na(&self) -> Series { + pub fn na(&self) -> Rule { self.fmap(|val| Some(val.is_none())) } - pub fn fill(scalar: f32, len: usize) -> Series { - Series::empty(len).nz(Some(scalar)) - } - - pub fn zero(len: usize) -> Series { - Series::fill(0., len) - } - - pub fn one(len: usize) -> Series { - Series::fill(1., len) - } - - pub fn change(&self, length: usize) -> Self { - self - self.shift(length) + pub fn fill(scalar: Scalar, n: usize) -> Price { + core::iter::repeat(scalar).take(n).collect() } - pub fn highest(&self, period: usize) -> Self { - self.window(period) - .map(|w| { - w.iter() - .flatten() - .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - }) - .collect() + pub fn zero(n: usize) -> Price { + Series::fill(ZERO, n) } - pub fn lowest(&self, period: usize) -> Self { - self.window(period) - .map(|w| { - w.iter() - .flatten() - .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal)) - }) - .collect() + pub fn one(n: usize) -> Price { + Series::fill(ONE, n) } } #[cfg(test)] mod tests { - use super::*; + use crate::series::Series; #[test] fn test_len() { @@ -167,6 +140,16 @@ mod tests { assert_eq!(result, expected); } + #[test] + fn test_nz() { + let source = Series::from([f32::NAN, f32::NAN, 3.0, 4.0, 5.0]); + let expected = Series::from([0.0, 0.0, 3.0, 4.0, 5.0]); + + let result = source.nz(Some(0.0)); + + assert_eq!(result, expected); + } + #[test] fn test_shift() { let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); @@ -175,6 +158,7 @@ mod tests { let result = source.shift(n); + assert_eq!(source.len(), result.len()); assert_eq!(result, expected); } @@ -199,55 +183,11 @@ mod tests { } #[test] - fn test_change() { - let source = Series::from([ - 44.34, 44.09, 44.15, 43.61, 44.33, 44.83, 45.10, 45.42, 45.84, - ]); - let length = 1; - let epsilon = 0.001; - let expected = Series::from([ - f32::NAN, - -0.25, - 0.0599, - -0.540, - 0.7199, - 0.5, - 0.2700, - 0.3200, - 0.4200, - ]); - - let result = source.change(length); - - for i in 0..result.len() { - match (result[i], expected[i]) { - (Some(a), Some(b)) => { - assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b) - } - (None, None) => {} - _ => panic!("at position {}: {:?} != {:?}", i, result[i], expected[i]), - } - } - } - - #[test] - fn test_highest() { + fn test_get() { let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); - let expected = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); - let period = 3; - - let result = source.highest(period); - - assert_eq!(result, expected); - } - - #[test] - fn test_lowest() { - let source = Series::from([f32::NAN, 2.0, 3.0, 1.0, 5.0]); - let expected = Series::from([f32::NAN, 2.0, 2.0, 1.0, 1.0]); - let period = 3; + let expected = Some(5.0); - let result = source.lowest(period); + let result = source.get(4); assert_eq!(result, expected); } diff --git a/ta_lib/core/src/smoothing.rs b/ta_lib/core/src/smoothing.rs index aebcbf9a..7cdbdf44 100644 --- a/ta_lib/core/src/smoothing.rs +++ b/ta_lib/core/src/smoothing.rs @@ -1,6 +1,7 @@ +use crate::constants::{ONE, PI, SCALE, ZERO}; use crate::series::Series; use crate::traits::Comparator; -use crate::ZERO; +use crate::types::{Period, Price, Scalar}; use crate::{iff, nz}; #[derive(Copy, Clone)] @@ -14,57 +15,55 @@ pub enum Smooth { ZLEMA, LSMA, TEMA, + DEMA, + ULTS, } -impl Series { - pub fn ew(&self, alpha: &Series, seed: &Series) -> Self { +impl Price { + pub fn ew(&self, alpha: &Price, seed: &Price) -> Self { let len = self.len(); let mut sum = Series::zero(len); + let a = alpha * self; + let b = ONE - alpha; for _ in 0..len { - sum = alpha * self + (1. - alpha) * nz!(sum.shift(1), seed) + sum = &a + &b * nz!(sum.shift(1), seed) } sum } - pub fn wg(&self, weights: &[f32]) -> Self { + pub fn wg(&self, weights: &[Scalar]) -> Self { let mut sum = Series::zero(self.len()); - let norm = weights.iter().sum::(); + let norm = weights.iter().sum::(); for (i, &weight) in weights.iter().enumerate() { - sum = sum + self.shift(i) * weight; + sum = sum + nz!(self.shift(i), self) * weight; } sum / norm } - fn ma(&self, period: usize) -> Self { - self.window(period) - .map(|w| { - if w.iter().all(|&x| x.is_none()) { - None - } else { - Some(w.iter().flatten().sum::() / w.len() as f32) - } - }) - .collect() - } - - fn ema(&self, period: usize) -> Self { - let alpha = Series::fill(2. / (period as f32 + 1.), self.len()); + fn ema(&self, period: Period) -> Self { + let alpha = Series::fill(2. / (period + 1) as Scalar, self.len()); self.ew(&alpha, self) } - fn smma(&self, period: usize) -> Self { - let alpha = Series::fill(1. / (period as f32), self.len()); + fn smma(&self, period: Period) -> Self { + let alpha = Series::fill(ONE / (period as Scalar), self.len()); let seed = self.ma(period); self.ew(&alpha, &seed) } - fn tema(&self, period: usize) -> Self { + fn dema(&self, period: Period) -> Self { + let ema = self.ema(period); + + 2. * &ema - ema.ema(period) + } + + fn tema(&self, period: Period) -> Self { let ema1 = self.ema(period); let ema2 = ema1.ema(period); let ema3 = ema2.ema(period); @@ -72,29 +71,31 @@ impl Series { 3. * (ema1 - ema2) + ema3 } - fn wma(&self, period: usize) -> Self { - let weights = (0..period).map(|i| (period - i) as f32).collect::>(); + fn wma(&self, period: Period) -> Self { + let weights = (0..period) + .map(|i| (period - i) as Scalar) + .collect::>(); self.wg(&weights) } fn swma(&self) -> Self { - let x1 = self.shift(1); - let x2 = self.shift(2); - let x3 = self.shift(3); + let x1 = nz!(self.shift(1), self); + let x2 = nz!(self.shift(2), self); + let x3 = nz!(self.shift(3), self); - x3 * 1. / 6. + x2 * 2. / 6. + x1 * 2. / 6. + self * 1. / 6. + x3 * ONE / 6. + x2 * 2. / 6. + x1 * 2. / 6. + self * ONE / 6. } - fn hma(&self, period: usize) -> Self { - let lag = (0.5 * period as f32).round() as usize; - let sqrt_period = (period as f32).sqrt() as usize; + fn hma(&self, period: Period) -> Self { + let lag = (0.5 * period as Scalar) as Period; + let sqrt_period = (period as Scalar).sqrt().floor() as Period; (2. * self.wma(lag) - self.wma(period)).wma(sqrt_period) } - fn linreg(&self, period: usize) -> Self { - let x = (0..self.len()).map(|i| i as f32).collect::>(); + fn linreg(&self, period: Period) -> Self { + let x = (0..self.len()).map(|i| i as Scalar).collect::>(); let x_mean = x.ma(period); let y_mean = self.ma(period); @@ -112,25 +113,54 @@ impl Series { &intercept + &slope * &x } - fn kama(&self, period: usize) -> Series { + fn kama(&self, period: Period) -> Self { let len = self.len(); let mom = self.change(period).abs(); let volatility = self.change(1).abs().sum(period); let er = iff!(volatility.seq(&ZERO), Series::zero(len), mom / volatility); - let alpha = (er * 0.666_666_7).pow(2); + let alpha = (er.nz(Some(ZERO)) * 0.6015 + 0.0645).pow(2); self.ew(&alpha, self) } - fn zlema(&self, period: usize) -> Series { - let lag = (0.5 * (period as f32 - 1.)) as usize; + fn zlema(&self, period: Period) -> Self { + let lag = (0.5 * (period - 1) as Scalar).floor() as Period; + + (self + (self - nz!(self.shift(lag), self))).ema(period) + } + + fn ults(&self, period: Period) -> Self { + let a1 = (-1.414 * PI / period as Scalar).exp(); + let c2 = 2. * a1 * (1.414 * PI / period as Scalar).cos(); + let c3 = -a1 * a1; + let c1 = 0.25 * (ONE + c2 - c3); + + let len = self.len(); + + let mut us = Series::zero(len); + + let src1 = nz!(self.shift(1), self); + let src2 = nz!(self.shift(2), src1); + + let a = (ONE - c1) * self; + let b = (2. * c1 - c2) * &src1; + let c = (c1 + c3) * &src2; + + let abc = a + b - c; + + for _ in 0..len { + let d = c2 * nz!(us.shift(1), src1); + let e = c3 * nz!(us.shift(2), src2); + + us = &abc + d + e; + } - (self + (self - self.shift(lag))).ema(period) + us } - pub fn smooth(&self, smooth_type: Smooth, period: usize) -> Self { - match smooth_type { + pub fn smooth(&self, smooth: Smooth, period: Period) -> Self { + match smooth { Smooth::EMA => self.ema(period), Smooth::SMA => self.ma(period), Smooth::SMMA => self.smma(period), @@ -140,23 +170,41 @@ impl Series { Smooth::ZLEMA => self.zlema(period), Smooth::LSMA => self.linreg(period), Smooth::TEMA => self.tema(period), + Smooth::DEMA => self.dema(period), + Smooth::ULTS => self.ults(period), } } -} -#[cfg(test)] -mod tests { - use super::*; + pub fn smooth_dst(&self, smooth: Smooth, period: Period) -> Self { + self - self.smooth(smooth, period) + } - #[test] - fn test_ma() { - let source = Series::from([f32::NAN, 2.0, 3.0, 4.0, 5.0]); - let expected = Series::from([f32::NAN, 1.0, 1.6666666, 3.0, 4.0]); + pub fn spread(&self, smooth: Smooth, period_fast: Period, period_slow: Period) -> Self { + self.smooth(smooth, period_fast) - self.smooth(smooth, period_slow) + } - let result = source.ma(3); + pub fn spread_pct(&self, smooth: Smooth, period_fast: Period, period_slow: Period) -> Self { + let fsm = self.smooth(smooth, period_fast); + let ssm = self.smooth(smooth, period_slow); - assert_eq!(result, expected); + SCALE * (fsm - &ssm) / &ssm + } + + pub fn spread_diff( + &self, + smooth: Smooth, + period_fast: Period, + period_slow: Period, + n: usize, + ) -> Self { + self.spread(smooth, period_fast, period_slow) + - self.shift(n).spread(smooth, period_fast, period_slow) } +} + +#[cfg(test)] +mod tests { + use super::*; #[test] fn test_ema() { @@ -181,7 +229,7 @@ mod tests { #[test] fn test_wma() { let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); - let expected = Series::from([f32::NAN, f32::NAN, 2.3333333, 3.3333333, 4.3333335]); + let expected = Series::from([1.0, 1.6666666, 2.3333333, 3.3333333, 4.3333335]); let result = source.wma(3); @@ -191,7 +239,7 @@ mod tests { #[test] fn test_swma() { let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); - let expected = Series::from([f32::NAN, f32::NAN, f32::NAN, 2.5, 3.5]); + let expected = Series::from([1.0, 1.6666667, 2.0, 2.5, 3.5]); let result = source.swma(); @@ -201,7 +249,7 @@ mod tests { #[test] fn test_kama() { let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); - let expected = Series::from([f32::NAN, f32::NAN, f32::NAN, 4.0, 4.4444447]); + let expected = Series::from([1.0, 1.0041603, 1.0124636, 2.337603, 3.5185251]); let result = source.kama(3); @@ -232,4 +280,70 @@ mod tests { assert_eq!(result.len(), expected.len()); assert_eq!(result, expected); } + + #[test] + fn test_ults() { + let source = Series::from([ + 0.3847, 0.3863, 0.3885, 0.3839, 0.3834, 0.3843, 0.3840, 0.3834, 0.3832, + ]); + let expected = Series::from([ + 0.38469997, 0.38586292, 0.3883182, 0.3857727, 0.38236603, 0.38377836, 0.38435996, + 0.38352367, 0.38307717, + ]); + + let result = source.ults(3); + + assert_eq!(result, expected); + } + + #[test] + fn test_average_distance() { + let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let expected = Series::from([0.0, 0.5, 1.0, 1.0, 1.0]); + let period = 3; + + let result = source.smooth_dst(Smooth::SMA, period); + + assert_eq!(result, expected); + } + + #[test] + fn test_spread() { + let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let expected = Series::from([0.0, 0.16666675, 0.30555558, 0.39351845, 0.44367313]); + let period_fast = 2; + let period_slow = 3; + let smooth = Smooth::EMA; + + let result = source.spread(smooth, period_fast, period_slow); + + assert_eq!(result, expected); + } + + #[test] + fn test_percent_of_spread() { + let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let expected = Series::from([0.0, 11.111117, 13.580248, 12.59259, 10.921185]); + let period_fast = 2; + let period_slow = 3; + let smooth = Smooth::EMA; + + let result = source.spread_pct(smooth, period_fast, period_slow); + + assert_eq!(result, expected); + } + + #[test] + fn test_spread_diff() { + let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let expected = Series::from([f32::NAN, 0.16666675, 0.13888884, 0.087962866, 0.050154686]); + let period_fast = 2; + let period_slow = 3; + let n = 1; + let smooth = Smooth::EMA; + + let result = source.spread_diff(smooth, period_fast, period_slow, n); + + assert_eq!(result, expected); + } } diff --git a/ta_lib/core/src/traits.rs b/ta_lib/core/src/traits.rs index fea58d7f..949a9854 100644 --- a/ta_lib/core/src/traits.rs +++ b/ta_lib/core/src/traits.rs @@ -1,3 +1,5 @@ +use crate::types::Scalar; + pub trait Cross { type Output; @@ -11,10 +13,11 @@ pub trait Extremum { fn extremum(&self, rhs: &T, f: F) -> Self::Output where - F: Fn(f32, f32) -> f32; + F: Fn(Scalar, Scalar) -> Scalar; fn max(&self, rhs: &T) -> Self::Output; fn min(&self, rhs: &T) -> Self::Output; + fn clip(&self, lhs: &T, rhs: &T) -> Self::Output; } pub trait Comparator { @@ -22,7 +25,7 @@ pub trait Comparator { fn compare(&self, rhs: &T, comparator: F) -> Self::Output where - F: Fn(&f32, &f32) -> bool; + F: Fn(Scalar, Scalar) -> bool; fn seq(&self, rhs: &T) -> Self::Output; fn sne(&self, rhs: &T) -> Self::Output; @@ -37,7 +40,7 @@ pub trait Operation { fn ops(&self, rhs: &T, op: F) -> Self::Output where - F: Fn(&U, &V) -> f32; + F: Fn(U, V) -> Scalar; fn sadd(&self, rhs: &T) -> Self::Output; fn ssub(&self, rhs: &T) -> Self::Output; @@ -50,7 +53,7 @@ pub trait Bitwise { fn op(&self, rhs: &T, op: F) -> Self::Output where - F: Fn(&bool, &bool) -> bool; + F: Fn(bool, bool) -> bool; fn sand(&self, rhs: &T) -> Self::Output; fn sor(&self, rhs: &T) -> Self::Output; diff --git a/ta_lib/core/src/types.rs b/ta_lib/core/src/types.rs new file mode 100644 index 00000000..4302c377 --- /dev/null +++ b/ta_lib/core/src/types.rs @@ -0,0 +1,6 @@ +use crate::series::Series; + +pub type Scalar = f32; +pub type Price = Series; +pub type Rule = Series; +pub type Period = usize; diff --git a/ta_lib/ffi/Cargo.toml b/ta_lib/ffi/Cargo.toml new file mode 100644 index 00000000..de09744b --- /dev/null +++ b/ta_lib/ffi/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "ffi" +authors.workspace = true +edition.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +version.workspace = true + +[lib] +crate-type = ["cdylib"] + +[dependencies] +timeseries = { path = "../timeseries" } +serde = { version = "1.0", default-features = false, features = ["derive"] } +serde_json = { version = "1.0", default-features = false, features = ["alloc"] } +once_cell = "1.19" \ No newline at end of file diff --git a/ta_lib/ffi/src/lib.rs b/ta_lib/ffi/src/lib.rs new file mode 100644 index 00000000..3640ec07 --- /dev/null +++ b/ta_lib/ffi/src/lib.rs @@ -0,0 +1,3 @@ +mod timeseries; + +pub use timeseries::*; diff --git a/ta_lib/ffi/src/timeseries.rs b/ta_lib/ffi/src/timeseries.rs new file mode 100644 index 00000000..ca7fae41 --- /dev/null +++ b/ta_lib/ffi/src/timeseries.rs @@ -0,0 +1,203 @@ +use once_cell::sync::Lazy; +use serde::Serialize; +use serde_json::to_string; +use std::collections::HashMap; +use std::sync::atomic::{AtomicI32, Ordering}; +use std::sync::{Arc, RwLock}; +use timeseries::prelude::*; + +type TsTableType = Lazy>>>>; +type Result = (i32, i32); + +static TIMESERIES: TsTableType = Lazy::new(|| Arc::new(RwLock::new(HashMap::new()))); +static TIMESERIES_ID_COUNTER: Lazy = Lazy::new(|| AtomicI32::new(0)); + +fn generate_timeseries_id() -> i32 { + TIMESERIES_ID_COUNTER.fetch_add(1, Ordering::SeqCst) +} + +const ERROR: Result = (-1, 0); +const NOT_FOUND: Result = (0, 0); + +fn serialize(data: &T) -> Result { + match to_string(data) { + Ok(json) => { + let bytes = json.as_bytes(); + (bytes.as_ptr() as i32, bytes.len() as i32) + } + Err(_) => ERROR, + } +} + +#[no_mangle] +pub fn timeseries_register() -> i32 { + let timeseries_id = generate_timeseries_id(); + + let mut timeseries = TIMESERIES.write().unwrap(); + + timeseries.insert(timeseries_id, Box::::default()); + + timeseries_id +} + +#[no_mangle] +pub fn timeseries_add( + timeseries_id: i32, + ts: i64, + open: f32, + high: f32, + low: f32, + close: f32, + volume: f32, +) -> Result { + let mut timeseries = TIMESERIES.write().unwrap(); + + if let Some(timeseries) = timeseries.get_mut(×eries_id) { + let bar = OHLCV { + ts, + open, + high, + low, + close, + volume, + }; + + timeseries.add(&bar); + + NOT_FOUND + } else { + ERROR + } +} + +#[no_mangle] +pub fn timeseries_next_bar( + timeseries_id: i32, + ts: i64, + open: f32, + high: f32, + low: f32, + close: f32, + volume: f32, +) -> Result { + let timeseries = TIMESERIES.read().unwrap(); + + if let Some(timeseries) = timeseries.get(×eries_id) { + let curr_bar = OHLCV { + ts, + open, + high, + low, + close, + volume, + }; + + if let Some(next_bar) = timeseries.next_bar(&curr_bar) { + serialize(&next_bar) + } else { + NOT_FOUND + } + } else { + ERROR + } +} + +#[no_mangle] +pub fn timeseries_prev_bar( + timeseries_id: i32, + ts: i64, + open: f32, + high: f32, + low: f32, + close: f32, + volume: f32, +) -> Result { + let timeseries = TIMESERIES.read().unwrap(); + + if let Some(timeseries) = timeseries.get(×eries_id) { + let curr_bar = OHLCV { + ts, + open, + high, + low, + close, + volume, + }; + + if let Some(prev_bar) = timeseries.prev_bar(&curr_bar) { + serialize(&prev_bar) + } else { + NOT_FOUND + } + } else { + ERROR + } +} + +#[no_mangle] +pub fn timeseries_back_n_bars( + timeseries_id: i32, + ts: i64, + open: f32, + high: f32, + low: f32, + close: f32, + volume: f32, + n: usize, +) -> Result { + let timeseries = TIMESERIES.read().unwrap(); + + if let Some(timeseries) = timeseries.get(×eries_id) { + let curr_bar = OHLCV { + ts, + open, + high, + low, + close, + volume, + }; + + let bars = timeseries.back_n_bars(&curr_bar, n); + + serialize(&bars) + } else { + ERROR + } +} + +#[no_mangle] +pub fn timeseries_ta( + timeseries_id: i32, + ts: i64, + open: f32, + high: f32, + low: f32, + close: f32, + volume: f32, +) -> Result { + let timeseries = TIMESERIES.read().unwrap(); + + if let Some(timeseries) = timeseries.get(×eries_id) { + let curr_bar = OHLCV { + ts, + open, + high, + low, + close, + volume, + }; + + let ta = timeseries.ta(&curr_bar); + + serialize(&ta) + } else { + ERROR + } +} + +#[no_mangle] +pub fn timeseries_unregister(timeseries_id: i32) -> i32 { + let mut timeseries = TIMESERIES.write().unwrap(); + + timeseries.remove(×eries_id).is_some() as i32 +} diff --git a/ta_lib/indicators/momentum/src/ao.rs b/ta_lib/indicators/momentum/src/ao.rs deleted file mode 100644 index b9ed7e10..00000000 --- a/ta_lib/indicators/momentum/src/ao.rs +++ /dev/null @@ -1,30 +0,0 @@ -use core::prelude::*; - -pub fn ao( - source: &Series, - smooth_type: Smooth, - fast_period: usize, - slow_period: usize, -) -> Series { - source.smooth(smooth_type, fast_period) - source.smooth(smooth_type, slow_period) -} - -#[cfg(test)] -mod tests { - use super::*; - use price::prelude::*; - - #[test] - fn test_ao() { - let high = Series::from([3.0, 4.0, 5.0, 6.0, 7.0]); - let low = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); - let hl2 = median_price(&high, &low); - let fast_period = 2; - let slow_period = 4; - let expected = vec![0.0, 0.0, 0.5, 1.0, 1.0]; - - let result: Vec = ao(&hl2, Smooth::SMA, fast_period, slow_period).into(); - - assert_eq!(result, expected); - } -} diff --git a/ta_lib/indicators/momentum/src/bop.rs b/ta_lib/indicators/momentum/src/bop.rs index 4faf484b..5cb76c46 100644 --- a/ta_lib/indicators/momentum/src/bop.rs +++ b/ta_lib/indicators/momentum/src/bop.rs @@ -1,14 +1,14 @@ use core::prelude::*; pub fn bop( - open: &Series, - high: &Series, - low: &Series, - close: &Series, - smooth_type: Smooth, - smoothing_period: usize, -) -> Series { - ((close - open) / (high - low)).smooth(smooth_type, smoothing_period) + open: &Price, + high: &Price, + low: &Price, + close: &Price, + smooth: Smooth, + period_smooth: Period, +) -> Price { + ((close - open) / (high - low)).smooth(smooth, period_smooth) } #[cfg(test)] @@ -23,7 +23,7 @@ mod tests { let close = Series::from([2.0310, 2.0282, 1.9937, 1.9795, 1.9632]); let expected = vec![-0.58558744, -0.4300509, -0.6022142, -0.8487407, -0.77561265]; - let result: Vec = bop(&open, &high, &low, &close, Smooth::SMA, 2).into(); + let result: Vec = bop(&open, &high, &low, &close, Smooth::SMA, 2).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/cc.rs b/ta_lib/indicators/momentum/src/cc.rs index 0fbd8d9f..95f80226 100644 --- a/ta_lib/indicators/momentum/src/cc.rs +++ b/ta_lib/indicators/momentum/src/cc.rs @@ -2,13 +2,13 @@ use crate::roc; use core::prelude::*; pub fn cc( - source: &Series, - fast_period: usize, - slow_period: usize, - smooth_type: Smooth, - smoothing_period: usize, -) -> Series { - (roc(source, fast_period) + roc(source, slow_period)).smooth(smooth_type, smoothing_period) + source: &Price, + period_fast: Period, + period_slow: Period, + smooth: Smooth, + period_smooth: Period, +) -> Price { + (roc(source, period_fast) + roc(source, period_slow)).smooth(smooth, period_smooth) } #[cfg(test)] @@ -18,9 +18,9 @@ mod tests { #[test] fn test_cc() { let close = Series::from([19.299, 19.305, 19.310, 19.316, 19.347, 19.355, 19.386]); - let expected = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.6957161]; + let expected = vec![0.0, 0.0, 0.0, 0.0, 0.0, 0.52320945, 0.6957161]; - let result: Vec = cc(&close, 3, 5, Smooth::WMA, 2).into(); + let result: Vec = cc(&close, 3, 5, Smooth::WMA, 2).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/cci.rs b/ta_lib/indicators/momentum/src/cci.rs index 52545a05..f42f3d97 100644 --- a/ta_lib/indicators/momentum/src/cci.rs +++ b/ta_lib/indicators/momentum/src/cci.rs @@ -1,7 +1,7 @@ use core::prelude::*; -pub fn cci(source: &Series, smooth_type: Smooth, period: usize, factor: f32) -> Series { - (source - source.smooth(smooth_type, period)) / (factor * source.mad(period)) +pub fn cci(source: &Price, period: Period, factor: Scalar) -> Price { + source.smooth_dst(Smooth::SMA, period) / (factor * source.mad(period)) } #[cfg(test)] @@ -17,7 +17,7 @@ mod tests { let hlc3 = typical_price(&high, &low, &close); let expected = vec![0.0, 66.66667, 100.0, 100.0, 100.0]; - let result: Vec = cci(&hlc3, Smooth::SMA, 3, 0.015).into(); + let result: Vec = cci(&hlc3, 3, 0.015).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/cfo.rs b/ta_lib/indicators/momentum/src/cfo.rs index 1dd35781..de34ffb0 100644 --- a/ta_lib/indicators/momentum/src/cfo.rs +++ b/ta_lib/indicators/momentum/src/cfo.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn cfo(source: &Series, period: usize) -> Series { +pub fn cfo(source: &Price, period: Period) -> Price { SCALE * (source - source.smooth(Smooth::LSMA, period)) / source } @@ -24,7 +24,7 @@ mod tests { 0.017605804, ]; - let result: Vec = cfo(&source, 3).into(); + let result: Vec = cfo(&source, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/cmo.rs b/ta_lib/indicators/momentum/src/cmo.rs index ed65535d..1f671887 100644 --- a/ta_lib/indicators/momentum/src/cmo.rs +++ b/ta_lib/indicators/momentum/src/cmo.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn cmo(source: &Series, period: usize) -> Series { +pub fn cmo(source: &Price, period: Period) -> Price { let mom = source.change(1); let zero = Series::zero(source.len()); @@ -22,7 +22,7 @@ mod tests { let close = Series::from([19.571, 19.606, 19.594, 19.575, 19.612, 19.631, 19.634]); let expected = vec![0.0, 100.0, 48.934788, 6.062883, 8.821632, 49.33496, 100.0]; - let result: Vec = cmo(&close, 3).into(); + let result: Vec = cmo(&close, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/di.rs b/ta_lib/indicators/momentum/src/di.rs index df7abd5a..731e0845 100644 --- a/ta_lib/indicators/momentum/src/di.rs +++ b/ta_lib/indicators/momentum/src/di.rs @@ -1,7 +1,7 @@ use core::prelude::*; -pub fn di(source: &Series, smooth_type: Smooth, period: usize) -> Series { - let ma = source.smooth(smooth_type, period); +pub fn di(source: &Price, smooth: Smooth, period: Period) -> Price { + let ma = source.smooth(smooth, period); SCALE * (source - &ma) / ma } @@ -17,8 +17,8 @@ mod tests { 6.8360, 6.8345, 6.8285, 6.8395, ]); let expected = vec![ - 0.0, - 0.0, + -0.0000069530056, + 0.009725365, -0.08268177, 0.040116996, 0.07046368, @@ -34,7 +34,7 @@ mod tests { 0.0658433, ]; - let result: Vec = di(&source, Smooth::WMA, 3).into(); + let result: Vec = di(&source, Smooth::WMA, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/dmi.rs b/ta_lib/indicators/momentum/src/dmi.rs index eabaeca3..5a70b230 100644 --- a/ta_lib/indicators/momentum/src/dmi.rs +++ b/ta_lib/indicators/momentum/src/dmi.rs @@ -1,13 +1,13 @@ use core::prelude::*; pub fn dmi( - high: &Series, - low: &Series, - atr: &Series, - smooth_type: Smooth, - adx_period: usize, - di_period: usize, -) -> (Series, Series, Series) { + high: &Price, + low: &Price, + atr: &Price, + smooth: Smooth, + period_adx: Period, + period_di: Period, +) -> (Price, Price, Price) { let len = high.len(); let up = high.change(1); let down = low.change(1).negate(); @@ -18,16 +18,16 @@ pub fn dmi( let dm_plus = iff!(up.sgt(&down) & up.sgt(&ZERO), up, zero); let dm_minus = iff!(down.sgt(&up) & down.sgt(&ZERO), down, zero); - let di_plus = SCALE * dm_plus.smooth(smooth_type, di_period) / atr; - let di_minus = SCALE * dm_minus.smooth(smooth_type, di_period) / atr; + let di_plus = SCALE * dm_plus.smooth(smooth, period_di) / atr; + let di_minus = SCALE * dm_minus.smooth(smooth, period_di) / atr; let sum = &di_plus + &di_minus; let adx = SCALE * ((&di_plus - &di_minus).abs() / iff!(sum.seq(&ZERO), one, sum)) - .smooth(smooth_type, adx_period); + .smooth(smooth, period_adx); - (adx, di_plus, di_minus) + (di_plus, di_minus, adx) } #[cfg(test)] @@ -74,12 +74,12 @@ mod tests { 38.785717, 57.523396, 68.031395, 42.329178, 33.812767, 37.56404, 50.37606, 26.629393, ]; - let (result_adx, result_di_plus, result_di_minus) = + let (result_di_plus, result_di_minus, result_adx) = dmi(&high, &low, &atr, Smooth::SMMA, adx_period, di_period); - let adx: Vec = result_adx.into(); - let di_plus: Vec = result_di_plus.into(); - let di_minus: Vec = result_di_minus.into(); + let adx: Vec = result_adx.into(); + let di_plus: Vec = result_di_plus.into(); + let di_minus: Vec = result_di_minus.into(); assert_eq!(adx, expected_adx); assert_eq!(di_plus, expected_di_plus); diff --git a/ta_lib/indicators/momentum/src/dso.rs b/ta_lib/indicators/momentum/src/dso.rs deleted file mode 100644 index 8f4d0785..00000000 --- a/ta_lib/indicators/momentum/src/dso.rs +++ /dev/null @@ -1,46 +0,0 @@ -use crate::stoch; -use core::prelude::*; - -pub fn dso( - close: &Series, - smooth_type: Smooth, - smooth_period: usize, - k_period: usize, - d_period: usize, -) -> (Series, Series) { - let close_smooth = close.smooth(smooth_type, k_period); - - let k = stoch(&close_smooth, &close_smooth, &close_smooth, smooth_period) - .smooth(smooth_type, k_period); - let d = k.smooth(smooth_type, d_period); - - (k, d) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_dso() { - let close = Series::from([4.9112, 4.9140, 4.9135, 4.9151, 4.9233, 4.9313, 4.9357]); - let period = 3; - let k_period = 2; - let d_period = 2; - - let expected_k = vec![ - 0.0, 66.66667, 88.88889, 96.2963, 98.76544, 99.588486, 99.86283, - ]; - let expected_d = vec![ - 0.0, 44.44445, 74.07408, 88.8889, 95.47326, 98.21674, 99.31413, - ]; - - let (k, d) = dso(&close, Smooth::EMA, period, k_period, d_period); - - let result_k: Vec = k.into(); - let result_d: Vec = d.into(); - - assert_eq!(result_k, expected_k); - assert_eq!(result_d, expected_d); - } -} diff --git a/ta_lib/indicators/momentum/src/kst.rs b/ta_lib/indicators/momentum/src/kst.rs index a865498c..262052e8 100644 --- a/ta_lib/indicators/momentum/src/kst.rs +++ b/ta_lib/indicators/momentum/src/kst.rs @@ -2,21 +2,21 @@ use crate::roc; use core::prelude::*; pub fn kst( - source: &Series, - smooth_type: Smooth, - roc_period_first: usize, - roc_period_second: usize, - roc_period_third: usize, - roc_period_fouth: usize, - period_first: usize, - period_second: usize, - period_third: usize, - period_fouth: usize, -) -> Series { - roc(source, roc_period_first).smooth(smooth_type, period_first) - + (2. * roc(source, roc_period_second).smooth(smooth_type, period_second)) - + (3. * roc(source, roc_period_third).smooth(smooth_type, period_third)) - + (4. * roc(source, roc_period_fouth).smooth(smooth_type, period_fouth)) + source: &Price, + smooth: Smooth, + period_roc_first: Period, + period_roc_second: Period, + period_roc_third: Period, + period_roc_fouth: Period, + period_first: Period, + period_second: Period, + period_third: Period, + period_fouth: Period, +) -> Price { + roc(source, period_roc_first).smooth(smooth, period_first) + + (2. * roc(source, period_roc_second).smooth(smooth, period_second)) + + (3. * roc(source, period_roc_third).smooth(smooth, period_third)) + + (4. * roc(source, period_roc_fouth).smooth(smooth, period_fouth)) } #[cfg(test)] @@ -43,7 +43,7 @@ mod tests { 8.414183, ]; - let result: Vec = kst( + let result: Vec = kst( &source, Smooth::SMA, roc_period_one, diff --git a/ta_lib/indicators/momentum/src/lib.rs b/ta_lib/indicators/momentum/src/lib.rs index 4610ba34..269d909c 100644 --- a/ta_lib/indicators/momentum/src/lib.rs +++ b/ta_lib/indicators/momentum/src/lib.rs @@ -1,4 +1,3 @@ -mod ao; mod bop; mod cc; mod cci; @@ -6,22 +5,21 @@ mod cfo; mod cmo; mod di; mod dmi; -mod dso; mod kst; mod macd; -mod pr; +mod qstick; +mod rex; mod roc; mod rsi; -mod sso; mod stc; -mod stoch; mod stochosc; mod tdfi; mod tii; mod trix; mod tsi; +mod uo; +mod wpr; -pub use ao::ao; pub use bop::bop; pub use cc::cc; pub use cci::cci; @@ -29,17 +27,17 @@ pub use cfo::cfo; pub use cmo::cmo; pub use di::di; pub use dmi::dmi; -pub use dso::dso; pub use kst::kst; pub use macd::macd; -pub use pr::pr; +pub use qstick::qstick; +pub use rex::rex; pub use roc::roc; pub use rsi::rsi; -pub use sso::sso; pub use stc::stc; -use stoch::stoch; -pub use stochosc::stochosc; +pub use stochosc::{dso, sso, stochosc}; pub use tdfi::tdfi; pub use tii::tii; pub use trix::trix; pub use tsi::tsi; +pub use uo::uo; +pub use wpr::wpr; diff --git a/ta_lib/indicators/momentum/src/macd.rs b/ta_lib/indicators/momentum/src/macd.rs index 7e6c8d9e..da1eae1f 100644 --- a/ta_lib/indicators/momentum/src/macd.rs +++ b/ta_lib/indicators/momentum/src/macd.rs @@ -1,16 +1,15 @@ use core::prelude::*; pub fn macd( - source: &Series, - smooth_type: Smooth, - fast_period: usize, - slow_period: usize, - signal_period: usize, -) -> (Series, Series, Series) { - let macd_line = - source.smooth(smooth_type, fast_period) - source.smooth(smooth_type, slow_period); + source: &Price, + smooth: Smooth, + period_fast: Period, + period_slow: Period, + period_signal: Period, +) -> (Price, Price, Price) { + let macd_line = source.spread(smooth, period_fast, period_slow); - let signal_line = macd_line.smooth(smooth_type, signal_period); + let signal_line = macd_line.smooth(smooth, period_signal); let histogram = &macd_line - &signal_line; @@ -47,9 +46,9 @@ mod tests { signal_period, ); - let result_macd_line: Vec = macd_line.into(); - let result_signal_line: Vec = signal_line.into(); - let result_histogram: Vec = histogram.into(); + let result_macd_line: Vec = macd_line.into(); + let result_signal_line: Vec = signal_line.into(); + let result_histogram: Vec = histogram.into(); for i in 0..source.len() { assert!( diff --git a/ta_lib/indicators/trend/src/qstick.rs b/ta_lib/indicators/momentum/src/qstick.rs similarity index 61% rename from ta_lib/indicators/trend/src/qstick.rs rename to ta_lib/indicators/momentum/src/qstick.rs index 75c57600..d3fe4c58 100644 --- a/ta_lib/indicators/trend/src/qstick.rs +++ b/ta_lib/indicators/momentum/src/qstick.rs @@ -1,12 +1,7 @@ use core::prelude::*; -pub fn qstick( - open: &Series, - close: &Series, - smooth_type: Smooth, - period: usize, -) -> Series { - (close - open).smooth(smooth_type, period) +pub fn qstick(open: &Price, close: &Price, smooth: Smooth, period: Period) -> Price { + (close - open).smooth(smooth, period) } #[cfg(test)] @@ -20,7 +15,7 @@ mod tests { let period = 3; let expected = vec![-12.4655, -12.4627495, -12.4766245, -12.486312, -12.509655]; - let result: Vec = qstick(&open, &close, Smooth::EMA, period).into(); + let result: Vec = qstick(&open, &close, Smooth::EMA, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/rex.rs b/ta_lib/indicators/momentum/src/rex.rs new file mode 100644 index 00000000..fbad7af5 --- /dev/null +++ b/ta_lib/indicators/momentum/src/rex.rs @@ -0,0 +1,39 @@ +use core::prelude::*; + +pub fn rex( + source: &Price, + open: &Price, + high: &Price, + low: &Price, + smooth: Smooth, + period: Period, +) -> Price { + (3. * source - (open + high + low)).smooth(smooth, period) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rex() { + let source = Series::from([2.0310, 2.0282, 1.9937, 1.9795, 1.9632]); + let open = Series::from([2.0505, 2.0310, 2.0282, 1.9937, 1.9795]); + let high = Series::from([2.0507, 2.0310, 2.0299, 1.9977, 1.9824]); + let low = Series::from([2.0174, 2.0208, 1.9928, 1.9792, 1.9616]); + + let expected = vec![ + -0.025600433, + -0.011900425, + -0.040849924, + -0.036474824, + -0.035187542, + ]; + + let period = 3; + + let result: Vec = rex(&source, &open, &high, &low, Smooth::EMA, period).into(); + + assert_eq!(result, expected); + } +} diff --git a/ta_lib/indicators/momentum/src/roc.rs b/ta_lib/indicators/momentum/src/roc.rs index 4efb81e8..37786c01 100644 --- a/ta_lib/indicators/momentum/src/roc.rs +++ b/ta_lib/indicators/momentum/src/roc.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn roc(source: &Series, period: usize) -> Series { +pub fn roc(source: &Price, period: Period) -> Price { SCALE * source.change(period) / source.shift(period) } @@ -14,7 +14,7 @@ mod tests { let period = 3; let expected = vec![0.0, 0.0, 0.0, 0.23304027, 0.36239228]; - let result: Vec = roc(&source, period).into(); + let result: Vec = roc(&source, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/rsi.rs b/ta_lib/indicators/momentum/src/rsi.rs index 37795457..27622b65 100644 --- a/ta_lib/indicators/momentum/src/rsi.rs +++ b/ta_lib/indicators/momentum/src/rsi.rs @@ -1,15 +1,19 @@ use core::prelude::*; -pub fn rsi(source: &Series, smooth_type: Smooth, period: usize) -> Series { - let len = source.len(); - +pub fn rsi(source: &Price, smooth: Smooth, period: Period) -> Price { let mom = source.change(1); - let up = mom.max(&ZERO).smooth(smooth_type, period); - let down = mom.min(&ZERO).negate().smooth(smooth_type, period); + let up = mom.max(&ZERO).smooth(smooth, period); + let down = mom.min(&ZERO).negate().smooth(smooth, period); + + let len = source.len(); - let oneh = Series::fill(SCALE, len); + let rsi = iff!( + down.seq(&ZERO), + Series::fill(SCALE, len), + SCALE - (SCALE / (1. + &up / down)) + ); - iff!(down.seq(&ZERO), oneh, SCALE - (SCALE / (1. + up / down))) + iff!(up.seq(&ZERO), Series::fill(ZERO, len), rsi) } #[cfg(test)] @@ -24,12 +28,12 @@ mod test { ]); let period = 3; let expected = [ - 100.0, 100.0, 100.0, 100.0, 46.885506, 66.75195, 50.889442, 65.60162, 73.53246, + 0.0, 100.0, 100.0, 100.0, 46.885506, 66.75195, 50.889442, 65.60162, 73.53246, 23.915344, 57.76078, 71.00006, 46.02974, 25.950226, 25.200401, 14.512299, 10.280083, 33.926575, 36.707954, 30.863396, 15.785042, 64.06485, ]; - let result: Vec = rsi(&source, Smooth::SMMA, period).into(); + let result: Vec = rsi(&source, Smooth::SMMA, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/sso.rs b/ta_lib/indicators/momentum/src/sso.rs deleted file mode 100644 index 8d98c634..00000000 --- a/ta_lib/indicators/momentum/src/sso.rs +++ /dev/null @@ -1,45 +0,0 @@ -use crate::stoch; -use core::prelude::*; - -pub fn sso( - source: &Series, - high: &Series, - low: &Series, - smooth_type: Smooth, - k_period: usize, - d_period: usize, -) -> (Series, Series) { - let high_smooth = high.smooth(smooth_type, k_period); - let low_smooth = low.smooth(smooth_type, k_period); - let source = source.smooth(smooth_type, k_period); - - let k = stoch(&source, &high_smooth, &low_smooth, k_period); - let d = k.smooth(smooth_type, d_period); - - (k, d) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_sso() { - let high = Series::from([3.0, 3.0, 3.0, 3.0, 3.0]); - let low = Series::from([1.0, 1.0, 1.0, 1.0, 1.0]); - let close = Series::from([2.0, 2.5, 2.0, 1.5, 2.0]); - let k_period = 3; - let d_period = 3; - - let expected_k = vec![0.0, 0.0, 58.333336, 41.666668, 41.666668]; - let expected_d = vec![0.0, 0.0, 0.0, 0.0, 44.444447]; - - let (k, d) = sso(&close, &high, &low, Smooth::WMA, k_period, d_period); - - let result_k: Vec = k.into(); - let result_d: Vec = d.into(); - - assert_eq!(result_k, expected_k); - assert_eq!(result_d, expected_d); - } -} diff --git a/ta_lib/indicators/momentum/src/stc.rs b/ta_lib/indicators/momentum/src/stc.rs index df74baba..d05e3571 100644 --- a/ta_lib/indicators/momentum/src/stc.rs +++ b/ta_lib/indicators/momentum/src/stc.rs @@ -1,24 +1,21 @@ -use crate::stoch; use core::prelude::*; pub fn stc( - source: &Series, - smooth_type: Smooth, - fast_period: usize, - slow_period: usize, - cycle: usize, - d_first: usize, - d_second: usize, -) -> Series { - let macd_line = - source.smooth(smooth_type, fast_period) - source.smooth(smooth_type, slow_period); - let k = stoch(&macd_line, &macd_line, &macd_line, cycle); - let d = k.smooth(smooth_type, d_first); - let kd = stoch(&d, &d, &d, cycle); - - let stc = kd.smooth(smooth_type, d_second); - - stc.min(&SCALE).max(&ZERO) + source: &Price, + smooth: Smooth, + period_fast: Period, + period_slow: Period, + cycle: Period, + d_first: Period, + d_second: Period, +) -> Price { + source + .spread(smooth, period_fast, period_slow) + .normalize(cycle, SCALE) + .smooth(smooth, d_first) + .normalize(cycle, SCALE) + .smooth(smooth, d_second) + .clip(&ZERO, &SCALE) } #[cfg(test)] @@ -41,7 +38,7 @@ mod tests { 67.08984, 83.54492, ]; - let result: Vec = stc( + let result: Vec = stc( &source, Smooth::EMA, fast_period, diff --git a/ta_lib/indicators/momentum/src/stoch.rs b/ta_lib/indicators/momentum/src/stoch.rs deleted file mode 100644 index a990031e..00000000 --- a/ta_lib/indicators/momentum/src/stoch.rs +++ /dev/null @@ -1,32 +0,0 @@ -use core::prelude::*; - -pub fn stoch( - source: &Series, - high: &Series, - low: &Series, - period: usize, -) -> Series { - let hh = high.highest(period); - let ll = low.lowest(period); - - SCALE * (source - &ll) / (&hh - &ll) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_stoch() { - let high = Series::from([3.0, 3.0, 3.0, 3.0, 3.0]); - let low = Series::from([1.0, 1.0, 1.0, 1.0, 1.0]); - let source = Series::from([2.0, 2.5, 2.0, 1.5, 2.0]); - let period = 3; - - let expected = vec![50.0, 75.0, 50.0, 25.0, 50.0]; - - let result: Vec = stoch(&source, &high, &low, period).into(); - - assert_eq!(result, expected); - } -} diff --git a/ta_lib/indicators/momentum/src/stochosc.rs b/ta_lib/indicators/momentum/src/stochosc.rs index 4f581476..2d9ca823 100644 --- a/ta_lib/indicators/momentum/src/stochosc.rs +++ b/ta_lib/indicators/momentum/src/stochosc.rs @@ -1,20 +1,56 @@ -use crate::stoch; use core::prelude::*; +pub fn stoch(source: &Price, high: &Price, low: &Price, period: Period) -> Price { + let hh = high.highest(period); + let ll = low.lowest(period); + + SCALE * (source - &ll) / (hh - ll) +} + pub fn stochosc( - source: &Series, - high: &Series, - low: &Series, - smooth_type: Smooth, - period: usize, - k_period: usize, - d_period: usize, -) -> (Series, Series) { - let stoch = stoch(source, high, low, period); + source: &Price, + high: &Price, + low: &Price, + smooth: Smooth, + period: Period, + period_k: Period, + period_d: Period, +) -> (Price, Price) { + let k = stoch(source, high, low, period).smooth(smooth, period_k); + let d = k.smooth(smooth, period_d); + + (k, d) +} + +pub fn sso( + source: &Price, + high: &Price, + low: &Price, + smooth: Smooth, + period_k: Period, + period_d: Period, +) -> (Price, Price) { + let high_smooth = high.smooth(smooth, period_k); + let low_smooth = low.smooth(smooth, period_k); + let source = source.smooth(smooth, period_k); + + let k = stoch(&source, &high_smooth, &low_smooth, period_k); + let d = k.smooth(smooth, period_d); + + (k, d) +} - let k = stoch.smooth(smooth_type, k_period); +pub fn dso( + source: &Price, + smooth: Smooth, + period: Period, + period_k: Period, + period_d: Period, +) -> (Price, Price) { + let source = source.smooth(smooth, period_k); - let d = k.smooth(smooth_type, d_period); + let k = source.normalize(period, SCALE).smooth(smooth, period_k); + let d = k.smooth(smooth, period_d); (k, d) } @@ -23,6 +59,20 @@ pub fn stochosc( mod tests { use super::*; + #[test] + fn test_stoch() { + let high = Series::from([3.0, 3.0, 3.0, 3.0, 3.0]); + let low = Series::from([1.0, 1.0, 1.0, 1.0, 1.0]); + let source = Series::from([2.0, 2.5, 2.0, 1.5, 2.0]); + let period = 3; + + let expected = vec![50.0, 75.0, 50.0, 25.0, 50.0]; + + let result: Vec = stoch(&source, &high, &low, period).into(); + + assert_eq!(result, expected); + } + #[test] fn test_stochosc() { let high = Series::from([3.0, 3.0, 3.0, 3.0, 3.0]); @@ -38,8 +88,8 @@ mod tests { let (k, d) = stochosc(&close, &high, &low, Smooth::SMA, period, k_period, d_period); - let result_k: Vec = k.into(); - let result_d: Vec = d.into(); + let result_k: Vec = k.into(); + let result_d: Vec = d.into(); for i in 0..result_k.len() { assert!( @@ -58,4 +108,47 @@ mod tests { ); } } + + #[test] + fn test_sso() { + let high = Series::from([3.0, 3.0, 3.0, 3.0, 3.0]); + let low = Series::from([1.0, 1.0, 1.0, 1.0, 1.0]); + let close = Series::from([2.0, 2.5, 2.0, 1.5, 2.0]); + let k_period = 3; + let d_period = 3; + + let expected_k = vec![50.0, 66.666664, 58.333336, 41.666668, 41.666668]; + let expected_d = vec![50.0, 61.11111, 59.722218, 51.38889, 44.444447]; + + let (k, d) = sso(&close, &high, &low, Smooth::WMA, k_period, d_period); + + let result_k: Vec = k.into(); + let result_d: Vec = d.into(); + + assert_eq!(result_k, expected_k); + assert_eq!(result_d, expected_d); + } + + #[test] + fn test_dso() { + let close = Series::from([4.9112, 4.9140, 4.9135, 4.9151, 4.9233, 4.9313, 4.9357]); + let period = 3; + let k_period = 2; + let d_period = 2; + + let expected_k = vec![ + 0.0, 66.66667, 88.88889, 96.2963, 98.76544, 99.588486, 99.86283, + ]; + let expected_d = vec![ + 0.0, 44.44445, 74.07408, 88.8889, 95.47326, 98.21674, 99.31413, + ]; + + let (k, d) = dso(&close, Smooth::EMA, period, k_period, d_period); + + let result_k: Vec = k.into(); + let result_d: Vec = d.into(); + + assert_eq!(result_k, expected_k); + assert_eq!(result_d, expected_d); + } } diff --git a/ta_lib/indicators/momentum/src/tdfi.rs b/ta_lib/indicators/momentum/src/tdfi.rs index ddfddee5..5a341230 100644 --- a/ta_lib/indicators/momentum/src/tdfi.rs +++ b/ta_lib/indicators/momentum/src/tdfi.rs @@ -1,10 +1,10 @@ use core::prelude::*; -pub fn tdfi(source: &Series, smooth_type: Smooth, period: usize, n: usize) -> Series { - let ma = (SCALE * 10. * source).smooth(smooth_type, period); - let sma = ma.smooth(smooth_type, period); +pub fn tdfi(source: &Price, smooth: Smooth, period: Period, n: usize) -> Price { + let ma = (SCALE * 10. * source).smooth(smooth, period); + let sma = ma.smooth(smooth, period); - let tdf = (&ma - &sma).abs().pow(1) * (0.5 * (ma.change(1) + sma.change(1))).pow(n); + let tdf = ((&ma - &sma).abs().pow(1)) * ((HALF * (ma.change(1) + sma.change(1))).pow(n)); &tdf / tdf.abs().highest(period * n) } @@ -41,7 +41,7 @@ mod tests { 0.0041924296, ]; - let result: Vec = tdfi(&source, smooth_type, period, n).into(); + let result: Vec = tdfi(&source, smooth_type, period, n).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/tii.rs b/ta_lib/indicators/momentum/src/tii.rs index 6b739ccb..7d8d5146 100644 --- a/ta_lib/indicators/momentum/src/tii.rs +++ b/ta_lib/indicators/momentum/src/tii.rs @@ -1,18 +1,10 @@ use core::prelude::*; -pub fn tii( - source: &Series, - smooth_type: Smooth, - major_period: usize, - minor_period: usize, -) -> Series { - let price_diff = source - source.smooth(smooth_type, major_period); +pub fn tii(source: &Price, smooth: Smooth, period_major: Period, period_minor: Period) -> Price { + let price_diff = source - source.smooth(smooth, period_major); - let positive_sum = price_diff.max(&ZERO).smooth(smooth_type, minor_period); - let negative_sum = price_diff - .min(&ZERO) - .abs() - .smooth(smooth_type, minor_period); + let positive_sum = price_diff.max(&ZERO).smooth(smooth, period_minor); + let negative_sum = price_diff.min(&ZERO).abs().smooth(smooth, period_minor); SCALE * &positive_sum / (positive_sum + negative_sum) } @@ -34,7 +26,7 @@ mod tests { 100.0, 4.648687, 48.748272, ]; - let result: Vec = tii(&source, Smooth::SMA, major_period, minor_period).into(); + let result: Vec = tii(&source, Smooth::SMA, major_period, minor_period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/trix.rs b/ta_lib/indicators/momentum/src/trix.rs index 04d3cc56..65abb6e8 100644 --- a/ta_lib/indicators/momentum/src/trix.rs +++ b/ta_lib/indicators/momentum/src/trix.rs @@ -1,11 +1,11 @@ use crate::roc; use core::prelude::*; -pub fn trix(source: &Series, smooth_type: Smooth, period: usize) -> Series { +pub fn trix(source: &Price, smooth: Smooth, period: Period) -> Price { let ema3 = source - .smooth(smooth_type, period) - .smooth(smooth_type, period) - .smooth(smooth_type, period); + .smooth(smooth, period) + .smooth(smooth, period) + .smooth(smooth, period); SCALE * roc(&ema3, 1) } @@ -26,7 +26,7 @@ mod tests { -5.0126247, -6.273622, -5.246739, -3.7003598, -2.672774, -2.9872787, -0.9277002, ]; - let result: Vec = trix(&source, Smooth::EMA, period).into(); + let result: Vec = trix(&source, Smooth::EMA, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/tsi.rs b/ta_lib/indicators/momentum/src/tsi.rs index 361e319a..0fbc5496 100644 --- a/ta_lib/indicators/momentum/src/tsi.rs +++ b/ta_lib/indicators/momentum/src/tsi.rs @@ -1,20 +1,13 @@ use core::prelude::*; -pub fn tsi( - source: &Series, - smooth_type: Smooth, - slow_period: usize, - fast_period: usize, -) -> Series { +pub fn tsi(source: &Price, smooth: Smooth, period_slow: Period, period_fast: Period) -> Price { let pc = source.change(1); - let pcds = pc - .smooth(smooth_type, slow_period) - .smooth(smooth_type, fast_period); + let pcds = pc.smooth(smooth, period_slow).smooth(smooth, period_fast); let apcds = pc .abs() - .smooth(smooth_type, slow_period) - .smooth(smooth_type, fast_period); + .smooth(smooth, period_slow) + .smooth(smooth, period_fast); SCALE * pcds / apcds } @@ -49,7 +42,7 @@ mod tests { 27.16367, ]; - let result: Vec = tsi(&source, Smooth::EMA, slow_period, fast_period).into(); + let result: Vec = tsi(&source, Smooth::EMA, slow_period, fast_period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/momentum/src/uo.rs b/ta_lib/indicators/momentum/src/uo.rs new file mode 100644 index 00000000..8d6a1bc9 --- /dev/null +++ b/ta_lib/indicators/momentum/src/uo.rs @@ -0,0 +1,69 @@ +use core::prelude::*; + +pub fn uo( + source: &Price, + high: &Price, + low: &Price, + period_fast: Period, + period_medium: Period, + period_slow: Period, +) -> Price { + let prev_source = source.shift(1).nz(Some(ZERO)); + + let high = prev_source.max(high); + let low = prev_source.min(low); + + let bp = source - &low; + let tr = high - &low; + + SCALE + * (4. * bp.sum(period_fast) / tr.sum(period_fast) + + 2. * bp.sum(period_medium) / tr.sum(period_medium) + + bp.sum(period_slow) / tr.sum(period_slow)) + / 7. +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_uo() { + let source = Series::from([ + 6.6430, 6.8595, 6.8680, 6.8650, 6.8445, 6.8560, 6.8565, 6.8590, 6.8530, 6.8575, 6.855, + 6.858, 6.86, 6.8480, 6.8575, 6.864, 6.8565, 6.8455, 6.8450, 6.8365, 6.8310, 6.8355, + 6.8360, 6.8345, 6.8285, 6.8395, + ]); + let high = Series::from([ + 6.8430, 6.8660, 6.8685, 6.8690, 6.865, 6.8595, 6.8565, 6.862, 6.859, 6.86, 6.8580, + 6.8605, 6.8620, 6.86, 6.859, 6.8670, 6.8640, 6.8575, 6.8485, 6.8450, 6.8365, 6.84, + 6.8385, 6.8365, 6.8345, 6.8395, + ]); + let low = Series::from([ + 6.8380, 6.8430, 6.8595, 6.8640, 6.8435, 6.8445, 6.8510, 6.8560, 6.8520, 6.8530, 6.8550, + 6.8550, 6.8565, 6.8475, 6.8480, 6.8535, 6.8565, 6.8455, 6.8445, 6.8365, 6.8310, 6.8310, + 6.8345, 6.8325, 6.8275, 6.8285, + ]); + let expected = vec![ + 97.07731, 97.07755, 97.02184, 79.879684, 25.307936, 35.087112, 65.40703, 70.05135, + 41.210888, 42.434715, 40.23762, 39.388878, 54.514324, 24.983133, 44.44906, 70.54365, + 53.009586, 15.58435, 6.6393533, 3.0904624, 1.0311968, 25.968098, 38.813465, 43.52247, + 29.511023, 64.793076, + ]; + let period_fast = 2; + let period_medium = 3; + let period_slow = 4; + + let result: Vec = uo( + &source, + &high, + &low, + period_fast, + period_medium, + period_slow, + ) + .into(); + + assert_eq!(result, expected); + } +} diff --git a/ta_lib/indicators/momentum/src/pr.rs b/ta_lib/indicators/momentum/src/wpr.rs similarity index 67% rename from ta_lib/indicators/momentum/src/pr.rs rename to ta_lib/indicators/momentum/src/wpr.rs index 333afb4e..cd08a3d6 100644 --- a/ta_lib/indicators/momentum/src/pr.rs +++ b/ta_lib/indicators/momentum/src/wpr.rs @@ -1,15 +1,10 @@ use core::prelude::*; -pub fn pr( - source: &Series, - high: &Series, - low: &Series, - period: usize, -) -> Series { +pub fn wpr(source: &Price, high: &Price, low: &Price, period: Period) -> Price { let hh = high.highest(period); let ll = low.lowest(period); - SCALE * (source - &hh) / (&hh - &ll) + SCALE * (source - &hh) / (hh - ll) } #[cfg(test)] @@ -17,7 +12,7 @@ mod tests { use super::*; #[test] - fn test_pr() { + fn test_wpr() { let high = Series::from([6.748, 6.838, 6.804, 6.782, 6.786]); let low = Series::from([6.655, 6.728, 6.729, 6.718, 6.732]); let source = Series::from([6.738, 6.780, 6.751, 6.766, 6.735]); @@ -25,7 +20,7 @@ mod tests { let expected = vec![-10.752942, -31.693844, -47.541027, -60.00008, -80.23232]; - let result: Vec = pr(&source, &high, &low, period).into(); + let result: Vec = wpr(&source, &high, &low, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/alma.rs b/ta_lib/indicators/trend/src/alma.rs index ec6d5fd6..1c4aef7e 100644 --- a/ta_lib/indicators/trend/src/alma.rs +++ b/ta_lib/indicators/trend/src/alma.rs @@ -1,11 +1,12 @@ use core::prelude::*; -pub fn alma(source: &Series, period: usize, offset: f32, sigma: f32) -> Series { - let m = offset * (period as f32 - 1.); - let s = period as f32 / sigma; +pub fn alma(source: &Price, period: Period, offset: Scalar, sigma: Scalar) -> Price { + let m = (offset * (period - 1) as Scalar).floor(); + let s = period as Scalar / sigma; let weights = (0..period) - .map(|i| ((-1. * (i as f32 - m).powi(2)) / (2. * s.powi(2))).exp()) + .rev() + .map(|i| (-1. * (i as Scalar - m).powi(2) / (2. * s.powi(2))).exp()) .collect::>(); source.wg(&weights) @@ -21,27 +22,18 @@ mod tests { 0.01707, 0.01706, 0.01707, 0.01705, 0.01710, 0.01705, 0.01704, 0.01709, ]); let expected = [ - 0.0, - 0.0, - 0.017066907, - 0.01705621, - 0.017084463, - 0.017065462, - 0.017043246, - 0.017074436, + 0.01707, + 0.017067866, + 0.01706213, + 0.017066803, + 0.017057454, + 0.01708935, + 0.01705426, + 0.01704639, ]; - let epsilon = 0.001; - let result: Vec = alma(&source, 3, 0.85, 6.0).into(); + let result: Vec = alma(&source, 3, 0.85, 6.0).into(); - for i in 0..source.len() { - assert!( - (result[i] - expected[i]).abs() < epsilon, - "at position {}: {} != {}", - i, - result[i], - expected[i] - ); - } + assert_eq!(result, expected); } } diff --git a/ta_lib/indicators/trend/src/ast.rs b/ta_lib/indicators/trend/src/ast.rs index 2ecfec17..e32ce235 100644 --- a/ta_lib/indicators/trend/src/ast.rs +++ b/ta_lib/indicators/trend/src/ast.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn ast(close: &Series, atr: &Series, factor: f32) -> (Series, Series) { +pub fn ast(close: &Price, atr: &Price, factor: Scalar) -> (Price, Price) { let atr_multi = atr * factor; let len = close.len(); @@ -82,8 +82,8 @@ mod tests { ]; let (direction, trend) = ast(&close, &atr, factor); - let result_direction: Vec = direction.into(); - let result_trend: Vec = trend.into(); + let result_direction: Vec = direction.into(); + let result_trend: Vec = trend.into(); assert_eq!(high.len(), low.len()); assert_eq!(high.len(), close.len()); diff --git a/ta_lib/indicators/trend/src/cama.rs b/ta_lib/indicators/trend/src/cama.rs index 224b23f6..d7cd40dc 100644 --- a/ta_lib/indicators/trend/src/cama.rs +++ b/ta_lib/indicators/trend/src/cama.rs @@ -1,12 +1,6 @@ use core::prelude::*; -pub fn cama( - source: &Series, - high: &Series, - low: &Series, - tr: &Series, - period: usize, -) -> Series { +pub fn cama(source: &Price, high: &Price, low: &Price, tr: &Price, period: Period) -> Price { let hh = high.highest(period); let ll = low.lowest(period); @@ -18,7 +12,7 @@ pub fn cama( #[cfg(test)] mod tests { use super::*; - use volatility::tr; + use volatility::wtr; #[test] fn test_cama() { @@ -52,9 +46,9 @@ mod tests { 7.1226425, 7.1630764, 7.157433, 7.156123, ]; let period = 2; - let tr = tr(&high, &low, &close); + let tr = wtr(&high, &low, &close); - let result: Vec = cama(&close, &high, &low, &tr, period).into(); + let result: Vec = cama(&close, &high, &low, &tr, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/ce.rs b/ta_lib/indicators/trend/src/ce.rs index 83542095..e55e2340 100644 --- a/ta_lib/indicators/trend/src/ce.rs +++ b/ta_lib/indicators/trend/src/ce.rs @@ -1,45 +1,38 @@ use core::prelude::*; -pub fn ce( - close: &Series, - atr: &Series, - period: usize, - factor: f32, -) -> (Series, Series) { +pub fn ce(close: &Price, atr: &Price, period: Period, factor: Scalar) -> (Price, Price) { let len = close.len(); let atr_mul = atr * factor; - let basic_up = close.highest(period) - &atr_mul; - let mut up = Series::empty(len); - - let basic_dn = close.lowest(period) + &atr_mul; - let mut dn = Series::empty(len); + let mut up = close.highest(period) - &atr_mul; + let mut dn = close.lowest(period) + &atr_mul; let prev_close = close.shift(1); - let mut direction = Series::empty(len); + let mut direction = Series::one(len); let trend_bull = Series::one(len); let trend_bear = trend_bull.negate(); for _ in 0..len { - let prev_up = up.shift(1); - up = iff!(prev_close.sgt(&prev_up), basic_up.max(&prev_up), basic_up); + let prev_up = nz!(up.shift(1), up); + up = iff!(prev_close.sgt(&prev_up), up.max(&prev_up), up); - let prev_dn = dn.shift(1); - dn = iff!(prev_close.slt(&prev_dn), basic_dn.min(&prev_dn), basic_dn); + let prev_dn = nz!(dn.shift(1), dn); + dn = iff!(prev_close.slt(&prev_dn), dn.min(&prev_dn), dn); direction = nz!(direction.shift(1), direction); - direction = iff!(close.sgte(&prev_dn), trend_bull, direction); - direction = iff!(close.slt(&prev_up), trend_bear, direction); + direction = iff!( + direction.seq(&MINUS_ONE) & close.sgt(&prev_dn), + trend_bull, + direction + ); + direction = iff!( + direction.seq(&ONE) & close.slt(&prev_up), + trend_bear, + direction + ); } - let first_non_empty = direction - .iter() - .find(|&&el| el.is_some()) - .map(|&el| -el.unwrap()); - - direction = direction.nz(first_non_empty); - let trend = iff!(direction.seq(&ONE), up, dn); (direction, trend) @@ -69,16 +62,16 @@ mod tests { let factor = 2.0; let expected_trend = vec![ - 4.8592, 4.8566666, 4.8563113, 4.8435073, 4.8271384, 4.8271384, 4.8271384, 4.8271384, - 4.8271384, 4.784503, 4.815269, 4.815269, 4.81972, 4.81972, 4.81972, + 4.7652, 4.7667336, 4.767089, 4.767089, 4.767089, 4.767089, 4.767089, 4.767089, + 4.767089, 4.784503, 4.815269, 4.815269, 4.81972, 4.81972, 4.81972, ]; let expected_direction = vec![ - -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, ]; let (direction, trend) = ce(&close, &atr, period, factor); - let result_direction: Vec = direction.into(); - let result_trend: Vec = trend.into(); + let result_direction: Vec = direction.into(); + let result_trend: Vec = trend.into(); assert_eq!(high.len(), low.len()); assert_eq!(high.len(), close.len()); @@ -105,16 +98,16 @@ mod tests { let factor = 2.0; let expected_trend = vec![ - 4.946201, 4.9390006, 4.9390006, 4.9390006, 4.8700404, 4.8700404, 4.8700404, 4.8700404, + 4.7653995, 4.7874, 4.8232665, 4.831711, 4.8700404, 4.8700404, 4.8700404, 4.8700404, 4.959112, 4.949508, 4.9344387, 4.912726, ]; let expected_direction = vec![ - -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, -1.0, -1.0, -1.0, -1.0, ]; let (direction, trend) = ce(&close, &atr, period, factor); - let result_direction: Vec = direction.into(); - let result_trend: Vec = trend.into(); + let result_direction: Vec = direction.into(); + let result_trend: Vec = trend.into(); assert_eq!(high.len(), low.len()); assert_eq!(high.len(), close.len()); diff --git a/ta_lib/indicators/trend/src/chop.rs b/ta_lib/indicators/trend/src/chop.rs index 6533b991..32528343 100644 --- a/ta_lib/indicators/trend/src/chop.rs +++ b/ta_lib/indicators/trend/src/chop.rs @@ -1,13 +1,8 @@ use core::prelude::*; -pub fn chop( - high: &Series, - low: &Series, - atr: &Series, - period: usize, -) -> Series { +pub fn chop(high: &Price, low: &Price, atr: &Price, period: Period) -> Price { SCALE * (atr.sum(period) / (high.highest(period) - low.lowest(period))).log10() - / (period as f32).log10() + / (period as Scalar).log10() } #[cfg(test)] @@ -26,7 +21,7 @@ mod tests { let expected = [0.0, 45.571022, 0.0, 26.31491, 40.33963, 58.496246]; let epsilon = 0.0001; - let result: Vec = chop(&high, &low, &atr, period).into(); + let result: Vec = chop(&high, &low, &atr, period).into(); for i in 0..result.len() { assert!( diff --git a/ta_lib/indicators/trend/src/dema.rs b/ta_lib/indicators/trend/src/dema.rs index 475fc3f5..b05cc713 100644 --- a/ta_lib/indicators/trend/src/dema.rs +++ b/ta_lib/indicators/trend/src/dema.rs @@ -1,9 +1,7 @@ use core::prelude::*; -pub fn dema(source: &Series, period: usize) -> Series { - let ema = source.smooth(Smooth::EMA, period); - - 2. * &ema - ema.smooth(Smooth::EMA, period) +pub fn dema(source: &Price, period: Period) -> Price { + source.smooth(Smooth::DEMA, period) } #[cfg(test)] diff --git a/ta_lib/indicators/trend/src/dpo.rs b/ta_lib/indicators/trend/src/dpo.rs index dddcf934..587a092b 100644 --- a/ta_lib/indicators/trend/src/dpo.rs +++ b/ta_lib/indicators/trend/src/dpo.rs @@ -1,9 +1,9 @@ use core::prelude::*; -pub fn dpo(source: &Series, smooth_type: Smooth, period: usize) -> Series { +pub fn dpo(source: &Price, smooth: Smooth, period: Period) -> Price { let k = period / 2 + 1; - source - source.smooth(smooth_type, period).shift(k) + source - source.smooth(smooth, period).shift(k) } #[cfg(test)] @@ -24,7 +24,7 @@ mod tests { 0.0022332668, ]; - let result: Vec = dpo(&source, Smooth::SMA, period).into(); + let result: Vec = dpo(&source, Smooth::SMA, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/ema.rs b/ta_lib/indicators/trend/src/ema.rs index 16de440b..b20b0c63 100644 --- a/ta_lib/indicators/trend/src/ema.rs +++ b/ta_lib/indicators/trend/src/ema.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn ema(source: &Series, period: usize) -> Series { +pub fn ema(source: &Price, period: Period) -> Price { source.smooth(Smooth::EMA, period) } @@ -19,7 +19,7 @@ mod tests { 6.5332203, 6.5316105, 6.5384054, 6.531903, 6.5178514, 6.489626, 6.487513, 6.492057, ]; - let result: Vec = ema(&source, 3).into(); + let result: Vec = ema(&source, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/frama.rs b/ta_lib/indicators/trend/src/frama.rs index 723df8d0..20d39e4f 100644 --- a/ta_lib/indicators/trend/src/frama.rs +++ b/ta_lib/indicators/trend/src/frama.rs @@ -1,18 +1,13 @@ use core::prelude::*; -pub fn frama( - source: &Series, - high: &Series, - low: &Series, - period: usize, -) -> Series { - let len = (0.5 * period as f32).round() as usize; +pub fn frama(source: &Price, high: &Price, low: &Price, period: Period) -> Price { + let len = (HALF * period as Scalar).floor() as Period; let hh = high.highest(len); let ll = low.lowest(len); - let n1 = (&hh - &ll) / len as f32; - let n2 = (hh.shift(len) - ll.shift(len)) / len as f32; - let n3 = (high.highest(period) - low.lowest(period)) / period as f32; + let n1 = (&hh - &ll) / len as Scalar; + let n2 = (hh.shift(len) - ll.shift(len)) / len as Scalar; + let n3 = (high.highest(period) - low.lowest(period)) / period as Scalar; let d = ((n1 + n2) / n3).log() / 2.0_f32.ln(); @@ -40,11 +35,11 @@ mod tests { 5.2000, 5.2169, ]); let expected = vec![ - 0.0, 0.0, 5.0958, 5.09975, 5.0997515, 5.100709, 5.10147, 5.1022835, 5.130958, - 5.2076283, 5.172987, 5.1907825, 5.2242994, + 0.0, 5.0896997, 5.090174, 5.0918617, 5.0919185, 5.092221, 5.092286, 5.092319, 5.094049, + 5.105295, 5.122919, 5.1347446, 5.152087, ]; - let result: Vec = frama(&source, &high, &low, 3).into(); + let result: Vec = frama(&source, &high, &low, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/gma.rs b/ta_lib/indicators/trend/src/gma.rs index 9fbc4711..17ee709d 100644 --- a/ta_lib/indicators/trend/src/gma.rs +++ b/ta_lib/indicators/trend/src/gma.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn gma(source: &Series, period: usize) -> Series { +pub fn gma(source: &Price, period: Period) -> Price { source.log().smooth(Smooth::SMA, period).exp() } diff --git a/ta_lib/indicators/trend/src/hema.rs b/ta_lib/indicators/trend/src/hema.rs index 2a1aea68..c1a365ad 100644 --- a/ta_lib/indicators/trend/src/hema.rs +++ b/ta_lib/indicators/trend/src/hema.rs @@ -1,7 +1,7 @@ use core::prelude::*; -pub fn hema(source: &Series, period: usize) -> Series { - let period = (period as f32 / 2.) as usize; +pub fn hema(source: &Price, period: usize) -> Price { + let period = (period as Scalar / 2.) as Period; 3. * source.smooth(Smooth::WMA, period) - 2. * source.smooth(Smooth::EMA, period) } @@ -13,10 +13,10 @@ mod tests { #[test] fn test_hema() { let source = Series::from([19.099, 19.079, 19.074, 19.139, 19.191]); - let expected = vec![0.0, 19.08567, 19.07122, 19.114738, 19.18724]; + let expected = vec![19.099003, 19.08567, 19.07122, 19.114738, 19.18724]; let period = 4; - let result: Vec = hema(&source, period).into(); + let result: Vec = hema(&source, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/hma.rs b/ta_lib/indicators/trend/src/hma.rs index c713ad9c..7b2fc915 100644 --- a/ta_lib/indicators/trend/src/hma.rs +++ b/ta_lib/indicators/trend/src/hma.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn hma(source: &Series, period: usize) -> Series { +pub fn hma(source: &Price, period: Period) -> Price { source.smooth(Smooth::HMA, period) } @@ -11,9 +11,9 @@ mod tests { #[test] fn test_hma() { let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); - let expected = vec![0.0, 0.0, 3.0000002, 4.0, 4.9999995]; + let expected = vec![1.0, 2.3333335, 3.6666667, 4.666667, 5.6666665]; - let result: Vec = hma(&source, 3).into(); + let result: Vec = hma(&source, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/kama.rs b/ta_lib/indicators/trend/src/kama.rs index 01a670d2..3554110b 100644 --- a/ta_lib/indicators/trend/src/kama.rs +++ b/ta_lib/indicators/trend/src/kama.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn kama(source: &Series, period: usize) -> Series { +pub fn kama(source: &Price, period: Period) -> Price { source.smooth(Smooth::KAMA, period) } @@ -15,11 +15,11 @@ mod tests { 5.2000, 5.2169, ]); let expected = vec![ - 0.0, 0.0, 0.0, 5.1023, 5.1018033, 5.1023088, 5.102306, 5.104807, 5.1141686, 5.129071, - 5.1461506, 5.1700835, 5.190891, + 5.0788, 5.0788455, 5.078916, 5.0892878, 5.0915785, 5.0941925, 5.09429, 5.0988545, + 5.110461, 5.1269784, 5.144952, 5.1693687, 5.190451, ]; - let result: Vec = kama(&source, 3).into(); + let result: Vec = kama(&source, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/lib.rs b/ta_lib/indicators/trend/src/lib.rs index cb419fa1..91643251 100644 --- a/ta_lib/indicators/trend/src/lib.rs +++ b/ta_lib/indicators/trend/src/lib.rs @@ -11,18 +11,20 @@ mod gma; mod hema; mod hma; mod kama; -mod kjs; mod lsma; mod md; -mod qstick; +mod midpoint; +mod pp; mod rmsma; mod sinwma; +mod slsma; mod sma; mod smma; mod supertrend; mod t3; mod tema; -mod tma; +mod trima; +mod ults; mod vi; mod vidya; mod vwema; @@ -46,18 +48,20 @@ pub use gma::gma; pub use hema::hema; pub use hma::hma; pub use kama::kama; -pub use kjs::kjs; pub use lsma::lsma; pub use md::md; -pub use qstick::qstick; +pub use midpoint::midpoint; +pub use pp::{cpp, dpp, fpp, pp, spp, wpp}; pub use rmsma::rmsma; pub use sinwma::sinwma; +pub use slsma::slsma; pub use sma::sma; pub use smma::smma; pub use supertrend::supertrend; pub use t3::t3; pub use tema::tema; -pub use tma::tma; +pub use trima::trima; +pub use ults::ults; pub use vi::vi; pub use vidya::vidya; pub use vwema::vwema; diff --git a/ta_lib/indicators/trend/src/lsma.rs b/ta_lib/indicators/trend/src/lsma.rs index 6b2b2729..93289057 100644 --- a/ta_lib/indicators/trend/src/lsma.rs +++ b/ta_lib/indicators/trend/src/lsma.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn lsma(source: &Series, period: usize) -> Series { +pub fn lsma(source: &Price, period: Period) -> Price { source.smooth(Smooth::LSMA, period) } @@ -10,10 +10,15 @@ mod tests { #[test] fn test_lsma() { - let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); - let expected = vec![1.0, 2.0, 3.0000002, 4.0, 5.0]; + let source = Series::from([ + 12.529, 12.504, 12.517, 12.542, 12.547, 12.577, 12.539, 12.577, 12.490, 12.490, + ]); + let expected = vec![ + 12.529, 12.504, 12.510668, 12.54, 12.550336, 12.572831, 12.550328, 12.564322, + 12.510831, 12.475489, + ]; - let result: Vec = lsma(&source, 3).into(); + let result: Vec = lsma(&source, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/md.rs b/ta_lib/indicators/trend/src/md.rs index a7aeea93..82225501 100644 --- a/ta_lib/indicators/trend/src/md.rs +++ b/ta_lib/indicators/trend/src/md.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn md(source: &Series, period: usize) -> Series { +pub fn md(source: &Price, period: Period) -> Price { let len = source.len(); let mut mg = Series::empty(len); @@ -8,7 +8,7 @@ pub fn md(source: &Series, period: usize) -> Series { for _ in 0..len { let prev_mg = nz!(mg.shift(1), seed); - mg = &prev_mg + (source - &prev_mg) / ((source / &prev_mg).pow(4) * period as f32); + mg = &prev_mg + (source - &prev_mg) / ((source / &prev_mg).pow(4) * period as Scalar); } mg @@ -28,7 +28,7 @@ mod tests { 19.576805, 19.576_204, ]; - let result: Vec = md(&source, 3).into(); + let result: Vec = md(&source, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/kjs.rs b/ta_lib/indicators/trend/src/midpoint.rs similarity index 63% rename from ta_lib/indicators/trend/src/kjs.rs rename to ta_lib/indicators/trend/src/midpoint.rs index ec6b0779..6b5d5bf0 100644 --- a/ta_lib/indicators/trend/src/kjs.rs +++ b/ta_lib/indicators/trend/src/midpoint.rs @@ -1,7 +1,7 @@ use core::prelude::*; -pub fn kjs(high: &Series, low: &Series, period: usize) -> Series { - 0.5 * (low.lowest(period) + high.highest(period)) +pub fn midpoint(high: &Price, low: &Price, period: Period) -> Price { + HALF * (high.highest(period) + low.lowest(period)) } #[cfg(test)] @@ -9,12 +9,12 @@ mod tests { use super::*; #[test] - fn test_kjs() { + fn test_midpoint() { let high = Series::from([2.0859, 2.0881, 2.0889, 2.0896, 2.0896, 2.0907]); let low = Series::from([2.0846, 2.0846, 2.0881, 2.0886, 2.0865, 2.0875]); let expected = vec![2.08525, 2.08635, 2.08675, 2.0871, 2.08805, 2.0886]; - let result: Vec = kjs(&high, &low, 3).into(); + let result: Vec = midpoint(&high, &low, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/pp.rs b/ta_lib/indicators/trend/src/pp.rs new file mode 100644 index 00000000..82d826fe --- /dev/null +++ b/ta_lib/indicators/trend/src/pp.rs @@ -0,0 +1,270 @@ +use core::prelude::*; + +pub fn pp(high: &Price, low: &Price, close: &Price) -> (Price, Price) { + let pp = (high + low + close) / 3.; + + let support = 2. * &pp - high; + let resistance = 2. * &pp - low; + + (support, resistance) +} + +pub fn fpp(high: &Price, low: &Price, close: &Price) -> (Price, Price) { + let pp = (high + low + close) / 3.; + + let hl = 0.382 * (high - low); + + let support = &pp - &hl; + let resistance = &pp + &hl; + + (support, resistance) +} + +pub fn wpp(open: &Price, high: &Price, low: &Price) -> (Price, Price) { + let pp = (high + low + 2. * open) / 4.; + + let support = 2. * &pp - high; + let resistance = 2. * &pp - low; + + (support, resistance) +} + +pub fn cpp(high: &Price, low: &Price, close: &Price) -> (Price, Price) { + let hl = 1.1 * (high - low) / 12.; + + let support = close - &hl; + let resistance = close + &hl; + + (support, resistance) +} + +pub fn dpp(open: &Price, high: &Price, low: &Price, close: &Price) -> (Price, Price) { + let mut pp = iff!( + close.sgt(open), + 2. * high + low + close, + high + low + 2. * close + ); + pp = iff!(close.slt(open), high + 2. * low + close, pp); + + let support = 0.5 * &pp - high; + let resistance = 0.5 * &pp - low; + + (support, resistance) +} + +pub fn spp( + high: &Price, + low: &Price, + close: &Price, + smooth: Smooth, + period: Period, +) -> (Price, Price) { + let hh = high.highest(period); + let ll = low.lowest(period); + let close = close.smooth(smooth, period); + + let pp = (&hh + &ll + close) / 3.; + + let support = 2. * pp.lowest(period) - hh; + let resistance = 2. * pp.highest(period) - ll; + + (support, resistance) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_pivot_points() { + let high = Series::from([ + 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, + 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, + ]); + let low = Series::from([ + 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, + 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, + ]); + let close = Series::from([ + 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038, + 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, + ]); + let expected_support = vec![ + 6.5439324, 6.5541005, 6.5118666, 6.4641, 6.5020003, 6.526467, 6.5022664, 6.5202003, + 6.5273, 6.5162005, 6.4905, 6.4427676, 6.469, 6.4663997, 6.4966, 6.5126, + ]; + let expected_resistance = vec![ + 6.5621324, 6.6196003, 6.575967, 6.54, 6.5438004, 6.583367, 6.5526667, 6.538, 6.5564, + 6.5413003, 6.5248, 6.4932675, 6.4941998, 6.5401993, 6.5205, 6.5354, + ]; + + let (support, resistance) = pp(&high, &low, &close); + let result_support: Vec = support.into(); + let result_resistance: Vec = resistance.into(); + + assert_eq!(result_support, expected_support); + assert_eq!(result_resistance, expected_resistance); + } + + #[test] + fn test_fibonacci_pivot_points() { + let high = Series::from([ + 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, + 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, + ]); + let low = Series::from([ + 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, + 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, + ]); + let close = Series::from([ + 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038, + 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, + ]); + let expected_support = vec![ + 6.545014, 6.554479, 6.5285473, 6.4801064, 6.5000324, 6.526498, 6.5133805, 6.5214005, + 6.527384, 6.522512, 6.4983974, 6.455343, 6.4681735, 6.481808, 6.49627, 6.5122905, + ]; + let expected_resistance = vec![ + 6.5589185, 6.6045213, 6.5775194, 6.5380936, 6.531968, 6.569969, 6.551886, 6.535, + 6.549616, 6.5416884, 6.5246024, 6.4939246, 6.4874263, 6.5381913, 6.51453, 6.5297093, + ]; + + let (support, resistance) = fpp(&high, &low, &close); + let result_support: Vec = support.into(); + let result_resistance: Vec = resistance.into(); + + assert_eq!(result_support, expected_support); + assert_eq!(result_resistance, expected_resistance); + } + + #[test] + fn test_woodie_pivot_points() { + let open = Series::from([ + 6.5541, 6.5942, 6.5345, 6.4950, 6.5298, 6.5619, 6.5223, 6.5300, 6.5451, 6.5254, 6.5038, + 6.4614, 6.4853, 6.4966, 6.5117, 6.5272, + ]); + let high = Series::from([ + 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, + 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, + ]); + let low = Series::from([ + 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, + 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, + ]); + let expected_support = vec![ + 6.5449996, 6.5614505, 6.50245, 6.4570503, 6.5089, 6.5334506, 6.4970994, 6.5211005, + 6.53055, 6.5128503, 6.48665, 6.43615, 6.4727, 6.4597, 6.49975, 6.5158005, + ]; + let expected_resistance = vec![ + 6.5631995, 6.6269503, 6.5665503, 6.5329504, 6.5507, 6.5903506, 6.5474997, 6.5389004, + 6.55965, 6.53795, 6.52095, 6.48665, 6.4979, 6.5334997, 6.52365, 6.5386004, + ]; + + let (support, resistance) = wpp(&open, &high, &low); + let result_support: Vec = support.into(); + let result_resistance: Vec = resistance.into(); + + assert_eq!(result_support, expected_support); + assert_eq!(result_resistance, expected_resistance); + } + + #[test] + fn test_camarilla_pivot_points() { + let high = Series::from([ + 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, + 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, + ]); + let low = Series::from([ + 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, + 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, + ]); + let close = Series::from([ + 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038, + 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, + ]); + let expected_support = vec![ + 6.5524316, 6.588196, 6.528924, 6.4880424, 6.525968, 6.5563846, 6.5176797, 6.5283685, + 6.5425324, 6.5230994, 6.5006557, 6.456771, 6.4830904, 6.4898353, 6.509509, 6.52491, + ]; + let expected_resistance = vec![ + 6.5557685, 6.6002045, 6.540676, 6.5019574, 6.533632, 6.566816, 6.52692, 6.531632, + 6.5478673, 6.527701, 6.506944, 6.466029, 6.48771, 6.503365, 6.513891, 6.52909, + ]; + + let (support, resistance) = cpp(&high, &low, &close); + let result_support: Vec = support.into(); + let result_resistance: Vec = resistance.into(); + + assert_eq!(result_support, expected_support); + assert_eq!(result_resistance, expected_resistance); + } + + #[test] + fn test_demark_pivot_points() { + let open = Series::from([ + 6.5541, 6.5942, 6.5345, 6.4950, 6.5298, 6.5619, 6.5223, 6.5300, 6.5451, 6.5254, 6.5038, + 6.4614, 6.4853, 6.4866, 6.5117, 6.5272, + ]); + let high = Series::from([ + 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, + 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, + ]); + let low = Series::from([ + 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, + 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, + ]); + let close = Series::from([ + 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038, + 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, + ]); + let expected_support = vec![ + 6.5449996, 6.5614505, 6.5324497, 6.4570503, 6.5089, 6.5089, 6.4970994, 6.5211005, + 6.5329, 6.5128503, 6.48665, 6.43615, 6.473401, 6.4881997, 6.49975, 6.5053997, + ]; + let expected_resistance = vec![ + 6.5631995, 6.6269503, 6.59655, 6.5329504, 6.5507, 6.5658, 6.5474997, 6.5389004, 6.562, + 6.53795, 6.52095, 6.48665, 6.498601, 6.5619993, 6.52365, 6.5281997, + ]; + + let (support, resistance) = dpp(&open, &high, &low, &close); + let result_support: Vec = support.into(); + let result_resistance: Vec = resistance.into(); + + assert_eq!(result_support, expected_support); + assert_eq!(result_resistance, expected_resistance); + } + + #[test] + fn test_smooth_pivot_points() { + let high = Series::from([ + 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, + 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, + ]); + let low = Series::from([ + 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, + 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, + ]); + let close = Series::from([ + 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038, + 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, + ]); + + let period = 3; + + let expected_support = vec![ + 6.5439324, 6.4990325, 6.4990325, 6.4780555, 6.467311, 6.4813333, 6.4813333, 6.4813333, + 6.5010676, 6.518056, 6.498766, 6.452577, 6.448855, 6.427755, 6.427755, 6.440223, + ]; + let expected_resistance = vec![ + 6.5621324, 6.6062336, 6.615534, 6.6674337, 6.6524887, 6.6047554, 6.5758677, 6.5677786, + 6.5677786, 6.5619783, 6.5738664, 6.611756, 6.592466, 6.544577, 6.5471992, 6.55031, + ]; + + let (support, resistance) = spp(&high, &low, &close, Smooth::SMA, period); + let result_support: Vec = support.into(); + let result_resistance: Vec = resistance.into(); + + assert_eq!(result_support, expected_support); + assert_eq!(result_resistance, expected_resistance); + } +} diff --git a/ta_lib/indicators/trend/src/rmsma.rs b/ta_lib/indicators/trend/src/rmsma.rs index ca54a6f2..abc37cc1 100644 --- a/ta_lib/indicators/trend/src/rmsma.rs +++ b/ta_lib/indicators/trend/src/rmsma.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn rmsma(source: &Series, period: usize) -> Series { +pub fn rmsma(source: &Price, period: Period) -> Price { (source * source).smooth(Smooth::SMA, period).sqrt() } @@ -13,7 +13,7 @@ mod tests { let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); let expected = vec![1.0, 1.5811388, 2.1602468, 3.1091263, 4.082483]; - let result: Vec = rmsma(&source, 3).into(); + let result: Vec = rmsma(&source, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/sinwma.rs b/ta_lib/indicators/trend/src/sinwma.rs index 6ad39bb3..3683b6a6 100644 --- a/ta_lib/indicators/trend/src/sinwma.rs +++ b/ta_lib/indicators/trend/src/sinwma.rs @@ -1,8 +1,9 @@ use core::prelude::*; -pub fn sinwma(source: &Series, period: usize) -> Series { +pub fn sinwma(source: &Price, period: Period) -> Price { let weights = (0..period) - .map(|i| ((i as f32 + 1.) * std::f32::consts::PI / (period as f32 + 1.)).sin()) + .rev() + .map(|i| ((i as Scalar + 1.) * PI / (period + 1) as Scalar).sin()) .collect::>(); source.wg(&weights) @@ -18,27 +19,18 @@ mod tests { 0.01707, 0.01706, 0.01707, 0.01705, 0.01710, 0.01705, 0.01704, 0.01709, ]); let expected = [ - 0.0, - 0.0, + 0.01707, + 0.017064141, 0.017065858, - 0.017061213, + 0.017061211, 0.017070502, 0.01707071, 0.017061714, 0.017057573, ]; - let epsilon = 0.0001; - let result: Vec = sinwma(&source, 3).into(); + let result: Vec = sinwma(&source, 3).into(); - for i in 0..source.len() { - assert!( - (result[i] - expected[i]).abs() < epsilon, - "at position {}: {} != {}", - i, - result[i], - expected[i] - ); - } + assert_eq!(result, expected); } } diff --git a/ta_lib/indicators/trend/src/slsma.rs b/ta_lib/indicators/trend/src/slsma.rs new file mode 100644 index 00000000..07b165d2 --- /dev/null +++ b/ta_lib/indicators/trend/src/slsma.rs @@ -0,0 +1,27 @@ +use core::prelude::*; + +pub fn slsma(source: &Price, period: Period) -> Price { + let lsma = source.smooth(Smooth::LSMA, period); + + &lsma - (&lsma - lsma.smooth(Smooth::LSMA, period)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_slsma() { + let source = Series::from([ + 12.529, 12.504, 12.517, 12.542, 12.547, 12.577, 12.539, 12.577, 12.490, 12.490, + ]); + let expected = vec![ + 12.529, 12.504, 12.50539, 12.536224, 12.553506, 12.570806, 12.557826, 12.558236, + 12.522075, 12.472454, + ]; + + let result: Vec = slsma(&source, 3).into(); + + assert_eq!(result, expected); + } +} diff --git a/ta_lib/indicators/trend/src/sma.rs b/ta_lib/indicators/trend/src/sma.rs index 07fba77e..3aa44726 100644 --- a/ta_lib/indicators/trend/src/sma.rs +++ b/ta_lib/indicators/trend/src/sma.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn sma(source: &Series, period: usize) -> Series { +pub fn sma(source: &Price, period: Period) -> Price { source.smooth(Smooth::SMA, period) } @@ -13,7 +13,7 @@ mod tests { let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); let expected = vec![1.0, 1.5, 2.0, 3.0, 4.0]; - let result: Vec = sma(&source, 3).into(); + let result: Vec = sma(&source, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/smma.rs b/ta_lib/indicators/trend/src/smma.rs index 9708ad82..b7ace4a2 100644 --- a/ta_lib/indicators/trend/src/smma.rs +++ b/ta_lib/indicators/trend/src/smma.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn smma(source: &Series, period: usize) -> Series { +pub fn smma(source: &Price, period: Period) -> Price { source.smooth(Smooth::SMMA, period) } @@ -20,7 +20,7 @@ mod tests { 6.835905, ]; - let result: Vec = smma(&source, 3).into(); + let result: Vec = smma(&source, 3).into(); assert_eq!(result, expected) } diff --git a/ta_lib/indicators/trend/src/supertrend.rs b/ta_lib/indicators/trend/src/supertrend.rs index 5a3d11f1..1d34bad2 100644 --- a/ta_lib/indicators/trend/src/supertrend.rs +++ b/ta_lib/indicators/trend/src/supertrend.rs @@ -1,46 +1,39 @@ use core::prelude::*; -pub fn supertrend( - source: &Series, - close: &Series, - atr: &Series, - factor: f32, -) -> (Series, Series) { +pub fn supertrend(source: &Price, close: &Price, atr: &Price, factor: Scalar) -> (Price, Price) { let len = source.len(); let atr_mul = atr * factor; - let basic_up = source - &atr_mul; - let mut up = Series::empty(len); - - let basic_dn = source + &atr_mul; - let mut dn = Series::empty(len); + let mut up = source - &atr_mul; + let mut dn = source + &atr_mul; let prev_close = close.shift(1); - let mut direction = Series::empty(len); + let mut direction = Series::one(len); let trend_bull = Series::one(len); let trend_bear = trend_bull.negate(); for _ in 0..len { - let prev_up = up.shift(1); - up = iff!(prev_close.sgt(&prev_up), basic_up.max(&prev_up), basic_up); + let prev_up = nz!(up.shift(1), up); + up = iff!(prev_close.sgt(&prev_up), up.max(&prev_up), up); - let prev_dn = dn.shift(1); - dn = iff!(prev_close.slt(&prev_dn), basic_dn.min(&prev_dn), basic_dn); + let prev_dn = nz!(dn.shift(1), dn); + dn = iff!(prev_close.slt(&prev_dn), dn.min(&prev_dn), dn); direction = nz!(direction.shift(1), direction); - direction = iff!(close.sgte(&prev_dn), trend_bull, direction); - direction = iff!(close.slt(&prev_up), trend_bear, direction); + direction = iff!( + direction.seq(&MINUS_ONE) & close.sgt(&prev_dn), + trend_bull, + direction + ); + direction = iff!( + direction.seq(&ONE) & close.slt(&prev_up), + trend_bear, + direction + ); } - let first_non_empty = direction - .iter() - .find(|&&el| el.is_some()) - .map(|&el| -el.unwrap()); - - direction = direction.nz(first_non_empty); - let trend = iff!(direction.seq(&ONE), up, dn); (direction, trend) @@ -82,8 +75,8 @@ mod tests { ]; let (direction, supertrend) = supertrend(&hl2, &close, &atr, factor); - let result_direction: Vec = direction.into(); - let result_supertrend: Vec = supertrend.into(); + let result_direction: Vec = direction.into(); + let result_supertrend: Vec = supertrend.into(); assert_eq!(high.len(), low.len()); assert_eq!(high.len(), close.len()); @@ -114,16 +107,16 @@ mod tests { let factor = 3.0; let expected_supertrend = vec![ - 6.223499, 6.207499, 6.2073326, 6.2073326, 6.2073326, 6.2073326, 6.2073326, 6.2073326, - 6.2073326, 6.2073326, 6.2073326, 6.2073326, 6.2073326, 6.2073326, 6.2073326, 6.2073326, - 6.2073326, 6.2073326, 6.2073326, 6.2073326, 6.2073326, 6.2073326, 6.1522994, 6.1522994, + 6.073501, 6.077501, 6.0926676, 6.107279, 6.1168528, 6.122902, 6.122902, 6.1263456, + 6.1263456, 6.1263456, 6.131306, 6.131306, 6.133747, 6.133747, 6.133747, 6.133747, + 6.133747, 6.133747, 6.134674, 6.134674, 6.134674, 6.1417, 6.1522994, 6.1522994, 6.1522994, 6.1522994, 6.1522994, 6.1522994, 6.1522994, 6.1522994, 6.1522994, 6.1522994, 6.1522994, 6.1580467, 6.1580467, 6.1580467, ]; let expected_direction = vec![ - -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, - -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, - 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, + 1.0, 1.0, ]; let (direction, supertrend) = supertrend(&hl2, &close, &atr, factor); diff --git a/ta_lib/indicators/trend/src/t3.rs b/ta_lib/indicators/trend/src/t3.rs index 352a97bb..634034f0 100644 --- a/ta_lib/indicators/trend/src/t3.rs +++ b/ta_lib/indicators/trend/src/t3.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn t3(source: &Series, period: usize) -> Series { +pub fn t3(source: &Price, period: Period) -> Price { let alpha = 0.618; let ema1 = source.smooth(Smooth::EMA, period); @@ -27,7 +27,7 @@ mod tests { let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); let expected = vec![1.0, 1.2803686, 1.8820143, 2.717381, 3.6838531]; - let result: Vec = t3(&source, 3).into(); + let result: Vec = t3(&source, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/tema.rs b/ta_lib/indicators/trend/src/tema.rs index 8fde0a79..3bbd76d6 100644 --- a/ta_lib/indicators/trend/src/tema.rs +++ b/ta_lib/indicators/trend/src/tema.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn tema(source: &Series, period: usize) -> Series { +pub fn tema(source: &Price, period: Period) -> Price { source.smooth(Smooth::TEMA, period) } @@ -13,7 +13,7 @@ mod tests { let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); let expected = vec![1.0, 1.875, 2.9375, 4.0, 5.03125]; - let result: Vec = tema(&source, 3).into(); + let result: Vec = tema(&source, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/tma.rs b/ta_lib/indicators/trend/src/tma.rs deleted file mode 100644 index 2ab7503a..00000000 --- a/ta_lib/indicators/trend/src/tma.rs +++ /dev/null @@ -1,23 +0,0 @@ -use core::prelude::*; - -pub fn tma(source: &Series, period: usize) -> Series { - let n = (0.5 * period as f32).signum() as usize; - let m = (0.5 * period as f32 + 1.) as usize; - - source.smooth(Smooth::SMA, n).smooth(Smooth::SMA, m) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_tma() { - let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); - let expected = vec![1.0, 1.5, 2.5, 3.5, 4.5]; - - let result: Vec = tma(&source, 3).into(); - - assert_eq!(result, expected); - } -} diff --git a/ta_lib/indicators/trend/src/trima.rs b/ta_lib/indicators/trend/src/trima.rs new file mode 100644 index 00000000..4b0afcc0 --- /dev/null +++ b/ta_lib/indicators/trend/src/trima.rs @@ -0,0 +1,25 @@ +use core::prelude::*; + +pub fn trima(source: &Price, period: Period) -> Price { + let period_half = HALF * period as Scalar; + + let n = period_half.ceil() as Period; + let m = (period_half.floor() + 1.) as Period; + + source.smooth(Smooth::SMA, n).smooth(Smooth::SMA, m) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_trima() { + let source = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let expected = vec![1.0, 1.25, 2.0, 3.0, 4.0]; + + let result: Vec = trima(&source, 3).into(); + + assert_eq!(result, expected); + } +} diff --git a/ta_lib/indicators/trend/src/ults.rs b/ta_lib/indicators/trend/src/ults.rs new file mode 100644 index 00000000..1fe12e45 --- /dev/null +++ b/ta_lib/indicators/trend/src/ults.rs @@ -0,0 +1,21 @@ +use core::prelude::*; + +pub fn ults(source: &Price, period: Period) -> Price { + source.smooth(Smooth::ULTS, period) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ults() { + let source = Series::from([18.898, 18.838, 18.881, 18.925, 18.846]); + let period = 3; + let expected = vec![18.898, 18.85439, 18.853537, 18.922752, 18.880928]; + + let result: Vec = ults(&source, period).into(); + + assert_eq!(result, expected); + } +} diff --git a/ta_lib/indicators/trend/src/vi.rs b/ta_lib/indicators/trend/src/vi.rs index eb2ab8bf..a293d147 100644 --- a/ta_lib/indicators/trend/src/vi.rs +++ b/ta_lib/indicators/trend/src/vi.rs @@ -1,16 +1,11 @@ use core::prelude::*; -pub fn vi( - high: &Series, - low: &Series, - atr: &Series, - period: usize, -) -> (Series, Series) { +pub fn vi(high: &Price, low: &Price, atr: &Price, period: Period) -> (Price, Price) { let vmp = (high - low.shift(1)).abs().sum(period); let vmm = (low - high.shift(1)).abs().sum(period); - let sum_atr = atr.sum(period); + let atrs = atr.sum(period); - (vmp / &sum_atr, vmm / &sum_atr) + (vmp / &atrs, vmm / &atrs) } #[cfg(test)] @@ -50,8 +45,8 @@ mod tests { ]; let (vip, vim) = vi(&high, &low, &atr, 2); - let vvip: Vec = vip.into(); - let vvim: Vec = vim.into(); + let vvip: Vec = vip.into(); + let vvim: Vec = vim.into(); assert_eq!(vvip, expected_vip); assert_eq!(vvim, expected_vim); diff --git a/ta_lib/indicators/trend/src/vidya.rs b/ta_lib/indicators/trend/src/vidya.rs index 91c3f334..3ce57811 100644 --- a/ta_lib/indicators/trend/src/vidya.rs +++ b/ta_lib/indicators/trend/src/vidya.rs @@ -1,8 +1,8 @@ use core::prelude::*; -pub fn vidya(source: &Series, fast_period: usize, slow_period: usize) -> Series { - let k = source.std(fast_period) / source.std(slow_period); - let alpha = 2. / (fast_period as f32 + 1.) * k.nz(Some(ZERO)); +pub fn vidya(source: &Price, period_fast: Period, period_slow: Period) -> Price { + let k = source.std(period_fast) / source.std(period_slow); + let alpha = 2. / ((period_fast + 1) as Scalar) * k.nz(Some(ZERO)); source.ew(&alpha, source) } @@ -18,7 +18,7 @@ mod tests { let slow_period = 3; let expected = vec![100.0, 103.33333, 110.46114, 114.34566, 119.90917]; - let result: Vec = vidya(&source, fast_period, slow_period).into(); + let result: Vec = vidya(&source, fast_period, slow_period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/vwema.rs b/ta_lib/indicators/trend/src/vwema.rs index 28db0b79..53d78ef5 100644 --- a/ta_lib/indicators/trend/src/vwema.rs +++ b/ta_lib/indicators/trend/src/vwema.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn vwema(source: &Series, volume: &Series, period: usize) -> Series { +pub fn vwema(source: &Price, volume: &Price, period: Period) -> Price { (source * volume).smooth(Smooth::EMA, period) / volume.smooth(Smooth::EMA, period) } @@ -15,7 +15,7 @@ mod tests { let period = 3; let expected = vec![100.0, 102.77778, 112.34501, 118.14274, 124.07811]; - let result: Vec = vwema(&source, &volume, period).into(); + let result: Vec = vwema(&source, &volume, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/vwma.rs b/ta_lib/indicators/trend/src/vwma.rs index 53731626..be359791 100644 --- a/ta_lib/indicators/trend/src/vwma.rs +++ b/ta_lib/indicators/trend/src/vwma.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn vwma(source: &Series, volume: &Series, period: usize) -> Series { +pub fn vwma(source: &Price, volume: &Price, period: Period) -> Price { (source * volume).smooth(Smooth::SMA, period) / volume.smooth(Smooth::SMA, period) } diff --git a/ta_lib/indicators/trend/src/wma.rs b/ta_lib/indicators/trend/src/wma.rs index f748f778..a30e3f78 100644 --- a/ta_lib/indicators/trend/src/wma.rs +++ b/ta_lib/indicators/trend/src/wma.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn wma(source: &Series, period: usize) -> Series { +pub fn wma(source: &Price, period: Period) -> Price { source.smooth(Smooth::WMA, period) } @@ -16,12 +16,12 @@ mod tests { ]); let period = 3; let expected = [ - 0.0, 0.0, 6.5467167, 6.573034, 6.557817, 6.5248, 6.5190334, 6.5399, 6.5366497, - 6.5326996, 6.5363164, 6.5327663, 6.5179005, 6.4862, 6.4804664, 6.487, 6.5022836, - 6.516834, 6.5373, 6.5378666, + 6.5231996, 6.5393333, 6.5467167, 6.573034, 6.557817, 6.5248, 6.5190334, 6.5399, + 6.5366497, 6.5326996, 6.5363164, 6.5327663, 6.5179005, 6.4862, 6.4804664, 6.487, + 6.5022836, 6.516834, 6.5373, 6.5378666, ]; - let result: Vec = wma(&source, period).into(); + let result: Vec = wma(&source, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/zlema.rs b/ta_lib/indicators/trend/src/zlema.rs index 3fdc040f..5ed6e610 100644 --- a/ta_lib/indicators/trend/src/zlema.rs +++ b/ta_lib/indicators/trend/src/zlema.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn zlema(source: &Series, period: usize) -> Series { +pub fn zlema(source: &Price, period: Period) -> Price { source.smooth(Smooth::ZLEMA, period) } @@ -12,9 +12,9 @@ mod tests { fn test_zlema() { let source = Series::from([18.898, 18.838, 18.881, 18.925, 18.846]); let period = 3; - let expected = vec![0.0, 18.777998, 18.851, 18.91, 18.838501]; + let expected = vec![18.898, 18.838, 18.881, 18.925, 18.846]; - let result: Vec = zlema(&source, period).into(); + let result: Vec = zlema(&source, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/zlhma.rs b/ta_lib/indicators/trend/src/zlhma.rs index 9e156c75..dfb28d2f 100644 --- a/ta_lib/indicators/trend/src/zlhma.rs +++ b/ta_lib/indicators/trend/src/zlhma.rs @@ -1,9 +1,9 @@ use core::prelude::*; -pub fn zlhma(source: &Series, period: usize, smooth_period: usize) -> Series { +pub fn zlhma(source: &Price, period: Period, period_smooth: Period) -> Price { source .smooth(Smooth::HMA, period) - .smooth(Smooth::HMA, smooth_period) + .smooth(Smooth::HMA, period_smooth) } #[cfg(test)] @@ -20,17 +20,17 @@ mod tests { 7.1560, ]); let period = 3; - let smooth_period = 2; + let period_smooth = 2; let expected = vec![ - 0.0, 0.0, 0.0, 7.1295276, 7.1204166, 7.137084, 7.145639, 7.1091666, 7.101055, - 7.1242785, 7.1356115, 7.159471, 7.1533885, 7.139722, 7.1393056, 7.1438065, 7.1512504, - 7.1530538, 7.140639, 7.1419444, 7.15478, 7.1451106, 7.133417, 7.128304, 7.1368594, - 7.125918, 7.1144705, 7.111918, 7.1284165, 7.125667, 7.1206403, 7.123055, 7.117305, - 7.124556, 7.124973, 7.110944, 7.119223, 7.133779, 7.136498, 7.132387, 7.113446, - 7.1199985, 7.172027, 7.1665835, 7.151444, + 7.1135, 7.0681653, 7.133167, 7.1317506, 7.117639, 7.1519723, 7.1459727, 7.0828333, + 7.1051674, 7.135945, 7.1408334, 7.178806, 7.141832, 7.137277, 7.1376395, 7.1475835, + 7.1553617, 7.1536107, 7.131306, 7.146833, 7.1612234, 7.1357775, 7.129084, 7.124083, + 7.144527, 7.11325, 7.111916, 7.1083617, 7.1425276, 7.1171117, 7.1228623, 7.122388, + 7.113417, 7.131778, 7.1216393, 7.1024995, 7.128556, 7.140001, 7.138055, 7.129722, + 7.099668, 7.1302214, 7.2068057, 7.1503606, 7.1527786, ]; - let result: Vec = zlhma(&source, period, smooth_period).into(); + let result: Vec = zlhma(&source, period, period_smooth).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/zlsma.rs b/ta_lib/indicators/trend/src/zlsma.rs index d04c5574..f9097022 100644 --- a/ta_lib/indicators/trend/src/zlsma.rs +++ b/ta_lib/indicators/trend/src/zlsma.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn zlsma(source: &Series, period: usize) -> Series { +pub fn zlsma(source: &Price, period: Period) -> Price { let lsma = source.smooth(Smooth::LSMA, period); 2. * &lsma - lsma.smooth(Smooth::LSMA, period) @@ -11,25 +11,16 @@ mod tests { use super::*; #[test] - fn test_zlsma() { + fn test_lsma() { let source = Series::from([ - 7.1135, 7.088, 7.112, 7.1205, 7.1195, 7.136, 7.1405, 7.112, 7.1095, 7.1220, 7.1310, - 7.1550, 7.1480, 7.1435, 7.1405, 7.1440, 7.1495, 7.1515, 7.1415, 7.1445, 7.1525, 7.1440, - 7.1370, 7.1305, 7.1375, 7.1250, 7.1190, 7.1135, 7.1280, 7.1220, 7.1230, 7.1225, 7.1180, - 7.1250, 7.1230, 7.1130, 7.1210, 7.13, 7.134, 7.132, 7.116, 7.1235, 7.1645, 7.1565, - 7.1560, + 12.529, 12.504, 12.517, 12.542, 12.547, 12.577, 12.539, 12.577, 12.490, 12.490, ]); - let period = 3; let expected = vec![ - 7.1135, 7.088, 7.1106257, 7.1236806, 7.1175256, 7.135417, 7.1420703, 7.1117597, - 7.1072783, 7.1239357, 7.1312175, 7.1539693, 7.1497893, 7.1412854, 7.141484, 7.143811, - 7.149786, 7.15152, 7.1415567, 7.143571, 7.1534204, 7.1443624, 7.135909, 7.1310244, - 7.137115, 7.1262765, 7.117335, 7.114386, 7.127291, 7.1236925, 7.12111, 7.1235223, - 7.117843, 7.1245036, 7.1240044, 7.1124487, 7.1202955, 7.1312103, 7.1337156, 7.131895, - 7.116152, 7.1222796, 7.1652455, 7.1591597, 7.152091, + 12.529, 12.504, 12.515945, 12.543776, 12.547166, 12.574857, 12.54283, 12.570409, + 12.499587, 12.478523, ]; - let result: Vec = zlsma(&source, period).into(); + let result: Vec = zlsma(&source, 3).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/trend/src/zltema.rs b/ta_lib/indicators/trend/src/zltema.rs index f9f42178..105bd108 100644 --- a/ta_lib/indicators/trend/src/zltema.rs +++ b/ta_lib/indicators/trend/src/zltema.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn zltema(source: &Series, period: usize) -> Series { +pub fn zltema(source: &Price, period: Period) -> Price { source .smooth(Smooth::TEMA, period) .smooth(Smooth::TEMA, period) @@ -16,7 +16,7 @@ mod tests { let period = 3; let expected = vec![18.898, 18.852058, 18.865294, 18.910984, 18.869732]; - let result: Vec = zltema(&source, period).into(); + let result: Vec = zltema(&source, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/volatility/src/atr.rs b/ta_lib/indicators/volatility/src/atr.rs deleted file mode 100644 index fc7cfd6f..00000000 --- a/ta_lib/indicators/volatility/src/atr.rs +++ /dev/null @@ -1,56 +0,0 @@ -use crate::tr; -use core::prelude::*; - -pub fn atr( - high: &Series, - low: &Series, - close: &Series, - smooth_type: Smooth, - period: usize, -) -> Series { - tr(high, low, close).smooth(smooth_type, period) -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_atr_smma() { - let high = Series::from([ - 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, - 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, - ]); - let low = Series::from([ - 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, - 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, - ]); - let close = Series::from([ - 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038, - 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, - ]); - let period = 3; - let expected = [ - 0.01819992, - 0.03396654, - 0.044011116, - 0.05464077, - 0.05036052, - 0.05254035, - 0.051826984, - 0.040484603, - 0.036689714, - 0.032826394, - 0.033317544, - 0.039045, - 0.03442996, - 0.047553174, - 0.03966879, - 0.03404585, - ]; - - let result: Vec = atr(&high, &low, &close, Smooth::SMMA, period).into(); - - assert_eq!(result, expected); - } -} diff --git a/ta_lib/indicators/volatility/src/bb.rs b/ta_lib/indicators/volatility/src/bb.rs index 5117f5f6..de39caa1 100644 --- a/ta_lib/indicators/volatility/src/bb.rs +++ b/ta_lib/indicators/volatility/src/bb.rs @@ -1,18 +1,25 @@ use core::prelude::*; -pub fn bb( - source: &Series, - smooth_type: Smooth, - period: usize, - factor: f32, -) -> (Series, Series, Series) { - let middle_band = source.smooth(smooth_type, period); - let std_mul = source.std(period) * factor; - - let upper_band = &middle_band + &std_mul; - let lower_band = &middle_band - &std_mul; - - (upper_band, middle_band, lower_band) +pub fn bb(source: &Price, smooth: Smooth, period: Period, factor: Scalar) -> (Price, Price, Price) { + let middle = source.smooth(smooth, period); + let volatility = factor * source.std(period); + + let upper = &middle + &volatility; + let lower = &middle - &volatility; + + (upper, middle, lower) +} + +pub fn bbp(source: &Price, smooth: Smooth, period: Period, factor: Scalar) -> Price { + let (upb, _, lb) = bb(source, smooth, period, factor); + + (source - &lb) / (upb - lb) +} + +pub fn bbw(source: &Price, smooth: Smooth, period: Period, factor: Scalar) -> Price { + let (upb, mb, lb) = bb(source, smooth, period, factor); + + SCALE * (upb - lb) / mb } #[cfg(test)] @@ -37,9 +44,9 @@ mod tests { let (upper_band, middle_band, lower_band) = bb(&source, Smooth::SMA, period, factor); - let result_upper_band: Vec = upper_band.into(); - let result_middle_band: Vec = middle_band.into(); - let result_lower_band: Vec = lower_band.into(); + let result_upper_band: Vec = upper_band.into(); + let result_middle_band: Vec = middle_band.into(); + let result_lower_band: Vec = lower_band.into(); for i in 0..source.len() { let a = result_upper_band[i]; @@ -55,4 +62,32 @@ mod tests { assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b); } } + + #[test] + fn test_bbp() { + let source = Series::from([2.0, 4.0, 6.0, 8.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0]); + let period = 3; + let factor = 2.0; + let expected = [ + 0.0, 0.75, 0.80618626, 0.80618614, 0.8061864, 0.5, 0.19381316, 0.19381316, 0.19381405, + 0.19381405, + ]; + let result: Vec = bbp(&source, Smooth::SMA, period, factor).into(); + + assert_eq!(result, expected); + } + + #[test] + fn test_bbw() { + let source = Series::from([2.0, 4.0, 6.0, 8.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0]); + let period = 3; + let factor = 2.0; + let expected = [ + 0.0, 133.33333, 163.2993, 108.86625, 81.64961, 36.288662, 36.288662, 40.824745, + 46.65699, 54.433155, + ]; + let result: Vec = bbw(&source, Smooth::SMA, period, factor).into(); + + assert_eq!(result, expected); + } } diff --git a/ta_lib/indicators/volatility/src/bbw.rs b/ta_lib/indicators/volatility/src/bbw.rs deleted file mode 100644 index aea1e1e9..00000000 --- a/ta_lib/indicators/volatility/src/bbw.rs +++ /dev/null @@ -1,30 +0,0 @@ -use core::prelude::*; - -pub fn bbw(source: &Series, smooth_type: Smooth, period: usize, factor: f32) -> Series { - let middle_band = source.smooth(smooth_type, period); - let std_mul = source.std(period) * factor; - - let upper_band = &middle_band + &std_mul; - let lower_band = &middle_band - &std_mul; - - (upper_band - lower_band) / &middle_band -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_bbw() { - let source = Series::from([2.0, 4.0, 6.0, 8.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0]); - let period = 3; - let factor = 2.0; - let expected = [ - 0.0, 1.3333334, 1.632993, 1.0886625, 0.81649613, 0.36288664, 0.36288664, 0.40824747, - 0.4665699, 0.54433155, - ]; - let result: Vec = bbw(&source, Smooth::SMA, period, factor).into(); - - assert_eq!(result, expected); - } -} diff --git a/ta_lib/indicators/volatility/src/dch.rs b/ta_lib/indicators/volatility/src/dch.rs index 8caa44ec..34ad49b8 100644 --- a/ta_lib/indicators/volatility/src/dch.rs +++ b/ta_lib/indicators/volatility/src/dch.rs @@ -1,16 +1,18 @@ use core::prelude::*; -pub fn dch( - high: &Series, - low: &Series, - period: usize, -) -> (Series, Series, Series) { - let upper_band = high.highest(period); - let lower_band = low.lowest(period); +pub fn dch(high: &Price, low: &Price, period: Period) -> (Price, Price, Price) { + let upper = high.highest(period); + let lower = low.lowest(period); - let middle_band = 0.5 * (&upper_band + &lower_band); + let middle = HALF * (&upper + &lower); - (upper_band, middle_band, lower_band) + (upper, middle, lower) +} + +pub fn dchw(high: &Price, low: &Price, period: Period) -> Price { + let (upb, _, lb) = dch(high, low, period); + + upb - lb } #[cfg(test)] @@ -29,12 +31,25 @@ mod tests { let (upper, middle, lower) = dch(&high, &low, period); - let result_upper: Vec = upper.into(); - let result_lower: Vec = lower.into(); - let result_middle: Vec = middle.into(); + let result_upper: Vec = upper.into(); + let result_lower: Vec = lower.into(); + let result_middle: Vec = middle.into(); assert_eq!(result_upper, expected_upper); assert_eq!(result_lower, expected_lower); assert_eq!(result_middle, expected_middle); } + + #[test] + fn test_dchw() { + let high = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let low = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let period = 3; + + let expected = vec![0.0, 1.0, 2.0, 2.0, 2.0]; + + let result: Vec = dchw(&high, &low, period).into(); + + assert_eq!(result, expected); + } } diff --git a/ta_lib/indicators/volatility/src/gkyz.rs b/ta_lib/indicators/volatility/src/gkyz.rs new file mode 100644 index 00000000..fb743b7d --- /dev/null +++ b/ta_lib/indicators/volatility/src/gkyz.rs @@ -0,0 +1,34 @@ +use core::prelude::*; + +pub fn gkyz(open: &Price, high: &Price, low: &Price, close: &Price, period: Period) -> Price { + let gkyzl = (open / close.shift(1).nz(Some(ZERO))).log(); + let pkl = (high / low).log(); + let gkl = (close / open).log(); + let gm = 2.0 * 2.0_f32.ln() - 1.0; + + let gkyzs = (1.0 / period as Scalar) * gkyzl.pow(2).sum(period); + let pks = (1.0 / (2.0 * period as Scalar)) * pkl.pow(2).sum(period); + let gs = (gm / period as Scalar) * gkl.pow(2).sum(period); + + (gkyzs + pks - gs).sqrt() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_garman_klass_yang_zhang() { + let open = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let high = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let low = Series::from([3.0, 2.0, 3.0, 4.0, 5.0]); + let close = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let period = 3; + + let expected = vec![0.0, 0.60109, 0.6450658, 0.49248216, 0.31461933]; + + let result: Vec = gkyz(&open, &high, &low, &close, period).into(); + + assert_eq!(result, expected); + } +} diff --git a/ta_lib/indicators/volatility/src/kb.rs b/ta_lib/indicators/volatility/src/kb.rs index a7253b36..dc0d3191 100644 --- a/ta_lib/indicators/volatility/src/kb.rs +++ b/ta_lib/indicators/volatility/src/kb.rs @@ -1,17 +1,25 @@ use core::prelude::*; -pub fn kb( - source: &Series, - period: usize, - factor: f32, -) -> (Series, Series, Series) { - let middle_band = 0.5 * (source.highest(period) + source.lowest(period)); - let volatility = source.std(period).highest(period) * factor; - - let upper_band = &middle_band + &volatility; - let lower_band = &middle_band - &volatility; - - (upper_band, middle_band, lower_band) +pub fn kb(source: &Price, period: usize, factor: Scalar) -> (Price, Price, Price) { + let middle = HALF * (source.highest(period) + source.lowest(period)); + let volatility = factor * source.std(period).highest(period); + + let upper = &middle + &volatility; + let lower = &middle - &volatility; + + (upper, middle, lower) +} + +pub fn kbp(source: &Price, period: usize, factor: Scalar) -> Price { + let (upb, _, lb) = kb(source, period, factor); + + (source - &lb) / (upb - lb) +} + +pub fn kbw(source: &Price, period: usize, factor: Scalar) -> Price { + let (upb, mb, lb) = kb(source, period, factor); + + SCALE * (upb - lb) / mb } #[cfg(test)] @@ -36,9 +44,9 @@ mod tests { let (upper_band, middle_band, lower_band) = kb(&source, period, factor); - let result_upper_band: Vec = upper_band.into(); - let result_middle_band: Vec = middle_band.into(); - let result_lower_band: Vec = lower_band.into(); + let result_upper_band: Vec = upper_band.into(); + let result_middle_band: Vec = middle_band.into(); + let result_lower_band: Vec = lower_band.into(); for i in 0..source.len() { let a = result_upper_band[i]; @@ -54,4 +62,34 @@ mod tests { assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b); } } + + #[test] + fn test_kbp() { + let source = Series::from([2.0, 4.0, 6.0, 8.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0]); + let period = 3; + let factor = 2.0; + let expected = [ + 0.0, 0.75, 0.80618626, 0.80618614, 0.80618614, 0.5, 0.3469068, 0.19381316, 0.19381405, + 0.19381405, + ]; + + let result: Vec = kbp(&source, period, factor).into(); + + assert_eq!(result, expected); + } + + #[test] + fn test_kbw() { + let source = Series::from([2.0, 4.0, 6.0, 8.0, 10.0, 9.0, 8.0, 7.0, 6.0, 5.0]); + let period = 3; + let factor = 2.0; + let expected = [ + 0.0, 133.33333, 163.2993, 108.86625, 81.64969, 72.5775, 72.57743, 40.824745, 46.65699, + 54.433155, + ]; + + let result: Vec = kbw(&source, period, factor).into(); + + assert_eq!(result, expected); + } } diff --git a/ta_lib/indicators/volatility/src/kch.rs b/ta_lib/indicators/volatility/src/kch.rs index 9e3193ef..f39922cf 100644 --- a/ta_lib/indicators/volatility/src/kch.rs +++ b/ta_lib/indicators/volatility/src/kch.rs @@ -1,19 +1,31 @@ use core::prelude::*; pub fn kch( - source: &Series, - atr: &Series, - smooth_type: Smooth, - period: usize, - factor: f32, -) -> (Series, Series, Series) { - let middle_band = source.smooth(smooth_type, period); - let atr = atr * factor; - - let upper_band = &middle_band + &atr; - let lower_band = &middle_band - &atr; - - (upper_band, middle_band, lower_band) + source: &Price, + smooth: Smooth, + atr: &Price, + period: Period, + factor: Scalar, +) -> (Price, Price, Price) { + let middle = source.smooth(smooth, period); + let volatility = factor * atr; + + let upper = &middle + &volatility; + let lower = &middle - &volatility; + + (upper, middle, lower) +} + +pub fn kchp(source: &Price, smooth: Smooth, atr: &Price, period: Period, factor: Scalar) -> Price { + let (upc, _, lc) = kch(source, smooth, atr, period, factor); + + (source - &lc) / (upc - lc) +} + +pub fn kchw(source: &Price, smooth: Smooth, atr: &Price, period: Period, factor: Scalar) -> Price { + let (upc, mc, lc) = kch(source, smooth, atr, period, factor); + + SCALE * (upc - lc) / mc } #[cfg(test)] @@ -52,11 +64,11 @@ mod tests { 19.168068, ]; - let (upper_band, middle_band, lower_band) = kch(&hlc3, &atr, Smooth::EMA, period, factor); + let (upper_band, middle_band, lower_band) = kch(&hlc3, Smooth::EMA, &atr, period, factor); - let result_upper_band: Vec = upper_band.into(); - let result_middle_band: Vec = middle_band.into(); - let result_lower_band: Vec = lower_band.into(); + let result_upper_band: Vec = upper_band.into(); + let result_middle_band: Vec = middle_band.into(); + let result_lower_band: Vec = lower_band.into(); for i in 0..high.len() { let a = result_upper_band[i]; @@ -72,4 +84,56 @@ mod tests { assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b); } } + + #[test] + fn test_kchp() { + let close = Series::from([ + 19.102, 19.100, 19.146, 19.181, 19.155, 19.248, 19.309, 19.355, 19.439, + ]); + let high = Series::from([ + 19.129, 19.116, 19.154, 19.195, 19.217, 19.285, 19.341, 19.394, 19.450, + ]); + let low = Series::from([ + 19.090, 19.086, 19.074, 19.145, 19.141, 19.155, 19.219, 19.306, 19.355, + ]); + let period = 3; + let atr_period = 3; + let atr = atr(&high, &low, &close, Smooth::SMMA, atr_period); + let hlc3 = typical_price(&high, &low, &close); + let factor = 2.0; + let expected = [ + 0.5, 0.47801208, 0.5513957, 0.6472284, 0.55733025, 0.6086896, 0.6256523, 0.64774823, + 0.6573705, + ]; + + let result: Vec = kchp(&hlc3, Smooth::EMA, &atr, period, factor).into(); + + assert_eq!(result, expected); + } + + #[test] + fn test_kchw() { + let close = Series::from([ + 19.102, 19.100, 19.146, 19.181, 19.155, 19.248, 19.309, 19.355, 19.439, + ]); + let high = Series::from([ + 19.129, 19.116, 19.154, 19.195, 19.217, 19.285, 19.341, 19.394, 19.450, + ]); + let low = Series::from([ + 19.090, 19.086, 19.074, 19.145, 19.141, 19.155, 19.219, 19.306, 19.355, + ]); + let period = 3; + let atr_period = 3; + let atr = atr(&high, &low, &close, Smooth::SMMA, atr_period); + let hlc3 = typical_price(&high, &low, &close); + let factor = 2.0; + let expected = [ + 0.8164454, 0.7537607, 1.060274, 1.0539857, 1.2310988, 1.7222717, 1.9907007, 1.9314072, + 1.9380906, + ]; + + let result: Vec = kchw(&hlc3, Smooth::EMA, &atr, period, factor).into(); + + assert_eq!(result, expected); + } } diff --git a/ta_lib/indicators/volatility/src/lib.rs b/ta_lib/indicators/volatility/src/lib.rs index 6fe81888..14e02a2a 100644 --- a/ta_lib/indicators/volatility/src/lib.rs +++ b/ta_lib/indicators/volatility/src/lib.rs @@ -1,19 +1,21 @@ -mod atr; mod bb; -mod bbw; mod dch; +mod gkyz; mod kb; mod kch; +mod pk; mod ppb; -mod snatr; +mod rs; mod tr; +mod yz; -pub use atr::atr; -pub use bb::bb; -pub use bbw::bbw; -pub use dch::dch; -pub use kb::kb; -pub use kch::kch; -pub use ppb::ppb; -pub use snatr::snatr; -pub use tr::tr; +pub use bb::{bb, bbp, bbw}; +pub use dch::{dch, dchw}; +pub use gkyz::gkyz; +pub use kb::{kb, kbp, kbw}; +pub use kch::{kch, kchp, kchw}; +pub use pk::pk; +pub use ppb::{ppb, ppbp, ppbw}; +pub use rs::rs; +pub use tr::{atr, snatr, tr, wtr}; +pub use yz::yz; diff --git a/ta_lib/indicators/volatility/src/pk.rs b/ta_lib/indicators/volatility/src/pk.rs new file mode 100644 index 00000000..f6ad2c7e --- /dev/null +++ b/ta_lib/indicators/volatility/src/pk.rs @@ -0,0 +1,29 @@ +use core::prelude::*; + +pub fn pk(high: &Price, low: &Price, period: Period) -> Price { + let hll = (high / low).log(); + + let factor = 1. / (4.0 * period as Scalar * 2.0_f32.ln()); + + let hls = factor * hll.pow(2).sum(period); + + hls.sqrt() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parkinson() { + let high = Series::from([1.0, 2.0, 3.0, 2.0, 5.0]); + let low = Series::from([3.0, 2.0, 3.0, 4.0, 5.0]); + let period = 3; + + let expected = vec![0.38092643, 0.38092643, 0.38092643, 0.24033782, 0.24033782]; + + let result: Vec = pk(&high, &low, period).into(); + + assert_eq!(result, expected); + } +} diff --git a/ta_lib/indicators/volatility/src/ppb.rs b/ta_lib/indicators/volatility/src/ppb.rs index 57c68333..5b0c5a27 100644 --- a/ta_lib/indicators/volatility/src/ppb.rs +++ b/ta_lib/indicators/volatility/src/ppb.rs @@ -1,67 +1,151 @@ use core::prelude::*; pub fn ppb( - high: &Series, - low: &Series, - close: &Series, - smooth_type: Smooth, - period: usize, - factor: f32, -) -> (Series, Series, Series) { - let ppvih = high.std(period).highest(period) * factor; - let ppvil = low.std(period).lowest(period) * factor; - - let middle_band = close.smooth(smooth_type, period); - - let upper_band = &middle_band + ppvih; - let lower_band = &middle_band - ppvil; - - (upper_band, middle_band, lower_band) + source: &Price, + high: &Price, + low: &Price, + smooth: Smooth, + period: Period, + factor: Scalar, +) -> (Price, Price, Price) { + let ppvih = factor * high.std(period).highest(period); + let ppvil = factor * low.std(period).lowest(period); + + let middle = source.smooth(smooth, period); + + let upper = &middle + ppvih; + let lower = &middle - ppvil; + + (upper, middle, lower) +} + +pub fn ppbp( + source: &Price, + high: &Price, + low: &Price, + smooth: Smooth, + period: Period, + factor: Scalar, +) -> Price { + let (upb, _, lb) = ppb(source, high, low, smooth, period, factor); + + (source - &lb) / (upb - lb) } -#[test] -fn test_ppb() { - let high = Series::from([ - 19.129, 19.116, 19.154, 19.195, 19.217, 19.285, 19.341, 19.394, 19.450, - ]); - let low = Series::from([ - 19.090, 19.086, 19.074, 19.145, 19.141, 19.155, 19.219, 19.306, 19.355, - ]); - let close = Series::from([ - 19.102, 19.100, 19.146, 19.181, 19.155, 19.248, 19.309, 19.355, 19.439, - ]); - let factor = 2.0; - let period = 3; - let epsilon = 0.0001; - let expected_upper_band = [ - 19.102, 19.116625, 19.149145, 19.207697, 19.22603, 19.27279, 19.338594, 19.405262, - 19.468927, - ]; - let expected_middle_band = [ - 19.102, 19.101, 19.116, 19.142332, 19.160666, 19.194666, 19.237333, 19.304, 19.367666, - ]; - let expected_lower_band = [ - 19.102, 19.101, 19.116, 19.131283, 19.149616, 19.194666, 19.237333, 19.304, 19.299559, - ]; - - let (upper_band, middle_band, lower_band) = - ppb(&high, &low, &close, Smooth::SMA, period, factor); - - let result_upper_band: Vec = upper_band.into(); - let result_middle_band: Vec = middle_band.into(); - let result_lower_band: Vec = lower_band.into(); - - for i in 0..high.len() { - let a = result_upper_band[i]; - let b = expected_upper_band[i]; - assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b); - - let a = result_middle_band[i]; - let b = expected_middle_band[i]; - assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b); - - let a = result_lower_band[i]; - let b = expected_lower_band[i]; - assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b); +pub fn ppbw( + source: &Price, + high: &Price, + low: &Price, + smooth: Smooth, + period: Period, + factor: Scalar, +) -> Price { + let (upb, mb, lb) = ppb(source, high, low, smooth, period, factor); + + SCALE * (upb - lb) / mb +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_ppb() { + let source = Series::from([ + 19.102, 19.100, 19.146, 19.181, 19.155, 19.248, 19.309, 19.355, 19.439, + ]); + let high = Series::from([ + 19.129, 19.116, 19.154, 19.195, 19.217, 19.285, 19.341, 19.394, 19.450, + ]); + let low = Series::from([ + 19.090, 19.086, 19.074, 19.145, 19.141, 19.155, 19.219, 19.306, 19.355, + ]); + let factor = 2.0; + let period = 3; + let epsilon = 0.0001; + let expected_upper_band = [ + 19.102, 19.116625, 19.149145, 19.207697, 19.22603, 19.27279, 19.338594, 19.405262, + 19.468927, + ]; + let expected_middle_band = [ + 19.102, 19.101, 19.116, 19.142332, 19.160666, 19.194666, 19.237333, 19.304, 19.367666, + ]; + let expected_lower_band = [ + 19.102, 19.101, 19.116, 19.131283, 19.149616, 19.194666, 19.237333, 19.304, 19.299559, + ]; + + let (upper_band, middle_band, lower_band) = + ppb(&source, &high, &low, Smooth::SMA, period, factor); + + let result_upper_band: Vec = upper_band.into(); + let result_middle_band: Vec = middle_band.into(); + let result_lower_band: Vec = lower_band.into(); + + for i in 0..high.len() { + let a = result_upper_band[i]; + let b = expected_upper_band[i]; + assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b); + + let a = result_middle_band[i]; + let b = expected_middle_band[i]; + assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b); + + let a = result_lower_band[i]; + let b = expected_lower_band[i]; + assert!((a - b).abs() < epsilon, "at position {}: {} != {}", i, a, b); + } + } + + #[test] + fn test_ppbp() { + let source = Series::from([ + 19.102, 19.100, 19.146, 19.181, 19.155, 19.248, 19.309, 19.355, 19.439, + ]); + let high = Series::from([ + 19.129, 19.116, 19.154, 19.195, 19.217, 19.285, 19.341, 19.394, 19.450, + ]); + let low = Series::from([ + 19.090, 19.086, 19.074, 19.145, 19.141, 19.155, 19.219, 19.306, 19.355, + ]); + let factor = 2.0; + let period = 3; + let expected = [ + 0.0, + -0.063964844, + 0.9051099, + 0.6506003, + 0.070439056, + 0.682666, + 0.70774156, + 0.5036542, + 0.8232956, + ]; + + let result: Vec = ppbp(&source, &high, &low, Smooth::SMA, period, factor).into(); + + assert_eq!(result, expected); + } + + #[test] + fn test_ppbw() { + let source = Series::from([ + 19.102, 19.100, 19.146, 19.181, 19.155, 19.248, 19.309, 19.355, 19.439, + ]); + let high = Series::from([ + 19.129, 19.116, 19.154, 19.195, 19.217, 19.285, 19.341, 19.394, 19.450, + ]); + let low = Series::from([ + 19.090, 19.086, 19.074, 19.145, 19.141, 19.155, 19.219, 19.306, 19.355, + ]); + let factor = 2.0; + let period = 3; + let expected = [ + 0.0, 0.081802, 0.17339352, 0.39918908, 0.39880714, 0.40701413, 0.5263783, 0.52456045, + 0.8744923, + ]; + + let result: Vec = ppbw(&source, &high, &low, Smooth::SMA, period, factor).into(); + + assert_eq!(result, expected); } } diff --git a/ta_lib/indicators/volatility/src/rs.rs b/ta_lib/indicators/volatility/src/rs.rs new file mode 100644 index 00000000..2dbcf2b0 --- /dev/null +++ b/ta_lib/indicators/volatility/src/rs.rs @@ -0,0 +1,32 @@ +use core::prelude::*; + +pub fn rs(open: &Price, high: &Price, low: &Price, close: &Price, period: Period) -> Price { + let hl = (high / close).log() * (high / open).log(); + let ll = (low / close).log() * (low / open).log(); + let factor = 1.0 / period as Scalar; + + let hs = factor * hl.sum(period); + let ls = factor * ll.sum(period); + + (hs + ls).sqrt() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_rogers_satchell() { + let open = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let high = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let low = Series::from([3.0, 2.0, 3.0, 4.0, 5.0]); + let close = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let period = 3; + + let expected = vec![0.63428414, 0.63428414, 0.63428414, 0.0, 0.0]; + + let result: Vec = rs(&open, &high, &low, &close, period).into(); + + assert_eq!(result, expected); + } +} diff --git a/ta_lib/indicators/volatility/src/snatr.rs b/ta_lib/indicators/volatility/src/snatr.rs deleted file mode 100644 index 7a6b99ef..00000000 --- a/ta_lib/indicators/volatility/src/snatr.rs +++ /dev/null @@ -1,45 +0,0 @@ -use core::prelude::*; - -pub fn snatr( - atr: &Series, - atr_period: usize, - smooth_type: Smooth, - smoothing_period: usize, -) -> Series { - ((atr - atr.lowest(atr_period)) / (atr.highest(atr_period) - atr.lowest(atr_period))) - .smooth(smooth_type, smoothing_period) -} - -#[test] -fn test_snatr() { - use crate::atr; - - let high = Series::from([ - 19.129, 19.116, 19.154, 19.195, 19.217, 19.285, 19.341, 19.394, 19.450, - ]); - let low = Series::from([ - 19.090, 19.086, 19.074, 19.145, 19.141, 19.155, 19.219, 19.306, 19.355, - ]); - let close = Series::from([ - 19.102, 19.100, 19.146, 19.181, 19.155, 19.248, 19.309, 19.355, 19.439, - ]); - let atr_period = 3; - let atr = atr(&high, &low, &close, Smooth::SMMA, atr_period); - let period = 3; - let epsilon = 0.001; - let expected = [ - 0.0, 0.0, 0.5, 0.8257546, 0.99494743, 0.9974737, 1.0, 0.9014031, 0.5520136, - ]; - - let result: Vec = snatr(&atr, atr_period, Smooth::WMA, period).into(); - - for i in 0..high.len() { - assert!( - (result[i] - expected[i]).abs() < epsilon, - "at position {}: {} != {}", - i, - result[i], - expected[i] - ) - } -} diff --git a/ta_lib/indicators/volatility/src/tr.rs b/ta_lib/indicators/volatility/src/tr.rs index 9e499a50..36550b98 100644 --- a/ta_lib/indicators/volatility/src/tr.rs +++ b/ta_lib/indicators/volatility/src/tr.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn tr(high: &Series, low: &Series, close: &Series) -> Series { +pub fn tr(high: &Price, low: &Price, close: &Price) -> Price { let prev_close = close.shift(1); let diff = high - low; @@ -12,6 +12,26 @@ pub fn tr(high: &Series, low: &Series, close: &Series) -> Series< ) } +pub fn wtr(high: &Price, low: &Price, close: &Price) -> Price { + let prev_close = close.shift(1); + let diff = high - low; + + iff!( + high.shift(1).na(), + diff, + diff.max(&(high - &prev_close)) + .max(&(low.negate() + &prev_close)) + ) +} + +pub fn atr(high: &Price, low: &Price, close: &Price, smooth: Smooth, period: Period) -> Price { + tr(high, low, close).smooth(smooth, period) +} + +pub fn snatr(atr: &Price, period: Period, smooth: Smooth, period_smooth: Period) -> Price { + atr.normalize(period, SCALE).smooth(smooth, period_smooth) +} + #[cfg(test)] mod tests { use super::*; @@ -49,7 +69,111 @@ mod tests { 0.022799969, ]; - let result: Vec = tr(&high, &low, &close).into(); + let result: Vec = tr(&high, &low, &close).into(); + + assert_eq!(result.len(), close.len()); + assert_eq!(result, expected); + } + + #[test] + fn test_wtrue_range() { + let high = Series::from([ + 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, + 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, + ]); + let low = Series::from([ + 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, + 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, + ]); + let close = Series::from([ + 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038, + 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, + ]); + let expected = vec![ + 0.01819992, + 0.06549978, + 0.064100266, + 0.07590008, + 0.041800022, + 0.056900024, + 0.050400257, + 0.017799854, + 0.029099941, + 0.025099754, + 0.03429985, + 0.050499916, + 0.02519989, + 0.07379961, + 0.023900032, + 0.022799969, + ]; + + let result: Vec = wtr(&high, &low, &close).into(); + + assert_eq!(result.len(), close.len()); + assert_eq!(result, expected); + } + + #[test] + fn test_atr_smma() { + let high = Series::from([ + 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, + 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, + ]); + let low = Series::from([ + 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, + 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, + ]); + let close = Series::from([ + 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038, + 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, + ]); + let period = 3; + let expected = [ + 0.01819992, + 0.03396654, + 0.044011116, + 0.05464077, + 0.05036052, + 0.05254035, + 0.051826984, + 0.040484603, + 0.036689714, + 0.032826394, + 0.033317544, + 0.039045, + 0.03442996, + 0.047553174, + 0.03966879, + 0.03404585, + ]; + + let result: Vec = atr(&high, &low, &close, Smooth::SMMA, period).into(); + + assert_eq!(result, expected); + } + + #[test] + fn test_snatr() { + use crate::atr; + + let high = Series::from([ + 19.129, 19.116, 19.154, 19.195, 19.217, 19.285, 19.341, 19.394, 19.450, + ]); + let low = Series::from([ + 19.090, 19.086, 19.074, 19.145, 19.141, 19.155, 19.219, 19.306, 19.355, + ]); + let close = Series::from([ + 19.102, 19.100, 19.146, 19.181, 19.155, 19.248, 19.309, 19.355, 19.439, + ]); + let atr_period = 3; + let atr = atr(&high, &low, &close, Smooth::SMMA, atr_period); + let period = 3; + let expected = [ + 0.0, 0.0, 50.0, 82.575455, 99.49475, 99.747375, 100.0, 90.14032, 55.201355, + ]; + + let result: Vec = snatr(&atr, atr_period, Smooth::WMA, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/volatility/src/yz.rs b/ta_lib/indicators/volatility/src/yz.rs new file mode 100644 index 00000000..bb1f7a60 --- /dev/null +++ b/ta_lib/indicators/volatility/src/yz.rs @@ -0,0 +1,40 @@ +use crate::rs; +use core::prelude::*; + +pub fn yz(open: &Price, high: &Price, low: &Price, close: &Price, period: Period) -> Price { + let oc = (open / close.shift(1).nz(Some(ZERO))).log(); + let ochat = oc.ma(period); + + let co = (close / open).log(); + let cohat = co.ma(period); + + let factor = 1. / (period - 1) as Scalar; + + let ov = factor * (oc - ochat).pow(2).sum(period); + let oc = factor * (co - cohat).pow(2).sum(period); + + let k = 0.34 / (1.34 + (period + 1) as Scalar / (period - 1) as Scalar); + let rs = rs(open, high, low, close, period).pow(2); + + (ov + k * oc + (1. - k) * rs).sqrt() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_yang_zhang() { + let open = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let high = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let low = Series::from([3.0, 2.0, 3.0, 4.0, 5.0]); + let close = Series::from([1.0, 2.0, 3.0, 4.0, 5.0]); + let period = 3; + + let expected = vec![0.0, 0.64916766, 0.649761, 0.27574953, 0.13916442]; + + let result: Vec = yz(&open, &high, &low, &close, period).into(); + + assert_eq!(result, expected); + } +} diff --git a/ta_lib/indicators/volume/src/cmf.rs b/ta_lib/indicators/volume/src/cmf.rs index 78a8711d..506a2343 100644 --- a/ta_lib/indicators/volume/src/cmf.rs +++ b/ta_lib/indicators/volume/src/cmf.rs @@ -1,12 +1,6 @@ use core::prelude::*; -pub fn cmf( - high: &Series, - low: &Series, - close: &Series, - volume: &Series, - period: usize, -) -> Series { +pub fn cmf(high: &Price, low: &Price, close: &Price, volume: &Price, period: Period) -> Price { let mfv = iff!( (close.seq(high) & close.seq(low)) | high.seq(low), Series::zero(close.len()), @@ -47,7 +41,7 @@ mod tests { 0.32079986, ]; - let result: Vec = cmf(&high, &low, &close, &volume, period).into(); + let result: Vec = cmf(&high, &low, &close, &volume, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/volume/src/eom.rs b/ta_lib/indicators/volume/src/eom.rs index 7004ff18..32881439 100644 --- a/ta_lib/indicators/volume/src/eom.rs +++ b/ta_lib/indicators/volume/src/eom.rs @@ -1,15 +1,14 @@ use core::prelude::*; pub fn eom( - hl2: &Series, - high: &Series, - low: &Series, - volume: &Series, - smooth_type: Smooth, - period: usize, - divisor: f32, -) -> Series { - (divisor * hl2.change(1) * (high - low) / volume).smooth(smooth_type, period) + hl2: &Price, + high: &Price, + low: &Price, + volume: &Price, + smooth: Smooth, + period: Period, +) -> Price { + (SCALE * SCALE * hl2.change(1) * (high - low) / volume).smooth(smooth, period) } #[cfg(test)] @@ -33,11 +32,10 @@ mod tests { 0.00023862119, ]; - let hlc = median_price(&high, &low); + let hl2 = median_price(&high, &low); let period = 2; - let divisor = 10000.0; - let result: Vec = eom(&hlc, &high, &low, &volume, Smooth::SMA, period, divisor).into(); + let result: Vec = eom(&hl2, &high, &low, &volume, Smooth::SMA, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/volume/src/lib.rs b/ta_lib/indicators/volume/src/lib.rs index f80d867d..2b492cf8 100644 --- a/ta_lib/indicators/volume/src/lib.rs +++ b/ta_lib/indicators/volume/src/lib.rs @@ -3,7 +3,6 @@ mod eom; mod mfi; mod nvol; mod obv; -mod vo; mod vwap; pub use cmf::cmf; @@ -11,5 +10,4 @@ pub use eom::eom; pub use mfi::mfi; pub use nvol::nvol; pub use obv::obv; -pub use vo::vo; pub use vwap::vwap; diff --git a/ta_lib/indicators/volume/src/mfi.rs b/ta_lib/indicators/volume/src/mfi.rs index e46afddf..c3b99f08 100644 --- a/ta_lib/indicators/volume/src/mfi.rs +++ b/ta_lib/indicators/volume/src/mfi.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn mfi(hlc3: &Series, volume: &Series, period: usize) -> Series { +pub fn mfi(hlc3: &Price, volume: &Price, period: Period) -> Price { let changes = hlc3.change(1); let volume_hlc3 = volume * hlc3; @@ -13,7 +13,7 @@ pub fn mfi(hlc3: &Series, volume: &Series, period: usize) -> Series = mfi(&hlc3, &volume, period).into(); + let result: Vec = mfi(&hlc3, &volume, period).into(); for i in 0..hlc3.len() { assert!( diff --git a/ta_lib/indicators/volume/src/nvol.rs b/ta_lib/indicators/volume/src/nvol.rs index 74a4783c..d7e73e97 100644 --- a/ta_lib/indicators/volume/src/nvol.rs +++ b/ta_lib/indicators/volume/src/nvol.rs @@ -1,7 +1,7 @@ use core::prelude::*; -pub fn nvol(volume: &Series, smooth_type: Smooth, period: usize) -> Series { - SCALE * volume / volume.smooth(smooth_type, period) +pub fn nvol(volume: &Price, smooth: Smooth, period: Period) -> Price { + SCALE * volume / volume.smooth(smooth, period) } #[cfg(test)] @@ -15,7 +15,7 @@ mod tests { let expected = [100.0, 23.115578, 132.68292, 72.897194, 163.8051, 28.640778]; - let result: Vec = nvol(&volume, Smooth::SMA, period).into(); + let result: Vec = nvol(&volume, Smooth::SMA, period).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/volume/src/obv.rs b/ta_lib/indicators/volume/src/obv.rs index 70e5a071..468b8f27 100644 --- a/ta_lib/indicators/volume/src/obv.rs +++ b/ta_lib/indicators/volume/src/obv.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn obv(source: &Series, volume: &Series) -> Series { +pub fn obv(source: &Price, volume: &Price) -> Price { (source.change(1).nz(Some(ZERO)).sign() * volume).cumsum() } @@ -22,7 +22,7 @@ mod tests { 3798.0, 9213.0, 16323.0, 18495.0, 11113.0, 13868.0, 15998.0, 37986.0, 47427.0, ]; - let result: Vec = obv(&close, &volume).into(); + let result: Vec = obv(&close, &volume).into(); assert_eq!(result, expected); } diff --git a/ta_lib/indicators/volume/src/vo.rs b/ta_lib/indicators/volume/src/vo.rs deleted file mode 100644 index 60beb650..00000000 --- a/ta_lib/indicators/volume/src/vo.rs +++ /dev/null @@ -1,37 +0,0 @@ -use core::prelude::*; - -pub fn vo( - source: &Series, - smooth_type: Smooth, - fast_period: usize, - slow_period: usize, -) -> Series { - let vo_short = source.smooth(smooth_type, fast_period); - let vo_long = source.smooth(smooth_type, slow_period); - - SCALE * (vo_short - &vo_long) / vo_long -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_vo() { - let source = Series::from([1.0, 2.0, 3.0, 2.0, 1.0]); - let expected = [0.0, 11.1111, 13.5802, 2.83224, -10.71604]; - let epsilon = 0.001; - - let result: Vec = vo(&source, Smooth::EMA, 2, 3).into(); - - for i in 0..source.len() { - assert!( - (result[i] - expected[i]).abs() < epsilon, - "at position {}: {} != {}", - i, - result[i], - expected[i] - ) - } - } -} diff --git a/ta_lib/indicators/volume/src/vwap.rs b/ta_lib/indicators/volume/src/vwap.rs index 6a6bd332..0f6b2579 100644 --- a/ta_lib/indicators/volume/src/vwap.rs +++ b/ta_lib/indicators/volume/src/vwap.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn vwap(source: &Series, volume: &Series) -> Series { +pub fn vwap(source: &Price, volume: &Price) -> Price { (source * volume).cumsum() / volume.cumsum() } @@ -19,7 +19,7 @@ mod tests { let expected = [1.5, 2.5, 3.5]; let epsilon = 0.001; - let result: Vec = vwap(&hlc3, &volume).into(); + let result: Vec = vwap(&hlc3, &volume).into(); for i in 0..high.len() { assert!( diff --git a/ta_lib/patterns/bands/Cargo.toml b/ta_lib/patterns/bands/Cargo.toml new file mode 100644 index 00000000..a8b44b67 --- /dev/null +++ b/ta_lib/patterns/bands/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "bands" +resolver = "2" + +authors.workspace = true +edition.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] diff --git a/ta_lib/patterns/bands/src/lib.rs b/ta_lib/patterns/bands/src/lib.rs new file mode 100644 index 00000000..eda363d6 --- /dev/null +++ b/ta_lib/patterns/bands/src/lib.rs @@ -0,0 +1 @@ +pub mod macros; diff --git a/ta_lib/patterns/bands/src/macros.rs b/ta_lib/patterns/bands/src/macros.rs new file mode 100644 index 00000000..b3391670 --- /dev/null +++ b/ta_lib/patterns/bands/src/macros.rs @@ -0,0 +1,53 @@ +#[macro_export] +macro_rules! a { + ($upper_band:expr, $lower_band:expr, $source:expr) => {{ + let prev_source = $source.shift(1); + + ( + $source.slt(&$lower_band) & prev_source.sgt(&$lower_band.shift(1)), + $source.sgt(&$upper_band) & prev_source.slt(&$upper_band.shift(1)), + ) + }}; +} + +#[macro_export] +macro_rules! c { + ($upper_band:expr, $middle_band:expr, $lower_band:expr, $source:expr) => {{ + let prev_source = $source.shift(1); + + ( + $source.sgt(&$lower_band) + & prev_source.slt(&$lower_band.shift(1)) + & $source.slt(&$middle_band), + $source.slt(&$upper_band) + & prev_source.sgt(&$upper_band.shift(1)) + & $source.sgt(&$middle_band), + ) + }}; +} + +#[macro_export] +macro_rules! r { + ($upper_band:expr, $lower_band:expr, $source:expr) => {{ + let prev_source = $source.shift(1); + let back_2_source = $source.shift(2); + let back_3_source = $source.shift(3); + let back_4_source = $source.shift(4); + let back_5_source = $source.shift(5); + + ( + $source.sgt(&$lower_band) + & prev_source.slt(&$lower_band.shift(1)) + & back_2_source.slt(&$lower_band.shift(2)) + & back_3_source.slt(&$lower_band.shift(3)) + & back_4_source.slt(&$lower_band.shift(4)) + & back_5_source.slt(&$lower_band.shift(5)), + $source.slt(&$upper_band) + & prev_source.sgt(&$upper_band.shift(1)) + & back_2_source.sgt(&$upper_band.shift(2)) + & back_3_source.sgt(&$upper_band.shift(3)) + & back_4_source.sgt(&$upper_band.shift(4)) + & back_5_source.sgt(&$upper_band.shift(5)), + ) + }}; +} diff --git a/ta_lib/patterns/candlestick/src/barrier.rs b/ta_lib/patterns/candlestick/src/barrier.rs index c252a430..a190b422 100644 --- a/ta_lib/patterns/candlestick/src/barrier.rs +++ b/ta_lib/patterns/candlestick/src/barrier.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, low: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, low: &Price, close: &Price) -> Rule { let back_2_low = low.shift(2); close.shift(1).sgt(&open.shift(1)) @@ -10,7 +10,7 @@ pub fn bullish(open: &Series, low: &Series, close: &Series) -> Se & back_2_low.seq(&low.shift(3)) } -pub fn bearish(open: &Series, high: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, high: &Price, close: &Price) -> Rule { let back_2_high = high.shift(2); close.shift(1).slt(&open.shift(1)) diff --git a/ta_lib/patterns/candlestick/src/blockade.rs b/ta_lib/patterns/candlestick/src/blockade.rs index 1ceb8755..f2138e39 100644 --- a/ta_lib/patterns/candlestick/src/blockade.rs +++ b/ta_lib/patterns/candlestick/src/blockade.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let close_4_back = close.shift(4); @@ -29,12 +24,7 @@ pub fn bullish( & high.shift(3).slt(&high_4_back) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let close_4_back = close.shift(4); diff --git a/ta_lib/patterns/candlestick/src/bottle.rs b/ta_lib/patterns/candlestick/src/bottle.rs index ee3ff368..bb87a054 100644 --- a/ta_lib/patterns/candlestick/src/bottle.rs +++ b/ta_lib/patterns/candlestick/src/bottle.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, low: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let back_2_close = close.shift(2); @@ -13,7 +13,7 @@ pub fn bullish(open: &Series, low: &Series, close: &Series) -> Se & prev_close.sgt(&back_2_close) } -pub fn bearish(open: &Series, high: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, high: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let back_2_close = close.shift(2); diff --git a/ta_lib/patterns/candlestick/src/breakaway.rs b/ta_lib/patterns/candlestick/src/breakaway.rs index 57c9c75a..6e5ba427 100644 --- a/ta_lib/patterns/candlestick/src/breakaway.rs +++ b/ta_lib/patterns/candlestick/src/breakaway.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, close: &Price) -> Rule { let back_3_open = open.shift(3); let back_4_close = close.shift(4); @@ -13,7 +13,7 @@ pub fn bullish(open: &Series, close: &Series) -> Series { & back_3_open.slt(&back_4_close) } -pub fn bearish(open: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, close: &Price) -> Rule { let back_3_open = open.shift(3); let back_4_close = close.shift(4); diff --git a/ta_lib/patterns/candlestick/src/counterattack.rs b/ta_lib/patterns/candlestick/src/counterattack.rs index 2e5f1879..ebac5feb 100644 --- a/ta_lib/patterns/candlestick/src/counterattack.rs +++ b/ta_lib/patterns/candlestick/src/counterattack.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); open.slt(&prev_close) @@ -9,7 +9,7 @@ pub fn bullish(open: &Series, close: &Series) -> Series { & close.seq(&prev_close) } -pub fn bearish(open: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); open.sgt(&prev_close) diff --git a/ta_lib/patterns/candlestick/src/doji.rs b/ta_lib/patterns/candlestick/src/doji.rs index 1854ed04..f784578c 100644 --- a/ta_lib/patterns/candlestick/src/doji.rs +++ b/ta_lib/patterns/candlestick/src/doji.rs @@ -1,10 +1,10 @@ use core::prelude::*; -pub fn bullish(open: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, close: &Price) -> Rule { close.sgt(open) & close.shift(1).seq(&open.shift(1)) & close.shift(2).slt(&open.shift(2)) } -pub fn bearish(open: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, close: &Price) -> Rule { close.slt(open) & close.shift(1).seq(&open.shift(1)) & close.shift(2).sgt(&open.shift(2)) } diff --git a/ta_lib/patterns/candlestick/src/doji_double.rs b/ta_lib/patterns/candlestick/src/doji_double.rs index b3f5d20e..0eff8a4a 100644 --- a/ta_lib/patterns/candlestick/src/doji_double.rs +++ b/ta_lib/patterns/candlestick/src/doji_double.rs @@ -1,12 +1,12 @@ use core::prelude::*; -pub fn bullish(open: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, close: &Price) -> Rule { close.shift(1).seq(&open.shift(1)) & close.shift(2).seq(&open.shift(2)) & close.shift(3).slt(&open.shift(3)) } -pub fn bearish(open: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, close: &Price) -> Rule { close.shift(1).seq(&open.shift(1)) & close.shift(2).seq(&open.shift(2)) & close.shift(3).sgt(&open.shift(3)) diff --git a/ta_lib/patterns/candlestick/src/doppelganger.rs b/ta_lib/patterns/candlestick/src/doppelganger.rs index 4fef3476..169e9f15 100644 --- a/ta_lib/patterns/candlestick/src/doppelganger.rs +++ b/ta_lib/patterns/candlestick/src/doppelganger.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let back_2_close = close.shift(2); @@ -23,12 +18,7 @@ pub fn bullish( .seq(&back_2_close.min(&back_2_open)) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let back_2_close = close.shift(2); diff --git a/ta_lib/patterns/candlestick/src/double_trouble.rs b/ta_lib/patterns/candlestick/src/double_trouble.rs index a94e6e1e..c7dd6f7a 100644 --- a/ta_lib/patterns/candlestick/src/double_trouble.rs +++ b/ta_lib/patterns/candlestick/src/double_trouble.rs @@ -1,12 +1,7 @@ use core::prelude::*; use volatility::atr; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let atr = atr(high, low, close, Smooth::SMMA, 10); let prev_close = close.shift(1); @@ -16,12 +11,7 @@ pub fn bullish( & (close - open).sgt(&(2.0 * atr.shift(1))) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let atr = atr(high, low, close, Smooth::SMMA, 10); let prev_close = close.shift(1); diff --git a/ta_lib/patterns/candlestick/src/engulfing.rs b/ta_lib/patterns/candlestick/src/engulfing.rs index 5c40d4f6..724d8240 100644 --- a/ta_lib/patterns/candlestick/src/engulfing.rs +++ b/ta_lib/patterns/candlestick/src/engulfing.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let body = (close - open).abs(); close.sgt(open) @@ -15,12 +10,7 @@ pub fn bullish( & body.sgte(&(2.0 * body.shift(1))) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let body = (close - open).abs(); close.slt(open) diff --git a/ta_lib/patterns/candlestick/src/euphoria.rs b/ta_lib/patterns/candlestick/src/euphoria.rs index 28439f70..7fb4a310 100644 --- a/ta_lib/patterns/candlestick/src/euphoria.rs +++ b/ta_lib/patterns/candlestick/src/euphoria.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, close: &Price) -> Rule { let body = (close - open).abs(); let back_2_body = body.shift(2); @@ -12,7 +12,7 @@ pub fn bullish(open: &Series, close: &Series) -> Series { & back_2_body.sgt(&body.shift(3)) } -pub fn bearish(open: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, close: &Price) -> Rule { let body = (close - open).abs(); let back_2_body = body.shift(2); diff --git a/ta_lib/patterns/candlestick/src/euphoria_extreme.rs b/ta_lib/patterns/candlestick/src/euphoria_extreme.rs index 2b33118d..dfcfa019 100644 --- a/ta_lib/patterns/candlestick/src/euphoria_extreme.rs +++ b/ta_lib/patterns/candlestick/src/euphoria_extreme.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, close: &Price) -> Rule { let body = (close - open).abs(); let back_2_body = body.shift(2); @@ -16,7 +16,7 @@ pub fn bullish(open: &Series, close: &Series) -> Series { & back_3_body.sgt(&body.shift(4)) } -pub fn bearish(open: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, close: &Price) -> Rule { let body = (close - open).abs(); let back_2_body = body.shift(2); diff --git a/ta_lib/patterns/candlestick/src/golden.rs b/ta_lib/patterns/candlestick/src/golden.rs index 8aee3c79..d4b5a644 100644 --- a/ta_lib/patterns/candlestick/src/golden.rs +++ b/ta_lib/patterns/candlestick/src/golden.rs @@ -1,25 +1,15 @@ use core::prelude::*; -const GOLDEN_RATIO: f32 = 2.618; - -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +const GOLDEN_RATIO: Scalar = 2.618; + +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let back_2_low = low.shift(2); let golden_low = &back_2_low + GOLDEN_RATIO * (high.shift(2) - &back_2_low); low.slte(&open.shift(1)) & close.shift(1).sgt(&golden_low) & close.shift(2).sgt(&open.shift(2)) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let back_2_high = high.shift(2); let golden_high = &back_2_high - GOLDEN_RATIO * (&back_2_high - low.shift(2)); diff --git a/ta_lib/patterns/candlestick/src/h.rs b/ta_lib/patterns/candlestick/src/h.rs index 2c155cbe..435f58cd 100644 --- a/ta_lib/patterns/candlestick/src/h.rs +++ b/ta_lib/patterns/candlestick/src/h.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); prev_close.sgt(&open.shift(1)) @@ -15,12 +10,7 @@ pub fn bullish( & low.shift(1).sgte(&low.shift(2)) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); prev_close.slt(&open.shift(1)) diff --git a/ta_lib/patterns/candlestick/src/hammer.rs b/ta_lib/patterns/candlestick/src/hammer.rs index 15dc5097..1cff088a 100644 --- a/ta_lib/patterns/candlestick/src/hammer.rs +++ b/ta_lib/patterns/candlestick/src/hammer.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, high: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, high: &Price, close: &Price) -> Rule { let body = (close - open).abs(); let back_2_body = body.shift(2); @@ -13,7 +13,7 @@ pub fn bullish(open: &Series, high: &Series, close: &Series) -> S & back_2_body.slt(&body.shift(3)) } -pub fn bearish(open: &Series, low: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, low: &Price, close: &Price) -> Rule { let body = (close - open).abs(); let back_2_body = body.shift(2); diff --git a/ta_lib/patterns/candlestick/src/harami_flexible.rs b/ta_lib/patterns/candlestick/src/harami_flexible.rs index 0acc2fcb..cb3c6596 100644 --- a/ta_lib/patterns/candlestick/src/harami_flexible.rs +++ b/ta_lib/patterns/candlestick/src/harami_flexible.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let prev_open = open.shift(1); @@ -18,12 +13,7 @@ pub fn bullish( & close.shift(2).slt(&open.shift(2)) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let prev_open = open.shift(1); diff --git a/ta_lib/patterns/candlestick/src/harami_strict.rs b/ta_lib/patterns/candlestick/src/harami_strict.rs index e99b0723..17bcedb9 100644 --- a/ta_lib/patterns/candlestick/src/harami_strict.rs +++ b/ta_lib/patterns/candlestick/src/harami_strict.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let prev_open = open.shift(1); @@ -16,12 +11,7 @@ pub fn bullish( & close.shift(2).slt(&open.shift(2)) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let prev_open = open.shift(1); diff --git a/ta_lib/patterns/candlestick/src/hexad.rs b/ta_lib/patterns/candlestick/src/hexad.rs index 0c8985ff..88b7f7b1 100644 --- a/ta_lib/patterns/candlestick/src/hexad.rs +++ b/ta_lib/patterns/candlestick/src/hexad.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, high: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, high: &Price, close: &Price) -> Rule { close.sgt(open) & close.shift(1).sgt(&open.shift(1)) & close.shift(2).sgt(&open.shift(2)) @@ -10,7 +10,7 @@ pub fn bullish(open: &Series, high: &Series, close: &Series) -> S & close.sgt(&high.shift(5)) } -pub fn bearish(open: &Series, low: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, low: &Price, close: &Price) -> Rule { close.slt(open) & close.shift(1).slt(&open.shift(1)) & close.shift(2).slt(&open.shift(2)) diff --git a/ta_lib/patterns/candlestick/src/hikkake.rs b/ta_lib/patterns/candlestick/src/hikkake.rs index 12c3d689..7087f95c 100644 --- a/ta_lib/patterns/candlestick/src/hikkake.rs +++ b/ta_lib/patterns/candlestick/src/hikkake.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let back_4_close = close.shift(4); let back_3_high = high.shift(3); @@ -22,12 +17,7 @@ pub fn bullish( & back_4_close.sgt(&open.shift(4)) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let back_4_close = close.shift(4); let back_3_low = low.shift(3); diff --git a/ta_lib/patterns/candlestick/src/kangaroo_tail.rs b/ta_lib/patterns/candlestick/src/kangaroo_tail.rs index 1ca27737..81b84d29 100644 --- a/ta_lib/patterns/candlestick/src/kangaroo_tail.rs +++ b/ta_lib/patterns/candlestick/src/kangaroo_tail.rs @@ -1,13 +1,8 @@ use core::prelude::*; -const RANGE_RATIO: f32 = 0.66; +const RANGE_RATIO: Scalar = 0.66; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let range = high - low; let two_third_low_range = low + &range * RANGE_RATIO; let prev_low = low.shift(1); @@ -27,12 +22,7 @@ pub fn bullish( & low.slte(&low.lowest(13)) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let range = high - low; let two_third_high_range = high - &range * RANGE_RATIO; let prev_low = low.shift(1); diff --git a/ta_lib/patterns/candlestick/src/lib.rs b/ta_lib/patterns/candlestick/src/lib.rs index ffb4fbbd..3d7f06b9 100644 --- a/ta_lib/patterns/candlestick/src/lib.rs +++ b/ta_lib/patterns/candlestick/src/lib.rs @@ -23,6 +23,7 @@ pub mod master_candle; pub mod on_neck; pub mod piercing; pub mod quintuplets; +pub mod r; pub mod shrinking; pub mod slingshot; pub mod split; @@ -30,3 +31,4 @@ pub mod tasuki; pub mod three_candles; pub mod three_methods; pub mod three_one_two; +pub mod tweezers; diff --git a/ta_lib/patterns/candlestick/src/marubozu.rs b/ta_lib/patterns/candlestick/src/marubozu.rs index 081da9cf..c3f2f568 100644 --- a/ta_lib/patterns/candlestick/src/marubozu.rs +++ b/ta_lib/patterns/candlestick/src/marubozu.rs @@ -1,23 +1,13 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let prev_open = open.shift(1); prev_close.sgt(&prev_open) & high.shift(1).seq(&prev_close) & low.shift(1).seq(&prev_open) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let prev_open = open.shift(1); diff --git a/ta_lib/patterns/candlestick/src/master_candle.rs b/ta_lib/patterns/candlestick/src/master_candle.rs index 0cecbac8..015fb84f 100644 --- a/ta_lib/patterns/candlestick/src/master_candle.rs +++ b/ta_lib/patterns/candlestick/src/master_candle.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let back_6_high = high.shift(6); let back_6_low = low.shift(6); @@ -23,12 +18,7 @@ pub fn bullish( & low.shift(5).sgt(&back_6_low) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let back_6_high = high.shift(6); let back_6_low = low.shift(6); diff --git a/ta_lib/patterns/candlestick/src/on_neck.rs b/ta_lib/patterns/candlestick/src/on_neck.rs index 898efa00..485ff9b3 100644 --- a/ta_lib/patterns/candlestick/src/on_neck.rs +++ b/ta_lib/patterns/candlestick/src/on_neck.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, high: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, high: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); close.seq(&prev_close) @@ -10,7 +10,7 @@ pub fn bullish(open: &Series, high: &Series, close: &Series) -> S & close.sgt(open) } -pub fn bearish(open: &Series, low: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); close.seq(&prev_close) diff --git a/ta_lib/patterns/candlestick/src/piercing.rs b/ta_lib/patterns/candlestick/src/piercing.rs index 0da70dac..2cd47ad8 100644 --- a/ta_lib/patterns/candlestick/src/piercing.rs +++ b/ta_lib/patterns/candlestick/src/piercing.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let prev_open = open.shift(1); @@ -11,7 +11,7 @@ pub fn bullish(open: &Series, close: &Series) -> Series { & close.sgte(&(0.5 * (prev_close + prev_open))) } -pub fn bearish(open: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let prev_open = open.shift(1); diff --git a/ta_lib/patterns/candlestick/src/quintuplets.rs b/ta_lib/patterns/candlestick/src/quintuplets.rs index 295643a3..358996f1 100644 --- a/ta_lib/patterns/candlestick/src/quintuplets.rs +++ b/ta_lib/patterns/candlestick/src/quintuplets.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, close: &Price) -> Rule { let body = (open - close).abs(); let prev_close = close.shift(1); @@ -29,7 +29,7 @@ pub fn bullish(open: &Series, close: &Series) -> Series { & back_4_body.slt(&body.shift(5)) } -pub fn bearish(open: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, close: &Price) -> Rule { let body = (open - close).abs(); let prev_close = close.shift(1); diff --git a/ta_lib/patterns/candlestick/src/r.rs b/ta_lib/patterns/candlestick/src/r.rs new file mode 100644 index 00000000..4debc7c3 --- /dev/null +++ b/ta_lib/patterns/candlestick/src/r.rs @@ -0,0 +1,56 @@ +use core::prelude::*; + +pub fn bullish(low: &Price, close: &Price) -> Rule { + let prev_close = close.shift(1); + let back_2_close = close.shift(2); + let prev_low = low.shift(1); + let back_2_low = low.shift(2); + + low.sgt(&prev_low) + & prev_low.sgt(&back_2_low) + & back_2_low.slt(&low.shift(3)) + & close.sgt(&prev_close) + & prev_close.sgt(&back_2_close) + & back_2_close.sgt(&close.shift(3)) +} + +pub fn bearish(high: &Price, close: &Price) -> Rule { + let prev_close = close.shift(1); + let back_2_close = close.shift(2); + let prev_high = high.shift(1); + let back_2_high = high.shift(2); + + high.slt(&prev_high) + & prev_high.slt(&back_2_high) + & back_2_high.sgt(&high.shift(3)) + & close.slt(&prev_close) + & prev_close.slt(&back_2_close) + & back_2_close.slt(&close.shift(3)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_r_bullish() { + let low = Series::from([0.4818, 0.4815, 0.4812, 0.4836, 0.4850]); + let close = Series::from([0.4822, 0.4818, 0.4837, 0.4856, 0.4888]); + let expected = vec![false, false, false, false, true]; + + let result: Vec = bullish(&low, &close).into(); + + assert_eq!(result, expected); + } + + #[test] + fn test_r_bearish() { + let high = Series::from([0.4802, 0.4807, 0.4808, 0.4796, 0.4791]); + let close = Series::from([0.4801, 0.4799, 0.4794, 0.4785, 0.4783]); + let expected = vec![false, false, false, false, true]; + + let result: Vec = bearish(&high, &close).into(); + + assert_eq!(result, expected); + } +} diff --git a/ta_lib/patterns/candlestick/src/shrinking.rs b/ta_lib/patterns/candlestick/src/shrinking.rs index e1d10416..900fdff7 100644 --- a/ta_lib/patterns/candlestick/src/shrinking.rs +++ b/ta_lib/patterns/candlestick/src/shrinking.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_high = high.shift(1); let back_2_high = high.shift(2); let back_3_high = high.shift(3); @@ -23,12 +18,7 @@ pub fn bullish( & close.shift(4).slt(&open.shift(4)) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_low = low.shift(1); let back_2_low = low.shift(2); let back_3_low = low.shift(3); diff --git a/ta_lib/patterns/candlestick/src/slingshot.rs b/ta_lib/patterns/candlestick/src/slingshot.rs index 39b32ff9..064796a6 100644 --- a/ta_lib/patterns/candlestick/src/slingshot.rs +++ b/ta_lib/patterns/candlestick/src/slingshot.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let close_3_back = close.shift(3); let close_4_back = close.shift(4); @@ -20,12 +15,7 @@ pub fn bullish( & low.shift(1).sgt(&close_4_back) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let back_3_close = close.shift(3); let back_4_close = close.shift(4); diff --git a/ta_lib/patterns/candlestick/src/split.rs b/ta_lib/patterns/candlestick/src/split.rs index 6321d987..adac3ea6 100644 --- a/ta_lib/patterns/candlestick/src/split.rs +++ b/ta_lib/patterns/candlestick/src/split.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let prev_open = open.shift(1); @@ -18,12 +13,7 @@ pub fn bullish( & close.sgt(&prev_open) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_close = close.shift(1); let prev_open = open.shift(1); diff --git a/ta_lib/patterns/candlestick/src/tasuki.rs b/ta_lib/patterns/candlestick/src/tasuki.rs index db90f3d6..6a936c4c 100644 --- a/ta_lib/patterns/candlestick/src/tasuki.rs +++ b/ta_lib/patterns/candlestick/src/tasuki.rs @@ -1,8 +1,7 @@ use core::prelude::*; -pub fn bullish(open: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, close: &Price) -> Rule { let prev_open = open.shift(1); - let back_2_close = close.shift(2); close.slt(open) @@ -13,9 +12,8 @@ pub fn bullish(open: &Series, close: &Series) -> Series { & prev_open.sgt(&back_2_close) } -pub fn bearish(open: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, close: &Price) -> Rule { let prev_open = open.shift(1); - let back_2_close = close.shift(2); close.sgt(open) diff --git a/ta_lib/patterns/candlestick/src/three_candles.rs b/ta_lib/patterns/candlestick/src/three_candles.rs index 3c81f83e..da4e4e75 100644 --- a/ta_lib/patterns/candlestick/src/three_candles.rs +++ b/ta_lib/patterns/candlestick/src/three_candles.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn bullish(open: &Series, close: &Series) -> Series { +pub fn bullish(open: &Price, close: &Price) -> Rule { let body = (open - close).abs(); let prev_close = close.shift(1); @@ -17,7 +17,7 @@ pub fn bullish(open: &Series, close: &Series) -> Series { & back_2_body.sgte(&back_2_body.highest(5)) } -pub fn bearish(open: &Series, close: &Series) -> Series { +pub fn bearish(open: &Price, close: &Price) -> Rule { let body = (open - close).abs(); let prev_close = close.shift(1); diff --git a/ta_lib/patterns/candlestick/src/three_methods.rs b/ta_lib/patterns/candlestick/src/three_methods.rs index 974ecaae..0a62268f 100644 --- a/ta_lib/patterns/candlestick/src/three_methods.rs +++ b/ta_lib/patterns/candlestick/src/three_methods.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_low = low.shift(1); let back_4_low = low.shift(4); @@ -23,12 +18,7 @@ pub fn bullish( & back_4_close.sgt(&open.shift(4)) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_high = high.shift(1); let back_4_high = high.shift(4); diff --git a/ta_lib/patterns/candlestick/src/three_one_two.rs b/ta_lib/patterns/candlestick/src/three_one_two.rs index 5b122aa6..0612b3a4 100644 --- a/ta_lib/patterns/candlestick/src/three_one_two.rs +++ b/ta_lib/patterns/candlestick/src/three_one_two.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn bullish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bullish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_high = high.shift(1); let back_2_high = high.shift(2); @@ -19,12 +14,7 @@ pub fn bullish( & high.sgt(&prev_high) } -pub fn bearish( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn bearish(open: &Price, high: &Price, low: &Price, close: &Price) -> Rule { let prev_low = low.shift(1); let back_2_low = low.shift(2); diff --git a/ta_lib/patterns/candlestick/src/tweezers.rs b/ta_lib/patterns/candlestick/src/tweezers.rs new file mode 100644 index 00000000..50f2594b --- /dev/null +++ b/ta_lib/patterns/candlestick/src/tweezers.rs @@ -0,0 +1,60 @@ +use core::prelude::*; + +const RANGE: Scalar = 0.0005; + +pub fn bullish(open: &Price, low: &Price, close: &Price) -> Rule { + let prev_close = close.shift(1); + let prev_open = open.shift(1); + let body = (close - open).abs(); + let prev_body = body.shift(1); + + close.sgt(open) + & low.seq(&low.shift(1)) + & body.slt(&RANGE) + & prev_body.slt(&RANGE) + & prev_close.slt(&prev_open) + & close.shift(2).slt(&open.shift(1)) +} + +pub fn bearish(open: &Price, high: &Price, close: &Price) -> Rule { + let prev_close = close.shift(1); + let prev_open = open.shift(1); + let body = (close - open).abs(); + let prev_body = body.shift(1); + + close.slt(open) + & high.seq(&high.shift(1)) + & body.slt(&RANGE) + & prev_body.slt(&RANGE) + & prev_close.sgt(&prev_open) + & close.shift(2).sgt(&open.shift(1)) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_tweezers_bullish() { + let open = Series::from([4.0, 3.0, 4.0, 3.0, 4.0]); + let low = Series::from([4.0, 3.0, 4.0, 3.0, 4.0]); + let close = Series::from([4.5, 3.5, 4.5, 3.5, 4.5]); + let expected = vec![false, false, false, false, false]; + + let result: Vec = bullish(&open, &low, &close).into(); + + assert_eq!(result, expected); + } + + #[test] + fn test_tweezers_bearish() { + let open = Series::from([4.0, 3.0, 4.0, 3.0, 4.0]); + let high = Series::from([4.0, 3.0, 4.0, 3.0, 4.0]); + let close = Series::from([3.5, 2.5, 3.5, 2.5, 3.5]); + let expected = vec![false, false, false, false, false]; + + let result: Vec = bearish(&open, &high, &close).into(); + + assert_eq!(result, expected); + } +} diff --git a/ta_lib/patterns/channel/Cargo.toml b/ta_lib/patterns/channel/Cargo.toml new file mode 100644 index 00000000..d22726e3 --- /dev/null +++ b/ta_lib/patterns/channel/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "channel" +resolver = "2" + +authors.workspace = true +edition.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] diff --git a/ta_lib/patterns/channel/src/lib.rs b/ta_lib/patterns/channel/src/lib.rs new file mode 100644 index 00000000..eda363d6 --- /dev/null +++ b/ta_lib/patterns/channel/src/lib.rs @@ -0,0 +1 @@ +pub mod macros; diff --git a/ta_lib/patterns/channel/src/macros.rs b/ta_lib/patterns/channel/src/macros.rs new file mode 100644 index 00000000..48ef25a3 --- /dev/null +++ b/ta_lib/patterns/channel/src/macros.rs @@ -0,0 +1,24 @@ +#[macro_export] +macro_rules! a { + ($source:expr, $upper_channel:expr, $lower_channel:expr) => {{ + let prev_source = $source.shift(1); + + ( + $source.sgt(&$lower_channel) & prev_source.slt(&$lower_channel.shift(1)), + $source.slt(&$upper_channel) & prev_source.sgt(&$upper_channel.shift(1)), + ) + }}; +} + +#[macro_export] +macro_rules! c { + ($low:expr, $high:expr, $upper_channel:expr, $lower_channel:expr) => {{ + let prev_lwch = $lower_channel.shift(1); + let prev_upch = $upper_channel.shift(1); + + ( + $low.slt(&prev_lwch) & $low.shift(1).sgt(&prev_lwch), + $high.sgt(&prev_upch) & $high.shift(1).slt(&prev_upch), + ) + }}; +} diff --git a/ta_lib/patterns/osc/Cargo.toml b/ta_lib/patterns/osc/Cargo.toml new file mode 100644 index 00000000..510730cc --- /dev/null +++ b/ta_lib/patterns/osc/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "osc" +resolver = "2" + +authors.workspace = true +edition.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] diff --git a/ta_lib/patterns/osc/src/lib.rs b/ta_lib/patterns/osc/src/lib.rs new file mode 100644 index 00000000..eda363d6 --- /dev/null +++ b/ta_lib/patterns/osc/src/lib.rs @@ -0,0 +1 @@ +pub mod macros; diff --git a/ta_lib/patterns/osc/src/macros.rs b/ta_lib/patterns/osc/src/macros.rs new file mode 100644 index 00000000..e9ce7cbb --- /dev/null +++ b/ta_lib/patterns/osc/src/macros.rs @@ -0,0 +1,68 @@ +#[macro_export] +macro_rules! a { + ($osc:expr, $lower_const:expr, $upper_const:expr) => {{ + let prev_osc = $osc.shift(1); + + ( + $osc.slt(&$lower_const) & prev_osc.sgt(&$lower_const), + $osc.sgt(&$upper_const) & prev_osc.slt(&$upper_const), + ) + }}; +} + +#[macro_export] +macro_rules! c { + ($osc:expr, $lower_const:expr, $upper_const:expr) => {{ + let prev_osc = $osc.shift(1); + + ( + $osc.sgt(&$lower_const) & prev_osc.slt(&$lower_const), + $osc.slt(&$upper_const) & prev_osc.sgt(&$upper_const), + ) + }}; +} + +#[macro_export] +macro_rules! v { + ($osc:expr, $lower_const:expr, $upper_const:expr) => {{ + let prev_osc = $osc.shift(1); + let osc_2_back = $osc.shift(2); + + ( + $osc.sgt(&$lower_const) & prev_osc.slt(&$lower_const) & osc_2_back.sgt(&$lower_const), + $osc.slt(&$upper_const) & prev_osc.sgt(&$upper_const) & osc_2_back.slt(&$upper_const), + ) + }}; +} + +#[macro_export] +macro_rules! w { + ($osc:expr, $lower_const:expr, $upper_const:expr) => {{ + let prev_osc = $osc.shift(1); + let osc_2_back = $osc.shift(2); + let osc_3_back = $osc.shift(3); + let osc_4_back = $osc.shift(4); + let osc_5_back = $osc.shift(5); + + ( + $osc.sgt(&$lower_const) + & prev_osc.slt(&$lower_const) + & prev_osc.slt(&osc_2_back) + & osc_2_back.slt(&$lower_cons) + & osc_2_back.sgt(&osc_3_back) + & osc_3_back.slt(&$lower_cons) + & osc_3_back.slt(&osc_4_back) + & osc_4_back.slt(&$lower_cons) + & osc_5_back.sgt(&$lower_cons), + $osc.slt(&$upper_const) + & prev_osc.sgt(&$upper_const) + & prev_osc.sgt(&osc_2_back) + & osc_2_back.sgt(&$upper_const) + & osc_2_back.slt(&osc_3_back) + & osc_3_back.sgt(&$upper_const) + & osc_3_back.sgt(&osc_4_back) + & osc_4_back.sgt(&$upper_const) + & osc_5_back.slt(&$upper_const), + ) + }}; +} diff --git a/ta_lib/patterns/trail/Cargo.toml b/ta_lib/patterns/trail/Cargo.toml new file mode 100644 index 00000000..98a30f67 --- /dev/null +++ b/ta_lib/patterns/trail/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "trail" +resolver = "2" + +authors.workspace = true +edition.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +version.workspace = true + +[dependencies] +core = { path = "../../core" } \ No newline at end of file diff --git a/ta_lib/patterns/trail/src/lib.rs b/ta_lib/patterns/trail/src/lib.rs new file mode 100644 index 00000000..eda363d6 --- /dev/null +++ b/ta_lib/patterns/trail/src/lib.rs @@ -0,0 +1 @@ +pub mod macros; diff --git a/ta_lib/patterns/trail/src/macros.rs b/ta_lib/patterns/trail/src/macros.rs new file mode 100644 index 00000000..f83b8e83 --- /dev/null +++ b/ta_lib/patterns/trail/src/macros.rs @@ -0,0 +1,18 @@ +#[macro_export] +macro_rules! f { + ($direction:expr) => {{ + ($direction.cross_over(&ZERO), $direction.cross_under(&ZERO)) + }}; +} + +#[macro_export] +macro_rules! p { + ($trend:expr, $high:expr, $low:expr, $close:expr) => {{ + let prev_trend = $trend.shift(1); + + ( + $low.shift(1).cross_under(&prev_trend) & $close.sgt(&$trend), + $high.shift(1).cross_over(&prev_trend) & $close.slt(&$trend), + ) + }}; +} diff --git a/ta_lib/price/src/average.rs b/ta_lib/price/src/average.rs index 39580bbb..e4d9b78d 100644 --- a/ta_lib/price/src/average.rs +++ b/ta_lib/price/src/average.rs @@ -1,11 +1,6 @@ use core::prelude::*; -pub fn average_price( - open: &Series, - high: &Series, - low: &Series, - close: &Series, -) -> Series { +pub fn average_price(open: &Price, high: &Price, low: &Price, close: &Price) -> Price { (open + high + low + close) / 4. } @@ -22,7 +17,7 @@ mod tests { let expected = vec![1.5625, 3.125, 4.6875]; - let result: Vec = average_price(&open, &high, &low, &close).into(); + let result: Vec = average_price(&open, &high, &low, &close).into(); assert_eq!(result, expected); } diff --git a/ta_lib/price/src/median.rs b/ta_lib/price/src/median.rs index db2cd349..71bbfda2 100644 --- a/ta_lib/price/src/median.rs +++ b/ta_lib/price/src/median.rs @@ -1,7 +1,7 @@ use core::prelude::*; -pub fn median_price(high: &Series, low: &Series) -> Series { - 0.5 * (high + low) +pub fn median_price(high: &Price, low: &Price) -> Price { + HALF * (high + low) } #[cfg(test)] @@ -15,7 +15,7 @@ mod tests { let expected = vec![0.75, 1.5, 2.5]; - let result: Vec = median_price(&high, &low).into(); + let result: Vec = median_price(&high, &low).into(); assert_eq!(result, expected); } diff --git a/ta_lib/price/src/typical.rs b/ta_lib/price/src/typical.rs index 2d27abb6..5e007401 100644 --- a/ta_lib/price/src/typical.rs +++ b/ta_lib/price/src/typical.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn typical_price(high: &Series, low: &Series, close: &Series) -> Series { +pub fn typical_price(high: &Price, low: &Price, close: &Price) -> Price { (high + low + close) / 3. } @@ -16,7 +16,7 @@ mod tests { let expected = vec![0.75, 1.5, 2.25]; - let result: Vec = typical_price(&high, &low, &close).into(); + let result: Vec = typical_price(&high, &low, &close).into(); assert_eq!(result, expected); } diff --git a/ta_lib/price/src/wcl.rs b/ta_lib/price/src/wcl.rs index 63ffecee..13813ef0 100644 --- a/ta_lib/price/src/wcl.rs +++ b/ta_lib/price/src/wcl.rs @@ -1,6 +1,6 @@ use core::prelude::*; -pub fn wcl(high: &Series, low: &Series, close: &Series) -> Series { +pub fn wcl(high: &Price, low: &Price, close: &Price) -> Price { (high + low + (close * 2.)) / 4. } @@ -16,7 +16,7 @@ mod tests { let expected = vec![0.75, 1.5, 2.25]; - let result: Vec = wcl(&high, &low, &close).into(); + let result: Vec = wcl(&high, &low, &close).into(); assert_eq!(result, expected); } diff --git a/ta_lib/strategies/base/Cargo.toml b/ta_lib/strategies/base/Cargo.toml index 7b086a0e..19e09286 100644 --- a/ta_lib/strategies/base/Cargo.toml +++ b/ta_lib/strategies/base/Cargo.toml @@ -12,5 +12,6 @@ repository.workspace = true [dependencies] once_cell = "1.19" core = { path = "../../core" } +timeseries = { path = "../../timeseries" } price = { path = "../../price" } volatility = { path = "../../indicators/volatility" } diff --git a/ta_lib/strategies/base/src/constants.rs b/ta_lib/strategies/base/src/constants.rs deleted file mode 100644 index 62db500e..00000000 --- a/ta_lib/strategies/base/src/constants.rs +++ /dev/null @@ -1,2 +0,0 @@ -pub const ZERO_LINE: f32 = 0.; -pub const NEUTRALITY_LINE: f32 = 50.; diff --git a/ta_lib/strategies/base/src/ffi.rs b/ta_lib/strategies/base/src/ffi.rs index 9e5db873..64f5c30f 100644 --- a/ta_lib/strategies/base/src/ffi.rs +++ b/ta_lib/strategies/base/src/ffi.rs @@ -1,19 +1,23 @@ use crate::{ - BaseLine, BaseStrategy, Confirm, Exit, Pulse, Signal, StopLoss, Strategy, TradeAction, OHLCV, + BaseLine, BaseStrategy, Confirm, Exit, Pulse, Signal, StopLoss, Strategy, TradeAction, }; use once_cell::sync::Lazy; use std::collections::HashMap; -use std::sync::{Mutex, MutexGuard, RwLock}; +use std::sync::atomic::{AtomicI32, Ordering}; +use std::sync::{Arc, RwLock}; +use timeseries::prelude::*; -static STRATEGY_ID_TO_INSTANCE: Lazy< - RwLock>>, -> = Lazy::new(|| RwLock::new(HashMap::new())); +type StrgTableType = Lazy>>>>; -static STRATEGY_ID_COUNTER: Lazy> = Lazy::new(|| RwLock::new(0)); +static STRATEGIES: StrgTableType = Lazy::new(|| Arc::new(RwLock::new(HashMap::new()))); +static STRATEGIES_ID_COUNTER: Lazy = Lazy::new(|| AtomicI32::new(0)); -static ALLOC_MUTEX: Lazy> = Lazy::new(|| Mutex::new(())); +fn generate_strategy_id() -> i32 { + STRATEGIES_ID_COUNTER.fetch_add(1, Ordering::SeqCst) +} pub fn register_strategy( + timeseries: Box, signal: Box, confirm: Box, pulse: Box, @@ -21,23 +25,23 @@ pub fn register_strategy( stop_loss: Box, exit: Box, ) -> i32 { - let mut id_counter = STRATEGY_ID_COUNTER.write().unwrap(); - *id_counter += 1; - - let current_id = *id_counter; - STRATEGY_ID_TO_INSTANCE.write().unwrap().insert( - current_id, - Box::new(BaseStrategy::new( - signal, confirm, pulse, base_line, stop_loss, exit, - )), - ); - - current_id + let mut strategies = STRATEGIES.write().unwrap(); + + let strategy_id = generate_strategy_id(); + + let strategy = Box::new(BaseStrategy::new( + timeseries, signal, confirm, pulse, base_line, stop_loss, exit, + )); + + strategies.insert(strategy_id, strategy); + + strategy_id } #[no_mangle] -pub fn unregister_strategy(strategy_id: i32) -> i32 { - let mut strategies = STRATEGY_ID_TO_INSTANCE.write().unwrap(); +pub fn strategy_unregister(strategy_id: i32) -> i32 { + let mut strategies = STRATEGIES.write().unwrap(); + strategies.remove(&strategy_id).is_some() as i32 } @@ -51,9 +55,10 @@ pub fn strategy_next( close: f32, volume: f32, ) -> (i32, f32) { - let mut strategies = STRATEGY_ID_TO_INSTANCE.write().unwrap(); + let mut strategies = STRATEGIES.write().unwrap(); + if let Some(strategy) = strategies.get_mut(&strategy_id) { - let ohlcv = OHLCV { + let bar = OHLCV { ts, open, high, @@ -62,7 +67,7 @@ pub fn strategy_next( volume, }; - let result = strategy.next(ohlcv); + let result = strategy.next(&bar); match result { TradeAction::GoLong(entry_price) => (1, entry_price), @@ -77,11 +82,27 @@ pub fn strategy_next( } #[no_mangle] -pub fn strategy_stop_loss(strategy_id: i32) -> (f32, f32) { - let mut strategies = STRATEGY_ID_TO_INSTANCE.write().unwrap(); - if let Some(strategy) = strategies.get_mut(&strategy_id) { - let stop_loss_levels = strategy.stop_loss(); +pub fn strategy_stop_loss( + strategy_id: i32, + ts: i64, + open: f32, + high: f32, + low: f32, + close: f32, + volume: f32, +) -> (f32, f32) { + let strategies = STRATEGIES.read().unwrap(); + if let Some(strategy) = strategies.get(&strategy_id) { + let bar = OHLCV { + ts, + open, + high, + low, + close, + volume, + }; + let stop_loss_levels = strategy.stop_loss(&bar); (stop_loss_levels.long, stop_loss_levels.short) } else { (-1.0, -1.0) @@ -90,10 +111,329 @@ pub fn strategy_stop_loss(strategy_id: i32) -> (f32, f32) { #[no_mangle] pub fn allocate(size: usize) -> *mut u8 { - let _guard: MutexGuard<_> = ALLOC_MUTEX.lock().unwrap(); - let mut buf = Vec::with_capacity(size); let ptr = buf.as_mut_ptr(); std::mem::forget(buf); ptr } + +#[cfg(test)] +mod tests { + use super::*; + use core::prelude::*; + + const PERIOD: usize = 7; + + struct MockSignal; + impl Signal for MockSignal { + fn lookback(&self) -> usize { + PERIOD + } + + fn trigger(&self, bar: &OHLCVSeries) -> (Series, Series) { + let len = bar.len(); + (Series::one(len).into(), Series::zero(len).into()) + } + } + + struct MockConfirm; + impl Confirm for MockConfirm { + fn lookback(&self) -> usize { + PERIOD + } + + fn filter(&self, bar: &OHLCVSeries) -> (Series, Series) { + let len = bar.len(); + (Series::one(len).into(), Series::zero(len).into()) + } + } + + struct MockPulse; + impl Pulse for MockPulse { + fn lookback(&self) -> usize { + PERIOD + } + + fn assess(&self, bar: &OHLCVSeries) -> (Series, Series) { + let len = bar.len(); + (Series::one(len).into(), Series::zero(len).into()) + } + } + + struct MockBaseLine; + impl BaseLine for MockBaseLine { + fn lookback(&self) -> usize { + PERIOD + } + + fn filter(&self, bar: &OHLCVSeries) -> (Series, Series) { + let len = bar.len(); + (Series::one(len).into(), Series::zero(len).into()) + } + + fn close(&self, bar: &OHLCVSeries) -> (Series, Series) { + let len = bar.len(); + (Series::zero(len).into(), Series::zero(len).into()) + } + } + + struct MockStopLoss; + impl StopLoss for MockStopLoss { + fn lookback(&self) -> usize { + PERIOD + } + + fn find(&self, bar: &OHLCVSeries) -> (Series, Series) { + let len = bar.len(); + (Series::zero(len), Series::zero(len)) + } + } + + struct MockExit; + impl Exit for MockExit { + fn lookback(&self) -> usize { + PERIOD + } + + fn close(&self, bar: &OHLCVSeries) -> (Series, Series) { + let len = bar.len(); + (Series::zero(len).into(), Series::zero(len).into()) + } + } + + #[test] + fn test_register_strategy() { + let timeseries = Box::::default(); + let signal = Box::new(MockSignal); + let confirm = Box::new(MockConfirm); + let pulse = Box::new(MockPulse); + let base_line = Box::new(MockBaseLine); + let stop_loss = Box::new(MockStopLoss); + let exit = Box::new(MockExit); + + let strategy_id = register_strategy( + timeseries, signal, confirm, pulse, base_line, stop_loss, exit, + ); + + assert!(strategy_id >= 0); + } + + #[test] + fn test_strategy_unregister() { + let timeseries = Box::::default(); + let signal = Box::new(MockSignal); + let confirm = Box::new(MockConfirm); + let pulse = Box::new(MockPulse); + let base_line = Box::new(MockBaseLine); + let stop_loss = Box::new(MockStopLoss); + let exit = Box::new(MockExit); + + let strategy_id = register_strategy( + timeseries, signal, confirm, pulse, base_line, stop_loss, exit, + ); + + assert_eq!(strategy_unregister(strategy_id), 1); + } + + #[test] + fn test_strategy_next() { + let timeseries = Box::::default(); + let signal = Box::new(MockSignal); + let confirm = Box::new(MockConfirm); + let pulse = Box::new(MockPulse); + let base_line = Box::new(MockBaseLine); + let stop_loss = Box::new(MockStopLoss); + let exit = Box::new(MockExit); + let ohlcv: Vec = vec![ + OHLCV { + ts: 1722710400876, + open: 0.29098, + high: 0.29309, + low: 0.29062, + close: 0.29215, + volume: 241728.0, + }, + OHLCV { + ts: 1722710700876, + open: 0.29215, + high: 0.2933, + low: 0.29193, + close: 0.29239, + volume: 88614.0, + }, + OHLCV { + ts: 1722711000877, + open: 0.29239, + high: 0.29256, + low: 0.28962, + close: 0.28982, + volume: 162963.0, + }, + OHLCV { + ts: 1722711300876, + open: 0.28982, + high: 0.2903, + low: 0.28909, + close: 0.28939, + volume: 201946.0, + }, + OHLCV { + ts: 1722711600883, + open: 0.28939, + high: 0.2911, + low: 0.28926, + close: 0.2911, + volume: 162808.0, + }, + OHLCV { + ts: 1722711900876, + open: 0.2911, + high: 0.29201, + low: 0.2897, + close: 0.29057, + volume: 170885.0, + }, + OHLCV { + ts: 1722712200876, + open: 0.29057, + high: 0.2918, + low: 0.28919, + close: 0.29152, + volume: 172555.0, + }, + OHLCV { + ts: 1722712500877, + open: 0.29152, + high: 0.29212, + low: 0.29027, + close: 0.29027, + volume: 101626.0, + }, + OHLCV { + ts: 1722712800876, + open: 0.29027, + high: 0.29029, + low: 0.28891, + close: 0.29026, + volume: 181359.0, + }, + OHLCV { + ts: 1722713100877, + open: 0.29026, + high: 0.29133, + low: 0.28933, + close: 0.29085, + volume: 79674.0, + }, + OHLCV { + ts: 1722713400876, + open: 0.29085, + high: 0.29111, + low: 0.28844, + close: 0.28854, + volume: 157827.0, + }, + OHLCV { + ts: 1722713700878, + open: 0.28854, + high: 0.29161, + low: 0.28852, + close: 0.29103, + volume: 213401.0, + }, + OHLCV { + ts: 1722714000876, + open: 0.29103, + high: 0.29145, + low: 0.29007, + close: 0.29025, + volume: 89210.0, + }, + OHLCV { + ts: 1722714300876, + open: 0.29025, + high: 0.29166, + low: 0.29005, + close: 0.29123, + volume: 80272.0, + }, + OHLCV { + ts: 1722714600876, + open: 0.29123, + high: 0.29235, + low: 0.29051, + close: 0.29082, + volume: 315809.0, + }, + OHLCV { + ts: 1722714300876, + open: 0.29025, + high: 0.29166, + low: 0.29005, + close: 0.29123, + volume: 80272.0, + }, + OHLCV { + ts: 1722714900876, + open: 0.29082, + high: 0.29196, + low: 0.28921, + close: 0.28935, + volume: 190734.0, + }, + OHLCV { + ts: 1722715200876, + open: 0.28935, + high: 0.28994, + low: 0.28853, + close: 0.28854, + volume: 249121.0, + }, + OHLCV { + ts: 1722716100876, + open: 0.288, + high: 0.29089, + low: 0.28766, + close: 0.2902, + volume: 100654.0, + }, + OHLCV { + ts: 1722715500877, + open: 0.28854, + high: 0.28924, + low: 0.28808, + close: 0.28915, + volume: 465408.0, + }, + OHLCV { + ts: 1722715800876, + open: 0.28915, + high: 0.28954, + low: 0.28712, + close: 0.288, + volume: 218446.0, + }, + ]; + + let strategy_id = register_strategy( + timeseries, signal, confirm, pulse, base_line, stop_loss, exit, + ); + + let mut res = vec![]; + for bar in &ohlcv { + let (action, _) = strategy_next( + strategy_id, + bar.ts, + bar.open, + bar.high, + bar.low, + bar.close, + bar.volume, + ); + res.push(action); + } + + assert_eq!(res.len(), ohlcv.len()); + assert_eq!(res[res.len() - 1], 1); + } +} diff --git a/ta_lib/strategies/base/src/lib.rs b/ta_lib/strategies/base/src/lib.rs index 5a39f843..232eee4d 100644 --- a/ta_lib/strategies/base/src/lib.rs +++ b/ta_lib/strategies/base/src/lib.rs @@ -1,17 +1,13 @@ extern crate alloc; -mod constants; mod ffi; -mod model; mod source; mod strategy; mod traits; mod volatility; pub mod prelude { - pub use crate::constants::*; pub use crate::ffi::*; - pub use crate::model::{OHLCVSeries, OHLCV}; pub use crate::source::*; pub use crate::strategy::{BaseStrategy, StopLossLevels, TradeAction}; pub use crate::traits::*; diff --git a/ta_lib/strategies/base/src/model.rs b/ta_lib/strategies/base/src/model.rs deleted file mode 100644 index 99fd9134..00000000 --- a/ta_lib/strategies/base/src/model.rs +++ /dev/null @@ -1,88 +0,0 @@ -use core::prelude::*; -use std::collections::{HashSet, VecDeque}; - -#[derive(Debug, Copy, Clone)] -pub struct OHLCV { - pub ts: i64, - pub open: f32, - pub high: f32, - pub low: f32, - pub close: f32, - pub volume: f32, -} - -#[derive(Debug, Clone)] -pub struct OHLCVSeries { - open: Series, - high: Series, - low: Series, - close: Series, - volume: Series, -} - -impl OHLCVSeries { - pub fn from_data(data: &VecDeque) -> Self { - let len = data.len(); - - let mut open = Vec::with_capacity(len); - let mut high = Vec::with_capacity(len); - let mut low = Vec::with_capacity(len); - let mut close = Vec::with_capacity(len); - let mut volume = Vec::with_capacity(len); - - let mut visited = HashSet::new(); - - let mut sorted_data: Vec<_> = data.iter().collect(); - sorted_data.sort_by_key(|v| v.ts); - - for ohlcv in sorted_data.iter() { - if !visited.contains(&ohlcv.ts) { - open.push(ohlcv.open); - high.push(ohlcv.high); - low.push(ohlcv.low); - close.push(ohlcv.close); - volume.push(ohlcv.volume); - - visited.insert(ohlcv.ts); - } - } - - Self { - open: Series::from(open), - high: Series::from(high), - low: Series::from(low), - close: Series::from(close), - volume: Series::from(volume), - } - } - - #[inline] - pub fn len(&self) -> usize { - self.close.len() - } - - #[inline] - pub fn open(&self) -> &Series { - &self.open - } - - #[inline] - pub fn high(&self) -> &Series { - &self.high - } - - #[inline] - pub fn low(&self) -> &Series { - &self.low - } - - #[inline] - pub fn close(&self) -> &Series { - &self.close - } - - #[inline] - pub fn volume(&self) -> &Series { - &self.volume - } -} diff --git a/ta_lib/strategies/base/src/source.rs b/ta_lib/strategies/base/src/source.rs index 025fac67..8dbc4688 100644 --- a/ta_lib/strategies/base/src/source.rs +++ b/ta_lib/strategies/base/src/source.rs @@ -1,6 +1,6 @@ -use crate::OHLCVSeries; use core::prelude::*; use price::prelude::*; +use timeseries::prelude::*; #[derive(Copy, Clone)] pub enum SourceType { @@ -12,12 +12,12 @@ pub enum SourceType { } pub trait Source { - fn source(&self, source_type: SourceType) -> Series; + fn source(&self, source_type: SourceType) -> Price; } impl Source for OHLCVSeries { #[inline] - fn source(&self, source_type: SourceType) -> Series { + fn source(&self, source_type: SourceType) -> Price { match source_type { SourceType::CLOSE => self.close().clone(), SourceType::HL2 => median_price(self.high(), self.low()), diff --git a/ta_lib/strategies/base/src/strategy.rs b/ta_lib/strategies/base/src/strategy.rs index 2ffe225d..f88b184f 100644 --- a/ta_lib/strategies/base/src/strategy.rs +++ b/ta_lib/strategies/base/src/strategy.rs @@ -1,28 +1,28 @@ use crate::source::{Source, SourceType}; -use crate::{BaseLine, Confirm, Exit, OHLCVSeries, Pulse, Signal, StopLoss, Strategy, OHLCV}; -use std::collections::VecDeque; +use crate::{BaseLine, Confirm, Exit, Pulse, Signal, StopLoss, Strategy}; +use core::prelude::*; +use timeseries::prelude::*; -const DEFAULT_LOOKBACK: usize = 55; -const DEFAULT_STOP_LEVEL: f32 = -1.0; -const DEFAULT_BUFF_SIZE: f32 = 1.236; +const DEFAULT_LOOKBACK: Period = 16; +const DEFAULT_STOP_LEVEL: Scalar = -1.0; #[derive(Debug, PartialEq)] pub enum TradeAction { - GoLong(f32), - GoShort(f32), - ExitLong(f32), - ExitShort(f32), + GoLong(Scalar), + GoShort(Scalar), + ExitLong(Scalar), + ExitShort(Scalar), DoNothing, } #[derive(Debug)] pub struct StopLossLevels { - pub long: f32, - pub short: f32, + pub long: Scalar, + pub short: Scalar, } pub struct BaseStrategy { - data: VecDeque, + timeseries: Box, signal: Box, confirm: Box, pulse: Box, @@ -34,6 +34,7 @@ pub struct BaseStrategy { impl BaseStrategy { pub fn new( + timeseries: Box, signal: Box, confirm: Box, pulse: Box, @@ -53,7 +54,7 @@ impl BaseStrategy { let lookback_period = lookbacks.into_iter().max().unwrap_or(DEFAULT_LOOKBACK); Self { - data: VecDeque::with_capacity(lookback_period), + timeseries, signal, confirm, pulse, @@ -64,47 +65,43 @@ impl BaseStrategy { } } - fn store(&mut self, data: OHLCV) { - let buf_size = (self.lookback_period as f32 * DEFAULT_BUFF_SIZE) as usize; - - if self.data.len() > buf_size { - self.data.pop_front(); - } - - self.data.push_back(data); + fn store(&mut self, bar: &OHLCV) { + self.timeseries.add(bar) } #[inline(always)] fn can_process(&self) -> bool { - self.data.len() >= self.lookback_period + self.timeseries.len() >= self.lookback_period } - #[inline(always)] - fn ohlcv_series(&self) -> OHLCVSeries { - OHLCVSeries::from_data(&self.data) + fn ohlcv(&self) -> OHLCVSeries { + self.timeseries.ohlcv(self.lookback_period) } } impl Strategy for BaseStrategy { - fn next(&mut self, data: OHLCV) -> TradeAction { - self.store(data); + fn next(&mut self, bar: &OHLCV) -> TradeAction { + self.store(bar); if !self.can_process() { return TradeAction::DoNothing; } - let theo_price = self.suggested_entry(); + let ohlcv = self.ohlcv(); + + let bar_index = ohlcv.bar_index(bar); + let theo_price = self.suggested_entry(&ohlcv, bar_index); - match self.trade_signals() { - (true, false, false, false) => TradeAction::GoLong(theo_price), - (false, true, false, false) => TradeAction::GoShort(theo_price), - (false, false, true, false) => TradeAction::ExitLong(theo_price), - (false, false, false, true) => TradeAction::ExitShort(theo_price), + match self.trade_signals(&ohlcv, bar_index) { + (true, _, _, _) => TradeAction::GoLong(theo_price), + (_, true, _, _) => TradeAction::GoShort(theo_price), + (_, _, true, _) => TradeAction::ExitLong(theo_price), + (_, _, _, true) => TradeAction::ExitShort(theo_price), _ => TradeAction::DoNothing, } } - fn stop_loss(&self) -> StopLossLevels { + fn stop_loss(&self, bar: &OHLCV) -> StopLossLevels { if !self.can_process() { return StopLossLevels { long: DEFAULT_STOP_LEVEL, @@ -112,7 +109,9 @@ impl Strategy for BaseStrategy { }; } - let (stop_loss_long, stop_loss_short) = self.stop_loss_levels(); + let ohlcv = self.ohlcv(); + let bar_index = ohlcv.bar_index(bar); + let (stop_loss_long, stop_loss_short) = self.stop_loss_levels(&ohlcv, bar_index); StopLossLevels { long: stop_loss_long, @@ -122,47 +121,47 @@ impl Strategy for BaseStrategy { } impl BaseStrategy { - fn trade_signals(&self) -> (bool, bool, bool, bool) { - let series = self.ohlcv_series(); + fn trade_signals(&self, ohlcv: &OHLCVSeries, bar_index: usize) -> (bool, bool, bool, bool) { + let (signal_go_long, signal_go_short) = self.signal.trigger(ohlcv); - let (go_long_trigger, go_short_trigger) = self.signal.generate(&series); - let (go_long_baseline, go_short_baseline) = self.base_line.generate(&series); - let (go_long_confirm, go_short_confirm) = self.confirm.validate(&series); - let (go_long_momentum, go_short_momentum) = self.pulse.assess(&series); - let (filter_long_baseline, filter_short_baseline) = self.base_line.filter(&series); - let (exit_long_eval, exit_short_eval) = self.exit.evaluate(&series); + let (baseline_confirm_long, baseline_confirm_short) = self.base_line.filter(ohlcv); + let (primary_confirm_long, primary_confirm_short) = self.confirm.filter(ohlcv); + let (pulse_confirm_long, pulse_confirm_short) = self.pulse.assess(ohlcv); - let go_long_signal = go_long_trigger | go_long_baseline; - let go_short_signal = go_short_trigger | go_short_baseline; + let (exit_close_long, exit_close_short) = self.exit.close(ohlcv); + let (baseline_close_long, baseline_close_short) = self.base_line.close(ohlcv); - let go_long = (go_long_signal & filter_long_baseline & go_long_confirm & go_long_momentum) - .last() - .unwrap_or(false); - let go_short = - (go_short_signal & filter_short_baseline & go_short_confirm & go_short_momentum) - .last() - .unwrap_or(false); + let confirm_long = primary_confirm_long & pulse_confirm_long; + let confirm_short = primary_confirm_short & pulse_confirm_short; + + let base_go_long = signal_go_long & baseline_confirm_long & confirm_long; + let base_go_short = signal_go_short & baseline_confirm_short & confirm_short; - let exit_long = exit_long_eval.last().unwrap_or(false); - let exit_short = exit_short_eval.last().unwrap_or(false); + let go_long = base_go_long.get(bar_index).unwrap_or(false); + let go_short = base_go_short.get(bar_index).unwrap_or(false); + + let exit_long = (exit_close_long | baseline_close_long) + .get(bar_index) + .unwrap_or(false); + let exit_short = (exit_close_short | baseline_close_short) + .get(bar_index) + .unwrap_or(false); (go_long, go_short, exit_long, exit_short) } - fn suggested_entry(&self) -> f32 { - self.ohlcv_series() - .source(SourceType::OHLC4) - .last() - .unwrap_or(std::f32::NAN) + fn suggested_entry(&self, ohlcv: &OHLCVSeries, bar_index: usize) -> Scalar { + ohlcv + .source(SourceType::CLOSE) + .get(bar_index) + .unwrap_or(NAN) } - fn stop_loss_levels(&self) -> (f32, f32) { - let series = self.ohlcv_series(); - - let (sl_long_find, sl_short_find) = self.stop_loss.find(&series); + fn stop_loss_levels(&self, ohlcv: &OHLCVSeries, bar_index: usize) -> (Scalar, Scalar) { + let (sl_long_find, sl_short_find) = self.stop_loss.find(ohlcv); - let stop_loss_long = sl_long_find.last().unwrap_or(std::f32::NAN); - let stop_loss_short = sl_short_find.last().unwrap_or(std::f32::NAN); + let stop_loss_long = sl_long_find.get(bar_index).unwrap_or(NAN); + let stop_loss_short = sl_short_find.get(bar_index).unwrap_or(NAN); (stop_loss_long, stop_loss_short) } @@ -170,15 +169,9 @@ impl BaseStrategy { #[cfg(test)] mod tests { - use crate::source::{Source, SourceType}; - use crate::{ - BaseLine, BaseStrategy, Confirm, Exit, OHLCVSeries, Pulse, Signal, StopLoss, Strategy, - TradeAction, OHLCV, - }; + use crate::{BaseLine, BaseStrategy, Confirm, Exit, Pulse, Signal, StopLoss}; use core::Series; - - const DEFAULT_BUFF_SIZE: f32 = 1.3; - const DEFAULT_LOOKBACK: usize = 55; + use timeseries::{BaseTimeSeries, OHLCVSeries}; struct MockSignal { fast_period: usize, @@ -189,7 +182,7 @@ mod tests { self.fast_period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let len = data.len(); (Series::one(len).into(), Series::zero(len).into()) } @@ -204,7 +197,7 @@ mod tests { self.period } - fn validate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { let len = data.len(); (Series::one(len).into(), Series::zero(len).into()) } @@ -234,14 +227,14 @@ mod tests { self.period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { let len = data.len(); (Series::one(len).into(), Series::zero(len).into()) } - fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { let len = data.len(); - (Series::one(len).into(), Series::zero(len).into()) + (Series::zero(len).into(), Series::one(len).into()) } } @@ -271,7 +264,7 @@ mod tests { 0 } - fn evaluate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { let len = data.len(); (Series::one(len).into(), Series::zero(len).into()) } @@ -280,6 +273,7 @@ mod tests { #[test] fn test_base_strategy_lookback() { let strategy = BaseStrategy::new( + Box::::default(), Box::new(MockSignal { fast_period: 10 }), Box::new(MockConfirm { period: 1 }), Box::new(MockPulse { period: 7 }), @@ -290,50 +284,6 @@ mod tests { }), Box::new(MockExit {}), ); - assert_eq!(strategy.lookback_period, 55); - } - - #[test] - fn test_strategy_data() { - let mut strategy = BaseStrategy::new( - Box::new(MockSignal { fast_period: 10 }), - Box::new(MockConfirm { period: 1 }), - Box::new(MockPulse { period: 7 }), - Box::new(MockBaseLine { period: 15 }), - Box::new(MockStopLoss { - period: 2, - multi: 2.0, - }), - Box::new(MockExit {}), - ); - let lookback = (DEFAULT_BUFF_SIZE * DEFAULT_LOOKBACK as f32) as usize; - let data = OHLCV { - ts: 1710297600000, - open: 1.0, - high: 2.0, - low: 0.5, - close: 1.5, - volume: 100.0, - }; - let ohlcvs = vec![data; lookback]; - - let mut action = TradeAction::DoNothing; - - for ohlcv in ohlcvs { - action = strategy.next(ohlcv); - } - - let series = OHLCVSeries::from_data(&strategy.data); - - let hl2: Vec = series.source(SourceType::HL2).into(); - let hlc3: Vec = series.source(SourceType::HLC3).into(); - let hlcc4: Vec = series.source(SourceType::HLCC4).into(); - let ohlc4: Vec = series.source(SourceType::OHLC4).into(); - - assert_eq!(hl2, vec![1.25]); - assert_eq!(hlc3, vec![1.333_333_4]); - assert_eq!(hlcc4, vec![1.375]); - assert_eq!(ohlc4, vec![1.25]); - assert_eq!(action, TradeAction::DoNothing); + assert_eq!(strategy.lookback_period, 16); } } diff --git a/ta_lib/strategies/base/src/traits.rs b/ta_lib/strategies/base/src/traits.rs index 2ff366e0..226dfee1 100644 --- a/ta_lib/strategies/base/src/traits.rs +++ b/ta_lib/strategies/base/src/traits.rs @@ -1,38 +1,39 @@ -use crate::{OHLCVSeries, StopLossLevels, TradeAction, OHLCV}; +use crate::{StopLossLevels, TradeAction}; use core::prelude::*; +use timeseries::prelude::*; pub trait Signal: Send + Sync { fn lookback(&self) -> usize; - fn generate(&self, data: &OHLCVSeries) -> (Series, Series); + fn trigger(&self, data: &OHLCVSeries) -> (Rule, Rule); } pub trait Confirm: Send + Sync { fn lookback(&self) -> usize; - fn validate(&self, data: &OHLCVSeries) -> (Series, Series); + fn filter(&self, data: &OHLCVSeries) -> (Rule, Rule); } pub trait Pulse: Send + Sync { fn lookback(&self) -> usize; - fn assess(&self, data: &OHLCVSeries) -> (Series, Series); + fn assess(&self, data: &OHLCVSeries) -> (Rule, Rule); } pub trait BaseLine: Send + Sync { fn lookback(&self) -> usize; - fn filter(&self, data: &OHLCVSeries) -> (Series, Series); - fn generate(&self, data: &OHLCVSeries) -> (Series, Series); + fn filter(&self, data: &OHLCVSeries) -> (Rule, Rule); + fn close(&self, data: &OHLCVSeries) -> (Rule, Rule); } pub trait Exit: Send + Sync { fn lookback(&self) -> usize; - fn evaluate(&self, data: &OHLCVSeries) -> (Series, Series); + fn close(&self, data: &OHLCVSeries) -> (Rule, Rule); } pub trait StopLoss: Send + Sync { fn lookback(&self) -> usize; - fn find(&self, data: &OHLCVSeries) -> (Series, Series); + fn find(&self, data: &OHLCVSeries) -> (Price, Price); } pub trait Strategy { - fn next(&mut self, ohlcv: OHLCV) -> TradeAction; - fn stop_loss(&self) -> StopLossLevels; + fn next(&mut self, bar: &OHLCV) -> TradeAction; + fn stop_loss(&self, bar: &OHLCV) -> StopLossLevels; } diff --git a/ta_lib/strategies/base/src/volatility.rs b/ta_lib/strategies/base/src/volatility.rs index 51cf688f..ab36c202 100644 --- a/ta_lib/strategies/base/src/volatility.rs +++ b/ta_lib/strategies/base/src/volatility.rs @@ -1,20 +1,26 @@ -use crate::OHLCVSeries; use core::prelude::*; -use volatility::{atr, tr}; +use timeseries::prelude::*; +use volatility::{tr, wtr}; pub trait Volatility { - fn atr(&self, period: usize) -> Series; - fn tr(&self) -> Series; + fn atr(&self, smooth: Smooth, period: usize) -> Price; + fn tr(&self) -> Price; + fn wtr(&self) -> Price; } impl Volatility for OHLCVSeries { #[inline] - fn atr(&self, period: usize) -> Series { - atr(self.high(), self.low(), self.close(), Smooth::SMMA, period) + fn atr(&self, smooth: Smooth, period: usize) -> Price { + self.tr().smooth(smooth, period) } #[inline] - fn tr(&self) -> Series { + fn tr(&self) -> Price { tr(self.high(), self.low(), self.close()) } + + #[inline] + fn wtr(&self) -> Price { + wtr(self.high(), self.low(), self.close()) + } } diff --git a/ta_lib/strategies/baseline/Cargo.toml b/ta_lib/strategies/baseline/Cargo.toml index c8cbb6b4..27da318e 100644 --- a/ta_lib/strategies/baseline/Cargo.toml +++ b/ta_lib/strategies/baseline/Cargo.toml @@ -15,4 +15,5 @@ base = { path = "../base" } core = { path = "../../core" } trend = { path = "../../indicators/trend" } signal = { path = "../signal" } -indicator = { path = "../indicator" } \ No newline at end of file +indicator = { path = "../indicator" } +timeseries = { path = "../../timeseries" } \ No newline at end of file diff --git a/ta_lib/strategies/baseline/src/ma.rs b/ta_lib/strategies/baseline/src/ma.rs index 7c8405bb..926f94bc 100644 --- a/ta_lib/strategies/baseline/src/ma.rs +++ b/ta_lib/strategies/baseline/src/ma.rs @@ -1,69 +1,50 @@ use base::prelude::*; use core::prelude::*; use indicator::{ma_indicator, MovingAverageType}; -use signal::{MaCrossSignal, MaQuadrupleSignal, MaSurpassSignal, MaTestingGroundSignal}; +use timeseries::prelude::*; -const DEFAULT_ATR_LOOKBACK: usize = 14; -const DEFAULT_ATR_FACTOR: f32 = 1.236; +const DEFAULT_ATR_LOOKBACK: Period = 14; +const DEFAULT_ATR_FACTOR: Scalar = 1.2; +const DEFAULT_ATR_SMOOTH: Smooth = Smooth::EMA; pub struct MaBaseLine { - source_type: SourceType, + source: SourceType, ma: MovingAverageType, period: usize, - signal: Vec>, } impl MaBaseLine { - pub fn new(source_type: SourceType, ma: MovingAverageType, period: f32) -> Self { + pub fn new(source: SourceType, ma: MovingAverageType, period: f32) -> Self { Self { - source_type, + source, ma, period: period as usize, - signal: vec![ - Box::new(MaSurpassSignal::new(source_type, ma, period)), - Box::new(MaQuadrupleSignal::new(source_type, ma, period)), - ], } } } impl BaseLine for MaBaseLine { fn lookback(&self) -> usize { - let mut m = std::cmp::max(DEFAULT_ATR_LOOKBACK, self.period); - - for signal in &self.signal { - m = std::cmp::max(m, signal.lookback()); - } - - m + std::cmp::max(DEFAULT_ATR_LOOKBACK, self.period) } fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { - let ma = ma_indicator(&self.ma, data, self.source_type, self.period); - let prev_ma = ma.shift(1); + let ma = ma_indicator(&self.ma, data, self.source, self.period); + let close = data.close(); - let dist = (&ma - data.close()).abs(); - let atr = data.atr(DEFAULT_ATR_LOOKBACK) * DEFAULT_ATR_FACTOR; + let dist = (&ma - close).abs(); + let atr = data.atr(DEFAULT_ATR_SMOOTH, DEFAULT_ATR_LOOKBACK) * DEFAULT_ATR_FACTOR; ( - ma.sgt(&prev_ma) & dist.slt(&atr), - ma.slt(&prev_ma) & dist.slt(&atr), + close.sgt(&ma) & dist.slt(&atr), + close.slt(&ma) & dist.slt(&atr), ) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { - let lookback = self.lookback(); - - let mut go_long_signal: Series = Series::zero(lookback).into(); - let mut go_short_signal: Series = Series::zero(lookback).into(); - - for signal in &self.signal { - let (go_long, go_short) = signal.generate(data); - - go_long_signal = go_long_signal | go_long; - go_short_signal = go_short_signal | go_short; - } + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { + let ma = ma_indicator(&self.ma, data, self.source, self.period); + let close = data.close(); - (go_long_signal, go_short_signal) + (close.cross_under(&ma), close.cross_over(&ma)) } } diff --git a/ta_lib/strategies/confirm/Cargo.toml b/ta_lib/strategies/confirm/Cargo.toml index dfd52da2..29d6b866 100644 --- a/ta_lib/strategies/confirm/Cargo.toml +++ b/ta_lib/strategies/confirm/Cargo.toml @@ -12,6 +12,8 @@ repository.workspace = true [dependencies] base = { path = "../base" } core = { path = "../../core" } +timeseries = { path = "../../timeseries" } momentum = { path = "../../indicators/momentum" } trend = { path = "../../indicators/trend" } -volume = { path = "../../indicators/volume" } \ No newline at end of file +volume = { path = "../../indicators/volume" } +volatility = { path = "../../indicators/volatility" } \ No newline at end of file diff --git a/ta_lib/strategies/confirm/src/bb.rs b/ta_lib/strategies/confirm/src/bb.rs new file mode 100644 index 00000000..f6eeaa8c --- /dev/null +++ b/ta_lib/strategies/confirm/src/bb.rs @@ -0,0 +1,37 @@ +use base::prelude::*; +use core::prelude::*; +use timeseries::prelude::*; +use volatility::bb; + +pub struct BbConfirm { + smooth: Smooth, + period: usize, + factor: f32, +} + +impl BbConfirm { + pub fn new(smooth: Smooth, period: f32, factor: f32) -> Self { + Self { + smooth, + period: period as usize, + factor, + } + } +} + +impl Confirm for BbConfirm { + fn lookback(&self) -> usize { + self.period + } + + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { + let close = data.close(); + let prev_close = close.shift(1); + let (upper_bb, _, lower_bb) = bb(close, self.smooth, self.period, self.factor); + + ( + close.sgt(&lower_bb) & prev_close.slt(&lower_bb), + close.slt(&upper_bb) & prev_close.sgt(&upper_bb), + ) + } +} diff --git a/ta_lib/strategies/pulse/src/braid.rs b/ta_lib/strategies/confirm/src/braid.rs similarity index 80% rename from ta_lib/strategies/pulse/src/braid.rs rename to ta_lib/strategies/confirm/src/braid.rs index 0d229290..20c5c86b 100644 --- a/ta_lib/strategies/pulse/src/braid.rs +++ b/ta_lib/strategies/confirm/src/braid.rs @@ -1,23 +1,26 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; -pub struct BraidPulse { +pub struct BraidConfirm { smooth_type: Smooth, fast_period: usize, slow_period: usize, open_period: usize, strength: f32, - atr_period: usize, + smooth_atr: Smooth, + period_atr: usize, } -impl BraidPulse { +impl BraidConfirm { pub fn new( smooth_type: Smooth, fast_period: f32, slow_period: f32, open_period: f32, strength: f32, - atr_period: f32, + smooth_atr: Smooth, + period_atr: f32, ) -> Self { Self { smooth_type, @@ -25,23 +28,24 @@ impl BraidPulse { slow_period: slow_period as usize, open_period: open_period as usize, strength, - atr_period: atr_period as usize, + smooth_atr, + period_atr: period_atr as usize, } } } -impl Pulse for BraidPulse { +impl Confirm for BraidConfirm { fn lookback(&self) -> usize { let adj_lookback = std::cmp::max(self.fast_period, self.slow_period); std::cmp::max(adj_lookback, self.open_period) } - fn assess(&self, data: &OHLCVSeries) -> (Series, Series) { + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { let fast_ma = data.close().smooth(self.smooth_type, self.fast_period); let open_ma = data.open().smooth(self.smooth_type, self.open_period); let slow_ma = data.close().smooth(self.smooth_type, self.slow_period); - let filter = data.atr(self.atr_period) * self.strength / 100.0; + let filter = data.atr(self.smooth_atr, self.period_atr) * self.strength / 100.0; let max = fast_ma.max(&open_ma).max(&slow_ma); let min = fast_ma.min(&open_ma).min(&slow_ma); @@ -49,12 +53,8 @@ impl Pulse for BraidPulse { let histogram = max - min; ( - histogram.sgt(&filter) - & (fast_ma.cross_over(&open_ma) - | (histogram.cross_over(&filter) & fast_ma.sgt(&open_ma))), - histogram.sgt(&filter) - & (fast_ma.cross_under(&open_ma) - | (histogram.cross_over(&filter) & fast_ma.slt(&open_ma))), + histogram.sgt(&filter) & fast_ma.sgt(&open_ma), + histogram.sgt(&filter) & fast_ma.slt(&open_ma), ) } } @@ -62,12 +62,11 @@ impl Pulse for BraidPulse { #[cfg(test)] mod tests { use super::*; - use std::collections::VecDeque; #[test] - fn test_pulse_braid() { - let pulse = BraidPulse::new(Smooth::LSMA, 3.0, 14.0, 7.0, 40.0, 14.0); - let data = VecDeque::from([ + fn test_confirm_braid() { + let confirm = BraidConfirm::new(Smooth::LSMA, 3.0, 14.0, 7.0, 40.0, Smooth::SMMA, 14.0); + let data = vec![ OHLCV { ts: 1679827200, open: 4.8914, @@ -188,18 +187,18 @@ mod tests { close: 4.8925, volume: 100.0, }, - ]); - let series = OHLCVSeries::from_data(&data); + ]; + let series = OHLCVSeries::from(data); - let (long_signal, short_signal) = pulse.assess(&series); + let (long_signal, short_signal) = confirm.filter(&series); let expected_long_signal = vec![ - false, false, false, false, false, false, false, false, false, false, false, false, + true, true, false, false, false, false, false, false, false, false, false, false, false, false, true, ]; let expected_short_signal = vec![ - false, false, false, false, false, false, true, false, false, false, false, false, - false, false, false, + false, false, false, false, false, false, true, true, true, true, true, true, true, + true, false, ]; let result_long_signal: Vec = long_signal.into(); diff --git a/ta_lib/strategies/confirm/src/cc.rs b/ta_lib/strategies/confirm/src/cc.rs new file mode 100644 index 00000000..ad9d4b07 --- /dev/null +++ b/ta_lib/strategies/confirm/src/cc.rs @@ -0,0 +1,62 @@ +use base::prelude::*; +use core::prelude::*; +use momentum::cc; +use timeseries::prelude::*; + +pub struct CcConfirm { + source: SourceType, + period_fast: usize, + period_slow: usize, + smooth: Smooth, + period_smooth: usize, + smooth_signal: Smooth, + period_signal: usize, +} + +impl CcConfirm { + pub fn new( + source: SourceType, + period_fast: f32, + period_slow: f32, + smooth: Smooth, + period_smooth: f32, + smooth_signal: Smooth, + period_signal: f32, + ) -> Self { + Self { + source, + period_fast: period_fast as usize, + period_slow: period_slow as usize, + smooth, + period_smooth: period_smooth as usize, + smooth_signal, + period_signal: period_signal as usize, + } + } +} + +impl Confirm for CcConfirm { + fn lookback(&self) -> usize { + std::cmp::max( + std::cmp::max(self.period_fast, self.period_slow), + std::cmp::max(self.period_smooth, self.period_signal), + ) + } + + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { + let cc = cc( + &data.source(self.source), + self.period_fast, + self.period_slow, + self.smooth, + self.period_smooth, + ); + + let signal = cc.smooth(self.smooth_signal, self.period_signal); + + ( + cc.sgt(&signal) & cc.sgt(&ZERO), + cc.slt(&signal) & cc.slt(&ZERO), + ) + } +} diff --git a/ta_lib/strategies/confirm/src/cci.rs b/ta_lib/strategies/confirm/src/cci.rs index 4f36810f..ea572255 100644 --- a/ta_lib/strategies/confirm/src/cci.rs +++ b/ta_lib/strategies/confirm/src/cci.rs @@ -1,41 +1,51 @@ use base::prelude::*; use core::prelude::*; use momentum::cci; +use timeseries::prelude::*; -const CCI_UPPER_BARRIER: f32 = 50.; -const CCI_LOWER_BARRIER: f32 = -50.; +const CCI_UPPER_NEUTRALITY: f32 = 50.; +const CCI_LOWER_NEUTRALITY: f32 = -50.; +const CCI_UPPER_BARRIER: f32 = 100.; +const CCI_LOWER_BARRIER: f32 = -100.; pub struct CciConfirm { - source_type: SourceType, - smooth_type: Smooth, + source: SourceType, period: usize, factor: f32, + smooth: Smooth, + period_smooth: usize, } impl CciConfirm { - pub fn new(source_type: SourceType, smooth_type: Smooth, period: f32, factor: f32) -> Self { + pub fn new( + source: SourceType, + period: f32, + factor: f32, + smooth: Smooth, + period_smooth: f32, + ) -> Self { Self { - source_type, - smooth_type, + source, period: period as usize, factor, + smooth, + period_smooth: period_smooth as usize, } } } impl Confirm for CciConfirm { fn lookback(&self) -> usize { - self.period + std::cmp::max(self.period, self.period_smooth) } - fn validate(&self, data: &OHLCVSeries) -> (Series, Series) { - let cci = cci( - &data.source(self.source_type), - self.smooth_type, - self.period, - self.factor, - ); + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { + let cci = cci(&data.source(self.source), self.period, self.factor) + .smooth(self.smooth, self.period_smooth); - (cci.sgt(&CCI_UPPER_BARRIER), cci.slt(&CCI_LOWER_BARRIER)) + ( + cci.sgt(&CCI_UPPER_NEUTRALITY) & cci.slt(&CCI_UPPER_BARRIER), + cci.slt(&CCI_LOWER_NEUTRALITY) & cci.sgt(&CCI_LOWER_BARRIER), + ) } } diff --git a/ta_lib/strategies/confirm/src/didi.rs b/ta_lib/strategies/confirm/src/didi.rs new file mode 100644 index 00000000..1955042b --- /dev/null +++ b/ta_lib/strategies/confirm/src/didi.rs @@ -0,0 +1,54 @@ +use base::prelude::*; +use core::prelude::*; +use timeseries::prelude::*; + +const DIDI_NEUTRALITY: f32 = 1.; + +pub struct DidiConfirm { + source: SourceType, + smooth: Smooth, + period_medium: usize, + period_slow: usize, + smooth_signal: Smooth, + period_signal: usize, +} + +impl DidiConfirm { + pub fn new( + source: SourceType, + smooth: Smooth, + period_medium: f32, + period_slow: f32, + smooth_signal: Smooth, + period_signal: f32, + ) -> Self { + Self { + source, + smooth, + period_medium: period_medium as usize, + period_slow: period_slow as usize, + smooth_signal, + period_signal: period_signal as usize, + } + } +} + +impl Confirm for DidiConfirm { + fn lookback(&self) -> usize { + std::cmp::max( + self.period_signal, + std::cmp::max(self.period_medium, self.period_slow), + ) + } + + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { + let source = data.source(self.source); + + let med_line = source.smooth(self.smooth, self.period_medium); + let long_line = source.smooth(self.smooth, self.period_slow) / med_line; + + let signal = long_line.smooth(self.smooth_signal, self.period_signal); + + (signal.sgt(&DIDI_NEUTRALITY), signal.slt(&DIDI_NEUTRALITY)) + } +} diff --git a/ta_lib/strategies/confirm/src/dpo.rs b/ta_lib/strategies/confirm/src/dpo.rs index 48f06164..7ccc47c4 100644 --- a/ta_lib/strategies/confirm/src/dpo.rs +++ b/ta_lib/strategies/confirm/src/dpo.rs @@ -1,18 +1,22 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use trend::dpo; +const DPO_UPPER_BARRIER: f32 = 0.005; +const DPO_LOWER_BARRIER: f32 = -0.005; + pub struct DpoConfirm { - source_type: SourceType, - smooth_type: Smooth, + source: SourceType, + smooth: Smooth, period: usize, } impl DpoConfirm { - pub fn new(source_type: SourceType, smooth_type: Smooth, period: f32) -> Self { + pub fn new(source: SourceType, smooth: Smooth, period: f32) -> Self { Self { - source_type, - smooth_type, + source, + smooth, period: period as usize, } } @@ -23,13 +27,9 @@ impl Confirm for DpoConfirm { self.period } - fn validate(&self, data: &OHLCVSeries) -> (Series, Series) { - let dpo = dpo( - &data.source(self.source_type), - self.smooth_type, - self.period, - ); + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { + let dpo = dpo(&data.source(self.source), self.smooth, self.period); - (dpo.sgt(&ZERO_LINE), dpo.slt(&ZERO_LINE)) + (dpo.sgt(&DPO_UPPER_BARRIER), dpo.slt(&DPO_LOWER_BARRIER)) } } diff --git a/ta_lib/strategies/confirm/src/dso.rs b/ta_lib/strategies/confirm/src/dso.rs deleted file mode 100644 index 0976d052..00000000 --- a/ta_lib/strategies/confirm/src/dso.rs +++ /dev/null @@ -1,199 +0,0 @@ -use base::prelude::*; -use core::prelude::*; -use momentum::dso; - -pub struct DsoConfirm { - source_type: SourceType, - smooth_type: Smooth, - smooth_period: usize, - k_period: usize, - d_period: usize, -} - -impl DsoConfirm { - pub fn new( - source_type: SourceType, - smooth_type: Smooth, - smooth_period: f32, - k_period: f32, - d_period: f32, - ) -> Self { - Self { - source_type, - smooth_type, - smooth_period: smooth_period as usize, - k_period: k_period as usize, - d_period: d_period as usize, - } - } -} - -impl Confirm for DsoConfirm { - fn lookback(&self) -> usize { - let period = std::cmp::max(self.smooth_period, self.k_period); - std::cmp::max(period, self.d_period) - } - - fn validate(&self, data: &OHLCVSeries) -> (Series, Series) { - let (k, d) = dso( - &data.source(self.source_type), - self.smooth_type, - self.smooth_period, - self.k_period, - self.d_period, - ); - - (k.sgt(&d), k.slt(&d)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::collections::VecDeque; - - #[test] - fn test_confirm_dso() { - let confirm = DsoConfirm::new(SourceType::CLOSE, Smooth::EMA, 13.0, 8.0, 9.0); - let data = VecDeque::from([ - OHLCV { - ts: 1679827200, - open: 4.8914, - high: 4.9045, - low: 4.8895, - close: 4.8995, - volume: 100.0, - }, - OHLCV { - ts: 1679827500, - open: 4.8995, - high: 4.9073, - low: 4.8995, - close: 4.9061, - volume: 100.0, - }, - OHLCV { - ts: 1679827800, - open: 4.9061, - high: 4.9070, - low: 4.9001, - close: 4.9001, - volume: 100.0, - }, - OHLCV { - ts: 1679828100, - open: 4.9001, - high: 4.9053, - low: 4.8995, - close: 4.9053, - volume: 100.0, - }, - OHLCV { - ts: 1679828400, - open: 4.9053, - high: 4.9093, - low: 4.9046, - close: 4.9087, - volume: 100.0, - }, - OHLCV { - ts: 1679828700, - open: 4.9087, - high: 4.9154, - low: 4.9087, - close: 4.9131, - volume: 100.0, - }, - OHLCV { - ts: 1679829000, - open: 4.9131, - high: 4.9131, - low: 4.9040, - close: 4.9041, - volume: 100.0, - }, - OHLCV { - ts: 1679829300, - open: 4.9041, - high: 4.9068, - low: 4.8988, - close: 4.9023, - volume: 100.0, - }, - OHLCV { - ts: 1679829600, - open: 4.9023, - high: 4.9051, - low: 4.8949, - close: 4.9010, - volume: 100.0, - }, - OHLCV { - ts: 1679829900, - open: 4.9010, - high: 4.9052, - low: 4.8969, - close: 4.8969, - volume: 100.0, - }, - OHLCV { - ts: 1679830200, - open: 4.8969, - high: 4.8969, - low: 4.8819, - close: 4.8895, - volume: 100.0, - }, - OHLCV { - ts: 1679830500, - open: 4.8895, - high: 4.8928, - low: 4.8851, - close: 4.8901, - volume: 100.0, - }, - OHLCV { - ts: 1679830800, - open: 4.8901, - high: 4.8910, - low: 4.8813, - close: 4.8855, - volume: 100.0, - }, - OHLCV { - ts: 1679831100, - open: 4.8855, - high: 4.8864, - low: 4.8816, - close: 4.8824, - volume: 100.0, - }, - OHLCV { - ts: 1679831400, - open: 4.8824, - high: 4.8934, - low: 4.8814, - close: 4.8925, - volume: 100.0, - }, - ]); - let series = OHLCVSeries::from_data(&data); - - let (long_signal, short_signal) = confirm.validate(&series); - - let expected_long_signal = vec![ - false, true, true, true, true, true, true, true, true, true, false, false, false, - false, false, - ]; - let expected_short_signal = vec![ - false, false, false, false, false, false, false, false, false, false, true, true, true, - true, true, - ]; - - let result_long_signal: Vec = long_signal.into(); - let result_short_signal: Vec = short_signal.into(); - - assert_eq!(result_long_signal, expected_long_signal); - assert_eq!(result_short_signal, expected_short_signal); - } -} diff --git a/ta_lib/strategies/confirm/src/dumb.rs b/ta_lib/strategies/confirm/src/dumb.rs index 03629ede..0ea8bdde 100644 --- a/ta_lib/strategies/confirm/src/dumb.rs +++ b/ta_lib/strategies/confirm/src/dumb.rs @@ -1,5 +1,6 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; pub struct DumbConfirm { period: usize, @@ -18,7 +19,7 @@ impl Confirm for DumbConfirm { self.period } - fn validate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { let len = data.len(); (Series::one(len).into(), Series::one(len).into()) diff --git a/ta_lib/strategies/confirm/src/eom.rs b/ta_lib/strategies/confirm/src/eom.rs index 6d54ad95..99dcf7e1 100644 --- a/ta_lib/strategies/confirm/src/eom.rs +++ b/ta_lib/strategies/confirm/src/eom.rs @@ -1,21 +1,20 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use volume::eom; pub struct EomConfirm { - source_type: SourceType, - smooth_type: Smooth, + source: SourceType, + smooth: Smooth, period: usize, - divisor: f32, } impl EomConfirm { - pub fn new(source_type: SourceType, smooth_type: Smooth, period: f32, divisor: f32) -> Self { + pub fn new(source: SourceType, smooth: Smooth, period: f32) -> Self { Self { - source_type, - smooth_type, + source, + smooth, period: period as usize, - divisor, } } } @@ -25,17 +24,16 @@ impl Confirm for EomConfirm { self.period } - fn validate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { let eom = eom( - &data.source(self.source_type), + &data.source(self.source), data.high(), data.low(), data.volume(), - self.smooth_type, + self.smooth, self.period, - self.divisor, ); - (eom.sgt(&ZERO_LINE), eom.slt(&ZERO_LINE)) + (eom.sgt(&ZERO), eom.slt(&ZERO)) } } diff --git a/ta_lib/strategies/confirm/src/lib.rs b/ta_lib/strategies/confirm/src/lib.rs index 5b1210e5..f7c3c16d 100644 --- a/ta_lib/strategies/confirm/src/lib.rs +++ b/ta_lib/strategies/confirm/src/lib.rs @@ -1,21 +1,25 @@ +mod bb; +mod braid; +mod cc; mod cci; +mod didi; mod dpo; -mod dso; mod dumb; mod eom; -mod roc; mod rsi_neutrality; mod rsi_signalline; mod stc; -mod vi; +mod wpr; +pub use bb::BbConfirm; +pub use braid::BraidConfirm; +pub use cc::CcConfirm; pub use cci::CciConfirm; +pub use didi::DidiConfirm; pub use dpo::DpoConfirm; -pub use dso::DsoConfirm; pub use dumb::DumbConfirm; pub use eom::EomConfirm; -pub use roc::RocConfirm; pub use rsi_neutrality::RsiNeutralityConfirm; pub use rsi_signalline::RsiSignalLineConfirm; pub use stc::StcConfirm; -pub use vi::ViConfirm; +pub use wpr::WprConfirm; diff --git a/ta_lib/strategies/confirm/src/roc.rs b/ta_lib/strategies/confirm/src/roc.rs deleted file mode 100644 index 751f2211..00000000 --- a/ta_lib/strategies/confirm/src/roc.rs +++ /dev/null @@ -1,29 +0,0 @@ -use base::prelude::*; -use core::prelude::*; -use momentum::roc; - -pub struct RocConfirm { - source_type: SourceType, - period: usize, -} - -impl RocConfirm { - pub fn new(source_type: SourceType, period: f32) -> Self { - Self { - source_type, - period: period as usize, - } - } -} - -impl Confirm for RocConfirm { - fn lookback(&self) -> usize { - self.period - } - - fn validate(&self, data: &OHLCVSeries) -> (Series, Series) { - let roc = roc(&data.source(self.source_type), self.period); - - (roc.sgt(&ZERO_LINE), roc.slt(&ZERO_LINE)) - } -} diff --git a/ta_lib/strategies/confirm/src/rsi_neutrality.rs b/ta_lib/strategies/confirm/src/rsi_neutrality.rs index 28e9eff3..f266022f 100644 --- a/ta_lib/strategies/confirm/src/rsi_neutrality.rs +++ b/ta_lib/strategies/confirm/src/rsi_neutrality.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::rsi; +use timeseries::prelude::*; const RSI_UPPER_BARRIER: f32 = 75.0; const RSI_LOWER_BARRIER: f32 = 25.0; @@ -28,7 +29,7 @@ impl Confirm for RsiNeutralityConfirm { self.period } - fn validate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { let rsi = rsi( &data.source(self.source_type), self.smooth_type, @@ -37,20 +38,20 @@ impl Confirm for RsiNeutralityConfirm { let lower_barrier = RSI_LOWER_BARRIER + self.threshold; let upper_barrier = RSI_UPPER_BARRIER - self.threshold; - let lower_neutrality = NEUTRALITY_LINE - self.threshold; - let upper_neutrality = NEUTRALITY_LINE + self.threshold; + let lower_neutrality = NEUTRALITY - self.threshold; + let upper_neutrality = NEUTRALITY + self.threshold; let prev_rsi = rsi.shift(1); let back_2_rsi = rsi.shift(2); let back_3_rsi = rsi.shift(3); ( - rsi.sgt(&NEUTRALITY_LINE) + rsi.sgt(&NEUTRALITY) & rsi.slt(&upper_barrier) & prev_rsi.sgt(&lower_neutrality) & back_2_rsi.sgt(&lower_neutrality) & back_3_rsi.sgt(&lower_neutrality), - rsi.slt(&NEUTRALITY_LINE) + rsi.slt(&NEUTRALITY) & rsi.sgt(&lower_barrier) & prev_rsi.slt(&upper_neutrality) & back_2_rsi.slt(&upper_neutrality) diff --git a/ta_lib/strategies/confirm/src/rsi_signalline.rs b/ta_lib/strategies/confirm/src/rsi_signalline.rs index 18faa016..d500bfee 100644 --- a/ta_lib/strategies/confirm/src/rsi_signalline.rs +++ b/ta_lib/strategies/confirm/src/rsi_signalline.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::rsi; +use timeseries::prelude::*; const RSI_UPPER_BARRIER: f32 = 75.; const RSI_LOWER_BARRIER: f32 = 35.; @@ -39,7 +40,7 @@ impl Confirm for RsiSignalLineConfirm { std::cmp::max(self.rsi_period, self.smooth_period) } - fn validate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { let rsi = rsi( &data.source(self.source_type), self.smooth_type, @@ -51,8 +52,8 @@ impl Confirm for RsiSignalLineConfirm { let lower_barrier = RSI_LOWER_BARRIER - self.threshold; ( - rsi.slt(&upper_barrier) & rsi.sgt(&signal), - rsi.sgt(&lower_barrier) & rsi.slt(&signal), + rsi.sgt(&signal) & rsi.slt(&upper_barrier), + rsi.slt(&signal) & rsi.sgt(&lower_barrier), ) } } diff --git a/ta_lib/strategies/confirm/src/stc.rs b/ta_lib/strategies/confirm/src/stc.rs index 2ade7809..96ee1a6f 100644 --- a/ta_lib/strategies/confirm/src/stc.rs +++ b/ta_lib/strategies/confirm/src/stc.rs @@ -1,9 +1,10 @@ use base::prelude::*; use core::prelude::*; use momentum::stc; +use timeseries::prelude::*; -const LOWER_LINE: f32 = 25.; const UPPER_LINE: f32 = 75.; +const LOWER_LINE: f32 = 25.; pub struct StcConfirm { source_type: SourceType, @@ -45,7 +46,7 @@ impl Confirm for StcConfirm { std::cmp::max(adj_lookback_three, self.d_second) } - fn validate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { let stc = stc( &data.source(self.source_type), self.smooth_type, @@ -58,8 +59,8 @@ impl Confirm for StcConfirm { let prev_stc = stc.shift(1); ( - stc.sgt(&prev_stc) & stc.sgt(&UPPER_LINE), - stc.slt(&prev_stc) & stc.slt(&LOWER_LINE), + stc.sgt(&UPPER_LINE) & stc.sgte(&prev_stc), + stc.slt(&LOWER_LINE) & stc.slte(&prev_stc), ) } } diff --git a/ta_lib/strategies/confirm/src/vi.rs b/ta_lib/strategies/confirm/src/vi.rs deleted file mode 100644 index 478e4a17..00000000 --- a/ta_lib/strategies/confirm/src/vi.rs +++ /dev/null @@ -1,143 +0,0 @@ -use base::prelude::*; -use core::prelude::*; -use trend::vi; - -pub struct ViConfirm { - atr_period: usize, - period: usize, -} - -impl ViConfirm { - pub fn new(atr_period: f32, period: f32) -> Self { - Self { - atr_period: atr_period as usize, - period: period as usize, - } - } -} - -impl Confirm for ViConfirm { - fn lookback(&self) -> usize { - std::cmp::max(self.atr_period, self.period) - } - - fn validate(&self, data: &OHLCVSeries) -> (Series, Series) { - let (vip, vim) = vi( - data.high(), - data.low(), - &data.atr(self.atr_period), - self.period, - ); - - (vip.sgt(&vim), vip.slt(&vim)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::collections::VecDeque; - - #[test] - fn test_confirm_vi() { - let confirm = ViConfirm::new(1.0, 3.0); - let data = VecDeque::from([ - OHLCV { - ts: 1679827200, - open: 6.490, - high: 6.514, - low: 6.490, - close: 6.511, - volume: 100.0, - }, - OHLCV { - ts: 1679827500, - open: 6.511, - high: 6.522, - low: 6.506, - close: 6.512, - volume: 100.0, - }, - OHLCV { - ts: 1679827800, - open: 6.512, - high: 6.513, - low: 6.496, - close: 6.512, - volume: 100.0, - }, - OHLCV { - ts: 1679828100, - open: 6.512, - high: 6.528, - low: 6.507, - close: 6.527, - volume: 100.0, - }, - OHLCV { - ts: 1679828400, - open: 6.527, - high: 6.530, - low: 6.497, - close: 6.500, - volume: 100.0, - }, - OHLCV { - ts: 1679828700, - open: 6.500, - high: 6.508, - low: 6.489, - close: 6.505, - volume: 100.0, - }, - OHLCV { - ts: 1679829000, - open: 6.505, - high: 6.510, - low: 6.483, - close: 6.492, - volume: 100.0, - }, - OHLCV { - ts: 1679829300, - open: 6.492, - high: 6.496, - low: 6.481, - close: 6.491, - volume: 100.0, - }, - OHLCV { - ts: 1679829600, - open: 6.491, - high: 6.512, - low: 6.486, - close: 6.499, - volume: 100.0, - }, - OHLCV { - ts: 1679829900, - open: 6.499, - high: 6.500, - low: 6.481, - close: 6.486, - volume: 100.0, - }, - ]); - let series = OHLCVSeries::from_data(&data); - - let (long_signal, short_signal) = confirm.validate(&series); - - let expected_long_signal = vec![ - false, true, true, true, false, false, false, false, true, false, - ]; - let expected_short_signal = vec![ - false, false, false, false, true, true, true, true, false, true, - ]; - - let result_long_signal: Vec = long_signal.into(); - let result_short_signal: Vec = short_signal.into(); - - assert_eq!(result_long_signal, expected_long_signal); - assert_eq!(result_short_signal, expected_short_signal); - } -} diff --git a/ta_lib/strategies/confirm/src/wpr.rs b/ta_lib/strategies/confirm/src/wpr.rs new file mode 100644 index 00000000..78fdd83d --- /dev/null +++ b/ta_lib/strategies/confirm/src/wpr.rs @@ -0,0 +1,41 @@ +use base::prelude::*; +use core::prelude::*; +use momentum::wpr; +use timeseries::prelude::*; + +pub struct WprConfirm { + source: SourceType, + period: usize, + smooth_signal: Smooth, + period_signal: usize, +} + +impl WprConfirm { + pub fn new(source: SourceType, period: f32, smooth_signal: Smooth, period_signal: f32) -> Self { + Self { + source, + period: period as usize, + smooth_signal, + period_signal: period_signal as usize, + } + } +} + +impl Confirm for WprConfirm { + fn lookback(&self) -> usize { + std::cmp::max(self.period, self.period_signal) + } + + fn filter(&self, data: &OHLCVSeries) -> (Series, Series) { + let wrp = wpr( + &data.source(self.source), + data.high(), + data.low(), + self.period, + ); + + let signal = wrp.smooth(self.smooth_signal, self.period_signal); + + (wrp.sgt(&signal), wrp.slt(&signal)) + } +} diff --git a/ta_lib/strategies/exit/Cargo.toml b/ta_lib/strategies/exit/Cargo.toml index 40247050..b7356aab 100644 --- a/ta_lib/strategies/exit/Cargo.toml +++ b/ta_lib/strategies/exit/Cargo.toml @@ -16,4 +16,5 @@ candlestick = { path = "../../patterns/candlestick" } trend = { path = "../../indicators/trend" } momentum = { path = "../../indicators/momentum" } volume = { path = "../../indicators/volume" } -indicator = { path = "../indicator" } \ No newline at end of file +indicator = { path = "../indicator" } +timeseries = { path = "../../timeseries" } \ No newline at end of file diff --git a/ta_lib/strategies/exit/src/ast.rs b/ta_lib/strategies/exit/src/ast.rs index d1299c7d..de7e9a78 100644 --- a/ta_lib/strategies/exit/src/ast.rs +++ b/ta_lib/strategies/exit/src/ast.rs @@ -1,18 +1,21 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use trend::ast; pub struct AstExit { source_type: SourceType, - atr_period: usize, + smooth_atr: Smooth, + period_atr: usize, factor: f32, } impl AstExit { - pub fn new(source_type: SourceType, atr_period: f32, factor: f32) -> Self { + pub fn new(source_type: SourceType, smooth_atr: Smooth, period_atr: f32, factor: f32) -> Self { Self { source_type, - atr_period: atr_period as usize, + smooth_atr, + period_atr: period_atr as usize, factor, } } @@ -20,13 +23,13 @@ impl AstExit { impl Exit for AstExit { fn lookback(&self) -> usize { - self.atr_period + self.period_atr } - fn evaluate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { let (direction, _) = ast( &data.source(self.source_type), - &data.atr(self.atr_period), + &data.atr(self.smooth_atr, self.period_atr), self.factor, ); diff --git a/ta_lib/strategies/exit/src/cci.rs b/ta_lib/strategies/exit/src/cci.rs deleted file mode 100644 index e0c33f79..00000000 --- a/ta_lib/strategies/exit/src/cci.rs +++ /dev/null @@ -1,51 +0,0 @@ -use base::prelude::*; -use core::prelude::*; -use momentum::cci; - -const CCI_OVERBOUGHT: f32 = 100.0; -const CCI_OVERSOLD: f32 = -100.0; - -pub struct CciExit { - source_type: SourceType, - smooth_type: Smooth, - period: usize, - factor: f32, - threshold: f32, -} - -impl CciExit { - pub fn new( - source_type: SourceType, - smooth_type: Smooth, - period: f32, - factor: f32, - threshold: f32, - ) -> Self { - Self { - source_type, - smooth_type, - period: period as usize, - factor, - threshold, - } - } -} - -impl Exit for CciExit { - fn lookback(&self) -> usize { - self.period - } - - fn evaluate(&self, data: &OHLCVSeries) -> (Series, Series) { - let rsi = cci( - &data.source(self.source_type), - self.smooth_type, - self.period, - self.factor, - ); - let upper_bound = CCI_OVERBOUGHT - self.threshold; - let lower_bound = CCI_OVERSOLD + self.threshold; - - (rsi.cross_under(&upper_bound), rsi.cross_over(&lower_bound)) - } -} diff --git a/ta_lib/strategies/exit/src/dumb.rs b/ta_lib/strategies/exit/src/dumb.rs index 117fe00d..0193641e 100644 --- a/ta_lib/strategies/exit/src/dumb.rs +++ b/ta_lib/strategies/exit/src/dumb.rs @@ -1,5 +1,6 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; pub struct DumbExit {} @@ -8,7 +9,7 @@ impl Exit for DumbExit { 0 } - fn evaluate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { let len = data.len(); (Series::zero(len).into(), Series::zero(len).into()) diff --git a/ta_lib/strategies/exit/src/highlow.rs b/ta_lib/strategies/exit/src/highlow.rs index 43aa5bd5..eff6a104 100644 --- a/ta_lib/strategies/exit/src/highlow.rs +++ b/ta_lib/strategies/exit/src/highlow.rs @@ -1,5 +1,6 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; pub struct HighLowExit { period: usize, @@ -18,7 +19,7 @@ impl Exit for HighLowExit { self.period } - fn evaluate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { ( data.close().sgt(&data.high().shift(self.period)), data.close().slt(&data.low().shift(self.period)), diff --git a/ta_lib/strategies/exit/src/lib.rs b/ta_lib/strategies/exit/src/lib.rs index 1087cc68..7f61e4ab 100644 --- a/ta_lib/strategies/exit/src/lib.rs +++ b/ta_lib/strategies/exit/src/lib.rs @@ -1,17 +1,19 @@ mod ast; -mod cci; mod dumb; mod highlow; mod ma; +mod mad; mod mfi; +mod rex; mod rsi; mod trix; pub use ast::AstExit; -pub use cci::CciExit; pub use dumb::DumbExit; pub use highlow::HighLowExit; pub use ma::MaExit; +pub use mad::MadExit; pub use mfi::MfiExit; +pub use rex::RexExit; pub use rsi::RsiExit; pub use trix::TrixExit; diff --git a/ta_lib/strategies/exit/src/ma.rs b/ta_lib/strategies/exit/src/ma.rs index c7d6d1ea..3cb6641d 100644 --- a/ta_lib/strategies/exit/src/ma.rs +++ b/ta_lib/strategies/exit/src/ma.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use indicator::{ma_indicator, MovingAverageType}; +use timeseries::prelude::*; pub struct MaExit { source_type: SourceType, @@ -23,7 +24,7 @@ impl Exit for MaExit { self.period } - fn evaluate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { let ma = ma_indicator(&self.ma, data, self.source_type, self.period); (data.close().cross_under(&ma), data.close().cross_over(&ma)) diff --git a/ta_lib/strategies/exit/src/mad.rs b/ta_lib/strategies/exit/src/mad.rs new file mode 100644 index 00000000..f1ff54cc --- /dev/null +++ b/ta_lib/strategies/exit/src/mad.rs @@ -0,0 +1,33 @@ +use base::prelude::*; +use core::prelude::*; +use timeseries::prelude::*; + +pub struct MadExit { + source: SourceType, + period_fast: usize, + period_slow: usize, +} + +impl MadExit { + pub fn new(source: SourceType, period_fast: f32, period_slow: f32) -> Self { + Self { + source, + period_fast: period_fast as usize, + period_slow: period_slow as usize, + } + } +} + +impl Exit for MadExit { + fn lookback(&self) -> usize { + std::cmp::max(self.period_fast, self.period_slow) + } + + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { + let mad = + data.source(self.source) + .spread_pct(Smooth::SMA, self.period_fast, self.period_slow); + + (mad.cross_under(&ZERO), mad.cross_over(&ZERO)) + } +} diff --git a/ta_lib/strategies/exit/src/mfi.rs b/ta_lib/strategies/exit/src/mfi.rs index 9d1d76ef..5b804beb 100644 --- a/ta_lib/strategies/exit/src/mfi.rs +++ b/ta_lib/strategies/exit/src/mfi.rs @@ -1,5 +1,6 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use volume::mfi; const MFI_OVERBOUGHT: f32 = 80.0; @@ -26,7 +27,7 @@ impl Exit for MfiExit { self.period } - fn evaluate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { let mfi = mfi(&data.source(self.source_type), data.volume(), self.period); let upper_bound = MFI_OVERBOUGHT - self.threshold; let lower_bound = MFI_OVERSOLD + self.threshold; diff --git a/ta_lib/strategies/exit/src/rex.rs b/ta_lib/strategies/exit/src/rex.rs new file mode 100644 index 00000000..16716f55 --- /dev/null +++ b/ta_lib/strategies/exit/src/rex.rs @@ -0,0 +1,51 @@ +use base::prelude::*; +use core::prelude::*; +use momentum::rex; +use timeseries::prelude::*; + +pub struct RexExit { + source: SourceType, + smooth: Smooth, + period: usize, + smooth_signal: Smooth, + period_signal: usize, +} + +impl RexExit { + pub fn new( + source: SourceType, + smooth: Smooth, + period: f32, + smooth_signal: Smooth, + period_signal: f32, + ) -> Self { + Self { + source, + smooth, + period: period as usize, + smooth_signal, + period_signal: period_signal as usize, + } + } +} + +impl Exit for RexExit { + fn lookback(&self) -> usize { + std::cmp::max(self.period, self.period_signal) + } + + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { + let rex = rex( + &data.source(self.source), + data.open(), + data.high(), + data.low(), + self.smooth, + self.period, + ); + + let signal_line = rex.smooth(self.smooth_signal, self.period_signal); + + (rex.cross_under(&signal_line), rex.cross_over(&signal_line)) + } +} diff --git a/ta_lib/strategies/exit/src/rsi.rs b/ta_lib/strategies/exit/src/rsi.rs index 52fbf835..7630feb4 100644 --- a/ta_lib/strategies/exit/src/rsi.rs +++ b/ta_lib/strategies/exit/src/rsi.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::rsi; +use timeseries::prelude::*; const RSI_OVERBOUGHT: f32 = 70.0; const RSI_OVERSOLD: f32 = 30.0; @@ -28,7 +29,7 @@ impl Exit for RsiExit { self.period } - fn evaluate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { let rsi = rsi( &data.source(self.source_type), self.smooth_type, diff --git a/ta_lib/strategies/exit/src/trix.rs b/ta_lib/strategies/exit/src/trix.rs index 659dad97..0cdf764b 100644 --- a/ta_lib/strategies/exit/src/trix.rs +++ b/ta_lib/strategies/exit/src/trix.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::trix; +use timeseries::prelude::*; pub struct TrixExit { source_type: SourceType, @@ -30,7 +31,7 @@ impl Exit for TrixExit { std::cmp::max(self.period, self.signal_period) } - fn evaluate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn close(&self, data: &OHLCVSeries) -> (Series, Series) { let trix = trix( &data.source(self.source_type), self.smooth_type, diff --git a/ta_lib/strategies/indicator/Cargo.toml b/ta_lib/strategies/indicator/Cargo.toml index 9c78a962..44d11178 100644 --- a/ta_lib/strategies/indicator/Cargo.toml +++ b/ta_lib/strategies/indicator/Cargo.toml @@ -13,4 +13,5 @@ repository.workspace = true core = { path = "../../core" } base = { path = "../base" } candlestick = { path = "../../patterns/candlestick" } -trend = { path = "../../indicators/trend" } \ No newline at end of file +trend = { path = "../../indicators/trend" } +timeseries = { path = "../../timeseries" } \ No newline at end of file diff --git a/ta_lib/strategies/indicator/src/candle.rs b/ta_lib/strategies/indicator/src/candle.rs index 4b52aaf1..65017eb9 100644 --- a/ta_lib/strategies/indicator/src/candle.rs +++ b/ta_lib/strategies/indicator/src/candle.rs @@ -1,9 +1,6 @@ -use base::prelude::*; -use candlestick::{ - bottle, double_trouble, golden, h, hexad, hikkake, marubozu, master_candle, quintuplets, - slingshot, tasuki, three_candles, three_methods, three_one_two, -}; +use candlestick::*; use core::prelude::*; +use timeseries::prelude::*; #[derive(Copy, Clone)] pub enum CandleTrendType { @@ -23,10 +20,7 @@ pub enum CandleTrendType { THREE_ONE_TWO, } -pub fn candlestick_trend_indicator( - candle: &CandleTrendType, - data: &OHLCVSeries, -) -> (Series, Series) { +pub fn candlestick_trend_indicator(candle: &CandleTrendType, data: &OHLCVSeries) -> (Rule, Rule) { match candle { CandleTrendType::BOTTLE => ( bottle::bullish(data.open(), data.low(), data.close()), @@ -86,3 +80,65 @@ pub fn candlestick_trend_indicator( ), } } + +#[derive(Copy, Clone)] +pub enum CandleReversalType { + DOJI, + ENGULFING, + EUPHORIA, + KANGAROO, + R, + SPLIT, + TWEEZERS, + HARAMIS, + HARAMIF, + HAMMER, +} + +pub fn candlestick_reversal_indicator( + candle: &CandleReversalType, + data: &OHLCVSeries, +) -> (Rule, Rule) { + match candle { + CandleReversalType::DOJI => ( + doji::bullish(data.open(), data.close()), + doji::bearish(data.open(), data.close()), + ), + CandleReversalType::ENGULFING => ( + engulfing::bullish(data.open(), data.high(), data.low(), data.close()), + engulfing::bearish(data.open(), data.high(), data.low(), data.close()), + ), + CandleReversalType::EUPHORIA => ( + euphoria::bullish(data.open(), data.close()), + euphoria::bearish(data.open(), data.close()), + ), + CandleReversalType::HAMMER => ( + hammer::bullish(data.open(), data.high(), data.close()), + hammer::bearish(data.open(), data.low(), data.close()), + ), + CandleReversalType::HARAMIS => ( + harami_strict::bullish(data.open(), data.high(), data.low(), data.close()), + harami_strict::bearish(data.open(), data.high(), data.low(), data.close()), + ), + CandleReversalType::HARAMIF => ( + harami_flexible::bullish(data.open(), data.high(), data.low(), data.close()), + harami_flexible::bearish(data.open(), data.high(), data.low(), data.close()), + ), + CandleReversalType::KANGAROO => ( + kangaroo_tail::bullish(data.open(), data.high(), data.low(), data.close()), + kangaroo_tail::bearish(data.open(), data.high(), data.low(), data.close()), + ), + CandleReversalType::R => ( + r::bullish(data.low(), data.close()), + r::bearish(data.high(), data.close()), + ), + CandleReversalType::SPLIT => ( + split::bullish(data.open(), data.high(), data.low(), data.close()), + split::bearish(data.open(), data.high(), data.low(), data.close()), + ), + CandleReversalType::TWEEZERS => ( + tweezers::bullish(data.open(), data.low(), data.close()), + tweezers::bearish(data.open(), data.high(), data.close()), + ), + } +} diff --git a/ta_lib/strategies/indicator/src/ma.rs b/ta_lib/strategies/indicator/src/ma.rs index e8249908..41aa2349 100644 --- a/ta_lib/strategies/indicator/src/ma.rs +++ b/ta_lib/strategies/indicator/src/ma.rs @@ -1,8 +1,9 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use trend::{ - alma, cama, dema, ema, frama, gma, hema, hma, kama, kjs, lsma, md, rmsma, sinwma, sma, smma, - t3, tema, tma, vidya, vwema, vwma, wma, zlema, zlhma, zlsma, zltema, + alma, cama, dema, ema, frama, gma, hema, hma, kama, lsma, md, midpoint, rmsma, sinwma, slsma, + sma, smma, t3, tema, trima, ults, vidya, vwema, vwma, wma, zlema, zlhma, zlsma, zltema, }; #[derive(Copy, Clone)] @@ -21,11 +22,14 @@ pub enum MovingAverageType { MD, RMSMA, SINWMA, + SLSMA, SMA, SMMA, TTHREE, TEMA, - TMA, + TL, + TRIMA, + ULTS, VIDYA, VWMA, VWEMA, @@ -39,44 +43,41 @@ pub enum MovingAverageType { pub fn ma_indicator( ma: &MovingAverageType, data: &OHLCVSeries, - source_type: SourceType, + source: SourceType, period: usize, -) -> Series { +) -> Price { + let source = data.source(source); + match ma { - MovingAverageType::ALMA => alma(&data.source(source_type), period, 0.85, 6.0), - MovingAverageType::CAMA => cama( - &data.source(source_type), - data.high(), - data.low(), - &data.tr(), - period, - ), - MovingAverageType::DEMA => dema(&data.source(source_type), period), - MovingAverageType::EMA => ema(&data.source(source_type), period), - MovingAverageType::FRAMA => { - frama(&data.source(source_type), data.high(), data.low(), period) - } - MovingAverageType::GMA => gma(&data.source(source_type), period), - MovingAverageType::HMA => hma(&data.source(source_type), period), - MovingAverageType::HEMA => hema(&data.source(source_type), period), - MovingAverageType::KAMA => kama(&data.source(source_type), period), - MovingAverageType::KJS => kjs(data.high(), data.low(), period), - MovingAverageType::LSMA => lsma(&data.source(source_type), period), - MovingAverageType::MD => md(&data.source(source_type), period), - MovingAverageType::RMSMA => rmsma(&data.source(source_type), period), - MovingAverageType::SINWMA => sinwma(&data.source(source_type), period), - MovingAverageType::SMA => sma(&data.source(source_type), period), - MovingAverageType::SMMA => smma(&data.source(source_type), period), - MovingAverageType::TTHREE => t3(&data.source(source_type), period), - MovingAverageType::TEMA => tema(&data.source(source_type), period), - MovingAverageType::TMA => tma(&data.source(source_type), period), - MovingAverageType::VIDYA => vidya(&data.source(source_type), period, 3 * period), - MovingAverageType::VWMA => vwma(&data.source(source_type), data.volume(), period), - MovingAverageType::VWEMA => vwema(&data.source(source_type), data.volume(), period), - MovingAverageType::WMA => wma(&data.source(source_type), period), - MovingAverageType::ZLEMA => zlema(&data.source(source_type), period), - MovingAverageType::ZLSMA => zlsma(&data.source(source_type), period), - MovingAverageType::ZLTEMA => zltema(&data.source(source_type), period), - MovingAverageType::ZLHMA => zlhma(&data.source(source_type), period, 3), + MovingAverageType::ALMA => alma(&source, period, 0.85, 6.0), + MovingAverageType::CAMA => cama(&source, data.high(), data.low(), &data.wtr(), period), + MovingAverageType::DEMA => dema(&source, period), + MovingAverageType::EMA => ema(&source, period), + MovingAverageType::FRAMA => frama(&source, data.high(), data.low(), period), + MovingAverageType::GMA => gma(&source, period), + MovingAverageType::HMA => hma(&source, period), + MovingAverageType::HEMA => hema(&source, period), + MovingAverageType::KAMA => kama(&source, period), + MovingAverageType::KJS => midpoint(data.high(), data.low(), 26), + MovingAverageType::LSMA => lsma(&source, period), + MovingAverageType::MD => md(&source, period), + MovingAverageType::RMSMA => rmsma(&source, period), + MovingAverageType::SINWMA => sinwma(&source, period), + MovingAverageType::SLSMA => slsma(&source, period), + MovingAverageType::SMA => sma(&source, period), + MovingAverageType::SMMA => smma(&source, period), + MovingAverageType::TTHREE => t3(&source, period), + MovingAverageType::TEMA => tema(&source, period), + MovingAverageType::TL => midpoint(data.high(), data.low(), 55), + MovingAverageType::TRIMA => trima(&source, period), + MovingAverageType::ULTS => ults(&source, period), + MovingAverageType::VIDYA => vidya(&source, period, 3 * period), + MovingAverageType::VWMA => vwma(&source, data.volume(), period), + MovingAverageType::VWEMA => vwema(&source, data.volume(), period), + MovingAverageType::WMA => wma(&source, period), + MovingAverageType::ZLEMA => zlema(&source, period), + MovingAverageType::ZLSMA => zlsma(&source, period), + MovingAverageType::ZLTEMA => zltema(&source, period), + MovingAverageType::ZLHMA => zlhma(&source, period, 3), } } diff --git a/ta_lib/strategies/pulse/Cargo.toml b/ta_lib/strategies/pulse/Cargo.toml index 15188d66..719e3c63 100644 --- a/ta_lib/strategies/pulse/Cargo.toml +++ b/ta_lib/strategies/pulse/Cargo.toml @@ -17,3 +17,4 @@ volume = { path = "../../indicators/volume" } trend = { path = "../../indicators/trend" } momentum = { path = "../../indicators/momentum" } volatility = { path = "../../indicators/volatility" } +timeseries = { path = "../../timeseries" } \ No newline at end of file diff --git a/ta_lib/strategies/pulse/src/adx.rs b/ta_lib/strategies/pulse/src/adx.rs index 53b98a08..9c1d77b2 100644 --- a/ta_lib/strategies/pulse/src/adx.rs +++ b/ta_lib/strategies/pulse/src/adx.rs @@ -1,22 +1,23 @@ use base::prelude::*; use core::prelude::*; use momentum::dmi; +use timeseries::prelude::*; -const ADX_LOWER_BARRIER: f32 = 25.; +const ADX_THRESHOLD: f32 = 30.; pub struct AdxPulse { - smooth_type: Smooth, - adx_period: usize, - di_period: usize, + smooth: Smooth, + period_adx: usize, + period_di: usize, threshold: f32, } impl AdxPulse { - pub fn new(smooth_type: Smooth, adx_period: f32, di_period: f32, threshold: f32) -> Self { + pub fn new(smooth: Smooth, period_adx: f32, period_di: f32, threshold: f32) -> Self { Self { - smooth_type, - adx_period: adx_period as usize, - di_period: di_period as usize, + smooth, + period_adx: period_adx as usize, + period_di: period_di as usize, threshold, } } @@ -24,20 +25,22 @@ impl AdxPulse { impl Pulse for AdxPulse { fn lookback(&self) -> usize { - std::cmp::max(self.adx_period, self.di_period) + std::cmp::max(self.period_adx, self.period_di) } fn assess(&self, data: &OHLCVSeries) -> (Series, Series) { - let (adx, _, _) = dmi( + let atr = data.atr(self.smooth, self.period_di); + let (_, _, adx) = dmi( data.high(), data.low(), - &data.atr(self.di_period), - self.smooth_type, - self.adx_period, - self.di_period, + &atr, + self.smooth, + self.period_adx, + self.period_di, ); - let adx_lower = ADX_LOWER_BARRIER + self.threshold; - (adx.sgt(&adx_lower), adx.sgt(&adx_lower)) + let barrier = ADX_THRESHOLD - self.threshold; + + (adx.sgt(&barrier), adx.sgt(&barrier)) } } diff --git a/ta_lib/strategies/pulse/src/chop.rs b/ta_lib/strategies/pulse/src/chop.rs index 70ea6c60..1593fa0f 100644 --- a/ta_lib/strategies/pulse/src/chop.rs +++ b/ta_lib/strategies/pulse/src/chop.rs @@ -1,20 +1,23 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use trend::chop; -const CHOP_MIDDLE_LINE: f32 = 38.2; +const CHOP_LINE: f32 = 38.2; pub struct ChopPulse { period: usize, - atr_period: usize, + smooth_atr: Smooth, + period_atr: usize, threshold: f32, } impl ChopPulse { - pub fn new(period: f32, atr_period: f32, threshold: f32) -> Self { + pub fn new(period: f32, smooth_atr: Smooth, period_atr: f32, threshold: f32) -> Self { Self { period: period as usize, - atr_period: atr_period as usize, + smooth_atr, + period_atr: period_atr as usize, threshold, } } @@ -22,32 +25,31 @@ impl ChopPulse { impl Pulse for ChopPulse { fn lookback(&self) -> usize { - std::cmp::max(self.period, self.atr_period) + std::cmp::max(self.period, self.period_atr) } fn assess(&self, data: &OHLCVSeries) -> (Series, Series) { let chop = chop( data.high(), data.low(), - &data.atr(self.atr_period), + &data.atr(self.smooth_atr, self.period_atr), self.period, ); - let lower_chop = CHOP_MIDDLE_LINE + self.threshold; + let barrier = CHOP_LINE + self.threshold; - (chop.slt(&lower_chop), chop.slt(&lower_chop)) + (chop.slt(&barrier), chop.slt(&barrier)) } } #[cfg(test)] mod tests { use super::*; - use std::collections::VecDeque; #[test] fn test_pulse_chop() { - let pulse = ChopPulse::new(6.0, 1.0, 0.0); - let data = VecDeque::from([ + let pulse = ChopPulse::new(6.0, Smooth::SMMA, 1.0, 0.0); + let data = vec![ OHLCV { ts: 1679825700, open: 5.993, @@ -128,8 +130,8 @@ mod tests { close: 5.943, volume: 100.0, }, - ]); - let series = OHLCVSeries::from_data(&data); + ]; + let series = OHLCVSeries::from(data); let (long_signal, short_signal) = pulse.assess(&series); diff --git a/ta_lib/strategies/pulse/src/dumb.rs b/ta_lib/strategies/pulse/src/dumb.rs index 0ef0a11f..ad643b64 100644 --- a/ta_lib/strategies/pulse/src/dumb.rs +++ b/ta_lib/strategies/pulse/src/dumb.rs @@ -1,5 +1,6 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; pub struct DumbPulse { period: usize, diff --git a/ta_lib/strategies/pulse/src/lib.rs b/ta_lib/strategies/pulse/src/lib.rs index 714735e0..0b6198ec 100644 --- a/ta_lib/strategies/pulse/src/lib.rs +++ b/ta_lib/strategies/pulse/src/lib.rs @@ -1,17 +1,19 @@ mod adx; -mod braid; mod chop; mod dumb; mod nvol; +mod sqz; mod tdfi; mod vo; mod wae; +mod yz; pub use adx::AdxPulse; -pub use braid::BraidPulse; pub use chop::ChopPulse; pub use dumb::DumbPulse; pub use nvol::NvolPulse; +pub use sqz::SqzPulse; pub use tdfi::TdfiPulse; pub use vo::VoPulse; pub use wae::WaePulse; +pub use yz::YzPulse; diff --git a/ta_lib/strategies/pulse/src/nvol.rs b/ta_lib/strategies/pulse/src/nvol.rs index e14bf034..8c10c99d 100644 --- a/ta_lib/strategies/pulse/src/nvol.rs +++ b/ta_lib/strategies/pulse/src/nvol.rs @@ -1,18 +1,19 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use volume::nvol; const NVOL_LINE: f32 = 100.0; pub struct NvolPulse { - smooth_type: Smooth, + smooth: Smooth, period: usize, } impl NvolPulse { - pub fn new(smooth_type: Smooth, period: f32) -> Self { + pub fn new(smooth: Smooth, period: f32) -> Self { Self { - smooth_type, + smooth, period: period as usize, } } @@ -24,7 +25,7 @@ impl Pulse for NvolPulse { } fn assess(&self, data: &OHLCVSeries) -> (Series, Series) { - let nvol = nvol(data.volume(), self.smooth_type, self.period); + let nvol = nvol(data.volume(), self.smooth, self.period); (nvol.sgt(&NVOL_LINE), nvol.sgt(&NVOL_LINE)) } diff --git a/ta_lib/strategies/pulse/src/sqz.rs b/ta_lib/strategies/pulse/src/sqz.rs new file mode 100644 index 00000000..76ac24cb --- /dev/null +++ b/ta_lib/strategies/pulse/src/sqz.rs @@ -0,0 +1,55 @@ +use base::prelude::*; +use core::prelude::*; +use timeseries::prelude::*; +use volatility::{bb, kch}; + +pub struct SqzPulse { + source: SourceType, + smooth: Smooth, + period: usize, + smooth_atr: Smooth, + period_atr: usize, + factor_bb: f32, + factor_kch: f32, +} + +impl SqzPulse { + pub fn new( + source: SourceType, + smooth: Smooth, + period: f32, + smooth_atr: Smooth, + period_atr: f32, + factor_bb: f32, + factor_kch: f32, + ) -> Self { + Self { + source, + smooth, + period: period as usize, + smooth_atr, + period_atr: period_atr as usize, + factor_bb, + factor_kch, + } + } +} + +impl Pulse for SqzPulse { + fn lookback(&self) -> usize { + std::cmp::max(self.period, self.period_atr) + } + + fn assess(&self, data: &OHLCVSeries) -> (Series, Series) { + let source = data.source(self.source); + let atr = data.atr(self.smooth_atr, self.period_atr); + + let (upbb, _, lwbb) = bb(&source, self.smooth, self.period, self.factor_bb); + let (upkch, _, lwkch) = kch(&source, self.smooth, &atr, self.period, self.factor_kch); + + ( + upbb.sgt(&upkch) & lwbb.slt(&lwkch), + upbb.sgt(&upkch) & lwbb.slt(&lwkch), + ) + } +} diff --git a/ta_lib/strategies/pulse/src/tdfi.rs b/ta_lib/strategies/pulse/src/tdfi.rs index d2d99af8..763f1fe6 100644 --- a/ta_lib/strategies/pulse/src/tdfi.rs +++ b/ta_lib/strategies/pulse/src/tdfi.rs @@ -1,22 +1,23 @@ use base::prelude::*; use core::prelude::*; use momentum::tdfi; +use timeseries::prelude::*; const TDFI_UPPER_LINE: f32 = 0.05; const TDFI_LOWER_LINE: f32 = -0.05; pub struct TdfiPulse { - source_type: SourceType, - smooth_type: Smooth, + source: SourceType, + smooth: Smooth, period: usize, n: usize, } impl TdfiPulse { - pub fn new(source_type: SourceType, smooth_type: Smooth, period: f32, n: f32) -> Self { + pub fn new(source: SourceType, smooth: Smooth, period: f32, n: f32) -> Self { Self { - source_type, - smooth_type, + source, + smooth, period: period as usize, n: n as usize, } @@ -29,12 +30,7 @@ impl Pulse for TdfiPulse { } fn assess(&self, data: &OHLCVSeries) -> (Series, Series) { - let tdfi = tdfi( - &data.source(self.source_type), - self.smooth_type, - self.period, - self.n, - ); + let tdfi = tdfi(&data.source(self.source), self.smooth, self.period, self.n); (tdfi.sgt(&TDFI_UPPER_LINE), tdfi.slt(&TDFI_LOWER_LINE)) } @@ -43,12 +39,11 @@ impl Pulse for TdfiPulse { #[cfg(test)] mod tests { use super::*; - use std::collections::VecDeque; #[test] fn test_pulse_tdfi() { let pulse = TdfiPulse::new(SourceType::CLOSE, Smooth::TEMA, 6.0, 3.0); - let data = VecDeque::from([ + let data = vec![ OHLCV { ts: 1679825700, open: 5.993, @@ -129,8 +124,8 @@ mod tests { close: 5.943, volume: 100.0, }, - ]); - let series = OHLCVSeries::from_data(&data); + ]; + let series = OHLCVSeries::from(data); let (long_signal, short_signal) = pulse.assess(&series); diff --git a/ta_lib/strategies/pulse/src/vo.rs b/ta_lib/strategies/pulse/src/vo.rs index c62a1e6f..b15ae586 100644 --- a/ta_lib/strategies/pulse/src/vo.rs +++ b/ta_lib/strategies/pulse/src/vo.rs @@ -1,36 +1,33 @@ use base::prelude::*; use core::prelude::*; -use volume::vo; +use timeseries::prelude::*; pub struct VoPulse { - smooth_type: Smooth, - fast_period: usize, - slow_period: usize, + smooth: Smooth, + period_fast: usize, + period_slow: usize, } impl VoPulse { - pub fn new(smooth_type: Smooth, fast_period: f32, slow_period: f32) -> Self { + pub fn new(smooth: Smooth, period_fast: f32, period_slow: f32) -> Self { Self { - smooth_type, - fast_period: fast_period as usize, - slow_period: slow_period as usize, + smooth, + period_fast: period_fast as usize, + period_slow: period_slow as usize, } } } impl Pulse for VoPulse { fn lookback(&self) -> usize { - std::cmp::max(self.fast_period, self.slow_period) + std::cmp::max(self.period_fast, self.period_slow) } fn assess(&self, data: &OHLCVSeries) -> (Series, Series) { - let vo = vo( - data.volume(), - self.smooth_type, - self.fast_period, - self.slow_period, - ); + let vo = data + .volume() + .spread_pct(self.smooth, self.period_fast, self.period_slow); - (vo.sgt(&ZERO_LINE), vo.sgt(&ZERO_LINE)) + (vo.sgt(&ZERO), vo.sgt(&ZERO)) } } diff --git a/ta_lib/strategies/pulse/src/wae.rs b/ta_lib/strategies/pulse/src/wae.rs index 8016614e..691c9b0a 100644 --- a/ta_lib/strategies/pulse/src/wae.rs +++ b/ta_lib/strategies/pulse/src/wae.rs @@ -1,74 +1,67 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use volatility::bb; pub struct WaePulse { - smooth_type: Smooth, - fast_period: usize, - slow_period: usize, + source: SourceType, + smooth: Smooth, + period_fast: usize, + period_slow: usize, smooth_bb: Smooth, - bb_period: usize, + period_bb: usize, factor: f32, strength: f32, - atr_period: usize, - dz_factor: f32, } impl WaePulse { pub fn new( - smooth_type: Smooth, - fast_period: f32, - slow_period: f32, + source: SourceType, + smooth: Smooth, + period_fast: f32, + period_slow: f32, smooth_bb: Smooth, - bb_period: f32, + period_bb: f32, factor: f32, strength: f32, - atr_period: f32, - dz_factor: f32, ) -> Self { Self { - smooth_type, - fast_period: fast_period as usize, - slow_period: slow_period as usize, + source, + smooth, + period_fast: period_fast as usize, + period_slow: period_slow as usize, smooth_bb, - bb_period: bb_period as usize, + period_bb: period_bb as usize, factor, strength, - atr_period: atr_period as usize, - dz_factor, } } } impl Pulse for WaePulse { fn lookback(&self) -> usize { - let mut adj_lookback = std::cmp::max(self.fast_period, self.slow_period); - adj_lookback = std::cmp::max(adj_lookback, self.bb_period); - std::cmp::max(adj_lookback, self.atr_period) + let adj_lookback = std::cmp::max(self.period_fast, self.period_slow); + std::cmp::max(adj_lookback, self.period_bb) } fn assess(&self, data: &OHLCVSeries) -> (Series, Series) { - let dz = data.atr(self.atr_period) * self.dz_factor; + let source = data.source(self.source); - let (upper_bb, _, lower_bb) = bb(data.close(), self.smooth_bb, self.bb_period, self.factor); - let e = upper_bb - lower_bb; + let (upper_bb, _, lower_bb) = bb(&source, self.smooth_bb, self.period_bb, self.factor); - let prev_close = data.close().shift(1); + let e = upper_bb - lower_bb; - let macd_line = data.close().smooth(self.smooth_type, self.fast_period) - - data.close().smooth(self.smooth_type, self.slow_period); - let prev_macd_line = prev_close.smooth(self.smooth_type, self.fast_period) - - prev_close.smooth(self.smooth_type, self.slow_period); - let t = (macd_line - prev_macd_line) * self.strength; + let diff = + self.strength * source.spread_diff(self.smooth, self.period_fast, self.period_slow, 1); let zero = Series::zero(data.len()); - let up = iff!(t.sgte(&ZERO), t, zero); - let down = iff!(t.slt(&ZERO), t.negate(), zero); + let up = iff!(diff.sgte(&ZERO), diff, zero); + let down = iff!(diff.slt(&ZERO), diff.negate(), zero); ( - up.sgt(&up.shift(1)) & up.sgt(&e) & up.sgt(&dz), - down.sgt(&down.shift(1)) & down.sgt(&e) & down.sgt(&dz), + up.sgt(&up.shift(1)) & up.sgt(&e), + down.sgt(&down.shift(1)) & down.sgt(&e), ) } } diff --git a/ta_lib/strategies/pulse/src/yz.rs b/ta_lib/strategies/pulse/src/yz.rs new file mode 100644 index 00000000..ae6263e4 --- /dev/null +++ b/ta_lib/strategies/pulse/src/yz.rs @@ -0,0 +1,39 @@ +use base::prelude::*; +use core::prelude::*; +use timeseries::prelude::*; +use volatility::yz; + +pub struct YzPulse { + period: usize, + smooth_signal: Smooth, + period_signal: usize, +} + +impl YzPulse { + pub fn new(period: f32, smooth_signal: Smooth, period_signal: f32) -> Self { + Self { + period: period as usize, + smooth_signal, + period_signal: period_signal as usize, + } + } +} + +impl Pulse for YzPulse { + fn lookback(&self) -> usize { + std::cmp::max(self.period, self.period_signal) + } + + fn assess(&self, data: &OHLCVSeries) -> (Series, Series) { + let yz = yz( + data.open(), + data.high(), + data.low(), + data.close(), + self.period, + ); + let signal = yz.smooth(self.smooth_signal, self.period_signal); + + (yz.sgt(&signal), yz.sgt(&signal)) + } +} diff --git a/ta_lib/strategies/signal/Cargo.toml b/ta_lib/strategies/signal/Cargo.toml index cd873e14..e92226eb 100644 --- a/ta_lib/strategies/signal/Cargo.toml +++ b/ta_lib/strategies/signal/Cargo.toml @@ -18,4 +18,9 @@ trend = { path = "../../indicators/trend" } momentum = { path = "../../indicators/momentum" } volatility = { path = "../../indicators/volatility" } volume = { path = "../../indicators/volume" } -candlestick = { path = "../../patterns/candlestick" } \ No newline at end of file +candlestick = { path = "../../patterns/candlestick" } +osc = { path = "../../patterns/osc" } +channel = { path = "../../patterns/channel" } +bands = { path = "../../patterns/bands" } +trail = { path = "../../patterns/trail" } +timeseries = { path = "../../timeseries" } \ No newline at end of file diff --git a/ta_lib/strategies/signal/src/bb/macd_bb.rs b/ta_lib/strategies/signal/src/bb/macd.rs similarity index 93% rename from ta_lib/strategies/signal/src/bb/macd_bb.rs rename to ta_lib/strategies/signal/src/bb/macd.rs index 9158444c..aa568ea5 100644 --- a/ta_lib/strategies/signal/src/bb/macd_bb.rs +++ b/ta_lib/strategies/signal/src/bb/macd.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::macd; +use timeseries::prelude::*; use volatility::bb; pub struct MacdBbSignal { @@ -45,7 +46,7 @@ impl Signal for MacdBbSignal { std::cmp::max(adj_lookback, self.bb_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let (macd_line, _, _) = macd( &data.source(self.source_type), self.smooth_type, diff --git a/ta_lib/strategies/signal/src/bb/mod.rs b/ta_lib/strategies/signal/src/bb/mod.rs index 6e326789..f6eebb28 100644 --- a/ta_lib/strategies/signal/src/bb/mod.rs +++ b/ta_lib/strategies/signal/src/bb/mod.rs @@ -1,5 +1,5 @@ -mod macd_bb; -mod vwap_bb; +mod macd; +mod vwap; -pub use macd_bb::MacdBbSignal; -pub use vwap_bb::VwapBbSignal; +pub use macd::MacdBbSignal; +pub use vwap::VwapBbSignal; diff --git a/ta_lib/strategies/signal/src/bb/vwap_bb.rs b/ta_lib/strategies/signal/src/bb/vwap.rs similarity index 90% rename from ta_lib/strategies/signal/src/bb/vwap_bb.rs rename to ta_lib/strategies/signal/src/bb/vwap.rs index 6088bc27..165076a0 100644 --- a/ta_lib/strategies/signal/src/bb/vwap_bb.rs +++ b/ta_lib/strategies/signal/src/bb/vwap.rs @@ -1,5 +1,6 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use volatility::bb; use volume::vwap; @@ -34,7 +35,7 @@ impl Signal for VwapBbSignal { std::cmp::max(self.period, self.bb_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let vwap = vwap(&data.source(self.source_type), data.volume()); let (upper_bb, _, lower_bb) = bb(&vwap, self.bb_smooth, self.bb_period, self.factor); diff --git a/ta_lib/strategies/signal/src/breakout/dch_ma2_breakout.rs b/ta_lib/strategies/signal/src/breakout/dch_ma2.rs similarity index 64% rename from ta_lib/strategies/signal/src/breakout/dch_ma2_breakout.rs rename to ta_lib/strategies/signal/src/breakout/dch_ma2.rs index 0300d013..2493e3a5 100644 --- a/ta_lib/strategies/signal/src/breakout/dch_ma2_breakout.rs +++ b/ta_lib/strategies/signal/src/breakout/dch_ma2.rs @@ -1,10 +1,11 @@ use base::prelude::*; use core::prelude::*; use indicator::{ma_indicator, MovingAverageType}; +use timeseries::prelude::*; use volatility::dch; pub struct DchMa2BreakoutSignal { - source_type: SourceType, + source: SourceType, dch_period: usize, ma: MovingAverageType, fast_period: usize, @@ -13,14 +14,14 @@ pub struct DchMa2BreakoutSignal { impl DchMa2BreakoutSignal { pub fn new( - source_type: SourceType, + source: SourceType, dch_period: f32, ma: MovingAverageType, fast_period: f32, slow_period: f32, ) -> Self { Self { - source_type, + source, dch_period: dch_period as usize, ma, fast_period: fast_period as usize, @@ -35,15 +36,16 @@ impl Signal for DchMa2BreakoutSignal { std::cmp::max(adj_lookback, self.dch_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let (upper_band, _, lower_band) = dch(data.high(), data.low(), self.dch_period); - let ma_short = ma_indicator(&self.ma, data, self.source_type, self.fast_period); - let ma_long = ma_indicator(&self.ma, data, self.source_type, self.slow_period); + let ma_short = ma_indicator(&self.ma, data, self.source, self.fast_period); + let ma_long = ma_indicator(&self.ma, data, self.source, self.slow_period); + let source = data.close(); ( - data.close().sgt(&upper_band.shift(1)) & ma_short.sgt(&ma_long), - data.close().slt(&lower_band.shift(1)) & ma_short.slt(&ma_long), + source.sgt(&upper_band.shift(1)) & ma_short.sgt(&ma_long), + source.slt(&lower_band.shift(1)) & ma_short.slt(&ma_long), ) } } diff --git a/ta_lib/strategies/signal/src/breakout/mod.rs b/ta_lib/strategies/signal/src/breakout/mod.rs index 46f9ac52..616f29b9 100644 --- a/ta_lib/strategies/signal/src/breakout/mod.rs +++ b/ta_lib/strategies/signal/src/breakout/mod.rs @@ -1,3 +1,3 @@ -mod dch_ma2_breakout; +mod dch_ma2; -pub use dch_ma2_breakout::DchMa2BreakoutSignal; +pub use dch_ma2::DchMa2BreakoutSignal; diff --git a/ta_lib/strategies/signal/src/pattern/macd_colorswitch.rs b/ta_lib/strategies/signal/src/colorswitch/macd.rs similarity index 90% rename from ta_lib/strategies/signal/src/pattern/macd_colorswitch.rs rename to ta_lib/strategies/signal/src/colorswitch/macd.rs index b4ceae1e..ba878dab 100644 --- a/ta_lib/strategies/signal/src/pattern/macd_colorswitch.rs +++ b/ta_lib/strategies/signal/src/colorswitch/macd.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::macd; +use timeseries::prelude::*; pub struct MacdColorSwitchSignal { source_type: SourceType, @@ -34,7 +35,7 @@ impl Signal for MacdColorSwitchSignal { std::cmp::max(adj_lookback, self.signal_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let (_, _, histogram) = macd( &data.source(self.source_type), self.smooth_type, @@ -48,11 +49,11 @@ impl Signal for MacdColorSwitchSignal { let back_3_histogram = histogram.shift(3); ( - histogram.slt(&ZERO_LINE) + histogram.slt(&ZERO) & histogram.sgt(&prev_histogram) & prev_histogram.slt(&back_2_histogram) & back_2_histogram.slt(&back_3_histogram), - histogram.sgt(&ZERO_LINE) + histogram.sgt(&ZERO) & histogram.slt(&prev_histogram) & prev_histogram.sgt(&back_2_histogram) & back_2_histogram.sgt(&back_3_histogram), diff --git a/ta_lib/strategies/signal/src/colorswitch/mod.rs b/ta_lib/strategies/signal/src/colorswitch/mod.rs new file mode 100644 index 00000000..317e0503 --- /dev/null +++ b/ta_lib/strategies/signal/src/colorswitch/mod.rs @@ -0,0 +1,3 @@ +mod macd; + +pub use macd::MacdColorSwitchSignal; diff --git a/ta_lib/strategies/signal/src/contrarian/kch_a.rs b/ta_lib/strategies/signal/src/contrarian/kch_a.rs new file mode 100644 index 00000000..31fce2bf --- /dev/null +++ b/ta_lib/strategies/signal/src/contrarian/kch_a.rs @@ -0,0 +1,207 @@ +use base::prelude::*; +use channel::a; +use core::prelude::*; +use timeseries::prelude::*; +use volatility::kch; + +pub struct KchASignal { + source: SourceType, + smooth: Smooth, + period: usize, + smooth_atr: Smooth, + period_atr: usize, + factor: f32, +} + +impl KchASignal { + pub fn new( + source: SourceType, + smooth: Smooth, + period: f32, + smooth_atr: Smooth, + period_atr: f32, + factor: f32, + ) -> Self { + Self { + source, + smooth, + period: period as usize, + smooth_atr, + period_atr: period_atr as usize, + factor, + } + } +} + +impl Signal for KchASignal { + fn lookback(&self) -> usize { + std::cmp::max(self.period, self.period_atr) + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let source = data.source(self.source); + let atr = data.atr(self.smooth_atr, self.period_atr); + let (upper, _, lower) = kch(&source, self.smooth, &atr, self.period, self.factor); + + a!(source, upper, lower) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_kch_a_ults_signal() { + let signal = KchASignal::new(SourceType::HLC3, Smooth::ULTS, 5.0, Smooth::ULTS, 5.0, 0.3); + let data = vec![ + OHLCV { + ts: 1679827200, + open: 0.29437, + high: 0.29606, + low: 0.29415, + close: 0.29456, + volume: 100.0, + }, + OHLCV { + ts: 1679827500, + open: 0.29456, + high: 0.29623, + low: 0.29456, + close: 0.29603, + volume: 100.0, + }, + OHLCV { + ts: 1679827800, + open: 0.29603, + high: 0.29620, + low: 0.29263, + close: 0.29263, + volume: 100.0, + }, + OHLCV { + ts: 1679828100, + open: 0.29263, + high: 0.29329, + low: 0.28850, + close: 0.28877, + volume: 100.0, + }, + OHLCV { + ts: 1679828400, + open: 0.28877, + high: 0.29104, + low: 0.28599, + close: 0.29085, + volume: 100.0, + }, + OHLCV { + ts: 1679828700, + open: 0.29085, + high: 0.29393, + low: 0.29085, + close: 0.29241, + volume: 100.0, + }, + OHLCV { + ts: 1679829000, + open: 0.29241, + high: 0.29318, + low: 0.29202, + close: 0.29287, + volume: 100.0, + }, + OHLCV { + ts: 1679829300, + open: 0.29287, + high: 0.29355, + low: 0.29223, + close: 0.29304, + volume: 100.0, + }, + OHLCV { + ts: 1679829600, + open: 0.29304, + high: 0.29305, + low: 0.29130, + close: 0.29153, + volume: 100.0, + }, + OHLCV { + ts: 1679829900, + open: 0.29153, + high: 0.29216, + low: 0.28969, + close: 0.28991, + volume: 100.0, + }, + OHLCV { + ts: 1679830200, + open: 0.28991, + high: 0.29068, + low: 0.28866, + close: 0.28879, + volume: 100.0, + }, + OHLCV { + ts: 1679830500, + open: 0.28879, + high: 0.28941, + low: 0.28830, + close: 0.28860, + volume: 100.0, + }, + OHLCV { + ts: 1679830800, + open: 0.28860, + high: 0.29012, + low: 0.28837, + close: 0.28940, + volume: 100.0, + }, + OHLCV { + ts: 1679831100, + open: 0.28940, + high: 0.29074, + low: 0.28940, + close: 0.29074, + volume: 100.0, + }, + OHLCV { + ts: 1679831400, + open: 0.29074, + high: 0.29270, + low: 0.29074, + close: 0.29270, + volume: 100.0, + }, + OHLCV { + ts: 1679831700, + open: 0.29270, + high: 0.29419, + low: 0.29270, + close: 0.29390, + volume: 100.0, + }, + ]; + + let series = OHLCVSeries::from(data); + + let (long_signal, short_signal) = signal.trigger(&series); + + let expected_long_signal = vec![ + false, false, false, true, false, false, false, false, false, true, false, false, + false, false, false, false, + ]; + let expected_short_signal = vec![ + false, false, false, false, false, false, true, false, false, false, false, false, + false, false, true, false, + ]; + + let result_long_signal: Vec = long_signal.into(); + let result_short_signal: Vec = short_signal.into(); + + assert_eq!(result_long_signal, expected_long_signal); + assert_eq!(result_short_signal, expected_short_signal); + } +} diff --git a/ta_lib/strategies/signal/src/contrarian/kch_c.rs b/ta_lib/strategies/signal/src/contrarian/kch_c.rs new file mode 100644 index 00000000..d0d72819 --- /dev/null +++ b/ta_lib/strategies/signal/src/contrarian/kch_c.rs @@ -0,0 +1,48 @@ +use base::prelude::*; +use channel::c; +use core::prelude::*; +use timeseries::prelude::*; +use volatility::kch; + +pub struct KchCSignal { + source: SourceType, + smooth: Smooth, + period: usize, + smooth_atr: Smooth, + period_atr: usize, + factor: f32, +} + +impl KchCSignal { + pub fn new( + source: SourceType, + smooth: Smooth, + period: f32, + smooth_atr: Smooth, + period_atr: f32, + factor: f32, + ) -> Self { + Self { + source, + smooth, + period: period as usize, + smooth_atr, + period_atr: period_atr as usize, + factor, + } + } +} + +impl Signal for KchCSignal { + fn lookback(&self) -> usize { + std::cmp::max(self.period, self.period_atr) + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let source = data.source(self.source); + let atr = data.atr(self.smooth_atr, self.period_atr); + let (upper, _, lower) = kch(&source, self.smooth, &atr, self.period, self.factor); + + c!(data.low(), data.high(), upper, lower) + } +} diff --git a/ta_lib/strategies/signal/src/contrarian/mod.rs b/ta_lib/strategies/signal/src/contrarian/mod.rs new file mode 100644 index 00000000..e8078f1d --- /dev/null +++ b/ta_lib/strategies/signal/src/contrarian/mod.rs @@ -0,0 +1,21 @@ +mod kch_a; +mod kch_c; +mod rsi_c; +mod rsi_d; +mod rsi_nt; +mod rsi_u; +mod rsi_v; +mod snatr; +mod stoch_e; +mod tii_v; + +pub use kch_a::KchASignal; +pub use kch_c::KchCSignal; +pub use rsi_c::RsiCSignal; +pub use rsi_d::RsiDSignal; +pub use rsi_nt::RsiNtSignal; +pub use rsi_u::RsiUSignal; +pub use rsi_v::RsiVSignal; +pub use snatr::SnatrSignal; +pub use stoch_e::StochESignal; +pub use tii_v::TiiVSignal; diff --git a/ta_lib/strategies/signal/src/contrarian/rsi_c.rs b/ta_lib/strategies/signal/src/contrarian/rsi_c.rs new file mode 100644 index 00000000..a4773623 --- /dev/null +++ b/ta_lib/strategies/signal/src/contrarian/rsi_c.rs @@ -0,0 +1,40 @@ +use base::prelude::*; +use core::prelude::*; +use momentum::rsi; +use osc::c; +use timeseries::prelude::*; + +const RSI_UPPER_BARRIER: f32 = 70.0; +const RSI_LOWER_BARRIER: f32 = 30.0; + +pub struct RsiCSignal { + source: SourceType, + smooth: Smooth, + period: usize, + threshold: f32, +} + +impl RsiCSignal { + pub fn new(source: SourceType, smooth: Smooth, period: f32, threshold: f32) -> Self { + Self { + source, + smooth, + period: period as usize, + threshold, + } + } +} + +impl Signal for RsiCSignal { + fn lookback(&self) -> usize { + self.period + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + c!( + rsi(&data.source(self.source), self.smooth, self.period), + RSI_LOWER_BARRIER + self.threshold, + RSI_UPPER_BARRIER - self.threshold + ) + } +} diff --git a/ta_lib/strategies/signal/src/contrarian/rsi_d.rs b/ta_lib/strategies/signal/src/contrarian/rsi_d.rs new file mode 100644 index 00000000..07975243 --- /dev/null +++ b/ta_lib/strategies/signal/src/contrarian/rsi_d.rs @@ -0,0 +1,55 @@ +use base::prelude::*; +use core::prelude::*; +use momentum::rsi; +use timeseries::prelude::*; + +const RSI_UPPER_BARRIER: f32 = 80.0; +const RSI_LOWER_BARRIER: f32 = 20.0; + +pub struct RsiDSignal { + source: SourceType, + smooth: Smooth, + period: usize, + threshold: f32, +} + +impl RsiDSignal { + pub fn new(source: SourceType, smooth: Smooth, period: f32, threshold: f32) -> Self { + Self { + source, + smooth, + period: period as usize, + threshold, + } + } +} + +impl Signal for RsiDSignal { + fn lookback(&self) -> usize { + self.period + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let rsi = rsi(&data.source(self.source), self.smooth, self.period); + let lower_barrier = RSI_LOWER_BARRIER + self.threshold; + let upper_barrier = RSI_UPPER_BARRIER - self.threshold; + + let prev_rsi = rsi.shift(1); + let back_2_rsi = rsi.shift(2); + let back_3_rsi = rsi.shift(3); + let back_4_rsi = rsi.shift(4); + + ( + rsi.slt(&prev_rsi) + & prev_rsi.slt(&back_2_rsi) + & back_2_rsi.slt(&back_3_rsi) + & back_3_rsi.slt(&lower_barrier) + & back_4_rsi.sgt(&lower_barrier), + rsi.sgt(&prev_rsi) + & prev_rsi.sgt(&back_2_rsi) + & back_2_rsi.sgt(&back_3_rsi) + & back_3_rsi.sgt(&upper_barrier) + & back_4_rsi.slt(&upper_barrier), + ) + } +} diff --git a/ta_lib/strategies/signal/src/contrarian/rsi_nt.rs b/ta_lib/strategies/signal/src/contrarian/rsi_nt.rs new file mode 100644 index 00000000..363e33f2 --- /dev/null +++ b/ta_lib/strategies/signal/src/contrarian/rsi_nt.rs @@ -0,0 +1,53 @@ +use base::prelude::*; +use core::prelude::*; +use momentum::rsi; +use timeseries::prelude::*; + +const RSI_UPPER_BARRIER: f32 = 70.0; +const RSI_LOWER_BARRIER: f32 = 30.0; + +pub struct RsiNtSignal { + source: SourceType, + smooth: Smooth, + period: usize, + threshold: f32, +} + +impl RsiNtSignal { + pub fn new(source: SourceType, smooth: Smooth, period: f32, threshold: f32) -> Self { + Self { + source, + smooth, + period: period as usize, + threshold, + } + } +} + +impl Signal for RsiNtSignal { + fn lookback(&self) -> usize { + self.period + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let rsi = rsi(&data.source(self.source), self.smooth, self.period); + let low = data.low(); + let high = data.high(); + + let lower_barrier = RSI_LOWER_BARRIER + self.threshold; + let upper_barrier = RSI_UPPER_BARRIER - self.threshold; + + let prev_rsi = rsi.shift(1); + + ( + rsi.sgt(&prev_rsi) + & rsi.slt(&lower_barrier) + & prev_rsi.slt(&lower_barrier) + & low.slt(&low.shift(1)), + rsi.slt(&prev_rsi) + & rsi.sgt(&upper_barrier) + & prev_rsi.sgt(&upper_barrier) + & high.sgt(&high.shift(1)), + ) + } +} diff --git a/ta_lib/strategies/signal/src/contrarian/rsi_u.rs b/ta_lib/strategies/signal/src/contrarian/rsi_u.rs new file mode 100644 index 00000000..27125450 --- /dev/null +++ b/ta_lib/strategies/signal/src/contrarian/rsi_u.rs @@ -0,0 +1,46 @@ +use base::prelude::*; +use core::prelude::*; +use momentum::rsi; +use timeseries::prelude::*; + +pub struct RsiUSignal { + source: SourceType, + smooth: Smooth, + period: usize, + threshold: f32, +} + +impl RsiUSignal { + pub fn new(source: SourceType, smooth: Smooth, period: f32, threshold: f32) -> Self { + Self { + source, + smooth, + period: period as usize, + threshold, + } + } +} + +impl Signal for RsiUSignal { + fn lookback(&self) -> usize { + self.period + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let rsi = rsi(&data.source(self.source), self.smooth, self.period); + let prev_rsi = rsi.shift(1); + let back_2_rsi = rsi.shift(2); + let back_3_rsi = rsi.shift(3); + + ( + rsi.sgt(&prev_rsi) + & prev_rsi.seq(&back_2_rsi) + & back_2_rsi.slt(&back_3_rsi) + & rsi.slt(&NEUTRALITY), + rsi.slt(&prev_rsi) + & prev_rsi.seq(&back_2_rsi) + & back_2_rsi.sgt(&back_3_rsi) + & rsi.sgt(&NEUTRALITY), + ) + } +} diff --git a/ta_lib/strategies/signal/src/contrarian/rsi_v.rs b/ta_lib/strategies/signal/src/contrarian/rsi_v.rs new file mode 100644 index 00000000..70f465e3 --- /dev/null +++ b/ta_lib/strategies/signal/src/contrarian/rsi_v.rs @@ -0,0 +1,40 @@ +use base::prelude::*; +use core::prelude::*; +use momentum::rsi; +use osc::v; +use timeseries::prelude::*; + +const RSI_UPPER_BARRIER: f32 = 80.0; +const RSI_LOWER_BARRIER: f32 = 20.0; + +pub struct RsiVSignal { + source: SourceType, + smooth: Smooth, + period: usize, + threshold: f32, +} + +impl RsiVSignal { + pub fn new(source: SourceType, smooth: Smooth, period: f32, threshold: f32) -> Self { + Self { + source, + smooth, + period: period as usize, + threshold, + } + } +} + +impl Signal for RsiVSignal { + fn lookback(&self) -> usize { + self.period + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + v!( + rsi(&data.source(self.source), self.smooth, self.period), + RSI_LOWER_BARRIER + self.threshold, + RSI_UPPER_BARRIER - self.threshold + ) + } +} diff --git a/ta_lib/strategies/signal/src/reversal/snatr_reversal.rs b/ta_lib/strategies/signal/src/contrarian/snatr.rs similarity index 76% rename from ta_lib/strategies/signal/src/reversal/snatr_reversal.rs rename to ta_lib/strategies/signal/src/contrarian/snatr.rs index b8f5cb9e..4255567c 100644 --- a/ta_lib/strategies/signal/src/reversal/snatr_reversal.rs +++ b/ta_lib/strategies/signal/src/contrarian/snatr.rs @@ -1,18 +1,19 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use volatility::snatr; -const SNATR_UPPER_BARRIER: f32 = 0.8; -const SNATR_LOWER_BARRIER: f32 = 0.2; +const SNATR_UPPER_BARRIER: f32 = 80.; +const SNATR_LOWER_BARRIER: f32 = 20.; -pub struct SnatrReversalSignal { +pub struct SnatrSignal { smooth_type: Smooth, atr_period: usize, atr_smooth_period: usize, threshold: f32, } -impl SnatrReversalSignal { +impl SnatrSignal { pub fn new( smooth_type: Smooth, atr_period: f32, @@ -28,14 +29,14 @@ impl SnatrReversalSignal { } } -impl Signal for SnatrReversalSignal { +impl Signal for SnatrSignal { fn lookback(&self) -> usize { std::cmp::max(self.atr_period, self.atr_smooth_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let snatr = snatr( - &data.atr(self.atr_period), + &data.atr(Smooth::SMMA, self.atr_period), self.atr_period, self.smooth_type, self.atr_smooth_period, diff --git a/ta_lib/strategies/signal/src/contrarian/stoch_e.rs b/ta_lib/strategies/signal/src/contrarian/stoch_e.rs new file mode 100644 index 00000000..71241ce2 --- /dev/null +++ b/ta_lib/strategies/signal/src/contrarian/stoch_e.rs @@ -0,0 +1,64 @@ +use base::prelude::*; +use core::prelude::*; +use momentum::stochosc; +use osc::c; +use timeseries::prelude::*; + +const STOCH_E_UPPER_BARRIER: f32 = 95.0; +const STOCH_E_LOWER_BARRIER: f32 = 5.0; +const STOCH_OVERBOUGHT: f32 = 70.0; +const STOCH_OVERSOLD: f32 = 30.0; + +pub struct StochESignal { + source: SourceType, + smooth: Smooth, + period: usize, + period_k: usize, + period_d: usize, + threshold: f32, +} + +impl StochESignal { + pub fn new( + source: SourceType, + smooth: Smooth, + period: f32, + period_k: f32, + period_d: f32, + threshold: f32, + ) -> Self { + Self { + source, + smooth, + period: period as usize, + period_k: period_k as usize, + period_d: period_d as usize, + threshold, + } + } +} + +impl Signal for StochESignal { + fn lookback(&self) -> usize { + std::cmp::max(std::cmp::max(self.period, self.period_k), self.period_d) + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let (k, _) = stochosc( + &data.source(self.source), + data.high(), + data.low(), + self.smooth, + self.period, + self.period_k, + self.period_d, + ); + + let (st_lg, st_sh) = c!(k, STOCH_E_LOWER_BARRIER, STOCH_E_UPPER_BARRIER); + + ( + st_lg & k.slt(&STOCH_OVERSOLD), + st_sh & k.sgt(&STOCH_OVERBOUGHT), + ) + } +} diff --git a/ta_lib/strategies/signal/src/pattern/tii_v.rs b/ta_lib/strategies/signal/src/contrarian/tii_v.rs similarity index 71% rename from ta_lib/strategies/signal/src/pattern/tii_v.rs rename to ta_lib/strategies/signal/src/contrarian/tii_v.rs index 6627378b..c9aaef57 100644 --- a/ta_lib/strategies/signal/src/pattern/tii_v.rs +++ b/ta_lib/strategies/signal/src/contrarian/tii_v.rs @@ -1,27 +1,23 @@ use base::prelude::*; use core::prelude::*; use momentum::tii; +use timeseries::prelude::*; const TII_UPPER_BARRIER: f32 = 100.0; const TII_LOWER_BARRIER: f32 = 0.0; pub struct TiiVSignal { - source_type: SourceType, - smooth_type: Smooth, + source: SourceType, + smooth: Smooth, major_period: usize, minor_period: usize, } impl TiiVSignal { - pub fn new( - source_type: SourceType, - smooth_type: Smooth, - major_period: f32, - minor_period: f32, - ) -> Self { + pub fn new(source: SourceType, smooth: Smooth, major_period: f32, minor_period: f32) -> Self { Self { - source_type, - smooth_type, + source, + smooth, major_period: major_period as usize, minor_period: minor_period as usize, } @@ -33,10 +29,10 @@ impl Signal for TiiVSignal { std::cmp::max(self.minor_period, self.major_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let tii = tii( - &data.source(self.source_type), - self.smooth_type, + &data.source(self.source), + self.smooth, self.major_period, self.minor_period, ); diff --git a/ta_lib/strategies/signal/src/flip/ce.rs b/ta_lib/strategies/signal/src/flip/ce.rs new file mode 100644 index 00000000..3f207e12 --- /dev/null +++ b/ta_lib/strategies/signal/src/flip/ce.rs @@ -0,0 +1,48 @@ +use base::prelude::*; +use core::prelude::*; +use timeseries::prelude::*; +use trail::f; +use trend::ce; + +pub struct CeFlipSignal { + source_type: SourceType, + period: usize, + smooth_atr: Smooth, + period_atr: usize, + factor: f32, +} + +impl CeFlipSignal { + pub fn new( + source_type: SourceType, + period: f32, + smooth_atr: Smooth, + period_atr: f32, + factor: f32, + ) -> Self { + Self { + source_type, + period: period as usize, + smooth_atr, + period_atr: period_atr as usize, + factor, + } + } +} + +impl Signal for CeFlipSignal { + fn lookback(&self) -> usize { + std::cmp::max(self.period, self.period_atr) + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let (direction, _) = ce( + &data.source(self.source_type), + &data.atr(self.smooth_atr, self.period_atr), + self.period, + self.factor, + ); + + f!(direction) + } +} diff --git a/ta_lib/strategies/signal/src/flip/ce_flip.rs b/ta_lib/strategies/signal/src/flip/ce_flip.rs deleted file mode 100644 index f1cac03d..00000000 --- a/ta_lib/strategies/signal/src/flip/ce_flip.rs +++ /dev/null @@ -1,36 +0,0 @@ -use base::prelude::*; -use core::prelude::*; -use trend::ce; - -pub struct CeFlipSignal { - period: usize, - atr_period: usize, - factor: f32, -} - -impl CeFlipSignal { - pub fn new(period: f32, atr_period: f32, factor: f32) -> Self { - Self { - period: period as usize, - atr_period: atr_period as usize, - factor, - } - } -} - -impl Signal for CeFlipSignal { - fn lookback(&self) -> usize { - std::cmp::max(self.period, self.atr_period) - } - - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { - let (direction, _) = ce( - data.close(), - &data.atr(self.atr_period), - self.period, - self.factor, - ); - - (direction.cross_under(&ZERO), direction.cross_over(&ZERO)) - } -} diff --git a/ta_lib/strategies/signal/src/flip/mod.rs b/ta_lib/strategies/signal/src/flip/mod.rs index 56f2478b..bc1b43d3 100644 --- a/ta_lib/strategies/signal/src/flip/mod.rs +++ b/ta_lib/strategies/signal/src/flip/mod.rs @@ -1,5 +1,5 @@ -mod ce_flip; -mod supertrend_flip; +mod ce; +mod supertrend; -pub use ce_flip::CeFlipSignal; -pub use supertrend_flip::SupertrendFlipSignal; +pub use ce::CeFlipSignal; +pub use supertrend::SupertrendFlipSignal; diff --git a/ta_lib/strategies/signal/src/flip/supertrend.rs b/ta_lib/strategies/signal/src/flip/supertrend.rs new file mode 100644 index 00000000..92457677 --- /dev/null +++ b/ta_lib/strategies/signal/src/flip/supertrend.rs @@ -0,0 +1,40 @@ +use base::prelude::*; +use core::prelude::*; +use timeseries::prelude::*; +use trail::f; +use trend::supertrend; + +pub struct SupertrendFlipSignal { + source: SourceType, + smooth_atr: Smooth, + period_atr: usize, + factor: f32, +} + +impl SupertrendFlipSignal { + pub fn new(source: SourceType, smooth_atr: Smooth, period_atr: f32, factor: f32) -> Self { + Self { + source, + smooth_atr, + period_atr: period_atr as usize, + factor, + } + } +} + +impl Signal for SupertrendFlipSignal { + fn lookback(&self) -> usize { + self.period_atr + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let (direction, _) = supertrend( + &data.source(self.source), + data.close(), + &data.atr(self.smooth_atr, self.period_atr), + self.factor, + ); + + f!(direction) + } +} diff --git a/ta_lib/strategies/signal/src/flip/supertrend_flip.rs b/ta_lib/strategies/signal/src/flip/supertrend_flip.rs deleted file mode 100644 index 12e8b9aa..00000000 --- a/ta_lib/strategies/signal/src/flip/supertrend_flip.rs +++ /dev/null @@ -1,357 +0,0 @@ -use base::prelude::*; -use core::prelude::*; -use trend::supertrend; - -pub struct SupertrendFlipSignal { - source_type: SourceType, - atr_period: usize, - factor: f32, -} - -impl SupertrendFlipSignal { - pub fn new(source_type: SourceType, atr_period: f32, factor: f32) -> Self { - Self { - source_type, - atr_period: atr_period as usize, - factor, - } - } -} - -impl Signal for SupertrendFlipSignal { - fn lookback(&self) -> usize { - self.atr_period - } - - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { - let (direction, _) = supertrend( - &data.source(self.source_type), - data.close(), - &data.atr(self.atr_period), - self.factor, - ); - - (direction.cross_over(&ZERO), direction.cross_under(&ZERO)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::collections::VecDeque; - - #[test] - fn test_supertrend_flip_signal() { - let signal = SupertrendFlipSignal::new(SourceType::HL2, 3.0, 3.0); - let data = VecDeque::from([ - OHLCV { - ts: 1679827200, - open: 6.161, - high: 6.161, - low: 6.136, - close: 6.146, - volume: 100.0, - }, - OHLCV { - ts: 1679827500, - open: 6.146, - high: 6.150, - low: 6.135, - close: 6.148, - volume: 100.0, - }, - OHLCV { - ts: 1679827800, - open: 6.148, - high: 6.157, - low: 6.143, - close: 6.155, - volume: 100.0, - }, - OHLCV { - ts: 1679828100, - open: 6.155, - high: 6.174, - low: 6.155, - close: 6.174, - volume: 100.0, - }, - OHLCV { - ts: 1679828400, - open: 6.174, - high: 6.179, - low: 6.163, - close: 6.173, - volume: 100.0, - }, - OHLCV { - ts: 1679828700, - open: 6.173, - high: 6.192, - low: 6.170, - close: 6.172, - volume: 100.0, - }, - OHLCV { - ts: 1679829000, - open: 6.172, - high: 6.184, - low: 6.167, - close: 6.182, - volume: 100.0, - }, - OHLCV { - ts: 1679829300, - open: 6.182, - high: 6.183, - low: 6.170, - close: 6.176, - volume: 100.0, - }, - OHLCV { - ts: 1679829600, - open: 6.176, - high: 6.185, - low: 6.161, - close: 6.167, - volume: 100.0, - }, - OHLCV { - ts: 1679829900, - open: 6.167, - high: 6.193, - low: 6.165, - close: 6.193, - volume: 100.0, - }, - OHLCV { - ts: 1679830200, - open: 6.193, - high: 6.213, - low: 6.188, - close: 6.201, - volume: 100.0, - }, - OHLCV { - ts: 1679830500, - open: 6.201, - high: 6.201, - low: 6.183, - close: 6.198, - volume: 100.0, - }, - OHLCV { - ts: 1679830800, - open: 6.198, - high: 6.205, - low: 6.186, - close: 6.188, - volume: 100.0, - }, - OHLCV { - ts: 1679831100, - open: 6.188, - high: 6.188, - low: 6.168, - close: 6.174, - volume: 100.0, - }, - OHLCV { - ts: 1679831400, - open: 6.174, - high: 6.180, - low: 6.164, - close: 6.176, - volume: 100.0, - }, - OHLCV { - ts: 1679831700, - open: 6.176, - high: 6.194, - low: 6.176, - close: 6.191, - volume: 100.0, - }, - OHLCV { - ts: 1679832000, - open: 6.191, - high: 6.191, - low: 6.169, - close: 6.175, - volume: 100.0, - }, - OHLCV { - ts: 1679832300, - open: 6.175, - high: 6.184, - low: 6.175, - close: 6.184, - volume: 100.0, - }, - OHLCV { - ts: 1679832600, - open: 6.184, - high: 6.194, - low: 6.176, - close: 6.188, - volume: 100.0, - }, - OHLCV { - ts: 1679832900, - open: 6.188, - high: 6.188, - low: 6.171, - close: 6.179, - volume: 100.0, - }, - OHLCV { - ts: 1679833200, - open: 6.179, - high: 6.188, - low: 6.171, - close: 6.184, - volume: 100.0, - }, - OHLCV { - ts: 1679833500, - open: 6.184, - high: 6.195, - low: 6.182, - close: 6.195, - volume: 100.0, - }, - OHLCV { - ts: 1679833800, - open: 6.195, - high: 6.212, - low: 6.193, - close: 6.210, - volume: 100.0, - }, - OHLCV { - ts: 1679834100, - open: 6.210, - high: 6.210, - low: 6.180, - close: 6.192, - volume: 100.0, - }, - OHLCV { - ts: 1679834400, - open: 6.192, - high: 6.193, - low: 6.152, - close: 6.173, - volume: 100.0, - }, - OHLCV { - ts: 1679834700, - open: 6.173, - high: 6.178, - low: 6.161, - close: 6.174, - volume: 100.0, - }, - OHLCV { - ts: 1679835000, - open: 6.174, - high: 6.189, - low: 6.161, - close: 6.189, - volume: 100.0, - }, - OHLCV { - ts: 1679835300, - open: 6.189, - high: 6.197, - low: 6.183, - close: 6.194, - volume: 100.0, - }, - OHLCV { - ts: 1679835600, - open: 6.194, - high: 6.205, - low: 6.189, - close: 6.202, - volume: 100.0, - }, - OHLCV { - ts: 1679835900, - open: 6.202, - high: 6.232, - low: 6.193, - close: 6.231, - volume: 100.0, - }, - OHLCV { - ts: 1679836200, - open: 6.231, - high: 6.236, - low: 6.215, - close: 6.218, - volume: 100.0, - }, - OHLCV { - ts: 1679836500, - open: 6.218, - high: 6.222, - low: 6.205, - close: 6.208, - volume: 100.0, - }, - OHLCV { - ts: 1679836800, - open: 6.208, - high: 6.233, - low: 6.208, - close: 6.224, - volume: 100.0, - }, - OHLCV { - ts: 1679837100, - open: 6.224, - high: 6.231, - low: 6.213, - close: 6.220, - volume: 100.0, - }, - OHLCV { - ts: 1679837400, - open: 6.220, - high: 6.224, - low: 6.196, - close: 6.208, - volume: 100.0, - }, - OHLCV { - ts: 1679837700, - open: 6.208, - high: 6.219, - low: 6.202, - close: 6.204, - volume: 100.0, - }, - ]); - let series = OHLCVSeries::from_data(&data); - - let (long_signal, short_signal) = signal.generate(&series); - - let expected_long_signal = vec![ - false, false, false, false, false, false, false, false, false, false, false, false, - false, false, false, false, false, false, false, false, false, false, true, false, - false, false, false, false, false, false, false, false, false, false, false, false, - ]; - let expected_short_signal = vec![ - false, false, false, false, false, false, false, false, false, false, false, false, - false, false, false, false, false, false, false, false, false, false, false, false, - false, false, false, false, false, false, false, false, false, false, false, false, - ]; - - let result_long_signal: Vec = long_signal.into(); - let result_short_signal: Vec = short_signal.into(); - - assert_eq!(result_long_signal, expected_long_signal); - assert_eq!(result_short_signal, expected_short_signal); - } -} diff --git a/ta_lib/strategies/signal/src/lib.rs b/ta_lib/strategies/signal/src/lib.rs index c740f668..e71eb0ad 100644 --- a/ta_lib/strategies/signal/src/lib.rs +++ b/ta_lib/strategies/signal/src/lib.rs @@ -1,19 +1,25 @@ mod bb; mod breakout; +mod colorswitch; +mod contrarian; mod flip; mod ma; mod neutrality; mod pattern; -mod reversal; +mod pullback; mod signalline; +mod twolinescross; mod zerocross; pub use bb::*; pub use breakout::*; +pub use colorswitch::*; +pub use contrarian::*; pub use flip::*; pub use ma::*; pub use neutrality::*; pub use pattern::*; -pub use reversal::*; +pub use pullback::*; pub use signalline::*; +pub use twolinescross::*; pub use zerocross::*; diff --git a/ta_lib/strategies/signal/src/ma/ma2_rsi.rs b/ta_lib/strategies/signal/src/ma/ma2_rsi.rs index 07ba1b52..d8bbc60d 100644 --- a/ta_lib/strategies/signal/src/ma/ma2_rsi.rs +++ b/ta_lib/strategies/signal/src/ma/ma2_rsi.rs @@ -2,6 +2,7 @@ use base::prelude::*; use core::prelude::*; use indicator::{ma_indicator, MovingAverageType}; use momentum::rsi; +use timeseries::prelude::*; const RSI_UPPER_BARRIER: f32 = 85.0; const RSI_LOWER_BARRIER: f32 = 15.0; @@ -44,7 +45,7 @@ impl Signal for Ma2RsiSignal { std::cmp::max(adj_lookback, self.rsi_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let rsi = rsi( &data.source(self.source_type), self.smooth_type, diff --git a/ta_lib/strategies/signal/src/ma/ma3_cross.rs b/ta_lib/strategies/signal/src/ma/ma3_cross.rs index 069837b6..fd3ce32f 100644 --- a/ta_lib/strategies/signal/src/ma/ma3_cross.rs +++ b/ta_lib/strategies/signal/src/ma/ma3_cross.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use indicator::{ma_indicator, MovingAverageType}; +use timeseries::prelude::*; pub struct Ma3CrossSignal { source_type: SourceType, @@ -34,7 +35,7 @@ impl Signal for Ma3CrossSignal { std::cmp::max(adjusted_lookback, self.medium_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let short_ma = ma_indicator(&self.ma, data, self.source_type, self.fast_period); let medium_ma = ma_indicator(&self.ma, data, self.source_type, self.medium_period); let long_ma = ma_indicator(&self.ma, data, self.source_type, self.slow_period); diff --git a/ta_lib/strategies/signal/src/ma/ma_cross.rs b/ta_lib/strategies/signal/src/ma/ma_cross.rs index de605cee..09d9a552 100644 --- a/ta_lib/strategies/signal/src/ma/ma_cross.rs +++ b/ta_lib/strategies/signal/src/ma/ma_cross.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use indicator::{ma_indicator, MovingAverageType}; +use timeseries::prelude::*; pub struct MaCrossSignal { source_type: SourceType, @@ -23,7 +24,7 @@ impl Signal for MaCrossSignal { self.period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let ma = ma_indicator(&self.ma, data, self.source_type, self.period); (data.close().cross_over(&ma), data.close().cross_under(&ma)) diff --git a/ta_lib/strategies/signal/src/ma/ma_quadruple.rs b/ta_lib/strategies/signal/src/ma/ma_quadruple.rs index ec90e71d..3d3cc540 100644 --- a/ta_lib/strategies/signal/src/ma/ma_quadruple.rs +++ b/ta_lib/strategies/signal/src/ma/ma_quadruple.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use indicator::{ma_indicator, MovingAverageType}; +use timeseries::prelude::*; pub struct MaQuadrupleSignal { source_type: SourceType, @@ -23,7 +24,7 @@ impl Signal for MaQuadrupleSignal { self.period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let ma = ma_indicator(&self.ma, data, self.source_type, self.period); let prev_ma = ma.shift(1); diff --git a/ta_lib/strategies/signal/src/ma/ma_surpass.rs b/ta_lib/strategies/signal/src/ma/ma_surpass.rs index 6d3c6dc0..2359fbf4 100644 --- a/ta_lib/strategies/signal/src/ma/ma_surpass.rs +++ b/ta_lib/strategies/signal/src/ma/ma_surpass.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use indicator::{ma_indicator, MovingAverageType}; +use timeseries::prelude::*; pub struct MaSurpassSignal { source_type: SourceType, @@ -23,7 +24,7 @@ impl Signal for MaSurpassSignal { self.period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let ma = ma_indicator(&self.ma, data, self.source_type, self.period); let prev_ma = ma.shift(1); diff --git a/ta_lib/strategies/signal/src/ma/ma_testing_ground.rs b/ta_lib/strategies/signal/src/ma/ma_testing_ground.rs index d9a1f9b1..945d461b 100644 --- a/ta_lib/strategies/signal/src/ma/ma_testing_ground.rs +++ b/ta_lib/strategies/signal/src/ma/ma_testing_ground.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use indicator::{ma_indicator, MovingAverageType}; +use timeseries::prelude::*; pub struct MaTestingGroundSignal { source_type: SourceType, @@ -23,7 +24,7 @@ impl Signal for MaTestingGroundSignal { self.period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let ma = ma_indicator(&self.ma, data, self.source_type, self.period); let prev_ma = ma.shift(1); diff --git a/ta_lib/strategies/signal/src/ma/vwap_cross.rs b/ta_lib/strategies/signal/src/ma/vwap_cross.rs index 36340b8b..5de9a799 100644 --- a/ta_lib/strategies/signal/src/ma/vwap_cross.rs +++ b/ta_lib/strategies/signal/src/ma/vwap_cross.rs @@ -1,5 +1,6 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use volume::vwap; pub struct VwapCrossSignal { @@ -21,7 +22,7 @@ impl Signal for VwapCrossSignal { self.period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let vwap = vwap(&data.source(self.source_type), data.volume()); ( diff --git a/ta_lib/strategies/signal/src/neutrality/dso_neutrality_cross.rs b/ta_lib/strategies/signal/src/neutrality/dso_cross.rs similarity index 84% rename from ta_lib/strategies/signal/src/neutrality/dso_neutrality_cross.rs rename to ta_lib/strategies/signal/src/neutrality/dso_cross.rs index 9283923a..9f49c845 100644 --- a/ta_lib/strategies/signal/src/neutrality/dso_neutrality_cross.rs +++ b/ta_lib/strategies/signal/src/neutrality/dso_cross.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::dso; +use timeseries::prelude::*; pub struct DsoNeutralityCrossSignal { source_type: SourceType, @@ -34,7 +35,7 @@ impl Signal for DsoNeutralityCrossSignal { std::cmp::max(period, self.d_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let (k, _) = dso( &data.source(self.source_type), self.smooth_type, @@ -43,9 +44,6 @@ impl Signal for DsoNeutralityCrossSignal { self.d_period, ); - ( - k.cross_over(&NEUTRALITY_LINE), - k.cross_under(&NEUTRALITY_LINE), - ) + (k.cross_over(&NEUTRALITY), k.cross_under(&NEUTRALITY)) } } diff --git a/ta_lib/strategies/signal/src/neutrality/mod.rs b/ta_lib/strategies/signal/src/neutrality/mod.rs index 08eee2cb..7390fde7 100644 --- a/ta_lib/strategies/signal/src/neutrality/mod.rs +++ b/ta_lib/strategies/signal/src/neutrality/mod.rs @@ -1,11 +1,11 @@ -mod dso_neutrality_cross; -mod rsi_neutrality_cross; -mod rsi_neutrality_pullback; -mod rsi_neutrality_rejection; -mod tii_neutrality_cross; +mod dso_cross; +mod rsi_cross; +mod rsi_pullback; +mod rsi_rejection; +mod tii_cross; -pub use dso_neutrality_cross::DsoNeutralityCrossSignal; -pub use rsi_neutrality_cross::RsiNeutralityCrossSignal; -pub use rsi_neutrality_pullback::RsiNeutralityPullbackSignal; -pub use rsi_neutrality_rejection::RsiNeutralityRejectionSignal; -pub use tii_neutrality_cross::TiiNeutralityCrossSignal; +pub use dso_cross::DsoNeutralityCrossSignal; +pub use rsi_cross::RsiNeutralityCrossSignal; +pub use rsi_pullback::RsiNeutralityPullbackSignal; +pub use rsi_rejection::RsiNeutralityRejectionSignal; +pub use tii_cross::TiiNeutralityCrossSignal; diff --git a/ta_lib/strategies/signal/src/neutrality/rsi_neutrality_cross.rs b/ta_lib/strategies/signal/src/neutrality/rsi_cross.rs similarity index 62% rename from ta_lib/strategies/signal/src/neutrality/rsi_neutrality_cross.rs rename to ta_lib/strategies/signal/src/neutrality/rsi_cross.rs index 46fc16d5..d42333df 100644 --- a/ta_lib/strategies/signal/src/neutrality/rsi_neutrality_cross.rs +++ b/ta_lib/strategies/signal/src/neutrality/rsi_cross.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::rsi; +use timeseries::prelude::*; pub struct RsiNeutralityCrossSignal { source_type: SourceType, @@ -30,14 +31,14 @@ impl Signal for RsiNeutralityCrossSignal { self.rsi_period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let rsi = rsi( &data.source(self.source_type), self.smooth_type, self.rsi_period, ); - let upper_neutrality = NEUTRALITY_LINE + self.threshold; - let lower_neutrality = NEUTRALITY_LINE - self.threshold; + let upper_neutrality = NEUTRALITY + self.threshold; + let lower_neutrality = NEUTRALITY - self.threshold; let prev_rsi = rsi.shift(1); let back_2_rsi = rsi.shift(2); @@ -46,15 +47,15 @@ impl Signal for RsiNeutralityCrossSignal { ( rsi.sgt(&upper_neutrality) - & prev_rsi.sgt(&NEUTRALITY_LINE) - & back_2_rsi.slt(&NEUTRALITY_LINE) - & back_3_rsi.slt(&NEUTRALITY_LINE) - & back_4_rsi.slt(&NEUTRALITY_LINE), + & prev_rsi.sgt(&NEUTRALITY) + & back_2_rsi.slt(&NEUTRALITY) + & back_3_rsi.slt(&NEUTRALITY) + & back_4_rsi.slt(&NEUTRALITY), rsi.slt(&lower_neutrality) - & prev_rsi.slt(&NEUTRALITY_LINE) - & back_2_rsi.sgt(&NEUTRALITY_LINE) - & back_3_rsi.sgt(&NEUTRALITY_LINE) - & back_4_rsi.sgt(&NEUTRALITY_LINE), + & prev_rsi.slt(&NEUTRALITY) + & back_2_rsi.sgt(&NEUTRALITY) + & back_3_rsi.sgt(&NEUTRALITY) + & back_4_rsi.sgt(&NEUTRALITY), ) } } diff --git a/ta_lib/strategies/signal/src/neutrality/rsi_neutrality_pullback.rs b/ta_lib/strategies/signal/src/neutrality/rsi_pullback.rs similarity index 69% rename from ta_lib/strategies/signal/src/neutrality/rsi_neutrality_pullback.rs rename to ta_lib/strategies/signal/src/neutrality/rsi_pullback.rs index fc54a0a6..8e53fc0e 100644 --- a/ta_lib/strategies/signal/src/neutrality/rsi_neutrality_pullback.rs +++ b/ta_lib/strategies/signal/src/neutrality/rsi_pullback.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::rsi; +use timeseries::prelude::*; pub struct RsiNeutralityPullbackSignal { source_type: SourceType, @@ -30,30 +31,30 @@ impl Signal for RsiNeutralityPullbackSignal { self.rsi_period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let rsi = rsi( &data.source(self.source_type), self.smooth_type, self.rsi_period, ); - let upper_neutrality = NEUTRALITY_LINE + self.threshold; - let lower_neutrality = NEUTRALITY_LINE - self.threshold; + let upper_neutrality = NEUTRALITY + self.threshold; + let lower_neutrality = NEUTRALITY - self.threshold; let prev_rsi = rsi.shift(1); let back_2_rsi = rsi.shift(2); let back_3_rsi = rsi.shift(3); ( - prev_rsi.sgt(&NEUTRALITY_LINE) + prev_rsi.sgt(&NEUTRALITY) & prev_rsi.slt(&lower_neutrality) & prev_rsi.slt(&back_2_rsi) - & back_2_rsi.sgt(&NEUTRALITY_LINE) - & back_3_rsi.slt(&NEUTRALITY_LINE), - prev_rsi.slt(&NEUTRALITY_LINE) + & back_2_rsi.sgt(&NEUTRALITY) + & back_3_rsi.slt(&NEUTRALITY), + prev_rsi.slt(&NEUTRALITY) & prev_rsi.sgt(&upper_neutrality) & prev_rsi.sgt(&back_2_rsi) - & back_2_rsi.slt(&NEUTRALITY_LINE) - & back_3_rsi.sgt(&NEUTRALITY_LINE), + & back_2_rsi.slt(&NEUTRALITY) + & back_3_rsi.sgt(&NEUTRALITY), ) } } diff --git a/ta_lib/strategies/signal/src/neutrality/rsi_neutrality_rejection.rs b/ta_lib/strategies/signal/src/neutrality/rsi_rejection.rs similarity index 66% rename from ta_lib/strategies/signal/src/neutrality/rsi_neutrality_rejection.rs rename to ta_lib/strategies/signal/src/neutrality/rsi_rejection.rs index 85dd2343..0ac1a527 100644 --- a/ta_lib/strategies/signal/src/neutrality/rsi_neutrality_rejection.rs +++ b/ta_lib/strategies/signal/src/neutrality/rsi_rejection.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::rsi; +use timeseries::prelude::*; pub struct RsiNeutralityRejectionSignal { source_type: SourceType, @@ -30,14 +31,14 @@ impl Signal for RsiNeutralityRejectionSignal { self.rsi_period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let rsi = rsi( &data.source(self.source_type), self.smooth_type, self.rsi_period, ); - let upper_neutrality = NEUTRALITY_LINE + self.threshold; - let lower_neutrality = NEUTRALITY_LINE - self.threshold; + let upper_neutrality = NEUTRALITY + self.threshold; + let lower_neutrality = NEUTRALITY - self.threshold; let prev_rsi = rsi.shift(1); let back_2_rsi = rsi.shift(2); @@ -45,13 +46,13 @@ impl Signal for RsiNeutralityRejectionSignal { ( rsi.sgt(&upper_neutrality) - & prev_rsi.slt(&NEUTRALITY_LINE) - & back_2_rsi.sgt(&NEUTRALITY_LINE) - & back_3_rsi.sgt(&NEUTRALITY_LINE), + & prev_rsi.slt(&NEUTRALITY) + & back_2_rsi.sgt(&NEUTRALITY) + & back_3_rsi.sgt(&NEUTRALITY), rsi.slt(&lower_neutrality) - & prev_rsi.sgt(&NEUTRALITY_LINE) - & back_2_rsi.slt(&NEUTRALITY_LINE) - & back_3_rsi.slt(&NEUTRALITY_LINE), + & prev_rsi.sgt(&NEUTRALITY) + & back_2_rsi.slt(&NEUTRALITY) + & back_3_rsi.slt(&NEUTRALITY), ) } } diff --git a/ta_lib/strategies/signal/src/neutrality/tii_neutrality_cross.rs b/ta_lib/strategies/signal/src/neutrality/tii_cross.rs similarity index 82% rename from ta_lib/strategies/signal/src/neutrality/tii_neutrality_cross.rs rename to ta_lib/strategies/signal/src/neutrality/tii_cross.rs index 163a5e56..0d8e8566 100644 --- a/ta_lib/strategies/signal/src/neutrality/tii_neutrality_cross.rs +++ b/ta_lib/strategies/signal/src/neutrality/tii_cross.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::tii; +use timeseries::prelude::*; pub struct TiiNeutralityCrossSignal { source_type: SourceType, @@ -30,7 +31,7 @@ impl Signal for TiiNeutralityCrossSignal { std::cmp::max(self.minor_period, self.major_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let tii = tii( &data.source(self.source_type), self.smooth_type, @@ -38,9 +39,6 @@ impl Signal for TiiNeutralityCrossSignal { self.minor_period, ); - ( - tii.cross_over(&NEUTRALITY_LINE), - tii.cross_under(&NEUTRALITY_LINE), - ) + (tii.cross_over(&NEUTRALITY), tii.cross_under(&NEUTRALITY)) } } diff --git a/ta_lib/strategies/signal/src/pattern/ao_saucer.rs b/ta_lib/strategies/signal/src/pattern/ao_saucer.rs index 50d8a82d..ab8516db 100644 --- a/ta_lib/strategies/signal/src/pattern/ao_saucer.rs +++ b/ta_lib/strategies/signal/src/pattern/ao_saucer.rs @@ -1,59 +1,51 @@ use base::prelude::*; use core::prelude::*; -use momentum::ao; +use timeseries::prelude::*; pub struct AoSaucerSignal { - source_type: SourceType, - smooth_type: Smooth, - fast_period: usize, - slow_period: usize, + source: SourceType, + smooth: Smooth, + period_fast: usize, + period_slow: usize, } impl AoSaucerSignal { - pub fn new( - source_type: SourceType, - smooth_type: Smooth, - fast_period: f32, - slow_period: f32, - ) -> Self { + pub fn new(source: SourceType, smooth: Smooth, period_fast: f32, period_slow: f32) -> Self { Self { - source_type, - smooth_type, - fast_period: fast_period as usize, - slow_period: slow_period as usize, + source, + smooth, + period_fast: period_fast as usize, + period_slow: period_slow as usize, } } } impl Signal for AoSaucerSignal { fn lookback(&self) -> usize { - std::cmp::max(self.fast_period, self.slow_period) + std::cmp::max(self.period_fast, self.period_slow) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { - let ao = ao( - &data.source(self.source_type), - self.smooth_type, - self.fast_period, - self.slow_period, - ); + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let ao = data + .source(self.source) + .spread(self.smooth, self.period_fast, self.period_slow); let diff = &ao - ao.shift(1); let prev_diff = diff.shift(1); let back_2_diff = diff.shift(2); ( - ao.sgt(&ZERO_LINE) - & diff.sgt(&ZERO_LINE) + ao.sgt(&ZERO) + & diff.sgt(&ZERO) & diff.sgt(&prev_diff) - & prev_diff.slt(&ZERO_LINE) - & back_2_diff.slt(&ZERO_LINE) + & prev_diff.slt(&ZERO) + & back_2_diff.slt(&ZERO) & prev_diff.slt(&back_2_diff), - ao.slt(&ZERO_LINE) - & diff.slt(&ZERO_LINE) + ao.slt(&ZERO) + & diff.slt(&ZERO) & diff.slt(&prev_diff) - & prev_diff.sgt(&ZERO_LINE) - & back_2_diff.sgt(&ZERO_LINE) + & prev_diff.sgt(&ZERO) + & back_2_diff.sgt(&ZERO) & prev_diff.slt(&back_2_diff), ) } diff --git a/ta_lib/strategies/signal/src/pattern/candlestick_reversal.rs b/ta_lib/strategies/signal/src/pattern/candlestick_reversal.rs new file mode 100644 index 00000000..98a83f3c --- /dev/null +++ b/ta_lib/strategies/signal/src/pattern/candlestick_reversal.rs @@ -0,0 +1,26 @@ +use base::prelude::*; +use core::prelude::*; +use indicator::{candlestick_reversal_indicator, CandleReversalType}; +use timeseries::prelude::*; + +const DEFAULT_LOOKBACK: usize = 200; + +pub struct CandlestickReversalSignal { + candle: CandleReversalType, +} + +impl CandlestickReversalSignal { + pub fn new(candle: CandleReversalType) -> Self { + Self { candle } + } +} + +impl Signal for CandlestickReversalSignal { + fn lookback(&self) -> usize { + DEFAULT_LOOKBACK + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + candlestick_reversal_indicator(&self.candle, data) + } +} diff --git a/ta_lib/strategies/signal/src/pattern/candlestick_trend.rs b/ta_lib/strategies/signal/src/pattern/candlestick_trend.rs index b7c81ab1..7eba6d1e 100644 --- a/ta_lib/strategies/signal/src/pattern/candlestick_trend.rs +++ b/ta_lib/strategies/signal/src/pattern/candlestick_trend.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use indicator::{candlestick_trend_indicator, CandleTrendType}; +use timeseries::prelude::*; const DEFAULT_LOOKBACK: usize = 13; @@ -19,7 +20,7 @@ impl Signal for CandlestickTrendSignal { DEFAULT_LOOKBACK } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { candlestick_trend_indicator(&self.candle, data) } } diff --git a/ta_lib/strategies/signal/src/pattern/hl.rs b/ta_lib/strategies/signal/src/pattern/hl.rs index 5417a536..57f074b6 100644 --- a/ta_lib/strategies/signal/src/pattern/hl.rs +++ b/ta_lib/strategies/signal/src/pattern/hl.rs @@ -1,5 +1,6 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; pub struct HighLowSignal { period: usize, @@ -18,7 +19,7 @@ impl Signal for HighLowSignal { self.period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { ( data.low().seq(&data.low().shift(1)), data.high().seq(&data.high().shift(1)), diff --git a/ta_lib/strategies/signal/src/pattern/mod.rs b/ta_lib/strategies/signal/src/pattern/mod.rs index f7701b1b..5449f7ac 100644 --- a/ta_lib/strategies/signal/src/pattern/mod.rs +++ b/ta_lib/strategies/signal/src/pattern/mod.rs @@ -1,13 +1,11 @@ mod ao_saucer; +mod candlestick_reversal; mod candlestick_trend; mod hl; -mod macd_colorswitch; -mod rsi_v; -mod tii_v; +mod spread; pub use ao_saucer::AoSaucerSignal; +pub use candlestick_reversal::CandlestickReversalSignal; pub use candlestick_trend::CandlestickTrendSignal; pub use hl::HighLowSignal; -pub use macd_colorswitch::MacdColorSwitchSignal; -pub use rsi_v::RsiVSignal; -pub use tii_v::TiiVSignal; +pub use spread::SpreadSignal; diff --git a/ta_lib/strategies/signal/src/pattern/rsi_v.rs b/ta_lib/strategies/signal/src/pattern/rsi_v.rs deleted file mode 100644 index 3a6aa837..00000000 --- a/ta_lib/strategies/signal/src/pattern/rsi_v.rs +++ /dev/null @@ -1,57 +0,0 @@ -use base::prelude::*; -use core::prelude::*; -use momentum::rsi; - -const RSI_UPPER_BARRIER: f32 = 80.0; -const RSI_LOWER_BARRIER: f32 = 20.0; - -pub struct RsiVSignal { - source_type: SourceType, - smooth_type: Smooth, - rsi_period: usize, - threshold: f32, -} - -impl RsiVSignal { - pub fn new( - source_type: SourceType, - smooth_type: Smooth, - rsi_period: f32, - threshold: f32, - ) -> Self { - Self { - source_type, - smooth_type, - rsi_period: rsi_period as usize, - threshold, - } - } -} - -impl Signal for RsiVSignal { - fn lookback(&self) -> usize { - self.rsi_period - } - - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { - let rsi = rsi( - &data.source(self.source_type), - self.smooth_type, - self.rsi_period, - ); - let lower_barrier = RSI_LOWER_BARRIER + self.threshold; - let upper_barrier = RSI_UPPER_BARRIER - self.threshold; - - let prev_rsi = rsi.shift(1); - let rsi_2_back = rsi.shift(2); - - ( - rsi.sgt(&lower_barrier) - & prev_rsi.slt(&RSI_LOWER_BARRIER) - & rsi_2_back.sgt(&RSI_LOWER_BARRIER), - rsi.slt(&upper_barrier) - & prev_rsi.sgt(&RSI_UPPER_BARRIER) - & rsi_2_back.slt(&RSI_UPPER_BARRIER), - ) - } -} diff --git a/ta_lib/strategies/signal/src/pattern/spread.rs b/ta_lib/strategies/signal/src/pattern/spread.rs new file mode 100644 index 00000000..973eb8f8 --- /dev/null +++ b/ta_lib/strategies/signal/src/pattern/spread.rs @@ -0,0 +1,40 @@ +use base::prelude::*; +use core::prelude::*; +use timeseries::prelude::*; + +pub struct SpreadSignal { + source: SourceType, + smooth: Smooth, + period_fast: usize, + period_slow: usize, +} + +impl SpreadSignal { + pub fn new(source: SourceType, smooth: Smooth, period_fast: f32, period_slow: f32) -> Self { + Self { + source, + smooth, + period_fast: period_fast as usize, + period_slow: period_slow as usize, + } + } +} + +impl Signal for SpreadSignal { + fn lookback(&self) -> usize { + std::cmp::max(self.period_fast, self.period_slow) + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let spread = + data.source(self.source) + .spread(self.smooth, self.period_fast, self.period_slow); + + let prev_spread = spread.shift(1); + + ( + spread.cross_over(&prev_spread), + spread.cross_under(&prev_spread), + ) + } +} diff --git a/ta_lib/strategies/signal/src/pullback/mod.rs b/ta_lib/strategies/signal/src/pullback/mod.rs new file mode 100644 index 00000000..82b0cf35 --- /dev/null +++ b/ta_lib/strategies/signal/src/pullback/mod.rs @@ -0,0 +1,3 @@ +mod supertrend; + +pub use supertrend::SupertrendPullbackSignal; diff --git a/ta_lib/strategies/signal/src/pullback/supertrend.rs b/ta_lib/strategies/signal/src/pullback/supertrend.rs new file mode 100644 index 00000000..2a3e087f --- /dev/null +++ b/ta_lib/strategies/signal/src/pullback/supertrend.rs @@ -0,0 +1,39 @@ +use base::prelude::*; +use core::prelude::*; +use timeseries::prelude::*; +use trail::p; +use trend::supertrend; + +pub struct SupertrendPullbackSignal { + source: SourceType, + smooth_atr: Smooth, + period_atr: usize, + factor: f32, +} + +impl SupertrendPullbackSignal { + pub fn new(source: SourceType, smooth_atr: Smooth, period_atr: f32, factor: f32) -> Self { + Self { + source, + smooth_atr, + period_atr: period_atr as usize, + factor, + } + } +} + +impl Signal for SupertrendPullbackSignal { + fn lookback(&self) -> usize { + self.period_atr + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let source = data.source(self.source); + let atr = data.atr(self.smooth_atr, self.period_atr); + let close = data.close(); + + let (_, trend) = supertrend(&source, close, &atr, self.factor); + + p!(trend, data.high(), data.low(), close) + } +} diff --git a/ta_lib/strategies/signal/src/reversal/mod.rs b/ta_lib/strategies/signal/src/reversal/mod.rs deleted file mode 100644 index ee85c8a3..00000000 --- a/ta_lib/strategies/signal/src/reversal/mod.rs +++ /dev/null @@ -1,7 +0,0 @@ -mod dmi_reversal; -mod snatr_reversal; -mod vi_reversal; - -pub use dmi_reversal::DmiReversalSignal; -pub use snatr_reversal::SnatrReversalSignal; -pub use vi_reversal::ViReversalSignal; diff --git a/ta_lib/strategies/signal/src/signalline/di_signalline.rs b/ta_lib/strategies/signal/src/signalline/di.rs similarity index 90% rename from ta_lib/strategies/signal/src/signalline/di_signalline.rs rename to ta_lib/strategies/signal/src/signalline/di.rs index 6d6b4c57..6de768f9 100644 --- a/ta_lib/strategies/signal/src/signalline/di_signalline.rs +++ b/ta_lib/strategies/signal/src/signalline/di.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::di; +use timeseries::prelude::*; pub struct DiSignalLineSignal { source_type: SourceType, @@ -30,7 +31,7 @@ impl Signal for DiSignalLineSignal { std::cmp::max(self.period, self.signal_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let di = di( &data.source(self.source_type), self.smooth_type, diff --git a/ta_lib/strategies/signal/src/signalline/dso_signalline.rs b/ta_lib/strategies/signal/src/signalline/dso.rs similarity index 91% rename from ta_lib/strategies/signal/src/signalline/dso_signalline.rs rename to ta_lib/strategies/signal/src/signalline/dso.rs index 983925ce..f62b24ea 100644 --- a/ta_lib/strategies/signal/src/signalline/dso_signalline.rs +++ b/ta_lib/strategies/signal/src/signalline/dso.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::dso; +use timeseries::prelude::*; pub struct DsoSignalLineSignal { source_type: SourceType, @@ -34,7 +35,7 @@ impl Signal for DsoSignalLineSignal { std::cmp::max(period, self.d_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let (k, d) = dso( &data.source(self.source_type), self.smooth_type, diff --git a/ta_lib/strategies/signal/src/signalline/kst_signalline.rs b/ta_lib/strategies/signal/src/signalline/kst.rs similarity index 96% rename from ta_lib/strategies/signal/src/signalline/kst_signalline.rs rename to ta_lib/strategies/signal/src/signalline/kst.rs index 785f1b12..0e6c9743 100644 --- a/ta_lib/strategies/signal/src/signalline/kst_signalline.rs +++ b/ta_lib/strategies/signal/src/signalline/kst.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::kst; +use timeseries::prelude::*; pub struct KstSignalLineSignal { source_type: SourceType, @@ -58,7 +59,7 @@ impl Signal for KstSignalLineSignal { std::cmp::max(adjusted_lookback_seven, self.signal_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let kst = kst( &data.source(self.source_type), self.smooth_type, diff --git a/ta_lib/strategies/signal/src/signalline/macd_signalline.rs b/ta_lib/strategies/signal/src/signalline/macd.rs similarity index 92% rename from ta_lib/strategies/signal/src/signalline/macd_signalline.rs rename to ta_lib/strategies/signal/src/signalline/macd.rs index a3415d32..da93ab61 100644 --- a/ta_lib/strategies/signal/src/signalline/macd_signalline.rs +++ b/ta_lib/strategies/signal/src/signalline/macd.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::macd; +use timeseries::prelude::*; pub struct MacdSignalLineSignal { source_type: SourceType, @@ -34,7 +35,7 @@ impl Signal for MacdSignalLineSignal { std::cmp::max(adj_lookback, self.signal_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let (macd_line, signal_line, _) = macd( &data.source(self.source_type), self.smooth_type, diff --git a/ta_lib/strategies/signal/src/signalline/mod.rs b/ta_lib/strategies/signal/src/signalline/mod.rs index b1757558..62c3c338 100644 --- a/ta_lib/strategies/signal/src/signalline/mod.rs +++ b/ta_lib/strategies/signal/src/signalline/mod.rs @@ -1,19 +1,19 @@ -mod di_signalline; -mod dso_signalline; -mod kst_signalline; -mod macd_signalline; -mod qstick_signalline; -mod rsi_signalline; -mod stoch_signalline; -mod trix_signalline; -mod tsi_signalline; +mod di; +mod dso; +mod kst; +mod macd; +mod qstick; +mod rsi; +mod stoch; +mod trix; +mod tsi; -pub use di_signalline::DiSignalLineSignal; -pub use dso_signalline::DsoSignalLineSignal; -pub use kst_signalline::KstSignalLineSignal; -pub use macd_signalline::MacdSignalLineSignal; -pub use qstick_signalline::QstickSignalLineSignal; -pub use rsi_signalline::RsiSignalLineSignal; -pub use stoch_signalline::StochSignalLineSignal; -pub use trix_signalline::TrixSignalLineSignal; -pub use tsi_signalline::TsiSignalLineSignal; +pub use di::DiSignalLineSignal; +pub use dso::DsoSignalLineSignal; +pub use kst::KstSignalLineSignal; +pub use macd::MacdSignalLineSignal; +pub use qstick::QstickSignalLineSignal; +pub use rsi::RsiSignalLineSignal; +pub use stoch::StochSignalLineSignal; +pub use trix::TrixSignalLineSignal; +pub use tsi::TsiSignalLineSignal; diff --git a/ta_lib/strategies/signal/src/signalline/qstick_signalline.rs b/ta_lib/strategies/signal/src/signalline/qstick.rs similarity index 86% rename from ta_lib/strategies/signal/src/signalline/qstick_signalline.rs rename to ta_lib/strategies/signal/src/signalline/qstick.rs index dcdca8d3..95b86aa0 100644 --- a/ta_lib/strategies/signal/src/signalline/qstick_signalline.rs +++ b/ta_lib/strategies/signal/src/signalline/qstick.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; -use trend::qstick; +use momentum::qstick; +use timeseries::prelude::*; pub struct QstickSignalLineSignal { smooth_type: Smooth, @@ -23,7 +24,7 @@ impl Signal for QstickSignalLineSignal { std::cmp::max(self.period, self.signal_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let qstick = qstick(data.open(), data.close(), self.smooth_type, self.period); let signal_line = qstick.smooth(self.smooth_type, self.signal_period); diff --git a/ta_lib/strategies/signal/src/signalline/rsi_signalline.rs b/ta_lib/strategies/signal/src/signalline/rsi.rs similarity index 88% rename from ta_lib/strategies/signal/src/signalline/rsi_signalline.rs rename to ta_lib/strategies/signal/src/signalline/rsi.rs index 88146a2d..4a6a0e06 100644 --- a/ta_lib/strategies/signal/src/signalline/rsi_signalline.rs +++ b/ta_lib/strategies/signal/src/signalline/rsi.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::rsi; +use timeseries::prelude::*; pub struct RsiSignalLineSignal { source_type: SourceType, @@ -36,7 +37,7 @@ impl Signal for RsiSignalLineSignal { std::cmp::max(self.rsi_period, self.smooth_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let rsi = rsi( &data.source(self.source_type), self.smooth_type, @@ -44,8 +45,8 @@ impl Signal for RsiSignalLineSignal { ); let rsi_ma = rsi.smooth(self.smooth_signal, self.smooth_period); - let upper_neutrality = NEUTRALITY_LINE + self.threshold; - let lower_neutrality = NEUTRALITY_LINE - self.threshold; + let upper_neutrality = NEUTRALITY + self.threshold; + let lower_neutrality = NEUTRALITY - self.threshold; let prev_rsi = rsi.shift(1); let back_2_rsi = rsi.shift(2); diff --git a/ta_lib/strategies/signal/src/signalline/stoch_signalline.rs b/ta_lib/strategies/signal/src/signalline/stoch.rs similarity index 90% rename from ta_lib/strategies/signal/src/signalline/stoch_signalline.rs rename to ta_lib/strategies/signal/src/signalline/stoch.rs index 7da9f2de..0b24f127 100644 --- a/ta_lib/strategies/signal/src/signalline/stoch_signalline.rs +++ b/ta_lib/strategies/signal/src/signalline/stoch.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::stochosc; +use timeseries::prelude::*; pub struct StochSignalLineSignal { smooth_type: Smooth, @@ -26,7 +27,7 @@ impl Signal for StochSignalLineSignal { std::cmp::max(adjusted_lookback, self.d_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let (k, d) = stochosc( data.high(), data.low(), diff --git a/ta_lib/strategies/signal/src/signalline/trix_signalline.rs b/ta_lib/strategies/signal/src/signalline/trix.rs similarity index 90% rename from ta_lib/strategies/signal/src/signalline/trix_signalline.rs rename to ta_lib/strategies/signal/src/signalline/trix.rs index 99004d55..d013fb9c 100644 --- a/ta_lib/strategies/signal/src/signalline/trix_signalline.rs +++ b/ta_lib/strategies/signal/src/signalline/trix.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::trix; +use timeseries::prelude::*; pub struct TrixSignalLineSignal { source_type: SourceType, @@ -30,7 +31,7 @@ impl Signal for TrixSignalLineSignal { std::cmp::max(self.period, self.signal_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let trix = trix( &data.source(self.source_type), self.smooth_type, diff --git a/ta_lib/strategies/signal/src/signalline/tsi_signalline.rs b/ta_lib/strategies/signal/src/signalline/tsi.rs similarity index 91% rename from ta_lib/strategies/signal/src/signalline/tsi_signalline.rs rename to ta_lib/strategies/signal/src/signalline/tsi.rs index 39271c6b..55f5aaaf 100644 --- a/ta_lib/strategies/signal/src/signalline/tsi_signalline.rs +++ b/ta_lib/strategies/signal/src/signalline/tsi.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::tsi; +use timeseries::prelude::*; pub struct TsiSignalLineSignal { source_type: SourceType, @@ -34,7 +35,7 @@ impl Signal for TsiSignalLineSignal { std::cmp::max(adj_lookback, self.signal_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let tsi = tsi( &data.source(self.source_type), self.smooth_type, diff --git a/ta_lib/strategies/signal/src/reversal/dmi_reversal.rs b/ta_lib/strategies/signal/src/twolinescross/dmi.rs similarity index 70% rename from ta_lib/strategies/signal/src/reversal/dmi_reversal.rs rename to ta_lib/strategies/signal/src/twolinescross/dmi.rs index 67aaee66..c4da7795 100644 --- a/ta_lib/strategies/signal/src/reversal/dmi_reversal.rs +++ b/ta_lib/strategies/signal/src/twolinescross/dmi.rs @@ -1,53 +1,51 @@ use base::prelude::*; use core::prelude::*; use momentum::dmi; +use timeseries::prelude::*; -pub struct DmiReversalSignal { - smooth_type: Smooth, - adx_period: usize, - di_period: usize, +pub struct Dmi2LinesCrossSignal { + smooth: Smooth, + period_adx: usize, + period_di: usize, } -impl DmiReversalSignal { - pub fn new(smooth_type: Smooth, adx_period: f32, di_period: f32) -> Self { +impl Dmi2LinesCrossSignal { + pub fn new(smooth: Smooth, period_adx: f32, period_di: f32) -> Self { Self { - smooth_type, - adx_period: adx_period as usize, - di_period: di_period as usize, + smooth, + period_adx: period_adx as usize, + period_di: period_di as usize, } } } -impl Signal for DmiReversalSignal { +impl Signal for Dmi2LinesCrossSignal { fn lookback(&self) -> usize { - std::cmp::max(self.adx_period, self.di_period) + std::cmp::max(self.period_adx, self.period_di) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { - let (_, di_plus, di_minus) = dmi( + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let (dip, dim, _) = dmi( data.high(), data.low(), - &data.atr(self.di_period), - self.smooth_type, - self.adx_period, - self.di_period, + &data.atr(self.smooth, self.period_di), + self.smooth, + self.period_adx, + self.period_di, ); - ( - di_plus.cross_over(&di_minus), - di_plus.cross_under(&di_minus), - ) + + (dip.cross_over(&dim), dip.cross_under(&dim)) } } #[cfg(test)] mod tests { use super::*; - use std::collections::VecDeque; #[test] - fn test_signal_dmi_reversal() { - let signal = DmiReversalSignal::new(Smooth::SMMA, 3.0, 3.0); - let data = VecDeque::from([ + fn test_signal_dmi_cross() { + let signal = Dmi2LinesCrossSignal::new(Smooth::SMMA, 3.0, 3.0); + let data = vec![ OHLCV { ts: 1679827200, open: 0.010631, @@ -112,10 +110,10 @@ mod tests { close: 0.010515, volume: 100.0, }, - ]); - let series = OHLCVSeries::from_data(&data); + ]; + let series = OHLCVSeries::from(data); - let (dip, dim) = signal.generate(&series); + let (dip, dim) = signal.trigger(&series); let expected_long_signal = vec![false, false, false, false, false, false, false, false]; let expected_short_signal = vec![false, false, false, false, false, true, false, false]; diff --git a/ta_lib/strategies/signal/src/twolinescross/mod.rs b/ta_lib/strategies/signal/src/twolinescross/mod.rs new file mode 100644 index 00000000..a67b4a50 --- /dev/null +++ b/ta_lib/strategies/signal/src/twolinescross/mod.rs @@ -0,0 +1,5 @@ +mod dmi; +mod vi; + +pub use dmi::Dmi2LinesCrossSignal; +pub use vi::Vi2LinesCrossSignal; diff --git a/ta_lib/strategies/signal/src/reversal/vi_reversal.rs b/ta_lib/strategies/signal/src/twolinescross/vi.rs similarity index 85% rename from ta_lib/strategies/signal/src/reversal/vi_reversal.rs rename to ta_lib/strategies/signal/src/twolinescross/vi.rs index 1aac539e..8088e852 100644 --- a/ta_lib/strategies/signal/src/reversal/vi_reversal.rs +++ b/ta_lib/strategies/signal/src/twolinescross/vi.rs @@ -1,31 +1,34 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use trend::vi; -pub struct ViReversalSignal { - atr_period: usize, +pub struct Vi2LinesCrossSignal { period: usize, + smooth_atr: Smooth, + period_atr: usize, } -impl ViReversalSignal { - pub fn new(atr_period: f32, period: f32) -> Self { +impl Vi2LinesCrossSignal { + pub fn new(period: f32, smooth_atr: Smooth, period_atr: f32) -> Self { Self { - atr_period: atr_period as usize, period: period as usize, + smooth_atr, + period_atr: period_atr as usize, } } } -impl Signal for ViReversalSignal { +impl Signal for Vi2LinesCrossSignal { fn lookback(&self) -> usize { - std::cmp::max(self.atr_period, self.period) + std::cmp::max(self.period_atr, self.period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let (vip, vim) = vi( data.high(), data.low(), - &data.atr(self.atr_period), + &data.atr(self.smooth_atr, self.period_atr), self.period, ); @@ -36,12 +39,11 @@ impl Signal for ViReversalSignal { #[cfg(test)] mod tests { use super::*; - use std::collections::VecDeque; #[test] - fn test_signal_vi_reversal() { - let signal = ViReversalSignal::new(1.0, 2.0); - let data = VecDeque::from([ + fn test_signal_vi_cross() { + let signal = Vi2LinesCrossSignal::new(2.0, Smooth::SMMA, 1.0); + let data = vec![ OHLCV { ts: 1679827200, open: 4.8914, @@ -162,10 +164,10 @@ mod tests { close: 4.8925, volume: 100.0, }, - ]); - let series = OHLCVSeries::from_data(&data); + ]; + let series = OHLCVSeries::from(data); - let (vip, vim) = signal.generate(&series); + let (vip, vim) = signal.trigger(&series); let expected_long_signal = vec![ false, false, false, false, true, false, false, false, false, false, false, false, diff --git a/ta_lib/strategies/signal/src/zerocross/ao.rs b/ta_lib/strategies/signal/src/zerocross/ao.rs new file mode 100644 index 00000000..75e78860 --- /dev/null +++ b/ta_lib/strategies/signal/src/zerocross/ao.rs @@ -0,0 +1,35 @@ +use base::prelude::*; +use core::prelude::*; +use timeseries::prelude::*; + +pub struct AoZeroCrossSignal { + source: SourceType, + smooth: Smooth, + period_fast: usize, + period_slow: usize, +} + +impl AoZeroCrossSignal { + pub fn new(source: SourceType, smooth: Smooth, period_fast: f32, period_slow: f32) -> Self { + Self { + source, + smooth, + period_fast: period_fast as usize, + period_slow: period_slow as usize, + } + } +} + +impl Signal for AoZeroCrossSignal { + fn lookback(&self) -> usize { + std::cmp::max(self.period_fast, self.period_slow) + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let ao = data + .source(self.source) + .spread(self.smooth, self.period_fast, self.period_slow); + + (ao.cross_over(&ZERO), ao.cross_under(&ZERO)) + } +} diff --git a/ta_lib/strategies/signal/src/zerocross/ao_zerocross.rs b/ta_lib/strategies/signal/src/zerocross/ao_zerocross.rs deleted file mode 100644 index 034e2d57..00000000 --- a/ta_lib/strategies/signal/src/zerocross/ao_zerocross.rs +++ /dev/null @@ -1,43 +0,0 @@ -use base::prelude::*; -use core::prelude::*; -use momentum::ao; - -pub struct AoZeroCrossSignal { - source_type: SourceType, - smooth_type: Smooth, - fast_period: usize, - slow_period: usize, -} - -impl AoZeroCrossSignal { - pub fn new( - source_type: SourceType, - smooth_type: Smooth, - fast_period: f32, - slow_period: f32, - ) -> Self { - Self { - source_type, - smooth_type, - fast_period: fast_period as usize, - slow_period: slow_period as usize, - } - } -} - -impl Signal for AoZeroCrossSignal { - fn lookback(&self) -> usize { - std::cmp::max(self.fast_period, self.slow_period) - } - - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { - let ao = ao( - &data.source(self.source_type), - self.smooth_type, - self.fast_period, - self.slow_period, - ); - - (ao.cross_over(&ZERO_LINE), ao.cross_under(&ZERO_LINE)) - } -} diff --git a/ta_lib/strategies/signal/src/zerocross/bop_zerocross.rs b/ta_lib/strategies/signal/src/zerocross/bop.rs similarity index 80% rename from ta_lib/strategies/signal/src/zerocross/bop_zerocross.rs rename to ta_lib/strategies/signal/src/zerocross/bop.rs index 02e03f96..ed7bbc2e 100644 --- a/ta_lib/strategies/signal/src/zerocross/bop_zerocross.rs +++ b/ta_lib/strategies/signal/src/zerocross/bop.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::bop; +use timeseries::prelude::*; pub struct BopZeroCrossSignal { smooth_type: Smooth, @@ -21,7 +22,7 @@ impl Signal for BopZeroCrossSignal { self.smooth_period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let bop = bop( data.open(), data.high(), @@ -31,6 +32,6 @@ impl Signal for BopZeroCrossSignal { self.smooth_period, ); - (bop.cross_over(&ZERO_LINE), bop.cross_under(&ZERO_LINE)) + (bop.cross_over(&ZERO), bop.cross_under(&ZERO)) } } diff --git a/ta_lib/strategies/signal/src/zerocross/cc_zerocross.rs b/ta_lib/strategies/signal/src/zerocross/cc.rs similarity index 87% rename from ta_lib/strategies/signal/src/zerocross/cc_zerocross.rs rename to ta_lib/strategies/signal/src/zerocross/cc.rs index 144826aa..5cfef2a8 100644 --- a/ta_lib/strategies/signal/src/zerocross/cc_zerocross.rs +++ b/ta_lib/strategies/signal/src/zerocross/cc.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::cc; +use timeseries::prelude::*; pub struct CcZeroCrossSignal { source_type: SourceType, @@ -34,7 +35,7 @@ impl Signal for CcZeroCrossSignal { std::cmp::max(adj_lookback, self.smooth_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let cc = cc( &data.source(self.source_type), self.fast_period, @@ -43,6 +44,6 @@ impl Signal for CcZeroCrossSignal { self.smooth_period, ); - (cc.cross_over(&ZERO_LINE), cc.cross_under(&ZERO_LINE)) + (cc.cross_over(&ZERO), cc.cross_under(&ZERO)) } } diff --git a/ta_lib/strategies/signal/src/zerocross/cfo_zerocross.rs b/ta_lib/strategies/signal/src/zerocross/cfo.rs similarity index 76% rename from ta_lib/strategies/signal/src/zerocross/cfo_zerocross.rs rename to ta_lib/strategies/signal/src/zerocross/cfo.rs index cc86aa48..66a4f67c 100644 --- a/ta_lib/strategies/signal/src/zerocross/cfo_zerocross.rs +++ b/ta_lib/strategies/signal/src/zerocross/cfo.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::cfo; +use timeseries::prelude::*; pub struct CfoZeroCrossSignal { source_type: SourceType, @@ -21,9 +22,9 @@ impl Signal for CfoZeroCrossSignal { self.period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let cfo = cfo(&data.source(self.source_type), self.period); - (cfo.cross_over(&ZERO_LINE), cfo.cross_under(&ZERO_LINE)) + (cfo.cross_over(&ZERO), cfo.cross_under(&ZERO)) } } diff --git a/ta_lib/strategies/signal/src/zerocross/di_zerocross.rs b/ta_lib/strategies/signal/src/zerocross/di.rs similarity index 80% rename from ta_lib/strategies/signal/src/zerocross/di_zerocross.rs rename to ta_lib/strategies/signal/src/zerocross/di.rs index 53918ac2..c3bbce1c 100644 --- a/ta_lib/strategies/signal/src/zerocross/di_zerocross.rs +++ b/ta_lib/strategies/signal/src/zerocross/di.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::di; +use timeseries::prelude::*; pub struct DiZeroCrossSignal { source_type: SourceType, @@ -23,13 +24,13 @@ impl Signal for DiZeroCrossSignal { self.period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let di = di( &data.source(self.source_type), self.smooth_type, self.period, ); - (di.cross_over(&ZERO_LINE), di.cross_under(&ZERO_LINE)) + (di.cross_over(&ZERO), di.cross_under(&ZERO)) } } diff --git a/ta_lib/strategies/signal/src/zerocross/macd_zerocross.rs b/ta_lib/strategies/signal/src/zerocross/macd.rs similarity index 85% rename from ta_lib/strategies/signal/src/zerocross/macd_zerocross.rs rename to ta_lib/strategies/signal/src/zerocross/macd.rs index a4105cf9..8b6dae09 100644 --- a/ta_lib/strategies/signal/src/zerocross/macd_zerocross.rs +++ b/ta_lib/strategies/signal/src/zerocross/macd.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::macd; +use timeseries::prelude::*; pub struct MacdZeroCrossSignal { source_type: SourceType, @@ -34,7 +35,7 @@ impl Signal for MacdZeroCrossSignal { std::cmp::max(adj_lookback, self.signal_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let (macd_line, _, _) = macd( &data.source(self.source_type), self.smooth_type, @@ -43,9 +44,6 @@ impl Signal for MacdZeroCrossSignal { self.signal_period, ); - ( - macd_line.cross_over(&ZERO_LINE), - macd_line.cross_under(&ZERO_LINE), - ) + (macd_line.cross_over(&ZERO), macd_line.cross_under(&ZERO)) } } diff --git a/ta_lib/strategies/signal/src/zerocross/mad.rs b/ta_lib/strategies/signal/src/zerocross/mad.rs new file mode 100644 index 00000000..fb5622f9 --- /dev/null +++ b/ta_lib/strategies/signal/src/zerocross/mad.rs @@ -0,0 +1,35 @@ +use base::prelude::*; +use core::prelude::*; +use timeseries::prelude::*; + +pub struct MadZeroCrossSignal { + source: SourceType, + smooth: Smooth, + period_fast: usize, + period_slow: usize, +} + +impl MadZeroCrossSignal { + pub fn new(source: SourceType, smooth: Smooth, period_fast: f32, period_slow: f32) -> Self { + Self { + source, + smooth, + period_fast: period_fast as usize, + period_slow: period_slow as usize, + } + } +} + +impl Signal for MadZeroCrossSignal { + fn lookback(&self) -> usize { + std::cmp::max(self.period_fast, self.period_slow) + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let mad = + data.source(self.source) + .spread_pct(self.smooth, self.period_fast, self.period_slow); + + (mad.cross_over(&ZERO), mad.cross_under(&ZERO)) + } +} diff --git a/ta_lib/strategies/signal/src/zerocross/mod.rs b/ta_lib/strategies/signal/src/zerocross/mod.rs index d702c0e2..8b88582d 100644 --- a/ta_lib/strategies/signal/src/zerocross/mod.rs +++ b/ta_lib/strategies/signal/src/zerocross/mod.rs @@ -1,21 +1,23 @@ -mod ao_zerocross; -mod bop_zerocross; -mod cc_zerocross; -mod cfo_zerocross; -mod di_zerocross; -mod macd_zerocross; -mod qstick_zerocross; -mod roc_zerocross; -mod trix_zerocross; -mod tsi_zerocross; +mod ao; +mod bop; +mod cc; +mod cfo; +mod di; +mod macd; +mod mad; +mod qstick; +mod roc; +mod trix; +mod tsi; -pub use ao_zerocross::AoZeroCrossSignal; -pub use bop_zerocross::BopZeroCrossSignal; -pub use cc_zerocross::CcZeroCrossSignal; -pub use cfo_zerocross::CfoZeroCrossSignal; -pub use di_zerocross::DiZeroCrossSignal; -pub use macd_zerocross::MacdZeroCrossSignal; -pub use qstick_zerocross::QstickZeroCrossSignal; -pub use roc_zerocross::RocZeroCrossSignal; -pub use trix_zerocross::TrixZeroCrossSignal; -pub use tsi_zerocross::TsiZeroCrossSignal; +pub use ao::AoZeroCrossSignal; +pub use bop::BopZeroCrossSignal; +pub use cc::CcZeroCrossSignal; +pub use cfo::CfoZeroCrossSignal; +pub use di::DiZeroCrossSignal; +pub use macd::MacdZeroCrossSignal; +pub use mad::MadZeroCrossSignal; +pub use qstick::QstickZeroCrossSignal; +pub use roc::RocZeroCrossSignal; +pub use trix::TrixZeroCrossSignal; +pub use tsi::TsiZeroCrossSignal; diff --git a/ta_lib/strategies/signal/src/zerocross/qstick_zerocross.rs b/ta_lib/strategies/signal/src/zerocross/qstick.rs similarity index 71% rename from ta_lib/strategies/signal/src/zerocross/qstick_zerocross.rs rename to ta_lib/strategies/signal/src/zerocross/qstick.rs index 87030b0b..86edbabd 100644 --- a/ta_lib/strategies/signal/src/zerocross/qstick_zerocross.rs +++ b/ta_lib/strategies/signal/src/zerocross/qstick.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; -use trend::qstick; +use momentum::qstick; +use timeseries::prelude::*; pub struct QstickZeroCrossSignal { smooth_type: Smooth, @@ -21,12 +22,9 @@ impl Signal for QstickZeroCrossSignal { self.period } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let qstick = qstick(data.open(), data.close(), self.smooth_type, self.period); - ( - qstick.cross_over(&ZERO_LINE), - qstick.cross_under(&ZERO_LINE), - ) + (qstick.cross_over(&ZERO), qstick.cross_under(&ZERO)) } } diff --git a/ta_lib/strategies/signal/src/zerocross/roc.rs b/ta_lib/strategies/signal/src/zerocross/roc.rs new file mode 100644 index 00000000..af96c955 --- /dev/null +++ b/ta_lib/strategies/signal/src/zerocross/roc.rs @@ -0,0 +1,30 @@ +use base::prelude::*; +use core::prelude::*; +use momentum::roc; +use timeseries::prelude::*; + +pub struct RocZeroCrossSignal { + source: SourceType, + period: usize, +} + +impl RocZeroCrossSignal { + pub fn new(source: SourceType, period: f32) -> Self { + Self { + source, + period: period as usize, + } + } +} + +impl Signal for RocZeroCrossSignal { + fn lookback(&self) -> usize { + self.period + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let roc = roc(&data.source(self.source), self.period); + + (roc.cross_over(&ZERO), roc.cross_under(&ZERO)) + } +} diff --git a/ta_lib/strategies/signal/src/zerocross/roc_zerocross.rs b/ta_lib/strategies/signal/src/zerocross/roc_zerocross.rs deleted file mode 100644 index f00e3da1..00000000 --- a/ta_lib/strategies/signal/src/zerocross/roc_zerocross.rs +++ /dev/null @@ -1,29 +0,0 @@ -use base::prelude::*; -use core::prelude::*; -use momentum::roc; - -pub struct RocZeroCrossSignal { - source_type: SourceType, - period: usize, -} - -impl RocZeroCrossSignal { - pub fn new(source_type: SourceType, period: f32) -> Self { - Self { - source_type, - period: period as usize, - } - } -} - -impl Signal for RocZeroCrossSignal { - fn lookback(&self) -> usize { - self.period - } - - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { - let roc = roc(&data.source(self.source_type), self.period); - - (roc.cross_over(&ZERO_LINE), roc.cross_under(&ZERO_LINE)) - } -} diff --git a/ta_lib/strategies/signal/src/zerocross/trix.rs b/ta_lib/strategies/signal/src/zerocross/trix.rs new file mode 100644 index 00000000..e29141cb --- /dev/null +++ b/ta_lib/strategies/signal/src/zerocross/trix.rs @@ -0,0 +1,32 @@ +use base::prelude::*; +use core::prelude::*; +use momentum::trix; +use timeseries::prelude::*; + +pub struct TrixZeroCrossSignal { + source: SourceType, + smooth: Smooth, + period: usize, +} + +impl TrixZeroCrossSignal { + pub fn new(source: SourceType, smooth: Smooth, period: f32) -> Self { + Self { + source, + smooth, + period: period as usize, + } + } +} + +impl Signal for TrixZeroCrossSignal { + fn lookback(&self) -> usize { + self.period + } + + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { + let trix = trix(&data.source(self.source), self.smooth, self.period); + + (trix.cross_over(&ZERO), trix.cross_under(&ZERO)) + } +} diff --git a/ta_lib/strategies/signal/src/zerocross/trix_zerocross.rs b/ta_lib/strategies/signal/src/zerocross/trix_zerocross.rs deleted file mode 100644 index a289eabe..00000000 --- a/ta_lib/strategies/signal/src/zerocross/trix_zerocross.rs +++ /dev/null @@ -1,35 +0,0 @@ -use base::prelude::*; -use core::prelude::*; -use momentum::trix; - -pub struct TrixZeroCrossSignal { - source_type: SourceType, - smooth_type: Smooth, - period: usize, -} - -impl TrixZeroCrossSignal { - pub fn new(source_type: SourceType, smooth_type: Smooth, period: f32) -> Self { - Self { - source_type, - smooth_type, - period: period as usize, - } - } -} - -impl Signal for TrixZeroCrossSignal { - fn lookback(&self) -> usize { - self.period - } - - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { - let trix = trix( - &data.source(self.source_type), - self.smooth_type, - self.period, - ); - - (trix.cross_over(&ZERO_LINE), trix.cross_under(&ZERO_LINE)) - } -} diff --git a/ta_lib/strategies/signal/src/zerocross/tsi_zerocross.rs b/ta_lib/strategies/signal/src/zerocross/tsi.rs similarity index 84% rename from ta_lib/strategies/signal/src/zerocross/tsi_zerocross.rs rename to ta_lib/strategies/signal/src/zerocross/tsi.rs index bcd15000..964ab739 100644 --- a/ta_lib/strategies/signal/src/zerocross/tsi_zerocross.rs +++ b/ta_lib/strategies/signal/src/zerocross/tsi.rs @@ -1,6 +1,7 @@ use base::prelude::*; use core::prelude::*; use momentum::tsi; +use timeseries::prelude::*; pub struct TsiZeroCrossSignal { source_type: SourceType, @@ -30,7 +31,7 @@ impl Signal for TsiZeroCrossSignal { std::cmp::max(self.fast_period, self.slow_period) } - fn generate(&self, data: &OHLCVSeries) -> (Series, Series) { + fn trigger(&self, data: &OHLCVSeries) -> (Series, Series) { let tsi = tsi( &data.source(self.source_type), self.smooth_type, @@ -38,6 +39,6 @@ impl Signal for TsiZeroCrossSignal { self.fast_period, ); - (tsi.cross_over(&ZERO_LINE), tsi.cross_under(&ZERO_LINE)) + (tsi.cross_over(&ZERO), tsi.cross_under(&ZERO)) } } diff --git a/ta_lib/strategies/stop_loss/Cargo.toml b/ta_lib/strategies/stop_loss/Cargo.toml index f8e5fa8f..baa85efa 100644 --- a/ta_lib/strategies/stop_loss/Cargo.toml +++ b/ta_lib/strategies/stop_loss/Cargo.toml @@ -13,3 +13,4 @@ repository.workspace = true core = { path = "../../core" } base = { path = "../base" } volatility = { path = "../../indicators/volatility" } +timeseries = { path = "../../timeseries" } \ No newline at end of file diff --git a/ta_lib/strategies/stop_loss/src/atr.rs b/ta_lib/strategies/stop_loss/src/atr.rs index 63a21733..3150e064 100644 --- a/ta_lib/strategies/stop_loss/src/atr.rs +++ b/ta_lib/strategies/stop_loss/src/atr.rs @@ -1,14 +1,17 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; pub struct AtrStopLoss { + smooth: Smooth, period: usize, factor: f32, } impl AtrStopLoss { - pub fn new(period: f32, factor: f32) -> Self { + pub fn new(smooth: Smooth, period: f32, factor: f32) -> Self { Self { + smooth, period: period as usize, factor, } @@ -20,8 +23,8 @@ impl StopLoss for AtrStopLoss { self.period } - fn find(&self, data: &OHLCVSeries) -> (Series, Series) { - let atr_multi = data.atr(self.period) * self.factor; + fn find(&self, data: &OHLCVSeries) -> (Price, Price) { + let atr_multi = data.atr(self.smooth, self.period) * self.factor; (data.low() - &atr_multi, data.high() + &atr_multi) } diff --git a/ta_lib/strategies/stop_loss/src/dch.rs b/ta_lib/strategies/stop_loss/src/dch.rs index 484de52f..d84b7531 100644 --- a/ta_lib/strategies/stop_loss/src/dch.rs +++ b/ta_lib/strategies/stop_loss/src/dch.rs @@ -1,5 +1,6 @@ use base::prelude::*; use core::prelude::*; +use timeseries::prelude::*; use volatility::dch; pub struct DchStopLoss { @@ -21,7 +22,7 @@ impl StopLoss for DchStopLoss { self.period } - fn find(&self, data: &OHLCVSeries) -> (Series, Series) { + fn find(&self, data: &OHLCVSeries) -> (Price, Price) { let (upper, _, lower) = dch(data.high(), data.low(), self.period); let volatility = data.close().std(self.period).highest(self.period) * self.factor; diff --git a/ta_lib/strategies/trend_follow/Cargo.toml b/ta_lib/strategies/trend_follow/Cargo.toml index 77b0980e..c1c38a86 100644 --- a/ta_lib/strategies/trend_follow/Cargo.toml +++ b/ta_lib/strategies/trend_follow/Cargo.toml @@ -11,6 +11,7 @@ repository.workspace = true [dependencies] core = { path = "../../core" } +timeseries = { path = "../../timeseries" } base = { path = "../base" } stop_loss = { path = "../stop_loss" } signal = { path = "../signal" } diff --git a/ta_lib/strategies/trend_follow/src/config/baseline_config.rs b/ta_lib/strategies/trend_follow/src/config/baseline_config.rs index cd6cb6ed..5b81ac97 100644 --- a/ta_lib/strategies/trend_follow/src/config/baseline_config.rs +++ b/ta_lib/strategies/trend_follow/src/config/baseline_config.rs @@ -3,9 +3,5 @@ use serde::Deserialize; #[derive(Deserialize)] #[serde(tag = "type")] pub enum BaseLineConfig { - Ma { - source_type: f32, - ma: f32, - period: f32, - }, + Ma { source: f32, ma: f32, period: f32 }, } diff --git a/ta_lib/strategies/trend_follow/src/config/confirm_config.rs b/ta_lib/strategies/trend_follow/src/config/confirm_config.rs index 068e78d3..4b1ca732 100644 --- a/ta_lib/strategies/trend_follow/src/config/confirm_config.rs +++ b/ta_lib/strategies/trend_follow/src/config/confirm_config.rs @@ -3,6 +3,13 @@ use serde::Deserialize; #[derive(Deserialize)] #[serde(tag = "type")] pub enum ConfirmConfig { + // contrarian + BbC { + smooth: f32, + period: f32, + factor: f32, + }, + // trend Dpo { source_type: f32, smooth_type: f32, @@ -12,13 +19,13 @@ pub enum ConfirmConfig { source_type: f32, smooth_type: f32, period: f32, - divisor: f32, }, Cci { - source_type: f32, - smooth_type: f32, + source: f32, period: f32, factor: f32, + smooth: f32, + period_smooth: f32, }, Dumb { period: f32, @@ -37,10 +44,6 @@ pub enum ConfirmConfig { smooth_period: f32, threshold: f32, }, - Roc { - source_type: f32, - period: f32, - }, Stc { source_type: f32, smooth_type: f32, @@ -50,15 +53,36 @@ pub enum ConfirmConfig { d_first: f32, d_second: f32, }, - Dso { - source_type: f32, + Braid { smooth_type: f32, - smooth_period: f32, - k_period: f32, - d_period: f32, + fast_period: f32, + slow_period: f32, + open_period: f32, + strength: f32, + smooth_atr: f32, + period_atr: f32, }, - Vi { - atr_period: f32, + Wpr { + source: f32, period: f32, + smooth_signal: f32, + period_signal: f32, + }, + Didi { + source: f32, + smooth: f32, + period_medium: f32, + period_slow: f32, + smooth_signal: f32, + period_signal: f32, + }, + Cc { + source: f32, + period_fast: f32, + period_slow: f32, + smooth: f32, + period_smooth: f32, + smooth_signal: f32, + period_signal: f32, }, } diff --git a/ta_lib/strategies/trend_follow/src/config/exit_config.rs b/ta_lib/strategies/trend_follow/src/config/exit_config.rs index a6b735f6..5748e8ef 100644 --- a/ta_lib/strategies/trend_follow/src/config/exit_config.rs +++ b/ta_lib/strategies/trend_follow/src/config/exit_config.rs @@ -5,16 +5,10 @@ use serde::Deserialize; pub enum ExitConfig { Ast { source_type: f32, - atr_period: f32, + smooth_atr: f32, + period_atr: f32, factor: f32, }, - Cci { - source_type: f32, - smooth_type: f32, - period: f32, - factor: f32, - threshold: f32, - }, Dumb {}, HighLow { period: f32, @@ -41,4 +35,16 @@ pub enum ExitConfig { period: f32, signal_period: f32, }, + Rex { + source: f32, + smooth: f32, + period: f32, + smooth_signal: f32, + period_signal: f32, + }, + Mad { + source: f32, + period_fast: f32, + period_slow: f32, + }, } diff --git a/ta_lib/strategies/trend_follow/src/config/pulse_config.rs b/ta_lib/strategies/trend_follow/src/config/pulse_config.rs index 730d4c83..0825c476 100644 --- a/ta_lib/strategies/trend_follow/src/config/pulse_config.rs +++ b/ta_lib/strategies/trend_follow/src/config/pulse_config.rs @@ -4,51 +4,57 @@ use serde::Deserialize; #[serde(tag = "type")] pub enum PulseConfig { Adx { - smooth_type: f32, - adx_period: f32, - di_period: f32, + smooth: f32, + period: f32, + period_di: f32, threshold: f32, }, - Braid { - smooth_type: f32, - fast_period: f32, - slow_period: f32, - open_period: f32, - strength: f32, - atr_period: f32, - }, Dumb { period: f32, }, Chop { - atr_period: f32, period: f32, + smooth_atr: f32, + period_atr: f32, threshold: f32, }, Nvol { - smooth_type: f32, + smooth: f32, period: f32, }, Vo { - smooth_type: f32, - fast_period: f32, - slow_period: f32, + smooth: f32, + period_fast: f32, + period_slow: f32, }, Tdfi { - source_type: f32, - smooth_type: f32, + source: f32, + smooth: f32, period: f32, n: f32, }, Wae { - smooth_type: f32, - fast_period: f32, - slow_period: f32, + source: f32, + smooth: f32, + period_fast: f32, + period_slow: f32, smooth_bb: f32, - bb_period: f32, + period_bb: f32, factor: f32, strength: f32, - atr_period: f32, - dz_factor: f32, + }, + Yz { + period: f32, + smooth_signal: f32, + period_signal: f32, + }, + Sqz { + source: f32, + smooth: f32, + period: f32, + smooth_atr: f32, + period_atr: f32, + factor_bb: f32, + factor_kch: f32, }, } diff --git a/ta_lib/strategies/trend_follow/src/config/signal_config.rs b/ta_lib/strategies/trend_follow/src/config/signal_config.rs index 7f25dfb1..daf7a805 100644 --- a/ta_lib/strategies/trend_follow/src/config/signal_config.rs +++ b/ta_lib/strategies/trend_follow/src/config/signal_config.rs @@ -37,6 +37,12 @@ pub enum SignalConfig { slow_period: f32, signal_period: f32, }, + MadZeroCross { + source: f32, + smooth: f32, + period_fast: f32, + period_slow: f32, + }, QstickZeroCross { smooth_type: f32, period: f32, @@ -140,31 +146,36 @@ pub enum SignalConfig { bb_period: f32, factor: f32, }, - // Reversal - DmiReversal { + // 2 lines cross + Dmi2LinesCross { smooth_type: f32, adx_period: f32, di_period: f32, }, - SnatrReversal { - smooth_type: f32, - atr_period: f32, - atr_smooth_period: f32, - threshold: f32, - }, - ViReversal { - atr_period: f32, + Vi2LinesCross { period: f32, + smooth_atr: f32, + period_atr: f32, }, // Flip CeFlip { + source_type: f32, period: f32, - atr_period: f32, + smooth_atr: f32, + period_atr: f32, factor: f32, }, SupFlip { source_type: f32, - atr_period: f32, + smooth_atr: f32, + period_atr: f32, + factor: f32, + }, + // Pullback + SupPullback { + source: f32, + smooth_atr: f32, + period_atr: f32, factor: f32, }, // Ma @@ -215,12 +226,22 @@ pub enum SignalConfig { fast_period: f32, slow_period: f32, }, + Spread { + source: f32, + smooth: f32, + period_fast: f32, + period_slow: f32, + }, CandlestickTrend { candle: f32, }, + CandlestickReversal { + candle: f32, + }, HighLow { period: f32, }, + // Color Switch MacdColorSwitch { source_type: f32, smooth_type: f32, @@ -228,18 +249,73 @@ pub enum SignalConfig { slow_period: f32, signal_period: f32, }, - RsiV { - source_type: f32, + // Contrarian + KchA { + source: f32, + smooth: f32, + period: f32, + smooth_atr: f32, + period_atr: f32, + factor: f32, + }, + KchC { + source: f32, + smooth: f32, + period: f32, + smooth_atr: f32, + period_atr: f32, + factor: f32, + }, + Snatr { smooth_type: f32, - rsi_period: f32, + atr_period: f32, + atr_smooth_period: f32, + threshold: f32, + }, + RsiC { + source: f32, + smooth: f32, + period: f32, + threshold: f32, + }, + RsiD { + source: f32, + smooth: f32, + period: f32, + threshold: f32, + }, + RsiNt { + source: f32, + smooth: f32, + period: f32, + threshold: f32, + }, + RsiV { + source: f32, + smooth: f32, + period: f32, + threshold: f32, + }, + RsiU { + source: f32, + smooth: f32, + period: f32, threshold: f32, }, TiiV { - source_type: f32, - smooth_type: f32, + source: f32, + smooth: f32, major_period: f32, minor_period: f32, }, + StochE { + source: f32, + smooth: f32, + period: f32, + period_k: f32, + period_d: f32, + threshold: f32, + }, // Neutrality DsoNeutralityCross { source_type: f32, diff --git a/ta_lib/strategies/trend_follow/src/config/stoploss_config.rs b/ta_lib/strategies/trend_follow/src/config/stoploss_config.rs index 5304bdb1..a6531231 100644 --- a/ta_lib/strategies/trend_follow/src/config/stoploss_config.rs +++ b/ta_lib/strategies/trend_follow/src/config/stoploss_config.rs @@ -3,6 +3,13 @@ use serde::Deserialize; #[derive(Deserialize)] #[serde(tag = "type")] pub enum StopLossConfig { - Atr { period: f32, factor: f32 }, - Dch { period: f32, factor: f32 }, + Atr { + smooth: f32, + period: f32, + factor: f32, + }, + Dch { + period: f32, + factor: f32, + }, } diff --git a/ta_lib/strategies/trend_follow/src/deserialize/candle_deserialize.rs b/ta_lib/strategies/trend_follow/src/deserialize/candle_deserialize.rs index 3a30e3ba..c6e272ce 100644 --- a/ta_lib/strategies/trend_follow/src/deserialize/candle_deserialize.rs +++ b/ta_lib/strategies/trend_follow/src/deserialize/candle_deserialize.rs @@ -1,4 +1,4 @@ -use indicator::CandleTrendType; +use indicator::{CandleReversalType, CandleTrendType}; #[inline] pub fn candletrend_deserialize(candle: usize) -> CandleTrendType { @@ -20,3 +20,20 @@ pub fn candletrend_deserialize(candle: usize) -> CandleTrendType { _ => CandleTrendType::THREE_CANDLES, } } + +#[inline] +pub fn candlereversal_deserialize(candle: usize) -> CandleReversalType { + match candle { + 1 => CandleReversalType::DOJI, + 2 => CandleReversalType::ENGULFING, + 3 => CandleReversalType::EUPHORIA, + 4 => CandleReversalType::HAMMER, + 5 => CandleReversalType::HARAMIF, + 6 => CandleReversalType::HARAMIS, + 7 => CandleReversalType::KANGAROO, + 8 => CandleReversalType::R, + 9 => CandleReversalType::SPLIT, + 10 => CandleReversalType::TWEEZERS, + _ => CandleReversalType::R, + } +} diff --git a/ta_lib/strategies/trend_follow/src/deserialize/ma_deserialize.rs b/ta_lib/strategies/trend_follow/src/deserialize/ma_deserialize.rs index 538b6099..37cc15cb 100644 --- a/ta_lib/strategies/trend_follow/src/deserialize/ma_deserialize.rs +++ b/ta_lib/strategies/trend_follow/src/deserialize/ma_deserialize.rs @@ -17,19 +17,22 @@ pub fn ma_deserialize(ma: usize) -> MovingAverageType { 12 => MovingAverageType::MD, 13 => MovingAverageType::RMSMA, 14 => MovingAverageType::SINWMA, - 15 => MovingAverageType::SMA, - 16 => MovingAverageType::SMMA, - 17 => MovingAverageType::TTHREE, - 18 => MovingAverageType::TEMA, - 19 => MovingAverageType::TMA, - 20 => MovingAverageType::VIDYA, - 21 => MovingAverageType::VWMA, - 22 => MovingAverageType::VWEMA, - 23 => MovingAverageType::WMA, - 24 => MovingAverageType::ZLEMA, - 25 => MovingAverageType::ZLSMA, - 26 => MovingAverageType::ZLTEMA, - 27 => MovingAverageType::ZLHMA, + 15 => MovingAverageType::SLSMA, + 16 => MovingAverageType::SMA, + 17 => MovingAverageType::SMMA, + 18 => MovingAverageType::TTHREE, + 19 => MovingAverageType::TEMA, + 20 => MovingAverageType::TL, + 21 => MovingAverageType::TRIMA, + 22 => MovingAverageType::ULTS, + 23 => MovingAverageType::VIDYA, + 24 => MovingAverageType::VWMA, + 25 => MovingAverageType::VWEMA, + 26 => MovingAverageType::WMA, + 27 => MovingAverageType::ZLEMA, + 28 => MovingAverageType::ZLSMA, + 29 => MovingAverageType::ZLTEMA, + 30 => MovingAverageType::ZLHMA, _ => MovingAverageType::SMA, } } diff --git a/ta_lib/strategies/trend_follow/src/deserialize/mod.rs b/ta_lib/strategies/trend_follow/src/deserialize/mod.rs index 48522e3d..b7a5140b 100644 --- a/ta_lib/strategies/trend_follow/src/deserialize/mod.rs +++ b/ta_lib/strategies/trend_follow/src/deserialize/mod.rs @@ -3,7 +3,7 @@ mod ma_deserialize; mod smooth_deserialize; mod source_deserialize; -pub use candle_deserialize::candletrend_deserialize; +pub use candle_deserialize::{candlereversal_deserialize, candletrend_deserialize}; pub use ma_deserialize::ma_deserialize; pub use smooth_deserialize::smooth_deserialize; pub use source_deserialize::source_deserialize; diff --git a/ta_lib/strategies/trend_follow/src/deserialize/smooth_deserialize.rs b/ta_lib/strategies/trend_follow/src/deserialize/smooth_deserialize.rs index 32b72fc0..9c456a7b 100644 --- a/ta_lib/strategies/trend_follow/src/deserialize/smooth_deserialize.rs +++ b/ta_lib/strategies/trend_follow/src/deserialize/smooth_deserialize.rs @@ -12,6 +12,8 @@ pub fn smooth_deserialize(smooth: usize) -> Smooth { 7 => Smooth::ZLEMA, 8 => Smooth::LSMA, 9 => Smooth::TEMA, + 10 => Smooth::DEMA, + 11 => Smooth::ULTS, _ => Smooth::EMA, } } diff --git a/ta_lib/strategies/trend_follow/src/ffi.rs b/ta_lib/strategies/trend_follow/src/ffi.rs index 1e5dfbfb..169b855d 100644 --- a/ta_lib/strategies/trend_follow/src/ffi.rs +++ b/ta_lib/strategies/trend_follow/src/ffi.rs @@ -5,6 +5,7 @@ use crate::mapper::{ map_to_baseline, map_to_confirm, map_to_exit, map_to_pulse, map_to_signal, map_to_stoploss, }; use base::prelude::*; +use timeseries::prelude::*; fn read_from_memory(ptr: *const u8, len: usize) -> &'static [u8] { unsafe { std::slice::from_raw_parts(ptr, len) } @@ -40,6 +41,7 @@ pub fn register( let exit: ExitConfig = serde_json::from_slice(exit_buffer).unwrap(); register_strategy( + Box::::default(), map_to_signal(signal), map_to_confirm(confirm), map_to_pulse(pulse), diff --git a/ta_lib/strategies/trend_follow/src/mapper/baseline_mapper.rs b/ta_lib/strategies/trend_follow/src/mapper/baseline_mapper.rs index f0c98997..750cea6d 100644 --- a/ta_lib/strategies/trend_follow/src/mapper/baseline_mapper.rs +++ b/ta_lib/strategies/trend_follow/src/mapper/baseline_mapper.rs @@ -6,12 +6,8 @@ use baseline::*; #[inline] pub fn map_to_baseline(config: BaseLineConfig) -> Box { match config { - BaseLineConfig::Ma { - source_type, - ma, - period, - } => Box::new(MaBaseLine::new( - source_deserialize(source_type as usize), + BaseLineConfig::Ma { source, ma, period } => Box::new(MaBaseLine::new( + source_deserialize(source as usize), ma_deserialize(ma as usize), period, )), diff --git a/ta_lib/strategies/trend_follow/src/mapper/confirm_mapper.rs b/ta_lib/strategies/trend_follow/src/mapper/confirm_mapper.rs index b7e99c96..839a7ec8 100644 --- a/ta_lib/strategies/trend_follow/src/mapper/confirm_mapper.rs +++ b/ta_lib/strategies/trend_follow/src/mapper/confirm_mapper.rs @@ -6,6 +6,17 @@ use confirm::*; #[inline] pub fn map_to_confirm(config: ConfirmConfig) -> Box { match config { + // contrarian + ConfirmConfig::BbC { + smooth, + period, + factor, + } => Box::new(BbConfirm::new( + smooth_deserialize(smooth as usize), + period, + factor, + )), + // trend ConfirmConfig::Dpo { source_type, smooth_type, @@ -15,40 +26,55 @@ pub fn map_to_confirm(config: ConfirmConfig) -> Box { smooth_deserialize(smooth_type as usize), period, )), - ConfirmConfig::Dso { - source_type, - smooth_type, - smooth_period, - k_period, - d_period, - } => Box::new(DsoConfirm::new( - source_deserialize(source_type as usize), - smooth_deserialize(smooth_type as usize), - smooth_period, - k_period, - d_period, + ConfirmConfig::Cc { + source, + period_fast, + period_slow, + smooth, + period_smooth, + smooth_signal, + period_signal, + } => Box::new(CcConfirm::new( + source_deserialize(source as usize), + period_fast, + period_slow, + smooth_deserialize(smooth as usize), + period_smooth, + smooth_deserialize(smooth_signal as usize), + period_signal, )), ConfirmConfig::Cci { - source_type, - smooth_type, + source, period, factor, + smooth, + period_smooth, } => Box::new(CciConfirm::new( - source_deserialize(source_type as usize), - smooth_deserialize(smooth_type as usize), + source_deserialize(source as usize), period, factor, + smooth_deserialize(smooth as usize), + period_smooth, + )), + ConfirmConfig::Wpr { + source, + period, + smooth_signal, + period_signal, + } => Box::new(WprConfirm::new( + source_deserialize(source as usize), + period, + smooth_deserialize(smooth_signal as usize), + period_signal, )), ConfirmConfig::Eom { source_type, smooth_type, period, - divisor, } => Box::new(EomConfirm::new( source_deserialize(source_type as usize), smooth_deserialize(smooth_type as usize), period, - divisor, )), ConfirmConfig::Dumb { period } => Box::new(DumbConfirm::new(period)), ConfirmConfig::RsiSignalLine { @@ -77,14 +103,6 @@ pub fn map_to_confirm(config: ConfirmConfig) -> Box { period, threshold, )), - ConfirmConfig::Roc { - source_type, - period, - } => Box::new(RocConfirm::new( - source_deserialize(source_type as usize), - period, - )), - ConfirmConfig::Vi { atr_period, period } => Box::new(ViConfirm::new(atr_period, period)), ConfirmConfig::Stc { source_type, smooth_type, @@ -102,5 +120,37 @@ pub fn map_to_confirm(config: ConfirmConfig) -> Box { d_first, d_second, )), + ConfirmConfig::Braid { + smooth_type, + fast_period, + slow_period, + open_period, + strength, + smooth_atr, + period_atr, + } => Box::new(BraidConfirm::new( + smooth_deserialize(smooth_type as usize), + fast_period, + slow_period, + open_period, + strength, + smooth_deserialize(smooth_atr as usize), + period_atr, + )), + ConfirmConfig::Didi { + source, + smooth, + period_medium, + period_slow, + smooth_signal, + period_signal, + } => Box::new(DidiConfirm::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period_medium, + period_slow, + smooth_deserialize(smooth_signal as usize), + period_signal, + )), } } diff --git a/ta_lib/strategies/trend_follow/src/mapper/exit_mapper.rs b/ta_lib/strategies/trend_follow/src/mapper/exit_mapper.rs index 21110c81..bb9faaca 100644 --- a/ta_lib/strategies/trend_follow/src/mapper/exit_mapper.rs +++ b/ta_lib/strategies/trend_follow/src/mapper/exit_mapper.rs @@ -8,26 +8,15 @@ pub fn map_to_exit(config: ExitConfig) -> Box { match config { ExitConfig::Ast { source_type, - atr_period, + smooth_atr, + period_atr, factor, } => Box::new(AstExit::new( source_deserialize(source_type as usize), - atr_period, + smooth_deserialize(smooth_atr as usize), + period_atr, factor, )), - ExitConfig::Cci { - source_type, - smooth_type, - period, - factor, - threshold, - } => Box::new(CciExit::new( - source_deserialize(source_type as usize), - smooth_deserialize(smooth_type as usize), - period, - factor, - threshold, - )), ExitConfig::Dumb {} => Box::new(DumbExit {}), ExitConfig::HighLow { period } => Box::new(HighLowExit::new(period)), ExitConfig::Rsi { @@ -70,5 +59,27 @@ pub fn map_to_exit(config: ExitConfig) -> Box { ma_deserialize(ma as usize), period, )), + ExitConfig::Rex { + source, + smooth, + period, + smooth_signal, + period_signal, + } => Box::new(RexExit::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period, + smooth_deserialize(smooth_signal as usize), + period_signal, + )), + ExitConfig::Mad { + source, + period_fast, + period_slow, + } => Box::new(MadExit::new( + source_deserialize(source as usize), + period_fast, + period_slow, + )), } } diff --git a/ta_lib/strategies/trend_follow/src/mapper/pulse_mapper.rs b/ta_lib/strategies/trend_follow/src/mapper/pulse_mapper.rs index 6bd140d6..4edfee55 100644 --- a/ta_lib/strategies/trend_follow/src/mapper/pulse_mapper.rs +++ b/ta_lib/strategies/trend_follow/src/mapper/pulse_mapper.rs @@ -7,84 +7,95 @@ use pulse::*; pub fn map_to_pulse(config: PulseConfig) -> Box { match config { PulseConfig::Adx { - smooth_type, - adx_period, - di_period, + smooth, + period, + period_di, threshold, } => Box::new(AdxPulse::new( - smooth_deserialize(smooth_type as usize), - adx_period, - di_period, + smooth_deserialize(smooth as usize), + period, + period_di, threshold, )), - PulseConfig::Braid { - smooth_type, - fast_period, - slow_period, - open_period, - strength, - atr_period, - } => Box::new(BraidPulse::new( - smooth_deserialize(smooth_type as usize), - fast_period, - slow_period, - open_period, - strength, - atr_period, - )), PulseConfig::Chop { - atr_period, period, + smooth_atr, + period_atr, threshold, - } => Box::new(ChopPulse::new(atr_period, period, threshold)), - PulseConfig::Dumb { period } => Box::new(DumbPulse::new(period)), - PulseConfig::Nvol { - smooth_type, - period, - } => Box::new(NvolPulse::new( - smooth_deserialize(smooth_type as usize), + } => Box::new(ChopPulse::new( period, + smooth_deserialize(smooth_atr as usize), + period_atr, + threshold, )), + PulseConfig::Dumb { period } => Box::new(DumbPulse::new(period)), + PulseConfig::Nvol { smooth, period } => { + Box::new(NvolPulse::new(smooth_deserialize(smooth as usize), period)) + } PulseConfig::Tdfi { - source_type, - smooth_type, + source, + smooth, period, n, } => Box::new(TdfiPulse::new( - source_deserialize(source_type as usize), - smooth_deserialize(smooth_type as usize), + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), period, n, )), PulseConfig::Vo { - smooth_type, - fast_period, - slow_period, + smooth, + period_fast, + period_slow, } => Box::new(VoPulse::new( - smooth_deserialize(smooth_type as usize), - fast_period, - slow_period, + smooth_deserialize(smooth as usize), + period_fast, + period_slow, )), PulseConfig::Wae { - smooth_type, - fast_period, - slow_period, + source, + smooth, + period_fast, + period_slow, smooth_bb, - bb_period, + period_bb, factor, strength, - atr_period, - dz_factor, } => Box::new(WaePulse::new( - smooth_deserialize(smooth_type as usize), - fast_period, - slow_period, + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period_fast, + period_slow, smooth_deserialize(smooth_bb as usize), - bb_period, + period_bb, factor, strength, - atr_period, - dz_factor, + )), + PulseConfig::Yz { + period, + smooth_signal, + period_signal, + } => Box::new(YzPulse::new( + period, + smooth_deserialize(smooth_signal as usize), + period_signal, + )), + PulseConfig::Sqz { + source, + smooth, + period, + smooth_atr, + period_atr, + factor_bb, + factor_kch, + } => Box::new(SqzPulse::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period, + smooth_deserialize(smooth_atr as usize), + period_atr, + factor_bb, + factor_kch, )), } } diff --git a/ta_lib/strategies/trend_follow/src/mapper/signal_mapper.rs b/ta_lib/strategies/trend_follow/src/mapper/signal_mapper.rs index bacfbebe..b82d5c39 100644 --- a/ta_lib/strategies/trend_follow/src/mapper/signal_mapper.rs +++ b/ta_lib/strategies/trend_follow/src/mapper/signal_mapper.rs @@ -1,6 +1,7 @@ use crate::config::SignalConfig; use crate::deserialize::{ - candletrend_deserialize, ma_deserialize, smooth_deserialize, source_deserialize, + candlereversal_deserialize, candletrend_deserialize, ma_deserialize, smooth_deserialize, + source_deserialize, }; use base::prelude::*; use signal::*; @@ -69,6 +70,17 @@ pub fn map_to_signal(config: SignalConfig) -> Box { slow_period, signal_period, )), + SignalConfig::MadZeroCross { + source, + smooth, + period_fast, + period_slow, + } => Box::new(MadZeroCrossSignal::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period_fast, + period_slow, + )), SignalConfig::RocZeroCross { source_type, period, @@ -227,17 +239,39 @@ pub fn map_to_signal(config: SignalConfig) -> Box { )), // Flip SignalConfig::CeFlip { + source_type, period, - atr_period, + smooth_atr, + period_atr, + factor, + } => Box::new(CeFlipSignal::new( + source_deserialize(source_type as usize), + period, + smooth_deserialize(smooth_atr as usize), + period_atr, factor, - } => Box::new(CeFlipSignal::new(period, atr_period, factor)), + )), SignalConfig::SupFlip { source_type, - atr_period, + smooth_atr, + period_atr, factor, } => Box::new(SupertrendFlipSignal::new( source_deserialize(source_type as usize), - atr_period, + smooth_deserialize(smooth_atr as usize), + period_atr, + factor, + )), + // Pullback + SignalConfig::SupPullback { + source, + smooth_atr, + period_atr, + factor, + } => Box::new(SupertrendPullbackSignal::new( + source_deserialize(source as usize), + smooth_deserialize(smooth_atr as usize), + period_atr, factor, )), // Pattern @@ -252,6 +286,25 @@ pub fn map_to_signal(config: SignalConfig) -> Box { fast_period, slow_period, )), + SignalConfig::Spread { + source, + smooth, + period_fast, + period_slow, + } => Box::new(SpreadSignal::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period_fast, + period_slow, + )), + SignalConfig::HighLow { period } => Box::new(HighLowSignal::new(period)), + SignalConfig::CandlestickTrend { candle } => Box::new(CandlestickTrendSignal::new( + candletrend_deserialize(candle as usize), + )), + SignalConfig::CandlestickReversal { candle } => Box::new(CandlestickReversalSignal::new( + candlereversal_deserialize(candle as usize), + )), + // Color Switch SignalConfig::MacdColorSwitch { source_type, smooth_type, @@ -265,32 +318,129 @@ pub fn map_to_signal(config: SignalConfig) -> Box { slow_period, signal_period, )), - SignalConfig::TiiV { - source_type, + // Contrarian + SignalConfig::KchA { + source, + smooth, + period, + smooth_atr, + period_atr, + factor, + } => Box::new(KchASignal::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period, + smooth_deserialize(smooth_atr as usize), + period_atr, + factor, + )), + SignalConfig::KchC { + source, + smooth, + period, + smooth_atr, + period_atr, + factor, + } => Box::new(KchCSignal::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period, + smooth_deserialize(smooth_atr as usize), + period_atr, + factor, + )), + SignalConfig::Snatr { smooth_type, + atr_period, + atr_smooth_period, + threshold, + } => Box::new(SnatrSignal::new( + smooth_deserialize(smooth_type as usize), + atr_period, + atr_smooth_period, + threshold, + )), + SignalConfig::TiiV { + source, + smooth, major_period, minor_period, } => Box::new(TiiVSignal::new( - source_deserialize(source_type as usize), - smooth_deserialize(smooth_type as usize), + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), major_period, minor_period, )), + SignalConfig::StochE { + source, + smooth, + period, + period_k, + period_d, + threshold, + } => Box::new(StochESignal::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period, + period_k, + period_d, + threshold, + )), + SignalConfig::RsiC { + source, + smooth, + period, + threshold, + } => Box::new(RsiCSignal::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period, + threshold, + )), + SignalConfig::RsiD { + source, + smooth, + period, + threshold, + } => Box::new(RsiDSignal::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period, + threshold, + )), + SignalConfig::RsiNt { + source, + smooth, + period, + threshold, + } => Box::new(RsiNtSignal::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period, + threshold, + )), + SignalConfig::RsiU { + source, + smooth, + period, + threshold, + } => Box::new(RsiUSignal::new( + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period, + threshold, + )), SignalConfig::RsiV { - source_type, - smooth_type, - rsi_period, + source, + smooth, + period, threshold, } => Box::new(RsiVSignal::new( - source_deserialize(source_type as usize), - smooth_deserialize(smooth_type as usize), - rsi_period, + source_deserialize(source as usize), + smooth_deserialize(smooth as usize), + period, threshold, )), - SignalConfig::HighLow { period } => Box::new(HighLowSignal::new(period)), - SignalConfig::CandlestickTrend { candle } => Box::new(CandlestickTrendSignal::new( - candletrend_deserialize(candle as usize), - )), // BB SignalConfig::MacdBb { source_type, @@ -456,30 +606,25 @@ pub fn map_to_signal(config: SignalConfig) -> Box { source_deserialize(source_type as usize), period, )), - // Reversal - SignalConfig::SnatrReversal { - smooth_type, - atr_period, - atr_smooth_period, - threshold, - } => Box::new(SnatrReversalSignal::new( - smooth_deserialize(smooth_type as usize), - atr_period, - atr_smooth_period, - threshold, - )), - SignalConfig::DmiReversal { + // 2 lines cross + SignalConfig::Dmi2LinesCross { smooth_type, adx_period, di_period, - } => Box::new(DmiReversalSignal::new( + } => Box::new(Dmi2LinesCrossSignal::new( smooth_deserialize(smooth_type as usize), adx_period, di_period, )), - SignalConfig::ViReversal { period, atr_period } => { - Box::new(ViReversalSignal::new(period, atr_period)) - } + SignalConfig::Vi2LinesCross { + period, + smooth_atr, + period_atr, + } => Box::new(Vi2LinesCrossSignal::new( + period, + smooth_deserialize(smooth_atr as usize), + period_atr, + )), // Breakout SignalConfig::DchMa2Breakout { source_type, diff --git a/ta_lib/strategies/trend_follow/src/mapper/stoploss_mapper.rs b/ta_lib/strategies/trend_follow/src/mapper/stoploss_mapper.rs index d8a7612c..1d76a242 100644 --- a/ta_lib/strategies/trend_follow/src/mapper/stoploss_mapper.rs +++ b/ta_lib/strategies/trend_follow/src/mapper/stoploss_mapper.rs @@ -1,11 +1,20 @@ use crate::config::StopLossConfig; +use crate::deserialize::smooth_deserialize; use base::prelude::*; use stop_loss::*; #[inline] pub fn map_to_stoploss(config: StopLossConfig) -> Box { match config { - StopLossConfig::Atr { period, factor } => Box::new(AtrStopLoss::new(period, factor)), + StopLossConfig::Atr { + smooth, + period, + factor, + } => Box::new(AtrStopLoss::new( + smooth_deserialize(smooth as usize), + period, + factor, + )), StopLossConfig::Dch { period, factor } => Box::new(DchStopLoss::new(period, factor)), } } diff --git a/ta_lib/timeseries/Cargo.toml b/ta_lib/timeseries/Cargo.toml new file mode 100644 index 00000000..38c463b3 --- /dev/null +++ b/ta_lib/timeseries/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "timeseries" +authors.workspace = true +edition.workspace = true +license.workspace = true +readme.workspace = true +repository.workspace = true +version.workspace = true + + +[dependencies] +core = { path = "../core" } +momentum = { path = "../indicators/momentum" } +volume = { path = "../indicators/volume" } +trend = { path = "../indicators/trend" } +price = { path = "../price" } +volatility = { path = "../indicators/volatility" } +serde = { version = "1.0", default-features = false, features = ["derive"] } \ No newline at end of file diff --git a/ta_lib/timeseries/src/lib.rs b/ta_lib/timeseries/src/lib.rs new file mode 100644 index 00000000..b4b16d8b --- /dev/null +++ b/ta_lib/timeseries/src/lib.rs @@ -0,0 +1,13 @@ +mod model; +mod ohlcv; +mod ta; +mod traits; + +pub mod prelude { + pub use crate::model::BaseTimeSeries; + pub use crate::ohlcv::{OHLCVSeries, OHLCV}; + pub use crate::ta::TechAnalysis; + pub use crate::traits::*; +} + +pub use prelude::*; diff --git a/ta_lib/timeseries/src/model.rs b/ta_lib/timeseries/src/model.rs new file mode 100644 index 00000000..53a1dac9 --- /dev/null +++ b/ta_lib/timeseries/src/model.rs @@ -0,0 +1,386 @@ +use crate::{OHLCVSeries, TechAnalysis, TimeSeries, OHLCV}; +use core::prelude::*; +use momentum::{cci, dmi, macd, roc, rsi, stochosc}; +use price::{typical_price, wcl}; +use std::collections::BTreeMap; +use trend::spp; +use volatility::{bb, gkyz, kch, tr, yz}; +use volume::{mfi, nvol, obv, vwap}; + +#[derive(Debug, Clone)] +pub struct BaseTimeSeries { + data: BTreeMap, +} + +impl Default for BaseTimeSeries { + fn default() -> Self { + Self::new() + } +} + +impl BaseTimeSeries { + pub fn new() -> Self { + Self { + data: BTreeMap::new(), + } + } +} + +impl TimeSeries for BaseTimeSeries { + fn add(&mut self, bar: &OHLCV) { + self.data.insert(bar.ts, *bar); + } + + fn next_bar(&self, bar: &OHLCV) -> Option { + self.data.range(bar.ts..).nth(1).map(|(_, &v)| v) + } + + fn prev_bar(&self, bar: &OHLCV) -> Option { + self.data.range(..bar.ts).next_back().map(|(_, &v)| v) + } + + fn back_n_bars(&self, bar: &OHLCV, n: usize) -> Vec { + self.data + .range(..bar.ts) + .rev() + .take(n) + .map(|(_, &v)| v) + .collect() + } + + #[inline] + fn len(&self) -> usize { + self.data.len() + } + + fn ohlcv(&self, size: usize) -> OHLCVSeries { + let len = self.len(); + let start_index = if len >= size { len - size } else { 0 }; + + OHLCVSeries::from( + self.data + .range(..) + .skip(start_index) + .map(|(_, &v)| v) + .collect::>(), + ) + } + + fn ta(&self, bar: &OHLCV) -> TechAnalysis { + let periods = [2, 14, 12, 26, 9, 5, 10, 1, 3, 11]; + let factors = [1.8, 0.015, 1.0]; + + let end_index = self + .data + .keys() + .position(|&ts| ts >= bar.ts) + .unwrap_or_else(|| self.len()); + let max_period = periods.iter().max().unwrap_or(&0); + + let start_index = if end_index > *max_period { + end_index - max_period + } else { + 0 + }; + + let series = OHLCVSeries::from( + self.data + .values() + .skip(start_index) + .take(end_index - start_index) + .copied() + .collect::>(), + ); + + let open = series.open(); + let high = series.high(); + let low = series.low(); + let source = series.close(); + let volume = series.volume(); + let hlc3 = typical_price(high, low, source); + let hlcc4 = wcl(high, low, source); + + let rsi2 = rsi(source, Smooth::SMMA, periods[0]); + let rsi14 = rsi(source, Smooth::SMMA, periods[1]); + let ema5 = source.smooth(Smooth::EMA, periods[5]); + let ema11 = source.smooth(Smooth::EMA, periods[9]); + + let (_, _, histogram) = macd(source, Smooth::EMA, periods[2], periods[3], periods[4]); + let ppo = source.spread_pct(Smooth::EMA, periods[2], periods[3]); + let vo = volume.spread_pct(Smooth::EMA, periods[5], periods[6]); + let nvol = nvol(volume, Smooth::SMA, periods[4]); + let obv = obv(source, volume); + let mfi = mfi(&hlc3, volume, periods[1]); + let tr = tr(high, low, source); + let atr = tr.smooth(Smooth::SMMA, periods[1]); + let gkyz = gkyz(open, high, low, source, periods[3]); + let yz = yz(open, high, low, source, periods[3]); + let (upb, _, lwb) = bb(source, Smooth::SMA, periods[5], factors[0]); + let (upkch, _, lwkch) = kch( + source, + Smooth::SMA, + &gkyz.smooth(Smooth::SMA, periods[1]), + periods[5], + factors[2], + ); + let ebb = &upb - &lwb; + let ekch = &upkch - &lwkch; + let (k, d) = stochosc( + source, + high, + low, + Smooth::SMA, + periods[1], + periods[7], + periods[8], + ); + let cci = cci(&hlc3, periods[5], factors[1]); + let roc9 = roc(source, periods[4]); + let roc14 = roc(source, periods[1]); + let hh = high.highest(periods[5]); + let ll = low.lowest(periods[5]); + let (support, resistance) = spp(high, low, source, Smooth::SMA, periods[2]); + + let (dp, dm, _) = dmi(high, low, &atr, Smooth::SMMA, periods[1], periods[1]); + + let dmi = dp - dm; + let vwap = vwap(&hlc3, volume); + + TechAnalysis { + frsi: rsi2.into(), + srsi: rsi14.into(), + fma: ema5.into(), + sma: ema11.into(), + froc: roc9.into(), + sroc: roc14.into(), + macd: histogram.into(), + ppo: ppo.into(), + cci: cci.into(), + obv: obv.into(), + vo: vo.into(), + nvol: nvol.into(), + mfi: mfi.into(), + tr: tr.into(), + gkyz: gkyz.into(), + yz: yz.into(), + upb: upb.into(), + lwb: lwb.into(), + ebb: ebb.into(), + ekch: ekch.into(), + k: k.into(), + d: d.into(), + hh: hh.into(), + ll: ll.into(), + support: support.into(), + resistance: resistance.into(), + dmi: dmi.into(), + vwap: vwap.into(), + close: source.clone().into(), + hlc3: hlc3.into(), + hlcc4: hlcc4.into(), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_remove_dublicate() { + let data = vec![ + OHLCV { + ts: 1679826900, + open: 5.992, + high: 5.993, + low: 5.976, + close: 5.980, + volume: 100.0, + }, + OHLCV { + ts: 1679825700, + open: 5.993, + high: 6.000, + low: 5.983, + close: 5.997, + volume: 100.0, + }, + OHLCV { + ts: 1679826000, + open: 5.997, + high: 6.001, + low: 5.989, + close: 6.001, + volume: 100.0, + }, + OHLCV { + ts: 1679826000, + open: 6.001, + high: 6.0013, + low: 5.993, + close: 6.007, + volume: 100.0, + }, + OHLCV { + ts: 1679826600, + open: 6.007, + high: 6.008, + low: 5.980, + close: 5.992, + volume: 100.0, + }, + ]; + let mut ts = BaseTimeSeries::new(); + + for bar in &data { + ts.add(bar); + } + + assert_eq!(ts.len(), data.len() - 1) + } + + #[test] + fn test_right_order() { + let data = vec![ + OHLCV { + ts: 1679825700, + open: 5.993, + high: 6.000, + low: 5.983, + close: 5.997, + volume: 100.0, + }, + OHLCV { + ts: 1679826000, + open: 5.997, + high: 6.001, + low: 5.989, + close: 6.001, + volume: 100.0, + }, + OHLCV { + ts: 1679826600, + open: 6.007, + high: 6.008, + low: 5.980, + close: 5.992, + volume: 100.0, + }, + OHLCV { + ts: 1679826300, + open: 6.001, + high: 6.0013, + low: 5.993, + close: 6.007, + volume: 100.0, + }, + OHLCV { + ts: 1679826900, + open: 5.992, + high: 5.993, + low: 5.976, + close: 5.980, + volume: 100.0, + }, + ]; + let mut ts = BaseTimeSeries::new(); + + for bar in &data { + ts.add(bar); + } + + let curr_bar = OHLCV { + ts: 1679826000, + open: 5.997, + high: 6.001, + low: 5.989, + close: 6.001, + volume: 100.0, + }; + + let next_bar = OHLCV { + ts: 1679826300, + open: 6.001, + high: 6.0013, + low: 5.993, + close: 6.007, + volume: 100.0, + }; + + let prev_bar = OHLCV { + ts: 1679825700, + open: 5.993, + high: 6.000, + low: 5.983, + close: 5.997, + volume: 100.0, + }; + let n = 2; + + let back_bars = ts.back_n_bars(&curr_bar, n); + + assert_eq!(ts.next_bar(&curr_bar).unwrap(), next_bar); + assert_eq!(ts.prev_bar(&curr_bar).unwrap(), prev_bar); + assert_eq!(back_bars.len(), 1); + assert_eq!(back_bars[0], prev_bar); + } + + #[test] + fn test_ohlcv() { + let data = vec![ + OHLCV { + ts: 1679825700, + open: 5.993, + high: 6.000, + low: 5.983, + close: 5.997, + volume: 100.0, + }, + OHLCV { + ts: 1679826000, + open: 5.997, + high: 6.001, + low: 5.989, + close: 6.001, + volume: 100.0, + }, + OHLCV { + ts: 1679826300, + open: 6.001, + high: 6.0013, + low: 5.993, + close: 6.007, + volume: 100.0, + }, + OHLCV { + ts: 1679826600, + open: 6.007, + high: 6.008, + low: 5.980, + close: 5.992, + volume: 100.0, + }, + OHLCV { + ts: 1679826900, + open: 5.992, + high: 5.993, + low: 5.976, + close: 5.980, + volume: 100.0, + }, + ]; + let mut ts = BaseTimeSeries::new(); + + for bar in &data { + ts.add(bar); + } + + let series = ts.ohlcv(3); + let close: Vec = series.close().clone().into(); + + assert_eq!(series.len(), 3); + assert_eq!(close[0], 6.007); + assert_eq!(close[1], 5.992); + assert_eq!(close[2], 5.980); + } +} diff --git a/ta_lib/timeseries/src/ohlcv.rs b/ta_lib/timeseries/src/ohlcv.rs new file mode 100644 index 00000000..5d0302b2 --- /dev/null +++ b/ta_lib/timeseries/src/ohlcv.rs @@ -0,0 +1,141 @@ +use core::prelude::*; +use serde::{Deserialize, Serialize}; +use std::fmt; + +#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)] +pub struct OHLCV { + pub ts: i64, + pub open: f32, + pub high: f32, + pub low: f32, + pub close: f32, + pub volume: f32, +} + +#[derive(Debug, Clone)] +pub struct OHLCVSeries { + ts: Vec, + open: Series, + high: Series, + low: Series, + close: Series, + volume: Series, +} + +impl OHLCVSeries { + pub fn open(&self) -> &Series { + &self.open + } + + pub fn high(&self) -> &Series { + &self.high + } + + pub fn low(&self) -> &Series { + &self.low + } + + pub fn close(&self) -> &Series { + &self.close + } + + pub fn volume(&self) -> &Series { + &self.volume + } + + #[inline(always)] + pub fn len(&self) -> usize { + self.close.len() + } + + pub fn bar_index(&self, bar: &OHLCV) -> usize { + self.ts + .binary_search_by(|&ts| ts.cmp(&bar.ts)) + .unwrap_or_else(|_| self.len()) + } +} + +impl<'a> From<&'a [OHLCV]> for OHLCVSeries { + fn from(data: &'a [OHLCV]) -> Self { + let len = data.len(); + + let mut ts = Vec::with_capacity(len); + let mut open = Vec::with_capacity(len); + let mut high = Vec::with_capacity(len); + let mut low = Vec::with_capacity(len); + let mut close = Vec::with_capacity(len); + let mut volume = Vec::with_capacity(len); + + for bar in data { + ts.push(bar.ts); + open.push(bar.open); + high.push(bar.high); + low.push(bar.low); + close.push(bar.close); + volume.push(bar.volume); + } + + Self { + ts, + open: Series::from(open), + high: Series::from(high), + low: Series::from(low), + close: Series::from(close), + volume: Series::from(volume), + } + } +} + +impl From> for OHLCVSeries { + fn from(data: Vec) -> Self { + OHLCVSeries::from(&data[..]) + } +} + +impl fmt::Display for OHLCV { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "OHLCV {{ ts: {}, open: {}, high: {}, low: {}, close: {}, volume: {} }}", + self.ts, self.open, self.high, self.low, self.close, self.volume + ) + } +} + +impl fmt::Display for OHLCVSeries { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + writeln!(f, "OHLCV:")?; + writeln!( + f, + "Index | Timestamp | Open | High | Low | Close | Volume" + )?; + writeln!( + f, + "--------------------------------------------------------------------------" + )?; + for i in 0..self.len() { + writeln!( + f, + "{:<5} | {:<10} | {:<8} | {:<8} | {:<8} | {:<8} | {:<8}", + i, + self.ts[i], + self.open + .get(i) + .map_or("None".to_string(), |v| v.to_string()), + self.high + .get(i) + .map_or("None".to_string(), |v| v.to_string()), + self.low + .get(i) + .map_or("None".to_string(), |v| v.to_string()), + self.close + .get(i) + .map_or("None".to_string(), |v| v.to_string()), + self.volume + .get(i) + .map_or("None".to_string(), |v| v.to_string()), + )?; + } + Ok(()) + } +} diff --git a/ta_lib/timeseries/src/ta.rs b/ta_lib/timeseries/src/ta.rs new file mode 100644 index 00000000..9e0f8d21 --- /dev/null +++ b/ta_lib/timeseries/src/ta.rs @@ -0,0 +1,36 @@ +use serde::Serialize; + +#[derive(Debug, Serialize)] +pub struct TechAnalysis { + pub frsi: Vec, + pub srsi: Vec, + pub fma: Vec, + pub sma: Vec, + pub froc: Vec, + pub sroc: Vec, + pub macd: Vec, + pub ppo: Vec, + pub cci: Vec, + pub obv: Vec, + pub vo: Vec, + pub nvol: Vec, + pub mfi: Vec, + pub tr: Vec, + pub gkyz: Vec, + pub yz: Vec, + pub upb: Vec, + pub lwb: Vec, + pub ebb: Vec, + pub ekch: Vec, + pub k: Vec, + pub d: Vec, + pub hh: Vec, + pub ll: Vec, + pub support: Vec, + pub resistance: Vec, + pub dmi: Vec, + pub vwap: Vec, + pub close: Vec, + pub hlc3: Vec, + pub hlcc4: Vec, +} diff --git a/ta_lib/timeseries/src/traits.rs b/ta_lib/timeseries/src/traits.rs new file mode 100644 index 00000000..004bcedb --- /dev/null +++ b/ta_lib/timeseries/src/traits.rs @@ -0,0 +1,11 @@ +use crate::{OHLCVSeries, TechAnalysis, OHLCV}; + +pub trait TimeSeries: Send + Sync { + fn add(&mut self, bar: &OHLCV); + fn next_bar(&self, bar: &OHLCV) -> Option; + fn prev_bar(&self, bar: &OHLCV) -> Option; + fn back_n_bars(&self, bar: &OHLCV, n: usize) -> Vec; + fn ohlcv(&self, size: usize) -> OHLCVSeries; + fn ta(&self, bar: &OHLCV) -> TechAnalysis; + fn len(&self) -> usize; +} diff --git a/tr.rs b/tr.rs new file mode 100644 index 00000000..e0c096af --- /dev/null +++ b/tr.rs @@ -0,0 +1,191 @@ +use core::prelude::*; + +pub fn tr(high: &Price, low: &Price, close: &Price) -> Price { + let prev_close = close.shift(1); + let diff = high - low; + + iff!( + high.shift(1).na(), + diff, + diff.max(&(high - &prev_close).abs()) + .max(&(low - &prev_close).abs()) + ) +} + +pub fn wtr(high: &Price, low: &Price, close: &Price) -> Price { + let prev_close = close.shift(1); + let diff = high - low; + + iff!( + high.shift(1).na(), + diff, + diff.max(&(high - &prev_close)) + .max(&(low.negate() + &prev_close)) + ) +} + +pub fn atr( + high: &Price, + low: &Price, + close: &Price, + smooth: Smooth, + period: Period, +) -> Price { + tr(high, low, close).smooth(smooth, period) +} + +pub fn snatr( + atr: &Price, + period: Period, + smooth: Smooth, + period_smooth: Period, +) -> Price { + atr.normalize(period, SCALE).smooth(smooth, period_smooth) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_true_range() { + let high = Series::from([ + 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, + 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, + ]); + let low = Series::from([ + 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, + 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, + ]); + let close = Series::from([ + 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038, + 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, + ]); + let expected = vec![ + 0.01819992, + 0.06549978, + 0.064100266, + 0.07590008, + 0.041800022, + 0.056900024, + 0.050400257, + 0.017799854, + 0.029099941, + 0.025099754, + 0.03429985, + 0.050499916, + 0.02519989, + 0.07379961, + 0.023900032, + 0.022799969, + ]; + + let result: Vec = tr(&high, &low, &close).into(); + + assert_eq!(result.len(), close.len()); + assert_eq!(result, expected); + } + + #[test] + fn test_wtrue_range() { + let high = Series::from([ + 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, + 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, + ]); + let low = Series::from([ + 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, + 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, + ]); + let close = Series::from([ + 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038, + 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, + ]); + let expected = vec![ + 0.01819992, + 0.06549978, + 0.064100266, + 0.07590008, + 0.041800022, + 0.056900024, + 0.050400257, + 0.017799854, + 0.029099941, + 0.025099754, + 0.03429985, + 0.050499916, + 0.02519989, + 0.07379961, + 0.023900032, + 0.022799969, + ]; + + let result: Vec = wtr(&high, &low, &close).into(); + + assert_eq!(result.len(), close.len()); + assert_eq!(result, expected); + } + + #[test] + fn test_atr_smma() { + let high = Series::from([ + 6.5600, 6.6049, 6.5942, 6.5541, 6.5300, 6.5700, 6.5630, 6.5362, 6.5497, 6.5480, 6.5325, + 6.5065, 6.4866, 6.5536, 6.5142, 6.5294, + ]); + let low = Series::from([ + 6.5418, 6.5394, 6.5301, 6.4782, 6.4882, 6.5131, 6.5126, 6.5184, 6.5206, 6.5229, 6.4982, + 6.4560, 6.4614, 6.4798, 6.4903, 6.5066, + ]); + let close = Series::from([ + 6.5541, 6.5942, 6.5348, 6.4950, 6.5298, 6.5616, 6.5223, 6.5300, 6.5452, 6.5254, 6.5038, + 6.4614, 6.4854, 6.4966, 6.5117, 6.5270, + ]); + let period = 3; + let expected = [ + 0.01819992, + 0.03396654, + 0.044011116, + 0.05464077, + 0.05036052, + 0.05254035, + 0.051826984, + 0.040484603, + 0.036689714, + 0.032826394, + 0.033317544, + 0.039045, + 0.03442996, + 0.047553174, + 0.03966879, + 0.03404585, + ]; + + let result: Vec = atr(&high, &low, &close, Smooth::SMMA, period).into(); + + assert_eq!(result, expected); + } + + #[test] + fn test_snatr() { + use crate::atr; + + let high = Series::from([ + 19.129, 19.116, 19.154, 19.195, 19.217, 19.285, 19.341, 19.394, 19.450, + ]); + let low = Series::from([ + 19.090, 19.086, 19.074, 19.145, 19.141, 19.155, 19.219, 19.306, 19.355, + ]); + let close = Series::from([ + 19.102, 19.100, 19.146, 19.181, 19.155, 19.248, 19.309, 19.355, 19.439, + ]); + let atr_period = 3; + let atr = atr(&high, &low, &close, Smooth::SMMA, atr_period); + let period = 3; + let expected = [ + 0.0, 0.0, 50.0, 82.575455, 99.49475, 99.747375, 100.0, 90.14032, 55.201355, + ]; + + let result: Vec = snatr(&atr, atr_period, Smooth::WMA, period).into(); + + assert_eq!(result, expected); + } +}