Skip to content

Commit

Permalink
test: VAEVAE
Browse files Browse the repository at this point in the history
  • Loading branch information
sgalkina committed Oct 24, 2023
1 parent 4eda306 commit 77803ca
Show file tree
Hide file tree
Showing 8 changed files with 314 additions and 234 deletions.
179 changes: 90 additions & 89 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,7 +238,13 @@ def __init__(


class ReclusteringOptions:
__slots__ = ["latent_path", "clusters_path", "hmmout_path", "binsplit_separator", "algorithm"]
__slots__ = [
"latent_path",
"clusters_path",
"hmmout_path",
"binsplit_separator",
"algorithm",
]

def __init__(
self,
Expand Down Expand Up @@ -309,18 +315,22 @@ def __init__(self, nepochs: int, batchsize: int, batchsteps: list[int]):

class PredictorTrainingOptions(VAETrainingOptions):
def __init__(
self,
nepochs: int,
batchsize: int,
batchsteps: list[int],
softmax_threshold: float,
ploss: str,
):
losses = ['flat_softmax', 'cond_softmax', 'soft_margin', 'random_cut']
self,
nepochs: int,
batchsize: int,
batchsteps: list[int],
softmax_threshold: float,
ploss: str,
):
losses = ["flat_softmax", "cond_softmax", "soft_margin", "random_cut"]
if (softmax_threshold > 1) or (softmax_threshold < 0):
raise argparse.ArgumentTypeError(f"Softmax threshold should be between 0 and 1, currently {softmax_threshold}")
raise argparse.ArgumentTypeError(
f"Softmax threshold should be between 0 and 1, currently {softmax_threshold}"
)
if ploss not in losses:
raise argparse.ArgumentTypeError(f"Predictor loss needs to be one of {losses}")
raise argparse.ArgumentTypeError(
f"Predictor loss needs to be one of {losses}"
)
self.softmax_threshold = softmax_threshold
self.ploss = ploss
super(PredictorTrainingOptions, self).__init__(
Expand Down Expand Up @@ -558,7 +568,7 @@ def calc_rpkm(
# I don't want this check in any constructors of abundance, since the constructors
# should be able to skip this check in case comp and abundance are independent.
# But when running the main Vamb workflow, we need to assert this.
if hasattr(abundance, 'nseqs') and abundance.nseqs != comp_metadata.nseqs:
if hasattr(abundance, "nseqs") and abundance.nseqs != comp_metadata.nseqs:
assert not abundance_options.refcheck
raise ValueError(
f"Loaded abundance has {abundance.nseqs} sequences, "
Expand Down Expand Up @@ -729,7 +739,7 @@ def cluster(
destroy=True,
normalized=False,
# cuda=vamb_options.cuda,
cuda=False, # disabled until clustering is fixed
cuda=False, # disabled until clustering is fixed
rng_seed=vamb_options.seed,
)

Expand Down Expand Up @@ -868,9 +878,12 @@ def run(
abundance_options=abundance_options,
logfile=logfile,
)

if hasattr(abundance, "matrix"):
rpkms = abundance.matrix
else:
rpkms = abundance
data_loader, mask = vamb.encode.make_dataloader(
abundance.matrix,
rpkms,
composition.matrix,
composition.metadata.lengths,
256, # dummy value - we change this before using the actual loader
Expand Down Expand Up @@ -1050,7 +1063,7 @@ def run(

def parse_mmseqs_taxonomy(
taxonomy_path: Path,
contignames: list[str], # already masked
contignames: list[str], # already masked
n_species: int,
logfile: IO[str],
) -> Tuple[list[int], list[str]]:
Expand All @@ -1068,8 +1081,8 @@ def parse_mmseqs_taxonomy(
for c in missing_contigs:
new_row = {i: np.nan for i in range(9)}
new_row[0] = c
new_row[2] = 'domain'
new_row[8] = 'd_Bacteria'
new_row[2] = "domain"
new_row[8] = "d_Bacteria"
df_mmseq = pd.concat([df_mmseq, pd.DataFrame([new_row])], ignore_index=True)

df_mmseq.set_index(0, inplace=True)
Expand All @@ -1078,46 +1091,18 @@ def parse_mmseqs_taxonomy(
graph_column = df_mmseq[8]

if list(df_mmseq[0]) != list(contignames):
raise AssertionError(f'The contig names of taxonomy entries are not the same as in the contigs metadata')
raise AssertionError(
f"The contig names of taxonomy entries are not the same as in the contigs metadata"
)

# species_column = df_mmseq[df_mmseq[2].isin(["species", "subspecies"])][8].str.split(';').str[6]
# species_dict = species_column.value_counts()
# unique_species = sorted(
# list(species_column.unique()), key=lambda x: species_dict[x], reverse=True
# )
# log(
# f"Found {len(unique_species)} unique species in mmseqs taxonomy file",
# logfile,
# 1,
# )
# if len(unique_species) > n_species:
# log(
# f"Pruning the taxonomy tree, only keeping {n_species} most abundant species",
# logfile,
# 1,
# )
# log(
# f"Removing the species with less than {species_dict[unique_species[n_species]]} contigs",
# logfile,
# 1,
# )
# non_abundant_species = set(unique_species[n_species:])
# df_mmseq["tax"] = df_mmseq[8]
# df_mmseq.loc[df_mmseq[3].isin(non_abundant_species), "tax"] = (
# df_mmseq.loc[df_mmseq[3].isin(non_abundant_species), 8]
# .str.split(";")
# .str[:1]
# .map(lambda x: ";".join(x))
# )
# graph_column = df_mmseq["tax"]
return graph_column


def predict_taxonomy(
rpkms: np.array,
tnfs: np.array,
lengths: np.array,
contignames: np.array,
rpkms: np.array,
tnfs: np.array,
lengths: np.array,
contignames: np.array,
taxonomy_path: Path,
n_species: int,
out_dir: Path,
Expand Down Expand Up @@ -1194,35 +1179,46 @@ def predict_taxonomy(
logfile,
0,
)
predicted_vector, predicted_labels = model.predict(dataloader_vamb)

log("Writing the taxonomy predictions", logfile, 0)
df_gt = pd.DataFrame({"contigs": contignames, "lengths": lengths})
predicted_path = out_dir.joinpath("results_taxonomy_predictor.csv")
nodes_ar = np.array(nodes)

log(f"Using threshold {predictor_training_options.softmax_threshold}", logfile, 0)
predictions = []
probs = []
labels = []
for i in range(len(df_gt)):
label = predicted_labels[i]
pred_labels = [label]
while table_parent[label] != -1:
pred_labels.append(table_parent[label])
label = table_parent[label]
pred_labels = ";".join([nodes_ar[l] for l in pred_labels][::-1])
threshold_mask = predicted_vector[i] > predictor_training_options.softmax_threshold
pred_line = ";".join(nodes_ar[threshold_mask][1:])
predictions.append(pred_line)
absolute_probs = predicted_vector[i][threshold_mask]
absolute_prob = ';'.join(map(str, absolute_probs))
probs.append(absolute_prob)
labels.append(pred_labels)
df_gt["predictions"] = predictions
df_gt["abs_probabilities"] = probs
df_gt["predictions_labels"] = labels
predicted_path = out_dir.joinpath("results_taxonomy_predictor.csv")
df_gt.to_csv(predicted_path, index=None)
row = 0
for predicted_vector, predicted_labels in model.predict(dataloader_vamb):
N = len(predicted_vector)
predictions = []
probs = []
labels = []
for i in range(predicted_vector.shape[0]):
label = predicted_labels[i]
pred_labels = [label]
while table_parent[label] != -1:
pred_labels.append(table_parent[label])
label = table_parent[label]
pred_labels = ";".join([nodes_ar[l] for l in pred_labels][::-1])
threshold_mask = (
predicted_vector[i] > predictor_training_options.softmax_threshold
)
pred_line = ";".join(nodes_ar[threshold_mask][1:])
predictions.append(pred_line)
absolute_probs = predicted_vector[i][threshold_mask]
absolute_prob = ";".join(map(str, absolute_probs[1:]))
probs.append(absolute_prob)
labels.append(pred_labels)
df_gt = pd.DataFrame(
{"contigs": contignames[row : row + N], "lengths": lengths[row : row + N]}
)
df_gt["predictions"] = predictions
df_gt["probabilities"] = probs
df_gt.to_csv(
predicted_path,
mode="a",
index=None,
header=not os.path.exists(predicted_path),
)
row += N

log(
f"\nCompleted taxonomy predictions in {round(time.time() - begintime, 2)} seconds.",
Expand All @@ -1244,7 +1240,7 @@ def extract_and_filter_data(
logfile=logfile,
)
tnfs, lengths = composition.matrix, composition.metadata.lengths
if hasattr(abundance, 'matrix'):
if hasattr(abundance, "matrix"):
rpkms = abundance.matrix
else:
rpkms = abundance
Expand All @@ -1265,8 +1261,8 @@ def extract_and_filter_data(
0,
)
return (
rpkms[mask_vamb],
tnfs[mask_vamb],
rpkms[mask_vamb],
tnfs[mask_vamb],
composition.metadata.lengths,
composition.metadata.identifiers,
)
Expand Down Expand Up @@ -1319,9 +1315,7 @@ def run_vaevae(
logfile=logfile,
)

if (
taxonomy_options.taxonomy_path is not None and not taxonomy_options.no_predictor
):
if taxonomy_options.taxonomy_path is not None and not taxonomy_options.no_predictor:
log("Predicting missing values from mmseqs taxonomy", logfile, 0)
predict_taxonomy(
rpkms=rpkms,
Expand All @@ -1335,7 +1329,9 @@ def run_vaevae(
cuda=vamb_options.cuda,
logfile=logfile,
)
predictions_path = vamb_options.out_dir.joinpath("results_taxonomy_predictor.csv")
predictions_path = vamb_options.out_dir.joinpath(
"results_taxonomy_predictor.csv"
)
elif taxonomy_options.taxonomy_predictions_path is not None:
log("mmseqs taxonomy predictions are provided", logfile, 0)
predictions_path = taxonomy_options.taxonomy_predictions_path
Expand All @@ -1345,6 +1341,9 @@ def run_vaevae(

if predictions_path is not None:
df_gt = pd.read_csv(predictions_path)
df_gt.loc[
df_gt["predictions"].isna(), "predictions"
] = "d_Bacteria" # TODO: it's a hack
graph_column = df_gt["predictions"]
elif taxonomy_options.no_predictor:
log("Using mmseqs taxonomy for semisupervised learning", logfile, 0)
Expand Down Expand Up @@ -1494,7 +1493,7 @@ def run_reclustering(
logfile=logfile,
)
tnfs, lengths = composition.matrix, composition.metadata.lengths
if hasattr(abundance, 'matrix'):
if hasattr(abundance, "matrix"):
rpkms = abundance.matrix
else:
rpkms = abundance
Expand Down Expand Up @@ -1543,7 +1542,7 @@ def run_reclustering(
clustersfile,
maybe_split,
rename=False,
cluster_prefix='recluster',
cluster_prefix="recluster",
)

print("", file=logfile)
Expand Down Expand Up @@ -1840,8 +1839,8 @@ def main():
dest="ploss",
metavar="",
type=str,
default="cond_softmax",
help='Hierarchical loss (one of flat_softmax, cond_softmax, soft_margin, random_cut) ["cond_softmax"]',
default="flat_softmax",
help='Hierarchical loss (one of flat_softmax, cond_softmax, soft_margin, random_cut) ["flat_softmax"]',
)

# AAE arguments
Expand Down Expand Up @@ -1983,7 +1982,9 @@ def main():
help="path to results_taxonomy_predictor.csv file, output of taxonomy_predictor model",
)
taxonomys.add_argument(
"--no_predictor", help="do not complete mmseqs search with taxonomy predictions [False]", action="store_true"
"--no_predictor",
help="do not complete mmseqs search with taxonomy predictions [False]",
action="store_true",
)
taxonomys.add_argument(
"--n_species",
Expand Down
3 changes: 3 additions & 0 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,10 +390,12 @@ def trainepoch(
for depths_in, tnf_in, abundance_in, weights in data_loader:
depths_in.requires_grad = True
tnf_in.requires_grad = True
abundance_in.requires_grad = True

if self.usecuda:
depths_in = depths_in.cuda()
tnf_in = tnf_in.cuda()
abundance_in = abundance_in.cuda()
weights = weights.cuda()

optimizer.zero_grad()
Expand Down Expand Up @@ -471,6 +473,7 @@ def encode(self, data_loader) -> _np.ndarray:
if self.usecuda:
depths = depths.cuda()
tnf = tnf.cuda()
ab = ab.cuda()

# Evaluate
_, _, _, mu = self(depths, tnf, ab)
Expand Down
Loading

0 comments on commit 77803ca

Please sign in to comment.