Table of Contents
Overview#
The simplest pruning strategy — removing weights with the smallest magnitudes — served as the foundation of neural network compression for decades. However, magnitude pruning rests on a fragile assumption: that a weight’s current absolute value is a reliable proxy for its importance to the network’s output. In practice, this assumption breaks in numerous scenarios. A weight that is currently small may be in the middle of growing toward a critical value during training. A weight that is currently large may be redundant given other weights in its layer. And at initialization, before any training has occurred, all weights are random — magnitude tells us nothing about future importance.
These limitations motivated a rich body of research into advanced pruning methods that go far beyond magnitude. The field can be organized along three axes:
When to prune:
- Before training (pruning at initialization): SNIP, GraSP, SynFlow
- During training (dynamic/continuous pruning): Movement Pruning, Continuous Sparsification, STR, Powerpropagation
- After training (post-hoc pruning): Taylor pruning, OBS/OBD, magnitude pruning
What criterion to use:
- Magnitude (weight size)
- Gradient-based (first-order or second-order Taylor expansion)
- Movement-based (training dynamics)
- Sensitivity-based (connection sensitivity)
- Gradient-flow preservation (trainability)
- Path-based (synaptic flow)
- Regularization-induced (L1, Group LASSO, Hoyer)
How to schedule pruning:
- One-shot (prune everything at once)
- Iterative (prune gradually over multiple rounds)
- Continuous (soft masks that evolve during training)
This post provides a comprehensive treatment of each major advanced method, with full mathematical derivations, algorithmic pseudocode, and critical analysis. We assume familiarity with basic pruning concepts (masking, sparsity ratios, structured vs. unstructured pruning) covered in the Pruning Fundamentals post.
Movement Pruning (Sanh et al., 2020) — Deep Dive#
Motivation: Beyond the Snapshot#
Magnitude pruning evaluates weights based on a static snapshot — their current values at the moment pruning is applied. This is fundamentally at odds with how neural networks learn. During training, weights are in constant flux. A weight’s current magnitude tells us where it is, but not where it is going.
Consider two weights during fine-tuning of a pre-trained BERT model:
- Weight A has magnitude 0.5, but the gradient is pushing it toward zero. It is becoming less important.
- Weight B has magnitude 0.1, but the gradient is pushing it away from zero. It is becoming more important.
Magnitude pruning would keep A and remove B — the exact opposite of what training dynamics suggest. This problem is especially acute during fine-tuning of pre-trained models, where the initial magnitudes reflect the pre-training task, not the target task.
Movement pruning addresses this by scoring weights based on their movement during training — specifically, whether they are moving toward zero (becoming unimportant) or away from zero (becoming important).
Score Definition and Derivation#
The movement score for weight \(w_i\) at training step \(t\) is defined as:
$$S_i^{(t)} = S_i^{(t-1)} + \alpha \cdot w_i^{(t)} \cdot \frac{\partial L}{\partial w_i^{(t)}}$$where \(\alpha\) is a scaling factor (typically absorbed into the learning rate), and \(S_i^{(0)} = 0\).
To understand this formula, consider the weight update rule in standard gradient descent:
$$w_i^{(t+1)} = w_i^{(t)} - \eta \frac{\partial L}{\partial w_i^{(t)}}$$The change in the weight is:
$$\Delta w_i^{(t)} = w_i^{(t+1)} - w_i^{(t)} = -\eta \frac{\partial L}{\partial w_i^{(t)}}$$Now consider the product \(w_i^{(t)} \cdot \frac{\partial L}{\partial w_i^{(t)}}\). There are two cases:
Case 1: Positive product (\(w_i > 0\) and \(\frac{\partial L}{\partial w_i} > 0\), or \(w_i < 0\) and \(\frac{\partial L}{\partial w_i} < 0\)).
When \(w_i > 0\) and the gradient is positive, the update \(\Delta w_i = -\eta \cdot (\text{positive})\) is negative, so the weight moves in the negative direction. But the weight itself is positive. So the weight is moving toward zero? No — let us think more carefully.
Actually, the key insight is about movement away from zero in terms of the weight’s increasing importance. Let us reconsider. The product \(w_i \cdot \frac{\partial L}{\partial w_i}\) can be rewritten in terms of the gradient update:
$$w_i \cdot \frac{\partial L}{\partial w_i} = -\frac{1}{\eta} w_i \cdot \Delta w_i$$When this product is positive, it means \(w_i \cdot \Delta w_i < 0\), which means the weight update is in the opposite direction to the current weight — the weight is moving toward zero. This means the optimization is actively shrinking this weight, and it should receive a positive movement score (it is being deemed unimportant by training dynamics… wait, let us re-examine the convention).
Let us be precise about the original paper’s convention. The score accumulates \(w_i \cdot \frac{\partial L}{\partial w_i}\). When this is positive:
- If \(w_i > 0\) and \(g_i > 0\): the gradient descent update \(\Delta w_i = -\eta g_i < 0\) pushes \(w_i\) toward zero. The weight is shrinking.
- If \(w_i < 0\) and \(g_i < 0\): the update \(\Delta w_i = -\eta g_i > 0\) pushes \(w_i\) toward zero. The weight is shrinking.
So a positive product means the weight is moving toward zero — it is becoming less important. But this gives a higher score \(S_i\). In movement pruning, weights with low scores are pruned. Equivalently, the paper frames it so that:
- High score = weight consistently moves toward zero = should be pruned
Wait — the paper actually uses the opposite convention. Let us re-derive from the paper’s perspective.
The paper defines movement as the direction the weight travels during fine-tuning and argues that weights moving away from zero are gaining importance (the fine-tuning task needs them), while weights moving toward zero are losing importance. The score is:
$$S_i^{(t)} = S_i^{(t-1)} + \alpha \cdot w_i^{(t)} \cdot \left(-\frac{\partial L}{\partial w_i^{(t)}}\right)$$Note the negative sign — this uses the negative gradient (the actual update direction). Now:
- If the weight is positive and the update pushes it further positive (away from zero), the product \(w_i \cdot (-g_i)\) is positive, giving a high score.
- If the weight is negative and the update pushes it further negative (away from zero), the product is also positive.
- If the weight is moving toward zero, the product is negative, giving a low score.
Weights with the highest scores are kept; weights with the lowest scores are pruned.
In practice, the accumulated score can be equivalently written as:
$$S_i = \sum_{t=1}^{T} w_i^{(t)} \cdot \Delta w_i^{(t)}$$where \(\Delta w_i^{(t)} = -\eta \frac{\partial L}{\partial w_i^{(t)}}\) is the actual weight update. This sum is positive when the weight consistently moves away from zero and negative when it consistently moves toward zero.
Numerical Example#
Consider a weight \(w = 0.3\) over three training steps:
| Step \(t\) | \(w^{(t)}\) | \(g^{(t)} = \frac{\partial L}{\partial w}\) | \(\Delta w = -0.01 \cdot g\) | \(w \cdot \Delta w\) | Cumulative \(S\) |
|---|---|---|---|---|---|
| 1 | 0.30 | -2.0 | +0.02 | +0.006 | +0.006 |
| 2 | 0.32 | -1.5 | +0.015 | +0.0048 | +0.0108 |
| 3 | 0.335 | -1.0 | +0.01 | +0.00335 | +0.01415 |
The weight starts at 0.3 and the gradient consistently pushes it to grow (gradient is negative, so update is positive). The weight moves away from zero at every step, accumulating a positive score of 0.01415. This weight is important — movement pruning will keep it.
Now consider another weight \(w = 0.5\) that is moving toward zero:
| Step \(t\) | \(w^{(t)}\) | \(g^{(t)}\) | \(\Delta w\) | \(w \cdot \Delta w\) | Cumulative \(S\) |
|---|---|---|---|---|---|
| 1 | 0.50 | +3.0 | -0.03 | -0.015 | -0.015 |
| 2 | 0.47 | +2.5 | -0.025 | -0.01175 | -0.02675 |
| 3 | 0.445 | +2.0 | -0.02 | -0.0089 | -0.03565 |
Despite having a larger magnitude (0.5 vs 0.3), this weight accumulates a negative score of -0.03565. Movement pruning will prune it. Magnitude pruning would have kept it and pruned the first weight — the opposite decision.
Soft Movement Pruning#
In soft movement pruning, the scores \(S_i\) are converted to binary masks via a threshold \(\tau\), and the straight-through estimator (STE) is used to propagate gradients through the non-differentiable thresholding operation.
The mask is:
$$m_i = \mathbb{1}[S_i > \tau]$$where \(\tau\) is chosen to achieve the target sparsity level (e.g., the \(k\)-th percentile of all scores for \(k\%\) sparsity).
The effective weight is \(\tilde{w}_i = m_i \cdot w_i\).
Forward pass: Use \(\tilde{w}_i = \mathbb{1}[S_i > \tau] \cdot w_i\).
Backward pass (STE): Pretend the threshold function is the identity, so:
$$\frac{\partial L}{\partial S_i} \approx \frac{\partial L}{\partial m_i} = w_i \cdot \frac{\partial L}{\partial \tilde{w}_i}$$This allows the scores to be updated via gradient descent alongside the weights.
Hard Movement Pruning#
In hard movement pruning, the top-\(k\) weights by movement score are selected at each step, and the rest are zeroed out. There is no straight-through estimator — the scores are not learned via backpropagation but simply accumulated from the \(w \cdot \Delta w\) products.
Full Algorithm (Soft Movement Pruning)#
Algorithm: Soft Movement Pruning
Input: Pre-trained model weights W, target sparsity s, training data D
learning rate eta, score learning rate eta_S
1. Initialize scores S_i = 0 for all weights
2. Initialize threshold tau = 0
3. For each training step t = 1, ..., T:
a. Compute masks: m_i = 1[S_i > tau]
b. Compute effective weights: w_tilde_i = m_i * w_i
c. Forward pass with w_tilde to compute loss L
d. Backward pass to compute gradients dL/dw_i
e. Update scores (STE): S_i <- S_i + eta_S * w_i * dL/dw_tilde_i
f. Update weights: w_i <- w_i - eta * dL/dw_tilde_i
g. Update threshold tau so that fraction s of weights have S_i < tau
(linearly increase s from 0 to target over warmup period)
4. Return pruned model: w_final_i = 1[S_i > tau_final] * w_iComparison with Magnitude Pruning on BERT#
Sanh et al. (2020) evaluated movement pruning on BERT fine-tuning tasks from the GLUE benchmark. Key findings:
| Method | MNLI (acc) | QQP (F1) | SQuAD (F1) | Sparsity |
|---|---|---|---|---|
| BERT (dense) | 84.6 | 88.0 | 88.5 | 0% |
| Magnitude Pruning | 78.3 | 85.2 | 79.1 | 90% |
| Movement Pruning (soft) | 82.3 | 87.5 | 85.6 | 90% |
| Movement Pruning (hard) | 81.2 | 86.8 | 84.1 | 90% |
At 90% sparsity (only 10% of weights remaining), movement pruning outperforms magnitude pruning by 4-6 percentage points across tasks. The gap widens at higher sparsity levels. At 97% sparsity, magnitude pruning nearly collapses while movement pruning retains meaningful performance.
When Movement Pruning Excels#
Movement pruning is most effective when:
- Fine-tuning pre-trained models: The initial magnitudes reflect the pre-training distribution, not the target task. Movement captures the adaptation dynamics.
- High sparsity regimes: At moderate sparsity (50-70%), magnitude and movement pruning perform similarly. At extreme sparsity (90%+), movement pruning pulls ahead significantly.
- Transfer learning: When the source and target domains differ, the weights that matter most for the target task may differ substantially from those that were large after pre-training.
Movement pruning is less advantageous when training from scratch, because there is no pre-existing magnitude distribution to overcome — the weights and their movements develop together.
Gradient-Based Pruning Methods#
Gradient-based methods use the loss function’s sensitivity to weight removal as the pruning criterion. This section covers four increasingly sophisticated approaches.
First-Order Taylor Pruning#
Derivation from First Principles#
We want to estimate the change in loss \(\delta L\) when weight \(w_i\) is removed (set to zero). Removing \(w_i\) means applying a perturbation \(\delta w_i = -w_i\) (since the new value is \(0 = w_i + \delta w_i\), so \(\delta w_i = -w_i\)).
The Taylor expansion of the loss around the current weights is:
$$L(w + \delta w) = L(w) + \sum_i \frac{\partial L}{\partial w_i} \delta w_i + \frac{1}{2} \sum_{i,j} \frac{\partial^2 L}{\partial w_i \partial w_j} \delta w_i \delta w_j + \cdots$$For a single weight removal (\(\delta w_j = 0\) for \(j \neq i\), \(\delta w_i = -w_i\)):
$$\delta L = L(w + \delta w) - L(w) \approx \frac{\partial L}{\partial w_i} (-w_i) = -w_i \frac{\partial L}{\partial w_i}$$The importance score is the absolute value of this change:
$$\text{score}(w_i) = \left| w_i \cdot \frac{\partial L}{\partial w_i} \right|$$A large score means removing this weight would cause a large change in loss — so the weight is important and should be kept.
Accumulation Over Mini-Batches#
Since the gradient \(\frac{\partial L}{\partial w_i}\) varies across mini-batches, we accumulate the score over \(B\) batches:
$$\text{score}(w_i) = \left| \frac{1}{B} \sum_{b=1}^{B} w_i^{(b)} \cdot \frac{\partial L_b}{\partial w_i} \right|$$In practice, if weights change slowly (small learning rate or evaluation-only), we can factor out:
$$\text{score}(w_i) \approx \left| w_i \cdot \frac{1}{B} \sum_{b=1}^{B} \frac{\partial L_b}{\partial w_i} \right|$$Comparison with Magnitude Pruning#
- Magnitude pruning uses score \(= |w_i|\). It ignores gradient information entirely.
- Taylor pruning uses score \(= |w_i \cdot g_i|\). It considers both magnitude and gradient.
A weight with large magnitude but near-zero gradient (meaning the loss is insensitive to it) will get a low Taylor score but a high magnitude score. Conversely, a small weight with a large gradient will be scored high by Taylor but low by magnitude.
Second-Order Taylor Pruning#
Full Derivation#
Including the second-order term in the Taylor expansion for a single weight removal:
$$\delta L \approx -w_i \frac{\partial L}{\partial w_i} + \frac{1}{2} w_i^2 \frac{\partial^2 L}{\partial w_i^2}$$Let \(g_i = \frac{\partial L}{\partial w_i}\) and \(h_{ii} = \frac{\partial^2 L}{\partial w_i^2}\) (diagonal of the Hessian). The importance score is:
$$\text{score}(w_i) = \left| w_i g_i - \frac{1}{2} w_i^2 h_{ii} \right|$$Note: we write \(-w_i g_i\) with the sign absorbed differently depending on convention. The full expression accounting for the perturbation \(\delta w_i = -w_i\) is:
$$\delta L = g_i \cdot (-w_i) + \frac{1}{2} h_{ii} \cdot (-w_i)^2 = -w_i g_i + \frac{1}{2} w_i^2 h_{ii}$$So:
$$\text{score}(w_i) = \left| -w_i g_i + \frac{1}{2} w_i^2 h_{ii} \right|$$Connection to Optimal Brain Damage (OBD)#
LeCun et al. (1990) introduced Optimal Brain Damage, which assumes that near a local minimum, \(g_i \approx 0\), so the first-order term vanishes:
$$\delta L \approx \frac{1}{2} w_i^2 h_{ii}$$This is the saliency in OBD. Weights with the smallest saliency are pruned because removing them causes the least increase in loss. This requires computing the diagonal Hessian, which OBD approximates via the empirical Fisher information matrix.
Efficient Hessian Diagonal Computation#
Computing the full Hessian \(H \in \mathbb{R}^{n \times n}\) is intractable for modern networks. Even the diagonal requires \(O(n)\) additional storage and computation.
Empirical Fisher approximation: For a loss function \(L\), the empirical Fisher information provides an approximation to the Hessian diagonal:
$$h_{ii} \approx \mathbb{E}\left[\left(\frac{\partial L}{\partial w_i}\right)^2\right] = \frac{1}{B}\sum_{b=1}^{B} \left(\frac{\partial L_b}{\partial w_i}\right)^2$$This is simply the mean squared gradient, which is trivially computed during training. The approximation is exact when the model is at a local minimum of the expected loss and the loss is the negative log-likelihood.
Numerical Example#
Consider a weight \(w = 0.8\) with gradient \(g = 0.1\) and Hessian diagonal \(h = 2.0\):
- First-order score: \(|w \cdot g| = |0.8 \times 0.1| = 0.08\)
- Second-order score: \(|-w \cdot g + 0.5 \cdot w^2 \cdot h| = |-0.08 + 0.5 \times 0.64 \times 2.0| = |-0.08 + 0.64| = 0.56\)
The second-order term (0.64) dominates, revealing that this weight occupies a region of high curvature — removing it would cause a large loss increase despite the small first-order effect.
SNIP (Single-shot Network Pruning, Lee et al. 2019) — Full Detail#
Core Idea: Connection Sensitivity at Initialization#
SNIP answers a radical question: can we determine which connections to prune before any training occurs, using only a single mini-batch of data? If so, we save enormous computational cost — there is no need for iterative pruning-retraining cycles.
The key idea is to introduce mask variables \(c_j \in {0, 1}\) for each connection, where \(c_j = 1\) means the connection is active and \(c_j = 0\) means it is pruned. The effective weight is:
$$w'_j = c_j \cdot w_j$$We then measure the sensitivity of the loss to each mask variable, evaluated at \(c = \mathbf{1}\) (all connections active):
$$g_j(w; \mathcal{D}) = \frac{\partial L(c \odot w; \mathcal{D})}{\partial c_j}\bigg|_{c=\mathbf{1}}$$Full Chain Rule Derivation#
Let us derive this sensitivity explicitly. The loss depends on the effective weights \(w’ = c \odot w\). By the chain rule:
$$\frac{\partial L}{\partial c_j} = \frac{\partial L}{\partial w'_j} \cdot \frac{\partial w'_j}{\partial c_j}$$Since \(w’_j = c_j \cdot w_j\):
$$\frac{\partial w'_j}{\partial c_j} = w_j$$Therefore:
$$\frac{\partial L}{\partial c_j}\bigg|_{c=\mathbf{1}} = w_j \cdot \frac{\partial L}{\partial w'_j}\bigg|_{w'=w} = w_j \cdot \frac{\partial L}{\partial w_j}$$This is exactly the same as the first-order Taylor pruning score. The connection sensitivity is:
$$g_j = w_j \cdot \frac{\partial L(w; \mathcal{D})}{\partial w_j}$$The intuition is clear: \(g_j\) measures how much the loss would change if connection \(j\) were removed, to first order.
Normalized Score#
To make the scores comparable across layers (which may have very different scales), SNIP normalizes them:
$$\kappa_j = \frac{|g_j|}{\sum_{k=1}^{n} |g_k|}$$This ensures \(\sum_j \kappa_j = 1\), giving each connection a share of the total sensitivity. Connections with the top \(\kappa\) values (up to the desired remaining ratio) are kept.
Algorithm Step by Step#
Algorithm: SNIP (Single-shot Network Pruning)
Input: Randomly initialized network with weights w, target sparsity s
One mini-batch of data (x, y) from training set
1. Initialize all mask variables: c_j = 1 for all j
2. Forward pass: compute L(c * w; (x,y)) with current masks
3. Backward pass: compute dL/dc_j for all j
This gives g_j = w_j * dL/dw_j for all j
4. Compute normalized scores: kappa_j = |g_j| / sum_k(|g_k|)
5. Determine threshold tau such that fraction s of weights
have kappa_j < tau
6. Set final masks: m_j = 1 if kappa_j >= tau, else m_j = 0
7. Apply masks: w'_j = m_j * w_j
8. Train the pruned network from this initializationNumerical Example#
Consider a tiny 3-weight network at initialization:
| Weight | \(w_j\) | \(\frac{\partial L}{\partial w_j}\) | \(g_j = w_j \cdot \frac{\partial L}{\partial w_j}\) | \(|g_j|\) | \(\kappa_j\) |
|---|---|---|---|---|---|
| \(w_1\) | 0.5 | -2.0 | -1.0 | 1.0 | 0.435 |
| \(w_2\) | -0.3 | 1.0 | -0.3 | 0.3 | 0.130 |
| \(w_3\) | 0.8 | -1.25 | -1.0 | 1.0 | 0.435 |
Sum of \(|g_j|\) = 2.3. If we want 33% sparsity (prune 1 of 3 weights), we prune \(w_2\) which has the lowest \(\kappa_2 = 0.130\).
Why It Works with Just One Mini-Batch#
The empirical finding is that connection sensitivities are surprisingly stable across different mini-batches at initialization. The reason is that the sensitivity primarily reflects the network topology — how information flows through the randomly initialized graph. A connection that sits on many high-activation paths will be sensitive regardless of which specific data points are used. This topological property is relatively stable.
Limitations#
- Instability at high sparsity: At 95%+ sparsity, SNIP’s single-shot decision becomes unreliable because the interactions between pruned connections matter (the first-order approximation breaks down).
- Layer collapse: SNIP can allocate zero connections to entire layers, especially narrow bottleneck layers, causing the network to lose all representational capacity in those layers. This happens because the normalization does not account for the structural role of each layer.
- No iterative refinement: Once the mask is set, there is no way to recover from a bad decision.
GraSP (Gradient Signal Preservation, Wang et al. 2020)#
Motivation: From Local Sensitivity to Gradient Flow#
SNIP measures the local effect of removing each connection — how much the loss changes. But this ignores a crucial property: the network still needs to be trained after pruning. What matters is not just the loss at initialization, but whether the pruned network can be effectively trained.
GraSP approaches this by asking: which connections, if removed, would most reduce the gradient flow through the network? If gradient flow is impaired, training will stall.
Key Quantity: Gradient Flow#
The gradient flow is quantified as:
$$\Delta L = -g^T H g$$where \(g = \nabla_w L\) is the gradient vector and \(H = \nabla_w^2 L\) is the Hessian matrix. This quantity represents how much the loss would decrease in a single Newton step. If \(\Delta L\) is large and negative, gradient flow is strong and training can make progress.
More precisely, consider the effect of a gradient descent step on the loss:
$$L(w - \eta g) \approx L(w) - \eta g^T g + \frac{\eta^2}{2} g^T H g$$The gradient-dependent term that determines trainability is \(g^T H g\). GraSP seeks to preserve this quantity when pruning.
Score Derivation#
We want the score for connection \(j\) to measure how much removing it would change the gradient flow. Introducing mask variables \(c_j\) as in SNIP:
$$S_j = -\frac{\partial (g^T H g)}{\partial c_j}\bigg|_{c=\mathbf{1}}$$The negative sign ensures that connections whose removal decreases gradient flow (positive \(\frac{\partial(g^T Hg)}{\partial c_j}\)) get a negative score — these are important and should be kept.
Wait — the convention in the paper is to prune weights with the most negative scores (those that reduce gradient flow). Equivalently, we keep weights with the highest (most positive) scores.
Let us derive this more carefully. Define:
$$\mathcal{G}(c) = g(c)^T H(c)\, g(c)$$where \(g(c) = \nabla_{w’} L(c \odot w)\) and \(H(c)\) is the Hessian with respect to masked weights. The gradient of \(\mathcal{G}\) with respect to \(c_j\) is complex because both \(g\) and \(H\) depend on \(c\).
In practice, the paper simplifies by computing:
$$S_j = -\left(H g\right)_j \cdot w_j$$where \((Hg)_j\) denotes the \(j\)-th component of the Hessian-gradient product \(Hg\), and the \(w_j\) factor comes from the chain rule through the mask variable (same as in SNIP).
Efficient Computation via Hessian-Gradient Product#
The Hessian matrix \(H\) is far too large to compute explicitly. However, we only need the product \(Hg\), which can be computed efficiently using a finite difference approximation:
$$Hg \approx \frac{\nabla L(w + \epsilon g) - \nabla L(w)}{\epsilon}$$for a small \(\epsilon\) (typically \(10^{-5}\) to \(10^{-3}\)). This requires:
- One forward-backward pass at \(w\) to get \(g = \nabla L(w)\)
- One forward-backward pass at \(w + \epsilon g\) to get \(\nabla L(w + \epsilon g)\)
Total cost: two forward-backward passes, regardless of network size.
Full Derivation of the Score#
Starting from the quantity we want to preserve:
$$\mathcal{G} = g^T H g = \sum_{i,j} g_i H_{ij} g_j$$We need \(\frac{\partial \mathcal{G}}{\partial c_k}\). Using the chain rule with the substitution \(w’_k = c_k w_k\):
$$\frac{\partial \mathcal{G}}{\partial c_k} = \frac{\partial \mathcal{G}}{\partial w'_k} \cdot \frac{\partial w'_k}{\partial c_k} = \frac{\partial \mathcal{G}}{\partial w'_k} \cdot w_k$$The quantity \(\frac{\partial \mathcal{G}}{\partial w’_k}\) involves derivatives of both \(g\) and \(H\) with respect to \(w’_k\), which is complex. The paper’s key approximation is to keep only the leading-order terms, which yields:
$$\frac{\partial \mathcal{G}}{\partial w'_k} \approx 2 (Hg)_k$$where we used \(\frac{\partial g}{\partial w’k} = H{:,k}\) (the \(k\)-th column of \(H\)). Therefore:
$$S_k = -\frac{\partial \mathcal{G}}{\partial c_k} \approx -2 w_k \cdot (Hg)_k$$The factor of 2 is a constant and does not affect ranking, so the practical score is:
$$S_k = -(Hg)_k \cdot w_k$$Connections with the highest \(S_k\) (most positive, or least negative) are kept. This means we prune connections that, when removed, would most reduce gradient flow.
Algorithm Pseudocode#
Algorithm: GraSP (Gradient Signal Preservation)
Input: Randomly initialized network with weights w, target sparsity s
One mini-batch of data (x, y), perturbation scale epsilon
1. Forward-backward pass at w:
g = grad_w L(w; (x,y))
2. Perturbed forward-backward pass:
g_perturbed = grad_w L(w + epsilon * g; (x,y))
3. Compute Hessian-gradient product:
Hg = (g_perturbed - g) / epsilon
4. Compute scores for each weight j:
S_j = -(Hg)_j * w_j
5. Determine threshold tau: top (1-s) fraction of S_j values
6. Create masks: m_j = 1 if S_j >= tau, else m_j = 0
7. Apply masks and train: w'_j = m_j * w_j, then train normallyWhy Preserving Gradient Flow Leads to Better Trainability#
Consider a pruned network where an entire layer has been stripped of most connections. The gradients flowing backward through that layer will be severely attenuated (because the layer’s Jacobian has near-zero rank). GraSP explicitly measures this effect through the \(g^T H g\) quantity. If pruning a connection would create such a bottleneck, the corresponding \(Hg\) component will be large, giving it a high preservation score.
SNIP, by contrast, only measures the first-order loss change and is blind to this trainability consideration. This is why GraSP consistently outperforms SNIP at high sparsity levels.
SynFlow (Synaptic Flow Pruning, Tanaka et al. 2020)#
Motivation: Data-Free Pruning and Layer Collapse#
Both SNIP and GraSP require data to compute their scores. SynFlow asks: can we prune effectively without any training data at all?
More importantly, SynFlow addresses the layer collapse problem that plagues SNIP and GraSP at high sparsity. Layer collapse occurs when an entire layer loses all its connections, rendering the network unable to propagate information regardless of how the remaining weights are set.
The Layer Collapse Theorem#
Tanaka et al. prove a fundamental theorem:
Theorem: Any pruning score that is (1) positive and (2) conservative (i.e., satisfies a flow conservation property through the network) will avoid layer collapse when applied iteratively.
The intuition is that a conservative scoring function ensures that if any path through the network is important, every connection along that path receives a nonzero score. Therefore, no layer can be completely zeroed out.
Synaptic Saliency#
The SynFlow score is based on the synaptic saliency, defined using a special loss function that does not require data:
$$\mathcal{R} = \mathbf{1}^T \left(\prod_{l=1}^{L} |\theta^{(l)}|\right) \mathbf{1}$$where \(\theta^{(l)}\) is the weight matrix of layer \(l\), \(|\cdot|\) denotes element-wise absolute value, and \(\mathbf{1}\) is a vector of ones.
This quantity is the sum of all path products through the network. A path is a sequence of weights, one from each layer, that connects an input node to an output node. The product of absolute values along a path measures the signal magnitude that path can carry.
The synaptic saliency for weight \(\theta_j^{(l)}\) in layer \(l\) is:
$$R_j^{(l)} = \frac{\partial \mathcal{R}}{\partial \theta_j^{(l)}} \odot \theta_j^{(l)}$$Derivation of Why SynFlow Avoids Layer Collapse#
Let us show that the SynFlow score is positive and conservative.
Positivity: Since \(\mathcal{R}\) is a product of absolute values, \(\frac{\partial \mathcal{R}}{\partial |\theta_j^{(l)}|} \geq 0\) for all weights (it is a sum of non-negative path products that include \(|\theta_j^{(l)}|\)). Therefore:
$$R_j^{(l)} = \frac{\partial \mathcal{R}}{\partial |\theta_j^{(l)}|} \cdot |\theta_j^{(l)}| \geq 0$$The score is zero only if all paths through weight \(j\) have at least one other zero weight. As long as there exists any nonzero path through weight \(j\), the score is strictly positive.
Conservation: Consider the total score across all weights in a single layer \(l\). By the structure of the product, the sum of scores in layer \(l\) equals the sum of scores in any other layer \(l’\):
$$\sum_j R_j^{(l)} = \sum_k R_k^{(l')} = \mathcal{R}$$This conservation property means the scoring budget is equally distributed across layers. When pruning iteratively, each layer loses connections proportionally to the path-level importance, preventing any layer from being disproportionately pruned.
Proof that layer collapse is avoided: Suppose for contradiction that iterative SynFlow pruning removes all weights from layer \(l\). Before the last weight in layer \(l\) is removed, it must have been the weight with the lowest score globally. But since at least one path goes through this weight (the last remaining path through layer \(l\)), its score is strictly positive. And since all weights in other layers that also lie on this path also have positive scores, the conservation property ensures the scores are balanced. The last weight in a bottleneck layer will therefore have a score comparable to weights in other layers, preventing its removal before comparable weights elsewhere. (The formal proof uses induction on the number of pruning iterations.)
Iterative SynFlow Algorithm#
A key feature of SynFlow is that it is applied iteratively rather than in a single shot. If the target sparsity is \(s\) (fraction to remove) and we use \(n\) iterations, each iteration prunes a fraction:
$$s_{\text{iter}} = 1 - (1 - s)^{1/n}$$For example, to reach 90% sparsity (\(s = 0.9\)) in \(n = 100\) iterations:
$$s_{\text{iter}} = 1 - 0.1^{0.01} = 1 - 0.977 = 0.023$$Each iteration prunes about 2.3% of the currently remaining weights.
Algorithm: Iterative SynFlow
Input: Initialized network with weights theta, target sparsity s
Number of iterations n
1. Replace all weights with absolute values: theta <- |theta|
2. Compute per-iteration sparsity: s_iter = 1 - (1 - s)^(1/n)
3. For iteration i = 1, ..., n:
a. Forward pass with all-ones input: ones vector through network
R_total = 1^T * (prod_{l} theta^{(l)}) * 1
b. Backward pass: compute dR/d(theta_j) for all weights
c. Compute scores: S_j = dR/d(theta_j) * theta_j
d. Among currently unmasked weights, find threshold tau such
that fraction s_iter have S_j < tau
e. Mask weights below threshold: m_j = 0 if S_j < tau
f. Apply masks: theta_j <- m_j * theta_j
4. Return final mask m (apply to original signed weights for training)Comparison with SNIP and GraSP#
| Property | SNIP | GraSP | SynFlow |
|---|---|---|---|
| Data required | 1 mini-batch | 1 mini-batch | None |
| Forward-backward passes | 1 | 2 | n (iterations) |
| Criterion | Connection sensitivity | Gradient flow preservation | Synaptic flow (path products) |
| Avoids layer collapse | No | No | Yes (provably) |
| Performance at 95% sparsity | Moderate | Good | Good |
| Performance at 99% sparsity | Poor (collapse) | Moderate (partial collapse) | Good (no collapse) |
SynFlow’s primary advantage is robustness at extreme sparsity levels, where SNIP and GraSP suffer from layer collapse. Its disadvantage is that being data-free, it cannot leverage task-specific information, which matters more at moderate sparsity levels.
Pruning During Training#
Rather than pruning before or after training, several methods integrate pruning into the training process itself, allowing the mask and weights to co-evolve.
Continuous Sparsification (Savarese et al., 2020)#
Reparameterization#
Instead of binary masks, Continuous Sparsification uses a differentiable relaxation. Each weight is reparameterized as:
$$w_i = \hat{w}_i \cdot \sigma(s_i)$$where \(\hat{w}_i\) is the underlying weight parameter, \(s_i\) is a learnable mask logit, and \(\sigma\) is the sigmoid function:
$$\sigma(s) = \frac{1}{1 + e^{-s}}$$When \(s_i \to +\infty\), \(\sigma(s_i) \to 1\) and the weight is fully active. When \(s_i \to -\infty\), \(\sigma(s_i) \to 0\) and the weight is effectively pruned.
Joint Training Objective#
The total loss includes a sparsity-inducing penalty:
$$L_{\text{total}} = L_{\text{task}}(\hat{w} \odot \sigma(s)) + \lambda \sum_i \sigma(s_i)$$The penalty \(\sum_i \sigma(s_i)\) is a differentiable proxy for the number of active connections (since each \(\sigma(s_i) \in (0,1)\) approximates a binary mask).
Gradient Derivations#
Gradient with respect to \(\hat{w}_i\):
$$\frac{\partial L_{\text{total}}}{\partial \hat{w}_i} = \frac{\partial L_{\text{task}}}{\partial w_i} \cdot \sigma(s_i)$$This is intuitive: the gradient for the weight is scaled by the mask value. Nearly-pruned weights (\(\sigma(s_i) \approx 0\)) receive nearly zero gradient, so they stop learning.
Gradient with respect to \(s_i\):
$$\frac{\partial L_{\text{total}}}{\partial s_i} = \frac{\partial L_{\text{task}}}{\partial w_i} \cdot \hat{w}_i \cdot \sigma'(s_i) + \lambda \cdot \sigma'(s_i)$$where \(\sigma’(s) = \sigma(s)(1 - \sigma(s))\). Expanding:
$$\frac{\partial L_{\text{total}}}{\partial s_i} = \sigma(s_i)(1 - \sigma(s_i)) \left[\hat{w}_i \frac{\partial L_{\text{task}}}{\partial w_i} + \lambda\right]$$The first factor \(\sigma(s_i)(1-\sigma(s_i))\) is largest when \(s_i = 0\) (mask at 0.5) and vanishes as \(s_i \to \pm\infty\). This means mask decisions are most actively refined when they are uncertain.
The second factor has two competing terms:
- \(\hat{w}i \frac{\partial L{\text{task}}}{\partial w_i}\): the task-driven signal (keep if removing hurts)
- \(\lambda\): the sparsity pressure (always pushes toward pruning)
Annealing Schedule for \(\lambda\)#
To avoid premature pruning, \(\lambda\) is typically annealed from 0 to its final value over the course of training:
$$\lambda(t) = \lambda_{\text{final}} \cdot \min\left(1, \frac{t}{T_{\text{warmup}}}\right)$$During warmup, the network learns useful features with minimal sparsity pressure. Then \(\lambda\) increases, gradually pushing unnecessary connections to zero.
At the end of training, the soft masks are binarized:
$$m_i = \begin{cases} 1 & \text{if } \sigma(s_i) > 0.5 \text{ (equivalently, } s_i > 0\text{)} \\ 0 & \text{otherwise} \end{cases}$$Soft Threshold Reparameterization (STR, Kusupati et al., 2020)#
Learnable Per-Layer Thresholds#
STR takes a different approach: instead of learning a mask for each weight individually, it learns a single threshold per layer that determines the sparsity pattern.
The effective weight is:
$$w'_i = \text{sign}(w_i) \cdot \max\left(|w_i| - \text{softplus}(t_l), \, 0\right)$$where \(t_l\) is the learnable threshold parameter for layer \(l\), and:
$$\text{softplus}(t) = \log(1 + e^t)$$The softplus function ensures the threshold is always positive (you cannot have a negative threshold for magnitude).
This is the soft thresholding operator from proximal optimization, but with a learned threshold. Weights with magnitude below \(\text{softplus}(t_l)\) are set exactly to zero, and weights above the threshold are shrunk toward zero by the threshold amount.
Visualization of Soft Thresholding#
w'
^
| /
| /
| /
| /
| /
|-----+ (slope 1 above threshold)
| |
| |
--+-----|--------+-----> |w|
| tau |
| |
| +-----
| /
| /
| /
v
w' = sign(w) * max(|w| - tau, 0)
Weights within [-tau, tau] are exactly zero.
Weights outside are shrunk by tau.Gradient Through Soft Thresholding#
The gradient of \(w’_i\) with respect to \(w_i\) is:
$$\frac{\partial w'_i}{\partial w_i} = \begin{cases} 1 & \text{if } |w_i| > \text{softplus}(t_l) \\ 0 & \text{if } |w_i| \leq \text{softplus}(t_l) \end{cases}$$The gradient of the loss with respect to the threshold parameter \(t_l\) (summed over all weights in layer \(l\)):
$$\frac{\partial L}{\partial t_l} = \sum_{i \in \text{layer } l} \frac{\partial L}{\partial w'_i} \cdot \frac{\partial w'_i}{\partial t_l}$$For weights above the threshold:
$$\frac{\partial w'_i}{\partial t_l} = -\text{sign}(w_i) \cdot \sigma(t_l)$$where \(\sigma(t_l) = \frac{e^{t_l}}{1+e^{t_l}}\) is the derivative of softplus. For weights below the threshold, the gradient is zero (they are already pruned).
Therefore:
$$\frac{\partial L}{\partial t_l} = -\sigma(t_l) \sum_{\substack{i \in \text{layer } l \\ |w_i| > \text{softplus}(t_l)}} \text{sign}(w_i) \cdot \frac{\partial L}{\partial w'_i}$$This gradient naturally balances: if pruning more weights (increasing the threshold) would hurt the loss, the gradient is negative, pushing the threshold down. If the pruned weights are unimportant, the gradient is near zero, allowing the sparsity pressure to dominate.
Automatic Per-Layer Sparsity#
A major advantage of STR is that it automatically learns the appropriate sparsity for each layer. Layers where weights are more uniformly distributed (less redundancy) will learn lower thresholds. Layers with many near-zero weights will learn higher thresholds. This eliminates the need for manual per-layer sparsity allocation, which is a significant hyperparameter burden in other methods.
Powerpropagation (Schwarz et al., 2021)#
Power Reparameterization#
Powerpropagation introduces a simple but elegant reparameterization:
$$w_i = \text{sign}(\hat{w}_i) \cdot |\hat{w}_i|^\alpha$$where \(\hat{w}_i\) is the underlying parameter and \(\alpha > 1\) is a fixed exponent (typically \(\alpha = 2\)).
This mapping is a bijection for \(\hat{w}_i \neq 0\), so it does not change the representational capacity of the network. However, it fundamentally changes the optimization landscape.
Gradient Analysis#
The gradient of the loss with respect to the underlying parameter \(\hat{w}_i\) is:
$$\frac{\partial L}{\partial \hat{w}_i} = \frac{\partial L}{\partial w_i} \cdot \frac{\partial w_i}{\partial \hat{w}_i}$$Computing the derivative of the reparameterization:
$$\frac{\partial w_i}{\partial \hat{w}_i} = \alpha \cdot |\hat{w}_i|^{\alpha - 1}$$(The sign function is locally constant and contributes zero derivative; we handle it as a straight-through operator.)
Therefore:
$$\frac{\partial L}{\partial \hat{w}_i} = \alpha \cdot |\hat{w}_i|^{\alpha - 1} \cdot \frac{\partial L}{\partial w_i}$$The “Rich Get Richer” Effect#
Consider two parameters \(\hat{w}_A = 1.0\) and \(\hat{w}_B = 0.1\) with \(\alpha = 2\). Even if the loss gradient \(\frac{\partial L}{\partial w}\) is the same for both, the effective gradients are:
$$\frac{\partial L}{\partial \hat{w}_A} = 2 \times 1.0^1 \times g = 2g$$$$\frac{\partial L}{\partial \hat{w}_B} = 2 \times 0.1^1 \times g = 0.2g$$The larger parameter receives a 10x larger gradient update. This creates a positive feedback loop: large weights grow faster, small weights grow slower. Over training, the distribution of weights becomes increasingly bimodal — a cluster near zero and a cluster at large magnitudes. This is exactly the distribution we want for pruning.
Natural Sparsity Emergence#
As training progresses with powerpropagation, the weight distribution naturally evolves:
Standard training: Powerpropagation (alpha=2):
Count Count
| |
| **** |*
| ****** |**
| ******** |*** *
| ********** |**** ***
|************ |***** *****
+------------> |w| +--+-----------+--> |w|
0 0 (bimodal)
(Roughly Gaussian) (Concentrated at 0 and large values)After training, we can simply threshold the weights at a small value to achieve sparsity, without any explicit pruning criterion needed. The optimization dynamics have already separated important from unimportant weights.
Advantages#
- No pruning schedule: Sparsity emerges naturally during training.
- No additional hyperparameters (beyond \(\alpha\)): No target sparsity, threshold schedule, or mask learning rate.
- Smooth optimization: The reparameterization is differentiable everywhere (except at zero, which is measure-zero).
- Compatible with any optimizer: Works with SGD, Adam, etc.
Pruning with Regularization#
Regularization provides a principled framework for inducing sparsity during training by adding penalty terms that encourage weights to become zero.
L1 Regularization (Weight Decay toward Sparsity)#
Formulation#
The L1-regularized objective is:
$$L_{\text{total}} = L_{\text{task}}(w) + \lambda \sum_{i=1}^{n} |w_i|$$where \(\lambda > 0\) controls the sparsity-accuracy tradeoff.
Why Gradient Descent Fails for L1#
The L1 penalty \(|w_i|\) is not differentiable at \(w_i = 0\). The subdifferential is:
$$\partial |w_i| = \begin{cases} \{+1\} & w_i > 0 \\ [-1, +1] & w_i = 0 \\ \{-1\} & w_i < 0 \end{cases}$$Standard gradient descent with a subgradient will oscillate around zero without ever reaching it exactly, because the gradient of the task loss will generically be nonzero, preventing the weight from settling at exactly zero.
Proximal Gradient Descent: Full Derivation#
The correct algorithm for L1 optimization is proximal gradient descent. At each step, we:
- Take a gradient step on the smooth part: \(\tilde{w}i = w_i - \eta \frac{\partial L{\text{task}}}{\partial w_i}\)
- Apply the proximal operator for the L1 penalty:
This is the soft thresholding operator. Let us derive it from first principles.
The proximal operator for a function \(h\) is defined as:
$$\text{prox}_h(v) = \arg\min_x \left\{ h(x) + \frac{1}{2}||x - v||^2 \right\}$$For \(h(x) = \eta\lambda|x|\) applied to a scalar:
$$\text{prox}_{\eta\lambda|\cdot|}(v) = \arg\min_x \left\{ \eta\lambda|x| + \frac{1}{2}(x - v)^2 \right\}$$Taking the derivative and setting to zero (for \(x > 0\)):
$$\eta\lambda + (x - v) = 0 \implies x = v - \eta\lambda$$This is valid only if \(x > 0\), i.e., \(v > \eta\lambda\).
For \(x < 0\):
$$-\eta\lambda + (x - v) = 0 \implies x = v + \eta\lambda$$This is valid only if \(x < 0\), i.e., \(v < -\eta\lambda\).
For \(|v| \leq \eta\lambda\), the minimum is at \(x = 0\) (check by evaluating the objective at \(x = 0\) vs. the boundary cases).
Combining:
$$\text{prox}_{\eta\lambda|\cdot|}(v) = \begin{cases} v - \eta\lambda & v > \eta\lambda \\ 0 & |v| \leq \eta\lambda \\ v + \eta\lambda & v < -\eta\lambda \end{cases} = \text{sign}(v)\max(|v| - \eta\lambda, 0)$$Why L1 Produces Exact Zeros but L2 Does Not#
This is a fundamental geometric property. Consider the regularized objective:
$$\min_w L_{\text{task}}(w) + \lambda R(w)$$Equivalently, this is a constrained optimization:
$$\min_w L_{\text{task}}(w) \quad \text{s.t.} \quad R(w) \leq c$$for some constant \(c\) determined by \(\lambda\).
L2 Constraint (circle): L1 Constraint (diamond):
w2 w2
| ___ |
| / \ ...loss | /\ ...loss
| | | / contours | / \ / contours
| | O | / |/ \/
----+-|-----+/-------w1 ----+------+-------w1
| \ / |\ /
| --- | \ /
| * = optimum | \/
| (generally | * = optimum
| nonzero) | (at corner = sparse!)The L1 constraint region is a diamond (cross-polytope) with corners on the axes. Loss contours are elliptical. The tangent point between an elliptical contour and a diamond is much more likely to occur at a corner (where one or more coordinates are zero) than at an interior point. In contrast, the L2 constraint region is a circle (sphere), which has no corners — tangent points occur at arbitrary locations, almost never on an axis.
Formally, for L1 the optimal solution lies at a corner of the diamond with probability 1 (for generic loss functions), while for L2 the optimal solution has all nonzero coordinates with probability 1.
Numerical Example#
Starting from \(w = 0.15\) with \(\eta = 0.1\) and \(\lambda = 0.5\):
Task gradient: \(\frac{\partial L_{\text{task}}}{\partial w} = 0.8\)
Step 1 (gradient): \(\tilde{w} = 0.15 - 0.1 \times 0.8 = 0.15 - 0.08 = 0.07\)
Step 2 (proximal): \(w^{\text{new}} = \text{sign}(0.07)\max(|0.07| - 0.5 \times 0.1, 0) = \max(0.07 - 0.05, 0) = 0.02\)
After one more step with similar gradient: \(\tilde{w} = 0.02 - 0.08 = -0.06\), then \(w^{\text{new}} = \text{sign}(-0.06)\max(0.06 - 0.05, 0) = -0.01\).
The weight is driven toward zero and will eventually hit exactly zero thanks to the proximal operator.
Group LASSO for Structured Sparsity#
Formulation#
While L1 regularization produces unstructured sparsity (individual weights become zero), many hardware platforms require structured sparsity — entire filters, channels, or attention heads removed.
Group LASSO achieves this by penalizing the \(\ell_2\) norm of predefined groups of weights:
$$L_{\text{reg}} = \lambda \sum_{g=1}^{G} ||W_g||_2 = \lambda \sum_{g=1}^{G} \sqrt{\sum_{i \in g} w_i^2}$$where \(W_g\) denotes the vector of weights in group \(g\).
Proximal Operator Derivation#
The proximal operator for Group LASSO requires solving:
$$\text{prox}_{\eta\lambda||\cdot||_2}(V_g) = \arg\min_{X_g} \left\{ \eta\lambda ||X_g||_2 + \frac{1}{2}||X_g - V_g||_2^2 \right\}$$Taking the gradient (for \(X_g \neq 0\)):
$$\eta\lambda \frac{X_g}{||X_g||_2} + (X_g - V_g) = 0$$This implies \(X_g\) is parallel to \(V_g\) (since the gradient points in the direction of \(X_g\), and the remaining term is \(V_g - X_g\)). Write \(X_g = \beta V_g\) for some \(\beta > 0\):
$$\eta\lambda \frac{\beta V_g}{\beta ||V_g||_2} + \beta V_g - V_g = 0$$$$\frac{\eta\lambda}{||V_g||_2} V_g + (\beta - 1) V_g = 0$$$$\beta = 1 - \frac{\eta\lambda}{||V_g||_2}$$This is valid when \(\beta > 0\), i.e., \(||V_g||_2 > \eta\lambda\). Otherwise, \(X_g = 0\).
The complete proximal operator is:
$$\text{prox}_{\eta\lambda||\cdot||_2}(V_g) = \left(1 - \frac{\eta\lambda}{||V_g||_2}\right)_+ V_g = \max\left(1 - \frac{\eta\lambda}{||V_g||_2}, \, 0\right) \cdot V_g$$When \(||V_g||_2 \leq \eta\lambda\), the entire group is set to zero simultaneously. This is the mechanism for structured sparsity — all weights in a group live or die together.
How to Define Groups#
The choice of groups determines the type of structured sparsity:
| Group Definition | Sparsity Type | Hardware Benefit |
|---|---|---|
| All weights in one conv filter | Filter pruning | Reduces output channels |
| All weights connecting to one input channel | Channel pruning | Reduces input channels |
| All weights in one attention head | Head pruning | Removes entire head computation |
| All weights in one row of FC layer | Neuron pruning | Removes one neuron |
| Block of weights (e.g., 4x4) | Block sparsity | NVIDIA structured sparsity support |
Hoyer Regularization#
The Hoyer Sparsity Measure#
The Hoyer measure quantifies the sparsity of a vector \(x \in \mathbb{R}^n\) using the ratio of L1 and L2 norms:
$$H(x) = \frac{\left(\sum_{i=1}^{n} |x_i|\right)^2}{\sum_{i=1}^{n} x_i^2}$$This ratio ranges from 1 (when only one element is nonzero — maximally sparse) to \(n\) (when all elements have equal magnitude — maximally dense). However, \(H\) is not normalized to \([0,1]\).
Normalized Hoyer Measure#
The normalized version maps to \([0,1]\):
$$\hat{H}(x) = \frac{\sqrt{n} - \frac{\sum|x_i|}{\sqrt{\sum x_i^2}}}{\sqrt{n} - 1}$$This equals 1 for a maximally sparse vector (one nonzero entry) and 0 for a maximally dense vector (all entries equal magnitude).
Derivation of the Normalization#
The ratio \(\frac{||x||_1}{||x||_2} = \frac{\sum|x_i|}{\sqrt{\sum x_i^2}}\) satisfies:
- Minimum (most sparse): when \(x = (a, 0, 0, \ldots, 0)\), the ratio is \(\frac{|a|}{|a|} = 1\).
- Maximum (most dense): when \(x = (a, a, \ldots, a)\), the ratio is \(\frac{n|a|}{\sqrt{n}|a|} = \sqrt{n}\).
By the Cauchy-Schwarz inequality: \(1 \leq \frac{||x||_1}{||x||_2} \leq \sqrt{n}\).
The normalized Hoyer inverts and scales this:
$$\hat{H}(x) = \frac{\sqrt{n} - \frac{||x||_1}{||x||_2}}{\sqrt{n} - 1} \in [0, 1]$$Use as Regularization#
Adding Hoyer regularization:
$$L_{\text{total}} = L_{\text{task}} + \lambda \cdot (1 - \hat{H}(w))$$This penalizes dense (low-sparsity) weight distributions. Minimizing \(1 - \hat{H}\) is equivalent to maximizing \(\hat{H}\), pushing toward sparsity.
Advantages over L1#
- Scale-invariant: \(\hat{H}(x) = \hat{H}(\alpha x)\) for any \(\alpha \neq 0\). L1 is not scale-invariant — it penalizes large weights even if they are sparse.
- Balanced sparsity pressure: Hoyer does not favor small weights over large ones. It measures the shape of the distribution, not its scale.
- Better gradient properties: The gradient of \(\hat{H}\) provides more uniform pressure across weights of different magnitudes, avoiding the pathological behavior of L1 where large weights receive constant gradient regardless of sparsity.
Combinatorial Optimization Approaches#
Pruning as Combinatorial Optimization#
The pruning problem can be formally stated as:
$$\min_{m \in \{0,1\}^n} L(w \odot m) \quad \text{subject to} \quad ||m||_0 \leq k$$where \(m\) is a binary mask, \(w\) are the (fixed) weights, and \(k\) is the budget of nonzero weights.
This is a combinatorial optimization problem — we must choose the best \(k\) out of \(n\) weights to keep. The number of possible masks is \(\binom{n}{k}\), which is astronomical for modern networks (e.g., \(\binom{10^8}{10^7}\)).
The problem is NP-hard in general. However, the structure of neural network loss functions admits useful approximations.
oBERT (Optimal BERT Surgeon, 2022)#
Applying OBS to Transformers#
Optimal Brain Surgeon (OBS), introduced by Hasselmo et al. (1993), uses the second-order Taylor expansion to optimally prune weights while compensating for the pruning error via weight updates to remaining weights:
$$\delta L \approx -w_i g_i + \frac{1}{2} w_i^2 [H^{-1}]_{ii}^{-1}$$The key insight of OBS over OBD is that after pruning weight \(w_i\), the remaining weights should be updated to compensate:
$$\delta w = -\frac{w_i}{[H^{-1}]_{ii}} H^{-1} e_i$$where \(e_i\) is the \(i\)-th standard basis vector.
oBERT adapts this framework for BERT-scale models (hundreds of millions of parameters) through several innovations:
Row-wise Hessian computation: Instead of computing the full Hessian (impossible at BERT scale), oBERT computes the Hessian independently for each row of each weight matrix. For a weight matrix \(W \in \mathbb{R}^{m \times n}\), this requires \(m\) Hessian matrices of size \(n \times n\), rather than one matrix of size \(mn \times mn\).
The row-wise Hessian for row \(r\) of a linear layer \(y = Wx + b\) is:
$$H_r = \frac{1}{B} \sum_{b=1}^{B} x_b x_b^T \cdot h_{rr}^{(\text{out})}$$where \(x_b\) is the input activation for sample \(b\) and \(h_{rr}^{(\text{out})}\) is the diagonal element of the output Hessian corresponding to row \(r\).
Gradual pruning with OBS updates: Rather than pruning all target weights at once, oBERT prunes in multiple steps, recomputing the Hessian after each step:
Algorithm: oBERT (Optimal BERT Surgeon)
Input: Fine-tuned BERT model, target sparsity s, calibration data D
Number of pruning steps P
1. s_step = 1 - (1-s)^(1/P) // per-step sparsity
2. For step p = 1, ..., P:
a. Compute row-wise Hessians H_r for each row of each layer
using calibration data D
b. For each row r, compute OBS saliencies:
sal_i = w_i^2 / (2 * [H_r^{-1}]_{ii})
c. Select weights to prune: bottom s_step fraction by saliency
(among currently unpruned weights)
d. For each pruned weight i, update remaining weights in same row:
delta_w = -w_i / [H_r^{-1}]_{ii} * H_r^{-1} * e_i
e. Apply weight updates and zero out pruned weights
3. (Optional) Fine-tune the pruned model for a few epochs
Return: Pruned BERT modelResults Compared to Magnitude Pruning#
| Method | SQuAD F1 | MNLI Acc | Sparsity | Pruning Time |
|---|---|---|---|---|
| Magnitude (one-shot) | 78.2 | 76.1 | 90% | Minutes |
| Magnitude (gradual) | 83.1 | 80.5 | 90% | Hours (retraining) |
| oBERT (one-shot) | 85.3 | 82.7 | 90% | Hours (Hessian) |
| oBERT (gradual) | 86.8 | 83.9 | 90% | Hours (Hessian) |
oBERT achieves significantly better accuracy than magnitude pruning at the same sparsity, especially in the one-shot setting where no retraining is needed. The cost is computing the row-wise Hessians, which requires a calibration dataset pass.
Combinatorial Brain Surgeon (CBS)#
From Greedy to Submodular Optimization#
CBS frames pruning as a submodular function maximization problem. The key observation is that the marginal benefit of keeping an additional weight exhibits diminishing returns — a hallmark of submodularity.
Define the set function:
$$F(S) = L(w) - L(w \odot m_S)$$where \(S \subseteq {1, \ldots, n}\) is the set of pruned weights and \(m_S\) is the corresponding mask (0 for pruned weights, 1 for kept weights). \(F(S)\) measures the loss increase from pruning the weights in \(S\).
We want to find the set \(S\) with \(|S| = n - k\) (pruning \(n-k\) weights) that minimizes \(F(S)\) — i.e., causes the least loss increase.
Under the second-order approximation:
$$F(S) \approx \sum_{i \in S} w_i g_i + \frac{1}{2} \sum_{i,j \in S} w_i w_j H_{ij}$$The cross terms \(H_{ij}\) capture interactions between pruned weights. When the Hessian is positive semi-definite (as it typically is near a minimum), \(F\) is supermodular, and the complementary problem (maximizing the set of kept weights) is submodular.
Guarantees via Submodularity#
For submodular function maximization with a cardinality constraint, the greedy algorithm provides a \((1 - 1/e)\)-approximation guarantee:
$$F_{\text{greedy}}(S) \geq \left(1 - \frac{1}{e}\right) F_{\text{optimal}}(S)$$This means the greedy solution achieves at least 63.2% of the optimal solution quality. While this bound is for the worst case, in practice the greedy solution is typically much closer to optimal.
The greedy algorithm iteratively selects the weight whose removal causes the smallest marginal increase in loss, accounting for previously pruned weights.
Pruning with Knowledge Distillation#
Motivation#
Pruning inevitably removes some model capacity, leading to accuracy degradation. Knowledge distillation can recover much of this lost accuracy by transferring knowledge from the original unpruned model (the teacher) to the pruned model (the student).
Distillation Loss Functions#
Logit-Level Distillation#
The student is trained to match the teacher’s soft output distribution:
$$L_{\text{KD}} = (1 - \alpha) L_{\text{CE}}(y, \sigma(z_S)) + \alpha \cdot T^2 \cdot \text{KL}(\sigma(z_T/T) \| \sigma(z_S/T))$$where \(z_S, z_T\) are student and teacher logits, \(T\) is the temperature, \(\sigma\) is softmax, and \(\alpha\) balances the hard label loss and distillation loss.
The temperature parameter \(T > 1\) softens the probability distribution, revealing the teacher’s relative confidence across classes (the “dark knowledge”). The \(T^2\) scaling factor compensates for the reduced gradient magnitude at higher temperatures.
Feature-Map Distillation#
For deeper knowledge transfer, we align intermediate representations:
$$L_{\text{FD}} = \sum_{l \in \mathcal{L}} ||f_l^S - \phi_l(f_l^T)||^2$$where \(f_l^S\) and \(f_l^T\) are the student and teacher feature maps at layer \(l\), \(\mathcal{L}\) is the set of matched layers, and \(\phi_l\) is a learned adaptation layer (typically a 1x1 convolution) that matches dimensions when the student has fewer channels than the teacher.
The adaptation layer is necessary because the pruned student may have different feature dimensions than the teacher. Its parameters are trained jointly with the student.
Attention Transfer#
For transformer models, we can specifically align attention patterns:
$$L_{\text{AT}} = \sum_{l=1}^{L} \sum_{h=1}^{H} ||A_{l,h}^S - A_{l,h}^T||_F^2$$where \(A_{l,h}^S, A_{l,h}^T \in \mathbb{R}^{n \times n}\) are the attention matrices for layer \(l\), head \(h\), with \(n\) being the sequence length and \(||\cdot||_F\) the Frobenius norm.
This loss ensures the pruned model maintains similar attention patterns to the teacher, preserving the learned relational structure between tokens.
Progressive Pruning + Distillation Pipeline#
The most effective approach combines gradual pruning with continuous distillation:
Algorithm: Progressive Pruning with Knowledge Distillation
Input: Teacher model T (unpruned), initial student S = copy of T
Target sparsity s, pruning steps P, training epochs E_per_step
Temperature tau, distillation weight alpha
1. Initialize student S as a copy of teacher T
2. s_per_step = 1 - (1-s)^(1/P)
3. For pruning step p = 1, ..., P:
a. PRUNE: Remove bottom s_per_step fraction of remaining
weights in S (by chosen criterion: magnitude, Taylor, etc.)
b. DISTILL: For epoch e = 1, ..., E_per_step:
For each mini-batch (x, y):
i. Teacher forward: z_T = T(x), features f_T
ii. Student forward: z_S = S(x), features f_S
iii. Compute combined loss:
L = (1-alpha) * CE(y, softmax(z_S))
+ alpha * tau^2 * KL(softmax(z_T/tau) || softmax(z_S/tau))
+ beta * sum_l ||f_l^S - phi_l(f_l^T)||^2
iv. Update student weights (only unpruned ones)
4. Final binarization: zero out all masked weights
Return: Pruned and distilled student model SWhy Distillation Recovers Accuracy Lost to Pruning#
The effectiveness of distillation after pruning can be understood through several lenses:
Richer supervision: The teacher’s soft targets contain more information per sample than hard labels. For a 1000-class problem, a hard label carries \(\log_2(1000) \approx 10\) bits. Soft targets carry up to \(1000 \times 32 = 32000\) bits (one float per class). This information-theoretic advantage helps the student learn more efficiently from fewer parameters.
Implicit regularization: The teacher’s output distribution acts as a form of label smoothing, preventing the pruned student from overfitting to the training data with its reduced capacity.
Feature alignment: Feature-map distillation provides layer-wise supervision, turning the student training from a single end-to-end optimization into multiple local optimization problems — each intermediate layer has its own target, making optimization easier.
Knowledge preservation: The teacher encodes relationships learned during its full-capacity training (e.g., “cats are more similar to dogs than to cars”). Without distillation, the pruned student must rediscover these relationships with fewer parameters. With distillation, these relationships are directly taught.
Lottery Ticket Variants and Extensions#
Deconstructing Lottery Tickets (Zhou et al., 2019)#
The original Lottery Ticket Hypothesis (Frankle & Carlin, 2019) states that dense networks contain sparse subnetworks that, when trained from their original initialization, can match the full network’s accuracy. But which component of the winning ticket actually matters?
Zhou et al. systematically ablate the three components of a winning ticket:
- The mask (which weights are kept)
- The sign of the initial weights
- The magnitude of the initial weights
Experimental Findings#
| Mask | Signs | Magnitudes | Accuracy (% of full) |
|---|---|---|---|
| Winning | Original | Original | 100% (baseline) |
| Winning | Original | Random | 89% |
| Winning | Random | Original | 62% |
| Winning | Original | Constant | 85% |
| Random | Original | Original | 41% |
The striking finding is that the mask + signs alone (with random or constant magnitudes) can achieve 85-89% of the full winning ticket’s accuracy. The mask alone with random signs drops to 62%, and a random mask with original weights drops to 41%.
The Supermask Discovery#
Even more remarkably, Zhou et al. discover that the mask alone, without any training, can achieve non-trivial accuracy. By using the mask as a binary selector over randomly initialized (but fixed) weights, they find supermasks that achieve well above chance accuracy on MNIST and even respectable accuracy on CIFAR-10.
This is found by treating the mask selection as an optimization problem: learn a score for each weight, threshold to get the mask, and use the straight-through estimator for gradients. The underlying weights are never changed.
The implication is profound: a sufficiently large random network contains within it — as a subnetwork selected by an appropriate mask — a model that performs well without any weight training. This connects to the random feature theory and provides theoretical support for the overparameterization hypothesis.
Multi-Prize Lottery Ticket Hypothesis#
The original lottery ticket work identified a single winning ticket. Subsequent work demonstrates that multiple winning tickets exist within the same dense network.
Key findings:
- Independent tickets: Different pruning seeds yield different winning tickets with similar accuracy. The winning subnetwork is not unique.
- Ensemble diversity: Different winning tickets make different errors, so ensembling sparse subnetworks can exceed the dense network’s accuracy.
- Functional diversity: Despite similar accuracy, different tickets learn different internal representations (measured by CKA similarity), suggesting they have found different local minima in the loss landscape.
The practical implication is that we can extract multiple complementary sparse models from a single dense training run, amortizing the training cost:
Dense Network (100M params)
|
+---> Ticket 1 (10M params, 95% acc on Task A)
|
+---> Ticket 2 (10M params, 94% acc on Task A, different errors)
|
+---> Ticket 3 (10M params, 95% acc on Task A, different errors)
|
Ensemble of 3 tickets (30M params total, 96.5% acc)
vs. Dense network (100M params, 96% acc)Dual Lottery Ticket Hypothesis (2022)#
The Dual Lottery Ticket Hypothesis inverts the relationship between sparse and dense networks:
Standard LTH: Dense networks contain winning sparse subnetworks.
Dual LTH: Sparse networks contain winning dense subnetworks that can be densified (expanded) to recover full accuracy.
More precisely, given a sparse network at some sparsity level, there exist dense substructures within it — sets of weights that, if duplicated and rearranged, can construct a dense network with comparable accuracy.
The practical algorithm works as follows:
- Train a sparse network (via any pruning method)
- Identify the “skeleton” — the structure of nonzero weights
- Grow the skeleton by reactivating pruned connections, initialized based on the existing sparse weights (e.g., via interpolation or local averaging)
- Fine-tune the densified network
The key insight is that the sparse network has already learned the essential structure and approximate weight values. Densification adds capacity where the sparse network is most constrained, recovering accuracy more efficiently than training a new dense network from scratch.
This creates a bidirectional relationship:
Dense Network
| ^
| Prune (LTH) | Densify (Dual LTH)
v |
Sparse Network
Dense -> Sparse: Pruning finds winning tickets
Sparse -> Dense: Densification finds winning expansionsEvaluation Framework#
Metrics Beyond Accuracy#
Evaluating pruning methods requires a multidimensional assessment. A method that achieves high accuracy but requires days of computation to find the mask may be impractical. Conversely, a fast method that achieves slightly lower accuracy may be preferred in practice.
| Metric | Definition | Why It Matters |
|---|---|---|
| Top-1 Accuracy | Classification accuracy on test set | Primary quality metric |
| FLOPs Remaining Ratio | \(\frac{\text{FLOPs (pruned)}}{\text{FLOPs (dense)}}\) | Theoretical speedup |
| Parameter Remaining Ratio | \(\frac{\text{Params (pruned)}}{\text{Params (dense)}}\) | Memory savings |
| Actual Inference Latency | Wall-clock time per sample | Real-world speedup (may differ from FLOPs) |
| Memory Footprint | Peak memory during inference (MB) | Deployment constraint |
| Pruning Cost | GPU-hours to find the mask + retrain | Total resource consumption |
| Accuracy per FLOP | \(\frac{\text{Accuracy}}{1 - \text{FLOPs ratio}}\) | Efficiency of pruning |
Standardized Benchmarks#
| Domain | Dataset | Model | Standard Sparsities |
|---|---|---|---|
| Vision | ImageNet | ResNet-50 | 50%, 70%, 80%, 90%, 95% |
| Vision | CIFAR-10 | VGG-16, ResNet-20 | 90%, 95%, 98% |
| NLP | GLUE | BERT-base | 70%, 80%, 90%, 95% |
| NLP | SQuAD | BERT-base | 70%, 80%, 90%, 95% |
| LLM | WikiText | GPT-2, LLaMA | 50%, 60%, 70% (2:4 structured) |
Fair Comparison: Same Training Budget Analysis#
A critical but often overlooked aspect of pruning evaluation is ensuring a fair computational budget. Consider two methods:
- Method A: Prune at initialization, then train for 100 epochs. Total cost: 100 training epochs.
- Method B: Train for 50 epochs, prune, retrain for 50 epochs. Total cost: 100 training epochs + pruning overhead.
- Method C: Train for 100 epochs, prune, retrain for 100 epochs. Total cost: 200 training epochs.
Method C will almost always achieve higher accuracy, but at 2x the computational cost. Comparing its accuracy to Method A’s is misleading.
The fair comparison approach is to fix the total training budget (e.g., 100 GPU-hours) and compare what each method achieves within that budget. Under this framework, methods like SNIP (which prune before training) gain a significant advantage: they spend their entire budget on training the sparse network, while iterative methods must split the budget between training, pruning, and retraining.
Common Pitfalls in Pruning Evaluation#
| Pitfall | Description | Impact |
|---|---|---|
| Unequal training budgets | Comparing methods with different total training epochs | Inflates accuracy of high-cost methods |
| Missing actual latency | Reporting only FLOPs/parameter reduction | Unstructured sparsity may not speed up real hardware |
| Cherry-picked sparsity | Reporting only the sparsity level where method excels | Hides poor performance at other sparsity levels |
| Single seed | Reporting results from one random seed | Hides variance, especially at high sparsity |
| Dense baseline mismatch | Comparing against a weak dense baseline | Inflates relative accuracy retention |
| Ignoring fine-tuning | Not fine-tuning after pruning | Underestimates post-hoc methods |
| Layer-wise vs. global | Not specifying whether sparsity is per-layer or global | Different allocations yield very different results |
| Comparing structured vs. unstructured | Mixing structured and unstructured methods in same table | Not comparable — different hardware requirements |
Summary#
Complete Taxonomy#
| Method | Type | Criterion | When Applied | Data Needed | Cost |
|---|---|---|---|---|---|
| Magnitude | Post-training | \(|w_i|\) | After training | None | Negligible |
| First-Order Taylor | Post-training | \(|w_i g_i|\) | After training | Calibration set | 1 forward-backward |
| Second-Order Taylor / OBD | Post-training | \(w_i^2 h_{ii}\) | After training | Calibration set | Hessian diagonal |
| OBS / oBERT | Post-training | \(w_i^2 / [H^{-1}]_{ii}\) | After training | Calibration set | Row-wise Hessian inverse |
| SNIP | At initialization | \(|w_j \cdot \partial L/\partial w_j|\) | Before training | 1 mini-batch | 1 forward-backward |
| GraSP | At initialization | \(-(Hg)_j \cdot w_j\) | Before training | 1 mini-batch | 2 forward-backward |
| SynFlow | At initialization | Path product saliency | Before training | None | n forward-backward |
| Movement Pruning | During training | \(\sum w \cdot \Delta w\) | During fine-tuning | Training data | Full training |
| Continuous Sparsification | During training | Learned sigmoid masks | During training | Training data | Full training |
| STR | During training | Learned per-layer threshold | During training | Training data | Full training |
| Powerpropagation | During training | Power reparameterization | During training | Training data | Full training |
| L1 Regularization | During training | Proximal threshold | During training | Training data | Full training |
| Group LASSO | During training | Group norm threshold | During training | Training data | Full training |
| CBS | Post-training | Submodular optimization | After training | Calibration set | Greedy selection |
Method Selection Guide#
START: What is your scenario?
|
+-- "I have a pre-trained model and want to prune quickly"
| |
| +-- Small model (< 100M params) --> OBS / oBERT
| +-- Large model (> 1B params) --> Magnitude or First-Order Taylor
| +-- Need structured sparsity --> Group LASSO + fine-tune
|
+-- "I want to train a sparse model from scratch"
| |
| +-- Have training data --> Continuous Sparsification or STR
| +-- No training data yet --> SynFlow (data-free)
| +-- Want simplicity --> Powerpropagation
|
+-- "I am fine-tuning a pre-trained model (BERT, etc.)"
| |
| +-- Movement Pruning (best for transfer learning)
| +-- + Knowledge Distillation for maximum accuracy recovery
|
+-- "I need pruning at initialization (one-shot, minimal cost)"
| |
| +-- Moderate sparsity (< 90%) --> SNIP
| +-- High sparsity (> 95%) --> SynFlow (avoids layer collapse)
| +-- Care about trainability --> GraSP
|
+-- "I want theoretical guarantees"
|
+-- Submodularity guarantees --> CBS
+-- Layer collapse avoidance --> SynFlow
+-- Optimal weight compensation --> OBS / oBERTKey Takeaways#
No single method dominates all scenarios. The best pruning method depends on the computational budget, model size, sparsity target, and whether you are training from scratch or fine-tuning.
Movement trumps magnitude for fine-tuning. When pruning pre-trained models, the training dynamics (captured by movement pruning) are far more informative than the static weight magnitudes.
Gradient-based methods form a spectrum. First-order Taylor is cheap but approximate. Second-order methods (OBD, OBS) are more accurate but expensive. SNIP and GraSP operate at initialization, trading accuracy for zero training cost.
Layer collapse is a real failure mode. At high sparsity, methods like SNIP and GraSP can catastrophically remove entire layers. SynFlow’s conservation-based approach provably prevents this.
Continuous pruning methods are the most flexible. Methods like STR and Continuous Sparsification that learn masks during training can automatically discover per-layer sparsity ratios, eliminating a major hyperparameter burden.
Knowledge distillation is nearly always beneficial. Regardless of the pruning method used, adding a distillation loss from the unpruned teacher consistently improves the pruned model’s accuracy.
Evaluation must be fair. When comparing pruning methods, control for total computational budget, report actual latency (not just FLOPs), test across multiple sparsity levels, and use multiple random seeds.
The frontier is moving toward structured sparsity. While unstructured pruning achieves higher accuracy at a given sparsity level, structured pruning (via Group LASSO, block sparsity, or N:M patterns) is increasingly favored because it translates directly to hardware speedups.
Preview: Pruning for Large Language Models#
The methods covered in this post were primarily developed and evaluated on models with hundreds of millions of parameters. The next post in this series — Pruning for LLMs — tackles the unique challenges that arise when pruning models with billions to hundreds of billions of parameters:
- SparseGPT: One-shot pruning for GPT-scale models using approximate OBS with lazy Hessian updates
- Wanda: Pruning by weights and activations — a magnitude-like criterion boosted by activation norms
- 2:4 Structured Sparsity: NVIDIA’s hardware-native sparsity pattern and how to achieve it
- Pruning + Quantization: Combining complementary compression techniques
- The “pruning paradox” for LLMs: Why larger models are easier to prune than smaller ones
- Scaling laws for sparse models: How sparsity interacts with model and data scale
These LLM-specific methods build directly on the foundations covered here, adapting the principles of Taylor expansion, Hessian approximation, and structured sparsity to the extreme scale of modern language models.