Skip to content

Conversation

@kevalmorabia97
Copy link
Collaborator

@kevalmorabia97 kevalmorabia97 commented Dec 22, 2025

What does this PR do?

Type of change: New feature

  • So far users we didnt have the NAS step from the Minitron paper so users had to manually find a pruned architecture which fits that params constraints and for that they have to tweak different combinations of width and depth pruning
  • This PR adds a simplified version of the NAS from paper. We first find all candidate subnets that fit the user's params constraint and sort them by parameter count.
    • [Paper] Then we pick top K candidates, do distillation for ~2B tokens and then select the one with best score (LM Loss / MMLU / other metric we care about). Note that the one among these top K with highest params is often not the best pruned model
    • [ModelOpt] Then we pick top K candidates, and select the one with best score (LM Loss / MMLU / other metric we care about). While doing KD gives better indication on which one to pick, skipping it makes the pruning much faster, much less compute intense, and finish everything in single prune API instead of first exporting top K models, doing KD and eval for all K models separately. We do print a Note in pruning step to let users know this so they can do KD if they want slightly better pruned model.
    • Further full KD is still needed as usual
  • We also restrict the search space choices (e.g. hidden_size multiple of 256, ffn_hidden_size multiple of 512) to make the process efficient. Users can configure this if they want to.

Usage

Pruning API is same as before:

import modelopt.torch.prune as mtp
mtp.prune(
    model,
    mode="mcore_minitron",
    constraints=constraints,
    dummy_input=None,  # Not used
    config=config,
)
  1. Manual Pruning (Existing):
constraints = {"export_config": {"hidden_size: 3072", "ffn_hidden_size": 9216}}
config = {"forward_loop": forward_loop, "checkpoint": "/path/to/cache/pruning/scores.pth"}
mtp.prune(...)
  1. NAS-based Auto Pruning (New):
constraints = {"params": 6e9}. # prune to 6B params
config = {"forward_loop": forward_loop, "checkpoint": "/path/to/cache/pruning/scores.pth"}

# define the score_func to maximize (e.g MMLU, negative val loss, etc.)
from modelopt.torch.utils.plugins.megatron_mmlu import megatron_mmlu
def score_func(m):
    return megatron_mmlu(m, tokenizer, percentage=0.05)  # 5% sampled data for faster eval
config["score_func"] = score_func

# overwrite search space choices (showing defaults):
config["max_width_pruning"] = 0.4
config["max_depth_pruning"] = 0.2
config["hparams_to_skip"] = [] # can be used to disable pruning some hparams e.g. ["num_attention_heads"]
config["top_k"] = 10 # might be better to use 20 at the cost of longer time to prune

mtp.prune(...)

To configure search space (shows defaults):

ss_config = mtp.mcore_minitron.get_mcore_minitron_config(
    hidden_size_divisor=256,
    ffn_hidden_size_divisor=512,
    mamba_head_dim_divisor=8,
    num_moe_experts_divisor=8,
    num_layers_divisor=2,
)
mtp.prune(model, mode=[("mcore_minitron", ss_config)], ....)

Testing

Qwen3-8B -> 6B (~2 hours on 8xA5000)

     0.4350 score -> {'num_layers': 34, 'hidden_size': 3328, 'ffn_hidden_size': 11264}
BEST 0.5705 score -> {'num_layers': 30, 'hidden_size': 3584, 'ffn_hidden_size': 11776}
     0.4051 score -> {'num_layers': 36, 'hidden_size': 3840, 'ffn_hidden_size': 8192}
     0.4593 score -> {'num_layers': 36, 'hidden_size': 3584, 'ffn_hidden_size': 9216}
     0.2737 score -> {'num_layers': 36, 'hidden_size': 3072, 'ffn_hidden_size': 11776}
     0.5556 score -> {'num_layers': 32, 'hidden_size': 3584, 'ffn_hidden_size': 10752}
     0.3198 score -> {'num_layers': 28, 'hidden_size': 4096, 'ffn_hidden_size': 10240}
     0.4119 score -> {'num_layers': 36, 'hidden_size': 4096, 'ffn_hidden_size': 7168}
     0.3808 score -> {'num_layers': 36, 'hidden_size': 3328, 'ffn_hidden_size': 10240}
     0.4783 score -> {'num_layers': 34, 'hidden_size': 3840, 'ffn_hidden_size': 8704}

