Skip to content

Commit

Permalink
feat: add absolute abundance
Browse files Browse the repository at this point in the history
  • Loading branch information
svekut committed Oct 13, 2023
1 parent 3386c41 commit 4eda306
Show file tree
Hide file tree
Showing 6 changed files with 194 additions and 167 deletions.
165 changes: 73 additions & 92 deletions vamb/h_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,28 +71,33 @@ def collate_fn_concat_hloss(num_categories, table_parent, batch):
a = _torch.stack([i[0] for i in batch])
b = _torch.stack([i[1] for i in batch])
c = _torch.stack([i[2] for i in batch])
d = [(i[3],) for i in batch]
return a, b, c, collate_fn_labels_hloss(num_categories, table_parent, d)[0]
d = _torch.stack([i[3] for i in batch])
e = [i[4] for i in batch]
return a, b, c, d, collate_fn_labels_hloss(num_categories, table_parent, e)[0]


def collate_fn_semisupervised_hloss(num_categories, table_parent, batch):
a = _torch.stack([i[0] for i in batch])
b = _torch.stack([i[1] for i in batch])
c = _torch.stack([i[2] for i in batch])
d = [i[3] for i in batch]
e = _torch.stack([i[4] for i in batch])
d = _torch.stack([i[3] for i in batch])
e = [i[4] for i in batch]
f = _torch.stack([i[5] for i in batch])
g = _torch.stack([i[6] for i in batch])
h = [i[7] for i in batch]
h = _torch.stack([i[7] for i in batch])
m = _torch.stack([i[8] for i in batch])
n = [i[9] for i in batch]
return (
a,
b,
c,
collate_fn_labels_hloss(num_categories, table_parent, d)[0],
e,
d,
collate_fn_labels_hloss(num_categories, table_parent, e)[0],
f,
g,
collate_fn_labels_hloss(num_categories, table_parent, h)[0],
h,
m,
collate_fn_labels_hloss(num_categories, table_parent, n)[0],
)


