Skip to content

Commit

Permalink
Add more type checking
Browse files Browse the repository at this point in the history
  • Loading branch information
jakobnissen committed Nov 7, 2023
1 parent 2d028ad commit 10653bc
Show file tree
Hide file tree
Showing 6 changed files with 162 additions and 142 deletions.
76 changes: 36 additions & 40 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import time
import random
from math import isfinite
from typing import Optional, IO, Tuple
from typing import Optional, IO, Tuple, Union
from pathlib import Path
from collections.abc import Sequence
from collections import defaultdict
Expand All @@ -34,7 +34,7 @@
sys.path.append(parentdir)


def try_make_dir(name):
def try_make_dir(name: Union[Path, str]):
try:
os.mkdir(name)
except FileExistsError:
Expand Down Expand Up @@ -607,7 +607,7 @@ def trainvae(
begintime = time.time()
log("\nCreating and training VAE", logfile)

nsamples = data_loader.dataset.tensors[0].shape[1]
nsamples = data_loader.dataset.tensors[0].shape[1] # type:ignore
vae = vamb.encode.VAE(
nsamples,
nhiddens=vae_options.nhiddens,
Expand Down Expand Up @@ -654,7 +654,7 @@ def trainaae(
) -> tuple[np.ndarray, dict[str, set[str]]]:
begintime = time.time()
log("\nCreating and training AAE", logfile)
nsamples = data_loader.dataset.tensors[0].shape[1]
nsamples = data_loader.dataset.tensors[0].shape[1] # type:ignore

aae = vamb.aamb_encode.AAE(
nsamples,
Expand Down Expand Up @@ -729,7 +729,7 @@ def cluster(

cluster_generator = vamb.cluster.ClusterGenerator(
latent,
lengths,
lengths, # type:ignore
windowsize=cluster_options.window_size,
minsuccesses=cluster_options.min_successes,
destroy=True,
Expand Down Expand Up @@ -874,12 +874,8 @@ def run(
abundance_options=abundance_options,
logfile=logfile,
)
if hasattr(abundance, "matrix"):
rpkms = abundance.matrix
else:
rpkms = abundance
data_loader = vamb.encode.make_dataloader(
rpkms,
abundance.matrix,
composition.matrix,
composition.metadata.lengths,
256, # dummy value - we change this before using the actual loader
Expand Down Expand Up @@ -915,7 +911,7 @@ def run(
lrate=training_options.lrate,
alpha=encoder_options.alpha,
logfile=logfile,
contignames=composition.metadata.identifiers,
contignames=composition.metadata.identifiers, # type:ignore
)
print("", file=logfile)

Expand All @@ -940,8 +936,8 @@ def run(
cluster_options,
clusterspath,
latent,
comp_metadata.identifiers,
comp_metadata.lengths,
comp_metadata.identifiers, # type:ignore
comp_metadata.lengths, # type:ignore
vamb_options,
logfile,
"vae_",
Expand All @@ -958,7 +954,7 @@ def run(
vamb_options.out_dir,
clusterspath,
path,
comp_metadata.identifiers,
comp_metadata.identifiers, # type:ignore
comp_metadata.lengths,
vamb_options.min_fasta_output_size,
logfile,
Expand All @@ -979,8 +975,8 @@ def run(
cluster_options,
clusterspath,
latent_z,
comp_metadata.identifiers,
comp_metadata.lengths,
comp_metadata.identifiers, # type:ignore
comp_metadata.lengths, # type:ignore
vamb_options,
logfile,
"aae_z_",
Expand All @@ -998,7 +994,7 @@ def run(
vamb_options.out_dir,
clusterspath,
path,
comp_metadata.identifiers,
comp_metadata.identifiers, # type:ignore
comp_metadata.lengths,
vamb_options.min_fasta_output_size,
logfile,
Expand Down Expand Up @@ -1036,7 +1032,7 @@ def run(
vamb_options.out_dir,
clusterspath,
path,
comp_metadata.identifiers,
comp_metadata.identifiers, # type:ignore
comp_metadata.lengths,
vamb_options.min_fasta_output_size,
logfile,
Expand All @@ -1053,7 +1049,7 @@ def parse_mmseqs_taxonomy(
taxonomy_path: Path,
contignames: list[str], # already masked
logfile: IO[str],
) -> Tuple[list[int], list[str]]:
) -> pd.Series:
df_mmseq = pd.read_csv(taxonomy_path, delimiter="\t", header=None)
assert (
len(df_mmseq.columns) >= 9
Expand All @@ -1066,7 +1062,7 @@ def parse_mmseqs_taxonomy(
# TODO: rethink when we start working with other domains
missing_contigs = set(contignames) - set(df_mmseq[0])
for c in missing_contigs:
new_row = {i: np.nan for i in range(9)}
new_row: dict[int, Union[str, float]] = {i: np.nan for i in range(9)}
new_row[0] = c
new_row[2] = "domain"
new_row[8] = "d_Bacteria"
Expand All @@ -1086,10 +1082,10 @@ def parse_mmseqs_taxonomy(


def predict_taxonomy(
rpkms: np.array,
tnfs: np.array,
lengths: np.array,
contignames: np.array,
rpkms: np.ndarray,
tnfs: np.ndarray,
lengths: np.ndarray,
contignames: np.ndarray,
taxonomy_path: Path,
out_dir: Path,
predictor_training_options: PredictorTrainingOptions,
Expand Down Expand Up @@ -1170,12 +1166,11 @@ def predict_taxonomy(
row = 0
for predicted_vector, predicted_labels in model.predict(dataloader_vamb):
N = len(predicted_vector)
predictions = []
probs = []
labels = []
predictions: list[str] = []
probs: list[float] = []
for i in range(predicted_vector.shape[0]):
label = predicted_labels[i]
pred_labels = [label]
pred_labels: list[str] = [label]
while table_parent[label] != -1:
pred_labels.append(table_parent[label])
label = table_parent[label]
Expand All @@ -1188,7 +1183,6 @@ def predict_taxonomy(
absolute_probs = predicted_vector[i][threshold_mask]
absolute_prob = ";".join(map(str, absolute_probs[1:]))
probs.append(absolute_prob)
labels.append(pred_labels)
df_gt = pd.DataFrame(
{"contigs": contignames[row : row + N], "lengths": lengths[row : row + N]}
)
Expand All @@ -1214,22 +1208,18 @@ def extract_and_filter_data(
comp_options: CompositionOptions,
abundance_options: AbundanceOptions,
logfile: IO[str],
):
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
composition, abundance = load_composition_and_abundance(
vamb_options=vamb_options,
comp_options=comp_options,
abundance_options=abundance_options,
logfile=logfile,
)
if hasattr(abundance, "matrix"):
rpkms = abundance.matrix
else:
rpkms = abundance

log(f"{len(composition.metadata.identifiers)} contig names", logfile, 0)

return (
rpkms,
abundance.matrix,
composition.matrix,
composition.metadata.lengths,
composition.metadata.identifiers,
Expand All @@ -1250,6 +1240,7 @@ def run_taxonomy_predictor(
abundance_options=abundance_options,
logfile=logfile,
)
assert taxonomy_options.taxonomy_path is not None
predict_taxonomy(
rpkms=rpkms,
tnfs=tnfs,
Expand Down Expand Up @@ -1313,9 +1304,10 @@ def run_vaevae(
graph_column = df_gt["predictions"]
elif taxonomy_options.no_predictor:
log("Using mmseqs taxonomy for semisupervised learning", logfile, 0)
assert taxonomy_options.taxonomy_path is not None
graph_column = parse_mmseqs_taxonomy(
taxonomy_path=taxonomy_options.taxonomy_path,
contignames=contignames,
contignames=contignames, # type:ignore
logfile=logfile,
)
else:
Expand All @@ -1326,6 +1318,7 @@ def run_vaevae(
classes_order = np.array(list(graph_column.str.split(";").str[-1]))
targets = np.array([ind_nodes[i] for i in classes_order])

assert vae_options is not None
vae = vamb.h_loss.VAEVAEHLoss(
rpkms.shape[1],
len(nodes),
Expand Down Expand Up @@ -1630,7 +1623,7 @@ def init_encoder_and_training(self):
lrate=self.args.lrate,
)

def run_inner(self, logfile):
def run_inner(self, logfile: Optional[IO[str]]):
run(
vamb_options=self.vamb_options,
comp_options=self.comp_options,
Expand All @@ -1657,7 +1650,7 @@ def __init__(self, args):
lrate=self.args.lrate,
)

def run_inner(self, logfile):
def run_inner(self, logfile: Optional[IO[str]]):
run(
vamb_options=self.vamb_options,
comp_options=self.comp_options,
Expand Down Expand Up @@ -1696,7 +1689,7 @@ def __init__(self, args):
ploss=self.args.ploss,
)

def run_inner(self, logfile):
def run_inner(self, logfile: Optional[IO[str]]):
run_vaevae(
vamb_options=self.vamb_options,
comp_options=self.comp_options,
Expand Down Expand Up @@ -1731,7 +1724,7 @@ def __init__(self, args):
no_predictor=None,
)

def run_inner(self, logfile):
def run_inner(self, logfile: Optional[IO[str]]):
run_reclustering(
vamb_options=self.vamb_options,
comp_options=self.comp_options,
Expand Down Expand Up @@ -2320,6 +2313,9 @@ def main():
runner = classes_map[args.model_subcommand](args)
elif args.subcommand == RECLUSTER:
runner = ReclusteringArguments(args)
else:
# There are no more subcommands
assert False
runner.run()


Expand Down
4 changes: 3 additions & 1 deletion vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -381,7 +381,9 @@ def trainepoch(
epoch_absseloss = 0.0

if epoch in batchsteps:
data_loader = set_batchsize(data_loader, data_loader.batch_size * 2)
data_loader = set_batchsize(
data_loader, data_loader.batch_size * 2
) # type:ignore

for depths_in, tnf_in, abundance_in, weights in data_loader:
depths_in.requires_grad = True
Expand Down
Loading

0 comments on commit 10653bc

Please sign in to comment.