Skip to content

Commit

Permalink
Update example plots
Browse files Browse the repository at this point in the history
  • Loading branch information
reidjohnson committed Sep 14, 2024
1 parent f54d55c commit 4b73d73
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 29 deletions.
2 changes: 2 additions & 0 deletions docs/sphinx_requirements.txt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
sphinx
altair
geopandas
numpydoc
pillow
pydata_sphinx_theme
Expand All @@ -8,4 +9,5 @@ skops
sphinxcontrib.bibtex
sphinx-design
sphinxext-altair
vega-datasets
vl-convert-python
77 changes: 57 additions & 20 deletions examples/plot_qrf_huggingface_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
import tempfile

import altair as alt
import geopandas as gpd
import numpy as np
import pandas as pd
from sklearn import datasets
from skops import hub_utils
from vega_datasets import data

import quantile_forest
from quantile_forest import RandomForestQuantileRegressor
Expand Down Expand Up @@ -166,16 +168,17 @@ def fit_and_upload_model(token, repo_id, local_dir="./local_repo", random_state=
X, y = datasets.fetch_california_housing(as_frame=True, return_X_y=True)
y_pred = qrf.predict(X, quantiles=quantiles) * 100_000 # predict in dollars


df = (
pd.DataFrame(y_pred, columns=quantiles)
.reset_index()
.sample(frac=sample_frac, random_state=random_state)
.melt(id_vars=["index"], var_name="quantile", value_name="value")
.rename(columns={q: f"q_{q:.3g}" for q in quantiles})
.merge(X[["Latitude", "Longitude", "Population"]].reset_index(), on="index", how="right")
)


def plot_quantiles_by_latlon(df, quantiles, color_scheme="cividis"):
def plot_quantiles_by_latlon(df, quantiles, color_scheme="lightgreyred"):
"""Plot quantile predictions on California Housing dataset by lat/lon."""
# Slider for varying the displayed quantile estimates.
slider = alt.binding_range(
Expand All @@ -187,33 +190,67 @@ def plot_quantiles_by_latlon(df, quantiles, color_scheme="cividis"):

quantile_val = alt.param(name="quantile", value=0.5, bind=slider)

# Load the US counties data and filter to California counties.
ca_counties = (
gpd.read_file(data.us_10m.url, layer="counties")
.set_crs("EPSG:4326")
.assign(**{"county_fips": lambda x: x["id"].astype(int)})
.drop(columns=["id"])
.query("(county_fips >= 6000) & (county_fips < 7000)")
)

x_min = df[[f"q_{q:.3g}" for q in quantiles]].min().min()
x_max = df[[f"q_{q:.3g}" for q in quantiles]].max().max()

df = (
gpd.GeoDataFrame(
df, geometry=gpd.points_from_xy(df["Longitude"], df["Latitude"]), crs="4326"
)
.sjoin(ca_counties, how="right")
.drop(columns=["index_left0"])
.assign(
**{f"w_q_{q:.3g}": lambda x, q=q: x[f"q_{q:.3g}"] * x["Population"] for q in quantiles}
)
)

grouped = (
df.groupby("county_fips")
.agg({**{f"w_q_{q:.3g}": "sum" for q in quantiles}, **{"Population": "sum"}})
.reset_index()
.assign(
**{f"q_{q:.3g}": lambda x, q=q: x[f"w_q_{q:.3g}"] / x["Population"] for q in quantiles}
)
)

df = (
df[["county_fips", "Latitude", "Longitude", "geometry"]]
.drop_duplicates(subset=["county_fips"])
.merge(
grouped[["county_fips", "Population"] + [f"q_{q:.3g}" for q in quantiles]],
on="county_fips",
how="left",
)
)

chart = (
alt.Chart(df)
.add_params(quantile_val)
.transform_filter("datum.quantile == quantile")
.mark_circle()
.transform_calculate(quantile_col="'q_' + quantile")
.transform_calculate(value=f"datum[datum.quantile_col]")
.mark_geoshape(stroke="black", strokeWidth=0)
.encode(
x=alt.X(
"Longitude:Q",
axis=alt.Axis(tickMinStep=1, format=".1f"),
scale=alt.Scale(zero=False),
title="Longitude",
),
y=alt.Y(
"Latitude:Q",
axis=alt.Axis(tickMinStep=1, format=".1f"),
scale=alt.Scale(zero=False),
title="Latitude",
color=alt.Color(
"value:Q",
scale=alt.Scale(domain=[x_min, x_max], scheme=color_scheme),
title="Prediction",
),
color=alt.Color("value:Q", scale=alt.Scale(scheme=color_scheme), title="Prediction"),
size=alt.Size("Population:Q"),
tooltip=[
alt.Tooltip("index:N", title="Row ID"),
alt.Tooltip("Latitude:Q", format=".2f", title="Latitude"),
alt.Tooltip("Longitude:Q", format=".2f", title="Longitude"),
alt.Tooltip("county_fips:N", title="County FIPS"),
alt.Tooltip("Population:N", format=",.0f", title="Population"),
alt.Tooltip("value:Q", format="$,.0f", title="Predicted Value"),
],
)
.project(type="mercator")
.properties(
title="Quantile Predictions on the California Housing Dataset",
height=650,
Expand Down
4 changes: 2 additions & 2 deletions examples/plot_qrf_interpolation_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,8 @@
"method": ["Actual"] * len(y),
"X": [f"Sample {idx + 1} ({x})" for idx, x in enumerate(X.tolist())],
"y_pred": y.tolist(),
"y_pred_low": y.tolist(),
"y_pred_upp": y.tolist(),
"y_pred_low": [None] * len(y),
"y_pred_upp": [None] * len(y),
"quantile_low": [None] * len(y),
"quantile_upp": [None] * len(y),
}
Expand Down
13 changes: 6 additions & 7 deletions examples/plot_qrf_multitarget.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
quantiles = np.linspace(0, 1, num=41, endpoint=True).round(3).tolist()

# Define functions that generate targets; each function maps to one target.
funcs = [
target_funcs = [
{
"signal": lambda x: np.log1p(x + 1),
"noise": lambda x: np.log1p(x) * random_state.uniform(size=len(x)),
Expand All @@ -36,18 +36,19 @@
},
]

legend = {k: v for f in funcs for k, v in f["legend"].items()}


def make_funcs_Xy(funcs, n_samples, bounds):
"""Make a dataset from specified function(s) with signal and noise."""
"""Make a dataset from specified function(s)."""
x = np.linspace(*bounds, n_samples)
y = np.empty((len(x), len(funcs)))
for i, func in enumerate(funcs):
y[:, i] = func["signal"](x) + func["noise"](x)
y[:, i] = func(x)
return np.atleast_2d(x).T, y


funcs = [lambda x, f=f: f["signal"](x) + f["noise"](x) for f in target_funcs]
legend = {k: v for f in target_funcs for k, v in f["legend"].items()}

# Create a dataset with multiple target variables.
X, y = make_funcs_Xy(funcs, n_samples, bounds)

Expand All @@ -63,7 +64,6 @@ def make_funcs_Xy(funcs, n_samples, bounds):
{
"x": np.tile(X.squeeze(), len(funcs)),
"y": y.reshape(-1, order="F"),
"y_true": np.concatenate([f["signal"](X.squeeze()) for f in funcs]),
"y_pred": np.concatenate([y_pred[:, i, len(quantiles) // 2] for i in range(len(funcs))]),
"target": np.concatenate([[str(i)] * len(X) for i in range(len(funcs))]),
**{f"q_{q_i:.3g}": y_i.ravel() for q_i, y_i in zip(quantiles, y_pred.T)},
Expand Down Expand Up @@ -95,7 +95,6 @@ def plot_multitargets(df, legend):
alt.Tooltip("target:N", title="Target"),
alt.Tooltip("x:Q", format=",.3f", title="X"),
alt.Tooltip("y:Q", format=",.3f", title="Y"),
alt.Tooltip("y_true:Q", format=",.3f", title="Y"),
alt.Tooltip("y_pred:Q", format=",.3f", title="Predicted Y"),
alt.Tooltip("y_pred_low:Q", format=",.3f", title="Predicted Lower Y"),
alt.Tooltip("y_pred_upp:Q", format=",.3f", title="Predicted Upper Y"),
Expand Down

0 comments on commit 4b73d73

Please sign in to comment.