Skip to content

Commit

Permalink
fixes+improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
xtianpoli committed Nov 5, 2024
1 parent ead5d6a commit 9c0f3fa
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 75 deletions.
24 changes: 2 additions & 22 deletions ipyprogressivis/widgets/chaining/join.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from .utils import (
make_button,
stage_register,
append_child, VBox, amend_last_record, get_recording_state, disable_all, runner
append_child, VBox, amend_last_record, get_recording_state, disable_all, runner, needs_dtypes
)
import ipywidgets as ipw
from progressivis.table.group_by import UTIME_SHORT_D
Expand Down Expand Up @@ -61,6 +61,7 @@ class JoinW(VBox):
def __init__(self) -> None:
super().__init__()

@needs_dtypes
def initialize(self) -> None:
self.output_dtypes = None
dd_list = [(f"{k}[{n}]" if n
Expand Down Expand Up @@ -158,15 +159,9 @@ def run(self) -> None:
content = self.frozen_kw
if (key := content["primary_inp"]) != "parent":
primary_wg = self.get_widget_by_key(key)
if primary_wg.output_dtypes is None:
primary_wg.compute_dtypes_then_call(self.run)
return
self.dag.add_parent(self.title, primary_wg.title)
if (key := content["related_inp"]) != "parent":
related_wg = self.get_widget_by_key(key)
if related_wg.output_dtypes is None:
related_wg.compute_dtypes_then_call(self.run)
return
self.dag.add_parent(self.title, related_wg.title)
self.output_module = self.init_join(**content)
self.output_slot = "result"
Expand All @@ -177,16 +172,6 @@ def init_join(self, primary_cols: list[str], related_cols: list[str],
inv_mask: str,
how: Literal['inner', 'outer']
) -> Join:
"""kw = dict(
primary_cols=primary_cols,
related_cols=related_cols,
primary_on=primary_on,
related_on=related_on,
primary_inp=primary_inp,
related_inp=related_inp,
inv_mask=inv_mask,
how=how
)"""
if primary_inp == "parent":
primary_wg = self.parent
related_wg = self.get_widget_by_key(related_inp)
Expand All @@ -195,15 +180,10 @@ def init_join(self, primary_cols: list[str], related_cols: list[str],
assert isinstance(primary_inp, tuple)
primary_wg = self.get_widget_by_key(primary_inp)
related_wg = self.parent
"""second_wg = primary_wg
if second_wg.output_dtypes is None:
second_wg."""
s = self.input_module.scheduler()
with s:
assert primary_wg is not None
assert related_wg is not None
assert primary_wg.output_dtypes is not None
assert related_wg.output_dtypes is not None
join = Join(how=how, inv_mask=inv_mask, scheduler=s)
join.create_dependent_modules(
related_module=cast(Module, related_wg.output_module),
Expand Down
59 changes: 6 additions & 53 deletions ipyprogressivis/widgets/chaining/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,6 @@ def resume(self) -> "NodeCarrier":
else:
self.__carrier.make_progress_bar()
parent_widget = self.__carrier
print("replay_next_if() from resume")
replay_next_if(self.__carrier)
return self.__carrier

Expand Down Expand Up @@ -358,11 +357,9 @@ def replay_next(obj: Optional[Union["Constructor", "NodeVBox"]] = None) -> None:
if (
parent is not None and tuple(parent) in PARAMS["deleted_stages"]
): # skipping deleted
print("replay_next_if() from replay_next 1")
return replay_next_if()
if "deleted" in stage:
PARAMS["deleted_stages"].add((stage["title"], stage["number"]))
print("replay_next_if() from replay_next 2")
return replay_next_if()
if obj is None and stage and "ftype" not in stage: # not a loader => has a parent
assert parent is not None
Expand Down Expand Up @@ -722,16 +719,7 @@ class _FakeSel:
sel.value = title
global parent_widget
parent_widget = obj
if obj._output_dtypes is None and False:
s = obj._output_module.scheduler()
with s:
ds = DataShape(scheduler=s)
ds.input.table = obj._output_module.output.result
ds.on_after_run(obj.make_guess_types_toc2(sel, frozen, number=number))
sink = Sink(scheduler=s)
sink.input.inp = ds.output.result
else:
add_new_stage(obj, title, frozen=frozen, number=number)
add_new_stage(obj, title, frozen=frozen, number=number)


def remove_tagged_cells(tag: int) -> None:
Expand Down Expand Up @@ -763,11 +751,6 @@ class ChainingProtocol(Protocol):
_output_dtypes: Optional[Dict[str, str]]
_output_module: ModuleOrFacade

def make_guess_types_toc2(
self, sel: ipw.Select, frozen: AnyType | None = None, number: int | None = None
) -> Callable[..., AnyType]:
...

def _make_btn_chain_it_cb(
self, sel: AnyType, frozen: AnyType | None = None, number: int | None = None
) -> Callable[..., None]:
Expand All @@ -777,26 +760,6 @@ def _make_btn_chain_it_cb(
class ChainingMixin:
_output_module: ModuleOrFacade

def make_guess_types_toc2(
self, sel: ipw.Select, frozen: AnyType | None = None, number: int | None = None
) -> Callable[..., AnyType]:
def _guess(m: Module, run_number: int) -> None:
global parent_dtypes
assert hasattr(m, "result")
if m.result is None:
return
parent_dtypes = {
k: "datetime64" if str(v)[0] == "6" else v
for (k, v) in m.result.items()
}
self.output_dtypes = parent_dtypes
add_new_stage(self, sel.value, frozen, number=number) # type: ignore
with m.scheduler() as dataflow:
deps = dataflow.collateral_damage(m.name)
dataflow.delete_modules(*deps)

return _guess

def _make_btn_chain_it_cb(
self: ChainingProtocol,
sel: AnyType,
Expand All @@ -806,16 +769,7 @@ def _make_btn_chain_it_cb(
def _cbk(btn: ipw.Button) -> None:
global parent_widget
parent_widget = self # type: ignore
if self._output_dtypes is None:
s = self._output_module.scheduler()
with s:
ds = DataShape(scheduler=s)
ds.input.table = self._output_module.output.result
ds.on_after_run(self.make_guess_types_toc2(sel, frozen=frozen))
sink = Sink(scheduler=s)
sink.input.inp = ds.output.result
else:
add_new_stage(self, sel.value, frozen=frozen, number=number) # type: ignore
add_new_stage(self, sel.value, frozen=frozen, number=number) # type: ignore

return _cbk

Expand Down Expand Up @@ -1075,7 +1029,6 @@ def add_new_loader(
code, rw, run = get_loader_cell(
key=alias or title, ftype=ftype, num=n, end=end, frozen=frozen
)
print("RUN IS:", run)
labcommand(
"progressivis:create_stage_cells", frozen=frozen, tag=tag, md=md, code=code, rw=rw, run=run
)
Expand Down Expand Up @@ -1238,7 +1191,10 @@ def _guess2(m: Module, run_number: int) -> None:
k: "datetime64" if str(v)[0] == "6" else v
for (k, v) in m.result.items()
}
self_ = args[0] # type: ignore
if hasattr(fun, "__self__"): # i.e. fun is a bound method
self_ = fun.__self__
else:
self_ = args[0] # type: ignore
self_.carrier._dtypes = self.output_dtypes
fun(*args, **kw)
with m.scheduler() as dataflow:
Expand Down Expand Up @@ -1298,19 +1254,16 @@ def post_run(self) -> "NodeCarrier":
# chaining_boxes_to_make.append(self)
else:
self.carrier.make_progress_bar()
print("replay_next_if() from post_run", self)
replay_next_if(self.carrier)
return self.carrier

def post_delete(self) -> "NodeCarrier":
self.carrier.children = (ipw.Label("deleted"),)
print("replay_next_if() from post_delete")
replay_next_if()
return self.carrier

def manage_replay(self) -> None:
if self._do_replay_next:
print("replay_next_if() from manage_replay")
replay_next_if()


Expand Down

0 comments on commit 9c0f3fa

Please sign in to comment.