Skip to content

Commit

Permalink
Added kind='scatter' functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
imperorrp committed Sep 3, 2024
1 parent 700f1f1 commit 89370cd
Showing 1 changed file with 109 additions and 49 deletions.
158 changes: 109 additions & 49 deletions src/arviz_plots/plots/ppcplot.py
Original file line number Diff line number Diff line change
Expand Up @@ -295,6 +295,10 @@ def plot_ppc(
)
pc_kwargs["aes"] = pc_kwargs.get("aes", {}).copy()
pc_kwargs["aes"].setdefault("overlay", sample_dims) # setting overlay dim
pc_kwargs.setdefault("y", np.linspace(0.01, 0.1, 11))
pc_kwargs["aes"].setdefault(
"y", sample_dims
) # setting y aesthetic for sample_dims('ppc_dim' or other) in case of kind="scatter"
plot_collection = PlotCollection.wrap(
pp_distribution,
backend=backend,
Expand All @@ -309,7 +313,10 @@ def plot_ppc(
aes_map = {}
else:
aes_map = aes_map.copy()
aes_map.setdefault("predictive", plot_collection.aes_set)
if kind == "scatter":
aes_map.setdefault("predictive", plot_collection.aes_set)
else:
aes_map.setdefault("predictive", plot_collection.aes_set.difference("y"))
# setting aggregate aes_map to `[]` so `overlay` isn't applied for it
aes_map.setdefault("aggregate", [])
if labeller is None:
Expand Down Expand Up @@ -341,39 +348,60 @@ def plot_ppc(
# print(f"\nreduce_dims = {reduce_dims!r}")

if pp_kwargs is not False:
_, pp_density_aes, _ = filter_aes(plot_collection, aes_map, "predictive", reduce_dims)
_, pp_density_aes, pp_ignore = filter_aes(
plot_collection, aes_map, "predictive", reduce_dims
)
# print(f"\npp_density_aes = {pp_density_aes}\npp_density_ignore= {pp_density_ignore}")

# getting first default color from color cycle and picking it
pp_default_color = plot_bknd.get_default_aes("color", 1, {})[0]
if "color" not in pp_density_aes:
pp_kwargs.setdefault("color", pp_default_color)

if "alpha" not in pp_density_aes:
pp_kwargs.setdefault("alpha", 0.2)
if kind == "scatter":
# plot scatter plot for PPCs
# print(f"\n pp_dt = {pp_dt!r}")

# passing plot_kwargs["predictive"] to plot_kwargs_dist (unlike tracedistplot there are
# multiple artists generated via plot_dist in this plot- plot_ppc)
plot_kwargs_dist[kind] = pp_kwargs
if "marker" not in pp_density_aes:
pp_kwargs.setdefault("marker", ".")
if "size" not in pp_density_aes:
pp_kwargs.setdefault("size", 30)

# calling plot_dist with plot_collection and customized args
plot_collection.map(
trace_rug,
"predictive",
data=pp_distribution,
ignore_aes=pp_ignore,
xname=False,
**pp_kwargs,
)

plot_dist(
pp_dt,
group="pp_dt",
sample_dims=reduce_dims,
kind=kind,
plot_collection=plot_collection,
labeller=labeller,
aes_map={
kind: aes_map.get("predictive", {})
}, # aes_map[kind] is set to "predictive" aes_map
plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "predictive" plot_kwargs
stats_kwargs={
"density": stats_kwargs.get("predictive", {})
}, # stats_kwargs["density"] is set to "predictive" stats_kwargs
# via plot_dist
)
else: # continue with plotting through `plot_dist``
if "alpha" not in pp_density_aes:
pp_kwargs.setdefault("alpha", 0.2)

# passing plot_kwargs["predictive"] to plot_kwargs_dist (unlike tracedistplot there are
# multiple artists generated via plot_dist in this plot- plot_ppc)
plot_kwargs_dist[kind] = pp_kwargs

# calling plot_dist with plot_collection and customized args

plot_dist(
pp_dt,
group="pp_dt",
sample_dims=reduce_dims,
kind=kind,
plot_collection=plot_collection,
labeller=labeller,
aes_map={
kind: aes_map.get("predictive", {})
}, # aes_map[kind] is set to "predictive" aes_map
plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] set to "predictive" plot_kwargs
stats_kwargs={
"density": stats_kwargs.get("predictive", {})
}, # stats_kwargs["density"] is set to "predictive" stats_kwargs
# via plot_dist
)

# ---------STEP 2 (PPC AGGREGATE)-------------

Expand All @@ -400,18 +428,25 @@ def plot_ppc(
aggregate_reduce_dims = reduce_dims + list(sample_dims)
# print(f"\n aggregate_reduce_dims = {aggregate_reduce_dims}")

if kind == "scatter":
aggregate_kind = "kde"
# setting so that aggregate is still plotted as a kde in plot_dist

else:
aggregate_kind = kind

# passing plot_kwargs["aggregate"] to plot_kwargs_dist
plot_kwargs_dist[kind] = aggregate_kwargs
plot_kwargs_dist[aggregate_kind] = aggregate_kwargs

plot_dist(
pp_dt,
group="pp_dt",
sample_dims=aggregate_reduce_dims,
kind=kind,
kind=aggregate_kind,
plot_collection=plot_collection,
labeller=labeller,
aes_map={
kind: aes_map.get("aggregate", {})
aggregate_kind: aes_map.get("aggregate", {})
}, # aes_map[kind] is set to "aggregate" aes_map
plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "aggregate" plot_kwargs
stats_kwargs={
Expand All @@ -432,33 +467,54 @@ def plot_ppc(
)
obs_kwargs = copy(plot_kwargs.get("observed", {}))

_, obs_density_aes, _ = filter_aes(plot_collection, aes_map, "observed", reduce_dims)
_, obs_density_aes, obs_density_ignore = filter_aes(
plot_collection, aes_map, "observed", reduce_dims
)
# print(f"\nobs_density_dims = {obs_density_dims}\nobs_density_aes = {obs_density_aes}")

if "color" not in obs_density_aes:
obs_kwargs.setdefault("color", "black")

# print(f"\nobs_kwargs = {obs_kwargs}")
if kind == "scatter":
if "marker" not in obs_density_aes:
obs_kwargs.setdefault("marker", ".")
if "size" not in obs_density_aes:
obs_kwargs.setdefault("size", 30)

# passing plot_kwargs["observed"] to plot_kwargs_dist
plot_kwargs_dist[kind] = obs_kwargs
# print(f"\nobs_distribution = {obs_distribution}")

obs_dt = DataTree(name="obs_dt", data=obs_distribution)
plot_dist(
obs_dt,
group="obs_dt",
sample_dims=reduce_dims,
kind=kind,
plot_collection=plot_collection,
labeller=labeller,
aes_map={
kind: aes_map.get("observed", {})
}, # aes_map[kind] is set to "observed" aes_map
plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "observed" plot_kwargs
stats_kwargs={
"density": stats_kwargs.get("observed", {})
}, # stats_kwargs["density"] is set to "observed" stats_kwargs
)
plot_collection.map(
trace_rug,
"observed",
data=obs_distribution,
ignore_aes=obs_density_ignore,
xname=False,
y=0,
**obs_kwargs,
)

else:
# print(f"\nobs_kwargs = {obs_kwargs}")

# passing plot_kwargs["observed"] to plot_kwargs_dist
plot_kwargs_dist[kind] = obs_kwargs

obs_dt = DataTree(name="obs_dt", data=obs_distribution)
plot_dist(
obs_dt,
group="obs_dt",
sample_dims=reduce_dims,
kind=kind,
plot_collection=plot_collection,
labeller=labeller,
aes_map={
kind: aes_map.get("observed", {})
}, # aes_map[kind] is set to "observed" aes_map
plot_kwargs=plot_kwargs_dist, # plot_kwargs[kind] is set to "observed" plot_kwargs
stats_kwargs={
"density": stats_kwargs.get("observed", {})
}, # stats_kwargs["density"] is set to "observed" stats_kwargs
)

# ---------STEP 4 (observed rug plot)-----------
if observed_rug:
Expand All @@ -470,8 +526,12 @@ def plot_ppc(
)
if "color" not in rug_aes:
rug_kwargs.setdefault("color", "black")
if "marker" not in rug_aes:
rug_kwargs.setdefault("marker", "|")
if kind == "scatter":
if "marker" not in rug_aes:
rug_kwargs.setdefault("marker", ".")
else:
if "marker" not in rug_aes:
rug_kwargs.setdefault("marker", "|")
if "size" not in rug_aes:
rug_kwargs.setdefault("size", 30)

Expand Down

0 comments on commit 89370cd

Please sign in to comment.