Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions torchrec/metrics/metric_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from torchrec.metrics.ne import NEMetric
from torchrec.metrics.ne_positive import NEPositiveMetric
from torchrec.metrics.ne_with_recalibration import RecalibratedNEMetric
from torchrec.metrics.nmse import NMSEMetric
from torchrec.metrics.output import OutputMetric
from torchrec.metrics.precision import PrecisionMetric
from torchrec.metrics.precision_session import PrecisionSessionMetric
Expand Down Expand Up @@ -105,6 +106,7 @@
RecMetricEnum.CALI_FREE_NE: CaliFreeNEMetric,
RecMetricEnum.UNWEIGHTED_NE: UnweightedNEMetric,
RecMetricEnum.HINDSIGHT_TARGET_PR: HindsightTargetPRMetric,
RecMetricEnum.NMSE: NMSEMetric,
}


Expand Down
1 change: 1 addition & 0 deletions torchrec/metrics/metrics_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class RecMetricEnum(RecMetricEnumBase):
CALI_FREE_NE = "cali_free_ne"
UNWEIGHTED_NE = "unweighted_ne"
HINDSIGHT_TARGET_PR = "hindsight_target_pr"
NMSE = "nmse"


@dataclass(unsafe_hash=True, eq=True)
Expand Down
5 changes: 5 additions & 0 deletions torchrec/metrics/metrics_namespace.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class MetricName(MetricNameBase):

EFFECTIVE_SAMPLE_RATE = "effective_sample_rate"

NMSE = "nmse"
NRMSE = "nrmse"


class MetricNamespaceBase(StrValueMixin, Enum):
pass
Expand Down Expand Up @@ -148,6 +151,8 @@ class MetricNamespace(MetricNamespaceBase):
# This is particularly useful for MTML models train with composite pipelines to figure out per-batch blending ratio.
EFFECTIVE_RATE = "effective_rate"

NMSE = "nmse"


class MetricPrefix(StrValueMixin, Enum):
DEFAULT = ""
Expand Down
164 changes: 164 additions & 0 deletions torchrec/metrics/nmse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
#!/usr/bin/env python3
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from typing import Any, cast, Dict, List, Optional, Type

import torch

from torchrec.metrics.metrics_namespace import MetricName, MetricNamespace, MetricPrefix
from torchrec.metrics.mse import (
compute_error_sum,
compute_mse,
compute_rmse,
ERROR_SUM,
get_mse_states,
MSEMetricComputation,
WEIGHTED_NUM_SAMPES,
)
from torchrec.metrics.rec_metric import (
MetricComputationReport,
RecMetric,
RecMetricException,
)

CONST_PRED_ERROR_SUM = "const_pred_error_sum"


def compute_norm(
model_error_sum: torch.Tensor, baseline_error_sum: torch.Tensor
) -> torch.Tensor:
return torch.where(
baseline_error_sum == 0,
torch.tensor(0.0),
model_error_sum / baseline_error_sum,
).double()


def get_norm_mse_states(
labels: torch.Tensor,
predictions: torch.Tensor,
weights: torch.Tensor,
) -> Dict[str, torch.Tensor]:
return {
**get_mse_states(labels, predictions, weights),
**(
{
CONST_PRED_ERROR_SUM: compute_error_sum(
labels, torch.ones_like(labels), weights
)
}
),
}


class NMSEMetricComputation(MSEMetricComputation):
r"""
This class extends the MSEMetricComputation for normalization computation for L2 regression metrics.

The constructor arguments are defined in RecMetricComputation.
See the docstring of RecMetricComputation for more detail.
"""

def __init__(self, *args: Any, **kwargs: Any) -> None:
super().__init__(*args, **kwargs)
self._add_state(
CONST_PRED_ERROR_SUM,
torch.zeros(self._n_tasks, dtype=torch.double),
add_window_state=True,
dist_reduce_fx="sum",
persistent=True,
)

def update(
self,
*,
predictions: Optional[torch.Tensor],
labels: torch.Tensor,
weights: Optional[torch.Tensor],
**kwargs: Dict[str, Any],
) -> None:
if predictions is None or weights is None:
raise RecMetricException(
"Inputs 'predictions' and 'weights' should not be None for NMSEMetricComputation update"
)
states = get_norm_mse_states(labels, predictions, weights)
num_samples = predictions.shape[-1]
for state_name, state_value in states.items():
state = getattr(self, state_name)
state += state_value
self._aggregate_window_state(state_name, state_value, num_samples)

def _compute(self) -> List[MetricComputationReport]:
mse = compute_mse(
cast(torch.Tensor, self.error_sum),
cast(torch.Tensor, self.weighted_num_samples),
)
const_pred_mse = compute_mse(
cast(torch.Tensor, self.const_pred_error_sum),
cast(torch.Tensor, self.weighted_num_samples),
)
nmse = compute_norm(mse, const_pred_mse)

rmse = compute_rmse(
cast(torch.Tensor, self.error_sum),
cast(torch.Tensor, self.weighted_num_samples),
)
const_pred_rmse = compute_rmse(
cast(torch.Tensor, self.const_pred_error_sum),
cast(torch.Tensor, self.weighted_num_samples),
)
nrmse = compute_norm(rmse, const_pred_rmse)

window_mse = compute_mse(
self.get_window_state(ERROR_SUM),
self.get_window_state(WEIGHTED_NUM_SAMPES),
)
window_const_pred_mse = compute_mse(
self.get_window_state(CONST_PRED_ERROR_SUM),
self.get_window_state(WEIGHTED_NUM_SAMPES),
)
window_nmse = compute_norm(window_mse, window_const_pred_mse)

window_rmse = compute_rmse(
self.get_window_state(ERROR_SUM),
self.get_window_state(WEIGHTED_NUM_SAMPES),
)
window_const_pred_rmse = compute_rmse(
self.get_window_state(CONST_PRED_ERROR_SUM),
self.get_window_state(WEIGHTED_NUM_SAMPES),
)
window_nrmse = compute_norm(window_rmse, window_const_pred_rmse)

return [
MetricComputationReport(
name=MetricName.NMSE,
metric_prefix=MetricPrefix.LIFETIME,
value=nmse,
),
MetricComputationReport(
name=MetricName.NRMSE,
metric_prefix=MetricPrefix.LIFETIME,
value=nrmse,
),
MetricComputationReport(
name=MetricName.NMSE,
metric_prefix=MetricPrefix.WINDOW,
value=window_nmse,
),
MetricComputationReport(
name=MetricName.NRMSE,
metric_prefix=MetricPrefix.WINDOW,
value=window_nrmse,
),
]


class NMSEMetric(RecMetric):
_namespace: MetricNamespace = MetricNamespace.NMSE
_computation_class: Type[NMSEMetricComputation] = NMSEMetricComputation
Loading
Loading