Expand All @@ -107,7 +112,7 @@ def make_dataloader_labels_hloss(
destroy=False,
cuda=False,
):
_, _, _, batchsize, n_workers, cuda, mask = _semisupervised_encode._make_dataset(
_, _, _, _, batchsize, n_workers, cuda, mask = _semisupervised_encode._make_dataset(
rpkm, tnf, lengths, batchsize=batchsize, destroy=destroy, cuda=cuda
)
dataset = _TensorDataset(_torch.Tensor(labels).long())
Expand Down Expand Up @@ -139,6 +144,7 @@ def make_dataloader_concat_hloss(
(
depthstensor,
tnftensor,
total_abundance_tensor,
weightstensor,
batchsize,
n_workers,
Expand All @@ -149,7 +155,7 @@ def make_dataloader_concat_hloss(
)
labelstensor = _torch.Tensor(labels).long()
dataset = _TensorDataset(
depthstensor, tnftensor, weightstensor, labelstensor
depthstensor, tnftensor, total_abundance_tensor, weightstensor, labelstensor
)
dataloader = _DataLoader(
dataset=dataset,
Expand Down Expand Up @@ -194,11 +200,13 @@ def make_dataloader_semisupervised_hloss(
dataloader_vamb.dataset.tensors[0][indices_unsup_vamb],
dataloader_vamb.dataset.tensors[1][indices_unsup_vamb],
dataloader_vamb.dataset.tensors[2][indices_unsup_vamb],
dataloader_vamb.dataset.tensors[3][indices_unsup_vamb],
dataloader_labels.dataset.tensors[0][indices_unsup_labels],
dataloader_joint.dataset.tensors[0][indices_sup],
dataloader_joint.dataset.tensors[1][indices_sup],
dataloader_joint.dataset.tensors[2][indices_sup],
dataloader_joint.dataset.tensors[3][indices_sup],
dataloader_joint.dataset.tensors[4][indices_sup],
)
dataloader_all = _DataLoader(
dataset=dataset_all,
Expand Down Expand Up @@ -344,18 +352,6 @@ def calc_loss(self, labels_in, labels_out, mu, logsigma):
loss = ce_labels * ce_labels_weight + kld * kld_weight
return loss, ce_labels, kld, _torch.tensor(0)

# _, labels_in_indices = labels_in.max(dim=1)
# with _torch.no_grad():
# gt_node = self.eval_label_map.to_node[labels_in_indices.cpu()]
# prob = self.pred_fn(labels_out)
# prob = prob.cpu().numpy()
# pred = _infer.argmax_with_confidence(
# self.specificity, prob, 0.5, self.not_trivial
# )
# pred = _hier.truncate_given_lca(gt_node, pred, self.find_lca(gt_node, pred))

# return loss, ce_labels, kld, _torch.tensor(_np.sum(gt_node == pred))

def trainmodel(
self,
dataloader: _DataLoader[tuple[Tensor]],
Expand Down Expand Up @@ -520,47 +516,44 @@ def calc_loss(
depths_out,
tnf_in,
tnf_out,
abundance_in,
abundance_out,
labels_in,
labels_out,
mu,
logsigma,
weights,
):
# If multiple samples, use cross entropy, else use SSE for abundance
if self.nsamples > 1:
# Add 1e-9 to depths_out to avoid numerical instability.
ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1).mean()
ce_weight = (1 - self.alpha) / _log(self.nsamples)
else:
ce = (depths_out - depths_in).pow(2).sum(dim=1).mean()
ce_weight = 1 - self.alpha
ab_sse = (abundance_out - abundance_in).pow(2).sum(dim=1)
# Add 1e-9 to depths_out to avoid numerical instability.
ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1)
sse = (tnf_out - tnf_in).pow(2).sum(dim=1)
kld = 0.5 * (mu.pow(2)).sum(dim=1)

ce_labels = self.loss_fn(labels_out, labels_in)
# Avoid having the denominator be zero
if self.nsamples == 1:
ce_weight = 0.0
else:
ce_weight = ((1 - self.alpha) * (self.nsamples - 1)) / (
self.nsamples * _log(self.nsamples)
)

ce_labels_weight = 1.0
sse = (tnf_out - tnf_in).pow(2).sum(dim=1).mean()
kld = -0.5 * (1 + logsigma - mu.pow(2) - logsigma.exp()).sum(dim=1).mean()
ab_sse_weight = (1 - self.alpha) * (1 / self.nsamples)
sse_weight = self.alpha / self.ntnf
kld_weight = 1 / (self.nlatent * self.beta)

ce_labels = self.loss_fn(labels_out, labels_in)
ce_labels_weight = 1.0

loss = (
ce * ce_weight
ab_sse*ab_sse_weight
+ ce * ce_weight
+ sse * sse_weight
+ ce_labels * ce_labels_weight
+ kld * kld_weight
) * weights
return loss, ce, sse, ce_labels, kld, _torch.tensor(0)

# _, labels_in_indices = labels_in.max(dim=1)
# with _torch.no_grad():
# gt_node = self.eval_label_map.to_node[labels_in_indices.cpu()]
# prob = self.pred_fn(labels_out)
# prob = prob.cpu().numpy()
# pred = _infer.argmax_with_confidence(
# self.specificity, prob, 0.5, self.not_trivial
# )
# pred = _hier.truncate_given_lca(gt_node, pred, self.find_lca(gt_node, pred))
# return loss, ce, sse, ce_labels, kld, _torch.tensor(_np.sum(gt_node == pred))


def kld_gauss(p_mu, p_logstd, q_mu, q_logstd):
loss = (
Expand Down Expand Up @@ -710,6 +703,8 @@ def calc_loss_joint(
depths_out,
tnf_in,
tnf_out,
abundance_in,
abundance_out,
labels_in,
labels_out,
mu_sup,
Expand All @@ -720,48 +715,43 @@ def calc_loss_joint(
logsigma_labels_unsup,
weights,
):
# If multiple samples, use cross entropy, else use SSE for abundance
if self.VAEVamb.nsamples > 1:
# Add 1e-9 to depths_out to avoid numerical instability.
ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1).mean()
ce_weight = (1 - self.VAEVamb.alpha) / _log(self.VAEVamb.nsamples)
ab_sse = (abundance_out - abundance_in).pow(2).sum(dim=1)
# Add 1e-9 to depths_out to avoid numerical instability.
ce = -((depths_out + 1e-9).log() * depths_in).sum(dim=1)
sse = (tnf_out - tnf_in).pow(2).sum(dim=1)

# Avoid having the denominator be zero
if self.VAEVamb.nsamples == 1:
ce_weight = 0.0
else:
ce = (depths_out - depths_in).pow(2).sum(dim=1).mean()
ce_weight = 1 - self.VAEVamb.alpha
ce_weight = ((1 - self.alpha) * (self.nsamples - 1)) / (
self.nsamples * _log(self.nsamples)
)

ab_sse_weight = (1 - self.alpha) * (1 / self.nsamples)
sse_weight = self.alpha / self.ntnf

# _, labels_in_indices = labels_in.max(dim=1)
# ce_labels = _nn.CrossEntropyLoss()(labels_out, labels_in_indices)
ce_labels = self.VAEJoint.loss_fn(labels_out, labels_in)
ce_labels_weight = 1.0 # TODO: figure out

sse = (tnf_out - tnf_in).pow(2).sum(dim=1)
sse_weight = self.VAEVamb.alpha / self.VAEVamb.ntnf
kld_weight = 1 / (self.VAEVamb.nlatent * self.VAEVamb.beta)
reconstruction_loss = (
ce * ce_weight + sse * sse_weight + ce_labels * ce_labels_weight
)
weighed_ab = ab_sse * ab_sse_weight
weighed_ce = ce * ce_weight
weighed_sse = sse * sse_weight
weighed_labels = ce_labels * ce_labels_weight

reconstruction_loss = weighed_ce + weighed_ab + weighed_sse + weighed_labels

kld_vamb = kld_gauss(mu_sup, logsigma_sup, mu_vamb_unsup, logsigma_vamb_unsup)
kld_labels = kld_gauss(
mu_sup, logsigma_sup, mu_labels_unsup, logsigma_labels_unsup
)
kld = kld_vamb + kld_labels
kld_weight = 1 / (self.nlatent * self.beta)
kld_loss = kld * kld_weight

loss = (reconstruction_loss + kld_loss) * weights
return loss.mean(), ce.mean(), sse.mean(), ce_labels.mean(), kld_vamb.mean(), kld_labels.mean(), _torch.tensor(0)

# _, labels_in_indices = labels_in.max(dim=1)
# with _torch.no_grad():
# gt_node = self.VAEJoint.eval_label_map.to_node[labels_in_indices.cpu()]
# prob = self.VAEJoint.pred_fn(labels_out)
# prob = prob.cpu().numpy()
# pred = _infer.argmax_with_confidence(
# self.VAEJoint.specificity, prob, 0.5, self.VAEJoint.not_trivial
# )
# pred = _hier.truncate_given_lca(gt_node, pred, self.VAEJoint.find_lca(gt_node, pred))
# return loss.mean(), ce.mean(), sse.mean(), ce_labels.mean(), kld_vamb.mean(), kld_labels.mean(), _torch.tensor(_np.sum(gt_node == pred))


class VAMB2Label(_nn.Module):
"""Label predictor, subclass of torch.nn.Module.
Expand Down Expand Up @@ -847,7 +837,7 @@ def __init__(

# Add all other hidden layers
for nin, nout in zip(
[self.nsamples + self.ntnf] + self.nhiddens, self.nhiddens
[self.nsamples + self.ntnf + 1] + self.nhiddens, self.nhiddens
):
self.encoderlayers.append(_nn.Linear(nin, nout))
self.encodernorms.append(_nn.BatchNorm1d(nout))
Expand Down Expand Up @@ -899,35 +889,23 @@ def _predict(self, tensor: Tensor) -> tuple[Tensor, Tensor]:
labels_out = reconstruction.narrow(1, 0, self.nlabels)
return labels_out

def forward(self, depths, tnf, weights):
tensor = _torch.cat((depths, tnf), 1)
def forward(self, depths, tnf, abundances, weights):
tensor = _torch.cat((depths, tnf, abundances), 1)
labels_out = self._predict(tensor)
return labels_out

def calc_loss(self, labels_in, labels_out):
ce_labels = self.loss_fn(labels_out, labels_in)
return ce_labels, _torch.tensor(0)

# _, labels_in_indices = labels_in.max(dim=1)
# with _torch.no_grad():
# gt_node = self.eval_label_map.to_node[labels_in_indices.cpu()]
# prob = self.pred_fn(labels_out)
# prob = prob.cpu().numpy()
# pred = _infer.argmax_with_confidence(
# self.specificity, prob, 0.5, self.not_trivial
# )
# pred = _hier.truncate_given_lca(gt_node, pred, self.find_lca(gt_node, pred))

# return ce_labels, _torch.tensor(_np.sum(gt_node == pred))

def predict(self, data_loader) -> _np.ndarray:
self.eval()

new_data_loader = _encode.set_batchsize(
data_loader, data_loader.batch_size, encode=True
)

depths_array, _, _ = data_loader.dataset.tensors
depths_array, _, _, _ = data_loader.dataset.tensors
length = len(depths_array)

# We make a Numpy array instead of a Torch array because, if we create
Expand All @@ -938,15 +916,16 @@ def predict(self, data_loader) -> _np.ndarray:

row = 0
with _torch.no_grad():
for depths, tnf, weights in new_data_loader:
for depths, tnf, abundances, weights in new_data_loader:
# Move input to GPU if requested
if self.usecuda:
depths = depths.cuda()
tnf = tnf.cuda()
abundances = abundances.cuda()
weights = weights.cuda()

# Evaluate
labels = self(depths, tnf, weights)
labels = self(depths, tnf, abundances, weights)
with _torch.no_grad():
prob = self.pred_fn(labels)
if self.usecuda:
Expand Down Expand Up @@ -1003,20 +982,22 @@ def trainepoch(self, data_loader, epoch, optimizer, batchsteps, logfile):
collate_fn=data_loader.collate_fn,
)

for depths_in, tnf_in, weights, labels_in in data_loader:
for depths_in, tnf_in, abundances_in, weights, labels_in in data_loader:
depths_in.requires_grad = True
tnf_in.requires_grad = True
abundances_in.requires_grad = True
labels_in.requires_grad = True

if self.usecuda:
depths_in = depths_in.cuda()
tnf_in = tnf_in.cuda()
abundances_in = abundances_in.cuda()
weights = weights.cuda()
labels_in = labels_in.cuda()

optimizer.zero_grad()

labels_out = self(depths_in, tnf_in, weights)
labels_out = self(depths_in, tnf_in, abundances_in, weights)

loss, correct_labels = self.calc_loss(labels_in, labels_out)

Expand Down
Loading

0 comments on commit 4eda306

Please sign in to comment.