Skip to content

Commit

Permalink
Ocean (#49)
Browse files Browse the repository at this point in the history
* ocean
  • Loading branch information
m5l14i11 authored Oct 2, 2024
1 parent 2349496 commit a3343f0
Show file tree
Hide file tree
Showing 46 changed files with 3,982 additions and 2,711 deletions.
8 changes: 7 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,10 @@ LOG_LEVEL=INFO
WASM_FOLDER=wasm
REGIME=default

COPILOT_MODEL_PATH=
COPILOT_MODEL_PATH=
OCEAN_EMB_PATH=

MASTER_ADDR=
MASTER_PORT=
WORLD_SIZE=1
RANK=0
1 change: 1 addition & 0 deletions .python-version
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
3.12
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,13 @@ build-timeseries:
cp $(TA_LIB_DIR)/target/wasm32-wasi/release/ffi.wasm $(WASM_DIR)/timeseries.wasm

run:
pipenv run python3 quant.py
uv run python3 quant.py

format:
cargo fmt --all --manifest-path=$(TA_LIB_PATH)
pipenv run black .
pipenv run ruff . --fix
uv run black .
uv run ruff check . --unsafe-fixes --fix

update:
cargo update --manifest-path=$(TA_LIB_PATH)
pipenv update
uv sync
36 changes: 0 additions & 36 deletions Pipfile

This file was deleted.

2,525 changes: 0 additions & 2,525 deletions Pipfile.lock

This file was deleted.

8 changes: 6 additions & 2 deletions config.default.ini
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ buf_size = 50
base_dir = tmp

[bus]
piority_groups = 8
piority_groups = 13
num_workers = 5

[backtest]
Expand Down Expand Up @@ -56,4 +56,8 @@ n_threads = 7
n_gpu_layers = 2
n_batch = 256
max_tokens = 66
temperature = 0.52
temperature = 0.52

[ocean]
top_k = 20
emb_file = ''
26 changes: 16 additions & 10 deletions copilot/_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,11 @@ def fit(self, X, y=None):

new_centers = np.array(
[
X[self.labels_ == j].mean(axis=0)
if len(X[self.labels_ == j]) > 0
else self.cluster_centers_[j]
(
X[self.labels_ == j].mean(axis=0)
if len(X[self.labels_ == j]) > 0
else self.cluster_centers_[j]
)
for j in range(self.n_clusters)
]
)
Expand Down Expand Up @@ -182,7 +184,7 @@ async def _evaluate_signal(self, msg: EvaluateSignal) -> SignalRisk:
else signal_trend_risk_prompt
)

prompt = template.format(
template.format(
side=side,
strategy_type=strategy_type,
entry=curr_bar.close,
Expand All @@ -195,12 +197,16 @@ async def _evaluate_signal(self, msg: EvaluateSignal) -> SignalRisk:
cci=momentum.cci[-self.bars_n :],
roc=momentum.sroc[-self.bars_n :],
nvol=volume.nvol[-self.bars_n :],
support=trend.support[-self.bars_n :]
if side == PositionSide.SHORT
else trend.resistance[-self.bars_n :],
resistance=trend.resistance[-self.bars_n :]
if side == PositionSide.SHORT
else trend.support[-self.bars_n :],
support=(
trend.support[-self.bars_n :]
if side == PositionSide.SHORT
else trend.resistance[-self.bars_n :]
),
resistance=(
trend.resistance[-self.bars_n :]
if side == PositionSide.SHORT
else trend.support[-self.bars_n :]
),
vwap=volume.vwap[-self.bars_n :],
upper_bb=volatility.upb[-self.bars_n :],
lower_bb=volatility.lwb[-self.bars_n :],
Expand Down
17 changes: 4 additions & 13 deletions core/commands/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from core.events._base import EventMeta
from core.groups.command import CommandGroup
from core.models.broker import MarginMode, PositionMode
from core.models.entity.position import Position
from core.models.exchange import ExchangeType
from core.models.symbol import Symbol

from ._base import Command
Expand All @@ -12,24 +12,15 @@
@dataclass(frozen=True)
class BrokerCommand(Command):
meta: EventMeta = field(
default_factory=lambda: EventMeta(priority=1, group=CommandGroup.broker),
default_factory=lambda: EventMeta(priority=5, group=CommandGroup.broker),
init=False,
)


@dataclass(frozen=True)
class UpdateSettings(BrokerCommand):
class UpdateSymbolSettings(BrokerCommand):
exchange: ExchangeType
symbol: Symbol
leverage: int
position_mode: PositionMode
margin_mode: MarginMode


@dataclass(frozen=True)
class OpenPosition(BrokerCommand):
position: Position


@dataclass(frozen=True)
class ClosePosition(BrokerCommand):
position: Position
25 changes: 25 additions & 0 deletions core/commands/position.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from dataclasses import dataclass, field

from core.events._base import EventMeta
from core.groups.command import CommandGroup
from core.models.entity.position import Position

from ._base import Command


@dataclass(frozen=True)
class PositionCommand(Command):
meta: EventMeta = field(
default_factory=lambda: EventMeta(priority=1, group=CommandGroup.position),
init=False,
)


@dataclass(frozen=True)
class OpenPosition(PositionCommand):
position: Position


@dataclass(frozen=True)
class ClosePosition(PositionCommand):
position: Position
1 change: 1 addition & 0 deletions core/groups/command.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ class CommandGroup(Enum):
broker = auto()
portfolio = auto()
market = auto()
position = auto()

def __str__(self):
return self.name
8 changes: 7 additions & 1 deletion core/mixins/_event_handler.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, Callable, Dict, Type


Expand All @@ -10,6 +11,11 @@ def register_handler(self, event_type: Type[Any], handler: Callable):

async def handle_event(self, event: Any) -> Any:
handler = self._handlers.get(type(event))

if handler:
return await handler(event)
if asyncio.iscoroutinefunction(handler):
return await handler(event)
else:
return await asyncio.to_thread(handler, event)

return None
7 changes: 7 additions & 0 deletions core/models/cap.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
from enum import Enum, auto


class CapType(Enum):
A = auto()
B = auto()
C = auto()
2 changes: 1 addition & 1 deletion core/models/entity/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,7 +359,7 @@ def target_filter(target, tp):
dist = abs(curr_price - entry_price)

trl_ratio = trl_dist / entry_price
exit_ratio = exit_dist / entry_price
exit_dist / entry_price
dist_ratio = dist / entry_price
is_exit = session_risk == SessionRiskType.EXIT

Expand Down
10 changes: 6 additions & 4 deletions core/models/entity/position_risk.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,11 +87,13 @@ def smooth_savgol(*arrays: np.ndarray) -> List[np.ndarray]:

def smooth_spline(*arrays: np.ndarray, s: float = 1.0, k: int = 3) -> List[np.ndarray]:
return [
UnivariateSpline(np.arange(len(array)), array, s=s, k=min(k, len(array) - 1))(
np.arange(len(array))
(
UnivariateSpline(
np.arange(len(array)), array, s=s, k=min(k, len(array) - 1)
)(np.arange(len(array)))
if len(array) > k
else array
)
if len(array) > k
else array
for array in arrays
]

Expand Down
19 changes: 7 additions & 12 deletions core/queries/broker.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,29 @@
from dataclasses import dataclass, field
from typing import List
from typing import List, Optional

from core.events._base import EventMeta
from core.groups.query import QueryGroup
from core.models.entity.position import Position
from core.models.cap import CapType
from core.models.exchange import ExchangeType
from core.models.symbol import Symbol

from ._base import Query


@dataclass(frozen=True)
class GetSymbols(Query[List[Symbol]]):
exchange: ExchangeType
cap: Optional[CapType] = None
meta: EventMeta = field(
default_factory=lambda: EventMeta(priority=3, group=QueryGroup.broker),
init=False,
)


@dataclass(frozen=True)
class GetSymbol(Query[Symbol]):
class GetSimularSymbols(Query[List[Symbol]]):
symbol: Symbol
meta: EventMeta = field(
default_factory=lambda: EventMeta(priority=3, group=QueryGroup.broker),
init=False,
)


@dataclass(frozen=True)
class HasPosition(Query[bool]):
position: Position
exchange: ExchangeType
meta: EventMeta = field(
default_factory=lambda: EventMeta(priority=3, group=QueryGroup.broker),
init=False,
Expand Down
9 changes: 9 additions & 0 deletions core/queries/position.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,12 @@ class GetClosePosition(Query[Order]):
default_factory=lambda: EventMeta(priority=1, group=QueryGroup.broker),
init=False,
)


@dataclass(frozen=True)
class HasPosition(Query[bool]):
position: Position
meta: EventMeta = field(
default_factory=lambda: EventMeta(priority=3, group=QueryGroup.broker),
init=False,
)
12 changes: 7 additions & 5 deletions exchange/_bybit.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,11 @@ def fetch_position(self, symbol: Symbol, side: PositionSide):

if position and position["entryPrice"] is not None:
return {
"position_side": PositionSide.LONG
if position["side"] == "long"
else PositionSide.SHORT,
"position_side": (
PositionSide.LONG
if position["side"] == "long"
else PositionSide.SHORT
),
"entry_price": float(position.get("entryPrice", 0)),
"position_size": float(position.get("contracts", 0)),
"open_fee": -float(position.get("info", {}).get("curRealisedPnl", 0)),
Expand All @@ -269,8 +271,8 @@ def fetch_ohlcv(
symbol: Symbol,
timeframe: Timeframe,
in_sample: Lookback,
out_sample: Lookback | None,
batch_size: int,
out_sample: Lookback | None = None,
batch_size: int = 512,
):
in_sample = TIMEFRAMES_TO_LOOKBACK[(in_sample, timeframe)]
out_sample = (
Expand Down
6 changes: 5 additions & 1 deletion exchange/_exchange_factory.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from functools import lru_cache

from core.interfaces.abstract_exchange import AbstractExchange
from core.interfaces.abstract_exhange_factory import AbstractExchangeFactory
from core.interfaces.abstract_secret_service import AbstractSecretService
from core.models.exchange import ExchangeType
from exchange._bybit import Bybit

from ._bybit import Bybit


class ExchangeFactory(AbstractExchangeFactory):
Expand All @@ -12,6 +15,7 @@ def __init__(self, secret: AbstractSecretService):
super().__init__()
self.secret = secret

@lru_cache(maxsize=None)
def create(self, type: ExchangeType) -> AbstractExchange:
if type not in self._exchange_type:
raise ValueError(f"Unknown Exchange: {type}")
Expand Down
5 changes: 2 additions & 3 deletions executor/_market_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Union

from core.actors import StrategyActor
from core.commands.broker import ClosePosition, OpenPosition
from core.commands.position import ClosePosition, OpenPosition
from core.events.position import (
BrokerPositionClosed,
BrokerPositionOpened,
Expand All @@ -12,8 +12,7 @@
from core.mixins import EventHandlerMixin
from core.models.symbol import Symbol
from core.models.timeframe import Timeframe
from core.queries.broker import HasPosition
from core.queries.position import GetClosePosition, GetOpenPosition
from core.queries.position import GetClosePosition, GetOpenPosition, HasPosition

logger = logging.getLogger(__name__)

Expand Down
6 changes: 3 additions & 3 deletions infrastructure/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@ def load(self, config_path: str) -> None:
self._config[current_section][key] = array_values
else:
try:
self._config[current_section][
key
] = self._parse_single_value(value)
self._config[current_section][key] = (
self._parse_single_value(value)
)
except ValueError:
array_values = value.split(",")

Expand Down
Loading

0 comments on commit a3343f0

Please sign in to comment.