Skip to content

Commit

Permalink
feat: hierarchical loss configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
sgalkina committed Sep 13, 2023
1 parent 6ed40cf commit 974a0fc
Show file tree
Hide file tree
Showing 14 changed files with 494 additions and 173 deletions.
105 changes: 81 additions & 24 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,21 @@ def __init__(self, nepochs: int, batchsize: int, batchsteps: list[int]):


class PredictorTrainingOptions(VAETrainingOptions):
def __init__(self, nepochs: int, batchsize: int, batchsteps: list[int], softmax_threshold: float):
def __init__(
self,
nepochs: int,
batchsize: int,
batchsteps: list[int],
softmax_threshold: float,
ploss: str,
):
losses = ['flat_softmax', 'cond_softmax', 'soft_margin', 'random_cut']
if (softmax_threshold > 1) or (softmax_threshold < 0):
raise argparse.ArgumentTypeError(f"Softmax threshold should be between 0 and 1, currently {softmax_threshold}")
if ploss not in losses:
raise argparse.ArgumentTypeError(f"Predictor loss needs to be one of {losses}")
self.softmax_threshold = softmax_threshold
self.ploss = ploss
super(PredictorTrainingOptions, self).__init__(
nepochs,
batchsize,
Expand Down Expand Up @@ -684,6 +695,7 @@ def cluster(
latent: np.ndarray,
lengths: np.ndarray,
contignames: Sequence[str], # of dtype object
lengths: Sequence[int], # of dtype object
vamb_options: VambOptions,
logfile: IO[str],
cluster_prefix: str,
Expand Down Expand Up @@ -717,7 +729,8 @@ def cluster(
minsuccesses=cluster_options.min_successes,
destroy=True,
normalized=False,
cuda=vamb_options.cuda,
# cuda=vamb_options.cuda,
cuda=False, # disabled until clustering is fixed
rng_seed=vamb_options.seed,
)

Expand Down Expand Up @@ -928,6 +941,7 @@ def run(
latent,
comp_metadata.lengths,
comp_metadata.identifiers,
comp_metadata.lengths,
vamb_options,
logfile,
"vae_",
Expand Down Expand Up @@ -967,6 +981,7 @@ def run(
latent_z,
comp_metadata.lengths,
comp_metadata.identifiers,
comp_metadata.lengths,
vamb_options,
logfile,
"aae_z_",
Expand Down Expand Up @@ -1049,6 +1064,7 @@ def parse_mmseqs_taxonomy(
log(f"{len(df_mmseq)} lines in taxonomy file", logfile, 1)
log(f"{len(contignames)} contigs", logfile, 1)
ind_map = {c: i for i, c in enumerate(contignames)}
df_mmseq = df_mmseq[df_mmseq[0].isin(contignames)]
indices_mmseq = [ind_map[c] for c in df_mmseq[0]]
graph_column = df_mmseq[8]
species_column = df_mmseq[df_mmseq[2].isin(["species", "subspecies"])][8].str.split(';').str[6]
Expand Down Expand Up @@ -1093,14 +1109,21 @@ def predict_taxonomy(
predictor_training_options: PredictorTrainingOptions,
cuda: bool,
logfile: IO[str],
mask=None,
):
begintime = time.time()
tnfs, lengths = composition.matrix, composition.metadata.lengths
contignames = composition.metadata.identifiers
if mask is not None:
tnfs, lengths = composition.matrix[mask], composition.metadata.lengths[mask]
contignames = composition.metadata.identifiers[mask]
else:
tnfs, lengths = composition.matrix, composition.metadata.lengths
contignames = composition.metadata.identifiers
if hasattr(abundance, 'matrix'):
rpkms = abundance.matrix
else:
rpkms = abundance
if mask is not None:
rpkms = rpkms[mask]

indices_mmseq, graph_column = parse_mmseqs_taxonomy(
taxonomy_path=taxonomy_path,
Expand All @@ -1113,13 +1136,14 @@ def predict_taxonomy(
log(f"{len(nodes)} nodes in the graph", logfile, 1)

classes_order = np.array(list(graph_column.str.split(";").str[-1]))
targets = [ind_nodes[i] for i in classes_order]
targets = np.array([ind_nodes[i] for i in classes_order])

model = vamb.h_loss.VAMB2Label(
rpkms.shape[1],
len(nodes),
nodes,
table_parent,
hier_loss=predictor_training_options.ploss,
cuda=cuda,
)

Expand All @@ -1146,7 +1170,7 @@ def predict_taxonomy(
log(f"Number of sequences remaining: {len(mask_vamb) - n_discarded}", logfile, 1)

names = composition.metadata.identifiers[mask_vamb] # not mutating operation because the composition can be reused
lengths_masked = lengths[mask_vamb]
lengths_masked = lengths

predictortime = time.time()
log(
Expand All @@ -1170,7 +1194,7 @@ def predict_taxonomy(
logfile,
0,
)
predicted_vector = model.predict(dataloader_vamb)
predicted_vector, predicted_labels = model.predict(dataloader_vamb)

log("Writing the taxonomy predictions", logfile, 0)
df_gt = pd.DataFrame({"contigs": names, "lengths": lengths_masked})
Expand All @@ -1179,15 +1203,24 @@ def predict_taxonomy(
log(f"Using threshold {predictor_training_options.softmax_threshold}", logfile, 0)
predictions = []
probs = []
labels = []
for i in range(len(df_gt)):
label = predicted_labels[i]
pred_labels = [label]
while table_parent[label] != -1:
pred_labels.append(table_parent[label])
label = table_parent[label]
pred_labels = ";".join([nodes_ar[l] for l in pred_labels][::-1])
threshold_mask = predicted_vector[i] > predictor_training_options.softmax_threshold
pred_line = ";".join(nodes_ar[threshold_mask][1:])
predictions.append(pred_line)
absolute_probs = predicted_vector[i][threshold_mask]
absolute_prob = ';'.join(map(str, absolute_probs))
probs.append(absolute_prob)
labels.append(pred_labels)
df_gt["predictions"] = predictions
df_gt["abs_probabilities"] = probs
df_gt["predictions_labels"] = labels
predicted_path = out_dir.joinpath("results_taxonomy_predictor.csv")
df_gt.to_csv(predicted_path, index=None)

Expand Down Expand Up @@ -1248,6 +1281,14 @@ def run_vaevae(
else:
rpkms = abundance

dataloader_vamb, mask_vamb = vamb.encode.make_dataloader(
rpkms,
tnfs,
lengths,
batchsize=vae_training_options.batchsize,
cuda=vamb_options.cuda,
)

if (
taxonomy_options.taxonomy_path is not None and not taxonomy_options.no_predictor
):
Expand All @@ -1261,6 +1302,7 @@ def run_vaevae(
predictor_training_options=predictor_training_options,
cuda=vamb_options.cuda,
logfile=logfile,
mask=mask_vamb,
)
predictions_path = vamb_options.out_dir.joinpath("results_taxonomy_predictor.csv")
elif taxonomy_options.taxonomy_predictions_path is not None:
Expand All @@ -1273,13 +1315,13 @@ def run_vaevae(
if predictions_path is not None:
df_gt = pd.read_csv(predictions_path)
graph_column = df_gt["predictions"]
ind_map = {c: i for i, c in enumerate(composition.metadata.identifiers)}
ind_map = {c: i for i, c in enumerate(composition.metadata.identifiers[mask_vamb])}
indices_mmseq = [ind_map[c] for c in df_gt['contigs']]
elif taxonomy_options.no_predictor:
log("Using mmseqs taxonomy for semisupervised learning", logfile, 0)
indices_mmseq, graph_column = parse_mmseqs_taxonomy(
taxonomy_path=taxonomy_options.taxonomy_path,
contignames=composition.metadata.identifiers,
contignames=composition.metadata.identifiers[mask_vamb],
n_species=taxonomy_options.n_species,
logfile=logfile,
)
Expand All @@ -1289,7 +1331,7 @@ def run_vaevae(
nodes, ind_nodes, table_parent = vamb.h_loss.make_graph(graph_column.unique())

classes_order = list(graph_column.str.split(";").str[-1])
missing_nodes_mmseqs = list(set(range(rpkms.shape[0])) - set(indices_mmseq))
missing_nodes_mmseqs = list(set(range(rpkms[mask_vamb].shape[0])) - set(indices_mmseq))
if missing_nodes_mmseqs:
indices_mmseq.extend(missing_nodes_mmseqs)
classes_order.extend(["d_Bacteria"]*len(missing_nodes_mmseqs))
Expand All @@ -1309,34 +1351,30 @@ def run_vaevae(
cuda=vamb_options.cuda,
logfile=logfile,
)
dataloader_vamb, mask_vamb = vamb.encode.make_dataloader(
rpkms,
tnfs,
lengths,
cuda=vamb_options.cuda,
)

log(
f"{len(indices_mmseq)} mmseq indices",
logfile,
0,
)
dataloader_joint, _ = vamb.h_loss.make_dataloader_concat_hloss(
rpkms[indices_mmseq],
tnfs[indices_mmseq],
lengths[indices_mmseq],
rpkms[mask_vamb][indices_mmseq],
tnfs[mask_vamb][indices_mmseq],
lengths[mask_vamb][indices_mmseq],
targets,
len(nodes),
table_parent,
batchsize=vae_training_options.batchsize,
cuda=vamb_options.cuda,
)
dataloader_labels, _ = vamb.h_loss.make_dataloader_labels_hloss(
rpkms[indices_mmseq],
tnfs[indices_mmseq],
lengths[indices_mmseq],
rpkms[mask_vamb][indices_mmseq],
tnfs[mask_vamb][indices_mmseq],
lengths[mask_vamb][indices_mmseq],
targets,
len(nodes),
table_parent,
batchsize=vae_training_options.batchsize,
cuda=vamb_options.cuda,
)

Expand All @@ -1349,6 +1387,7 @@ def run_vaevae(
table_parent,
shapes,
vamb_options.seed,
batchsize=vae_training_options.batchsize,
cuda=vamb_options.cuda,
)
model_path = vamb_options.out_dir.joinpath("vaevae_model.pt")
Expand Down Expand Up @@ -1379,6 +1418,14 @@ def run_vaevae(
)

LATENT_PATH = vamb_options.out_dir.joinpath("vaevae_latent.npy")
# LATENT_PATH = str(vamb_options.out_dir).replace('__fix', '') + "/vaevae_latent.npy"
# latent_both = np.load(LATENT_PATH)
# latent_both = np.concatenate([rpkms[mask_full], tnfs[mask_full]], axis=1)
# log(
# f"{latent_both.shape} embedding shape",
# logfile,
# 0,
# )
np.save(LATENT_PATH, latent_both)

# Cluster, save tsv file
Expand All @@ -1388,6 +1435,7 @@ def run_vaevae(
clusterspath,
latent_both,
composition.metadata.identifiers[indices_mmseq],
composition.metadata.lengths[indices_mmseq],
vamb_options,
logfile,
"vaevae_",
Expand All @@ -1404,8 +1452,8 @@ def run_vaevae(
vamb_options.out_dir,
clusterspath,
path,
composition.metadata.identifiers,
composition.metadata.lengths,
composition.metadata.identifiers[indices_mmseq],
composition.metadata.lengths[indices_mmseq],
vamb_options.min_fasta_output_size,
logfile,
separator=cluster_options.binsplit_separator,
Expand Down Expand Up @@ -1771,6 +1819,14 @@ def main():
default=0.5,
help="conditional probability threshold for accepting the taxonomic prediction [0.5]",
)
pred_trainos.add_argument(
"-ploss",
dest="ploss",
metavar="",
type=str,
default="cond_softmax",
help='Hierarchical loss (one of flat_softmax, cond_softmax, soft_margin, random_cut) ["cond_softmax"]',
)

# AAE arguments
aaeos = parser.add_argument_group(title="AAE options", description=None)
Expand Down Expand Up @@ -2021,6 +2077,7 @@ def main():
batchsize=args.pred_batchsize,
batchsteps=args.pred_batchsteps,
softmax_threshold=args.pred_softmax_threshold,
ploss=args.ploss,
)
else:
assert args.model == "reclustering"
Expand Down
2 changes: 2 additions & 0 deletions vamb/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,8 @@ def __init__(
self.order = _np.argsort(lengths)[::-1]
self.order_index = 0
self.lengths = _torch.Tensor(lengths)
if cuda:
self.lengths = self.lengths.cuda()
self.n_emitted_clusters = 0
self.n_remaining_points = len(torch_matrix)
self.peak_valley_ratio = 0.1
Expand Down
8 changes: 0 additions & 8 deletions vamb/encode.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,15 +283,7 @@ def _encode(self, tensor: Tensor) -> Tensor:
return mu

# sample with gaussian noise
<<<<<<< HEAD
<<<<<<< HEAD
def reparameterize(self, mu: Tensor) -> Tensor:
=======
def reparameterize(self, rng, mu: Tensor, logsigma: Tensor) -> Tensor:
>>>>>>> c2543e2 (Seed more sources of randomness)
=======
def reparameterize(self, mu: Tensor, logsigma: Tensor) -> Tensor:
>>>>>>> e4676f0 (feat: dbscan)
epsilon = _torch.randn(mu.size(0), mu.size(1))

if self.usecuda:
Expand Down
Loading

0 comments on commit 974a0fc

Please sign in to comment.