From 1bfe7b721eb225b49c3a6c67e7c62617bae9466d Mon Sep 17 00:00:00 2001 From: drisspg Date: Fri, 20 Dec 2024 13:46:41 -0800 Subject: [PATCH] hacked up stack-info: PR: https://github.com/pytorch-labs/torchfix/pull/88, branch: drisspg/stack/1 --- .../visitors/deprecated_symbols/__init__.py | 6 +- .../deprecated_symbols/size_average.py | 60 +++++++++++++++++++ 2 files changed, 64 insertions(+), 2 deletions(-) create mode 100644 torchfix/visitors/deprecated_symbols/size_average.py diff --git a/torchfix/visitors/deprecated_symbols/__init__.py b/torchfix/visitors/deprecated_symbols/__init__.py index 40885ee..65e221d 100644 --- a/torchfix/visitors/deprecated_symbols/__init__.py +++ b/torchfix/visitors/deprecated_symbols/__init__.py @@ -15,6 +15,7 @@ from .chain_matmul import call_replacement_chain_matmul from .cholesky import call_replacement_cholesky from .qr import call_replacement_qr +from .size_average import call_replacement_loss from .range import call_replacement_range @@ -54,6 +55,7 @@ def _call_replacement( "torch.qr": call_replacement_qr, "torch.cuda.amp.autocast": call_replacement_cuda_amp_autocast, "torch.cpu.amp.autocast": call_replacement_cpu_amp_autocast, + "torch.nn.functional.soft_margin_loss": call_replacement_loss } replacement = None @@ -103,7 +105,8 @@ def visit_Call(self, node) -> None: qualified_name = self.get_qualified_name_for_call(node) if qualified_name is None: return - + self.deprecated_config["torch.nn.functional.soft_margin_loss"] = {} + self.deprecated_config["torch.nn.functional.soft_margin_loss"]["remove_pr"] = None if qualified_name in self.deprecated_config: if self.deprecated_config[qualified_name]["remove_pr"] is None: error_code = self.ERRORS[1].error_code @@ -112,7 +115,6 @@ def visit_Call(self, node) -> None: error_code = self.ERRORS[0].error_code message = self.ERRORS[0].message(old_name=qualified_name) replacement = self._call_replacement(node, qualified_name) - reference = self.deprecated_config[qualified_name].get("reference") if reference is not None: message = f"{message}: {reference}" diff --git a/torchfix/visitors/deprecated_symbols/size_average.py b/torchfix/visitors/deprecated_symbols/size_average.py new file mode 100644 index 0000000..742b548 --- /dev/null +++ b/torchfix/visitors/deprecated_symbols/size_average.py @@ -0,0 +1,60 @@ +"""size_average and reduce are deprecated, please use reduction='mean' instead.""" + +import libcst as cst +from ...common import TorchVisitor, get_module_name +from torch.nn._reduction import legacy_get_string + +def call_replacement_loss(node: cst.Call) -> cst.CSTNode: + """ + Replace loss function that contains size_average / reduce with a new loss function + that uses reduction='mean' instead. Uses the logic from torch.nn._reduction to + determine the correct reduction value. + + Args: + node: The CST Call node representing the loss function call + + Returns: + A new CST node with updated reduction parameter + """ + # Extract existing arguments + input_arg = TorchVisitor.get_specific_arg(node, "input", 0) + target_arg = TorchVisitor.get_specific_arg(node, "target", 1) + + size_average_arg = TorchVisitor.get_specific_arg(node, "size_average", 2) + reduce_arg = TorchVisitor.get_specific_arg(node, "reduce", 3) + + # Ensure input and target args maintain their commas + input_arg = cst.ensure_type(input_arg, cst.Arg).with_changes( + comma=cst.MaybeSentinel.DEFAULT + ) + + target_arg = cst.ensure_type(target_arg, cst.Arg).with_changes( + comma=cst.MaybeSentinel.DEFAULT + ) + + # Extract size_average and reduce values + size_average_value = None + reduce_value = None + + if size_average_arg: + size_average_value = getattr(size_average_arg.value, "value", True) + if reduce_arg: + reduce_value = getattr(reduce_arg.value, "value", True) + + if size_average_value is None and reduce_value is None: + # We want to return the original call as is + return node + # Use legacy_get_string to determine the correct reduction value + reduction = legacy_get_string(size_average_value, reduce_value, emit_warning=False) + + # Create new reduction argument + reduction_arg = cst.Arg( + value=cst.SimpleString(f"'{reduction}'"), + keyword=cst.Name("reduction"), + comma=cst.MaybeSentinel.DEFAULT, + ) + + # Build new arguments list + new_args = [input_arg, target_arg, reduction_arg] + replacement = node.with_changes(args=new_args) + return replacement