Skip to content

Commit e169b74

Browse files
committed
supporting custom python code compatible with chaining widgets
1 parent 4ac946a commit e169b74

File tree

7 files changed

+160
-28
lines changed

7 files changed

+160
-28
lines changed

ipyprogressivis/js/src/commands.js

+5-3
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ export function setBackup(nbtracker, backupstring) {
9393
backupCell.model.sharedModel.setMetadata("progressivis_backup", backupstring);
9494
}
9595

96-
export function createStageCells(nbtracker, tag, md, code) {
96+
export function createStageCells(nbtracker, tag, md, code, rw, run) {
9797
var crtWidget = nbtracker.currentWidget;
9898
var notebook = crtWidget.content;
9999
var tag = tag.toString();
@@ -128,9 +128,11 @@ export function createStageCells(nbtracker, tag, md, code) {
128128
});
129129
notebook.activeCellIndex = i + 1;
130130
var cell = notebook.widgets[i + 1];
131-
NotebookActions.run(notebook, crtWidget.sessionContext);
131+
if(run){
132+
NotebookActions.run(notebook, crtWidget.sessionContext);
133+
}
132134
cell.model.sharedModel.setMetadata("trusted", true);
133-
cell.model.sharedModel.setMetadata("editable", false);
135+
cell.model.sharedModel.setMetadata("editable", rw);
134136
cell.model.sharedModel.setMetadata("deletable", false);
135137
cell.model.sharedModel.setMetadata("progressivis_tag", tag);
136138
}

