Table of Contents
Overview#
Pruning — the removal of unnecessary parameters from a neural network — has been a core compression technique since the late 1980s. The foundational ideas, Optimal Brain Damage (LeCun et al., 1989) and Optimal Brain Surgeon (Hasselmo et al., 1993), were developed for networks with thousands of parameters. Convolutional neural networks (CNNs) with millions of parameters became the primary testbed for pruning research throughout the 2010s. But the rise of large language models (LLMs) with billions of parameters has fundamentally changed the pruning landscape. Nearly every assumption from the CNN pruning era must be revisited.
Why LLM Pruning Is Different from CNN Pruning#
In classical CNN pruning, the standard workflow is: (1) train a dense model to convergence, (2) prune according to some criterion, (3) fine-tune (retrain) the pruned model to recover accuracy. This “prune-then-retrain” loop can be repeated iteratively, sometimes achieving extreme sparsity levels (95%+) with minimal accuracy loss.
For LLMs, this workflow is largely impractical:
| Aspect | CNN Pruning | LLM Pruning |
|---|---|---|
| Model size | 5M-60M parameters | 7B-175B+ parameters |
| Training cost | Hours to days on 1-8 GPUs | Weeks to months on 1000s of GPUs |
| Training data | Well-defined datasets (ImageNet) | Trillions of tokens, often proprietary |
| Retraining feasibility | Standard practice | Prohibitively expensive |
| Task scope | Single task (classification) | General-purpose (generation, reasoning, QA, …) |
| Architecture | Conv layers dominate | Attention + MLP, non-convolutional |
| Activation patterns | ReLU gives natural sparsity | GeLU/SiLU — no natural sparsity |
| Sensitivity to pruning | Gradually degrades | Can catastrophically collapse |
The key consequence is that LLM pruning methods must work without retraining — either as one-shot post-training methods or with only minimal calibration. This constraint has driven an entirely new family of algorithms.
Scale Challenges#
Consider the scale of modern LLMs:
- LLaMA-2 70B: 70 billion parameters, requiring 140 GB in FP16. Training cost estimated at $2-5 million.
- GPT-3 175B: 175 billion parameters, requiring 350 GB in FP16. Training cost estimated at $5-12 million on 2020 hardware.
- LLaMA-3 405B: 405 billion parameters, requiring 810 GB in FP16.
Even a single epoch of fine-tuning on these models requires massive compute. For LLaMA-70B, one pass over the RedPajama dataset (~1.2T tokens) at an estimated 300 tokens/sec/GPU on 8 A100s would take approximately:
$$\text{Time} = \frac{1.2 \times 10^{12}}{300 \times 8} \approx 5 \times 10^{8} \text{ seconds} \approx 15.8 \text{ years}$$Even with 1024 GPUs, that is still ~45 days. This makes iterative prune-retrain cycles effectively impossible for most practitioners.
Memory vs. Compute Bottleneck in LLM Inference#
LLM inference is overwhelmingly memory-bandwidth bound, not compute bound. During autoregressive generation, each new token requires reading the entire model from memory but performs relatively little computation (a single matrix-vector product per layer). The arithmetic intensity is:
$$\text{Arithmetic Intensity} = \frac{\text{FLOPs}}{\text{Bytes Accessed}} \approx \frac{2 \times d_{\text{model}}}{2 \times d_{\text{model}} \times \text{bytes per param}} = \frac{1}{\text{bytes per param}}$$For FP16 (2 bytes per parameter), this gives an arithmetic intensity of 0.5 FLOPs/byte — far below the compute-to-bandwidth ratio of modern GPUs (typically 50-200 FLOPs/byte for Tensor Cores). This means that reducing the number of parameters directly reduces inference latency, because the bottleneck is reading weights from memory.
Pruning therefore has a direct path to speedup — provided the sparsity pattern is hardware-friendly. Unstructured sparsity reduces parameter count but may not reduce memory traffic without sparse format support. Structured sparsity (2:4 patterns, head removal, layer removal) offers more straightforward hardware acceleration.
LLM Architecture and Pruning Targets#
To prune effectively, we must understand exactly where the parameters live in a transformer-based LLM. Modern decoder-only LLMs (GPT, LLaMA, Mistral, etc.) share a common architecture.
Transformer Block Anatomy#
Each transformer block consists of two main sub-blocks: Multi-Head Self-Attention (MHA) and a Feed-Forward Network (MLP). In LLaMA-style architectures, the MLP uses a gated structure (SwiGLU).
┌─────────────────────────────────────────────────────────┐
│ Transformer Block l │
│ │
│ ┌───────────────────────────────────────────────────┐ │
│ │ RMSNorm (attn_norm) │ │
│ │ Params: d_model │ │
│ └─────────────────────┬─────────────────────────────┘ │
│ │ │
│ ┌─────────────────────▼─────────────────────────────┐ │
│ │ Multi-Head Self-Attention │ │
│ │ │ │
│ │ ┌──────┐ ┌──────┐ ┌──────┐ │ │
│ │ │ W_Q │ │ W_K │ │ W_V │ (Linear projs) │ │
│ │ │d x d │ │d x d'│ │d x d'│ d' = d for MHA │ │
│ │ └──┬───┘ └──┬───┘ └──┬───┘ d' = d/GQA for GQA│ │
│ │ │ │ │ │ │
│ │ ▼ ▼ ▼ │ │
│ │ ┌──────────────────────┐ │ │
│ │ │ Scaled Dot-Product │ n_heads parallel │ │
│ │ │ Attention per head │ d_head = d / n_heads │ │
│ │ └──────────┬───────────┘ │ │
│ │ │ │ │
│ │ ┌──────────▼───────────┐ │ │
│ │ │ W_O │ (Output projection) │ │
│ │ │ d_model x d_model │ │ │
│ │ └──────────┬───────────┘ │ │
│ │ │ │ │
│ └──────────────┼────────────────────────────────────┘ │
│ │ │
│ (+ residual) │
│ │ │
│ ┌──────────────▼────────────────────────────────────┐ │
│ │ RMSNorm (ffn_norm) │ │
│ │ Params: d_model │ │
│ └──────────────┬────────────────────────────────────┘ │
│ │ │
│ ┌──────────────▼────────────────────────────────────┐ │
│ │ MLP (SwiGLU) │ │
│ │ │ │
│ │ ┌────────────┐ ┌────────────┐ │ │
│ │ │ W_gate │ │ W_up │ │ │
│ │ │ d x d_ffn │ │ d x d_ffn │ │ │
│ │ └─────┬──────┘ └─────┬──────┘ │ │
│ │ │ │ │ │
│ │ ▼ ▼ │ │
│ │ SiLU(x) * linear(x) │ │
│ │ └───────┬───────┘ │ │
│ │ ▼ │ │
│ │ ┌──────────────┐ │ │
│ │ │ W_down │ │ │
│ │ │ d_ffn x d │ │ │
│ │ └──────┬───────┘ │ │
│ │ │ │ │
│ └────────────────┼──────────────────────────────────┘ │
│ │ │
│ (+ residual) │
│ │ │
│ ▼ │
│ Output to next block │
└─────────────────────────────────────────────────────────┘Parameter Distribution#
For a LLaMA-style model with hidden dimension \(d\), FFN dimension \(d_{\text{ffn}}\), \(L\) layers, vocabulary size \(V\), and GQA groups \(n_{\text{kv}}\) (with \(n_h\) attention heads):
| Component | Parameters per Layer | LLaMA-7B (d=4096, d_ffn=11008, L=32) | % of Total |
|---|---|---|---|
| \(W_Q\) | \(d \times d\) | 16,777,216 | 2.5% |
| \(W_K\) | \(d \times d\) | 16,777,216 | 2.5% |
| \(W_V\) | \(d \times d\) | 16,777,216 | 2.5% |
| \(W_O\) | \(d \times d\) | 16,777,216 | 2.5% |
| \(W_{\text{gate}}\) | \(d \times d_{\text{ffn}}\) | 45,088,768 | 6.6% |
| \(W_{\text{up}}\) | \(d \times d_{\text{ffn}}\) | 45,088,768 | 6.6% |
| \(W_{\text{down}}\) | \(d_{\text{ffn}} \times d\) | 45,088,768 | 6.6% |
| RMSNorm (x2) | \(2d\) | 8,192 | ~0% |
| Attention total | \(4d^2\) | 67,108,864 | ~10% |
| MLP total | \(3d \cdot d_{\text{ffn}}\) | 135,266,304 | ~20% |
| Per-layer total | — | 202,375,168 | ~30% |
Across all 32 layers: \(32 \times 202{,}375{,}168 \approx 6.48 \times 10^9\) parameters in transformer blocks. Adding the embedding layer (\(V \times d = 32000 \times 4096 \approx 131M\)) and final LM head, the total comes to approximately 6.74 billion parameters.
Key observation: The MLP layers account for roughly two-thirds of each transformer block’s parameters. This makes them a primary pruning target. The attention layers, while smaller, contain highly structured redundancy (many heads learn similar patterns).
Which Components Are Most Redundant?#
Empirical studies consistently find:
- MLP layers tolerate higher sparsity than attention layers. The gate-up-down structure in SwiGLU creates natural redundancy — many neurons activate only for specific input patterns.
- Middle layers of the network are more compressible than the first and last few layers. The first layers learn low-level token representations; the last layers directly drive the output distribution. Both are sensitive to perturbation.
- Attention heads show extreme variance in importance. In a 32-head layer, often 8-12 heads can be removed with minimal impact, while 2-3 heads are absolutely critical (removing any one of them causes significant quality degradation).
- The embedding layer is large (131M in LLaMA-7B) but highly critical — it is the only interface between discrete tokens and continuous representations. Pruning the embedding table is rarely done.
Challenges Unique to LLM Pruning#
Retraining Is Prohibitively Expensive#
As computed above, even a single epoch of retraining on the original data is infeasible for most organizations. But the problem is actually worse than the raw compute cost suggests:
- Training data may be unavailable. Many LLMs are trained on proprietary datasets. Even for open-weight models like LLaMA, the exact training data mix and preprocessing are not fully reproducible.
- Hyperparameter sensitivity. Fine-tuning a pruned LLM requires careful learning rate schedules. Too high a learning rate causes catastrophic forgetting; too low fails to recover from pruning damage. This requires expensive sweeps.
- Multi-task generalization. Unlike CNNs (where we fine-tune for one task), LLMs must maintain performance across thousands of tasks simultaneously. Retraining on any single task’s data degrades others.
This motivates one-shot pruning (prune once, no retraining) and few-shot calibration (use a small calibration dataset to guide pruning decisions, but do not update the model through backpropagation).
Activation Outliers in LLMs#
A phenomenon unique to large-scale transformers is the emergence of activation outliers — a small number of hidden dimensions that consistently produce activation magnitudes 10-100x larger than the rest. This was first systematically documented by Dettmers et al. (2022) in the context of quantization (the “LLM.int8()” paper) but has profound implications for pruning.
Consider a weight matrix \(W \in \mathbb{R}^{m \times n}\) applied to input \(x \in \mathbb{R}^n\). The output for row \(i\) is:
$$y_i = \sum_{j=1}^{n} W_{ij} x_j$$If feature dimension \(j^\) consistently has \(|x_{j^}| \gg |x_j|\) for all other \(j\), then removing (pruning) weight \(W_{ij^}\) eliminates a disproportionately large contribution to the output, even if \(W_{ij^}\) itself is small. This is precisely the insight that motivates activation-aware pruning (Wanda).
Activation outliers typically appear in fewer than 1% of hidden dimensions but contribute over 50% of the output magnitude. They emerge at model scales above ~1 billion parameters and become more extreme as the model grows.
Attention Pattern Diversity#
Not all attention heads serve the same function. Empirical analysis of LLMs reveals distinct head types:
- Positional heads: Attend to nearby tokens (local context). These implement n-gram-like patterns.
- Retrieval heads: Attend to specific semantic content regardless of position. Critical for factual recall.
- Induction heads: Copy patterns from earlier in the context. Essential for in-context learning.
- Sink heads: Attend primarily to the first token or special tokens. These are often important for model stability but carry little semantic information.
Pruning a retrieval head may destroy factual knowledge while pruning a redundant positional head may have negligible effect. Any head pruning strategy must account for this heterogeneity.
The Calibration Data Problem#
One-shot pruning methods (SparseGPT, Wanda) require a small calibration dataset to estimate the Hessian or compute activation statistics. The choice of calibration data matters significantly:
- Too narrow (e.g., only code): the pruned model works well on code but degrades on natural language.
- Too broad (random web text): may not capture critical patterns for specialized tasks.
- Too small (< 64 samples): high variance in importance estimates.
- Standard practice: 128 random sequences from C4 (web text), each 2048 tokens. This has become the de facto standard since the SparseGPT paper.
Perplexity as Evaluation Metric#
The primary metric for evaluating pruned LLMs is perplexity (PPL), measured on a held-out dataset (typically WikiText-2 or C4):
$$\text{PPL} = \exp\left(-\frac{1}{N}\sum_{i=1}^{N} \log p(x_i \mid x_{where \(N\) is the total number of tokens and \(p(x_i \mid x_{<i})\) is the model’s predicted probability of token \(x_i\) given all preceding tokens.Perplexity is the exponentiated average negative log-likelihood. Intuitively, a perplexity of \(k\) means the model is “as uncertain as if it were choosing uniformly among \(k\) options at each step.” Lower is better.
Numerical example: If a dense LLaMA-7B model achieves PPL = 5.68 on WikiText-2, and a 50% sparse version achieves PPL = 5.95, the relative increase is:
$$\frac{5.95 - 5.68}{5.68} \approx 4.75\%$$This is generally considered acceptable. An increase to PPL = 7.0 (23% relative) would be considered a significant degradation. An increase beyond PPL = 10 typically indicates catastrophic quality loss.
SparseGPT (Frantar & Alistarh, 2023) — Complete Deep Dive#
SparseGPT is the foundational algorithm for one-shot LLM pruning. It demonstrated for the first time that GPT-scale models (up to 175 billion parameters) can be pruned to 50-60% sparsity in a single pass, without any retraining, while maintaining competitive perplexity.
Problem Statement#
Given a pre-trained weight matrix \(W \in \mathbb{R}^{d_{\text{row}} \times d_{\text{col}}}\) and a calibration dataset producing input activations \(X \in \mathbb{R}^{d_{\text{col}} \times N}\), find a sparse weight matrix \(\hat{W}\) that minimizes the reconstruction error:
$$\min_{\hat{W}} \| WX - \hat{W}X \|_F^2 \quad \text{subject to} \quad \hat{W} \text{ has at most } k \text{ non-zeros per row}$$This is a layer-wise objective: we prune each weight matrix independently, aiming to preserve the layer’s input-output behavior on the calibration data.
The Optimal Brain Surgeon (OBS) Framework#
SparseGPT builds upon the Optimal Brain Surgeon (OBS) framework. The idea is: when we prune (set to zero) a single weight \(w_q\), we should adjust all remaining weights to compensate for the error introduced. This is fundamentally different from magnitude pruning, which simply removes weights without compensation.
Derivation from First Principles#
Consider the second-order Taylor expansion of the loss around the current weight vector \(w\):
$$\mathcal{L}(w + \delta_w) \approx \mathcal{L}(w) + g^T \delta_w + \frac{1}{2} \delta_w^T H \delta_w$$where \(g = \nabla_w \mathcal{L}\) is the gradient and \(H = \nabla_w^2 \mathcal{L}\) is the Hessian. At a (local) minimum, the gradient \(g \approx 0\), so:
$$\Delta\mathcal{L} = \mathcal{L}(w + \delta_w) - \mathcal{L}(w) \approx \frac{1}{2} \delta_w^T H \delta_w$$We want to find the weight update \(\delta_w\) that minimizes this loss increase, subject to the constraint that weight \(q\) is pruned (set to zero):
$$\min_{\delta_w} \frac{1}{2} \delta_w^T H \delta_w \quad \text{subject to} \quad e_q^T (w + \delta_w) = 0$$where \(e_q\) is the unit vector with a 1 in position \(q\). The constraint says: the new weight at position \(q\) must be zero, i.e., \(w_q + \delta_{w_q} = 0\), meaning \(\delta_{w_q} = -w_q\).
Solving via Lagrange Multipliers#
Form the Lagrangian:
$$\mathcal{L}(\delta_w, \lambda) = \frac{1}{2} \delta_w^T H \delta_w + \lambda \left(e_q^T w + e_q^T \delta_w\right)$$Take partial derivatives and set them to zero:
Condition 1 — \(\frac{\partial \mathcal{L}}{\partial \delta_w} = 0\):
$$H \delta_w + \lambda e_q = 0$$$$\delta_w = -\lambda H^{-1} e_q \tag{1}$$Condition 2 — \(\frac{\partial \mathcal{L}}{\partial \lambda} = 0\):
$$e_q^T w + e_q^T \delta_w = 0$$$$w_q + \delta_{w_q} = 0 \tag{2}$$Now substitute (1) into (2). The \(q\)-th component of \(\delta_w\) is:
$$\delta_{w_q} = e_q^T \delta_w = -\lambda e_q^T H^{-1} e_q = -\lambda [H^{-1}]_{qq}$$Substituting into (2):
$$w_q - \lambda [H^{-1}]_{qq} = 0$$$$\lambda = \frac{w_q}{[H^{-1}]_{qq}} \tag{3}$$Substituting (3) back into (1):
$$\boxed{\delta_w = -\frac{w_q}{[H^{-1}]_{qq}} \cdot H^{-1} e_q = -\frac{w_q}{[H^{-1}]_{qq}} \cdot (H^{-1})_{:,q}}$$This is the OBS weight update formula. When we prune weight \(q\), we shift every remaining weight by an amount proportional to the corresponding column of \(H^{-1}\).
Saliency Score#
The increase in loss from pruning weight \(q\) (with optimal compensation) is obtained by substituting the optimal \(\delta_w\) back into the objective:
$$\Delta\mathcal{L}_q = \frac{1}{2} \delta_w^T H \delta_w$$Using \(\delta_w = -\lambda H^{-1} e_q\):
$$\Delta\mathcal{L}_q = \frac{1}{2} \lambda^2 e_q^T H^{-1} H H^{-1} e_q = \frac{1}{2} \lambda^2 [H^{-1}]_{qq}$$Substituting \(\lambda = w_q / [H^{-1}]_{qq}\):
$$\boxed{\text{Saliency}_q = \Delta\mathcal{L}_q = \frac{w_q^2}{2[H^{-1}]_{qq}}}$$This score tells us the cost of pruning weight \(q\) — weights with low saliency should be pruned first. Note this differs from simple magnitude pruning (\(|w_q|^2\)) by the factor \(1/[H^{-1}]_{qq}\), which accounts for the curvature of the loss landscape. A weight might be large but cheap to prune if the Hessian indicates the loss is flat in that direction.
Hessian Computation for Transformers#
For a linear layer \(y = Wx\) with squared-error reconstruction loss \(| WX - \hat{W}X |_F^2\), the Hessian with respect to a single row \(w_i\) of \(W\) is:
$$H = 2 X X^T$$where \(X \in \mathbb{R}^{d_{\text{col}} \times N}\) is the matrix of input activations (\(N\) calibration samples). This is because the loss for row \(i\) is:
$$\ell_i = \| w_i X - \hat{w}_i X \|^2 = (w_i - \hat{w}_i) X X^T (w_i - \hat{w}_i)^T$$and \(\nabla^2_{\delta_{w_i}} \ell_i = 2 X X^T\).
This is a critical simplification: the Hessian is the same for every row of the weight matrix, and it depends only on the input activations, not on the weights themselves. This means we compute \(H = 2XX^T \in \mathbb{R}^{d_{\text{col}} \times d_{\text{col}}}\) once and reuse it for every row.
In practice, \(H\) is regularized to ensure invertibility:
$$H \leftarrow H + \epsilon I, \quad \epsilon = 10^{-2} \cdot \text{mean}(\text{diag}(H))$$The Row-Wise Pruning Trick#
A key insight in SparseGPT is that the pruning problem decomposes across rows. Since each row \(w_i\) of \(W\) produces one output dimension independently, and the Hessian is shared, we can process each row separately:
┌─────────────────────────────────────────────────────────────┐
│ Weight Matrix W (d_row x d_col) │
│ │
│ Row 0: [ w_00 w_01 w_02 ... w_0,d_col ] ──► Prune │
│ Row 1: [ w_10 w_11 w_12 ... w_1,d_col ] ──► Prune │
│ Row 2: [ w_20 w_21 w_22 ... w_2,d_col ] ──► Prune │
│ ... ... │
│ Row m: [ w_m0 w_m1 w_m2 ... w_m,d_col ] ──► Prune │
│ │
│ Each row is pruned independently using the SAME H^{-1} │
│ (shared Hessian depends only on input activations X) │
└─────────────────────────────────────────────────────────────┘For each row, we iterate through columns from left to right. At each column, we decide whether to prune it (based on saliency) and, if so, apply the OBS update to all remaining columns in that row.
Lazy Batch Updates#
Processing columns one at a time is computationally expensive because each OBS update modifies all remaining weights. SparseGPT introduces lazy batch updates: process columns in blocks of size \(B\) (typically \(B = 128\)).
Within a block of \(B\) columns, pruning decisions and local weight updates are computed. The global update to all remaining columns (beyond the current block) is accumulated and applied once at the end of the block. This reduces the number of expensive full-row updates from \(d_{\text{col}}\) to \(d_{\text{col}} / B\).
Let the current block span columns \(\mathcal{B} = {b, b+1, \ldots, b+B-1}\). Partition the weight row \(w\) and the inverse Hessian accordingly. Within block \(\mathcal{B}\):
- For each column \(q \in \mathcal{B}\): compute saliency, decide whether to prune, and apply the OBS update only within the block.
- After processing all \(B\) columns: apply the accumulated update to all columns in \({b+B, b+B+1, \ldots, d_{\text{col}}-1}\) in a single matrix operation.
The accumulated update for the remaining columns is:
$$\delta_{w_{\text{remaining}}} = -\left(\sum_{q \in \mathcal{B}, \text{pruned}} \frac{w_q}{[H^{-1}]_{qq}} \cdot (H^{-1})_{\text{remaining}, q}\right)$$This can be computed as a matrix multiply, making it efficient on modern hardware.
Cholesky Decomposition for Efficient Updates#
Rather than explicitly computing and inverting \(H\), SparseGPT uses the Cholesky decomposition of \(H^{-1}\). Since \(H\) is symmetric positive definite (after regularization), we can write:
$$H^{-1} = L L^T$$where \(L\) is lower triangular. This provides an efficient way to extract the quantities we need:
- \([H^{-1}]_{qq}\) is simply the squared norm of the \(q\)-th row of \(L\).
- The column \((H^{-1})_{:,q}\) can be extracted from \(L\).
As we process columns left to right and “eliminate” pruned columns, we can update the Cholesky factor incrementally rather than recomputing from scratch. This is mathematically equivalent to a Cholesky downdate operation.
Full Algorithm Pseudocode#
Algorithm: SparseGPT (per weight matrix W, given input activations X)
────────────────────────────────────────────────────────────────
Input: W ∈ R^{d_row × d_col} — pre-trained weight matrix
X ∈ R^{d_col × N} — calibration activations
s ∈ (0, 1) — target sparsity ratio
B — block size (default 128)
1. Compute Hessian: H = 2 * X * X^T + ε * I
2. Compute Cholesky of H^{-1}: L = cholesky(H^{-1})
3. Initialize: Err = zeros(d_row, B) // error accumulator
4. For b = 0, B, 2B, ..., d_col - B: // block loop
5. For j = b, b+1, ..., b+B-1: // column loop within block
6. // Pruning decision for column j (across all rows)
7. // Option A: global threshold on saliency
8. // Option B: per-row top-k sparsity (prune s fraction)
9. For each row i where w_{ij} is selected for pruning:
10. Err[i, j-b] += w_{ij} / [H^{-1}]_{jj}
11. w_{ij} = 0
12. // Update remaining weights in current block
13. W[:, j+1:b+B] -= Err[:, j-b:j-b+1] * H^{-1}[j, j+1:b+B] / [H^{-1}]_{jj}
14. // Lazy batch update: propagate error to all remaining columns
15. W[:, b+B:] -= Err * H^{-1}[b:b+B, b+B:]
16. Return W (now sparse)
────────────────────────────────────────────────────────────────Complexity Analysis#
- Hessian computation: \(O(d_{\text{col}}^2 \times N)\) — matrix multiply \(XX^T\).
- Cholesky decomposition: \(O(d_{\text{col}}^3)\).
- Per-row pruning: \(O(d_{\text{col}}^2)\) due to the block updates.
- Total per weight matrix: \(O(d_{\text{row}} \times d_{\text{col}}^2)\) for the pruning loop, plus \(O(d_{\text{col}}^3)\) for the Hessian inverse. Since \(d_{\text{row}}\) and \(d_{\text{col}}\) are typically both \(O(d)\), the total is \(O(d^3)\) per weight matrix, and \(O(L \times d^3)\) for the entire model.
For LLaMA-7B (\(d = 4096\), \(L = 32\), 7 weight matrices per layer): approximately \(32 \times 7 \times 4096^3 \approx 1.5 \times 10^{13}\) FLOPs. On a single A100 GPU, this takes approximately 10-15 minutes — remarkably efficient for pruning a 7B model.
Numerical Example#
Let us work through a complete example with a small \(4 \times 4\) weight matrix. Consider:
$$W = \begin{bmatrix} 0.5 & -0.3 & 0.8 & 0.1 \\ -0.2 & 0.6 & -0.1 & 0.4 \end{bmatrix}$$Suppose from calibration data we compute:
$$H = XX^T = \begin{bmatrix} 2.0 & 0.5 & 0.3 & 0.1 \\ 0.5 & 1.5 & 0.2 & 0.4 \\ 0.3 & 0.2 & 3.0 & 0.1 \\ 0.1 & 0.4 & 0.1 & 1.0 \end{bmatrix}$$(We drop the factor of 2 for simplicity and absorb it into the regularization.)
First, compute \(H^{-1}\) (with slight regularization):
$$H^{-1} \approx \begin{bmatrix} 0.556 & -0.160 & -0.045 & 0.028 \\ -0.160 & 0.767 & -0.022 & -0.286 \\ -0.045 & -0.022 & 0.342 & -0.008 \\ 0.028 & -0.286 & -0.008 & 1.133 \end{bmatrix}$$Target: 50% sparsity (prune 2 out of 4 weights per row).
Process Row 0: \(w_0 = [0.5, -0.3, 0.8, 0.1]\)
Compute saliency scores:
- \(S_0 = \frac{0.5^2}{2 \times 0.556} = \frac{0.25}{1.112} = 0.225\)
- \(S_1 = \frac{(-0.3)^2}{2 \times 0.767} = \frac{0.09}{1.534} = 0.059\)
- \(S_2 = \frac{0.8^2}{2 \times 0.342} = \frac{0.64}{0.684} = 0.936\)
- \(S_3 = \frac{0.1^2}{2 \times 1.133} = \frac{0.01}{2.266} = 0.004\)
Sorted by saliency: \(S_3 = 0.004 < S_1 = 0.059 < S_0 = 0.225 < S_2 = 0.936\)
Prune columns 3 and 1 (lowest saliency). Start with column 3 (\(w_{0,3} = 0.1\)):
$$\delta_w = -\frac{0.1}{1.133} \cdot (H^{-1})_{:,3} = -0.0883 \cdot [0.028, -0.286, -0.008, 1.133]^T$$$$\delta_w = [-0.0025, 0.0252, 0.0007, -0.1]$$Update: \(w_0 \leftarrow [0.5 - 0.0025, -0.3 + 0.0252, 0.8 + 0.0007, 0] = [0.4975, -0.2748, 0.8007, 0]\)
Now prune column 1 (\(w_{0,1} = -0.2748\)), using the updated inverse Hessian (with column 3 removed):
$$\delta_w = -\frac{-0.2748}{0.767} \cdot (H^{-1})_{:,1} = 0.3583 \cdot [-0.160, 0.767, -0.022, ...]$$$$\delta_{w_0} = 0.3583 \times (-0.160) = -0.0573$$$$\delta_{w_2} = 0.3583 \times (-0.022) = -0.0079$$Final Row 0: \(\hat{w}_0 = [0.4975 - 0.0573, \ 0, \ 0.8007 - 0.0079, \ 0] = [0.4402, \ 0, \ 0.7928, \ 0]\)
The same process repeats for Row 1. The key observation is that pruning with OBS compensation preserves the row’s output much better than simply zeroing the weights. Without compensation, Row 0 would be \([0.5, 0, 0.8, 0]\), but with compensation it becomes \([0.4402, 0, 0.7928, 0]\) — the remaining weights are adjusted to partially absorb the error from the pruned weights.
SparseGPT Results#
| Model | Sparsity | WikiText-2 PPL (Dense) | WikiText-2 PPL (Sparse) | PPL Increase |
|---|---|---|---|---|
| OPT-175B | 50% unstructured | 8.34 | 8.52 | +2.2% |
| OPT-175B | 60% unstructured | 8.34 | 9.21 | +10.4% |
| OPT-175B | 2:4 structured | 8.34 | 9.14 | +9.6% |
| LLaMA-7B | 50% unstructured | 5.68 | 5.95 | +4.8% |
| LLaMA-7B | 60% unstructured | 5.68 | 6.55 | +15.3% |
| LLaMA-13B | 50% unstructured | 5.09 | 5.28 | +3.7% |
| LLaMA-30B | 50% unstructured | 4.10 | 4.21 | +2.7% |
| LLaMA-65B | 50% unstructured | 3.53 | 3.60 | +2.0% |
Key observations: (1) Larger models are easier to prune — LLaMA-65B loses only 2% PPL at 50% sparsity. (2) Beyond 60% unstructured sparsity, quality degrades rapidly. (3) 2:4 structured sparsity is only slightly worse than unstructured at the same density.
Wanda (Pruning by Weights AND Activations, Sun et al., 2023) — Complete Deep Dive#
Motivation#
SparseGPT achieves excellent results but requires computing and inverting the Hessian matrix \(H \in \mathbb{R}^{d_{\text{col}} \times d_{\text{col}}}\), which for LLaMA-7B means storing and decomposing a \(4096 \times 4096\) matrix per layer. While feasible, this is computationally non-trivial and requires careful numerical implementation (Cholesky decomposition, lazy updates, etc.).
Wanda asks: can we achieve comparable pruning quality with a much simpler metric?
Key Insight: Weight Magnitude Alone Is Insufficient#
Classical magnitude pruning ranks weights by \(|W_{ij}|\) and prunes the smallest ones. This ignores a crucial factor: how much the input actually activates each weight. Consider two weights:
- Weight A: \(|W_A| = 0.01\) connected to a feature with \(|X_A|_2 = 100\). Output contribution: \(0.01 \times 100 = 1.0\)
- Weight B: \(|W_B| = 0.5\) connected to a feature with \(|X_B|_2 = 0.01\). Output contribution: \(0.5 \times 0.01 = 0.005\)
Magnitude pruning would keep Weight B and prune Weight A. But Weight A contributes 200x more to the output! This discrepancy is especially severe in LLMs because of activation outliers — a few features have enormous activation norms.
The Wanda Pruning Metric#
Wanda’s metric combines weight magnitude with input activation norm:
$$\boxed{\text{Score}(W_{ij}) = |W_{ij}| \cdot \|X_j\|_2}$$where:
- \(W_{ij}\) is the weight connecting input feature \(j\) to output neuron \(i\)
- \(X_j \in \mathbb{R}^{N}\) is the vector of activations for feature \(j\) across all \(N\) calibration tokens
- \(|X_j|2 = \sqrt{\sum{n=1}^{N} x_{j,n}^2}\) is the L2 norm across calibration samples
Weights with low scores are pruned. The intuition is clear: a weight is unimportant if it is small, or if the feature it connects to is rarely activated, or both.
Mathematical Justification#
Consider the output of a linear layer for a single output neuron \(i\):
$$y_i = \sum_{j=1}^{d_{\text{col}}} W_{ij} x_j$$The error introduced by pruning weight \(W_{ij}\) (setting it to zero without compensation) is:
$$\Delta y_i = W_{ij} x_j$$The squared error, aggregated over \(N\) calibration samples, is:
$$\text{Error}_{ij} = \sum_{n=1}^{N} (W_{ij} \cdot x_{j,n})^2 = W_{ij}^2 \sum_{n=1}^{N} x_{j,n}^2 = W_{ij}^2 \cdot \|X_j\|_2^2$$Taking the square root to get a more stable metric:
$$\sqrt{\text{Error}_{ij}} = |W_{ij}| \cdot \|X_j\|_2$$This is exactly the Wanda score. It is a first-order approximation of the pruning error for each individual weight, without considering inter-weight correlations (which is what the Hessian captures in SparseGPT).
To make this more rigorous, note the connection to the diagonal of the Hessian. For the row-wise squared-error loss, the Hessian is \(H = 2XX^T\). The diagonal entry is:
$$H_{jj} = 2\|X_j\|_2^2$$The OBS saliency score (from SparseGPT) is:
$$\text{Saliency}_j^{\text{OBS}} = \frac{W_{ij}^2}{2[H^{-1}]_{jj}}$$If we approximate \(H\) as diagonal (i.e., ignore off-diagonal correlations between features), then \([H^{-1}]{jj} = 1/H{jj} = 1/(2|X_j|_2^2)\), and:
$$\text{Saliency}_j^{\text{diag}} = \frac{W_{ij}^2}{2 / (2\|X_j\|_2^2)} = W_{ij}^2 \cdot \|X_j\|_2^2$$This is exactly the square of the Wanda score. Therefore, Wanda can be understood as the OBS framework under a diagonal Hessian approximation — it ignores feature correlations but captures the first-order importance of each weight.
Per-Output Pruning#
Wanda prunes within each row (output neuron) independently. For a target sparsity of \(s\), in each row of \(d_{\text{col}}\) weights, the \(\lfloor s \times d_{\text{col}} \rfloor\) weights with the lowest Wanda scores are set to zero. This ensures uniform sparsity across output neurons, which is important for maintaining balanced representations.
Algorithm#
Algorithm: Wanda (per weight matrix W, given input activations X)
────────────────────────────────────────────────────────────────
Input: W ∈ R^{d_row × d_col} — pre-trained weight matrix
X ∈ R^{d_col × N} — calibration activations
s ∈ (0, 1) — target sparsity ratio
1. Compute activation norms: a_j = ||X_j||_2 for j = 1, ..., d_col
2. Compute score matrix: S_{ij} = |W_{ij}| * a_j
3. For each row i = 1, ..., d_row:
4. Find the indices of the ⌊s × d_col⌋ smallest entries in S[i, :]
5. Set W[i, those indices] = 0
6. Return W (now sparse)
────────────────────────────────────────────────────────────────This is remarkably simple: the entire algorithm is essentially 5 lines of code. No Hessian computation, no Cholesky decomposition, no iterative updates. The only computation beyond the weight matrix itself is a single pass over the calibration data to compute \(|X_j|_2\).
Speed Comparison#
| Operation | SparseGPT | Wanda |
|---|---|---|
| Activation statistics | \(O(d^2 N)\) (full \(XX^T\)) | \(O(dN)\) (column norms only) |
| Hessian inversion | \(O(d^3)\) Cholesky | Not needed |
| Pruning pass | \(O(d_{\text{row}} \times d^2)\) with block updates | \(O(d_{\text{row}} \times d)\) sorting |
| Total for LLaMA-7B | ~10-15 minutes (1 GPU) | ~30 seconds (1 GPU) |
Wanda is approximately 20-30x faster than SparseGPT while achieving comparable quality at moderate sparsity levels.
N:M Sparsity Support#
Wanda naturally supports N:M structured sparsity. For 2:4 sparsity, within each group of 4 consecutive weights in a row, the 2 with the lowest Wanda scores are pruned. The only change to the algorithm is replacing global per-row sorting with local within-group sorting:
For each row i:
For g = 0, 4, 8, ..., d_col - 4: // groups of 4
scores = S[i, g:g+4]
prune the 2 entries with the lowest scoresResults Comparison: Wanda vs. SparseGPT#
| Model | Method | Sparsity | WikiText-2 PPL | C4 PPL |
|---|---|---|---|---|
| LLaMA-7B | Dense | 0% | 5.68 | 7.08 |
| LLaMA-7B | Magnitude | 50% | 17.29 | 19.54 |
| LLaMA-7B | SparseGPT | 50% | 5.95 | 7.51 |
| LLaMA-7B | Wanda | 50% | 6.01 | 7.56 |
| LLaMA-7B | SparseGPT | 2:4 | 7.16 | 8.82 |
| LLaMA-7B | Wanda | 2:4 | 7.26 | 8.91 |
| LLaMA-13B | Magnitude | 50% | 10.13 | 12.40 |
| LLaMA-13B | SparseGPT | 50% | 5.28 | 6.75 |
| LLaMA-13B | Wanda | 50% | 5.33 | 6.79 |
| LLaMA-30B | SparseGPT | 50% | 4.21 | 5.76 |
| LLaMA-30B | Wanda | 50% | 4.27 | 5.82 |
| LLaMA-65B | SparseGPT | 50% | 3.60 | 5.23 |
| LLaMA-65B | Wanda | 50% | 3.64 | 5.27 |
Key findings:
- Wanda is within 1-2% PPL of SparseGPT at 50% sparsity across all model sizes.
- Simple magnitude pruning is catastrophically bad (17.29 vs. 5.68 for LLaMA-7B), demonstrating that activation awareness is essential.
- The gap between Wanda and SparseGPT widens at higher sparsity (60%+), where the Hessian-based weight updates of SparseGPT become more valuable.
- Both methods dramatically outperform magnitude pruning, confirming that activation outliers dominate pruning behavior in LLMs.
Magnitude Pruning Revisited for LLMs#
Despite the clear superiority of activation-aware methods, it is instructive to understand exactly when and why simple magnitude pruning fails for LLMs, and how it can be partially rescued.
Global vs. Per-Layer vs. Per-Row Thresholds#
There are three granularities for setting the pruning threshold:
Global: Compute a single threshold across all weights in the model. Prune all weights below it. This leads to highly non-uniform sparsity — sensitive layers may be over-pruned while redundant layers are under-pruned.
Per-layer (uniform): Apply the same sparsity ratio to every layer. This is the most common approach but ignores the fact that different layers have vastly different sensitivity. Empirically, the first and last layers are much more sensitive — applying 60% sparsity uniformly often means the first layer is catastrophically damaged.
Per-row (per-output): Apply the target sparsity independently within each row of each weight matrix. This ensures balanced sparsity across output neurons and is the approach used by both SparseGPT and Wanda.
Why Per-Layer Uniform Sparsity Fails#
Different layers in an LLM have dramatically different weight distributions and sensitivity to pruning. Consider the weight magnitude statistics for LLaMA-7B:
Layer Sensitivity Analysis (LLaMA-7B, 50% unstructured pruning)
─────────────────────────────────────────────────────────────
Layer | Avg |W| (×10⁻³) | Outlier Ratio | PPL if pruned alone
─────────────────────────────────────────────────────────────
Layer 0 | 8.2 | 0.3% | 6.42 (+13.0%)
Layer 1 | 5.1 | 0.5% | 5.82 (+2.5%)
Layer 2 | 4.8 | 0.7% | 5.79 (+1.9%)
... | ... | ... | ...
Layer 15 | 3.9 | 1.2% | 5.72 (+0.7%)
Layer 16 | 3.8 | 1.4% | 5.71 (+0.5%)
... | ... | ... | ...
Layer 30 | 4.5 | 0.8% | 5.75 (+1.2%)
Layer 31 | 6.1 | 0.4% | 6.15 (+8.3%)
─────────────────────────────────────────────────────────────The first layer (Layer 0) and last layer (Layer 31) are significantly more sensitive. An intelligent strategy would prune these layers less aggressively and allocate the saved budget to middle layers.
Outlier-Aware Magnitude Pruning#
A simple rescue for magnitude pruning is to protect outlier channels. If we know that feature dimensions with high activation norms are critical, we can:
- Identify the top-\(k\) features by activation norm (e.g., \(k = 0.01 \times d\)).
- Mark all weights connected to these features as unprunable.
- Apply magnitude pruning only to the remaining weights.
This hybrid approach (OWL — Outlier Weighed Layerwise sparsity, Yin et al., 2023) significantly improves magnitude pruning but still falls short of Wanda and SparseGPT because it only protects a binary set of features rather than smoothly weighting all features by their activation norms.
Structural Pruning for LLMs#
Unstructured sparsity (arbitrary individual weights set to zero) offers maximum flexibility but requires sparse matrix formats and specialized kernels to achieve actual speedups. Structural pruning removes entire computational units — heads, neurons, or layers — producing a smaller dense model that runs faster on standard hardware without any sparse acceleration support.
Attention Head Pruning#
A multi-head attention layer with \(n_h\) heads computes:
$$\text{MHA}(x) = \text{Concat}(\text{head}_1, \ldots, \text{head}_{n_h}) W_O$$where each head is:
$$\text{head}_h = \text{Softmax}\left(\frac{(xW_Q^h)(xW_K^h)^T}{\sqrt{d_h}}\right) (xW_V^h)$$Removing head \(h\) is equivalent to zeroing out the \(h\)-th block of \(d_h\) columns in \(W_Q\), \(W_K\), \(W_V\) and the corresponding \(d_h\) rows in \(W_O\). The parameter savings per removed head are:
$$\Delta P = 4 \times d_{\text{model}} \times d_h = 4 \times d_{\text{model}} \times \frac{d_{\text{model}}}{n_h}$$For LLaMA-7B (\(d = 4096\), \(n_h = 32\), \(d_h = 128\)): removing one head saves \(4 \times 4096 \times 128 = 2{,}097{,}152\) parameters per layer, or 67.1M across 32 layers.
Head Importance Scoring Methods#
Gradient-based importance: Measure how much the loss changes when a head’s output is perturbed:
$$I_h = \mathbb{E}\left[|\nabla_{\text{Attn}_h} \mathcal{L} \cdot \text{Attn}_h|\right]$$This is a first-order Taylor approximation of the loss change from removing the head. The expectation is taken over calibration samples. Heads where the gradient-activation product is large are critical; those where it is small contribute little to the loss.
Full derivation: The loss change from scaling head \(h\)’s output by \((1 - \epsilon)\) is:
$$\Delta\mathcal{L} \approx -\epsilon \cdot \nabla_{\text{Attn}_h} \mathcal{L}^T \cdot \text{Attn}_h$$Setting \(\epsilon = 1\) (complete removal) and taking the absolute value gives the importance score. The absolute value handles the sign ambiguity (removing a head might increase or decrease the loss, and we care about the magnitude of the change).
Taylor expansion importance: A slightly different formulation uses the element-wise product:
$$I_h = \left|\text{Attn}_h^T \nabla_{\text{Attn}_h} \mathcal{L}\right|$$This is equivalent to the gradient-based method but expressed as an inner product, making it a scalar importance score per head.
Entropy-based importance: Measure the entropy of each head’s attention distribution:
$$\text{Entropy}_h = -\sum_{i,j} A_{h,ij} \log A_{h,ij}$$where \(A_h\) is the attention matrix for head \(h\). Heads with very low entropy (attending to 1-2 positions) are often critical (they implement specific retrieval or positional functions). Heads with high entropy (near-uniform attention) are often redundant.
Effects on Different Capabilities#
Empirical studies show that different types of heads affect different capabilities:
| Head Type | Pruning Effect | Typical Fraction |
|---|---|---|
| Positional (local) | Minor degradation in fluency | 30-40% |
| Retrieval (content) | Factual accuracy drops | 10-15% |
| Induction (copy) | In-context learning degrades | 5-10% |
| Sink (first-token) | Model stability issues | 5-10% |
| Redundant/noisy | No measurable effect | 25-40% |
Typically, 25-50% of attention heads can be removed with minimal impact on perplexity, but the effect on specific downstream tasks (especially factual QA and in-context learning) can be more severe than perplexity alone suggests.
Width Pruning (Neuron/Channel Removal)#
Width pruning removes entire neurons (rows or columns) from weight matrices, reducing the hidden dimension of MLP or attention layers.
SliceGPT (Ashkboos et al., 2024)#
SliceGPT introduces a mathematically elegant approach to structural pruning based on orthogonal transformations that make weight matrices more amenable to column/row removal.
Key Idea: Before pruning columns from a weight matrix, apply an orthogonal rotation that concentrates the important information into a subset of columns. This is analogous to PCA — project into a basis where most of the variance is captured by a few dimensions, then discard the rest.
Mathematical Framework:
Consider a linear layer \(y = Wx\). We can insert an orthogonal matrix \(Q\) (where \(Q^T Q = I\)) without changing the computation:
$$y = Wx = W(QQ^T)x = (WQ)(Q^T x) = W' x'$$where \(W’ = WQ\) and \(x’ = Q^T x\). The rotated weight matrix \(W’\) has the same Frobenius norm as \(W\) (orthogonal transformations preserve norms) but its columns may have very different magnitudes.
If we choose \(Q\) such that the columns of \(W’\) are sorted by importance (e.g., by their contribution to the output variance), we can prune the least important columns of \(W’\) with minimal error.
Choosing Q via PCA: Compute the covariance of the layer’s output:
$$\Sigma_y = W \Sigma_x W^T$$where \(\Sigma_x = XX^T / N\) is the input covariance. The eigendecomposition of \(\Sigma_x\) gives:
$$\Sigma_x = Q \Lambda Q^T$$where \(\Lambda = \text{diag}(\lambda_1, \lambda_2, \ldots)\) with eigenvalues sorted in decreasing order. Using this \(Q\) as our rotation, the rotated input \(x’ = Q^T x\) has covariance \(\Lambda\) — a diagonal matrix. The rotated weight \(W’ = WQ\) applied to \(x’\) gives:
$$y_i = \sum_j W'_{ij} x'_j$$The contribution of feature \(j\) to the output variance is proportional to \(\lambda_j |W’_{:,j}|^2\). Since eigenvalues decay (often rapidly for LLMs), features corresponding to small eigenvalues can be pruned.
Absorbing the rotation: The beauty of SliceGPT is that the rotation can be absorbed into adjacent layers. For consecutive layers \(y = W_2(W_1 x)\), inserting \(QQ^T\) between them:
$$y = W_2 (Q Q^T) W_1 x = (W_2 Q)(Q^T W_1) x = W_2' W_1' x$$After absorbing \(Q\) into both weight matrices, we can slice (prune) the shared dimension. The model’s computational graph is unchanged — no rotation is needed at inference time.
Algorithm:
Algorithm: SliceGPT
────────────────────────────────────────────────────────────────
1. For each layer l = 1, ..., L:
a. Collect input activations X_l from calibration data
b. Compute input covariance: Σ_l = X_l X_l^T / N
c. Eigendecompose: Σ_l = Q_l Λ_l Q_l^T
d. Sort eigenvectors by eigenvalue (descending)
2. For each pair of adjacent layers:
a. Absorb Q_l into W_{l} and W_{l+1}:
W_l ← W_l Q_l
W_{l+1} ← Q_l^T W_{l+1}
3. Choose slicing fraction s (e.g., 25%)
4. Remove the last s fraction of columns from each W_l
and corresponding rows from W_{l+1}
5. The result is a smaller, dense model
────────────────────────────────────────────────────────────────Complexity: The PCA step requires \(O(d^2 N + d^3)\) per layer for the covariance and eigendecomposition. The rotation absorption is \(O(d^2 d_{\text{ffn}})\). Total: \(O(L \times d^3)\), similar to SparseGPT.
Results: SliceGPT can remove 20-30% of the embedding dimension of LLaMA-7B/13B/70B while maintaining competitive perplexity. Unlike unstructured pruning, the resulting model is dense and runs faster on standard hardware without any sparse kernel support.
Layer Pruning for LLMs#
The most aggressive form of structural pruning removes entire transformer layers. This yields maximum speedup (each removed layer eliminates all its parameters and computation) but risks significant quality degradation.
Block Importance Metric: A simple and effective metric for layer importance is based on the observation that transformer layers learn residual functions. Layer \(l\) computes:
$$h_{l+1} = h_l + f_l(h_l)$$where \(f_l\) is the layer’s transformation (attention + MLP). If \(f_l(h_l) \approx 0\), the layer is doing little and can be skipped. This motivates the Block Influence (BI) metric:
$$\text{BI}_l = \frac{\|h_{l+1} - h_l\|_2}{\|h_l\|_2} = \frac{\|f_l(h_l)\|_2}{\|h_l\|_2}$$Layers with low BI scores contribute little to the residual stream and are candidates for removal.
Shortened LLaMA (Kim et al., 2024) uses this approach, combined with optional light fine-tuning after layer removal:
Before Layer Pruning (32-layer LLaMA-7B):
┌───────────────────────────┐
│ Layer 0 (BI = 0.42) │ ← Keep (high importance)
├───────────────────────────┤
│ Layer 1 (BI = 0.31) │ ← Keep
├───────────────────────────┤
│ Layer 2 (BI = 0.28) │ ← Keep
├───────────────────────────┤
│ ... │
├───────────────────────────┤
│ Layer 15 (BI = 0.08) │ ← REMOVE (low importance)
├───────────────────────────┤
│ Layer 16 (BI = 0.07) │ ← REMOVE (low importance)
├───────────────────────────┤
│ Layer 17 (BI = 0.06) │ ← REMOVE (low importance)
├───────────────────────────┤
│ ... │
├───────────────────────────┤
│ Layer 31 (BI = 0.35) │ ← Keep (high importance)
└───────────────────────────┘
After Pruning (29-layer model):
┌───────────────────────────┐
│ Layer 0 (BI = 0.42) │
├───────────────────────────┤
│ Layer 1 (BI = 0.31) │
├───────────────────────────┤
│ ... │
├───────────────────────────┤
│ Layer 14 (BI = 0.12) │ ← Connected directly to
├───────────────────────────┤ former Layer 18
│ Layer 18 (BI = 0.15) │
├───────────────────────────┤
│ ... │
├───────────────────────────┤
│ Layer 31 (BI = 0.35) │
└───────────────────────────┘
Saved: 3 layers × 202M params = ~606M parameters (9% reduction)
Speed: ~10% faster inference (fewer sequential steps)The key finding is that middle layers (especially around layers 15-20 in a 32-layer model) often have the lowest BI scores. These layers tend to learn highly similar representations — a phenomenon called layer redundancy. Removing 2-5 layers from a 32-layer model typically increases PPL by 0.5-2.0 points, which is often acceptable.
Depth Pruning (Dynamic Depth)#
Rather than permanently removing layers, dynamic depth methods decide at inference time whether to skip a layer based on the current input. This is especially useful because not all inputs need the same depth of processing — simple continuations (e.g., completing “The capital of France is”) may need fewer layers than complex reasoning tasks.
Early exit is a form of depth pruning: if the model’s hidden state at layer \(l\) is already confident enough, skip all remaining layers and project directly to the vocabulary:
$$\hat{y} = \text{LMHead}(\text{RMSNorm}(h_l)) \quad \text{if confidence}(h_l) > \theta$$Confidence can be measured by the entropy of the predicted distribution or by the norm of the residual \(|f_l(h_l)|\).
Compounding: Pruning + Quantization for LLMs#
Pruning and quantization are complementary compression techniques that can be applied together for multiplicative benefits. Pruning reduces the number of parameters; quantization reduces the bits per parameter. Combined:
$$\text{Compression Ratio} = \frac{1}{(1 - s)} \times \frac{b_{\text{original}}}{b_{\text{quantized}}}$$where \(s\) is the sparsity ratio, \(b_{\text{original}}\) is the original bitwidth, and \(b_{\text{quantized}}\) is the quantized bitwidth.
SparseGPT + GPTQ Pipeline#
The standard pipeline for maximum compression:
- Prune with SparseGPT to 50% unstructured sparsity (or 2:4 structured).
- Quantize with GPTQ to 4-bit weights.
- Store the sparse-quantized model in a compressed format.
The compression calculation for LLaMA-70B:
| Configuration | Params (effective) | Bits/param | Memory | Compression |
|---|---|---|---|---|
| Dense FP16 | 70B | 16 | 140 GB | 1x |
| 50% sparse FP16 | 35B | 16 | 70 GB | 2x |
| Dense INT4 | 70B | 4 | 35 GB | 4x |
| 50% sparse INT4 | 35B | 4 | 17.5 GB | 8x |
| 2:4 sparse INT4 | 35B | 4 | 17.5 GB | 8x (hw-accelerated) |
At 8x compression, LLaMA-70B fits in 17.5 GB — within the memory of a single consumer GPU (RTX 4090 with 24 GB). This is a transformative reduction: from requiring multiple enterprise GPUs to running on consumer hardware.
Wanda + AWQ Combination#
Activation-Weighted Quantization (AWQ) and Wanda share the same core insight: activation magnitudes determine importance. They can be combined naturally:
- Compute activation statistics (shared between both methods).
- Prune with Wanda using \(|W_{ij}| \cdot |X_j|_2\).
- Quantize with AWQ using activation-aware scaling factors.
- Both methods use the same calibration data, so the overhead of combining them is minimal.
2:4 Sparsity + INT4 Quantization#
NVIDIA’s Ampere and Hopper architectures provide hardware acceleration for 2:4 structured sparsity through the Sparse Tensor Cores. When combined with INT4 quantization:
- Sparsity gain: 2:4 pattern gives 2x throughput on Sparse Tensor Cores.
- Quantization gain: INT4 gives 2x throughput over INT8 on Tensor Cores.
- Combined throughput: up to 4x over dense INT8, or 8x over dense FP16.
Memory calculation for LLaMA-70B with 2:4 + INT4:
Dense FP16 parameters: \(70 \times 10^9 \times 2 = 140\) GB
With 2:4 sparsity, we store 50% of weights plus a 2-bit index per group of 4:
- Weight storage: \(70 \times 10^9 \times 0.5 \times 0.5 = 17.5\) GB (50% weights at 4 bits = 0.5 bytes)
- Sparsity metadata: \(70 \times 10^9 \times 2 / (8 \times 4) = 4.375\) GB (2 bits per group of 4)
- Total: approximately 22 GB
Accuracy-Compression Pareto Frontier#
| Method | Compression | LLaMA-7B PPL | LLaMA-13B PPL | Hardware Speedup |
|---|---|---|---|---|
| Dense FP16 | 1x | 5.68 | 5.09 | 1x (baseline) |
| GPTQ INT4 | 4x | 6.09 | 5.36 | ~2x |
| SparseGPT 50% | 2x | 5.95 | 5.28 | 1x (no hw support) |
| Wanda 50% | 2x | 6.01 | 5.33 | 1x (no hw support) |
| SparseGPT 2:4 | 2x | 7.16 | 5.85 | 2x (Ampere+) |
| SparseGPT 2:4 + GPTQ 4b | 8x | 8.20 | 6.45 | ~4x |
| Wanda 2:4 + AWQ 4b | 8x | 8.35 | 6.52 | ~4x |
| SparseGPT 50% + GPTQ 4b | 8x | 7.05 | 5.95 | ~2x (no sparse hw) |
The sweet spot for most deployments is either (a) quantization only (GPTQ/AWQ INT4) for 4x compression with minimal quality loss, or (b) 2:4 sparsity + INT4 for 8x compression when hardware supports sparse acceleration.
Semi-Structured and Pattern-Based Pruning for LLMs#
2:4 Sparsity on LLMs#
NVIDIA’s 2:4 (or N:M) sparsity constraint requires that exactly 2 out of every 4 consecutive weights are zero. This provides a regular pattern that Sparse Tensor Cores can exploit for 2x throughput.
Applying SparseGPT with 2:4 constraint: Replace the pruning decision step in SparseGPT with a constrained selection. Instead of choosing the globally least salient weights, within each group of 4 consecutive columns, prune the 2 with the lowest saliency. The OBS weight update still applies — remaining weights are compensated.
Applying Wanda with 2:4 constraint: Similarly, within each group of 4, prune the 2 with the lowest Wanda scores (\(|W_{ij}| \cdot |X_j|_2\)). This is the simplest possible implementation of hardware-friendly LLM pruning.
Training LLMs with N:M from Scratch vs. Post-Training N:M#
Training with N:M sparsity from scratch uses the SR-STE (Straight-Through Estimator for Structured Sparsity) approach:
- Maintain a dense weight matrix \(W_{\text{dense}}\).
- At each forward pass, apply the N:M mask to get \(W_{\text{sparse}}\).
- Compute loss with \(W_{\text{sparse}}\).
- In the backward pass, compute gradients with respect to \(W_{\text{sparse}}\) but apply them to \(W_{\text{dense}}\) (straight-through).
- Periodically update the mask based on current weight magnitudes.
For LLMs, training from scratch with N:M is extremely expensive (it requires the same resources as dense training). Post-training N:M pruning (SparseGPT or Wanda with the 2:4 constraint) is therefore the practical choice, even though it yields slightly worse results.
Performance on Actual Hardware#
Measured on NVIDIA A100 (Ampere) for LLaMA-7B inference, batch size 1, sequence length 2048:
| Configuration | Latency (ms/token) | Throughput (tokens/s) | Memory (GB) |
|---|---|---|---|
| Dense FP16 | 28.5 | 35.1 | 13.5 |
| Dense INT8 | 18.2 | 54.9 | 7.0 |
| 2:4 Sparse FP16 | 19.8 | 50.5 | 7.2 |
| 2:4 Sparse INT8 | 12.1 | 82.6 | 3.8 |
| Dense INT4 (GPTQ) | 11.5 | 86.9 | 3.8 |
| 2:4 Sparse INT4 | 8.2 | 121.9 | 2.2 |
The 2:4 sparse INT4 configuration achieves 3.5x speedup over dense FP16 with 6x memory reduction, while typically maintaining acceptable quality (PPL increase of 2-4 points on WikiText-2).
Pruning for Efficient LLM Inference#
KV-Cache Pruning#
During autoregressive generation, LLMs cache the key-value (KV) tensors for all previous tokens to avoid redundant computation. This KV-cache grows linearly with sequence length and becomes the dominant memory consumer for long sequences.
KV-cache memory formula:
$$\text{KV memory} = 2 \times L \times n_h \times d_h \times \text{seq\_len} \times \text{batch} \times b$$where:
- Factor 2 = keys + values
- \(L\) = number of layers
- \(n_h\) = number of attention heads (or KV heads in GQA)
- \(d_h\) = head dimension
- \(\text{seq_len}\) = sequence length
- \(\text{batch}\) = batch size
- \(b\) = bytes per element
Numerical example for LLaMA-7B (\(L=32, n_h=32, d_h=128\)):
At sequence length 4096, FP16:
$$\text{KV memory} = 2 \times 32 \times 32 \times 128 \times 4096 \times 1 \times 2 = 2{,}147{,}483{,}648 \text{ bytes} \approx 2 \text{ GB}$$At sequence length 128K (long-context models):
$$\text{KV memory} = 2 \times 32 \times 32 \times 128 \times 131072 \times 1 \times 2 \approx 64 \text{ GB}$$This is larger than the model weights themselves (13.5 GB for LLaMA-7B). For long-context applications, KV-cache memory is the primary bottleneck.
KV-cache pruning removes tokens from the cache that are unlikely to be attended to in the future, keeping the cache at a manageable size.
KV-Cache Pruning Strategies:
─────────────────────────────────────────────────────────────
Full KV-Cache (no pruning):
┌────────────────────────────────────────────────────┐
│ t1 │ t2 │ t3 │ t4 │ t5 │ t6 │ ... │ t_n │ NEW │
└────────────────────────────────────────────────────┘
Memory: O(n) — grows without bound
Sliding Window (fixed):
┌────────────────────────────────────────────────────┐
│ ██ │ ██ │ ██ │ ██ │ t_{n-3} │ t_{n-2} │ t_{n-1} │ NEW │
└────────────────────────────────────────────────────┘
██ = evicted Memory: O(window_size)
Problem: loses long-range dependencies
H2O (Heavy-Hitter Oracle):
┌────────────────────────────────────────────────────┐
│ t1 │ ██ │ t3 │ ██ │ ██ │ t6 │ t_{n-2} │ t_{n-1} │ NEW │
└────────────────────────────────────────────────────┘
Keep: high-attention tokens + recent tokens
Memory: O(budget) — fixed budget
StreamingLLM (sink + window):
┌────────────────────────────────────────────────────┐
│ t1 │ t2 │ ██ │ ██ │ ██ │ ██ │ t_{n-2} │ t_{n-1} │ NEW │
└────────────────────────────────────────────────────┘
Keep: first few tokens (sinks) + recent window
Memory: O(sinks + window_size)
─────────────────────────────────────────────────────────────H2O (Heavy-Hitter Oracle) — Zhang et al., 2023#
H2O observes that attention patterns in LLMs follow a power-law distribution: a small fraction of tokens receive a disproportionately large share of attention across all heads and layers. These “heavy hitter” tokens carry critical information.
Algorithm:
- Maintain a cumulative attention score for each cached token:
- Set a fixed budget \(B\) for the KV-cache.
- When the cache exceeds \(B\), evict the token with the lowest cumulative attention score (excluding the most recent \(W\) tokens, which are always kept).
The combination of heavy hitters (globally important tokens) + recent tokens (locally important for coherent generation) preserves both long-range factual knowledge and short-range fluency.
Results: H2O can reduce KV-cache by 5-10x (keeping only 10-20% of tokens) with minimal perplexity increase on long-context benchmarks. This enables LLaMA-7B to handle 100K+ token contexts within the memory budget originally needed for 10K tokens.
StreamingLLM — Xiao et al., 2023#
StreamingLLM makes a simpler but powerful observation: the first few tokens in the sequence (the “attention sinks”) receive abnormally high attention from all subsequent tokens, regardless of their semantic content. This is an artifact of the softmax function — it must assign attention mass somewhere, and the first positions serve as a default sink.
Strategy: Keep the first \(S\) tokens (sinks, typically \(S = 4\)) and a sliding window of the most recent \(W\) tokens. Evict everything in between.
Memory: \(O(S + W)\) — constant regardless of total sequence length. This enables truly infinite-length streaming generation.
Limitation: Information between the sinks and the window is lost. StreamingLLM is best suited for streaming applications (chatbots, real-time transcription) where maintaining a complete history is not required.
Activation Pruning / Dynamic Sparsity#
A complementary approach to weight pruning is activation pruning: dynamically skipping computations based on the input at inference time. This exploits the fact that for any given input token, a large fraction of MLP neurons produce negligibly small outputs.
ReLU-Based Sparsity#
Models using ReLU activation (rather than GeLU or SiLU) naturally produce sparse activations — roughly 90%+ of neurons output zero for any given input. This sparsity can be exploited by only computing the non-zero neurons.
Some recent work has explored replacing GeLU/SiLU with ReLU in LLMs specifically to unlock this activation sparsity. The quality impact is small (< 0.5% PPL increase) but the inference speedup can be significant (2-3x for MLP layers).
Deja Vu: Contextual Sparsity Prediction (Liu et al., 2023)#
Deja Vu takes activation pruning further by training a small predictor network that, given the current hidden state, predicts which neurons and attention heads will be active (have significant output) for the current input.
Architecture:
┌──────────────────────────────────────────────────┐
│ Deja Vu: Contextual Sparsity │
│ │
│ Input hidden state h_l │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Small Predictor │ (2-layer MLP, ~0.5% of │
│ │ Network │ main model size) │
│ └────────┬────────┘ │
│ │ │
│ ▼ │
│ Predicted active set: {neuron_3, neuron_7, │
│ neuron_15, ...} (~25% of total neurons) │
│ │ │
│ ▼ │
│ ┌─────────────────┐ │
│ │ Execute ONLY │ (75% computation saved │
│ │ active neurons │ in this layer) │
│ └────────┬────────┘ │
│ │ │
│ ▼ │
│ Output h_{l+1} (approximately same as full) │
└──────────────────────────────────────────────────┘Results: Deja Vu demonstrates that 75%+ of neurons can be skipped per token with less than 0.1% accuracy loss. For MLP layers (which dominate computation), this translates to a 2-4x speedup. The predictor adds negligible overhead because it is much smaller than the main model.
Latency analysis: For a LLaMA-7B MLP layer with \(d = 4096\), \(d_{\text{ffn}} = 11008\):
- Full MLP: \(3 \times d \times d_{\text{ffn}} = 3 \times 4096 \times 11008 \approx 135M\) MACs
- With 75% sparsity prediction: \(3 \times 4096 \times 2752 \approx 34M\) MACs + predictor overhead (~0.5M MACs)
- Speedup: \(135M / 34.5M \approx 3.9\times\)
However, the practical speedup is lower (typically 2-3x) because: (1) sparse computation has lower hardware utilization than dense, (2) the predictor introduces a serial dependency (must run before the main layer), and (3) memory access patterns become irregular.
Pruning Emerging Architectures#
Mamba/State Space Models Pruning#
State Space Models (SSMs) like Mamba replace attention with a linear recurrence, eliminating the KV-cache entirely. Pruning SSMs involves:
- Width pruning: Reduce the state dimension \(N\) or the expansion factor \(E\). SSMs typically use \(N = 16\) and \(E = 2\), so there is less room to prune than in transformers.
- Channel pruning: Remove input/output channels from the SSM block. Similar to MLP neuron pruning.
- Selective state removal: Some state dimensions contribute more than others. Importance can be measured by the magnitude of the \(C\) matrix (output projection).
SSM pruning is less studied than transformer pruning because SSMs are inherently more parameter-efficient. A Mamba model with equivalent quality to a transformer is typically 2-3x smaller, reducing the need for aggressive pruning.
Mixture of Experts (MoE) Pruning#
MoE models (Mixtral, Switch Transformer) have a unique structure: each MLP layer is replaced by \(E\) expert MLPs, and a router selects the top-\(k\) experts per token. The total parameter count is \(E \times\) larger than a dense model, but only \(k/E\) fraction is activated per token.
Expert merging: Identify pairs of experts that learn similar functions (measured by weight cosine similarity or output correlation on calibration data). Merge them by averaging their weights:
$$W_{\text{merged}} = \frac{1}{2}(W_{\text{expert}_i} + W_{\text{expert}_j})$$This reduces the number of experts (and total parameters) while maintaining the routing structure. Mixtral-8x7B (8 experts) can often be compressed to 4-6 experts with modest quality loss.
Expert dropping: Simply remove the least-used experts (those routed to least frequently). The router is adjusted to distribute their probability mass among remaining experts. This is the simplest approach but can cause load imbalance.
Multi-Modal LLM Pruning#
Multi-modal LLMs (LLaVA, GPT-4V) combine a vision encoder with a language model. Pruning these models requires separate strategies:
- Vision encoder: Typically a ViT, which can be pruned using standard CNN/ViT pruning techniques (token merging, attention head pruning, patch embedding pruning).
- Language model: Uses the LLM pruning techniques discussed throughout this post.
- Cross-modal projector: Usually a small MLP connecting vision and language spaces. This is typically not pruned because it is small and critical for alignment.
The key question is how to balance pruning between the vision and language components. Empirically, the language model tolerates pruning better because it is much larger and more redundant. A common strategy is to prune the language model to 50% sparsity while leaving the vision encoder at 20-30% sparsity.
Evaluation and Benchmarking#
Standard Benchmarks#
LLM pruning methods are evaluated on:
Perplexity (PPL) on held-out text:
- WikiText-2: ~250K tokens of Wikipedia articles. The standard benchmark for language modeling quality.
- C4: Colossal Clean Crawled Corpus. Larger and more diverse than WikiText-2.
Zero-shot downstream tasks:
- ARC (AI2 Reasoning Challenge): Science exam questions. Tests reasoning.
- HellaSwag: Sentence completion. Tests common sense.
- WinoGrande: Pronoun resolution. Tests commonsense reasoning.
- MMLU: Massive Multitask Language Understanding. 57 subjects, tests broad knowledge.
- TruthfulQA: Tests tendency to generate truthful answers.
Generation quality: Human evaluation or GPT-4-based evaluation of generation fluency, coherence, and factual accuracy.
Comprehensive Comparison Table#
| Method | Type | Sparsity | LLaMA-7B WikiText PPL | ARC | HellaSwag | WinoGrande | Time (7B) | Memory |
|---|---|---|---|---|---|---|---|---|
| Dense | — | 0% | 5.68 | 51.2 | 76.2 | 70.0 | — | 13.5 GB |
| Magnitude | Unstr. | 50% | 17.29 | 35.4 | 52.1 | 54.3 | <1 min | 13.5 GB |
| SparseGPT | Unstr. | 50% | 5.95 | 49.8 | 74.5 | 68.9 | ~15 min | 14.5 GB |
| SparseGPT | 2:4 | 50% | 7.16 | 46.2 | 70.8 | 66.1 | ~15 min | 14.5 GB |
| Wanda | Unstr. | 50% | 6.01 | 49.5 | 74.1 | 68.5 | ~30 sec | 13.8 GB |
| Wanda | 2:4 | 50% | 7.26 | 45.8 | 70.2 | 65.7 | ~30 sec | 13.8 GB |
| SliceGPT | Struct. | 25% dim | 6.85 | 47.5 | 72.0 | 67.2 | ~20 min | ~10 GB |
| Layer Prune | Struct. | 3 layers | 7.50 | 44.1 | 69.5 | 64.8 | <1 min | ~12 GB |
| Head Prune | Struct. | 25% heads | 6.45 | 48.2 | 73.0 | 68.0 | ~5 min | ~11 GB |
Key takeaways from the table:
- Magnitude pruning catastrophically fails at 50% sparsity for LLMs.
- SparseGPT and Wanda achieve comparable quality, with SparseGPT slightly better.
- Structured pruning (SliceGPT, head pruning) yields real dense-model speedups at the cost of slightly worse quality than unstructured methods at the same compression ratio.
- Layer pruning is the most aggressive but can maintain reasonable quality when removing only a few layers.
Throughput Measurements#
Actual inference throughput depends heavily on the sparsity pattern, hardware, and software stack:
| Configuration | Tokens/sec (A100, bs=1) | Tokens/sec (RTX 4090, bs=1) |
|---|---|---|
| Dense FP16 | 35 | 28 |
| Dense INT4 (GPTQ) | 87 | 72 |
| 2:4 Sparse FP16 | 51 | N/A (no sparse support) |
| 2:4 Sparse INT8 | 83 | N/A |
| 50% Unstructured FP16 | 36 | 29 (no speedup without sparse kernels) |
| SliceGPT 25% reduction | 47 | 38 |
| 3 Layers Removed | 39 | 31 |
Critical observation: unstructured sparsity provides no speedup on standard hardware. The 50% sparse model runs at the same speed as the dense model because standard dense GEMM kernels process all matrix entries regardless of sparsity. Only specialized sparse kernels (available on NVIDIA Ampere+ for 2:4 patterns) provide actual acceleration.
Practical Guide#
Decision Flowchart: Which Method to Use?#
Start: I want to compress my LLM
│
├── Q: Do you need actual inference speedup on current hardware?
│ │
│ ├── YES ──► Q: Does your hardware support 2:4 sparsity?
│ │ │
│ │ ├── YES ──► SparseGPT or Wanda with 2:4 constraint
│ │ │ + INT4 quantization (GPTQ/AWQ)
│ │ │ Expected: ~4x speedup, ~8x compression
│ │ │
│ │ └── NO ──► Quantization alone (GPTQ/AWQ INT4)
│ │ is your best bet for speedup.
│ │ Consider SliceGPT for structural reduction.
│ │ Expected: ~2-3x speedup, ~4x compression
│ │
│ └── NO (just want smaller model for storage/transfer)
│ │
│ ├── Q: How much quality loss is acceptable?
│ │ │
│ │ ├── Minimal (<5% PPL increase)
│ │ │ └── Wanda or SparseGPT at 50% unstructured
│ │ │ + GPTQ INT4
│ │ │ Expected: ~8x size reduction
│ │ │
│ │ └── Moderate (5-15% PPL increase)
│ │ └── SparseGPT at 60% + INT4
│ │ or 2:4 sparsity + INT4
│ │ Expected: ~10-16x size reduction
│ │
│ └── Q: How much time can you spend on pruning?
│ │
│ ├── Minutes ──► Wanda (30 sec for 7B model)
│ │
│ └── Hours OK ──► SparseGPT (better quality)
│
└── Q: Do you need to fit model in very limited memory?
│
├── < 4 GB ──► Aggressive: 2:4 + INT4 + layer pruning
│ May need to accept significant quality loss
│
└── < 8 GB ──► 50% sparse + INT4, or INT4 alone
for 7B models this is quite achievableCalibration Data Selection Tips#
- Use diverse data: 128 sequences from C4 is the standard. Do not use only code, only dialogue, or only one domain.
- Match the deployment domain (partially): If the model will primarily be used for code generation, include some code in the calibration set (but not exclusively).
- Sequence length matters: Use the same context length for calibration as you expect during inference. Short calibration sequences (512 tokens) may not capture long-range patterns needed for long-context inference (4096+ tokens).
- More is not always better: Beyond 128-256 sequences, additional calibration data provides diminishing returns for both SparseGPT and Wanda. The Hessian / activation statistics converge quickly.
- Avoid degenerate samples: Remove sequences that are mostly whitespace, repetitive characters, or non-natural-language content. These can skew activation statistics.
Common Failures and Debugging#
| Symptom | Likely Cause | Solution |
|---|---|---|
| PPL explodes (>100) | First or last layer over-pruned | Reduce sparsity for layer 0 and layer L-1 |
| Repetitive generation | Attention heads pruned too aggressively | Reduce head pruning ratio, especially for induction heads |
| Factual hallucinations increase | Retrieval heads removed | Use gradient-based head importance scoring instead of random |
| NaN/Inf in outputs | Numerical issues in Hessian computation | Increase regularization \(\epsilon\) in \(H + \epsilon I\) |
| No speedup despite high sparsity | Using unstructured sparsity on dense hardware | Switch to 2:4 structured sparsity or use sparse-aware kernels |
| Quality varies across tasks | Calibration data too narrow | Use more diverse calibration data |
Tools and Frameworks#
- SparseML (Neural Magic): Production-ready framework for pruning and sparse inference. Supports SparseGPT and Wanda with optimized CPU sparse kernels.
- llama.cpp: Supports GGML format with sparse weight matrices. Wanda-pruned models can be exported directly.
- NVIDIA TensorRT-LLM: Supports 2:4 structured sparsity with Sparse Tensor Core acceleration on Ampere and Hopper GPUs.
- torch.ao (Architecture Optimization): PyTorch’s built-in support for structured pruning and semi-structured sparsity.
- Hugging Face PEFT + Pruning: Integration of pruning with parameter-efficient fine-tuning for combined compression and adaptation.
Current State and Future Directions (2024-2025)#
Trends#
From post-training to training-aware LLM pruning: As more organizations train LLMs from scratch, there is increasing interest in incorporating sparsity during pre-training. Cerebras and Neural Magic have demonstrated that training with 80% unstructured sparsity from the start can match dense model quality while reducing training compute by 2-5x.
Scaling laws for sparse models: Extending the Chinchilla scaling laws to sparse models. Early results suggest that sparse models follow different scaling laws — with optimal sparsity increasing as model size grows. A 70B model may achieve its best quality-efficiency tradeoff at 60-70% sparsity, while a 7B model peaks at 40-50%.
Hardware co-design for sparsity: Next-generation AI accelerators are being designed with native support for flexible sparsity patterns (beyond 2:4). NVIDIA’s Blackwell architecture extends sparse support, and startups like Cerebras and Groq build sparsity into their fundamental architecture.
Compositional compression: Combining pruning + quantization + distillation + architecture search in a unified framework. Rather than applying each technique independently, joint optimization finds better Pareto-optimal points.
Pruning for inference-time efficiency: KV-cache pruning, dynamic depth, and speculative decoding are becoming standard components of production LLM serving stacks. These techniques reduce per-query cost without modifying the model weights.
Open Problems and Research Directions#
Pruning with guaranteed capability preservation: Current methods optimize for perplexity, which does not capture all capabilities (reasoning, factual knowledge, instruction following). How can we prune while provably preserving specific capabilities?
Structured sparsity beyond 2:4: Can we design hardware-friendly sparsity patterns (e.g., 4:8, block sparsity) that offer higher compression while maintaining real speedups?
Pruning for fine-tuned/RLHF models: Most pruning research focuses on base models. How does pruning interact with instruction tuning and RLHF? Preliminary evidence suggests that RLHF-tuned models are more sensitive to pruning.
Data-free pruning: Eliminating the need for calibration data entirely. Some methods explore using the weight statistics alone (e.g., weight distribution shape) to guide pruning decisions.
Pruning as a regularizer: Some evidence suggests that mild pruning (10-20%) can actually improve generalization on out-of-distribution tasks by removing spurious correlations learned during training.
Summary#
Master Comparison of All LLM Pruning Methods#
| Method | Year | Type | Retraining | Hessian | Speed | Best For |
|---|---|---|---|---|---|---|
| Magnitude | Classic | Unstructured | Optional | No | Instant | Baseline only |
| SparseGPT | 2023 | Unstructured/2:4 | No | Yes (\(H^{-1}\)) | Minutes | Best quality, one-shot |
| Wanda | 2023 | Unstructured/2:4 | No | No | Seconds | Speed-quality balance |
| SliceGPT | 2024 | Structured (width) | No | No (PCA) | Minutes | Dense model reduction |
| Head Pruning | Various | Structured (heads) | Optional | Optional | Minutes | Attention compression |
| Layer Pruning | Various | Structured (depth) | Optional | No | Instant | Aggressive compression |
| H2O | 2023 | KV-cache | No | No | Runtime | Long-context serving |
| StreamingLLM | 2023 | KV-cache | No | No | Runtime | Infinite-length streaming |
| Deja Vu | 2023 | Dynamic (activation) | Yes (predictor) | No | Runtime | Per-token efficiency |
Key Takeaways#
LLM pruning requires fundamentally different approaches than CNN pruning because retraining is infeasible. One-shot methods (SparseGPT, Wanda) are the practical standard.
Activation awareness is essential: Simple magnitude pruning fails catastrophically for LLMs because of activation outliers. Both SparseGPT (via the full Hessian) and Wanda (via activation norms) address this, with Wanda offering a remarkably simple and effective solution.
The OBS framework (SparseGPT) provides the theoretical gold standard: The formula \(\delta_w = -\frac{w_q}{[H^{-1}]{qq}} \cdot (H^{-1}){:,q}\) optimally compensates for each pruned weight. Wanda approximates this with a diagonal Hessian assumption, trading some accuracy for dramatic simplicity.
Structured pruning is necessary for real speedups: Unstructured sparsity reduces model size but does not reduce inference latency on standard hardware. 2:4 structured sparsity (with hardware support), head pruning, layer pruning, and width reduction (SliceGPT) provide actual deployment benefits.
Pruning and quantization are complementary: The combination of 50% sparsity + INT4 quantization achieves 8x compression, potentially fitting LLaMA-70B in under 20 GB of memory.
KV-cache pruning is a separate but equally important frontier: For long-context applications, KV-cache memory dominates. H2O and StreamingLLM provide practical solutions that enable 5-10x longer contexts within the same memory budget.
The field is moving toward inference-time adaptive sparsity: Rather than static pruning, methods like Deja Vu and dynamic depth adapt the computation per input, exploiting the fact that not all tokens need the same amount of processing.
Series Wrap-Up: Pruning and Quantization Together#
This post concludes the AI Accelerator series on model compression. Across the series, we have covered:
- Quantization fundamentals: Number representation, uniform/non-uniform schemes, calibration strategies.
- Post-training quantization (PTQ): GPTQ, AWQ, and their application to LLMs.
- Quantization-aware training (QAT): Training with simulated quantization for maximum quality.
- Extreme and mixed-precision quantization: Sub-4-bit methods, mixed precision allocation.
- Pruning fundamentals: Magnitude pruning, OBD, OBS, theoretical foundations.
- Structured vs. unstructured pruning: Filter pruning, channel pruning, hardware implications.
- Advanced pruning methods: Lottery ticket hypothesis, progressive pruning, neural architecture search.
- Pruning for LLMs (this post): SparseGPT, Wanda, KV-cache pruning, and the unique challenges of billion-parameter models.
The central lesson across the entire series is that model compression is not a single technique but a toolbox. The best deployment strategy depends on the specific constraints: target hardware, latency requirements, memory budget, quality tolerance, and engineering resources. The table below summarizes the overall compression landscape:
| Technique | Compression | Quality Loss | Hardware Support | Engineering Effort |
|---|---|---|---|---|
| FP16 (from FP32) | 2x | ~0% | Universal | Trivial |
| INT8 quantization | 4x (from FP32) | <0.1% | Excellent | Low |
| INT4 quantization | 8x (from FP32) | 0.5-2% | Good (Ampere+) | Medium |
| 50% unstructured pruning | 2x | 1-5% | Poor (needs sparse HW) | Medium |
| 2:4 structured sparsity | 2x | 2-8% | Good (Ampere+) | Medium |
| Layer removal (10%) | 1.1x | 3-10% | Universal | Low |
| 2:4 + INT4 | 8x (effective) | 5-15% | Good (Ampere+) | High |
| Pruning + INT4 + distillation | 8-16x | 2-5% | Variable | Very high |
As hardware evolves to better support sparsity and low-precision arithmetic, and as algorithms improve to minimize quality loss at extreme compression ratios, the gap between a 70-billion-parameter model and its 10x-compressed counterpart will continue to narrow. The goal is not merely to make models smaller, but to make the full power of frontier language models accessible on the hardware that people actually have.
References
- Frantar, E., & Alistarh, D. (2023). SparseGPT: Massive Language Models Can Be Accurately Pruned in One-Shot. ICML 2023.
- Sun, M., et al. (2023). A Simple and Effective Pruning Approach for Large Language Models. ICLR 2024.
- Ashkboos, S., et al. (2024). SliceGPT: Compress Large Language Models by Deleting Rows and Columns. ICLR 2024.
- Zhang, Z., et al. (2023). H2O: Heavy-Hitter Oracle for Efficient Generative Inference of Large Language Models. NeurIPS 2023.
- Xiao, G., et al. (2023). Efficient Streaming Language Models with Attention Sinks. ICLR 2024.
- Liu, Z., et al. (2023). Deja Vu: Contextual Sparsity for Efficient LLMs at Inference Time. ICML 2023.
- Kim, B., et al. (2024). Shortened LLaMA: A Simple Depth Pruning for Large Language Models. ICLR 2024 Workshop.
- Yin, L., et al. (2023). Outlier Weighed Layerwise Sparsity (OWL): A Missing Secret Sauce for Pruning LLMs to High Sparsity. ICML 2024.
- Dettmers, T., et al. (2022). LLM.int8(): 8-bit Matrix Multiplication for Transformers at Scale. NeurIPS 2022.
- LeCun, Y., Denker, J., & Solla, S. (1989). Optimal Brain Damage. NeurIPS 1989.
- Hasselmo, M., et al. (1993). Optimal Brain Surgeon and General Network Pruning. IEEE ICNN 1993.