Skip to content

Commit

Permalink
bug fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
bfclarke committed Feb 5, 2025
1 parent 254e341 commit ea86b96
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 11 deletions.
3 changes: 1 addition & 2 deletions deeprvat/deeprvat/associate.py
Original file line number Diff line number Diff line change
Expand Up @@ -867,10 +867,9 @@ def compute_burdens_(
annotation_columns=config.get("annotations", None),
) # TODO: Use AnnGenoDataModule, stage="associate"

logger.info("Caching genotypes to memory")

# TODO: Decide whether to do this.
# Current implementation has a bug; also, does it help with overall execution time?
# logger.info("Caching genotypes to memory")
# ds.anngeno.cache_genotypes()

logger.info("Loading models")
Expand Down
15 changes: 6 additions & 9 deletions deeprvat/deeprvat/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ def cli():

def train_(
config: Dict,
training_regions: Dict[int, np.ndarray],
training_regions: Dict[str, np.ndarray],
log_dir: str,
sample_set: Optional[Set[str]] = None,
checkpoint_file: Optional[str] = None,
Expand Down Expand Up @@ -157,17 +157,17 @@ def train_(

# load datamodule
covariates = config["covariates"]
phenotypes = config["training"]["phenotypes"]
if isinstance(phenotypes, dict):
phenotypes = list(phenotypes.keys())
# phenotypes = config["training"]["phenotypes"]
# if isinstance(phenotypes, dict):
# phenotypes = list(phenotypes.keys())
annotation_columns = config.get("annotations", None)
dm = AnnGenoDataModule(
anngeno_filename,
training_regions=training_regions,
train_proportion=train_proportion,
sample_set=sample_set,
covariates=covariates,
phenotypes=phenotypes,
# phenotypes=phenotypes,
annotation_columns=annotation_columns,
**config["training"][
"dataloader_config"
Expand All @@ -182,7 +182,7 @@ def train_(
n_annotations=len(annotation_columns),
n_covariates=len(covariates),
n_genes=sum([rs.shape[0] for rs in training_regions.values()]),
n_phenotypes=len(phenotypes),
n_phenotypes=len(training_regions),
gene_covariatephenotype_mask=dm.gene_covariatephenotype_mask,
**config["model"].get("kwargs", {}),
)
Expand Down Expand Up @@ -217,9 +217,6 @@ def train_(

while True:
try:
import ipdb

ipdb.set_trace()
trainer.fit(model, dm)
except RuntimeError as e:
# if batch_size is choosen to big, it will be reduced until it fits the GPU
Expand Down

0 comments on commit ea86b96

Please sign in to comment.