From 5230d932461bb67c295e5b2d2a90a6b07baf8f17 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 13 Oct 2025 09:05:38 +0000 Subject: [PATCH 1/9] code drop Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/run_distributed.py | 4 + tests/pytorch/debug/test_perf.py | 95 ++++++++----------- .../debug/features/_test_dummy_feature.py | 52 +++++++++- 3 files changed, 93 insertions(+), 58 deletions(-) diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py index 716c16056f..4babb821b0 100644 --- a/tests/pytorch/debug/run_distributed.py +++ b/tests/pytorch/debug/run_distributed.py @@ -283,6 +283,8 @@ def _compute_dynamic_range(tensor): @run_debug_test def test_log_distributed(parallel_mode, gather_weight, **kwargs): + if not fp8_available: + return # skip - test requires FP8 _prepare_config_test_log_distributed(kwargs["config_file"]) _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) set_weight_tensor_tp_group_reduce(gather_weight) @@ -366,6 +368,8 @@ def get_stat(tensor, stat): @run_debug_test def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs): + if not fp8_available: + return # skip - test requires FP8 from test_log import LOG_QUANTIZED_CONFIG kwargs["config_file"].write(LOG_QUANTIZED_CONFIG) diff --git a/tests/pytorch/debug/test_perf.py b/tests/pytorch/debug/test_perf.py index 2d4b62b23f..ab598d800f 100644 --- a/tests/pytorch/debug/test_perf.py +++ b/tests/pytorch/debug/test_perf.py @@ -6,71 +6,60 @@ import pytest import torch import transformer_engine.pytorch as te -import time import nvdlfw_inspect.api as debug_api from transformer_engine.debug.pytorch.debug_state import TEDebugState -def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs): +def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs): + """ + Test that layers switch to non-debug mode when no features are active. + + Uses TestDummyFeature with inspect_only_once=True, which makes inspect_tensor_enabled return (False, None). + The TE should: + 1. Call inspect_tensor_enabled to check if feature is needed + 2. Never call inspect_tensor + 3. Allow layers to switch to non-debug mode for optimal performance, + so that inspect_tensor_enabled is never called again. + """ + debug_api.end_debug() TEDebugState._reset() - if debug_tools_initialized: - # This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS. - # So after 1 warm-up iteration, this layers should work in non-debug mode. - debug_api.initialize( - config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs - ) try: - if layer == "linear": - model = torch.nn.Sequential( - te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2") - ).cuda() - NUM_ITERS = 18000 - elif layer == "transformer": - model = torch.nn.Sequential( - te.TransformerLayer(1, 1, 1, name="transformer1"), - te.TransformerLayer(1, 1, 1, name="transformer2"), - ).cuda() - NUM_ITERS = 2000 - - x = torch.randn(1, 1, 1).cuda() + debug_api.initialize( + config_file=configs_dir + "/test_switch_to_nondebug_mode.yaml", + feature_dirs=feature_dirs + ) + from transformer_engine.debug.features._test_dummy_feature import TestDummyFeature + TestDummyFeature.reset_call_counts() - y = model(x) - y.sum().backward() - debug_api.step() - torch.cuda.synchronize() + model = te.Linear(256, 256, name="test_linear").cuda() + x = torch.randn(8, 256, 256).cuda() - time_start = time.time() - for i in range(NUM_ITERS): - y = model(x) + # Run multiple iterations with is_first_microbatch + for i in range(20): + is_first_microbatch = (i % 2 == 0) # Alternate between True and False + y = model(x, is_first_microbatch=is_first_microbatch) y.sum().backward() - if debug_tools_initialized: - debug_api.step() - torch.cuda.synchronize() - time_end = time.time() - - finally: - if debug_tools_initialized: - debug_api.end_debug() - - return time_end - time_start - - -@pytest.mark.parametrize("layer", ["linear", "transformer"]) -def test_cpu_overhead(layer, configs_dir, feature_dirs): - # runs one layer many times on very small tensor - # - gpu time should be negligible, so time should be dominated by cpu time. - # if layers does not invoke any feature in current iteration, - # then it changed into non-debug mode and should not have any non-negligible cpu overhead - # compared to layer without debug tools initialized. - - with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs) - without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs) + debug_api.step() + + # Verify inspect_tensor_enabled was called only once per tensor + # (input, activation, weight, output, wgrad, dgrad) + enabled_call_count = TestDummyFeature.get_inspect_tensor_enabled_call_count() + assert enabled_call_count == 6, ( + "inspect_tensor_enabled should be called to check if feature is needed for each tensor " + "(input, activation, weight, output, wgrad, dgrad)" + ) - print(f"with_debug_tools: {with_debug_tools} s") - print(f"without_debug_tools: {without_debug_tools} s") + # Verify inspect_tensor was never called - it should not be called if inspect_tensor_enabled returns (False, None) + inspect_call_count = TestDummyFeature.get_inspect_tensor_call_count() + assert inspect_call_count == 0, ( + f"inspect_tensor was called {inspect_call_count} times, " + f"but should never be called when inspect_tensor_enabled returns (False, None)" + ) - assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin + finally: + debug_api.end_debug() + TEDebugState._reset() diff --git a/transformer_engine/debug/features/_test_dummy_feature.py b/transformer_engine/debug/features/_test_dummy_feature.py index c8a31a3436..140d8dbd2e 100644 --- a/transformer_engine/debug/features/_test_dummy_feature.py +++ b/transformer_engine/debug/features/_test_dummy_feature.py @@ -7,19 +7,61 @@ from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.api import TEConfigAPIMapper +import transformer_engine + +_inspect_tensor_enabled_call_count = 0 +_inspect_tensor_call_count = 0 @Registry.register_feature(namespace="transformer_engine") class TestDummyFeature(TEConfigAPIMapper): """ - This is feature used only in tests. It invokes look_at_tensor_before_process - and does nothing. + This is feature used only in tests. It invokes inspect_tensor and does nothing. If no features are used, then TE layer automatically switches to the non-debug mode. This feature is invoked for each GEMM to prevent this behavior. + + Config options: + - inspect_only_once: if True, return (False, None) from inspect_tensor_enabled to test caching behavior + + Note: This feature always tracks invocations for testing purposes. """ @api_method - def inspect_tensor_enabled(self, *_args, **_kwargs): - """API call used to determine whether to run look_at_tensor_before_process - in the forward pass.""" + def inspect_tensor_enabled(self, config, *_args, **_kwargs): + """API call used to determine whether to run inspect_tensor in the forward pass. + + Always tracks calls for testing purposes. + + Returns: + - If inspect_only_once=True in config: returns (False, None) - check once, never call inspect_tensor + - Otherwise: returns True - feature is always enabled + """ + transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count += 1 + + inspect_only_once = config.get("inspect_only_once", False) + if inspect_only_once: + return False, None return True + + @api_method + def inspect_tensor(self, config, *_args, **_kwargs): + """This method does nothing but always tracks invocations for testing.""" + transformer_engine.debug.features._test_dummy_feature._inspect_tensor_call_count += 1 + + @classmethod + def reset_call_counts(cls): + """Reset the call counters for testing.""" + transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count = 0 + transformer_engine.debug.features._test_dummy_feature._inspect_tensor_call_count = 0 + + @classmethod + def get_inspect_tensor_enabled_call_count(cls): + """Get the number of times inspect_tensor_enabled was called.""" + transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count + return _inspect_tensor_enabled_call_count + + @classmethod + def get_inspect_tensor_call_count(cls): + """Get the number of times inspect_tensor was called.""" + transformer_engine.debug.features._test_dummy_feature._inspect_tensor_call_count + return _inspect_tensor_call_count From 7de75963b0fb95d2418b93c9101c998747367329 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Oct 2025 09:18:24 +0000 Subject: [PATCH 2/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/test_perf.py | 11 ++++++----- .../debug/features/_test_dummy_feature.py | 15 +++++++++------ 2 files changed, 15 insertions(+), 11 deletions(-) diff --git a/tests/pytorch/debug/test_perf.py b/tests/pytorch/debug/test_perf.py index ab598d800f..866b2d3e74 100644 --- a/tests/pytorch/debug/test_perf.py +++ b/tests/pytorch/debug/test_perf.py @@ -15,7 +15,7 @@ def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs): """ Test that layers switch to non-debug mode when no features are active. - + Uses TestDummyFeature with inspect_only_once=True, which makes inspect_tensor_enabled return (False, None). The TE should: 1. Call inspect_tensor_enabled to check if feature is needed @@ -30,9 +30,10 @@ def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs): try: debug_api.initialize( config_file=configs_dir + "/test_switch_to_nondebug_mode.yaml", - feature_dirs=feature_dirs + feature_dirs=feature_dirs, ) from transformer_engine.debug.features._test_dummy_feature import TestDummyFeature + TestDummyFeature.reset_call_counts() model = te.Linear(256, 256, name="test_linear").cuda() @@ -40,12 +41,12 @@ def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs): # Run multiple iterations with is_first_microbatch for i in range(20): - is_first_microbatch = (i % 2 == 0) # Alternate between True and False + is_first_microbatch = i % 2 == 0 # Alternate between True and False y = model(x, is_first_microbatch=is_first_microbatch) y.sum().backward() debug_api.step() - # Verify inspect_tensor_enabled was called only once per tensor + # Verify inspect_tensor_enabled was called only once per tensor # (input, activation, weight, output, wgrad, dgrad) enabled_call_count = TestDummyFeature.get_inspect_tensor_enabled_call_count() assert enabled_call_count == 6, ( @@ -57,7 +58,7 @@ def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs): inspect_call_count = TestDummyFeature.get_inspect_tensor_call_count() assert inspect_call_count == 0, ( f"inspect_tensor was called {inspect_call_count} times, " - f"but should never be called when inspect_tensor_enabled returns (False, None)" + "but should never be called when inspect_tensor_enabled returns (False, None)" ) finally: diff --git a/transformer_engine/debug/features/_test_dummy_feature.py b/transformer_engine/debug/features/_test_dummy_feature.py index 140d8dbd2e..4690fa99ad 100644 --- a/transformer_engine/debug/features/_test_dummy_feature.py +++ b/transformer_engine/debug/features/_test_dummy_feature.py @@ -12,6 +12,7 @@ _inspect_tensor_enabled_call_count = 0 _inspect_tensor_call_count = 0 + @Registry.register_feature(namespace="transformer_engine") class TestDummyFeature(TEConfigAPIMapper): """ @@ -19,25 +20,27 @@ class TestDummyFeature(TEConfigAPIMapper): If no features are used, then TE layer automatically switches to the non-debug mode. This feature is invoked for each GEMM to prevent this behavior. - + Config options: - inspect_only_once: if True, return (False, None) from inspect_tensor_enabled to test caching behavior - + Note: This feature always tracks invocations for testing purposes. """ @api_method def inspect_tensor_enabled(self, config, *_args, **_kwargs): """API call used to determine whether to run inspect_tensor in the forward pass. - + Always tracks calls for testing purposes. - + Returns: - If inspect_only_once=True in config: returns (False, None) - check once, never call inspect_tensor - Otherwise: returns True - feature is always enabled """ - transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count += 1 - + transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count += ( + 1 + ) + inspect_only_once = config.get("inspect_only_once", False) if inspect_only_once: return False, None From 0af7bc6e16d1c22150ce922c438c3b43a658f005 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 13 Oct 2025 09:19:07 +0000 Subject: [PATCH 3/9] fix Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/run_distributed.py | 12 +++++------- .../test_configs/test_switch_to_nondebug_mode.yaml | 11 +++++++++++ 2 files changed, 16 insertions(+), 7 deletions(-) create mode 100644 tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py index 4babb821b0..f6adcc2aa3 100644 --- a/tests/pytorch/debug/run_distributed.py +++ b/tests/pytorch/debug/run_distributed.py @@ -283,8 +283,6 @@ def _compute_dynamic_range(tensor): @run_debug_test def test_log_distributed(parallel_mode, gather_weight, **kwargs): - if not fp8_available: - return # skip - test requires FP8 _prepare_config_test_log_distributed(kwargs["config_file"]) _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) set_weight_tensor_tp_group_reduce(gather_weight) @@ -368,8 +366,6 @@ def get_stat(tensor, stat): @run_debug_test def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs): - if not fp8_available: - return # skip - test requires FP8 from test_log import LOG_QUANTIZED_CONFIG kwargs["config_file"].write(LOG_QUANTIZED_CONFIG) @@ -672,11 +668,13 @@ def _run_test_with_combinations( _init_distributed() test_log_expert_parallel() - for parallel_mode in ["column", "row"]: - for gather_weight in [True, False]: - test_log_distributed(parallel_mode, gather_weight) if fp8_available: + for parallel_mode in ["column", "row"]: + for gather_weight in [True, False]: + test_log_distributed(parallel_mode, gather_weight) + + for parallel_mode in ["row", "column"]: test_disable_fp8_layer(parallel_mode) diff --git a/tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml b/tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml new file mode 100644 index 0000000000..224be46180 --- /dev/null +++ b/tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml @@ -0,0 +1,11 @@ +test_switch_to_nondebug_mode: + enabled: True + layers: + layer_name_regex_pattern: .* + transformer_engine: + TestDummyFeature: + enabled: True + inspect_only_once: True + tensors: [weight, activation, gradient, output, wgrad, dgrad] + gemms: [wgrad, dgrad, fprop] + From ea3d54ef2ec0eb3a225338030ecb7d9af60bb061 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Mon, 13 Oct 2025 09:38:10 +0000 Subject: [PATCH 4/9] fix: Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_perf.py | 42 +++++++++++-------- .../debug/features/_test_dummy_feature.py | 27 ++++-------- 2 files changed, 34 insertions(+), 35 deletions(-) diff --git a/tests/pytorch/debug/test_perf.py b/tests/pytorch/debug/test_perf.py index 866b2d3e74..b46eb95f57 100644 --- a/tests/pytorch/debug/test_perf.py +++ b/tests/pytorch/debug/test_perf.py @@ -12,7 +12,8 @@ from transformer_engine.debug.pytorch.debug_state import TEDebugState -def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs): +@pytest.mark.parametrize("use_microbatching", [False, True]) +def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs, use_microbatching): """ Test that layers switch to non-debug mode when no features are active. @@ -22,42 +23,49 @@ def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs): 2. Never call inspect_tensor 3. Allow layers to switch to non-debug mode for optimal performance, so that inspect_tensor_enabled is never called again. + + Tests both with and without microbatching to ensure proper behavior in both scenarios. """ - - debug_api.end_debug() - TEDebugState._reset() - + try: debug_api.initialize( config_file=configs_dir + "/test_switch_to_nondebug_mode.yaml", feature_dirs=feature_dirs, ) - from transformer_engine.debug.features._test_dummy_feature import TestDummyFeature - - TestDummyFeature.reset_call_counts() + import transformer_engine.debug.features._test_dummy_feature as dummy_feature + # Reset counters + dummy_feature._inspect_tensor_enabled_call_count = 0 + dummy_feature._inspect_tensor_call_count = 0 model = te.Linear(256, 256, name="test_linear").cuda() x = torch.randn(8, 256, 256).cuda() - # Run multiple iterations with is_first_microbatch + # Run multiple iterations for i in range(20): - is_first_microbatch = i % 2 == 0 # Alternate between True and False - y = model(x, is_first_microbatch=is_first_microbatch) + if use_microbatching: + # Alternate between first and non-first microbatch + is_first_microbatch = i % 2 == 0 + y = model(x, is_first_microbatch=is_first_microbatch) + else: + # Run without specifying is_first_microbatch + y = model(x) y.sum().backward() debug_api.step() # Verify inspect_tensor_enabled was called only once per tensor - # (input, activation, weight, output, wgrad, dgrad) - enabled_call_count = TestDummyFeature.get_inspect_tensor_enabled_call_count() + # (activation, weight, gradient, output, wgrad, dgrad) + enabled_call_count = dummy_feature._inspect_tensor_enabled_call_count + microbatch_info = "with microbatching" if use_microbatching else "without microbatching" assert enabled_call_count == 6, ( - "inspect_tensor_enabled should be called to check if feature is needed for each tensor " - "(input, activation, weight, output, wgrad, dgrad)" + f"inspect_tensor_enabled was called {enabled_call_count} times ({microbatch_info}), " + "but should be called 6 times to check if feature is needed for each tensor " + "(activation, weight, gradient, output, wgrad, dgrad)" ) # Verify inspect_tensor was never called - it should not be called if inspect_tensor_enabled returns (False, None) - inspect_call_count = TestDummyFeature.get_inspect_tensor_call_count() + inspect_call_count = dummy_feature._inspect_tensor_call_count assert inspect_call_count == 0, ( - f"inspect_tensor was called {inspect_call_count} times, " + f"inspect_tensor was called {inspect_call_count} times ({microbatch_info}), " "but should never be called when inspect_tensor_enabled returns (False, None)" ) diff --git a/transformer_engine/debug/features/_test_dummy_feature.py b/transformer_engine/debug/features/_test_dummy_feature.py index 4690fa99ad..c3c020c89a 100644 --- a/transformer_engine/debug/features/_test_dummy_feature.py +++ b/transformer_engine/debug/features/_test_dummy_feature.py @@ -9,6 +9,12 @@ import transformer_engine +# Module-level counters for tracking invocations +# NOTE: These must be accessed via the full module path +# (transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count) +# to ensure the same module instance is used when the feature is loaded by the debug framework +# and when imported by tests. Using just the variable name would create separate instances +# in different import contexts. _inspect_tensor_enabled_call_count = 0 _inspect_tensor_call_count = 0 @@ -37,6 +43,8 @@ def inspect_tensor_enabled(self, config, *_args, **_kwargs): - If inspect_only_once=True in config: returns (False, None) - check once, never call inspect_tensor - Otherwise: returns True - feature is always enabled """ + # Access counter via full module path to ensure we're modifying the same module-level + # variable regardless of import context (debug framework vs test import) transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count += ( 1 ) @@ -49,22 +57,5 @@ def inspect_tensor_enabled(self, config, *_args, **_kwargs): @api_method def inspect_tensor(self, config, *_args, **_kwargs): """This method does nothing but always tracks invocations for testing.""" + # Access counter via full module path to ensure shared state across import contexts transformer_engine.debug.features._test_dummy_feature._inspect_tensor_call_count += 1 - - @classmethod - def reset_call_counts(cls): - """Reset the call counters for testing.""" - transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count = 0 - transformer_engine.debug.features._test_dummy_feature._inspect_tensor_call_count = 0 - - @classmethod - def get_inspect_tensor_enabled_call_count(cls): - """Get the number of times inspect_tensor_enabled was called.""" - transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count - return _inspect_tensor_enabled_call_count - - @classmethod - def get_inspect_tensor_call_count(cls): - """Get the number of times inspect_tensor was called.""" - transformer_engine.debug.features._test_dummy_feature._inspect_tensor_call_count - return _inspect_tensor_call_count From 42b9f503e4f8fa54c810c95b1cb58eb7ccedb125 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 13 Oct 2025 09:39:05 +0000 Subject: [PATCH 5/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/run_distributed.py | 1 - tests/pytorch/debug/test_perf.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py index f6adcc2aa3..5620c8e646 100644 --- a/tests/pytorch/debug/run_distributed.py +++ b/tests/pytorch/debug/run_distributed.py @@ -674,7 +674,6 @@ def _run_test_with_combinations( for gather_weight in [True, False]: test_log_distributed(parallel_mode, gather_weight) - for parallel_mode in ["row", "column"]: test_disable_fp8_layer(parallel_mode) diff --git a/tests/pytorch/debug/test_perf.py b/tests/pytorch/debug/test_perf.py index b46eb95f57..c8c9ae3c1f 100644 --- a/tests/pytorch/debug/test_perf.py +++ b/tests/pytorch/debug/test_perf.py @@ -23,16 +23,17 @@ def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs, use_microbat 2. Never call inspect_tensor 3. Allow layers to switch to non-debug mode for optimal performance, so that inspect_tensor_enabled is never called again. - + Tests both with and without microbatching to ensure proper behavior in both scenarios. """ - + try: debug_api.initialize( config_file=configs_dir + "/test_switch_to_nondebug_mode.yaml", feature_dirs=feature_dirs, ) import transformer_engine.debug.features._test_dummy_feature as dummy_feature + # Reset counters dummy_feature._inspect_tensor_enabled_call_count = 0 dummy_feature._inspect_tensor_call_count = 0 From e361166a988dff95b97f5e0c43e342b36925a301 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 17 Oct 2025 11:45:41 +0000 Subject: [PATCH 6/9] fix Signed-off-by: Pawel Gadzinski --- qa/L0_pytorch_debug_unittest/test.sh | 41 +++++++++++++------ .../debug/features/_test_dummy_feature.py | 8 ++-- 2 files changed, 32 insertions(+), 17 deletions(-) diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index 7f19dda670..fb9829b0d9 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -2,7 +2,19 @@ # # See LICENSE for license information. +function error_exit() { + echo "Error: $1" + exit 1 +} +function test_fail() { + RET=1 + FAILED_CASES="$FAILED_CASES $1" + echo "Error: sub-test failed: $1" +} + +RET=0 +FAILED_CASES="" : ${TE_PATH:=/opt/transformerengine} : ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features} @@ -14,24 +26,27 @@ mkdir -p "$XML_LOG_DIR" # Nvinspect will be disabled if no feature is active. : ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml} -FAIL=0 - # It is not installed as a requirement, # because it is not available on PyPI. pip uninstall -y nvdlfw-inspect pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git -pip install pytest==8.2.1 -pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 -pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 -pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pip install pytest==8.2.1 || error_exit "Failed to install pytest" +pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py" +NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py" +pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 - -exit $FAIL +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" + +if [ "$RET" -ne 0 ]; then + echo "Error in the following test cases:$FAILED_CASES" + exit 1 +fi +echo "All tests passed" +exit 0 diff --git a/transformer_engine/debug/features/_test_dummy_feature.py b/transformer_engine/debug/features/_test_dummy_feature.py index c3c020c89a..1982905993 100644 --- a/transformer_engine/debug/features/_test_dummy_feature.py +++ b/transformer_engine/debug/features/_test_dummy_feature.py @@ -45,9 +45,8 @@ def inspect_tensor_enabled(self, config, *_args, **_kwargs): """ # Access counter via full module path to ensure we're modifying the same module-level # variable regardless of import context (debug framework vs test import) - transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count += ( - 1 - ) + import transformer_engine.debug.features._test_dummy_feature as dummy_feature + dummy_feature._inspect_tensor_enabled_call_count += 1 inspect_only_once = config.get("inspect_only_once", False) if inspect_only_once: @@ -58,4 +57,5 @@ def inspect_tensor_enabled(self, config, *_args, **_kwargs): def inspect_tensor(self, config, *_args, **_kwargs): """This method does nothing but always tracks invocations for testing.""" # Access counter via full module path to ensure shared state across import contexts - transformer_engine.debug.features._test_dummy_feature._inspect_tensor_call_count += 1 + import transformer_engine.debug.features._test_dummy_feature as dummy_feature + dummy_feature._inspect_tensor_call_count += 1 From aa44016184acdbf069aefa2d46ea44fa2aa363c5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 17 Oct 2025 11:48:03 +0000 Subject: [PATCH 7/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/debug/features/_test_dummy_feature.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/transformer_engine/debug/features/_test_dummy_feature.py b/transformer_engine/debug/features/_test_dummy_feature.py index 1982905993..6fd131d5e2 100644 --- a/transformer_engine/debug/features/_test_dummy_feature.py +++ b/transformer_engine/debug/features/_test_dummy_feature.py @@ -46,6 +46,7 @@ def inspect_tensor_enabled(self, config, *_args, **_kwargs): # Access counter via full module path to ensure we're modifying the same module-level # variable regardless of import context (debug framework vs test import) import transformer_engine.debug.features._test_dummy_feature as dummy_feature + dummy_feature._inspect_tensor_enabled_call_count += 1 inspect_only_once = config.get("inspect_only_once", False) @@ -58,4 +59,5 @@ def inspect_tensor(self, config, *_args, **_kwargs): """This method does nothing but always tracks invocations for testing.""" # Access counter via full module path to ensure shared state across import contexts import transformer_engine.debug.features._test_dummy_feature as dummy_feature + dummy_feature._inspect_tensor_call_count += 1 From 3684e4e8d3fd9b6763e163badaacd302fca269cb Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 17 Oct 2025 11:48:21 +0000 Subject: [PATCH 8/9] fix Signed-off-by: Pawel Gadzinski --- qa/L0_pytorch_debug_unittest/test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index fb9829b0d9..ec2ab97536 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -41,8 +41,8 @@ NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py" # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" -NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py" +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" From 6c2ac247847488caabf54f58a695bc340d608cad Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Wed, 22 Oct 2025 12:31:17 +0000 Subject: [PATCH 9/9] fix Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/features/_test_dummy_feature.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/transformer_engine/debug/features/_test_dummy_feature.py b/transformer_engine/debug/features/_test_dummy_feature.py index 6fd131d5e2..4dee97b707 100644 --- a/transformer_engine/debug/features/_test_dummy_feature.py +++ b/transformer_engine/debug/features/_test_dummy_feature.py @@ -7,8 +7,6 @@ from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.api import TEConfigAPIMapper -import transformer_engine - # Module-level counters for tracking invocations # NOTE: These must be accessed via the full module path # (transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count) @@ -45,7 +43,7 @@ def inspect_tensor_enabled(self, config, *_args, **_kwargs): """ # Access counter via full module path to ensure we're modifying the same module-level # variable regardless of import context (debug framework vs test import) - import transformer_engine.debug.features._test_dummy_feature as dummy_feature + import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self dummy_feature._inspect_tensor_enabled_call_count += 1 @@ -55,9 +53,9 @@ def inspect_tensor_enabled(self, config, *_args, **_kwargs): return True @api_method - def inspect_tensor(self, config, *_args, **_kwargs): + def inspect_tensor(self, _config, *_args, **_kwargs): """This method does nothing but always tracks invocations for testing.""" # Access counter via full module path to ensure shared state across import contexts - import transformer_engine.debug.features._test_dummy_feature as dummy_feature + import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self dummy_feature._inspect_tensor_call_count += 1