ipyprogressivis/js/src/labplugin.js

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ export const progressivisPlugin = {
8787
label: "Create stage cells",
8888
caption: "Create stage cells",
8989
execute: (args) => {
90-
cmds.createStageCells(nbtracker, args.tag, args.md, args.code);
90+
cmds.createStageCells(nbtracker, args.tag, args.md, args.code, args.rw, args.run);
9191
},
9292
});
9393
const TalkerView = class extends DOMWidgetView {

ipyprogressivis/widgets/chaining/__init__.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
from .histogram import HistogramW
1414
from .iscaler import ScalerW
1515
from .any_vega import AnyVegaW
16+
from .code_cell import CodeCellW
17+
1618
__all__ = [
1719
"Constructor",
1820
"DescStatsW",
@@ -27,5 +29,6 @@
2729
"ScalerW",
2830
"FacadeCreatorW",
2931
"HeatmapW",
30-
"AnyVegaW"
32+
"AnyVegaW",
33+
"CodeCellW"
3134
]
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from .utils import VBoxTyped, TypedBase, stage_register
2+
import ipywidgets as ipw
3+
4+
5+
class CodeCellW(VBoxTyped):
6+
class Typed(TypedBase):
7+
dongle: ipw.Label
8+
9+
def initialize(self) -> None:
10+
self.c_.dongle = ipw.Label("Chaining ...")
11+
12+
13+
stage_register["Python"] = CodeCellW

ipyprogressivis/widgets/chaining/constructor.py

+15-3
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,11 @@
66
make_button,
77
set_dag,
88
_Dag,
9+
Proxy,
910
DAGWidget,
1011
RootVBox,
1112
TypedBox,
12-
NodeVBox,
13+
NodeCarrier,
1314
TypedBase,
1415
get_widget_by_id,
1516
get_widget_by_key,
@@ -58,6 +59,7 @@ class Typed(TypedBase):
5859
record_ck: ipw.Checkbox
5960
csv: Optional[ipw.HBox]
6061
parquet: Optional[ipw.HBox]
62+
custom: Optional[ipw.HBox]
6163
replay: ipw.Button
6264

6365
last_created = None
@@ -126,6 +128,8 @@ def _start_scheduler_cb(self, btn: ipw.Button) -> None:
126128
self.child.csv = self.make_loader_box(ftype="csv", disabled=self._locked())
127129
self.child.parquet = self.make_loader_box(ftype="parquet",
128130
disabled=self._locked())
131+
self.child.custom = self.make_loader_box(ftype="custom",
132+
disabled=self._locked())
129133
self._arch_list = [b642json(elt)
130134
for elt in bunpack(self._backup.value)
131135
] if self._backup.value else []
@@ -143,6 +147,8 @@ def _play_mode_cb(self, change: dict[str, AnyType]) -> None:
143147
wg.disabled = is_replay
144148
for wg in self.child.parquet.children:
145149
wg.disabled = is_replay
150+
for wg in self.child.custom.children:
151+
wg.disabled = is_replay
146152

147153
if is_replay:
148154
self.child.record_ck.value = False
@@ -160,6 +166,7 @@ def _replay_cb(self, btn: ipw.Button) -> None:
160166
btn.disabled = True
161167
self.child.csv.children[-1].disabled = True
162168
self.child.parquet.children[-1].disabled = True
169+
self.child.custom.children[-1].disabled = True
163170
self.child.record_ck.value = False
164171
self.child.record_ck.disabled = True
165172
self.child.play_mode_radio.disabled = True
@@ -170,13 +177,18 @@ def _replay_cb(self, btn: ipw.Button) -> None:
170177
replay_next(self)
171178

172179
@staticmethod
173-
def widget_by_id(key: int) -> NodeVBox:
180+
def widget_by_id(key: int) -> NodeCarrier:
174181
return get_widget_by_id(key)
175182

176183
@staticmethod
177-
def widget(key: str, num: int = 0) -> NodeVBox:
184+
def widget(key: str, num: int = 0) -> NodeCarrier:
178185
return get_widget_by_key(key, num)
179186

187+
@staticmethod
188+
def proxy(key: str, num: int = 0) -> Proxy:
189+
widget = get_widget_by_key(key, num)
190+
return Proxy(widget)
191+
180192
@property
181193
def dom_id(self) -> str:
182194
return f"prog_{id(self)}"
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
from .utils import VBoxTyped, TypedBase
2+
import ipywidgets as ipw
3+
4+
5+
class CustomLoaderW(VBoxTyped):
6+
class Typed(TypedBase):
7+
dongle: ipw.Label
8+
9+
def initialize(self) -> None:
10+
self.c_.dongle = ipw.Label("Custom loader")

ipyprogressivis/widgets/chaining/utils.py

+112-20
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
import ipywidgets as ipw
88
from progressivis.table.dshape import dataframe_dshape
99
from progressivis.vis import DataShape
10-
from progressivis.core import Sink
10+
from progressivis.core import Sink, Scheduler
1111
from progressivis.core import Module
1212
from progressivis.table.table_facade import TableFacade
1313
from progressivis.core.utils import normalize_columns
@@ -69,6 +69,47 @@ class Header:
6969
modules_out: Sidecar
7070

7171

72+
class Proxy:
73+
def __init__(self, carrier: "NodeCarrier") -> None:
74+
self.__carrier = carrier
75+
76+
@property
77+
def input_module(self) -> ModuleOrFacade | None:
78+
if self.__carrier is PARAMS["constructor"]:
79+
return None
80+
return self.__carrier._input_module
81+
82+
@property
83+
def input_slot(self) -> str | None:
84+
if self.__carrier is PARAMS["constructor"]:
85+
return None
86+
return self.__carrier._input_slot
87+
88+
@property
89+
def input_dtypes(self) -> dict[str, str] | None:
90+
if self.__carrier is PARAMS["constructor"]:
91+
return None
92+
return self.__carrier._dtypes
93+
94+
@property
95+
def scheduler(self) -> Scheduler:
96+
if self.__carrier is PARAMS["constructor"]:
97+
assert PARAMS["constructor"] is not None
98+
return cast(Scheduler, PARAMS["constructor"].scheduler)
99+
assert self.input_module is not None
100+
return self.input_module.scheduler()
101+
102+
def resume(self, output_module: ModuleOrFacade,
103+
output_slot: str = "result",
104+
output_dtypes: dict[str, str] | None = None,
105+
freeze: bool = False) -> "NodeCarrier":
106+
self.__carrier._output_module = output_module
107+
self.__carrier._output_slot = output_slot
108+
self.__carrier._output_dtypes = output_dtypes
109+
self.__carrier.make_chaining_box()
110+
return self.__carrier
111+
112+
72113
def get_header() -> Header:
73114
"""
74115
NB: call this function ONLY from the first cell of the notebook!!
@@ -357,11 +398,11 @@ def make_button(
357398

358399

359400
stage_register: Dict[str, AnyType] = {}
360-
parent_widget: Optional["NodeVBox"] = None
401+
parent_widget: Optional["NodeCarrier"] = None
361402
parent_dtypes: Optional[Dict[str, str]] = None
362403
key_by_id: Dict[int, Tuple[str, int]] = {}
363-
widget_by_id: Dict[int, "NodeVBox"] = {}
364-
widget_by_key: Dict[Tuple[str, int], "NodeVBox"] = {}
404+
widget_by_id: Dict[int, "NodeCarrier"] = {}
405+
widget_by_key: Dict[Tuple[str, int], "NodeCarrier"] = {}
365406
widget_numbers: Dict[str, int] = defaultdict(int)
366407
recording_state: bool = False
367408

@@ -412,13 +453,16 @@ def create_loader_widget(
412453
ctx = dict(parent=obj, dtypes=dtypes, input_module=obj._output_module, dag=dag)
413454
from .csv_loader import CsvLoaderW
414455
from .parquet_loader import ParquetLoaderW
456+
from .custom_loader import CustomLoaderW
415457

416-
loader: Union[CsvLoaderW, ParquetLoaderW]
458+
loader: CsvLoaderW | ParquetLoaderW | CustomLoaderW
417459
if ftype == "csv":
418460
loader = CsvLoaderW()
419-
else:
420-
assert ftype == "parquet"
461+
elif ftype == "parquet":
421462
loader = ParquetLoaderW()
463+
else:
464+
assert ftype == "custom"
465+
loader = CustomLoaderW()
422466
if frozen is not None:
423467
loader.frozen_kw = frozen
424468
stage = NodeCarrier(ctx, loader)
@@ -435,11 +479,11 @@ def create_loader_widget(
435479
return stage
436480

437481

438-
def get_widget_by_id(key: int) -> "NodeVBox":
482+
def get_widget_by_id(key: int) -> "NodeCarrier":
439483
return widget_by_id[key]
440484

441485

442-
def get_widget_by_key(key: str, num: int) -> "NodeVBox":
486+
def get_widget_by_key(key: str, num: int) -> "NodeCarrier":
443487
return widget_by_key[(key, num)]
444488

445489

@@ -453,7 +497,7 @@ def set_recording_state(val: bool) -> None:
453497

454498

455499
def _make_btn_start_loader(
456-
obj: "NodeVBox", ftype: str, alias: WidgetType, frozen: AnyType = None
500+
obj: "NodeCarrier", ftype: str, alias: WidgetType, frozen: AnyType = None
457501
) -> Callable[..., None]:
458502
def _cbk(btn: ipw.Button) -> None:
459503
global parent_widget
@@ -468,7 +512,7 @@ def _cbk(btn: ipw.Button) -> None:
468512

469513

470514
def replay_start_loader(
471-
obj: "NodeVBox", ftype: str, alias: str, frozen: AnyType | None = None
515+
obj: "NodeCarrier", ftype: str, alias: str, frozen: AnyType | None = None
472516
) -> None:
473517
global parent_widget
474518
parent_widget = obj
@@ -477,7 +521,7 @@ def replay_start_loader(
477521

478522

479523
def replay_new_stage(
480-
obj: "NodeVBox", title: str, frozen: AnyType | None = None
524+
obj: "NodeCarrier", title: str, frozen: AnyType | None = None
481525
) -> None:
482526
class _FakeSel:
483527
value: str
@@ -675,6 +719,58 @@ def get_previous(obj: "ChainingWidget") -> "ChainingWidget":
675719

676720
new_stage_cell_0 = "Constructor.widget('{key}'){end}"
677721
new_stage_cell = "Constructor.widget('{key}', {num}){end}"
722+
new_stage_cell_code = ("proxy = Constructor.proxy('{key}', {num})\n"
723+
"# proxy object provides the following attributes:\n"
724+
"# input_module: Module | TableFacade \n"
725+
"# input_slot: str \n"
726+
"# input_dtypes: dict[str, str] | None\n"
727+
"# scheduler: Scheduler\n"
728+
"# Warning: keep the code above unchanged\n"
729+
"# Put your own code here\n"
730+
"...\n"
731+
"...\n"
732+
"...\n"
733+
"# fill in the following variables:\n"
734+
"output_module: 'Module | TableFacade' = ...\n"
735+
"output_slot: str = 'result'\n"
736+
"output_dtypes: dict[str, str] | None = None\n"
737+
"freeze: bool = False\n"
738+
"# Warning: keep the code below unchanged\n"
739+
"proxy.resume(output_module, output_slot, output_dtypes, freeze)"
740+
"{end}"
741+
)
742+
743+
new_loader_cell_code = ("proxy = Constructor.proxy('{key}', {num})\n"
744+
"scheduler = proxy.scheduler\n"
745+
"# Warning: keep the code above unchanged\n"
746+
"# Put your own imports here\n"
747+
"... \n"
748+
"with scheduler:\n"
749+
" # Put your own code here\n"
750+
" ...\n"
751+
" ...\n"
752+
" # fill in the following variables:\n"
753+
" output_module: 'Module | TableFacade' = ...\n"
754+
" output_slot: str = 'result'\n"
755+
" output_dtypes: dict[str, str] | None = None\n"
756+
" freeze: bool = False\n"
757+
" # Warning: keep the code below unchanged\n"
758+
" display(proxy.resume(output_module, output_slot,"
759+
" output_dtypes, freeze))"
760+
"{end}"
761+
)
762+
763+
764+
def get_stage_cell(key: str, num: int, end: str) -> tuple[str, bool, bool]:
765+
if key == "Python":
766+
return new_stage_cell_code.format(key=key, num=num, end=end), True, False
767+
return new_stage_cell.format(key=key, num=num, end=end), False, True
768+
769+
770+
def get_loader_cell(key: str, ftype: str, num: int, end: str) -> tuple[str, bool, bool]:
771+
if ftype == "custom":
772+
return new_loader_cell_code.format(key=key, num=num, end=end), True, False
773+
return new_stage_cell.format(key=key, num=num, end=end), False, True
678774

679775

680776
def add_new_stage(parent: "ChainingWidget", title: str, frozen: AnyType = None) -> None:
@@ -686,8 +782,8 @@ def add_new_stage(parent: "ChainingWidget", title: str, frozen: AnyType = None)
686782
if frozen is not None:
687783
end = ".run()"
688784
md = "## " + title + (f"[{n}]" if n else "")
689-
code = new_stage_cell.format(key=title, num=n, end=end)
690-
labcommand("progressivis:create_stage_cells", tag=tag, md=md, code=code)
785+
code, rw, run = get_stage_cell(key=title, num=n, end=end)
786+
labcommand("progressivis:create_stage_cells", tag=tag, md=md, code=code, rw=rw, run=run)
691787
add_to_record(dict(title=title, parent=parent_key))
692788

693789

@@ -703,14 +799,10 @@ def add_new_loader(
703799
end = ".run();"
704800
if alias:
705801
md = f"## {alias}"
706-
code = new_stage_cell_0.format(key=alias, end=end)
707802
else:
708803
md = "## " + title + (f"[{n}]" if n else "")
709-
if n:
710-
code = new_stage_cell.format(key=title, num=n, end=end)
711-
else:
712-
code = new_stage_cell_0.format(key=title, end=end)
713-
labcommand("progressivis:create_stage_cells", tag=tag, md=md, code=code)
804+
code, rw, run = get_loader_cell(key=alias or title, ftype=ftype, num=n, end=end)
805+
labcommand("progressivis:create_stage_cells", tag=tag, md=md, code=code, rw=rw, run=run)
714806
add_to_record(dict(ftype=ftype, alias=alias))
715807

716808

0 commit comments

Comments
 (0)