-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathworkbench.py
147 lines (104 loc) · 3.25 KB
/
workbench.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
# %%
from flotta.core.distributions import Collect
from flotta.core.estimators import GroupCountEstimator, MeanEstimator
from flotta.core.model_operations import Aggregation, Train, TrainTest
from flotta.core.models import FederatedRandomForestClassifier, StrategyRandomForestClassifier
from flotta.core.steps import Finalize, Parallel
from flotta.core.transformers import FederatedSplitter, FederatedKBinsDiscretizer
from flotta.schemas.workbench import WorkbenchResource
from flotta.workbench import (
Context,
Project,
Client,
Artifact,
ArtifactStatus,
DataSource,
)
import numpy as np
import json
# %% create the context
ctx = Context("http://localhost:1456")
# %% load a project given a token
project_token = "58981bcbab77ef4b8e01207134c38873e0936a9ab88cd76b243a2e2c85390b94"
project: Project = ctx.project(project_token)
# %% What is this project?
print(project)
# %% (for DEBUG) ask the context for clients of a project
clients: list[Client] = ctx.clients(project)
for c in clients:
print(c)
# %% (for DEBUG) ask the context for data source of a project
datasources: list[DataSource] = ctx.datasources(project)
for datasource in datasources:
print(datasource) # <--- non aggregated
# %% working with data
ds = project.data # <--- AggregatedDataSource
print(ds)
# this is like a describe, but per single feature
for feature in ds.features:
print(feature)
# %% develop a filter query
# prepare transformation query with all features
q = project.extract()
# inspect a feature data type
feature = q["variety"]
print(feature)
# add filter to the extraction query
q = q.add(q["variety"] < 2)
# %% add transformer
q = q.add(
FederatedKBinsDiscretizer(
features_in=[q["variety"]],
features_out=[q["variety_discr"]],
)
)
# %% statistics 1
gc = GroupCountEstimator(
query=q,
by=["variety_discr"],
features=["variety_discr"],
)
ret = ctx.submit(project, gc.get_steps())
# %% statistics 2
me = MeanEstimator(query=q)
ret = ctx.submit(project, me.get_steps())
# %% prepare the model steps
model = FederatedRandomForestClassifier(
n_estimators=10,
strategy=StrategyRandomForestClassifier.MERGE,
)
label = "MedHouseValDiscrete"
steps = [
Parallel(
TrainTest(
query=project.extract().add(
FederatedSplitter(
random_state=42,
test_percentage=0.2,
label=label,
)
),
trainer=Train(model=model),
model=model,
),
Collect(),
),
Finalize(
Aggregation(model=model),
),
]
# %% submit the task to the node, it will be converted to an Artifact
a: Artifact = ctx.submit(project, steps)
print(json.dumps(a.model_dump(), indent=True)) # view execution plan
# %% monitor learning progress
status: ArtifactStatus = ctx.status(a)
print(status)
# %% list produced resources
resources: list[WorkbenchResource] = ctx.list_resources(a)
for r in resources:
print(r.resource_id, r.creation_time, r.is_ready)
# %% get latest produced resource (the result)
agg_model: FederatedRandomForestClassifier = ctx.get_latest_resource(a)["model"]
# %%
print(agg_model.predict(np.array([[0, 0, 0, 0]])))
print(agg_model.predict(np.array([[1, 1, 1, 1]])))