Nemotron-Nano-9B-v2 -> 7B (~2.5 hours on 8xA5000)

     0.2629 score -> {'num_layers': 54, 'hidden_size': 4352, 'mamba_num_heads':  88, 'mamba_head_dim': 72, 'ffn_hidden_size': 15360}
     0.2778 score -> {'num_layers': 50, 'hidden_size': 4480, 'mamba_num_heads': 128, 'mamba_head_dim': 56, 'ffn_hidden_size': 15680}
     0.5041 score -> {'num_layers': 56, 'hidden_size': 4096, 'mamba_num_heads':  96, 'mamba_head_dim': 80, 'ffn_hidden_size': 14336}
BEST 0.6043 score -> {'num_layers': 48, 'hidden_size': 4352, 'mamba_num_heads': 120, 'mamba_head_dim': 80, 'ffn_hidden_size': 13824}
     0.0772 score -> {'num_layers': 56, 'hidden_size': 4352, 'mamba_num_heads': 112, 'mamba_head_dim': 80, 'ffn_hidden_size': 10240}
     0.3550 score -> {'num_layers': 50, 'hidden_size': 4480, 'mamba_num_heads': 112, 'mamba_head_dim': 64, 'ffn_hidden_size': 15680}
     0.1016 score -> {'num_layers': 56, 'hidden_size': 4352, 'mamba_num_heads': 120, 'mamba_head_dim': 72, 'ffn_hidden_size': 10752}
     0.5461 score -> {'num_layers': 46, 'hidden_size': 4480, 'mamba_num_heads': 128, 'mamba_head_dim': 72, 'ffn_hidden_size': 14848}
     0.1992 score -> {'num_layers': 54, 'hidden_size': 4480, 'mamba_num_heads':  80, 'mamba_head_dim': 80, 'ffn_hidden_size': 14336}
     0.5881 score -> {'num_layers': 48, 'hidden_size': 4480, 'mamba_num_heads': 112, 'mamba_head_dim': 80, 'ffn_hidden_size': 13824}

Before your PR is "Ready for review"

  • Make sure you read and follow Contributor guidelines and your commits are signed.
  • Is this change backward compatible?: Yes
  • Did you write any new necessary tests?: Yes
  • Did you add or update any necessary documentation?: Yes
  • Did you update Changelog?: Yes

Additional Information

OMNIML-3043

Summary by CodeRabbit

Release Notes

  • New Features

    • Added NAS-based Auto Pruning for Minitron models as an alternative to manual pruning using parameter constraints
    • Introduced parameter counting capabilities for architecture search
  • Documentation

    • Expanded pruning guides with detailed examples and workflows for both manual and automatic pruning approaches
    • Updated configuration documentation with granular divisor parameters
  • Improvements

    • Enhanced parameter counting support for models with dynamic or mixture-of-experts modules

✏️ Tip: You can customize this high-level summary in your review settings.

@copy-pr-bot
Copy link

copy-pr-bot bot commented Dec 22, 2025

Auto-sync is disabled for draft pull requests in this repository. Workflows must be run manually.

Contributors can view more details about this message here.

@codecov
Copy link

codecov bot commented Dec 22, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 74.18%. Comparing base (951c6aa) to head (c735e7b).
⚠️ Report is 11 commits behind head on main.

Additional details and impacted files
@@            Coverage Diff             @@
##             main     #720      +/-   ##
==========================================
- Coverage   74.23%   74.18%   -0.05%     
==========================================
  Files         192      192              
  Lines       19033    19240     +207     
==========================================
+ Hits        14129    14274     +145     
- Misses       4904     4966      +62     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@kevalmorabia97 kevalmorabia97 force-pushed the kmorabia/minitron-auto branch 3 times, most recently from 11bc408 to 1cbe285 Compare January 2, 2026 12:41
@kevalmorabia97 kevalmorabia97 changed the title Add parameter best auto-pruning for Minitron Add NAS to Minitron pruning for parameter best auto-pruning Jan 8, 2026
@kevalmorabia97 kevalmorabia97 marked this pull request as ready for review January 8, 2026 13:09
@kevalmorabia97 kevalmorabia97 requested review from a team as code owners January 8, 2026 13:09
@kevalmorabia97 kevalmorabia97 changed the title Add NAS to Minitron pruning for parameter best auto-pruning Add NAS to Minitron pruning for parameter based auto-pruning Jan 8, 2026
@kevalmorabia97
Copy link
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 8, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 8, 2026

Important

Review skipped

Auto incremental reviews are disabled on this repository.

Please check the settings in the CodeRabbit UI or the .coderabbit.yaml file in this repository. To trigger a single review, invoke the @coderabbitai review command.

