diff --git a/main.ipynb b/main.ipynb index 61190e2..2f00b5b 100644 --- a/main.ipynb +++ b/main.ipynb @@ -20,9 +20,23 @@ "\n", "from IPython.display import Javascript, display\n", "from IPython.utils.capture import capture_output\n", - "from kindergarten import plot\n", + "from kindergarten.core import Kindergarten\n", + "from kindergarten.constants import MAX_NUM_TRACES\n", + "\n", "from plotly.express import data\n", "\n", + "from typing import Any\n", + "import dash_bootstrap_components as dbc\n", + "import plotly.graph_objs as go\n", + "from dash import callback_context, dcc, html\n", + "from dash.dependencies import Input, Output\n", + "from jupyter_dash import JupyterDash\n", + "from plotly.subplots import make_subplots\n", + "\n", + "from kindergarten.constants import MAX_NUM_TRACES\n", + "from kindergarten.tab import Tab\n", + "from kindergarten.core import Kindergarten\n", + "\n", "\n", "# suppress JupyterDash deprecation warning\n", "warnings.filterwarnings(\n", @@ -45,6 +59,87 @@ "tips = data.tips()\n", "wind = data.wind()\n", "\n", + "# update Kindergarten\n", + "class KindergartenAdjusted(Kindergarten):\n", + " def _initialize_app(self):\n", + " self.app.config.suppress_callback_exceptions = True\n", + "\n", + " self.app.layout = dbc.Container(\n", + " [\n", + " dbc.Tabs(\n", + " [\n", + " dbc.Tab(\n", + " self.tabs[i].component(),\n", + " label=\"Trace {}\".format(i),\n", + " id=\"tab-{}\".format(i),\n", + " )\n", + " for i in range(len(self.tabs))\n", + " ],\n", + " id=\"tabs\",\n", + " ),\n", + " dbc.Row(dbc.Col(dcc.Graph(\n", + " id=\"graph\", \n", + " style={\"height\": \"calc(100vh - 250px)\"} # Added this line\n", + " ))),\n", + " ],\n", + " fluid=True,\n", + " className=\"dash-bootstrap\",\n", + " )\n", + "\n", + " for i in range(len(self.tabs)):\n", + " @self.app.callback(\n", + " Output(\"selector-{}\".format(i), \"children\"),\n", + " [\n", + " Input(\"graph-type-{}\".format(i), \"value\"),\n", + " Input(\"dataframe-{}\".format(i), \"value\"),\n", + " ],\n", + " prevent_initial_call=True,\n", + " )\n", + " def _on_graph_type_or_dataframe_change_update_selector(\n", + " graph_type, dataframe_name\n", + " ):\n", + " triggered_component_id = callback_context.triggered_id\n", + " kw, tab_id = triggered_component_id.rsplit(\"-\", 1)\n", + "\n", + " self.tabs[int(tab_id)].update_option(\n", + " kw, graph_type if kw == \"graph-type\" else dataframe_name\n", + " )\n", + " return self.tabs[int(tab_id)].options_component()\n", + "\n", + " all_options = sum(\n", + " [list(tab.options.values()) for tab in self.tabs],\n", + " [],\n", + " )\n", + " inputs = (\n", + " [Input(option.id, \"value\") for option in all_options]\n", + " + [Input(\"graph-type-{}\".format(i), \"value\") for i in range(len(self.tabs))]\n", + " + [Input(\"dataframe-{}\".format(i), \"value\") for i in range(len(self.tabs))]\n", + " )\n", + " input_names = (\n", + " [option.id for option in all_options]\n", + " + [\"graph-type-{}\".format(i) for i in range(len(self.tabs))]\n", + " + [\"dataframe-{}\".format(i) for i in range(len(self.tabs))]\n", + " )\n", + "\n", + " @self.app.callback(\n", + " Output(\"graph\", \"figure\"), inputs, prevent_initial_call=False\n", + " )\n", + " def _on_change_update_graph(*args) -> Any:\n", + " triggered_component_id = callback_context.triggered_id\n", + "\n", + " if triggered_component_id is not None:\n", + " kwargs = dict(zip(input_names, args))\n", + " value = kwargs[triggered_component_id]\n", + "\n", + " kw, tab_id = triggered_component_id.rsplit(\"-\", 1)\n", + " self.tabs[int(tab_id)].update_option(kw, value)\n", + "\n", + " return self._figure()\n", + " \n", + "def plot(num_traces=MAX_NUM_TRACES):\n", + " KindergartenAdjusted(num_traces).run()\n", + "\n", + "\n", "# launch plotting interface\n", "with capture_output() as captured:\n", " plot()\n", @@ -65,7 +160,7 @@ " flags=re.DOTALL\n", " )\n", "\n", - "Javascript(clean)\n" + "Javascript(clean)" ] } ],