Skip to content

Commit

Permalink
WIP: train [no ci]
Browse files Browse the repository at this point in the history
Signed-off-by: Cameron Smith <[email protected]>
  • Loading branch information
cameronraysmith committed Jan 27, 2024
1 parent 78456f8 commit deffc49
Show file tree
Hide file tree
Showing 7 changed files with 211 additions and 90 deletions.
23 changes: 16 additions & 7 deletions src/pyrovelocity/_velocity_model.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,30 @@
from typing import Optional, Tuple, Union
from typing import Optional
from typing import Tuple
from typing import Union

import pyro
import torch
from beartype import beartype
from jaxtyping import Float, jaxtyped
from jaxtyping import Float
from jaxtyping import jaxtyped
from pyro import poutine
from pyro.distributions import Bernoulli, LogNormal, Normal, Poisson
from pyro.nn import PyroModule, PyroSample
from pyro.distributions import Bernoulli
from pyro.distributions import LogNormal
from pyro.distributions import Normal
from pyro.distributions import Poisson
from pyro.nn import PyroModule
from pyro.nn import PyroSample
from pyro.primitives import plate
from scvi.nn import Decoder

# from torch.distributions import Bernoulli
from torch.nn.functional import relu, softplus
from torch.nn.functional import relu
from torch.nn.functional import softplus

from pyrovelocity.logging import configure_logging
from pyrovelocity.utils import mRNA


logger = configure_logging(__name__)

