-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathdummies.py
54 lines (36 loc) · 1.42 KB
/
dummies.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from typing import Any, Sequence
from flotta.core.environment import Environment
from flotta.core.interfaces import SchedulerContext, SchedulerJob, BaseStep
from flotta.core.models import AggregationModel
from flotta.core.operations.core import Operation
from flotta.core.transformers.core import QueryTransformer
from flotta.logging import get_logger
from numpy.typing import ArrayLike
import numpy as np
LOGGER = get_logger(__name__)
class DummyOp(Operation):
def exec(self, env: Environment) -> Environment:
return env
class DummyStep(BaseStep):
def step(self, env: Environment) -> Environment:
return env
def jobs(self, context: SchedulerContext) -> Sequence[SchedulerJob]:
return []
def bind(self, jobs0: Sequence[SchedulerJob], jobs1: Sequence[SchedulerJob]) -> None:
return None
class DummyModel(AggregationModel):
def train(self, x, y) -> Any:
LOGGER.info("training...")
return None
def aggregate(self, model_a, model_b) -> Any:
LOGGER.info("aggregating...")
return None
def predict(self, x) -> np.ndarray:
return np.zeros(x.shape)
def classify(self, x) -> ArrayLike | np.ndarray:
return np.zeros(x.shape)
class DummyTransformer(QueryTransformer):
def transform(self, env: Environment) -> tuple[Environment, Any]:
return env, None
def aggregate(self, env: Environment) -> Environment:
return env