Skip to content

Commit

Permalink
upd
Browse files Browse the repository at this point in the history
  • Loading branch information
m5l14i11 committed Oct 4, 2024
1 parent cc7bff5 commit b88338c
Show file tree
Hide file tree
Showing 4 changed files with 310 additions and 369 deletions.
590 changes: 268 additions & 322 deletions notebooks/pipelines/ocean.ipynb

Large diffs are not rendered by default.

8 changes: 3 additions & 5 deletions notebooks/pipelines/ocean_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
import torch.distributed as dist
import torch.multiprocessing as mp
from torch.nn.parallel import DistributedDataParallel as DDP
from torch.optim.lr_scheduler import StepLR
from torch.utils.data import DataLoader, TensorDataset
from torch.utils.data.distributed import DistributedSampler

Expand All @@ -17,12 +16,11 @@ def to_train(
feature_path: str,
latent_dim=16,
lr=1e-6,
wd=1e-5,
rank=0,
world_size=1,
batch_size=2,
num_workers=os.cpu_count(),
lr_scheduler_step=5,
lr_scheduler_gamma=0.1,
):
features = np.load(feature_path)
_, segment_length, n_features = features.shape
Expand Down Expand Up @@ -66,8 +64,8 @@ def to_train(
device_ids = [rank] if torch.cuda.is_available() else None
model = DDP(model, device_ids=device_ids)

optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = StepLR(optimizer, step_size=lr_scheduler_step, gamma=lr_scheduler_gamma)
optimizer = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)

return model, dataloader, optimizer, scheduler, device

Expand Down
18 changes: 10 additions & 8 deletions ocean/_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ async def on_receive(self, event: OceanEvent):

def _register_event_handlers(self):
self.register_handler(GetSymbols, self._get_symbols)
self.register_handler(GetSimularSymbols, self._get_simular_symbols)
self.register_handler(GetSimularSymbols, self._get_similar_symbols)
self.register_handler(UpdateSymbolSettings, self._update_symbol_settings)

def _get_symbols(self, event: GetSymbols):
Expand All @@ -43,27 +43,27 @@ def _get_symbols(self, event: GetSymbols):
if not event.cap:
return symbols

