Skip to content

Structural pruning defrag — mask first, compact later (like disk defrag) #94

@joelteply

Description

@joelteply

Insight from Joel

Pruning should work in two phases, like disk defragmentation:

Phase 1: Mask (fast, reversible)

  • Mark heads as pruned with a mask (don't modify weights)
  • Forward hooks zero the output of masked heads
  • Model still has same tensor dimensions — no restructuring
  • Can unmask if we change our mind
  • Training sees the mask via hooks, not zeroed weights (avoids NaN gradients)

Phase 2: Defrag/Compact (batch, permanent)

  • After accumulating enough pruned heads (e.g., >20% pruned)
  • Actually restructure tensors: remove masked dimensions
  • Model gets physically smaller (less VRAM, faster inference)
  • Save the compacted model
  • Like defrag: data is gone, now reclaim the space

Why This Matters

Current approach zeros weights directly → NaN during backward pass (0/0 in softmax gradients). Masking avoids this because base weights are intact — the mask just zeros output AFTER the computation.

Threshold-based Defrag

  • Don't defrag after every prune cycle (expensive restructuring)
  • Defrag when cumulative pruning exceeds threshold (20%, 30%, etc.)
  • Or defrag at the END of all forge cycles, once

Implementation

# Phase 1: mask
pruning_mask[layer_idx][head_idx] = False  # masked off
# Forward hook: output[:, :, masked_heads] = 0

# Phase 2: defrag (when ready)
model = structurally_prune(model, pruning_mask)
# Tensors resized, model is physically smaller
model.save_pretrained("compacted/")

Dependencies

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions