From bfab0c4bf551c37227af8f2ae4d851baec13e57c Mon Sep 17 00:00:00 2001 From: Jinpu Zhou Date: Fri, 24 Oct 2025 12:06:52 -0700 Subject: [PATCH] add normalised L2 loss metrics (#3482) Summary: Pull Request resolved: https://github.com/meta-pytorch/torchrec/pull/3482 Differential Revision: D85412386 --- torchrec/metrics/metric_module.py | 2 + torchrec/metrics/metrics_config.py | 1 + torchrec/metrics/metrics_namespace.py | 5 + torchrec/metrics/nmse.py | 164 +++++++++++++ torchrec/metrics/tests/test_nmse.py | 336 ++++++++++++++++++++++++++ 5 files changed, 508 insertions(+) create mode 100644 torchrec/metrics/nmse.py create mode 100644 torchrec/metrics/tests/test_nmse.py diff --git a/torchrec/metrics/metric_module.py b/torchrec/metrics/metric_module.py index 8e04cca6e..4ab3bb162 100644 --- a/torchrec/metrics/metric_module.py +++ b/torchrec/metrics/metric_module.py @@ -53,6 +53,7 @@ from torchrec.metrics.ne import NEMetric from torchrec.metrics.ne_positive import NEPositiveMetric from torchrec.metrics.ne_with_recalibration import RecalibratedNEMetric +from torchrec.metrics.nmse import NMSEMetric from torchrec.metrics.output import OutputMetric from torchrec.metrics.precision import PrecisionMetric from torchrec.metrics.precision_session import PrecisionSessionMetric @@ -105,6 +106,7 @@ RecMetricEnum.CALI_FREE_NE: CaliFreeNEMetric, RecMetricEnum.UNWEIGHTED_NE: UnweightedNEMetric, RecMetricEnum.HINDSIGHT_TARGET_PR: HindsightTargetPRMetric, + RecMetricEnum.NMSE: NMSEMetric, } diff --git a/torchrec/metrics/metrics_config.py b/torchrec/metrics/metrics_config.py index eb83538c6..ac04e8e99 100644 --- a/torchrec/metrics/metrics_config.py +++ b/torchrec/metrics/metrics_config.py @@ -50,6 +50,7 @@ class RecMetricEnum(RecMetricEnumBase): CALI_FREE_NE = "cali_free_ne" UNWEIGHTED_NE = "unweighted_ne" HINDSIGHT_TARGET_PR = "hindsight_target_pr" + NMSE = "nmse" @dataclass(unsafe_hash=True, eq=True) diff --git a/torchrec/metrics/metrics_namespace.py b/torchrec/metrics/metrics_namespace.py index 36a03a25a..9dea96b0d 100644 --- a/torchrec/metrics/metrics_namespace.py +++ b/torchrec/metrics/metrics_namespace.py @@ -90,6 +90,9 @@ class MetricName(MetricNameBase): EFFECTIVE_SAMPLE_RATE = "effective_sample_rate" + NMSE = "nmse" + NRMSE = "nrmse" + class MetricNamespaceBase(StrValueMixin, Enum): pass @@ -148,6 +151,8 @@ class MetricNamespace(MetricNamespaceBase): # This is particularly useful for MTML models train with composite pipelines to figure out per-batch blending ratio. EFFECTIVE_RATE = "effective_rate" + NMSE = "nmse" + class MetricPrefix(StrValueMixin, Enum): DEFAULT = "" diff --git a/torchrec/metrics/nmse.py b/torchrec/metrics/nmse.py new file mode 100644 index 000000000..01463520b --- /dev/null +++ b/torchrec/metrics/nmse.py @@ -0,0 +1,164 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +from typing import Any, cast, Dict, List, Optional, Type + +import torch + +from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix +from torchrec.metrics.mse import ( + compute_error_sum, + compute_mse, + compute_rmse, + ERROR_SUM, + get_mse_states, + MSEMetricComputation, + WEIGHTED_NUM_SAMPES, +) +from torchrec.metrics.rec_metric import ( + MetricComputationReport, + RecMetric, + RecMetricException, +) + +CONST_PRED_ERROR_SUM = "const_pred_error_sum" + + +def compute_norm( + model_error_sum: torch.Tensor, baseline_error_sum: torch.Tensor +) -> torch.Tensor: + return torch.where( + baseline_error_sum == 0, + torch.tensor(0.0), + model_error_sum / baseline_error_sum, + ).double() + + +def get_norm_mse_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, +) -> Dict[str, torch.Tensor]: + return { + **get_mse_states(labels, predictions, weights), + **( + { + CONST_PRED_ERROR_SUM: compute_error_sum( + labels, torch.ones_like(labels), weights + ) + } + ), + } + + +class NMSEMetricComputation(MSEMetricComputation): + r""" + This class extends the MSEMetricComputation for normalization computation for L2 regression metrics. + + The constructor arguments are defined in RecMetricComputation. + See the docstring of RecMetricComputation for more detail. + """ + + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self._add_state( + CONST_PRED_ERROR_SUM, + torch.zeros(self._n_tasks, dtype=torch.double), + add_window_state=True, + dist_reduce_fx="sum", + persistent=True, + ) + + def update( + self, + *, + predictions: Optional[torch.Tensor], + labels: torch.Tensor, + weights: Optional[torch.Tensor], + **kwargs: Dict[str, Any], + ) -> None: + if predictions is None or weights is None: + raise RecMetricException( + "Inputs 'predictions' and 'weights' should not be None for NMSEMetricComputation update" + ) + states = get_norm_mse_states(labels, predictions, weights) + num_samples = predictions.shape[-1] + for state_name, state_value in states.items(): + state = getattr(self, state_name) + state += state_value + self._aggregate_window_state(state_name, state_value, num_samples) + + def _compute(self) -> List[MetricComputationReport]: + mse = compute_mse( + cast(torch.Tensor, self.error_sum), + cast(torch.Tensor, self.weighted_num_samples), + ) + const_pred_mse = compute_mse( + cast(torch.Tensor, self.const_pred_error_sum), + cast(torch.Tensor, self.weighted_num_samples), + ) + nmse = compute_norm(mse, const_pred_mse) + + rmse = compute_rmse( + cast(torch.Tensor, self.error_sum), + cast(torch.Tensor, self.weighted_num_samples), + ) + const_pred_rmse = compute_rmse( + cast(torch.Tensor, self.const_pred_error_sum), + cast(torch.Tensor, self.weighted_num_samples), + ) + nrmse = compute_norm(rmse, const_pred_rmse) + + window_mse = compute_mse( + self.get_window_state(ERROR_SUM), + self.get_window_state(WEIGHTED_NUM_SAMPES), + ) + window_const_pred_mse = compute_mse( + self.get_window_state(CONST_PRED_ERROR_SUM), + self.get_window_state(WEIGHTED_NUM_SAMPES), + ) + window_nmse = compute_norm(window_mse, window_const_pred_mse) + + window_rmse = compute_rmse( + self.get_window_state(ERROR_SUM), + self.get_window_state(WEIGHTED_NUM_SAMPES), + ) + window_const_pred_rmse = compute_rmse( + self.get_window_state(CONST_PRED_ERROR_SUM), + self.get_window_state(WEIGHTED_NUM_SAMPES), + ) + window_nrmse = compute_norm(window_rmse, window_const_pred_rmse) + + return [ + MetricComputationReport( + name=MetricName.NMSE, + metric_prefix=MetricPrefix.LIFETIME, + value=nmse, + ), + MetricComputationReport( + name=MetricName.NRMSE, + metric_prefix=MetricPrefix.LIFETIME, + value=nrmse, + ), + MetricComputationReport( + name=MetricName.NMSE, + metric_prefix=MetricPrefix.WINDOW, + value=window_nmse, + ), + MetricComputationReport( + name=MetricName.NRMSE, + metric_prefix=MetricPrefix.WINDOW, + value=window_nrmse, + ), + ] + + +class NMSEMetric(RecMetric): + _namespace: MetricNamespace = MetricNamespace.NMSE + _computation_class: Type[NMSEMetricComputation] = NMSEMetricComputation diff --git a/torchrec/metrics/tests/test_nmse.py b/torchrec/metrics/tests/test_nmse.py new file mode 100644 index 000000000..803598bf5 --- /dev/null +++ b/torchrec/metrics/tests/test_nmse.py @@ -0,0 +1,336 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-strict + +import unittest +from typing import Dict, Iterable, Optional, Type, Union + +import torch +from torch import no_grad +from torchrec.metrics.metrics_config import DefaultTaskInfo +from torchrec.metrics.mse import compute_mse, compute_rmse +from torchrec.metrics.nmse import compute_norm, get_norm_mse_states, NMSEMetric +from torchrec.metrics.rec_metric import RecComputeMode, RecMetric, RecTaskInfo +from torchrec.metrics.test_utils import ( + metric_test_helper, + rec_metric_gpu_sync_test_launcher, + rec_metric_value_test_launcher, + sync_test_helper, + TestMetric, +) + + +WORLD_SIZE = 4 + + +class TestNMSEMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + return get_norm_mse_states(labels, predictions, weights) + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + mse = compute_mse(states["error_sum"], states["weighted_num_samples"]) + const_pred_mse = compute_mse( + states["const_pred_error_sum"], states["weighted_num_samples"] + ) + return compute_norm(mse, const_pred_mse) + + +class TestNRMSEMetric(TestMetric): + @staticmethod + def _get_states( + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + required_inputs_tensor: Optional[torch.Tensor] = None, + ) -> Dict[str, torch.Tensor]: + return get_norm_mse_states(labels, predictions, weights) + + @staticmethod + def _compute(states: Dict[str, torch.Tensor]) -> torch.Tensor: + rmse = compute_rmse(states["error_sum"], states["weighted_num_samples"]) + const_pred_rmse = compute_rmse( + states["const_pred_error_sum"], states["weighted_num_samples"] + ) + return compute_norm(rmse, const_pred_rmse) + + +class NMSEMetricTest(unittest.TestCase): + clazz: Type[RecMetric] = NMSEMetric + nmse_task_name: str = "nmse" + nrmse_task_name: str = "nrmse" + + def test_nmse_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NMSEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNMSEMetric, + metric_name=NMSEMetricTest.nmse_task_name, + task_names=["t1", "t2"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_nmse_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NMSEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestNMSEMetric, + metric_name=NMSEMetricTest.nmse_task_name, + task_names=["t1", "t2"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_nmse_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NMSEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestNMSEMetric, + metric_name=NMSEMetricTest.nmse_task_name, + task_names=["t1", "t2"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_nrmse_unfused(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NMSEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNRMSEMetric, + metric_name=NMSEMetricTest.nrmse_task_name, + task_names=["t1", "t2"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_nrmse_fused_tasks(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NMSEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_COMPUTATION, + test_clazz=TestNRMSEMetric, + metric_name=NMSEMetricTest.nrmse_task_name, + task_names=["t1", "t2"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + def test_nrmse_fused_tasks_and_states(self) -> None: + rec_metric_value_test_launcher( + target_clazz=NMSEMetric, + target_compute_mode=RecComputeMode.FUSED_TASKS_AND_STATES_COMPUTATION, + test_clazz=TestNRMSEMetric, + metric_name=NMSEMetricTest.nrmse_task_name, + task_names=["t1", "t2"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=WORLD_SIZE, + entry_point=metric_test_helper, + ) + + +class NMSEGPUSyncTest(unittest.TestCase): + clazz: Type[RecMetric] = NMSEMetric + task_name: str = "nmse" + + def test_sync_nmse(self) -> None: + rec_metric_gpu_sync_test_launcher( + target_clazz=NMSEMetric, + target_compute_mode=RecComputeMode.UNFUSED_TASKS_COMPUTATION, + test_clazz=TestNMSEMetric, + metric_name=NMSEGPUSyncTest.task_name, + task_names=["t1"], + fused_update_limit=0, + compute_on_all_ranks=False, + should_validate_update=False, + world_size=2, + batch_size=5, + batch_window_size=20, + entry_point=sync_test_helper, + ) + + +def generate_model_outputs_cases() -> Iterable[Dict[str, Union[float, torch.Tensor]]]: + return [ + # Perfect predictions - NMSE should be 0 + { + "labels": torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]), + "predictions": torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]), + "weights": torch.tensor([[1.0] * 5]), + "expected_nmse": torch.tensor([0.0]), + }, + # Constant predictor (all 1.0) - NMSE should be 1.0 + { + "labels": torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]), + "predictions": torch.tensor([[1.0, 1.0, 1.0, 1.0, 1.0]]), + "weights": torch.tensor([[1.0] * 5]), + "expected_nmse": torch.tensor([1.0]), + }, + # Better than constant predictor + { + "labels": torch.tensor([[1.0, 2.0, 3.0]]), + "predictions": torch.tensor([[1.5, 2.0, 2.5]]), + "weights": torch.tensor([[1.0, 1.0, 1.0]]), + "expected_nmse": torch.tensor([0.1]), + }, + # With non-uniform weights + { + "labels": torch.tensor([[1.0, 2.0, 3.0, 4.0]]), + "predictions": torch.tensor([[1.0, 2.0, 3.0, 4.0]]), + "weights": torch.tensor([[0.5, 1.0, 1.5, 2.0]]), + "expected_nmse": torch.tensor([0.0]), + }, + ] + + +class NMSEMetricValueTest(unittest.TestCase): + r"""This set of tests verify the computation logic of NMSE in several + corner cases that we know the computation results. The goal is to + provide some confidence of the correctness of the math formula. + """ + + def setUp(self) -> None: + self.predictions = {"DefaultTask": None} + self.weights = {"DefaultTask": None} + self.labels = {"DefaultTask": None} + self.batches = { + "predictions": self.predictions, + "weights": self.weights, + "labels": self.labels, + } + self.nmse = NMSEMetric( + world_size=1, + my_rank=0, + batch_size=100, + tasks=[DefaultTaskInfo], + ) + + def test_calc_nmse_perfect(self) -> None: + """Test NMSE when predictions are perfect (NMSE should be 0)""" + self.predictions["DefaultTask"] = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + self.labels["DefaultTask"] = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + self.weights["DefaultTask"] = torch.Tensor([[1.0] * 5]) + + expected_nmse = torch.tensor([0.0], dtype=torch.double) + self.nmse.update(**self.batches) + actual_nmse = self.nmse.compute()["nmse-DefaultTask|window_nmse"] + self.assertTrue(torch.allclose(expected_nmse, actual_nmse, atol=1e-6)) + + def test_calc_nmse_constant_predictor(self) -> None: + """Test NMSE when predictions are all constant (NMSE should be 1.0)""" + self.predictions["DefaultTask"] = torch.Tensor([[1.0, 1.0, 1.0, 1.0, 1.0]]) + self.labels["DefaultTask"] = torch.Tensor([[1.0, 2.0, 3.0, 4.0, 5.0]]) + self.weights["DefaultTask"] = torch.Tensor([[1.0] * 5]) + + expected_nmse = torch.tensor([1.0], dtype=torch.double) + self.nmse.update(**self.batches) + actual_nmse = self.nmse.compute()["nmse-DefaultTask|window_nmse"] + self.assertTrue(torch.allclose(expected_nmse, actual_nmse, atol=1e-6)) + + def test_calc_nmse_better_than_baseline(self) -> None: + """Test NMSE when predictions are better than baseline (NMSE should be < 1.0)""" + self.predictions["DefaultTask"] = torch.Tensor([[1.5, 2.0, 2.5]]) + self.labels["DefaultTask"] = torch.Tensor([[1.0, 2.0, 3.0]]) + self.weights["DefaultTask"] = torch.Tensor([[1.0, 1.0, 1.0]]) + + # Model MSE = ((1.5-1)^2 + (2-2)^2 + (2.5-3)^2) / 3 = (0.25 + 0 + 0.25) / 3 = 0.5/3 + # Baseline MSE = ((1-1)^2 + (1-2)^2 + (1-3)^2) / 3 = (0 + 1 + 4) / 3 = 5/3 + # NMSE = (0.5/3) / (5/3) = 0.5/5 = 0.1 + expected_nmse = torch.tensor([0.1], dtype=torch.double) + self.nmse.update(**self.batches) + actual_nmse = self.nmse.compute()["nmse-DefaultTask|window_nmse"] + self.assertTrue(torch.allclose(expected_nmse, actual_nmse, atol=1e-6)) + + +class NMSEThresholdValueTest(unittest.TestCase): + """This set of tests verify the computation logic of NMSE with various scenarios.""" + + @no_grad() + def _test_nmse_helper( + self, + labels: torch.Tensor, + predictions: torch.Tensor, + weights: torch.Tensor, + expected_nmse: torch.Tensor, + ) -> None: + num_task = labels.shape[0] + batch_size = labels.shape[0] + task_list = [] + predictions_dict: Dict[str, torch.Tensor] = {} + labels_dict: Dict[str, torch.Tensor] = {} + weights_dict: Dict[str, torch.Tensor] = {} + + for i in range(num_task): + task_info = RecTaskInfo( + name=f"Task:{i}", + label_name="label", + prediction_name="prediction", + weight_name="weight", + ) + task_list.append(task_info) + predictions_dict[task_info.name] = predictions[i] + labels_dict[task_info.name] = labels[i] + weights_dict[task_info.name] = weights[i] + + nmse = NMSEMetric( + world_size=1, + my_rank=0, + batch_size=batch_size, + tasks=task_list, + ) + nmse.update( + predictions=predictions_dict, + labels=labels_dict, + weights=weights_dict, + ) + actual_nmse = nmse.compute() + + for task_id, task in enumerate(task_list): + cur_actual_nmse = actual_nmse[f"nmse-{task.name}|window_nmse"] + cur_expected_nmse = expected_nmse[task_id].unsqueeze(dim=0) + + torch.testing.assert_close( + cur_actual_nmse, + cur_expected_nmse, + atol=1e-4, + rtol=1e-4, + check_dtype=False, + msg=f"Actual: {cur_actual_nmse}, Expected: {cur_expected_nmse}", + ) + + def test_nmse_values(self) -> None: + test_data = generate_model_outputs_cases() + for inputs in test_data: + try: + # pyre-ignore[6]: All values in generate_model_outputs_cases are torch.Tensor + self._test_nmse_helper(**inputs) + except AssertionError: + print("Assertion error caught with data set ", inputs) + raise