simular_symbols = self.gsim.find_similar_by_cap(
similar_symbols = self.gsim.find_similar_by_cap(
event.cap, top_k=self.config.get("top_k")
)

if not simular_symbols:
if not similar_symbols:
return symbols

return [symbol for symbol in symbols if symbol.name in simular_symbols]
return [symbol for symbol in symbols if symbol.name in similar_symbols]

def _get_simular_symbols(self, event: GetSimularSymbols):
def _get_similar_symbols(self, event: GetSimularSymbols):
exchange = self.exchange_factory.create(event.exchange)
symbols = exchange.fetch_future_symbols()

simular_symbols = self.gsim.find_similar_symbols(
similar_symbols = self.gsim.find_similar_symbols(
event.symbol.name, top_k=self.config.get("top_k")
)

if not simular_symbols:
if not similar_symbols:
return []

return [symbol for symbol in symbols if symbol.name in simular_symbols]
return [symbol for symbol in symbols if symbol.name in similar_symbols]

def _update_symbol_settings(self, event: UpdateSymbolSettings):
exchange = self.exchange_factory.create(event.exchange)
Expand All @@ -77,3 +77,5 @@ def _init_embeddings(self):

for symbol, emb in embs:
self.gsim.insert(emb, symbol)

self.gsim.perform_clustering()
63 changes: 29 additions & 34 deletions ocean/_gsim.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
from dataclasses import dataclass, field
from heapq import heappop, heappush, heappushpop
from typing import Dict, List, Optional, Tuple
Expand Down Expand Up @@ -46,7 +47,12 @@ def __eq__(self, other: "Node"):
)

def __hash__(self):
return hash((self.data.round(decimals=5).tobytes(), self.level))
hasher = hashlib.md5()

hasher.update(self.data.round(decimals=5).tobytes())
hasher.update(self.level.to_bytes(4, "little"))

return int(hasher.hexdigest(), 16)

def __lt__(self, other: "Node"):
return np.isclose(
Expand All @@ -67,7 +73,7 @@ def __init__(
self.ef_construction = ef_construction
self.ef_search = ef_search
self.entry_point = None
self.probas = self._set_probas(1 / np.log(max_neighbors))
self.cum_probas = np.cumsum(self._set_probas(1 / np.log(max_neighbors)))
self.emb = {}
self.clusters = None
self.centroids = None
Expand Down Expand Up @@ -137,7 +143,7 @@ def find_similar_by_cap(self, cap: CapType, top_k: int = 10):
return []

if self.clusters is None or self.centroids is None:
self._perform_clustering()
self.perform_clustering()

sorted_clusters_by_magnitude = sorted(
self.centroids.items(), key=lambda x: np.linalg.norm(x[1])
Expand Down Expand Up @@ -174,13 +180,24 @@ def find_similar_by_cap(self, cap: CapType, top_k: int = 10):

return [node.meta.get("symbol") for _, node in similar]

def _random_level(self) -> int:
level = 0
def perform_clustering(self, n_clusters: int = 3) -> None:
if not self.emb:
return

while level < self.max_level and np.random.rand() < self.probas[level]:
level += 1
symbols = list(self.emb.keys())
embeddings = np.array(list(self.emb.values()))

return level
gmm = GaussianMixture(
n_components=n_clusters, init_params="k-means++", random_state=1337
)

self.clusters = gmm.fit_predict(embeddings)
self.symbol_cluster_map = dict(zip(symbols, self.clusters))
self.centroids = self._calculate_cluster_centroids()

def _random_level(self) -> int:
rand_val = np.random.rand()
return int(np.searchsorted(self.cum_probas, rand_val))

def _beam_search(
self, entry_node: Node, query: np.ndarray, level: int, ef: int
Expand All @@ -203,8 +220,7 @@ def _beam_search(
visited.add(neighbor)
neighbor_dist = self._distance(neighbor.data, query)

if len(candidates) < ef or neighbor_dist < beam[0][0]:
heappush(candidates, (neighbor_dist, neighbor))
heappush(candidates, (neighbor_dist, neighbor))

return beam

Expand Down Expand Up @@ -233,21 +249,6 @@ def _select_best_neighbors(
) -> List[Tuple[float, Node]]:
return sorted(neighbors, key=lambda x: (x[0], id(x[1])))

def _perform_clustering(self, n_clusters: int = 3) -> None:
if not self.emb:
return

symbols = list(self.emb.keys())
embeddings = np.array(list(self.emb.values()))

gmm = GaussianMixture(
n_components=n_clusters, init_params="k-means++", random_state=1337
)

self.clusters = gmm.fit_predict(embeddings)
self.symbol_cluster_map = dict(zip(symbols, self.clusters))
self.centroids = self._calculate_cluster_centroids()

def _calculate_cluster_centroids(self) -> Dict[int, np.ndarray]:
clusters = np.array(
[self.symbol_cluster_map[symbol] for symbol in self.emb.keys()]
Expand All @@ -264,16 +265,10 @@ def _calculate_cluster_centroids(self) -> Dict[int, np.ndarray]:

@staticmethod
def _set_probas(m_l: float) -> List[float]:
level = 0
probas = []

while True:
proba = np.exp(-level / m_l) * (1 - np.exp(-1 / m_l))
if proba < 1e-9:
break
levels = np.arange(0, 100)

probas.append(proba)
level += 1
probas = np.exp(-levels / m_l) * (1 - np.exp(-1 / m_l))
probas = probas[probas >= 1e-9]

return probas

Expand Down

0 comments on commit b88338c

Please sign in to comment.