From 04389b8a34de2a2f3596bff843789cec353a46ea Mon Sep 17 00:00:00 2001 From: hsin-c <109615347+hsin-c@users.noreply.github.com> Date: Mon, 22 Jan 2024 08:22:34 -0800 Subject: [PATCH] Fix Loss Function to Improve Model Convergence for `AutoEncoder` (#1460) This PR addresses an issue in the dfencoder model related to its convergence behavior. Previously, the model exhibited difficulty in converging when trained exclusively with numerical features. This PR fixes the way different loss types are combined in the model's loss function to ensure that backpropagation works correctly. Note: This may alter the exact values resulting from calling `fit()` on the model. Before, categorical features were weighted much higher than binary or numerical categories (all numerical features shared a combined weight of 1, all binaries features shared a combined weight of 1, and each categorical feature had a weight of 1). Now all features are weighted equally which may impact the trained weights. Closes #1455 Authors: - https://github.com/hsin-c - Michael Demoret (https://github.com/mdemoret-nv) Approvers: - Michael Demoret (https://github.com/mdemoret-nv) URL: https://github.com/nv-morpheus/Morpheus/pull/1460 --- morpheus/models/dfencoder/autoencoder.py | 81 +++++++++++-------- tests/dfencoder/test_autoencoder.py | 25 +++++- .../test_dfencoder_distributed_e2e.py | 61 +++++++------- 3 files changed, 101 insertions(+), 66 deletions(-) diff --git a/morpheus/models/dfencoder/autoencoder.py b/morpheus/models/dfencoder/autoencoder.py index 820362cf8d..df429bbdf5 100644 --- a/morpheus/models/dfencoder/autoencoder.py +++ b/morpheus/models/dfencoder/autoencoder.py @@ -1,4 +1,4 @@ -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -233,6 +233,9 @@ def get_scaler(self, name): } return scalers[name] + def get_feature_count(self): + return len(self.numeric_fts) + len(self.binary_fts) + len(self.categorical_fts) + def _init_numeric(self, df=None): """Initializes the numerical features of the model by either using preset numerical scaler parameters or by using the input data. @@ -626,8 +629,10 @@ def preprocess_data( return preprocessed_data def compute_loss(self, num, bin, cat, target_df, should_log=True, _id=False): + num_target, bin_target, codes = self.compute_targets(target_df) - return self.compute_loss_from_targets( + + mse, bce, cce, net = self.compute_loss_from_targets( num=num, bin=bin, cat=cat, @@ -638,6 +643,10 @@ def compute_loss(self, num, bin, cat, target_df, should_log=True, _id=False): _id=_id, ) + net = net.cpu().item() + + return mse, bce, cce, net + def compute_loss_from_targets(self, num, bin, cat, num_target, bin_target, cat_target, should_log=True, _id=False): """Computes the loss from targets. @@ -670,38 +679,45 @@ def compute_loss_from_targets(self, num, bin, cat, num_target, bin_target, cat_t should_log = True else: should_log = False - net_loss = [] - mse_loss = self.mse(num, num_target) - net_loss += list(mse_loss.mean(dim=0).cpu().detach().numpy()) - mse_loss = mse_loss.mean() - bce_loss = self.bce(bin, bin_target) - net_loss += list(bce_loss.mean(dim=0).cpu().detach().numpy()) - bce_loss = bce_loss.mean() - cce_loss = [] - for i, ft in enumerate(self.categorical_fts): - loss = self.cce(cat[i], cat_target[i]) - loss = loss.mean() - cce_loss.append(loss) - val = loss.cpu().item() - net_loss += [val] + # Calculate the numerical loss (per feature) + mse_loss: torch.Tensor = self.mse(num, num_target).mean(dim=0) + + # Calculate the binary loss (per feature) + bce_loss: torch.Tensor = self.bce(bin, bin_target).mean(dim=0) + + # To calc the categorical loss, we need to average the loss of each categorical feature independently (since + # they will have a different number of categories) + cce_loss_list = [] + + for i in range(len(self.categorical_fts)): + # Take the full mean but ensure the output is a 1x1 tensor to make it easier to concatenate + cce_loss_list.append(self.cce(cat[i], cat_target[i]).mean(dim=0, keepdim=True)) + + if (len(cce_loss_list) > 0): + cce_loss = torch.cat(cce_loss_list) + else: + cce_loss = torch.Tensor().to(self.device) + + # The net loss should have one loss per feature + net_loss = 0 + for loss in [mse_loss, bce_loss, cce_loss]: + if len(loss) > 0: + net_loss += loss.sum() + net_loss /= self.get_feature_count() + if should_log: + # Convert it to a list of numpy + net_loss_list = torch.cat((mse_loss, bce_loss, cce_loss)).tolist() + if self.training: - self.logger.training_step(net_loss) + self.logger.training_step(net_loss_list) elif _id: - self.logger.id_val_step(net_loss) + self.logger.id_val_step(net_loss_list) elif not self.training: - self.logger.val_step(net_loss) - - net_loss = np.array(net_loss).mean() - return mse_loss, bce_loss, cce_loss, net_loss + self.logger.val_step(net_loss_list) - def do_backward(self, mse, bce, cce): - # running `backward()` seperately on mse/bce/cce is equivalent to summing them up and run `backward()` once - loss_fn = mse + bce - for ls in cce: - loss_fn += ls - loss_fn.backward() + return mse_loss.mean(), bce_loss.mean(), cce_loss.mean(), net_loss def compute_baseline_performance(self, in_, out_): """ @@ -729,6 +745,7 @@ def compute_baseline_performance(self, in_, out_): codes_pred.append(pred) mse_loss, bce_loss, cce_loss, net_loss = self.compute_loss(num_pred, bin_pred, codes_pred, out_, should_log=False) + if isinstance(self.logger, BasicLogger): self.logger.baseline_loss = net_loss return net_loss @@ -981,11 +998,11 @@ def _fit_batch(self, input_swapped, num_target, bin_target, cat_target, **kwargs cat_target=cat_target, should_log=True, ) - self.do_backward(mse, bce, cce) + net_loss.backward() self.optim.step() self.optim.zero_grad() - return net_loss + return net_loss.cpu().item() def _compute_baseline_performance_from_dataset(self, validation_dataset): self.eval() @@ -1028,7 +1045,7 @@ def _compute_batch_baseline_performance( cat_target=cat_target, should_log=False ) - return net_loss + return net_loss.cpu().item() def _validate_dataset(self, validation_dataset, rank=None): """Runs a validation loop on the given validation dataset, computing and returning the average loss of both the original @@ -1108,7 +1125,7 @@ def _validate_batch(self, input_original, input_swapped, num_target, bin_target, cat_target=cat_target, should_log=True, ) - return orig_net_loss, net_loss + return orig_net_loss.cpu().item(), net_loss.cpu().item() def _populate_loss_stats_from_dataset(self, dataset): """Populates the `self.feature_loss_stats` dict with feature losses computed using the provided dataset. diff --git a/tests/dfencoder/test_autoencoder.py b/tests/dfencoder/test_autoencoder.py index 1cf16eff49..70a85ec781 100755 --- a/tests/dfencoder/test_autoencoder.py +++ b/tests/dfencoder/test_autoencoder.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -18,6 +18,7 @@ import typing from unittest.mock import patch +import numpy as np import pandas as pd import pytest import torch @@ -374,7 +375,7 @@ def test_auto_encoder_get_anomaly_score(train_ae: autoencoder.AutoEncoder, train train_ae.fit(train_df, epochs=1) anomaly_score = train_ae.get_anomaly_score(train_df) assert len(anomaly_score) == len(train_df) - assert round(anomaly_score.mean().item(), 2) == 2.28 + assert round(anomaly_score.mean().item(), 2) == 2.29 assert round(anomaly_score.std().item(), 2) == 0.11 @@ -478,8 +479,24 @@ def test_auto_encoder_get_results(train_ae: autoencoder.AutoEncoder, train_df: p assert 'max_abs_z' in results.columns assert 'mean_abs_z' in results.columns - assert round(results.loc[0, 'max_abs_z'], 2) == 2.5 + assert np.isclose(results.loc[0, 'max_abs_z'], 2.51, atol=1e-2) # Numpy float has different precision checks than python float, so we wrap it. - assert round(float(results.loc[0, 'mean_abs_z']), 3) == 0.335 + assert np.isclose(results.loc[0, 'mean_abs_z'], 0.361, atol=1e-3) assert results.loc[0, 'z_loss_scaler_type'] == 'z' + + +@pytest.mark.usefixtures("manual_seed") +def test_auto_encoder_num_only_convergence(train_ae: autoencoder.AutoEncoder): + num_df = pd.DataFrame({ + 'num_feat_1': [5.1, 4.9, 4.7, 4.6, 5.0, 5.4, 4.6, 5.0, 4.4, 4.9], + 'num_feat_2': [3.5, 3.0, 3.2, 3.1, 3.6, 3.9, 3.4, 3.4, 2.9, 3.1], + }) + + train_ae.fit(num_df, epochs=50) + + avg_loss = np.sum([np.array(loss[1]) + for loss in train_ae.logger.train_fts.values()], axis=0) / len(train_ae.logger.train_fts) + + # Make sure the model converges with numerical feats only + assert avg_loss[-1] < avg_loss[0] / 2 diff --git a/tests/dfencoder/test_dfencoder_distributed_e2e.py b/tests/dfencoder/test_dfencoder_distributed_e2e.py index bd9d855173..6ec7913ae5 100644 --- a/tests/dfencoder/test_dfencoder_distributed_e2e.py +++ b/tests/dfencoder/test_dfencoder_distributed_e2e.py @@ -1,5 +1,5 @@ #!/usr/bin/env python -# SPDX-FileCopyrightText: Copyright (c) 2022-2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-FileCopyrightText: Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: Apache-2.0 # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -42,43 +42,44 @@ "log_count", "location_incr", "app_incr", + "has_error", ] LOSS_TYPES = ["train", "val", "id_val"] # 75th quantile of the losses from 100 times of offline training LOSS_TARGETS = { "train": { - "log_count": 0.33991, - "location_incr": 0.30789, - "app_incr": 0.17698, - "has_error": 0.00878, - "app_name": 0.13066, - "browser_type": 0.39804, - "os": 0.09882, - "country": 0.06063, - "city": 0.32344, + "log_count": 0.31612, + "location_incr": 0.27285, + "app_incr": 0.13989, + "has_error": 0.00536, + "app_name": 0.13652, + "browser_type": 0.39303, + "os": 0.00115, + "country": 0.00102, + "city": 0.30947 }, "val": { - "log_count": 0.3384, - "location_incr": 0.31456, - "app_incr": 0.16201, - "has_error": 0.00614, - "app_name": 0.11907, - "browser_type": 0.38239, - "os": 0.00064, - "country": 0.0042, - "city": 0.32161, + "log_count": 0.27835, + "location_incr": 0.28686, + "app_incr": 0.13064, + "has_error": 0.00364, + "app_name": 0.13276, + "browser_type": 0.36868, + "os": 2e-05, + "country": 0.00168, + "city": 0.31735 }, "id_val": { - "log_count": 0.07079, - "location_incr": 0.05318, - "app_incr": 0.03659, - "has_error": 0.0046, - "app_name": 0.03542, - "browser_type": 0.0915, - "os": 0.00057, - "country": 0.00343, - "city": 0.08525, - }, + "log_count": 0.04845, + "location_incr": 0.02274, + "app_incr": 0.01639, + "has_error": 0.00255, + "app_name": 0.04597, + "browser_type": 0.08826, + "os": 2e-05, + "country": 0.00146, + "city": 0.07591 + } } LOSS_TOLERANCE_RATIO = 1.25 @@ -146,7 +147,7 @@ def _run_test(rank, world_size): min_cats=1, device=rank, preset_numerical_scaler_params=preset_numerical_scaler_params, - binary_feature_list=[], + binary_feature_list=['has_error'], preset_cats=preset_cats, eval_batch_size=1024, patience=5,