RNAInputType = Union[
Expand Down Expand Up @@ -435,8 +444,8 @@ def get_rna(
beta: RNAInputType,
gamma: RNAInputType,
t: Float[torch.Tensor, "num_cells time"],
u0: Float[torch.Tensor, ""],
s0: Float[torch.Tensor, ""],
u0: RNAInputType,
s0: RNAInputType,
t0: RNAInputType,
switching: Optional[RNAInputType] = None,
u_inf: Optional[RNAInputType] = None,
Expand Down
4 changes: 2 additions & 2 deletions src/pyrovelocity/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def train_model(
adata: str | AnnData,
guide_type: str = "auto",
model_type: str = "auto",
svi_train: bool = False, # svi_train alreadys turn off
svi_train: bool = False,
batch_size: int = -1,
train_size: float = 1.0,
use_gpu: int | bool = False,
Expand All @@ -41,7 +41,7 @@ def train_model(
max_epochs: int = 3000,
include_prior: bool = True,
library_size: bool = True,
offset: bool = False,
offset: bool = True,
input_type: str = "raw",
cell_specific_kinetics: Optional[str] = None,
kinetics_num: int = 2,
Expand Down
42 changes: 39 additions & 3 deletions src/pyrovelocity/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ def download_dataset(
data_external_path: str = "data/external",
source: str = "pyrovelocity",
data_url: Optional[str] = None,
n_obs: Optional[int] = None,
n_vars: Optional[int] = None,
) -> Path:
"""
Downloads a dataset based on the specified parameters and returns the path
Expand All @@ -40,6 +42,8 @@ def download_dataset(
data_external_path (Path): Path where the downloaded data will be stored. Default is 'data/external'.
source (str): The source type of the dataset. Default is 'pyrovelocity'.
data_url (str): URL from where the dataset can be downloaded. Takes precedence over source.
n_obs (int): Number of observations to sample from the dataset. Defaults to None.
n_vars (int): Number of variables to sample from the dataset. Defaults to None.
Returns:
Path: The path to the downloaded dataset file.
Expand All @@ -50,11 +54,15 @@ def download_dataset(
... 'simulated',
... str(tmp) + '/data/external',
... 'simulate',
... n_obs=100,
... n_vars=300,
... ) # xdoctest: +SKIP
>>> simulated_dataset = download_dataset(
... 'simulated_path',
... tmp / Path('data/external'),
... 'simulate',
... n_obs=100,
... n_vars=300,
... ) # xdoctest: +SKIP
>>> pancreas_dataset = download_dataset(
... data_set_name='pancreas_direct',
Expand Down Expand Up @@ -130,11 +138,24 @@ def download_dataset(
raise
else:
adata = download_method(file_path=data_path)
if n_obs is not None and n_vars is not None:
adata, _ = subset(adata=adata, n_obs=n_obs, n_vars=n_vars)
adata.write(data_path)
elif n_obs is not None:
adata, _ = subset(adata=adata, n_obs=n_obs)
adata.write(data_path)
elif n_vars is not None:
logger.warning("n_vars is ignored if n_obs is not provided")

elif source == "simulate":
if n_obs is None or n_vars is None:
raise ValueError(
f"n_obs and n_vars must be provided if source is 'simulate'"
)
logger.info(f"Generating {data_set_name} data from simulation...")
adata = generate_sample_data(
n_obs=3000,
n_vars=1000,
n_obs=n_obs,
n_vars=n_vars,
noise_model="gillespie",
random_seed=99,
)
Expand All @@ -146,7 +167,7 @@ def download_dataset(
f"Please specify a valid source or URL that resolves to a .h5ad file."
)

print_attributes(adata)
print_attributes(adata.copy())
print_anndata(adata)

if data_path.is_file() and os.access(str(data_path), os.R_OK):
Expand Down Expand Up @@ -226,6 +247,7 @@ def subset(
file_path: Optional[str | Path] = None,
adata: Optional[anndata._core.anndata.AnnData] = None,
n_obs: int = 100,
n_vars: Optional[int] = None,
save_subset: bool = False,
output_path: Optional[str | Path] = None,
) -> Tuple[anndata._core.anndata.AnnData, str | Path | None]:
Expand Down Expand Up @@ -262,8 +284,21 @@ def subset(
f"n_obs ({n_obs}) is greater than the number of observations in the dataset ({adata.n_obs})"
)
n_obs = adata.n_obs
logger.info(f"constructing data subset")
print_anndata(adata)

if n_vars is not None:
if n_vars > adata.n_vars:
logger.warning(
f"n_vars ({n_vars}) is greater than the number of variables in the dataset ({adata.n_vars})"
)
n_vars = adata.n_vars
selected_vars_indices = np.random.choice(adata.n_vars, n_vars)
logger.info(f"selected {n_vars} vars from {adata.n_vars}")
adata = adata[:, selected_vars_indices]

selected_obs_indices = np.random.choice(adata.n_obs, n_obs)
logger.info(f"selected {n_obs} obs from {adata.n_obs}")
adata = adata[selected_obs_indices]
adata.obs_names_make_unique()

Expand All @@ -279,4 +314,5 @@ def subset(
adata.write(output_path)
logger.info(f"saved {n_obs} obs subset: {output_path}")

print_anndata(adata)
return adata.copy(), output_path
16 changes: 9 additions & 7 deletions src/pyrovelocity/preprocess.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
from pathlib import Path
from typing import List, Optional, Tuple
from typing import List
from typing import Optional
from typing import Tuple

import anndata
import anndata._core.anndata
Expand All @@ -15,11 +17,10 @@
import pyrovelocity.datasets
from pyrovelocity.cytotrace import cytotrace_sparse
from pyrovelocity.logging import configure_logging
from pyrovelocity.utils import (
ensure_numpy_array,
print_anndata,
print_attributes,
)
from pyrovelocity.utils import ensure_numpy_array
from pyrovelocity.utils import print_anndata
from pyrovelocity.utils import print_attributes


logger = configure_logging(__name__)

Expand Down Expand Up @@ -153,7 +154,8 @@ def preprocess_data(
copy_raw_counts(adata)
print_anndata(adata)

if process_cytotrace:
# if process_cytotrace:
if "pancreas" in data_set_name:
print("Processing data with cytotrace ...")
cytotrace_sparse(adata, layer="spliced")

Expand Down
30 changes: 14 additions & 16 deletions src/pyrovelocity/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

pyrovelocity_train_types_defaults: Dict[str, Tuple[Type, Any]] = {
"adata": (str, "/data/processed/simulated_processed.h5ad"),
"use_gpu": (bool, False),
}

pyrovelocity_train_fields = create_dataclass_from_callable(
Expand All @@ -39,7 +40,7 @@
@beartype
def train_dataset(
data_set_name: str = "simulated",
model_identifier: str = "model1",
model_identifier: str = "model2",
pyrovelocity_train_model_args: Optional[PyroVelocityTrainInterface] = None,
force: bool = False,
) -> Tuple[Path, Path, Path, Path, Path, Path]:
Expand Down Expand Up @@ -80,18 +81,20 @@ def train_dataset(
data_model = f"{data_set_name}_{model_identifier}"
model_dir = Path(f"models/{data_model}")

trained_data_path = model_dir / "trained.h5ad"
model_path = model_dir / "model"
posterior_samples_path = model_dir / "posterior_samples.pkl.zst"
metrics_path = model_dir / "metrics.json"
run_info_path = model_dir / "run_info.json"
loss_plot_path = model_dir / "ELBO.png"

if pyrovelocity_train_model_args is None:
processed_path = Path(f"data/processed/{data_set_name}_processed.h5ad")
pyrovelocity_train_model_args = PyroVelocityTrainInterface(
adata=str(processed_path)
)

trained_data_path = model_dir / "trained.h5ad"
model_path = model_dir / "model"
posterior_samples_path = model_dir / "posterior_samples.pkl.zst"
pyrovelocity_data_path = model_dir / "pyrovelocity.pkl.zst"
metrics_path = model_dir / "metrics.json"
run_info_path = model_dir / "run_info.json"
pyrovelocity_train_model_args.loss_plot_path = str(loss_plot_path)

logger.info(f"\n\nTraining: {data_model}\n\n")

Expand All @@ -117,24 +120,22 @@ def train_dataset(
if (
os.path.isfile(trained_data_path)
and os.path.exists(model_path)
and os.path.isfile(pyrovelocity_data_path)
and os.path.isfile(posterior_samples_path)
and not force
):
logger.info(
f"\n{trained_data_path}\n"
f"{model_path}\n"
f"{pyrovelocity_data_path}\n"
f"{posterior_samples_path}\n"
"all exist, set `force=True` to overwrite."
)
return (
trained_data_path,
model_path,
posterior_samples_path,
pyrovelocity_data_path,
metrics_path,
run_info_path,
loss_plot_path,
)
else:
logger.info(f"Training model: {data_model}")
Expand All @@ -158,10 +159,7 @@ def train_dataset(

run_id = run.info.run_id

logger.info(
f"\nSaving pyrovelocity data: {pyrovelocity_data_path}\n"
f"Saving posterior samples: {posterior_samples_path}\n"
)
logger.info(f"Saving posterior samples: {posterior_samples_path}\n")
CompressedPickle.save(
posterior_samples_path,
posterior_samples,
Expand Down Expand Up @@ -214,17 +212,17 @@ def check_shared_time(posterior_samples, adata):
f"{trained_data_path}\n"
f"{model_path}\n"
f"{posterior_samples_path}\n"
f"{pyrovelocity_data_path}\n"
f"{metrics_path}\n"
f"{run_info_path}\n"
f"{loss_plot_path}\n"
)
return (
trained_data_path,
model_path,
posterior_samples_path,
pyrovelocity_data_path,
metrics_path,
run_info_path,
loss_plot_path,
)


Expand Down
21 changes: 7 additions & 14 deletions src/pyrovelocity/workflows/cli/execution_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,31 +8,22 @@
import sys
import threading
import time
from dataclasses import dataclass
from dataclasses import asdict, dataclass, is_dataclass
from datetime import timedelta
from textwrap import dedent
from typing import Any
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
from typing import Any, Dict, List, Tuple, Union

from dataclasses_json import dataclass_json
from dulwich.repo import NotGitRepository
from dulwich.repo import Repo
from dulwich.repo import NotGitRepository, Repo
from flytekit import WorkflowExecutionPhase
from flytekit.core.base_task import PythonTask
from flytekit.core.workflow import WorkflowBase
from flytekit.exceptions.system import FlyteSystemException
from flytekit.exceptions.user import FlyteTimeout
from flytekit.remote import FlyteRemote
from flytekit.remote.executions import FlyteWorkflowExecution
from hydra.conf import HelpConf
from hydra.conf import HydraConf
from hydra.conf import JobConf
from hydra_zen import ZenStore
from hydra_zen import builds
from hydra_zen import make_custom_builds_fn
from hydra.conf import HelpConf, HydraConf, JobConf
from hydra_zen import ZenStore, builds, make_custom_builds_fn


@dataclass_json
Expand Down Expand Up @@ -146,6 +137,8 @@ def generate_entity_inputs(
# check if the type is a built-in type
if isinstance(param_type, type) and param_type.__module__ == "builtins":
inputs[name] = default
elif is_dataclass(default):
inputs[name] = fbuilds(param_type, **asdict(default))
else:
# dynamically import the type if it's not a built-in type
type_module = importlib.import_module(param_type.__module__)
Expand Down
Loading

0 comments on commit deffc49

Please sign in to comment.