You can disable this status message by setting the reviews.review_status to false in the CodeRabbit configuration file.

📝 Walkthrough

Walkthrough

Adds parameter-count–based automatic NAS pruning for Minitron alongside existing manual export_config pruning; updates configs/APIs with granular divisors, implements search-and-prune utilities, extends docs and examples, and updates tests and utilities to exercise both pruning modes.

Changes

Cohort / File(s) Summary
Editor config & changelog
\.vscode/settings.json, CHANGELOG.rst
VSCode Python analysis key renamed; changelog documents new params-constrained NAS pruning and updates pruning guidelines link.
Documentation & examples
examples/pruning/README.md
Large rewrite adding Manual vs NAS-based Auto Pruning flows, examples, advanced config, prerequisites, and support matrix.
Minitron NAS & searcher core
modelopt/torch/prune/plugins/mcore_minitron.py
Major addition: param-based pruning/search (CandidateSubnet, grid search, run_search, top-k handling), new get_mcore_param_count, expanded default state/config, updated get_mcore_minitron_config signature (keyword-only granular divisors).
Megatron / NAS plugins
modelopt/torch/nas/plugins/megatron.py, modelopt/torch/nas/search_space.py
Simplified Megatron imports; removed MegatronConstraintsFunc and related constraint wiring; added num_layers_divisor handling; reduced verbose summary printing in sort_parameters.
Pruning module surface
modelopt/torch/prune/__init__.py, modelopt/torch/prune/pruning.py
Removed top-level NAS availability import; clarified export_config applicability wording for mcore_minitron.
Searcher runtime & logging
modelopt/torch/opt/searcher.py, modelopt/torch/utils/logging.py
Added print_rank_0 output in forward-loop orchestration; changed num2hrb unit label from "G" to "B".
Network/param utils & plugin exports
modelopt/torch/utils/network.py, modelopt/torch/utils/plugins/*
Added DynamicModule guard in param_num; clarified param counting via forward pass; added all exports for several megatron plugin modules.
Import utilities
modelopt/torch/utils/import_utils.py
Error message updated for import_plugin failures to include "modelopt" in message.
Tests & test utils (API/usage updates)
tests/_test_utils/torch/megatron/models.py, tests/_test_utils/torch/nas_prune/minitron_common.py
Added mamba_num_heads param to test model factory; prune_minitron signature changed from export_config to constraints wrapper and now builds granular get_mcore_minitron_config.
Tests — NAS / dynamic modules
tests/gpu/torch/nas/plugins/*.py
Updated tests to use granular divisors; added self-attention head-sorting test; adjusted MoE/Mamba hybrid configs.
Tests — pruning flows
tests/gpu/torch/prune/plugins/*.py
Switched prune calls to constraints wrapper, added NAS test paths, integrated MCoreMinitronSearcher utilities, reduced forward-loop iterations, added search-space generation tests.

Sequence Diagram

sequenceDiagram
    participant User
    participant MCoreMinitronSearcher
    participant SearchSpace
    participant Pruner
    participant Model
    participant Validator

    User->>MCoreMinitronSearcher: run_search(model, constraints)
    activate MCoreMinitronSearcher

    MCoreMinitronSearcher->>SearchSpace: _generate_search_space_combos(constraints)
    activate SearchSpace
    SearchSpace-->>MCoreMinitronSearcher: candidate_configs
    deactivate SearchSpace

    MCoreMinitronSearcher->>Pruner: sort_and_select_top_k(candidate_configs)
    activate Pruner
    Pruner-->>MCoreMinitronSearcher: top_k_configs
    deactivate Pruner

    loop for each config in top_k_configs
      MCoreMinitronSearcher->>Model: _prune(model, config)
      activate Model
      Model-->>MCoreMinitronSearcher: pruned_model
      deactivate Model

      MCoreMinitronSearcher->>Validator: evaluate(pruned_model)
      activate Validator
      Validator-->>MCoreMinitronSearcher: score
      deactivate Validator

      MCoreMinitronSearcher->>MCoreMinitronSearcher: record_candidate(config, score)
    end

    MCoreMinitronSearcher-->>User: best_architecture
    deactivate MCoreMinitronSearcher
Loading

Estimated Code Review Effort

🎯 4 (Complex) | ⏱️ ~60 minutes

🚥 Pre-merge checks | ✅ 2 | ❌ 1
❌ Failed checks (1 warning)
Check name Status Explanation Resolution
Docstring Coverage ⚠️ Warning Docstring coverage is 55.77% which is insufficient. The required threshold is 80.00%. Write docstrings for the functions missing them to satisfy the coverage threshold.
✅ Passed checks (2 passed)
Check name Status Explanation
Description Check ✅ Passed Check skipped - CodeRabbit’s high-level summary is enabled.
Title check ✅ Passed The title clearly summarizes the main change: adding NAS-based auto-pruning to Minitron based on parameter constraints, which is the core focus of this PR.

✏️ Tip: You can configure your own custom pre-merge checks in the settings.


Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out.

❤️ Share

Comment @coderabbitai help to get the list of available commands and usage tips.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 4

🤖 Fix all issues with AI agents
In @.vscode/settings.json:
- Around line 43-45: Add the missing VS Code Python extension setting by adding
"python.analysis.extraPaths": ["./tests/"] alongside the existing
"cursorpyright.analysis.extraPaths": ["./tests/"] in the settings JSON; ensure
the new key is placed at the same object level as
"cursorpyright.analysis.extraPaths" and that the resulting .vscode settings
remain valid JSON (no trailing comma errors) so both IDEs use ./tests/ on the
Python path.

In @examples/pruning/README.md:
- Line 32: Fix the typo in the README sentence by replacing "requisred" with
"required" in the line that reads "For GradNAS pruning for Hugging Face BERT /
GPT-J, no additional dependencies are requisred." so it reads "For GradNAS
pruning for Hugging Face BERT / GPT-J, no additional dependencies are required."

In @modelopt/torch/opt/searcher.py:
- Line 215: The print_rank_0("Running forward loop...") log bypasses the silence
control; move that call inside the existing context manager (with no_stdout() if
silent else nullcontext():) or wrap it with an if not silent check so it only
runs when silent is False; update the code around print_rank_0, the no_stdout()
context usage, and the silent parameter passed from self.config["verbose"] to
ensure the message respects the provided silence flag.
🧹 Nitpick comments (3)
modelopt/torch/utils/network.py (1)

152-152: Optional: Fix grammar in docstring.

The phrase "This can helpful" is missing "be" and should read "This can be helpful".

✏️ Proposed grammar fix
-    This can helpful for MoE or dynamic modules, where the state dict might contain extra parameters that
+    This can be helpful for MoE or dynamic modules, where the state dict might contain extra parameters that
modelopt/torch/prune/plugins/mcore_minitron.py (1)

504-510: Input dictionary mutation may cause side effects.

The method mutates the input search_space dictionary by calling pop() on hparams_to_skip. If the caller expects the dictionary to remain unchanged (e.g., for logging or reuse), this could cause unexpected behavior.

♻️ Create a copy before mutation
+        # Create a copy to avoid mutating the input
+        search_space = dict(search_space)
+
         if hparams_to_skip:
             print_rank_0(f"Skipping {hparams_to_skip=} during search space generation...")
             for hparam in hparams_to_skip:
                 if hparam in search_space:
                     search_space.pop(hparam)
                 else:
                     warn(f"Hparam {hparam} not found in search space! Skipping...")
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (1)

288-289: Hardcoded parameter count may be fragile.

The assertion assert param_count == 31776.0 is a hardcoded check. If model architecture or default configurations change, this will fail silently without indicating what changed.

Consider adding a comment explaining how this value was computed, or computing it from model dimensions to make the test more self-documenting.

📜 Review details

Configuration used: Organization UI

Review profile: CHILL

Plan: Pro

📥 Commits

Reviewing files that changed from the base of the PR and between b655321 and 745e960.

📒 Files selected for processing (20)
  • .vscode/settings.json
  • CHANGELOG.rst
  • examples/pruning/README.md
  • modelopt/torch/nas/plugins/megatron.py
  • modelopt/torch/nas/search_space.py
  • modelopt/torch/opt/searcher.py
  • modelopt/torch/prune/__init__.py
  • modelopt/torch/prune/plugins/mcore_minitron.py
  • modelopt/torch/prune/pruning.py
  • modelopt/torch/utils/logging.py
  • modelopt/torch/utils/network.py
  • modelopt/torch/utils/plugins/megatron_generate.py
  • modelopt/torch/utils/plugins/megatron_mmlu.py
  • modelopt/torch/utils/plugins/megatron_preprocess_data.py
  • tests/_test_utils/torch/megatron/models.py
  • tests/_test_utils/torch/nas_prune/minitron_common.py
  • tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py
  • tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py
  • tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py
  • tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py
💤 Files with no reviewable changes (1)
  • modelopt/torch/prune/init.py
🧰 Additional context used
🧬 Code graph analysis (8)
modelopt/torch/utils/network.py (1)
modelopt/torch/opt/dynamic.py (1)
  • DynamicModule (338-896)
modelopt/torch/opt/searcher.py (2)
modelopt/torch/utils/logging.py (2)
  • no_stdout (99-103)
  • print_rank_0 (106-109)
modelopt/torch/utils/network.py (1)
  • run_forward_loop (510-580)
tests/_test_utils/torch/nas_prune/minitron_common.py (2)
modelopt/torch/prune/pruning.py (1)
  • prune (31-210)
modelopt/torch/prune/plugins/mcore_minitron.py (1)
  • get_mcore_minitron_config (630-657)
tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (3)
modelopt/torch/prune/plugins/mcore_minitron.py (2)
  • get_mcore_param_count (539-545)
  • get_mcore_minitron_config (630-657)
tests/_test_utils/torch/nas_prune/minitron_common.py (1)
  • prune_minitron (19-37)
tests/_test_utils/torch/megatron/utils.py (1)
  • run_mcore_inference_with_dummy_input (122-129)
modelopt/torch/prune/plugins/mcore_minitron.py (2)
modelopt/torch/nas/utils.py (3)
  • get_subnet_config (166-176)
  • sample (137-148)
  • sort_parameters (187-197)
modelopt/torch/nas/search_space.py (2)
  • sample (95-128)
  • sort_parameters (131-164)
modelopt/torch/nas/plugins/megatron.py (4)
modelopt/torch/opt/dynamic.py (2)
  • get_hparam (801-803)
  • get_hparam (1217-1222)
modelopt/torch/utils/network.py (1)
  • make_divisible (200-216)
modelopt/torch/opt/utils.py (1)
  • get_hparam (74-76)
modelopt/torch/opt/hparam.py (3)
  • choices (123-125)
  • choices (128-141)
  • original (157-159)
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (2)
modelopt/torch/prune/plugins/mcore_minitron.py (2)
  • get_mcore_minitron_config (630-657)
  • _generate_search_space_combos (472-536)
tests/_test_utils/torch/nas_prune/minitron_common.py (1)
  • prune_minitron (19-37)
tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py (2)
modelopt/torch/nas/modules/container.py (1)
  • DynamicModuleList (103-131)
modelopt/torch/nas/plugins/megatron.py (9)
  • NumAttentionHeadsHp (256-295)
  • choices (764-766)
  • active_slice (279-295)
  • active_slice (696-705)
  • active_slice (732-748)
  • active (725-729)
  • _get_output_size_indices (317-379)
  • expand_head_indices (248-253)
  • _get_input_size_indices (415-425)
🪛 LanguageTool
examples/pruning/README.md

[grammar] ~32-~32: Ensure spelling is correct
Context: .../ GPT-J, no additional dependencies are requisred. ## Getting Started As part of the pr...

(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)

⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (5)
  • GitHub Check: linux
  • GitHub Check: wait-checks / wait
  • GitHub Check: wait-checks / wait
  • GitHub Check: build-docs
  • GitHub Check: code-quality
🔇 Additional comments (37)
modelopt/torch/utils/logging.py (1)

46-53: LGTM! Aligns with ML community conventions.

The change from "G" (giga) to "B" (billion) for the third magnitude unit is appropriate for this logging utility in an ML context. The ML/AI community standardly uses "B" notation for parameter counts (e.g., "GPT-3 175B", "LLaMA 7B"), and this PR specifically deals with parameter-based pruning where consistency with that convention improves user experience.

tests/_test_utils/torch/megatron/models.py (3)

316-316: LGTM! Backward-compatible parameter addition.

The new mamba_num_heads parameter is correctly defined as optional with a None default, maintaining backward compatibility. The type hint and placement among other Mamba-specific parameters are appropriate.


350-350: LGTM! Parameter correctly propagated to TransformerConfig.

The mamba_num_heads parameter is properly passed through to the TransformerConfig constructor, consistent with how other Mamba-specific parameters are handled.


362-362: LGTM! Helpful version context added.

The comment clarifies that MoE support is available in Megatron-Core 0.16+, which is useful context for understanding the conditional logic.

modelopt/torch/utils/network.py (1)

115-121: LGTM! Proper defensive guard for DynamicModule.

The check correctly prevents incorrect parameter counting for DynamicModule instances where model.parameters() doesn't account for dynamic slicing. The error message clearly directs users to the appropriate alternative function.

modelopt/torch/utils/plugins/megatron_mmlu.py (1)

50-51: LGTM! Clean public API declaration.

The __all__ declaration correctly exports megatron_mmlu as the module's public interface. This aligns with similar updates across other plugin modules and follows Python best practices for explicit API exposure.

modelopt/torch/utils/plugins/megatron_generate.py (1)

27-28: LGTM! Correctly exports both public generation functions.

The __all__ declaration appropriately exports megatron_generate and megatron_prefill as the module's public API. Internal utilities like get_current_memory_info are correctly excluded from the export list.

modelopt/torch/utils/plugins/megatron_preprocess_data.py (1)

45-46: LGTM! Proper public API declaration for the preprocessing utility.

The __all__ declaration correctly exports megatron_preprocess_data as the module's public function. Private implementation classes (_Encoder, _Partition) and the CLI-only main() function are appropriately excluded.

modelopt/torch/opt/searcher.py (1)

38-38: LGTM!

The import addition is correct and properly used in the new log statement.

modelopt/torch/nas/search_space.py (1)

138-138: LGTM - Simplified logging output.

The removal of verbose summary printing streamlines the output while retaining the essential sorting notification.

modelopt/torch/prune/pruning.py (1)

81-81: LGTM - Documentation accurately reflects the updated API.

The wording change correctly indicates that export_config is now optional (for manual pruning) rather than required, aligning with the new NAS-based auto pruning mode that uses params constraints instead.

CHANGELOG.rst (2)

17-17: LGTM - Clear feature documentation.

The changelog entry accurately describes the new NAS-based auto pruning capability and references the appropriate documentation.


84-84: LGTM - Improved documentation link.

The direct anchor link provides better user experience by navigating directly to the pruning guidelines.

tests/gpu/torch/nas/plugins/test_megatron_mamba_dynamic_modules.py (2)

54-54: LGTM - Helpful clarification.

The comment clearly explains the pattern semantics for better test readability.


78-91: LGTM - Test updated to use new configuration API.

The test correctly uses the new get_mcore_minitron_config helper with explicit divisor parameters, aligning with the broader API changes while preserving test semantics.

examples/pruning/README.md (4)

28-28: LGTM - Updated prerequisites.

The container version update and pip upgrade instruction ensure users have the latest dependencies for the new features.


40-167: LGTM - Excellent documentation structure.

The comprehensive documentation clearly distinguishes between Manual Pruning and NAS-based Auto Pruning modes, with:

  • Clear mode descriptions and use cases
  • Well-structured code examples
  • Detailed "Under the Hood" explanations
  • Helpful tips and notes
  • Advanced configuration guidance

This will significantly improve user understanding and adoption of the new NAS-based auto pruning feature.


172-175: LGTM - Clear support matrix.

The updated matrix effectively distinguishes between manual (export_config) and auto (params with score_func) pruning modes for Minitron.


179-206: LGTM - Well-organized examples section.

The reorganized examples section provides clear pathways for users to find relevant framework-specific documentation and examples, with good references to external resources and tutorials.

modelopt/torch/prune/plugins/mcore_minitron.py (7)

166-171: LGTM - Clean dataclass definition.

The CandidateSubnet dataclass is a clean way to encapsulate subnet candidate information. Consider adding __slots__ for memory efficiency if many candidates are expected, but this is optional.


227-230: LGTM - Clear constraint validation.

The assertion properly validates that exactly one constraint type is provided and that it's one of the supported types (export_config or params).


322-365: LGTM - Comprehensive pruning logic.

The _prune method correctly handles both width and depth pruning, updates model config, and reinitializes MoE token dispatcher. The kv_channels handling for None case is a good defensive measure.


630-657: LGTM - Well-structured configuration API.

The keyword-only signature enforces explicit parameter naming, improving code clarity. The recursive _set_divisors helper correctly handles nested configuration dictionaries.


154-161: LGTM - Clean layer filtering and reindexing.

The refactored code cleanly separates the filtering of kept layers from their reindexing, making the logic easier to follow.


875-875: LGTM - Helpful logging for checkpoint restoration.

The added log message improves observability when activations and scores are loaded from a checkpoint.


583-588: Parameter counting logic for decoder layers is correct and properly handles selective layer counting.

The condition ("decoder.layers." not in name or layer_numbers_to_count is None) correctly excludes decoder layer parameters from the initial sum when layer_numbers_to_count is not None, then adds them back selectively in the loop. This is tested indirectly through the pruning test suite, including uneven pipeline parallelism scenarios.

modelopt/torch/nas/plugins/megatron.py (2)

27-48: LGTM - Cleaned up imports.

The import changes remove unused parallel state utilities and add the required make_divisible utility.


1037-1043: LGTM - Cleaner divisor handling with loop.

The refactored loop pattern for applying divisors to hidden_size and num_layers is cleaner and more maintainable than separate blocks of code. The preservation of the original value via | {hp.original} ensures the model can always restore to its original configuration.

tests/_test_utils/torch/nas_prune/minitron_common.py (1)

19-36: LGTM - API updated for constraints-based pruning.

The signature change from export_config to constraints properly reflects the expanded pruning API that now supports both export_config and params constraint types. The explicit divisor parameters in get_mcore_minitron_config provide clear configuration.

Note: The hardcoded divisors (mamba_head_dim_divisor=4, num_moe_experts_divisor=1, num_layers_divisor=1) may need to be parameterized if tests require different configurations in the future.

tests/gpu/torch/prune/plugins/test_mcore_mamba_minitron_pruning.py (4)

86-91: LGTM - Updated config with explicit divisors.

The explicit divisor parameters make the test configuration clearer and align with the updated get_mcore_minitron_config API.


175-176: Verify reduced iterations don't affect test reliability.

Forward loop iterations reduced from 5 to 2. While this speeds up tests, ensure it still collects sufficient activations for reliable importance estimation in the test scenarios.


345-373: Good handling of non-deterministic layer ordering.

The test correctly handles the fact that layer ordering can vary based on PP configuration by checking sorted_layers and providing expected values for known orderings. The RuntimeError with a clear message (line 372) helps identify when new orderings need to be handled.

The hardcoded expected values are extensive but necessary for deterministic validation. Consider extracting them to a constant or helper function if they need to be reused or modified frequently.


385-390: LGTM - Appropriate test skipping conditions.

The test correctly skips for configurations that aren't supported:

  • More than 4 GPUs (line 388)
  • MoE support in Mamba (line 390) with a TODO note
tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (2)

205-215: LGTM - Consistent constraint wrapper usage.

The test correctly wraps the export_config in a constraints dictionary and passes it to prune_minitron, aligning with the updated API.


510-534: LGTM - Thorough unit test for search space generation.

The test validates the _generate_search_space_combos method with:

  • Width pruning (50%) correctly filters hidden_size to [96, 128, 160] and num_attention_heads to [24, 32]
  • Depth pruning (25%) correctly filters num_layers to [7, 8]
  • hparams_to_skip correctly removes ffn_hidden_size from the search space
  • Total combinations: 3 × 2 × 2 = 12 ✓

The explicit verification of all 12 combinations ensures deterministic behavior.

tests/gpu/torch/nas/plugins/test_megatron_gpt_dynamic_modules.py (2)

86-98: LGTM - Explicit config for search space test.

The updated mtn.convert call with explicit divisor configuration makes the test setup clearer and more maintainable.


170-236: LGTM - Comprehensive test for attention head sorting and pruning.

This test thoroughly validates the NumAttentionHeadsHp class:

  1. Verifies choices are multiples of num_query_groups
  2. Tests importance-based ranking with enforce_order
  3. Validates active_slice returns correct indices after pruning
  4. Verifies QKV and projection layer index mappings are correct

The cleanup via destroy_model_parallel() is important since this test doesn't use the spawned multiprocess pattern.

"""
from modelopt.torch.opt.dynamic import DynamicModule

if isinstance(network, DynamicModule):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems to be applicable only for pruning. Quantization have QuantModule base block - should we have something similar for pruning and do:

if isinstance(network, PruneModule):

or

if hasattr(network, "some_special_pruning_attr")

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since quantization doesnt change param count, I dont see a reason to use this function. Besides, users can still use this on Quantized + Exported model since its not longer DynamicModule. Currently we will just raise an exception if Quantized DynamicModule but later if we need to support it, we can think of such a way. Wdyt?

@kevalmorabia97
Copy link
Collaborator Author

@coderabbitai review

@coderabbitai
Copy link
Contributor

coderabbitai bot commented Jan 19, 2026

✅ Actions performed

Review triggered.

Note: CodeRabbit is an incremental review system and does not re-review already reviewed commits. This command is applicable only when automatic reviews are paused.

Copy link
Contributor

@coderabbitai coderabbitai bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actionable comments posted: 1

🤖 Fix all issues with AI agents
In `@modelopt/torch/prune/plugins/mcore_minitron.py`:
- Around line 543-549: The code is mutating the input dict search_space when
hparams_to_skip is provided; instead create a local shallow copy (e.g.,
new_search_space = dict(search_space) or use search_space.copy()) before
performing the pop operations so the caller's dict isn't changed, then operate
on and return/use that local copy for the rest of the search-space generation
path; update the block that references hparams_to_skip, the pop calls, and any
subsequent uses of search_space in the surrounding function (keep print_rank_0
and warn calls as-is but point them at the copied dict).
♻️ Duplicate comments (3)
.vscode/settings.json (1)

43-45: Add VS Code Python extension extraPaths alongside Cursor.

This drops python.analysis.extraPaths, which likely breaks the documented VS Code setup. Please add it back alongside cursorpyright.analysis.extraPaths.

✅ Proposed fix
     "cursorpyright.analysis.extraPaths": [
         "./tests/" // add tests to python path just like pytest does in pyproject.toml
-    ],
+    ],
+    "python.analysis.extraPaths": [
+        "./tests/"
+    ],
examples/pruning/README.md (1)

32-32: Fix typo: "requisred" should be "required".

📝 Proposed fix
-For GradNAS pruning for Hugging Face BERT / GPT-J, no additional dependencies are requisred.
+For GradNAS pruning for Hugging Face BERT / GPT-J, no additional dependencies are required.
modelopt/torch/prune/plugins/mcore_minitron.py (1)

479-491: Layer reversion logic may still leave model in inconsistent state.

The reversion code saves all_layers and start_layer_number before _prune, then attempts to restore after scoring. However, _prune calls drop_mcore_language_model_layers which modifies model.config.num_layers permanently (line 163 in drop_mcore_language_model_layers). The sample(max) call only resets hyperparameters, not model.config.num_layers.

This could cause assertions to fail at line 304 (assert sorted(self.sorted_layers) == list(range(1, self.model.config.num_layers + 1))) on subsequent iterations since model.config.num_layers would be the pruned count, not the original.

🛠️ Suggested fix

Save and restore model.config.num_layers alongside the layer list:

             if candidate.score is None:  # not restored from checkpoint
                 all_layers = self.model.decoder.layers
                 start_layer_number = all_layers[0].layer_number
+                original_num_layers = self.model.config.num_layers

                 self._prune(candidate.ss_config, prune_depth=True)
                 candidate.score = self.eval_score(silent=False)
                 self.save_search_checkpoint(verbose=False)

                 # reset to max subnet and revert dropped layers
                 sample(self.model, sample_func=max)
                 for layer in all_layers:
                     layer.layer_number = start_layer_number
                     start_layer_number += 1
                 self.model.decoder.layers = all_layers
+                self.model.config.num_layers = original_num_layers
🧹 Nitpick comments (3)
tests/_test_utils/torch/megatron/models.py (1)

316-352: Validate mamba_num_heads constraints (if required by Megatron).

If upstream expects mamba_num_heads to be compatible with hidden_size/mamba_head_dim, add a fast‑fail check so tests fail early with a clear error. Please verify the current Megatron TransformerConfig/Mamba requirements before adding the guard.

Possible guard (adjust per upstream rules)
 def get_mcore_mamba_hybrid_model(
@@
     config = TransformerConfig(
@@
         mamba_state_dim=mamba_state_dim,
         mamba_num_heads=mamba_num_heads,
         mamba_head_dim=mamba_head_dim,
         mamba_num_groups=mamba_num_groups,
@@
     )
+    if mamba_num_heads is not None:
+        assert mamba_num_heads > 0
+        # If upstream expects divisibility or exact match, enforce it here:
+        # assert hidden_size % mamba_num_heads == 0
+        # assert mamba_num_heads * mamba_head_dim == hidden_size
modelopt/torch/prune/plugins/mcore_minitron.py (1)

166-171: Consider using frozen=True for the dataclass.

CandidateSubnet is used as a data container where ss_config and params should be immutable after creation. Only score gets updated. Consider using frozen=True or making score explicitly mutable via a different pattern to prevent accidental modifications.

tests/gpu/torch/prune/plugins/test_mcore_gpt_minitron_pruning.py (1)

510-534: Good test coverage for the new NAS search space generation.

The test thoroughly verifies _generate_search_space_combos:

  • Validates correct filtering based on max_width_pruning and max_depth_pruning thresholds.
  • Confirms hparams_to_skip properly excludes ffn_hidden_size.
  • Asserts both the count and exact expected combinations.

One minor note: This test doesn't require GPU and could potentially be placed in a non-GPU test module to speed up CI when GPU resources are constrained. Consider moving to a unit test file if reorganizing tests in the future.

Signed-off-by: Keval Morabia <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants