Skip to content

Commit 89fd807

Browse files
committed
refactor(gui): fix issues with big refactor
1 parent a206d6a commit 89fd807

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

gui/app.py

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -206,10 +206,9 @@ def __init__(self):
206206
)
207207

208208

209-
def flatten_attributes(component_group, absolute_name: str, d=None) -> OrderedDict:
210-
if d is None:
211-
d = OrderedDict()
212-
209+
def flatten_attributes(
210+
component_group, absolute_name: str, d: OrderedDict
211+
) -> OrderedDict:
213212
if not hasattr(component_group, "__dict__"):
214213
return d
215214

@@ -218,14 +217,14 @@ def flatten_attributes(component_group, absolute_name: str, d=None) -> OrderedDi
218217
if name.startswith("_"):
219218
# Private attribute
220219
continue
221-
elif elem in component_group.__dict__.values():
220+
elif elem in d.values():
222221
# Don't duplicate any tiems
223222
continue
224223
elif isinstance(elem, Component):
225224
# Only add components to dict
226225
d[new_absolute_name] = elem
227226
else:
228-
d = flatten_attributes(elem, new_absolute_name, d=d)
227+
flatten_attributes(elem, new_absolute_name, d)
229228

230229
return d
231230

@@ -250,26 +249,35 @@ def __init__(self, demo: gr.Blocks) -> None:
250249
show_progress=False,
251250
)
252251

252+
ignore = ["df", "predictions_plot"]
253253
self.run.click(
254-
create_processing_function(self, ignore=["df", "predictions_plot"]),
255-
inputs=list(flatten_attributes(self, "interface").values()),
254+
create_processing_function(self, ignore=ignore),
255+
inputs=[
256+
v
257+
for k, v in flatten_attributes(self, "interface", OrderedDict()).items()
258+
if last_part(k) not in ignore
259+
],
256260
outputs=[self.results.df, self.results.predictions_plot],
257261
show_progress=True,
258262
)
259263

260264

265+
def last_part(k: str) -> str:
266+
return k.split(".")[-1]
267+
268+
261269
def create_processing_function(interface: AppInterface, ignore=[]):
262-
d = flatten_attributes(interface, "interface")
263-
keys = [k.split(".")[-1] for k in d.keys()]
264-
keys = [k for k in keys if k not in ignore]
270+
d = flatten_attributes(interface, "interface", OrderedDict())
271+
keys = [k for k in map(last_part, d.keys()) if k not in ignore]
265272
_, idx, counts = np.unique(keys, return_index=True, return_counts=True)
266273
if np.any(counts > 1):
267274
raise AssertionError("Bad keys: " + ",".join(np.array(keys)[idx[counts > 1]]))
268275

269-
def f(components):
276+
def f(*components):
270277
n = len(components)
271278
assert n == len(keys)
272-
return processing(**{keys[i]: components[i] for i in range(n)})
279+
for output in processing(**{keys[i]: components[i] for i in range(n)}):
280+
yield output
273281

274282
return f
275283

0 commit comments

Comments
 (0)