Skip to content

Commit

Permalink
fix(plots): use pandas regex str replacement in lineage fate correlation
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <[email protected]>
  • Loading branch information
cameronraysmith committed Aug 25, 2024
1 parent 7f438b3 commit 34cfd15
Showing 1 changed file with 97 additions and 19 deletions.
116 changes: 97 additions & 19 deletions src/pyrovelocity/plots/_lineage_fate_correlation.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
get_posterior_sample_angle_uncertainty,
)
from pyrovelocity.plots._vector_fields import plot_vector_field_uncertainty
from pyrovelocity.utils import load_anndata_from_path

__all__ = ["plot_lineage_fate_correlation"]

Expand All @@ -31,23 +32,81 @@

@beartype
def plot_lineage_fate_correlation(
posterior_samples_path: str | Path,
adata_pyrovelocity: str | Path,
adata_scvelo: str | Path,
adata_cospar: AnnData,
ax: Axes,
posterior_samples_path: str | Path | AnnData,
adata_pyrovelocity: str | Path | AnnData,
adata_scvelo: str | Path | AnnData,
adata_cospar: str | Path | AnnData,
ax: Axes | np.ndarray,
fig: Figure,
state_color_dict: Dict,
# state_color_dict: Dict,
ylabel: str = "Unipotent Monocyte lineage",
dotsize: int = 3,
scale: float = 0.35,
arrow: float = 3.5,
):
"""
Plot lineage fate correlation with shared latent time estimates.
Args:
posterior_samples_path (str | Path): Path to the posterior samples.
adata_pyrovelocity (str | Path): Path to the Pyro-Velocity AnnData object.
adata_scvelo (str | Path): Path to the scVelo AnnData object.
adata_cospar (AnnData): AnnData object with COSPAR results.
ax (Axes): Matplotlib axes.
fig (Figure): Matplotlib figure.
state_color_dict (Dict): Dictionary with cell state colors.
ylabel (str, optional): Label for y axis. Defaults to "Unipotent Monocyte lineage".
dotsize (int, optional): Size of plotted points. Defaults to 3.
scale (float, optional): Plot scale. Defaults to 0.35.
arrow (float, optional): Arrow size. Defaults to 3.5.
Examples:
>>> # xdoctest: +SKIP
>>> import matplotlib.pyplot as plt
>>> import scanpy as sc
>>> from pyrovelocity.io.datasets import larry_cospar, larry_mono
>>> from pyrovelocity.utils import load_anndata_from_path
>>> from pyrovelocity.plots import plot_lineage_fate_correlation
...
>>> fig, ax = plt.subplots(1, 9)
>>> fig.set_size_inches(17, 2.75)
>>> fig.subplots_adjust(
... hspace=0.4, wspace=0.2, left=0.01, right=0.99, top=0.95, bottom=0.3
>>> )
...
>>> data_set_name = "larry_mono"
>>> model_name = "model2"
>>> data_set_model_pairing = f"{data_set_name}_{model_name}"
>>> model_path = f"models/{data_set_model_pairing}"
...
>>> adata_pyrovelocity = load_anndata_from_path(f"{model_path}/postprocessed.h5ad")
>>> # color_dict = dict(
... # zip(
... # adata_pyrovelocity.obs.state_info.cat.categories,
... # adata_pyrovelocity.uns["state_info_colors"],
... # )
... # )
>>> adata_dynamical = load_anndata_from_path(f"data/processed/larry_mono_processed.h5ad")
>>> adata_cospar = load_anndata_from_path(f"data/external/larry_cospar.h5ad")
>>> plot_lineage_fate_correlation(
... posterior_samples_path=f"{model_path}/pyrovelocity.pkl.zst",
... adata_pyrovelocity=adata_pyrovelocity,
... adata_scvelo=adata_dynamical,
... adata_cospar=adata_cospar,
... ax=ax,
... fig=fig,
... )
"""
posterior_samples = CompressedPickle.load(posterior_samples_path)
embed_mean = posterior_samples["vector_field_posterior_mean"]

adata_scvelo = scv.read(adata_scvelo)
adata_pyrovelocity = scv.read(adata_pyrovelocity)
if isinstance(adata_pyrovelocity, str | Path):
adata_pyrovelocity = load_anndata_from_path(adata_pyrovelocity)
if isinstance(adata_scvelo, str | Path):
adata_scvelo = load_anndata_from_path(adata_scvelo)
if isinstance(adata_cospar, str | Path):
adata_cospar = load_anndata_from_path(adata_cospar)

adata_input_clone = get_clone_trajectory(adata_scvelo)
adata_input_clone.obsm["clone_vector_emb"][
np.isnan(adata_input_clone.obsm["clone_vector_emb"])
Expand Down Expand Up @@ -85,7 +144,7 @@ def plot_lineage_fate_correlation(
x="X",
y="Y",
hue="celltype",
palette=state_color_dict,
# palette=state_color_dict,
ax=ax[0],
s=dotsize,
alpha=0.90,
Expand Down Expand Up @@ -211,21 +270,31 @@ def plot_lineage_fate_correlation(
"Pyro-Velocity cosine similarity: %.2f" % pyro_cos_mean, fontsize=7
)

# The obs names in adata_pyrovelocity have a "-N" suffix that
# is not present in the adata_cospar obs Index.
patched_adata_pyrovelocity_obs_names = (
adata_pyrovelocity.obs_names.str.replace(
r"-\d",
"",
regex=True,
)
)
adata_cospar_obs_subset = adata_cospar[
patched_adata_pyrovelocity_obs_names, :
]
scv.pl.scatter(
adata_cospar[adata_pyrovelocity.obs_names.str.replace(r"-\d", ""), :],
adata=adata_cospar_obs_subset,
basis="emb",
fontsize=7,
color="fate_potency",
color="fate_potency_transition_map",
cmap="inferno_r",
show=False,
ax=ax[6],
s=dotsize,
)
ax[6].set_title("Clonal fate potency", fontsize=7)
gold = adata_cospar[
adata_pyrovelocity.obs_names.str.replace(r"-\d", ""), :
].obs.fate_potency
select = ~np.isnan(gold)
gold_standard = adata_cospar_obs_subset.obs.fate_potency_transition_map
select = ~np.isnan(gold_standard)
scv.pl.scatter(
adata_scvelo,
c="latent_time",
Expand All @@ -238,9 +307,9 @@ def plot_lineage_fate_correlation(
)
ax[7].set_title(
"Scvelo latent time\ncorrelation: %.2f"
% spearmanr(-gold[select], adata_scvelo.obs.latent_time.values[select])[
0
],
% spearmanr(
-gold_standard[select], adata_scvelo.obs.latent_time.values[select]
)[0],
fontsize=7,
)
plot_posterior_time(
Expand All @@ -255,8 +324,17 @@ def plot_lineage_fate_correlation(
ax[8].set_title(
"Pyro-Velocity shared time\ncorrelation: %.2f"
% spearmanr(
-gold[select],
-gold_standard[select],
posterior_samples["cell_time"].mean(0).flatten()[select],
)[0],
fontsize=7,
)

for ext in ["", ".png"]:
fig.savefig(
f"lineage_fate_correlation.pdf{ext}",
facecolor=fig.get_facecolor(),
bbox_inches="tight",
edgecolor="none",
dpi=300,
)

0 comments on commit 34cfd15

Please sign in to comment.