diff --git a/src/pyrovelocity/plots/_lineage_fate_correlation.py b/src/pyrovelocity/plots/_lineage_fate_correlation.py index 695316705..2d0e6b654 100644 --- a/src/pyrovelocity/plots/_lineage_fate_correlation.py +++ b/src/pyrovelocity/plots/_lineage_fate_correlation.py @@ -23,6 +23,7 @@ get_posterior_sample_angle_uncertainty, ) from pyrovelocity.plots._vector_fields import plot_vector_field_uncertainty +from pyrovelocity.utils import load_anndata_from_path __all__ = ["plot_lineage_fate_correlation"] @@ -31,23 +32,81 @@ @beartype def plot_lineage_fate_correlation( - posterior_samples_path: str | Path, - adata_pyrovelocity: str | Path, - adata_scvelo: str | Path, - adata_cospar: AnnData, - ax: Axes, + posterior_samples_path: str | Path | AnnData, + adata_pyrovelocity: str | Path | AnnData, + adata_scvelo: str | Path | AnnData, + adata_cospar: str | Path | AnnData, + ax: Axes | np.ndarray, fig: Figure, - state_color_dict: Dict, + # state_color_dict: Dict, ylabel: str = "Unipotent Monocyte lineage", dotsize: int = 3, scale: float = 0.35, arrow: float = 3.5, ): + """ + Plot lineage fate correlation with shared latent time estimates. + + Args: + posterior_samples_path (str | Path): Path to the posterior samples. + adata_pyrovelocity (str | Path): Path to the Pyro-Velocity AnnData object. + adata_scvelo (str | Path): Path to the scVelo AnnData object. + adata_cospar (AnnData): AnnData object with COSPAR results. + ax (Axes): Matplotlib axes. + fig (Figure): Matplotlib figure. + state_color_dict (Dict): Dictionary with cell state colors. + ylabel (str, optional): Label for y axis. Defaults to "Unipotent Monocyte lineage". + dotsize (int, optional): Size of plotted points. Defaults to 3. + scale (float, optional): Plot scale. Defaults to 0.35. + arrow (float, optional): Arrow size. Defaults to 3.5. + + Examples: + >>> # xdoctest: +SKIP + >>> import matplotlib.pyplot as plt + >>> import scanpy as sc + >>> from pyrovelocity.io.datasets import larry_cospar, larry_mono + >>> from pyrovelocity.utils import load_anndata_from_path + >>> from pyrovelocity.plots import plot_lineage_fate_correlation + ... + >>> fig, ax = plt.subplots(1, 9) + >>> fig.set_size_inches(17, 2.75) + >>> fig.subplots_adjust( + ... hspace=0.4, wspace=0.2, left=0.01, right=0.99, top=0.95, bottom=0.3 + >>> ) + ... + >>> data_set_name = "larry_mono" + >>> model_name = "model2" + >>> data_set_model_pairing = f"{data_set_name}_{model_name}" + >>> model_path = f"models/{data_set_model_pairing}" + ... + >>> adata_pyrovelocity = load_anndata_from_path(f"{model_path}/postprocessed.h5ad") + >>> # color_dict = dict( + ... # zip( + ... # adata_pyrovelocity.obs.state_info.cat.categories, + ... # adata_pyrovelocity.uns["state_info_colors"], + ... # ) + ... # ) + >>> adata_dynamical = load_anndata_from_path(f"data/processed/larry_mono_processed.h5ad") + >>> adata_cospar = load_anndata_from_path(f"data/external/larry_cospar.h5ad") + >>> plot_lineage_fate_correlation( + ... posterior_samples_path=f"{model_path}/pyrovelocity.pkl.zst", + ... adata_pyrovelocity=adata_pyrovelocity, + ... adata_scvelo=adata_dynamical, + ... adata_cospar=adata_cospar, + ... ax=ax, + ... fig=fig, + ... ) + """ posterior_samples = CompressedPickle.load(posterior_samples_path) embed_mean = posterior_samples["vector_field_posterior_mean"] - adata_scvelo = scv.read(adata_scvelo) - adata_pyrovelocity = scv.read(adata_pyrovelocity) + if isinstance(adata_pyrovelocity, str | Path): + adata_pyrovelocity = load_anndata_from_path(adata_pyrovelocity) + if isinstance(adata_scvelo, str | Path): + adata_scvelo = load_anndata_from_path(adata_scvelo) + if isinstance(adata_cospar, str | Path): + adata_cospar = load_anndata_from_path(adata_cospar) + adata_input_clone = get_clone_trajectory(adata_scvelo) adata_input_clone.obsm["clone_vector_emb"][ np.isnan(adata_input_clone.obsm["clone_vector_emb"]) @@ -85,7 +144,7 @@ def plot_lineage_fate_correlation( x="X", y="Y", hue="celltype", - palette=state_color_dict, + # palette=state_color_dict, ax=ax[0], s=dotsize, alpha=0.90, @@ -211,21 +270,31 @@ def plot_lineage_fate_correlation( "Pyro-Velocity cosine similarity: %.2f" % pyro_cos_mean, fontsize=7 ) + # The obs names in adata_pyrovelocity have a "-N" suffix that + # is not present in the adata_cospar obs Index. + patched_adata_pyrovelocity_obs_names = ( + adata_pyrovelocity.obs_names.str.replace( + r"-\d", + "", + regex=True, + ) + ) + adata_cospar_obs_subset = adata_cospar[ + patched_adata_pyrovelocity_obs_names, : + ] scv.pl.scatter( - adata_cospar[adata_pyrovelocity.obs_names.str.replace(r"-\d", ""), :], + adata=adata_cospar_obs_subset, basis="emb", fontsize=7, - color="fate_potency", + color="fate_potency_transition_map", cmap="inferno_r", show=False, ax=ax[6], s=dotsize, ) ax[6].set_title("Clonal fate potency", fontsize=7) - gold = adata_cospar[ - adata_pyrovelocity.obs_names.str.replace(r"-\d", ""), : - ].obs.fate_potency - select = ~np.isnan(gold) + gold_standard = adata_cospar_obs_subset.obs.fate_potency_transition_map + select = ~np.isnan(gold_standard) scv.pl.scatter( adata_scvelo, c="latent_time", @@ -238,9 +307,9 @@ def plot_lineage_fate_correlation( ) ax[7].set_title( "Scvelo latent time\ncorrelation: %.2f" - % spearmanr(-gold[select], adata_scvelo.obs.latent_time.values[select])[ - 0 - ], + % spearmanr( + -gold_standard[select], adata_scvelo.obs.latent_time.values[select] + )[0], fontsize=7, ) plot_posterior_time( @@ -255,8 +324,17 @@ def plot_lineage_fate_correlation( ax[8].set_title( "Pyro-Velocity shared time\ncorrelation: %.2f" % spearmanr( - -gold[select], + -gold_standard[select], posterior_samples["cell_time"].mean(0).flatten()[select], )[0], fontsize=7, ) + + for ext in ["", ".png"]: + fig.savefig( + f"lineage_fate_correlation.pdf{ext}", + facecolor=fig.get_facecolor(), + bbox_inches="tight", + edgecolor="none", + dpi=300, + )