Skip to content
41 changes: 28 additions & 13 deletions qa/L0_pytorch_debug_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,19 @@
#
# See LICENSE for license information.

function error_exit() {
echo "Error: $1"
exit 1
}

function test_fail() {
RET=1
FAILED_CASES="$FAILED_CASES $1"
echo "Error: sub-test failed: $1"
}

RET=0
FAILED_CASES=""

: ${TE_PATH:=/opt/transformerengine}
: ${NVTE_TEST_NVINSPECT_FEATURE_DIRS:=$TE_PATH/transformer_engine/debug/features}
Expand All @@ -14,24 +26,27 @@ mkdir -p "$XML_LOG_DIR"
# Nvinspect will be disabled if no feature is active.
: ${NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE:=$TE_PATH/tests/pytorch/debug/test_configs/dummy_feature.yaml}

FAIL=0

# It is not installed as a requirement,
# because it is not available on PyPI.
pip uninstall -y nvdlfw-inspect
pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git

pip install pytest==8.2.1
pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1
pip install pytest==8.2.1 || error_exit "Failed to install pytest"

pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity.xml $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_sanity.py"
pytest -v -s --junitxml=$XML_LOG_DIR/test_config.xml $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_config.py"
pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics.xml $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || test_fail "test_numerics.py"
pytest -v -s --junitxml=$XML_LOG_DIR/test_log.xml $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_log.py"
NVTE_TORCH_COMPILE=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_api_features.xml $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_api_features.py"
pytest -v -s --junitxml=$XML_LOG_DIR/test_perf.xml $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || test_fail "test_perf.py"

# standard sanity and numerics tests with initialized debug
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1

exit $FAIL
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_sanity_2.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "debug test_sanity.py"
NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 NVTE_FUSED_ATTN=0 pytest -v -s --junitxml=$XML_LOG_DIR/test_numerics_2.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "debug test_numerics.py"

if [ "$RET" -ne 0 ]; then
echo "Error in the following test cases:$FAILED_CASES"
exit 1
fi
echo "All tests passed"
exit 0
7 changes: 4 additions & 3 deletions tests/pytorch/debug/run_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,11 +668,12 @@ def _run_test_with_combinations(
_init_distributed()

test_log_expert_parallel()
for parallel_mode in ["column", "row"]:
for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight)

if fp8_available:
for parallel_mode in ["column", "row"]:
for gather_weight in [True, False]:
test_log_distributed(parallel_mode, gather_weight)

for parallel_mode in ["row", "column"]:
test_disable_fp8_layer(parallel_mode)

Expand Down
11 changes: 11 additions & 0 deletions tests/pytorch/debug/test_configs/test_switch_to_nondebug_mode.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
test_switch_to_nondebug_mode:
enabled: True
layers:
layer_name_regex_pattern: .*
transformer_engine:
TestDummyFeature:
enabled: True
inspect_only_once: True
tensors: [weight, activation, gradient, output, wgrad, dgrad]
gemms: [wgrad, dgrad, fprop]

114 changes: 55 additions & 59 deletions tests/pytorch/debug/test_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,74 +6,70 @@
import pytest
import torch
import transformer_engine.pytorch as te
import time

import nvdlfw_inspect.api as debug_api

from transformer_engine.debug.pytorch.debug_state import TEDebugState


def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs):
debug_api.end_debug()
TEDebugState._reset()
if debug_tools_initialized:
# This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS.
# So after 1 warm-up iteration, this layers should work in non-debug mode.
debug_api.initialize(
config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs
)

try:
if layer == "linear":
model = torch.nn.Sequential(
te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2")
).cuda()
NUM_ITERS = 1800
elif layer == "transformer":
model = torch.nn.Sequential(
te.TransformerLayer(1, 1, 1, name="transformer1"),
te.TransformerLayer(1, 1, 1, name="transformer2"),
).cuda()
NUM_ITERS = 200

NUM_INVOCATIONS_PER_ITER = 10
@pytest.mark.parametrize("use_microbatching", [False, True])
def test_layer_switches_to_nondebug_mode(configs_dir, feature_dirs, use_microbatching):
"""
Test that layers switch to non-debug mode when no features are active.

x = torch.randn(1, 1, 1).cuda()
Uses TestDummyFeature with inspect_only_once=True, which makes inspect_tensor_enabled return (False, None).
The TE should:
1. Call inspect_tensor_enabled to check if feature is needed
2. Never call inspect_tensor
3. Allow layers to switch to non-debug mode for optimal performance,
so that inspect_tensor_enabled is never called again.

y = model(x)
y.sum().backward()
debug_api.step()
torch.cuda.synchronize()
Tests both with and without microbatching to ensure proper behavior in both scenarios.
"""

