Skip to content

Commit

Permalink
Fix Loss Function to Improve Model Convergence for AutoEncoder (#1460)
Browse files Browse the repository at this point in the history
This PR addresses an issue in the dfencoder model related to its convergence behavior. Previously, the model exhibited difficulty in converging when trained exclusively with numerical features. 
This PR fixes the way different loss types are combined in the model's loss function to ensure that backpropagation works correctly.

Note: This may alter the exact values resulting from calling `fit()` on the model. Before, categorical features were weighted much higher than binary or numerical categories (all numerical features shared a combined weight of 1, all binaries features shared a combined weight of 1, and each categorical feature had a weight of 1). Now all features are weighted equally which may impact the trained weights.

Closes #1455

Authors:
  - https://github.com/hsin-c
  - Michael Demoret (https://github.com/mdemoret-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #1460
  • Loading branch information
hsin-c authored Jan 22, 2024
1 parent 05d6747 commit 04389b8
Show file tree
Hide file tree
Showing 3 changed files with 101 additions and 66 deletions.
81 changes: 49 additions & 32 deletions morpheus/models/dfencoder/autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -233,6 +233,9 @@ def get_scaler(self, name):
}
return scalers[name]

def get_feature_count(self):
return len(self.numeric_fts) + len(self.binary_fts) + len(self.categorical_fts)

def _init_numeric(self, df=None):
"""Initializes the numerical features of the model by either using preset numerical scaler parameters
or by using the input data.
Expand Down Expand Up @@ -626,8 +629,10 @@ def preprocess_data(
return preprocessed_data

def compute_loss(self, num, bin, cat, target_df, should_log=True, _id=False):

num_target, bin_target, codes = self.compute_targets(target_df)
return self.compute_loss_from_targets(

mse, bce, cce, net = self.compute_loss_from_targets(
num=num,
bin=bin,
cat=cat,
Expand All @@ -638,6 +643,10 @@ def compute_loss(self, num, bin, cat, target_df, should_log=True, _id=False):
_id=_id,
)

net = net.cpu().item()

return mse, bce, cce, net

def compute_loss_from_targets(self, num, bin, cat, num_target, bin_target, cat_target, should_log=True, _id=False):
"""Computes the loss from targets.
Expand Down Expand Up @@ -670,38 +679,45 @@ def compute_loss_from_targets(self, num, bin, cat, num_target, bin_target, cat_t
should_log = True
else:
should_log = False
net_loss = []
mse_loss = self.mse(num, num_target)
net_loss += list(mse_loss.mean(dim=0).cpu().detach().numpy())
mse_loss = mse_loss.mean()
bce_loss = self.bce(bin, bin_target)

net_loss += list(bce_loss.mean(dim=0).cpu().detach().numpy())
bce_loss = bce_loss.mean()
cce_loss = []
for i, ft in enumerate(self.categorical_fts):
loss = self.cce(cat[i], cat_target[i])
loss = loss.mean()
cce_loss.append(loss)
val = loss.cpu().item()
net_loss += [val]
# Calculate the numerical loss (per feature)
mse_loss: torch.Tensor = self.mse(num, num_target).mean(dim=0)

# Calculate the binary loss (per feature)
bce_loss: torch.Tensor = self.bce(bin, bin_target).mean(dim=0)

# To calc the categorical loss, we need to average the loss of each categorical feature independently (since
# they will have a different number of categories)
cce_loss_list = []

for i in range(len(self.categorical_fts)):
# Take the full mean but ensure the output is a 1x1 tensor to make it easier to concatenate
cce_loss_list.append(self.cce(cat[i], cat_target[i]).mean(dim=0, keepdim=True))

if (len(cce_loss_list) > 0):
cce_loss = torch.cat(cce_loss_list)
else:
cce_loss = torch.Tensor().to(self.device)

# The net loss should have one loss per feature
net_loss = 0
for loss in [mse_loss, bce_loss, cce_loss]:
if len(loss) > 0:
net_loss += loss.sum()
net_loss /= self.get_feature_count()

if should_log:
# Convert it to a list of numpy
net_loss_list = torch.cat((mse_loss, bce_loss, cce_loss)).tolist()

if self.training:
self.logger.training_step(net_loss)
self.logger.training_step(net_loss_list)
elif _id:
self.logger.id_val_step(net_loss)
self.logger.id_val_step(net_loss_list)
elif not self.training:
self.logger.val_step(net_loss)

net_loss = np.array(net_loss).mean()
return mse_loss, bce_loss, cce_loss, net_loss
self.logger.val_step(net_loss_list)

def do_backward(self, mse, bce, cce):
# running `backward()` seperately on mse/bce/cce is equivalent to summing them up and run `backward()` once
loss_fn = mse + bce
for ls in cce:
loss_fn += ls
loss_fn.backward()
return mse_loss.mean(), bce_loss.mean(), cce_loss.mean(), net_loss

def compute_baseline_performance(self, in_, out_):
"""
Expand Down Expand Up @@ -729,6 +745,7 @@ def compute_baseline_performance(self, in_, out_):
codes_pred.append(pred)
mse_loss, bce_loss, cce_loss, net_loss = self.compute_loss(num_pred, bin_pred, codes_pred, out_,
should_log=False)

if isinstance(self.logger, BasicLogger):
self.logger.baseline_loss = net_loss
return net_loss
Expand Down Expand Up @@ -981,11 +998,11 @@ def _fit_batch(self, input_swapped, num_target, bin_target, cat_target, **kwargs
cat_target=cat_target,
should_log=True,
)
self.do_backward(mse, bce, cce)
net_loss.backward()
self.optim.step()
self.optim.zero_grad()

return net_loss
return net_loss.cpu().item()

def _compute_baseline_performance_from_dataset(self, validation_dataset):
self.eval()
Expand Down Expand Up @@ -1028,7 +1045,7 @@ def _compute_batch_baseline_performance(
cat_target=cat_target,
should_log=False
)
return net_loss
return net_loss.cpu().item()

def _validate_dataset(self, validation_dataset, rank=None):
"""Runs a validation loop on the given validation dataset, computing and returning the average loss of both the original
Expand Down Expand Up @@ -1108,7 +1125,7 @@ def _validate_batch(self, input_original, input_swapped, num_target, bin_target,
cat_target=cat_target,
should_log=True,
)
return orig_net_loss, net_loss
return orig_net_loss.cpu().item(), net_loss.cpu().item()

def _populate_loss_stats_from_dataset(self, dataset):
"""Populates the `self.feature_loss_stats` dict with feature losses computed using the provided dataset.
Expand Down
25 changes: 21 additions & 4 deletions tests/dfencoder/test_autoencoder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -18,6 +18,7 @@
import typing
from unittest.mock import patch

import numpy as np
import pandas as pd
import pytest
import torch
Expand Down Expand Up @@ -374,7 +375,7 @@ def test_auto_encoder_get_anomaly_score(train_ae: autoencoder.AutoEncoder, train
train_ae.fit(train_df, epochs=1)
anomaly_score = train_ae.get_anomaly_score(train_df)
assert len(anomaly_score) == len(train_df)
assert round(anomaly_score.mean().item(), 2) == 2.28
assert round(anomaly_score.mean().item(), 2) == 2.29
assert round(anomaly_score.std().item(), 2) == 0.11


Expand Down Expand Up @@ -478,8 +479,24 @@ def test_auto_encoder_get_results(train_ae: autoencoder.AutoEncoder, train_df: p
assert 'max_abs_z' in results.columns
assert 'mean_abs_z' in results.columns

assert round(results.loc[0, 'max_abs_z'], 2) == 2.5
assert np.isclose(results.loc[0, 'max_abs_z'], 2.51, atol=1e-2)

# Numpy float has different precision checks than python float, so we wrap it.
assert round(float(results.loc[0, 'mean_abs_z']), 3) == 0.335
assert np.isclose(results.loc[0, 'mean_abs_z'], 0.361, atol=1e-3)
assert results.loc[0, 'z_loss_scaler_type'] == 'z'


@pytest.mark.usefixtures("manual_seed")
def test_auto_encoder_num_only_convergence(train_ae: autoencoder.AutoEncoder):
num_df = pd.DataFrame({
'num_feat_1': [5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9],
'num_feat_2': [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1],
})

train_ae.fit(num_df, epochs=50)

avg_loss = np.sum([np.array(loss[1])
for loss in train_ae.logger.train_fts.values()], axis=0) / len(train_ae.logger.train_fts)

# Make sure the model converges with numerical feats only
assert avg_loss[-1] < avg_loss[0] / 2
61 changes: 31 additions & 30 deletions tests/dfencoder/test_dfencoder_distributed_e2e.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/usr/bin/env python
# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -42,43 +42,44 @@
"log_count",
"location_incr",
"app_incr",
"has_error",
]
LOSS_TYPES = ["train", "val", "id_val"]
# 75th quantile of the losses from 100 times of offline training
LOSS_TARGETS = {
"train": {
"log_count": 0.33991,
"location_incr": 0.30789,
"app_incr": 0.17698,
"has_error": 0.00878,
"app_name": 0.13066,
"browser_type": 0.39804,
"os": 0.09882,
"country": 0.06063,
"city": 0.32344,
"log_count": 0.31612,
"location_incr": 0.27285,
"app_incr": 0.13989,
"has_error": 0.00536,
"app_name": 0.13652,
"browser_type": 0.39303,
"os": 0.00115,
"country": 0.00102,
"city": 0.30947
},
"val": {
"log_count": 0.3384,
"location_incr": 0.31456,
"app_incr": 0.16201,
"has_error": 0.00614,
"app_name": 0.11907,
"browser_type": 0.38239,
"os": 0.00064,
"country": 0.0042,
"city": 0.32161,
"log_count": 0.27835,
"location_incr": 0.28686,
"app_incr": 0.13064,
"has_error": 0.00364,
"app_name": 0.13276,
"browser_type": 0.36868,
"os": 2e-05,
"country": 0.00168,
"city": 0.31735
},
"id_val": {
"log_count": 0.07079,
"location_incr": 0.05318,
"app_incr": 0.03659,
"has_error": 0.0046,
"app_name": 0.03542,
"browser_type": 0.0915,
"os": 0.00057,
"country": 0.00343,
"city": 0.08525,
},
"log_count": 0.04845,
"location_incr": 0.02274,
"app_incr": 0.01639,
"has_error": 0.00255,
"app_name": 0.04597,
"browser_type": 0.08826,
"os": 2e-05,
"country": 0.00146,
"city": 0.07591
}
}
LOSS_TOLERANCE_RATIO = 1.25

Expand Down Expand Up @@ -146,7 +147,7 @@ def _run_test(rank, world_size):
min_cats=1,
device=rank,
preset_numerical_scaler_params=preset_numerical_scaler_params,
binary_feature_list=[],
binary_feature_list=['has_error'],
preset_cats=preset_cats,
eval_batch_size=1024,
patience=5,
Expand Down

0 comments on commit 04389b8

Please sign in to comment.