Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion align_app/adm/decider/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from align_utils.models import ADMResult, Decision, ChoiceInfo
from .decider import MultiprocessDecider
from .client import get_decision
from .client import get_decision, is_model_cached
from .types import DeciderParams

__all__ = [
"MultiprocessDecider",
"get_decision",
"is_model_cached",
"DeciderParams",
"ADMResult",
"Decision",
Expand Down
7 changes: 7 additions & 0 deletions align_app/adm/decider/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"""

import atexit
from typing import Dict, Any
from align_utils.models import ADMResult
from .decider import MultiprocessDecider
from .types import DeciderParams
Expand All @@ -19,6 +20,12 @@ def _get_process_manager():
return _decider


async def is_model_cached(resolved_config: Dict[str, Any]) -> bool:
"""Check if model for this config is already loaded in worker."""
process_manager = _get_process_manager()
return await process_manager.is_model_cached(resolved_config)


async def get_decision(params: DeciderParams) -> ADMResult:
"""Get a decision using DeciderParams.

Expand Down
9 changes: 8 additions & 1 deletion align_app/adm/decider/decider.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Dict, Any
from align_utils.models import ADMResult
from .types import DeciderParams
from .worker import decider_worker_func
from .worker import decider_worker_func, CacheQuery, CacheQueryResult
from .multiprocess_worker import (
WorkerHandle,
create_worker,
Expand All @@ -13,6 +14,12 @@ class MultiprocessDecider:
def __init__(self):
self.worker: WorkerHandle = create_worker(decider_worker_func)

async def is_model_cached(self, resolved_config: Dict[str, Any]) -> bool:
self.worker, result = await send(self.worker, CacheQuery(resolved_config))
if isinstance(result, CacheQueryResult):
return result.is_cached
return False

async def get_decision(self, params: DeciderParams) -> ADMResult:
self.worker, result = await send(self.worker, params)

Expand Down
18 changes: 18 additions & 0 deletions align_app/adm/decider/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import json
import logging
import traceback
from dataclasses import dataclass
from typing import Dict, Tuple, Callable, Any
from multiprocessing import Queue
from align_utils.models import ADMResult
Expand All @@ -15,6 +16,16 @@ def extract_cache_key(resolved_config: Dict[str, Any]) -> str:
return hashlib.md5(cache_str.encode()).hexdigest()


@dataclass
class CacheQuery:
resolved_config: Dict[str, Any]


@dataclass
class CacheQueryResult:
is_cached: bool


def decider_worker_func(task_queue: Queue, result_queue: Queue):
root_logger = logging.getLogger()
root_logger.setLevel("WARNING")
Expand All @@ -24,6 +35,13 @@ def decider_worker_func(task_queue: Queue, result_queue: Queue):
try:
for task in iter(task_queue.get, None):
try:
if isinstance(task, CacheQuery):
cache_key = extract_cache_key(task.resolved_config)
result_queue.put(
CacheQueryResult(is_cached=cache_key in model_cache)
)
continue

params: DeciderParams = task
cache_key = extract_cache_key(params.resolved_config)

Expand Down
3 changes: 0 additions & 3 deletions align_app/app/alerts_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,3 @@ def show(self, message: str, timeout: int = -1):
self.server.state.alert_message = message
self.server.state.alert_timeout = timeout
self.server.state.alert_visible = True

def hide(self):
self.server.state.alert_visible = False
7 changes: 0 additions & 7 deletions align_app/app/runs_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,13 +103,6 @@ async def execute_run_decision(self, run_id: str) -> Optional[Run]:

return await self._execute_with_cache(run, probe.choices or [])

def has_cached_decision(self, run_id: str) -> bool:
run = runs_core.get_run(self._runs, run_id)
if not run:
return False
cache_key = run.compute_cache_key()
return runs_core.get_cached_decision(self._runs, cache_key) is not None

def get_run(self, run_id: str) -> Optional[Run]:
run = runs_core.get_run(self._runs, run_id)
if run:
Expand Down
12 changes: 6 additions & 6 deletions align_app/app/runs_state_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .runs_registry import RunsRegistry
from .runs_table_filter import RunsTableFilter
from ..adm.decider.types import DeciderParams
from ..adm.decider import is_model_cached
from ..adm.system_adm_discovery import discover_system_adms
from ..utils.utils import get_id
from .runs_presentation import extract_base_scenarios
Expand Down Expand Up @@ -616,12 +617,11 @@ async def _execute_run_decision(self, run_id: str):
with self.state:
self._add_pending_cache_key(cache_key)

is_cached = self.runs_registry.has_cached_decision(run_id)
if not is_cached:
self._alerts.show("Loading model...")
await self.server.network_completion

self._alerts.show("Making decision...")
run = self.runs_registry.get_run(run_id)
if run and await is_model_cached(run.decider_params.resolved_config):
self._alerts.show("Deciding...")
else:
self._alerts.show("Loading model and deciding...")
await self.server.network_completion

try:
Expand Down
2 changes: 1 addition & 1 deletion align_app/app/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -1544,7 +1544,7 @@ def __init__(
with vuetify3.VSnackbar(
v_model=("alert_visible", False),
text=("alert_message", ""),
location="bottom left",
location="bottom right",
color="white",
timeout=("alert_timeout", -1),
content_class="text-h6 font-weight-medium",
Expand Down