time_start = time.time()
for i in range(NUM_ITERS):
for _ in range(NUM_INVOCATIONS_PER_ITER):
try:
debug_api.initialize(
config_file=configs_dir + "/test_switch_to_nondebug_mode.yaml",
feature_dirs=feature_dirs,
)
import transformer_engine.debug.features._test_dummy_feature as dummy_feature

# Reset counters
dummy_feature._inspect_tensor_enabled_call_count = 0
dummy_feature._inspect_tensor_call_count = 0

model = te.Linear(256, 256, name="test_linear").cuda()
x = torch.randn(8, 256, 256).cuda()

# Run multiple iterations
for i in range(20):
if use_microbatching:
# Alternate between first and non-first microbatch
is_first_microbatch = i % 2 == 0
y = model(x, is_first_microbatch=is_first_microbatch)
else:
# Run without specifying is_first_microbatch
y = model(x)
y.sum().backward()
if debug_tools_initialized:
debug_api.step()
torch.cuda.synchronize()
time_end = time.time()

finally:
if debug_tools_initialized:
debug_api.end_debug()

return time_end - time_start


@pytest.mark.parametrize("layer", ["linear", "transformer"])
def test_cpu_overhead(layer, configs_dir, feature_dirs):
# runs one layer many times on very small tensor
# - gpu time should be negligible, so time should be dominated by cpu time.
# if layers does not invoke any feature in current iteration,
# then it changed into non-debug mode and should not have any non-negligible cpu overhead
# compared to layer without debug tools initialized.

with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs)
without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs)
y.sum().backward()
debug_api.step()

# Verify inspect_tensor_enabled was called only once per tensor
# (activation, weight, gradient, output, wgrad, dgrad)
enabled_call_count = dummy_feature._inspect_tensor_enabled_call_count
microbatch_info = "with microbatching" if use_microbatching else "without microbatching"
assert enabled_call_count == 6, (
f"inspect_tensor_enabled was called {enabled_call_count} times ({microbatch_info}), "
"but should be called 6 times to check if feature is needed for each tensor "
"(activation, weight, gradient, output, wgrad, dgrad)"
)

print(f"with_debug_tools: {with_debug_tools} s")
print(f"without_debug_tools: {without_debug_tools} s")
# Verify inspect_tensor was never called - it should not be called if inspect_tensor_enabled returns (False, None)
inspect_call_count = dummy_feature._inspect_tensor_call_count
assert inspect_call_count == 0, (
f"inspect_tensor was called {inspect_call_count} times ({microbatch_info}), "
"but should never be called when inspect_tensor_enabled returns (False, None)"
)

assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin
finally:
debug_api.end_debug()
TEDebugState._reset()
46 changes: 41 additions & 5 deletions transformer_engine/debug/features/_test_dummy_feature.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,55 @@
from nvdlfw_inspect.registry import Registry, api_method
from transformer_engine.debug.features.api import TEConfigAPIMapper

# Module-level counters for tracking invocations
# NOTE: These must be accessed via the full module path
# (transformer_engine.debug.features._test_dummy_feature._inspect_tensor_enabled_call_count)
# to ensure the same module instance is used when the feature is loaded by the debug framework
# and when imported by tests. Using just the variable name would create separate instances
# in different import contexts.
_inspect_tensor_enabled_call_count = 0
_inspect_tensor_call_count = 0


@Registry.register_feature(namespace="transformer_engine")
class TestDummyFeature(TEConfigAPIMapper):
"""
This is feature used only in tests. It invokes look_at_tensor_before_process
and does nothing.
This is feature used only in tests. It invokes inspect_tensor and does nothing.

If no features are used, then TE layer automatically switches to the non-debug mode.
This feature is invoked for each GEMM to prevent this behavior.

Config options:
- inspect_only_once: if True, return (False, None) from inspect_tensor_enabled to test caching behavior

Note: This feature always tracks invocations for testing purposes.
"""

@api_method
def inspect_tensor_enabled(self, *_args, **_kwargs):
"""API call used to determine whether to run look_at_tensor_before_process
in the forward pass."""
def inspect_tensor_enabled(self, config, *_args, **_kwargs):
"""API call used to determine whether to run inspect_tensor in the forward pass.

Always tracks calls for testing purposes.

Returns:
- If inspect_only_once=True in config: returns (False, None) - check once, never call inspect_tensor
- Otherwise: returns True - feature is always enabled
"""
# Access counter via full module path to ensure we're modifying the same module-level
# variable regardless of import context (debug framework vs test import)
import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self

dummy_feature._inspect_tensor_enabled_call_count += 1

inspect_only_once = config.get("inspect_only_once", False)
if inspect_only_once:
return False, None
return True

@api_method
def inspect_tensor(self, _config, *_args, **_kwargs):
"""This method does nothing but always tracks invocations for testing."""
# Access counter via full module path to ensure shared state across import contexts
import transformer_engine.debug.features._test_dummy_feature as dummy_feature # pylint: disable=import-self

dummy_feature._inspect_tensor_call_count += 1