Skip to content

Commit

Permalink
refactor: feature loading functions
Browse files Browse the repository at this point in the history
  • Loading branch information
sgalkina committed Oct 6, 2023
1 parent fe4e1df commit aaaa2a7
Show file tree
Hide file tree
Showing 10 changed files with 192 additions and 121 deletions.
175 changes: 99 additions & 76 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1079,42 +1079,44 @@ def parse_mmseqs_taxonomy(
if list(df_mmseq[0]) != list(contignames):
raise AssertionError(f'The contig names of taxonomy entries are not the same as in the contigs metadata')

species_column = df_mmseq[df_mmseq[2].isin(["species", "subspecies"])][8].str.split(';').str[6]
species_dict = species_column.value_counts()
unique_species = sorted(
list(species_column.unique()), key=lambda x: species_dict[x], reverse=True
)
log(
f"Found {len(unique_species)} unique species in mmseqs taxonomy file",
logfile,
1,
)
if len(unique_species) > n_species:
log(
f"Pruning the taxonomy tree, only keeping {n_species} most abundant species",
logfile,
1,
)
log(
f"Removing the species with less than {species_dict[unique_species[n_species]]} contigs",
logfile,
1,
)
non_abundant_species = set(unique_species[n_species:])
df_mmseq["tax"] = df_mmseq[8]
df_mmseq.loc[df_mmseq[3].isin(non_abundant_species), "tax"] = (
df_mmseq.loc[df_mmseq[3].isin(non_abundant_species), 8]
.str.split(";")
.str[:1]
.map(lambda x: ";".join(x))
)
graph_column = df_mmseq["tax"]
# species_column = df_mmseq[df_mmseq[2].isin(["species", "subspecies"])][8].str.split(';').str[6]
# species_dict = species_column.value_counts()
# unique_species = sorted(
# list(species_column.unique()), key=lambda x: species_dict[x], reverse=True
# )
# log(
# f"Found {len(unique_species)} unique species in mmseqs taxonomy file",
# logfile,
# 1,
# )
# if len(unique_species) > n_species:
# log(
# f"Pruning the taxonomy tree, only keeping {n_species} most abundant species",
# logfile,
# 1,
# )
# log(
# f"Removing the species with less than {species_dict[unique_species[n_species]]} contigs",
# logfile,
# 1,
# )
# non_abundant_species = set(unique_species[n_species:])
# df_mmseq["tax"] = df_mmseq[8]
# df_mmseq.loc[df_mmseq[3].isin(non_abundant_species), "tax"] = (
# df_mmseq.loc[df_mmseq[3].isin(non_abundant_species), 8]
# .str.split(";")
# .str[:1]
# .map(lambda x: ";".join(x))
# )
# graph_column = df_mmseq["tax"]
return graph_column


def predict_taxonomy(
composition: vamb.parsecontigs.Composition, # metadata already masked
abundance: vamb.parsebam.Abundance,
rpkms: np.array,
tnfs: np.array,
lengths: np.array,
contignames: np.array,
taxonomy_path: Path,
n_species: int,
out_dir: Path,
Expand All @@ -1123,12 +1125,6 @@ def predict_taxonomy(
logfile: IO[str],
):
begintime = time.time()
tnfs, lengths = composition.matrix, composition.metadata.lengths
contignames = composition.metadata.identifiers
if hasattr(abundance, 'matrix'):
rpkms = abundance.matrix
else:
rpkms = abundance

graph_column = parse_mmseqs_taxonomy(
taxonomy_path=taxonomy_path,
Expand All @@ -1148,6 +1144,7 @@ def predict_taxonomy(
len(nodes),
nodes,
table_parent,
nhiddens=[512, 512, 512, 512],
hier_loss=predictor_training_options.ploss,
cuda=cuda,
)
Expand All @@ -1174,9 +1171,6 @@ def predict_taxonomy(
log(f"Number of sequences unsuitable for encoding: {n_discarded}", logfile, 1)
log(f"Number of sequences remaining: {len(mask_vamb) - n_discarded}", logfile, 1)

names = composition.metadata.identifiers # already masked
lengths_masked = lengths

predictortime = time.time()
log(
f"Starting training the taxonomy predictor at {str(datetime.datetime.now())}",
Expand All @@ -1202,7 +1196,7 @@ def predict_taxonomy(
predicted_vector, predicted_labels = model.predict(dataloader_vamb)

log("Writing the taxonomy predictions", logfile, 0)
df_gt = pd.DataFrame({"contigs": names, "lengths": lengths_masked})
df_gt = pd.DataFrame({"contigs": contignames, "lengths": lengths})
nodes_ar = np.array(nodes)

log(f"Using threshold {predictor_training_options.softmax_threshold}", logfile, 0)
Expand Down Expand Up @@ -1236,6 +1230,47 @@ def predict_taxonomy(
)


def extract_and_filter_data(
vamb_options: VambOptions,
comp_options: CompositionOptions,
abundance_options: AbundanceOptions,
logfile: IO[str],
):
composition, abundance = load_composition_and_abundance(
vamb_options=vamb_options,
comp_options=comp_options,
abundance_options=abundance_options,
logfile=logfile,
)
tnfs, lengths = composition.matrix, composition.metadata.lengths
if hasattr(abundance, 'matrix'):
rpkms = abundance.matrix
else:
rpkms = abundance

_, mask_vamb = vamb.encode.make_dataloader(
rpkms,
tnfs,
lengths,
)

log(f"{len(composition.metadata.identifiers)} contig names", logfile, 0)

composition.metadata.filter_mask(mask_vamb)

log(
f"{len(composition.metadata.identifiers)} contig names after filtering",
logfile,
0,
)
return (
rpkms[mask_vamb],
tnfs[mask_vamb],
composition.metadata.lengths,
composition.metadata.identifiers,
)


def run_taxonomy_predictor(
vamb_options: VambOptions,
comp_options: CompositionOptions,
Expand All @@ -1244,15 +1279,17 @@ def run_taxonomy_predictor(
taxonomy_options: TaxonomyOptions,
logfile: IO[str],
):
composition, abundance = load_composition_and_abundance(
rpkms, tnfs, lengths, contignames = extract_and_filter_data(
vamb_options=vamb_options,
comp_options=comp_options,
abundance_options=abundance_options,
logfile=logfile,
)
predict_taxonomy(
composition=composition,
abundance=abundance,
rpkms=rpkms,
tnfs=tnfs,
lengths=lengths,
contignames=contignames,
taxonomy_path=taxonomy_options.taxonomy_path,
n_species=taxonomy_options.n_species,
out_dir=vamb_options.out_dir,
Expand All @@ -1274,43 +1311,22 @@ def run_vaevae(
logfile: IO[str],
):
vae_options = encoder_options.vae_options
composition, abundance = load_composition_and_abundance(
rpkms, tnfs, lengths, contignames = extract_and_filter_data(
vamb_options=vamb_options,
comp_options=comp_options,
abundance_options=abundance_options,
logfile=logfile,
)
tnfs, lengths = composition.matrix, composition.metadata.lengths
if hasattr(abundance, 'matrix'):
rpkms = abundance.matrix
else:
rpkms = abundance

dataloader_vamb, mask_vamb = vamb.encode.make_dataloader(
rpkms,
tnfs,
lengths,
batchsize=vae_training_options.batchsize,
cuda=vamb_options.cuda,
)

log(f"{len(composition.metadata.identifiers)} contig names", logfile, 0)

composition.metadata.filter_mask(mask_vamb)

log(
f"{len(composition.metadata.identifiers)} contig names after filtering",
logfile,
0,
)

if (
taxonomy_options.taxonomy_path is not None and not taxonomy_options.no_predictor
):
log("Predicting missing values from mmseqs taxonomy", logfile, 0)
predict_taxonomy(
composition=composition,
abundance=abundance,
rpkms=rpkms,
tnfs=tnfs,
lengths=lengths,
contignames=contignames,
taxonomy_path=taxonomy_options.taxonomy_path,
n_species=taxonomy_options.n_species,
out_dir=vamb_options.out_dir,
Expand All @@ -1333,7 +1349,7 @@ def run_vaevae(
log("Using mmseqs taxonomy for semisupervised learning", logfile, 0)
graph_column = parse_mmseqs_taxonomy(
taxonomy_path=taxonomy_options.taxonomy_path,
contignames=composition.metadata.identifiers,
contignames=contignames,
n_species=taxonomy_options.n_species,
logfile=logfile,
)
Expand All @@ -1359,6 +1375,13 @@ def run_vaevae(
logfile=logfile,
)

dataloader_vamb, _ = vamb.encode.make_dataloader(
rpkms,
tnfs,
lengths,
batchsize=vae_training_options.batchsize,
cuda=vamb_options.cuda,
)
dataloader_joint, _ = vamb.h_loss.make_dataloader_concat_hloss(
rpkms,
tnfs,
Expand Down Expand Up @@ -1426,8 +1449,8 @@ def run_vaevae(
cluster_options,
clusterspath,
latent_both,
composition.metadata.identifiers,
composition.metadata.lengths,
contignames,
lengths,
vamb_options,
logfile,
"vaevae_",
Expand All @@ -1444,8 +1467,8 @@ def run_vaevae(
vamb_options.out_dir,
clusterspath,
path,
composition.metadata.identifiers,
composition.metadata.lengths,
contignames,
lengths,
vamb_options.min_fasta_output_size,
logfile,
separator=cluster_options.binsplit_separator,
Expand Down
3 changes: 1 addition & 2 deletions workflow_vaevae/src/longread_human_no_predictor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

vamb \
--model vaevae \
--outdir /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/vaevae_flat_softmax \
--outdir /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/vaevae_flat_softmax__fix2 \
--fasta /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/contigs_2kbp.fna \
--rpkm /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/vambout/abundance.npz \
--taxonomy /home/projects/cpr_10006/people/svekut/04_mmseq2/taxonomy_cami_kfold/human_longread_taxonomy_metabuli_otu.tsv \
Expand All @@ -17,7 +17,6 @@ vamb \
-t 1024 \
-q \
-o C \
--cuda \
--minfasta 200000

# vamb \
Expand Down
33 changes: 33 additions & 0 deletions workflow_vaevae/src/longread_human_no_predictor_debug.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
#!/usr/bin/bash
# --taxonomy_predictions /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/vaevaeout/results_taxonomy_predictor.csv \
# --taxonomy /home/projects/cpr_10006/people/svekut/mmseq2/longread_taxonomy_2023.tsv \

# --cuda \


vamb \
--model vaevae \
--outdir /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/vaevae_flat_softmax_debug_mmseqs_predictor \
--fasta /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/contigs_2kbp.fna \
--rpkm /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/vambout/abundance.npz \
--taxonomy /home/projects/cpr_10006/people/svekut/04_mmseq2/taxonomy_cami_kfold/human_longread_taxonomy_full.tsv \
-l 64 \
-e 500 \
-t 1024 \
-q \
-o C \
--cuda \
--minfasta 200000

# vamb \
# --model reclustering \
# --latent_path /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/raw_dbscan_full/vaevae_latent.npy \
# --clusters_path /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/raw_dbscan_full/vaevae_clusters.tsv \
# --fasta /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/contigs_2kbp.fna \
# --rpkm /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/vambout/abundance.npz \
# --outdir /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/raw_dbscan_reclustering \
# --hmmout_path /home/projects/cpr_10006/projects/semi_vamb/data/marker_genes/markers_human.hmmout \
# --taxonomy_predictions /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/vaevaeout/results_taxonomy_predictor.csv \
# --algorithm dbscan \
# --minfasta 200000

8 changes: 5 additions & 3 deletions workflow_vaevae/src/longread_human_predictor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ run_id=$1

vamb \
--model taxonomy_predictor \
--outdir /home/projects/cpr_10006/people/svekut/long_read_human_kfold_predictor_lengths_${run_id} \
--outdir /home/projects/cpr_10006/people/svekut/long_read_human_kfold_predictor_flat_softmax_${run_id} \
--fasta /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/contigs_2kbp.fna \
--rpkm /home/projects/cpr_10006/projects/semi_vamb/data/human_longread/vambout/abundance.npz \
--taxonomy /home/projects/cpr_10006/people/svekut/04_mmseq2/taxonomy_cami_kfold/long_read_human_taxonomy_${run_id}.tsv \
-pe 100 \
-pq 25 75 \
--cuda
-pq \
-pt 1024 \
--cuda \
-ploss flat_softmax
8 changes: 5 additions & 3 deletions workflow_vaevae/src/longread_sludge_predictor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ run_id=$1

vamb \
--model taxonomy_predictor \
--outdir /home/projects/cpr_10006/people/svekut/long_read_sludge_kfold_predictor_lengths_${run_id} \
--outdir /home/projects/cpr_10006/people/svekut/long_read_sludge_kfold_predictor_flat_softmax_${run_id} \
--fasta /home/projects/cpr_10006/projects/semi_vamb/data/sludge/contigs_2kbp.fna \
--rpkm /home/projects/cpr_10006/projects/semi_vamb/data/sludge/vambout/abundance.npz \
--taxonomy /home/projects/cpr_10006/people/svekut/04_mmseq2/taxonomy_cami_kfold/long_read_sludge_taxonomy_${run_id}.tsv \
-pe 100 \
-pq 25 75 \
--cuda
-pq \
-pt 1024 \
--cuda \
-ploss flat_softmax
13 changes: 7 additions & 6 deletions workflow_vaevae/src/shortread_CAMI2_predictor.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,16 @@ run_id=$2
keyword=$3

# --taxonomy /home/projects/cpr_10006/people/svekut/mmseq2/${dataset}_taxonomy.tsv \

# --taxonomy /home/projects/cpr_10006/people/svekut/mmseq2/${dataset}_taxonomy_${run_id}.tsv \
#
vamb \
--model taxonomy_predictor \
--outdir /home/projects/cpr_10006/people/svekut/cami2_${dataset}_predictor_${keyword}_${run_id}_gpu \
--outdir /home/projects/cpr_10006/people/svekut/cami2_${dataset}_predictor_${keyword}_${run_id}_layers \
--fasta /home/projects/cpr_10006/projects/vamb/data/datasets/cami2_${dataset}/contigs_2kbp.fna.gz \
--rpkm /home/projects/cpr_10006/projects/vamb/data/datasets/cami2_${dataset}/abundance.npz \
--taxonomy /home/projects/cpr_10006/people/svekut/04_mmseq2/taxonomy_cami_kfold/${dataset}_taxonomy_${run_id}.tsv \
-pe 100 \
-pq 25 50 75 \
-pt 256 \
-ploss ${keyword} \
--cuda
-pq \
-pt 1024 \
--cuda \
-ploss ${keyword}
Loading

0 comments on commit aaaa2a7

Please sign in to comment.