|
| 1 | +import dash |
| 2 | +import dash_bootstrap_components as dbc |
| 3 | +import pandas as pd |
| 4 | +import plotly.express as px |
| 5 | +from dash import Input, Output, dcc, html |
| 6 | + |
| 7 | +app = dash.Dash(external_stylesheets=[dbc.themes.BOOTSTRAP]) |
| 8 | +# the style arguments for the sidebar. We use position:fixed and a fixed width |
| 9 | +SIDEBAR_STYLE = { |
| 10 | + "position": "fixed", |
| 11 | + "top": 0, |
| 12 | + "left": 0, |
| 13 | + "bottom": 0, |
| 14 | + "width": "16rem", |
| 15 | + "padding": "2rem 1rem", |
| 16 | + "background-color": "#f8f9fa", |
| 17 | +} |
| 18 | + |
| 19 | +CONTENT_STYLE = { |
| 20 | + "margin-left": "18rem", |
| 21 | + "margin-right": "2rem", |
| 22 | + "padding": "2rem 1rem", |
| 23 | +} |
| 24 | + |
| 25 | +sidebar = html.Div( |
| 26 | + children=[ |
| 27 | + dcc.Input(id="sample", type="number", min=0, max=1, value=0.1, step=0.01), |
| 28 | + html.Div("Plot scale"), |
| 29 | + dcc.RadioItems(["Linear", "Log"], id="scale"), |
| 30 | + ], |
| 31 | + style=SIDEBAR_STYLE, |
| 32 | +) |
| 33 | + |
| 34 | +content = html.Div( |
| 35 | + id="page-content", |
| 36 | + style=CONTENT_STYLE, |
| 37 | + children=[ |
| 38 | + html.Div(id="max-value", style={"padding-top": "50px"}), |
| 39 | + dcc.Graph(id="scatter-plot"), |
| 40 | + dcc.Graph(id="histogram"), |
| 41 | + dcc.Store(id="sampled-dataset"), |
| 42 | + ], |
| 43 | +) |
| 44 | + |
| 45 | +app.layout = html.Div([dcc.Location(id="url"), sidebar, content]) |
| 46 | + |
| 47 | + |
| 48 | +@app.callback(Output("sampled-dataset", "data"), Input("sample", "value")) |
| 49 | +def cache_dataset(sample): |
| 50 | + df = pd.read_csv("nyc-taxi.csv") |
| 51 | + df = df.sample(frac=sample) |
| 52 | + |
| 53 | + # To cache data in this way we need to seiralize it to json |
| 54 | + json = df.to_json(date_format="iso", orient="split") |
| 55 | + return json |
| 56 | + |
| 57 | + |
| 58 | +@app.callback(Output("max-value", "children"), Input("sampled-dataset", "data")) |
| 59 | +def update_max_value(sampled_df): |
| 60 | + df = pd.read_json(sampled_df, orient="split") |
| 61 | + return f'First taxi id: {df["taxi_id"].iloc[0]}' |
| 62 | + |
| 63 | + |
| 64 | +@app.callback( |
| 65 | + Output("scatter-plot", "figure"), |
| 66 | + Input("sampled-dataset", "data"), |
| 67 | + Input("scale", "value"), |
| 68 | +) |
| 69 | +def update_scatter(sampled_df, scale): |
| 70 | + df = pd.read_json(sampled_df, orient="split") |
| 71 | + scale = scale == "Log" |
| 72 | + fig = px.scatter(df, x="total_amount", y="tip_amount", log_x=scale, log_y=scale) |
| 73 | + fig.update_layout(transition_duration=500) |
| 74 | + return fig |
| 75 | + |
| 76 | + |
| 77 | +@app.callback(Output("histogram", "figure"), Input("sampled-dataset", "data")) |
| 78 | +def update_histogram(sampled_df): |
| 79 | + df = pd.read_json(sampled_df, orient="split") |
| 80 | + fig = px.histogram(df, x="total_amount") |
| 81 | + fig.update_layout(transition_duration=500) |
| 82 | + return fig |
| 83 | + |
| 84 | + |
| 85 | +if __name__ == "__main__": |
| 86 | + app.run_server(debug=True) |
0 commit comments