Skip to content

Commit

Permalink
feat: different hierarchical losses
Browse files Browse the repository at this point in the history
  • Loading branch information
sgalkina committed Sep 6, 2023
1 parent 8fbf6ee commit 9ef817e
Show file tree
Hide file tree
Showing 20 changed files with 652 additions and 109 deletions.
21 changes: 9 additions & 12 deletions vamb/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -1047,7 +1047,7 @@ def parse_mmseqs_taxonomy(
ind_map = {c: i for i, c in enumerate(contignames)}
indices_mmseq = [ind_map[c] for c in df_mmseq[0]]
graph_column = df_mmseq[8]
species_column = df_mmseq[df_mmseq[2] == "species"][3]
species_column = df_mmseq[df_mmseq[2].isin(["species", "subspecies"])][8].str.split(';').str[6]
species_dict = species_column.value_counts()
unique_species = sorted(
list(species_column.unique()), key=lambda x: species_dict[x], reverse=True
Expand Down Expand Up @@ -1172,11 +1172,7 @@ def predict_taxonomy(
df_gt = pd.DataFrame({"contigs": names, "lengths": lengths_masked})
nodes_ar = np.array(nodes)

df_mmseq = pd.read_csv(taxonomy_path, delimiter="\t", header=None)
df_mmseq_sp = df_mmseq[(df_mmseq[2] == "species")]
mmseq_map = {k: v for k, v in zip(df_mmseq_sp[0], df_mmseq_sp[8])}
log(f"Using threshold {predictor_training_options.softmax_threshold}", logfile, 0)

predictions = []
probs = []
for i in range(len(df_gt)):
Expand Down Expand Up @@ -1288,7 +1284,12 @@ def run_vaevae(

nodes, ind_nodes, table_parent = vamb.h_loss.make_graph(graph_column.unique())

classes_order = np.array(list(graph_column.str.split(";").str[-1]))
classes_order = list(graph_column.str.split(";").str[-1])
missing_nodes_mmseqs = list(set(range(rpkms.shape[0])) - set(indices_mmseq))
if missing_nodes_mmseqs:
indices_mmseq.extend(missing_nodes_mmseqs)
classes_order.extend(["d_Bacteria"]*len(missing_nodes_mmseqs))
classes_order = np.array(classes_order)
targets = [ind_nodes[i] for i in classes_order]

vae = vamb.h_loss.VAEVAEHLoss(
Expand Down Expand Up @@ -1382,7 +1383,7 @@ def run_vaevae(
cluster_options,
clusterspath,
latent_both,
composition.metadata.identifiers,
composition.metadata.identifiers[indices_mmseq],
vamb_options,
logfile,
"vaevae_",
Expand Down Expand Up @@ -1906,11 +1907,7 @@ def main():
help="path to results_taxonomy_predictor.csv file, output of taxonomy_predictor model",
)
taxonomys.add_argument(
"--no_predictor",
metavar="",
type=bool,
default=False,
help="do not complete mmseqs search with taxonomy predictions",
"--no_predictor", help="do not complete mmseqs search with taxonomy predictions [False]", action="store_true"
)
taxonomys.add_argument(
"--n_species",
Expand Down
6 changes: 4 additions & 2 deletions vamb/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
from typing import Optional
from collections.abc import Sequence, Iterable

_DEFAULT_RADIUS = 0.06
_DEFAULT_RADIUS = 0.12
# _DEFAULT_RADIUS = 0.06
# Distance within which to search for medoid point
_MEDOID_RADIUS = 0.05
_MEDOID_RADIUS = 0.1
# _MEDOID_RADIUS = 0.05

_DELTA_X = 0.005
_XMAX = 0.3
Expand Down
46 changes: 42 additions & 4 deletions vamb/h_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

import torch as _torch
from torch import nn as _nn
import torch.nn.functional as F
from torch.utils.data.dataset import TensorDataset as _TensorDataset
from torch.utils.data import DataLoader as _DataLoader
from torch.optim import Adam as _Adam
Expand Down Expand Up @@ -267,14 +268,22 @@ def __init__(
self.table_parent = table_parent
self.tree = _hier.Hierarchy(table_parent)
self.loss_fn = _hlosses_fast.HierSoftmaxCrossEntropy(self.tree)
# self.loss_fn = _hlosses_fast.RandomCutLoss(
# self.tree, 0.1, permit_root_cut=False, with_leaf_targets=True)
self.pred_helper = _hlosses_fast.HierLogSoftmax(self.tree)
# self.pred_helper = _hlosses_fast.SumAncestors(self.tree, exclude_root=True)
if self.usecuda:
self.loss_fn = self.loss_fn.cuda()
self.pred_helper = self.pred_helper.cuda()
self.pred_fn = partial(
lambda log_softmax_fn, theta: log_softmax_fn(theta).exp(),
self.pred_helper,
)
# self.pred_fn = partial(
# lambda sum_ancestor_fn, theta: _torch.exp(_hlosses_fast.multilabel_log_likelihood(
# sum_ancestor_fn(theta), replace_root=True, temperature=10.0)),
# self.pred_helper,
# )
self.specificity = -self.tree.num_leaf_descendants()
self.not_trivial = self.tree.num_children() != 1
self.find_lca = _hier.FindLCA(self.tree)
Expand All @@ -289,6 +298,7 @@ def __init__(

def calc_loss(self, labels_in, labels_out, mu, logsigma):
ce_labels = self.loss_fn(labels_out[:, 1:], labels_in)
# ce_labels = self.loss_fn(labels_out, labels_in)

ce_labels_weight = 1.0 # TODO: figure out
kld = -0.5 * (1 + logsigma - mu.pow(2) - logsigma.exp()).sum(dim=1).mean()
Expand Down Expand Up @@ -712,7 +722,7 @@ def __init__(
self.alpha = alpha
self.beta = beta
self.nhiddens = nhiddens
self.nlabels = nlabels
# self.nlabels = nlabels
self.dropout = dropout

# Initialize lists for holding hidden layers
Expand All @@ -728,6 +738,10 @@ def __init__(
self.encoderlayers.append(_nn.Linear(nin, nout))
self.encodernorms.append(_nn.BatchNorm1d(nout))

self.tree = _hier.Hierarchy(table_parent)
self.nlabels = nlabels
self.n_tree_nodes = nlabels
# self.nlabels = self.tree.leaf_mask().nonzero()[0].shape[0]
# Reconstruction (output) layer
self.outputlayer = _nn.Linear(self.nhiddens[0], self.nlabels)
# Activation functions
Expand All @@ -737,19 +751,40 @@ def __init__(
if cuda:
self.cuda()

self.nlabels = nlabels
self.nodes = nodes
self.table_parent = table_parent
self.tree = _hier.Hierarchy(table_parent)

self.loss_fn = _hlosses_fast.HierSoftmaxCrossEntropy(self.tree)
# self.loss_fn = _hlosses_fast.RandomCutLoss(
# self.tree, 0.1, permit_root_cut=False, with_leaf_targets=True)
# self.loss_fn = _hlosses_fast.MarginLoss(
# self.tree, with_leaf_targets=False,
# hardness='soft', margin='incorrect', tau=0.01)
# self.loss_fn = _hlosses_fast.FlatSoftmaxNLL(self.tree)
self.pred_helper = _hlosses_fast.HierLogSoftmax(self.tree)
# self.pred_helper = _hlosses_fast.SumAncestors(self.tree, exclude_root=True)
# self.pred_helper = _hlosses_fast.SumDescendants(self.tree, strict=False)
# self.pred_helper = _hlosses_fast.SumLeafDescendants(self.tree, strict=False)
if self.usecuda:
self.loss_fn = self.loss_fn.cuda()
self.pred_helper = self.pred_helper.cuda()
self.pred_fn = partial(
lambda log_softmax_fn, theta: log_softmax_fn(theta).exp(),
self.pred_helper,
)
# self.pred_fn = partial(
# lambda sum_ancestor_fn, theta: _torch.exp(_hlosses_fast.multilabel_log_likelihood(
# sum_ancestor_fn(theta), replace_root=True, temperature=10.0)),
# self.pred_helper,
# )
# self.pred_fn = partial(
# lambda sum_fn, theta: sum_fn(F.softmax(theta, dim=-1), dim=-1),
# self.pred_helper,
# )
# self.pred_fn = partial(
# lambda sum_fn, theta: sum_fn(F.softmax(theta, dim=-1), dim=-1),
# self.pred_helper,
# )
self.specificity = -self.tree.num_leaf_descendants()
self.not_trivial = self.tree.num_children() != 1
self.find_lca = _hier.FindLCA(self.tree)
Expand Down Expand Up @@ -781,11 +816,13 @@ def forward(self, depths, tnf, weights):

def calc_loss(self, labels_in, labels_out):
ce_labels = self.loss_fn(labels_out[:, 1:], labels_in)
# ce_labels = self.loss_fn(labels_out, labels_in)

_, labels_in_indices = labels_in.max(dim=1)
with _torch.no_grad():
gt_node = self.eval_label_map.to_node[labels_in_indices.cpu()]
prob = self.pred_fn(labels_out[:, 1:])
# prob = self.pred_fn(labels_out)
prob = prob.cpu().numpy()
pred = _infer.argmax_with_confidence(
self.specificity, prob, 0.5, self.not_trivial
Expand All @@ -807,7 +844,7 @@ def predict(self, data_loader) -> _np.ndarray:
# We make a Numpy array instead of a Torch array because, if we create
# a Torch array, then convert it to Numpy, Numpy will believe it doesn't
# own the memory block, and array resizes will not be permitted.
latent = _np.empty((length, self.nlabels), dtype=_np.float32)
latent = _np.empty((length, self.n_tree_nodes), dtype=_np.float32)

row = 0
with _torch.no_grad():
Expand All @@ -822,6 +859,7 @@ def predict(self, data_loader) -> _np.ndarray:
labels = self(depths, tnf, weights)
with _torch.no_grad():
prob = self.pred_fn(labels[:, 1:])
# prob = self.pred_fn(labels)

if self.usecuda:
prob = prob.cpu()
Expand Down
2 changes: 1 addition & 1 deletion vamb/hier.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def edges(self) -> List[Tuple[int, int]]:
def parents(self, root_loop: bool = False) -> np.ndarray:
if root_loop:
return np.where(
self._parents >= 0, self._parents, np.arange(len(self._parents))
np.array(self._parents) >= 0, np.array(self._parents), np.arange(len(self._parents))
)
else:
return np.array(self._parents)
Expand Down
Loading

0 comments on commit 9ef817e

Please sign in to comment.