Table of Contents
Neural networks are remarkably over-parameterized. A ResNet-50 contains approximately 25.6 million parameters, yet research consistently demonstrates that 90% or more of those weights can be removed with negligible loss in accuracy. This observation raises a fundamental question: if most weights are unnecessary, why do we train dense networks at all? Pruning is the systematic study and practice of identifying and removing redundant parameters from trained (or even untrained) neural networks. This post provides a thorough, mathematically grounded treatment of pruning fundamentals — the theoretical motivations, the algorithmic machinery, and the practical considerations that make pruning one of the most important tools in the model compression toolkit.
Overview#
Why Pruning Matters#
Modern deep learning models have grown to extraordinary sizes. GPT-3 contains 175 billion parameters. Vision transformers routinely exceed 600 million parameters. Yet these models carry enormous redundancy. The weights in a trained neural network are not all equally important; in fact, the vast majority contribute very little to the final output. Pruning exploits this redundancy by zeroing out (or physically removing) unimportant weights, yielding models that are smaller, faster, and often just as accurate.
The practical benefits of pruning are threefold:
- Memory reduction: A sparse model stores fewer nonzero values, reducing the memory footprint.
- Computation reduction: Multiplying by zero is trivial; sparse models skip unnecessary multiply-accumulate operations.
- Energy efficiency: Fewer operations mean less energy consumption, which is critical for edge deployment and large-scale inference.
Historical Context#
The idea of pruning neural networks is not new. In 1990, Yann LeCun and colleagues published Optimal Brain Damage (OBD), which used second-order information (the diagonal of the Hessian matrix) to decide which weights to remove. Three years later, Hasselmo, Stork, and Wolff introduced Optimal Brain Surgeon (OBS), which generalized OBD by using the full inverse Hessian. These methods established the theoretical foundations that modern pruning still builds upon.
The field saw renewed interest around 2015-2019, driven by the deployment of deep learning on mobile devices and the publication of the Lottery Ticket Hypothesis (Frankle & Carlin, 2019), which provided a compelling theoretical narrative for why pruning works so well.
Pruning in the Model Compression Toolkit#
Pruning is one of four major model compression techniques:
| Technique | What It Does | Typical Savings |
|---|---|---|
| Pruning | Removes unnecessary weights (sets to zero or deletes) | 10-100x parameter reduction |
| Quantization | Reduces numerical precision (FP32 to INT8 or lower) | 2-4x memory reduction |
| Knowledge Distillation | Trains a smaller “student” model to mimic a larger “teacher” | Architecture-dependent |
| Neural Architecture Search (NAS) | Searches for efficient architectures automatically | Architecture-dependent |
These techniques are complementary. A practitioner might first prune a model to 90% sparsity, then quantize the remaining weights to INT8, achieving a combined compression ratio exceeding 40x.
The Pruning Pipeline#
The standard pruning workflow follows a three-stage pipeline:
+------------------+ +------------------+ +------------------+
| 1. Train Dense | --> | 2. Prune Weights | --> | 3. Fine-Tune |
| Network | | (apply mask M) | | (recover acc.) |
+------------------+ +------------------+ +------------------+
^ |
| |
+-------------- Iterate (optional) ----------------+Step 1 — Train: Train the full (dense) network to convergence or near-convergence.
Step 2 — Prune: Evaluate each weight according to some importance criterion and zero out the least important ones.
Step 3 — Fine-tune: Retrain the pruned network for a few epochs to recover any lost accuracy.
In iterative pruning, steps 2 and 3 are repeated multiple times, each time removing a small fraction of the remaining weights. This gradual approach typically preserves accuracy far better than one-shot pruning to the same final sparsity.
The Theory of Over-Parameterization#
Why Neural Networks Are Over-Parameterized#
A neural network with \(n\) parameters defines a function \(f_\theta: \mathbb{R}^d \to \mathbb{R}^k\) where \(\theta \in \mathbb{R}^n\). For a dataset of \(m\) training examples, classical learning theory suggests that \(n \approx m\) parameters should suffice. In practice, successful deep networks have \(n \gg m\) — often by orders of magnitude.
This over-parameterization is not a bug; it is a feature. Over-parameterized networks:
- Converge more easily: The loss landscape becomes smoother with more parameters, making gradient descent more likely to find good minima.
- Generalize better: Counter-intuitively, larger models often generalize better, a phenomenon partially explained by implicit regularization of SGD.
- Contain redundant substructures: Many different subsets of parameters can represent the same function.
The third point is the key insight for pruning. If the same function can be represented by many subsets of the parameters, then we can find a small subset that works well and discard the rest.
Lottery Ticket Hypothesis (Frankle & Carlin, 2019)#
The Lottery Ticket Hypothesis (LTH) provides perhaps the most elegant theoretical framework for understanding why pruning works. It was introduced by Jonathan Frankle and Michael Carlin in their 2019 ICLR paper.
Statement#
Lottery Ticket Hypothesis: A randomly-initialized, dense neural network \(f(x; \theta_0)\) contains a subnetwork \(f(x; m \odot \theta_0)\) that, when trained in isolation from the same initialization \(\theta_0\), can match the test accuracy of the original network after training for at most the same number of iterations.
Here, \(m \in {0, 1}^{|\theta|}\) is a binary mask, \(\theta_0\) is the initial set of weights, and \(\odot\) denotes element-wise multiplication. The subnetwork \(f(x; m \odot \theta_0)\) is called a winning ticket.
Formal Definition#
Let us define this precisely. Consider:
- A neural network architecture \(\mathcal{A}\) with parameter space \(\Theta \subseteq \mathbb{R}^n\)
- An initialization distribution \(\mathcal{D}_\theta\) (e.g., Kaiming normal)
- A training algorithm \(\text{Train}(\theta, D, T)\) that trains parameters \(\theta\) on dataset \(D\) for \(T\) iterations
- Initial parameters \(\theta_0 \sim \mathcal{D}_\theta\)
After training: \(\theta_T = \text{Train}(\theta_0, D, T)\) with test accuracy \(a(\theta_T)\).
The LTH claims there exists a mask \(m \in {0,1}^n\) with \(|m|_0 \ll n\) such that:
$$\theta_T' = \text{Train}(m \odot \theta_0, D, T')$$achieves \(a(\theta_T’) \geq a(\theta_T)\) with \(T’ \leq T\) and \(|m|_0 / n\) is small (e.g., 10-20% of original parameters).
Iterative Magnitude Pruning (IMP) Algorithm#
The winning tickets are found via Iterative Magnitude Pruning (IMP):
Algorithm: Iterative Magnitude Pruning (IMP)
--------------------------------------------
Input: Network f(x; theta), pruning rate p per round,
number of rounds R, dataset D
1. Initialize theta_0 randomly
2. For round r = 1, 2, ..., R:
3. Train the network to convergence:
theta_T = Train(m_{r-1} * theta_0, D, T)
4. Compute importance scores:
s_i = |theta_T[i]| for all unmasked weights
5. Determine threshold tau_r:
tau_r = Percentile(s, p)
6. Update mask:
m_r[i] = m_{r-1}[i] AND (s_i >= tau_r)
7. Reset surviving weights to their INITIAL values:
theta = m_r * theta_0
8. Return final mask m_R and initial weights theta_0Key insight: In step 7, the surviving weights are reset to their original initialization \(\theta_0\), not to their trained values. This is what makes LTH remarkable — the structure of the winning ticket, combined with the specific initial values, is what enables successful training.
Numerical Example: Suppose we have a tiny network with 10 weights, and we prune 20% per round for 3 rounds.
Round 0: 10 weights active (100%)
Round 1: Remove bottom 20% -> 8 weights active (80%)
Round 2: Remove bottom 20% of remaining -> ~6 weights active (64%)
Round 3: Remove bottom 20% of remaining -> ~5 weights active (51.2%)
Final sparsity: 1 - 0.8^3 = 1 - 0.512 = 48.8%In general, after \(R\) rounds of pruning fraction \(p\):
$$s_R = 1 - (1-p)^R$$For \(p = 0.2\) and \(R = 10\): \(s_{10} = 1 - 0.8^{10} = 1 - 0.107 = 89.3%\) sparsity.
Rewinding Variants#
The original LTH resets weights to \(\theta_0\) (initialization). Later work introduced k-epoch rewinding:
- Full rewind (\(k = 0\)): Reset to \(\theta_0\). Works well on small networks (MNIST, small CIFAR-10 models).
- Early rewind (\(k > 0\)): Reset to \(\theta_k\), the weights after \(k\) epochs of training. Frankle et al. (2020) showed that rewinding to epoch \(k\) (where \(k\) is small, e.g., 1-5% of total training) is necessary for larger networks and datasets.
The rewinding point \(k\) represents the point at which the network has “found its trajectory” in the loss landscape. Before epoch \(k\), the training dynamics are chaotic; after epoch \(k\), the network settles into a basin of attraction.
Evidence and Experiments#
Frankle and Carlin demonstrated the LTH on several architectures:
| Network | Dataset | Sparsity at Matching Accuracy | Parameters Remaining |
|---|---|---|---|
| LeNet-300-100 | MNIST | 96.4% | 3.6% |
| Conv-2/4/6 | CIFAR-10 | 88.2-95.0% | 5.0-11.8% |
| ResNet-18 | CIFAR-10 | ~90% (with rewinding) | ~10% |
| VGG-19 | CIFAR-10 | ~93.5% | ~6.5% |
The winning tickets not only matched the original accuracy but often achieved it faster (in fewer training iterations) and sometimes even exceeded the original accuracy. Random subnetworks of the same size, by contrast, performed significantly worse, confirming that the specific structure of the winning ticket matters.
Limitations at Scale#
The original LTH faces challenges at scale:
- Computational cost: IMP requires training the full network \(R\) times, making it very expensive for large models.
- Rewinding necessity: For ImageNet-scale models, rewind to initialization (\(k=0\)) fails. Rewinding to \(k > 0\) is required, weakening the original claim.
- Task specificity: Winning tickets found for one task do not necessarily transfer to other tasks, though some universality has been observed.
Linear Mode Connectivity#
Linear Mode Connectivity (LMC) provides additional insight into why rewinding works. Two models \(\theta_A\) and \(\theta_B\) are linearly mode connected if every point on the line segment between them has low loss:
$$L(\alpha \theta_A + (1-\alpha) \theta_B) \leq \max(L(\theta_A), L(\theta_B)) \quad \forall \alpha \in [0, 1]$$Frankle et al. (2020) showed that networks trained from the same rewind point \(\theta_k\) but with different data orders are linearly mode connected, while networks trained from \(\theta_0\) are often not. This suggests that by epoch \(k\), the network has committed to a particular “basin” of the loss landscape, and the specific initialization within that basin (i.e., \(\theta_k\)) is what matters.
This explains why \(k\)-epoch rewinding works: \(\theta_k\) lies in the right basin, and the mask found by IMP identifies which parameters are important within that basin.
Strong Lottery Ticket Hypothesis#
The Strong Lottery Ticket Hypothesis makes an even bolder claim:
A sufficiently over-parameterized random network contains a subnetwork that, without any training, achieves accuracy comparable to a trained network.
This means the winning ticket does not even need to be trained — it exists “at birth.” The key idea is that a sufficiently large random network contains, with high probability, every possible small subnetwork. One can think of this as the neural network equivalent of the infinite monkey theorem.
Edge-Popup Algorithm#
Ramanujan et al. (2020) proposed the Edge-Popup algorithm to find these subnetworks:
Algorithm: Edge-Popup
---------------------
Input: Random (fixed) weights theta, target sparsity s
1. Initialize popup scores S_i ~ N(0, sigma^2) for each weight
2. For each training step:
3. Compute mask: m_i = 1 if S_i is in top (1-s) fraction, else 0
4. Forward pass: y = f(x; m * theta) [theta is FIXED]
5. Backward pass: compute dL/dS_i
6. Update scores: S_i <- S_i - eta * dL/dS_i
7. Return mask m (weights theta are NEVER updated)The scores \(S_i\) are differentiable (using a straight-through estimator for the thresholding step), so they can be optimized by gradient descent. The weights \(\theta\) themselves are never modified — only the selection of which weights to include is learned.
Proof Sketch of Existence#
The existence of good subnetworks in random networks can be shown probabilistically. Consider a target network with weights \(w_1^, w_2^, \ldots, w_k^*\) and a random network with \(n \gg k\) weights drawn i.i.d. from \(\mathcal{N}(0, \sigma^2)\).
For each target weight \(w_j^\), the probability that at least one of the \(n\) random weights falls within \(\epsilon\) of \(w_j^\) is:
$$P(\exists i : |w_i - w_j^*| < \epsilon) = 1 - \left(1 - \frac{2\epsilon}{\sigma\sqrt{2\pi}} e^{-\frac{(w_j^*)^2}{2\sigma^2}}\right)^n$$For large \(n\), this probability approaches 1. By a union bound over all \(k\) target weights, the probability that the random network contains an \(\epsilon\)-approximate copy of the target network is at least:
$$P(\text{match all}) \geq 1 - k \cdot \left(1 - \frac{2\epsilon}{\sigma\sqrt{2\pi}}\right)^n$$When \(n = \Omega(k \log k / \epsilon)\), this probability is high, completing the argument. More rigorous treatments (e.g., Malach et al., 2020; Pensia et al., 2020) formalize this for networks with multiple layers.
Weight Magnitude Pruning#
Weight magnitude pruning is the simplest and most widely used pruning criterion. The fundamental assumption is straightforward: small weights contribute little to the network’s output, so they can be safely removed.
L1-Norm Pruning#
The L1-norm criterion assigns importance based on absolute value:
$$\text{score}(w_i) = |w_i|$$Weights with the smallest absolute values are pruned first. The rationale is that a weight close to zero has minimal effect on the neuron’s output: if \(w_i \approx 0\), then the contribution \(w_i \cdot x_i \approx 0\) regardless of the input \(x_i\).
L2-Norm Pruning#
The L2-norm criterion squares the weights:
$$\text{score}(w_i) = w_i^2$$This is mathematically equivalent to L1 for pruning purposes (the ordering is identical since \(|a| > |b| \iff a^2 > b^2\) for real numbers), but it becomes different when used for structured pruning of filters, where L1 and L2 norms of vectors can rank elements differently.
For a filter \(F_j\) with weights \({w_1, w_2, \ldots, w_k}\):
$$\text{L1-score}(F_j) = \sum_{i=1}^{k} |w_i|, \quad \text{L2-score}(F_j) = \sqrt{\sum_{i=1}^{k} w_i^2}$$These can produce different rankings. A filter with many small weights might score higher under L1 than L2 compared to a filter with one large weight and many zeros.
Global vs Local Pruning#
There are two strategies for deciding which weights to prune:
Local pruning prunes \(p\%\) of weights independently in each layer:
$$\tau_l = \text{Percentile}_p(\{|w_i| : w_i \in W_l\})$$Weight \(w_i\) in layer \(l\) is pruned if \(|w_i| < \tau_l\).
Global pruning uses a single threshold across all layers:
$$\tau = \text{Percentile}_p(\{|w_i| : w_i \in W_1 \cup W_2 \cup \cdots \cup W_L\})$$Weight \(w_i\) in any layer is pruned if \(|w_i| < \tau\).
Why Global Is Generally Better#
Different layers have different weight distributions and different sensitivities to pruning. Early layers in a CNN tend to have smaller weights but are more sensitive (they extract low-level features that all subsequent layers depend on). Global pruning naturally adapts to this: it removes fewer weights from sensitive layers (which happen to have weight magnitudes comparable to other layers) and more from redundant layers.
Numerical Example: Consider a 2-layer network.
Layer 1 weights: [0.01, 0.05, 0.10, 0.20, 0.50]
Layer 2 weights: [0.30, 0.40, 0.60, 0.80, 1.00]
Target: prune 40% (remove 4 out of 10 weights)Local pruning (40% per layer):
- Layer 1: remove 2 smallest -> prune 0.01, 0.05 -> keep [0.10, 0.20, 0.50]
- Layer 2: remove 2 smallest -> prune 0.30, 0.40 -> keep [0.60, 0.80, 1.00]
Global pruning (40% overall):
- All weights sorted: [0.01, 0.05, 0.10, 0.20, 0.30, 0.40, 0.50, 0.60, 0.80, 1.00]
- Global threshold at 40th percentile: tau = 0.20
- Prune all weights < 0.20: remove 0.01, 0.05, 0.10 from Layer 1 and nothing from Layer 2 (Actually 3 weights < 0.20, and 0.20 is exactly at the boundary. With strict inequality, we prune 3; we need to prune 4, so we also prune 0.20.)
- Layer 1: keep [0.50] (4 pruned from Layer 1)
- Layer 2: keep [0.30, 0.40, 0.60, 0.80, 1.00] (0 pruned from Layer 2)
Global pruning aggressively prunes Layer 1 (which has smaller magnitudes) and preserves Layer 2 entirely. Whether this is better depends on the network, but empirically global pruning outperforms local pruning more often than not, precisely because it allocates sparsity non-uniformly based on layer sensitivity (as reflected in weight magnitudes).
Sparsity Ratio#
The sparsity ratio quantifies the fraction of parameters that have been pruned:
$$s = 1 - \frac{n_{\text{nonzero}}}{n_{\text{total}}}$$A sparsity of 0.90 (or 90%) means 90% of weights are zero and only 10% remain. The compression ratio is:
$$\text{CR} = \frac{1}{1 - s} = \frac{n_{\text{total}}}{n_{\text{nonzero}}}$$At 90% sparsity, \(\text{CR} = 10\times\).
The Accuracy-Sparsity Curve#
A typical accuracy-vs-sparsity curve has a characteristic shape: accuracy is nearly flat up to high sparsity, then drops sharply.
Accuracy
|
1 |* * * * * * * * *
| *
| *
| *
| *
| *
| *
| *
| *
| *
+-----------------------------------------> Sparsity
0% 20% 40% 60% 80% 90% 95% 99%The “knee” of the curve — where accuracy begins to drop significantly — varies by network and dataset. For many networks on standard benchmarks, this knee occurs between 80% and 95% sparsity, meaning the network can tolerate removing the vast majority of its weights with minimal performance degradation.
Step-by-Step Pruning Example#
Consider the following \(3 \times 3\) weight matrix:
$$W = \begin{bmatrix} 0.52 & -0.03 & 0.81 \\\\ -0.17 & 0.95 & 0.04 \\\\ 0.11 & -0.68 & -0.02 \end{bmatrix}$$Step 1: Compute importance scores (L1-norm: absolute value):
$$|W| = \begin{bmatrix} 0.52 & 0.03 & 0.81 \\\\ 0.17 & 0.95 & 0.04 \\\\ 0.11 & 0.68 & 0.02 \end{bmatrix}$$Step 2: Flatten and sort: \([0.02, 0.03, 0.04, 0.11, 0.17, 0.52, 0.68, 0.81, 0.95]\)
Step 3: Choose sparsity \(s = 55.6%\) (prune 5 of 9 weights). Threshold: the 5th smallest value is 0.17, so \(\tau = 0.17\). Prune all weights with \(|w_i| \leq \tau\).
Step 4: Construct the binary mask:
$$M = \begin{bmatrix} 1 & 0 & 1 \\\\ 0 & 1 & 0 \\\\ 0 & 1 & 0 \end{bmatrix}$$Step 5: Apply mask:
$$W_{\text{pruned}} = W \odot M = \begin{bmatrix} 0.52 & 0 & 0.81 \\\\ 0 & 0.95 & 0 \\\\ 0 & -0.68 & 0 \end{bmatrix}$$Result: 4 nonzero weights remain out of 9, giving sparsity \(s = 1 - 4/9 = 55.6%\).
Sensitivity-Based Pruning#
Weight magnitude pruning ignores a crucial factor: the curvature of the loss surface. A small weight might sit in a region of high curvature, meaning removing it causes a large increase in loss. A large weight might sit in a flat region, meaning its removal barely matters. Sensitivity-based methods use second-order information to account for this.
Optimal Brain Damage (OBD, LeCun 1990)#
Key Idea#
Optimal Brain Damage uses the Hessian matrix — the matrix of second partial derivatives of the loss — to estimate how much the loss will change when a weight is removed.
Derivation#
Consider the loss function \(L(\theta)\) where \(\theta \in \mathbb{R}^n\) is the vector of all weights. We want to estimate \(\delta L = L(\theta + \delta\theta) - L(\theta)\) when we set some weight \(w_q\) to zero (i.e., \(\delta w_q = -w_q\)).
Taylor expansion of the loss around the current weights:
$$L(\theta + \delta\theta) = L(\theta) + \sum_i \frac{\partial L}{\partial w_i} \delta w_i + \frac{1}{2} \sum_i \sum_j \frac{\partial^2 L}{\partial w_i \partial w_j} \delta w_i \delta w_j + O(|\delta\theta|^3)$$In compact notation:
$$\delta L = g^T \delta\theta + \frac{1}{2} \delta\theta^T H \delta\theta + O(|\delta\theta|^3)$$where \(g = \nabla_\theta L\) is the gradient and \(H = \nabla^2_\theta L\) is the Hessian.
Assumption 1 — Convergence: The network is trained to a local minimum, so the gradient is approximately zero:
$$g \approx 0 \implies g^T \delta\theta \approx 0$$Assumption 2 — Diagonal Hessian: The off-diagonal elements of the Hessian are negligible:
$$H_{ij} \approx 0 \quad \text{for } i \neq j$$Under these two assumptions, the loss change simplifies dramatically:
$$\delta L \approx \frac{1}{2} \sum_i H_{ii} (\delta w_i)^2 = \frac{1}{2} \sum_i h_{ii} (\delta w_i)^2$$where \(h_{ii} = \frac{\partial^2 L}{\partial w_i^2}\) is the \(i\)-th diagonal element of the Hessian.
When we prune weight \(w_q\), we set it to zero: \(\delta w_q = -w_q\) and \(\delta w_i = 0\) for \(i \neq q\). Therefore:
$$\delta L_q \approx \frac{1}{2} h_{qq} w_q^2$$This is the OBD saliency score:
$$\boxed{s_q^{\text{OBD}} = \frac{1}{2} h_{qq} w_q^2}$$Weights with the smallest saliency are pruned first, as they cause the least increase in loss.
Interpreting the Saliency Score#
The OBD saliency score \(s_q = \frac{1}{2} h_{qq} w_q^2\) is the product of two factors:
- \(w_q^2\): the magnitude of the weight (same as magnitude pruning).
- \(h_{qq}\): the curvature of the loss with respect to that weight.
A weight is deemed unimportant if it is small (\(w_q^2\) is small) or if the loss landscape is flat in that direction (\(h_{qq}\) is small). This is strictly more informative than magnitude pruning alone.
Numerical Example: Consider three weights with their Hessian diagonals:
| Weight | \(w_q\) | \(h_{qq}\) | \(s_q = \frac{1}{2}h_{qq}w_q^2\) | Magnitude rank | OBD rank |
|---|---|---|---|---|---|
| A | 0.10 | 100.0 | 0.500 | 3 (prune first) | 2 |
| B | 0.50 | 0.1 | 0.013 | 1 (keep) | 3 (prune first) |
| C | 0.30 | 20.0 | 0.900 | 2 | 1 (keep) |
Magnitude pruning would remove weight A first (smallest magnitude). But OBD recognizes that A sits in a high-curvature region (\(h_{qq} = 100\)) and removing it would cause a large loss increase. Instead, OBD removes weight B first — despite its large magnitude, the flat curvature (\(h_{qq} = 0.1\)) means its removal is nearly harmless.
Computing the Diagonal Hessian#
The diagonal Hessian entries can be computed efficiently using backpropagation. For a loss function \(L\), the diagonal entry is:
$$h_{ii} = \frac{\partial^2 L}{\partial w_i^2}$$This can be estimated empirically by averaging over a batch of training examples:
$$h_{ii} \approx \frac{1}{|B|} \sum_{(x,y) \in B} \frac{\partial^2 L(x, y; \theta)}{\partial w_i^2}$$Alternatively, one can use the Gauss-Newton approximation, which only requires first-order derivatives:
$$h_{ii} \approx \frac{1}{|B|} \sum_{(x,y) \in B} \left(\frac{\partial L(x, y; \theta)}{\partial w_i}\right)^2$$This approximation is the basis of the Fisher information approach discussed later.
OBD Algorithm#
Algorithm: Optimal Brain Damage
-------------------------------
Input: Trained network with weights theta, dataset D,
number of weights to prune K
1. Compute diagonal Hessian h_ii for all weights:
h_ii = (1/|D|) * sum over (x,y) in D of d^2L/dw_i^2
2. Compute saliency for each weight:
s_i = 0.5 * h_ii * w_i^2
3. Sort weights by saliency in ascending order
4. Prune the K weights with smallest saliency (set to zero)
5. Fine-tune the remaining weights
6. Optionally repeat from step 1Optimal Brain Surgeon (OBS, Hasselmo et al. 1993)#
Removing the Diagonal Assumption#
OBD assumes the Hessian is diagonal, which is often a poor approximation. In practice, weights interact with each other, and the off-diagonal terms of the Hessian capture these interactions. Optimal Brain Surgeon (OBS) removes this assumption and uses the full inverse Hessian.
Derivation Using Lagrange Multipliers#
We want to find the weight change \(\delta\theta\) that minimizes the loss increase when weight \(w_q\) is set to zero. This is a constrained optimization problem:
Objective: Minimize \(\delta L = \frac{1}{2} \delta\theta^T H \delta\theta\) (assuming convergence, so \(g \approx 0\))
Constraint: \(e_q^T (\theta + \delta\theta) = 0\), i.e., the \(q\)-th weight becomes zero.
This constraint can be rewritten as:
$$e_q^T \delta\theta + w_q = 0$$where \(e_q\) is the \(q\)-th standard basis vector.
Setting up the Lagrangian:
$$\mathcal{L}(\delta\theta, \lambda) = \frac{1}{2} \delta\theta^T H \delta\theta + \lambda(e_q^T \delta\theta + w_q)$$Taking the derivative with respect to \(\delta\theta\) and setting it to zero:
$$\frac{\partial \mathcal{L}}{\partial \delta\theta} = H \delta\theta + \lambda e_q = 0$$$$\delta\theta = -\lambda H^{-1} e_q$$Substituting back into the constraint:
$$e_q^T(-\lambda H^{-1} e_q) + w_q = 0$$$$-\lambda [H^{-1}]_{qq} + w_q = 0$$$$\lambda = \frac{w_q}{[H^{-1}]_{qq}}$$Therefore, the optimal weight update when pruning weight \(q\) is:
$$\boxed{\delta\theta = -\frac{w_q}{[H^{-1}]_{qq}} H^{-1} e_q}$$This is remarkable: when we remove weight \(q\), OBS tells us to also adjust all other weights to optimally compensate. The adjustment is proportional to the \(q\)-th column of \(H^{-1}\).
The resulting increase in loss is:
$$\delta L = \frac{1}{2} \delta\theta^T H \delta\theta = \frac{1}{2} \frac{w_q^2}{[H^{-1}]_{qq}^2} (H^{-1} e_q)^T H (H^{-1} e_q)$$$$= \frac{1}{2} \frac{w_q^2}{[H^{-1}]_{qq}^2} e_q^T H^{-1} e_q = \frac{1}{2} \frac{w_q^2}{[H^{-1}]_{qq}^2} [H^{-1}]_{qq}$$$$\boxed{L_q^{\text{OBS}} = \frac{w_q^2}{2[H^{-1}]_{qq}}}$$Comparison with OBD#
| Aspect | OBD | OBS |
|---|---|---|
| Hessian assumption | Diagonal | Full |
| Weight update | Only pruned weight set to zero | All weights adjusted optimally |
| Saliency | \(\frac{1}{2} h_{qq} w_q^2\) | \(\frac{w_q^2}{2[H^{-1}]_{qq}}\) |
| Computational cost | \(O(n)\) | \(O(n^2)\) to \(O(n^3)\) |
| Accuracy after pruning | Good | Better (due to optimal compensation) |
Note the crucial difference: OBD uses the Hessian diagonal \(h_{qq}\) directly, while OBS uses the inverse Hessian diagonal \([H^{-1}]{qq}\). These are very different quantities. If the Hessian were truly diagonal, \([H^{-1}]{qq} = 1/h_{qq}\), and the OBS saliency would reduce to \(\frac{1}{2} h_{qq} w_q^2\), recovering OBD. But when off-diagonal terms are significant, OBS provides a better estimate.
Connection to GPTQ#
The OBS framework directly inspired GPTQ (Frantar et al., 2022), a state-of-the-art post-training quantization method for large language models. GPTQ uses the same Lagrangian formulation but applies it to quantization rather than pruning: instead of constraining a weight to be zero, it constrains the weight to the nearest quantization level. The optimal compensation formula is identical in structure, with the quantization error replacing \(w_q\).
Fisher Information Based Pruning#
The Fisher Information Matrix (FIM) provides yet another way to estimate weight importance. It is closely related to the Hessian but can be computed using only first-order derivatives.
Definition#
For a model with parameters \(\theta\) that defines a conditional distribution \(p(y|x, \theta)\):
$$F = \mathbb{E}_{x \sim p(x)} \mathbb{E}_{y \sim p(y|x,\theta)} \left[\nabla_\theta \log p(y|x,\theta) \cdot \nabla_\theta \log p(y|x,\theta)^T\right]$$Relationship to the Hessian#
For models trained with negative log-likelihood loss \(L = -\log p(y|x,\theta)\), the Fisher information matrix equals the expected Hessian of the loss (under the model’s own distribution):
$$F = \mathbb{E}\left[-\nabla^2_\theta \log p(y|x,\theta)\right] = \mathbb{E}[H]$$This means the Fisher matrix is a positive semi-definite approximation to the Hessian, and it can be computed using only gradient samples — no second derivatives are needed.
Efficient Computation#
In practice, the full Fisher matrix is too large to store (\(n \times n\) for \(n\) parameters). We use the diagonal approximation:
$$F_{ii} \approx \frac{1}{|B|} \sum_{(x,y) \in B} \left(\frac{\partial L(x,y;\theta)}{\partial w_i}\right)^2$$This is simply the average squared gradient for each weight, computed over a batch \(B\) of training data.
Fisher Pruning Criterion#
The Fisher-based saliency score for weight \(w_q\) is:
$$s_q^{\text{Fisher}} = \frac{1}{2} F_{qq} w_q^2$$This has the same form as OBD (\(\frac{1}{2} h_{qq} w_q^2\)) but uses the Fisher diagonal instead of the Hessian diagonal. The advantage is computational: no second derivatives are needed.
Numerical Example: Given a batch of 4 training examples, suppose the gradients for weight \(w_3 = 0.4\) are:
$$\frac{\partial L}{\partial w_3} \in \{0.5, -0.3, 0.7, -0.1\}$$Then:
$$F_{33} = \frac{1}{4}(0.5^2 + 0.3^2 + 0.7^2 + 0.1^2) = \frac{1}{4}(0.25 + 0.09 + 0.49 + 0.01) = \frac{0.84}{4} = 0.21$$$$s_3^{\text{Fisher}} = \frac{1}{2} \times 0.21 \times 0.4^2 = \frac{1}{2} \times 0.21 \times 0.16 = 0.0168$$First-Order (Gradient) Pruning Methods#
Second-order methods (OBD, OBS, Fisher) can be expensive, especially for large models. First-order methods use only gradient information and offer a practical middle ground between simple magnitude pruning and expensive Hessian-based approaches.
Taylor Expansion: First-Order Term#
Revisiting the Taylor expansion of the loss, and not assuming the gradient is zero:
$$\delta L \approx \sum_i g_i \delta w_i + \frac{1}{2} \sum_i h_{ii} (\delta w_i)^2$$When we prune weight \(w_q\) (set \(\delta w_q = -w_q\)), the first-order contribution is:
$$\delta L^{(1)}_q = g_q \cdot (-w_q) = -w_q \cdot \frac{\partial L}{\partial w_q}$$To make this a non-negative importance score, we take the absolute value:
$$\boxed{s_q^{\text{Taylor-FO}} = \left|w_q \cdot \frac{\partial L}{\partial w_q}\right|}$$This is the Taylor first-order (Taylor-FO) pruning criterion. It measures importance as the product of weight magnitude and gradient magnitude.
Gradient x Weight Interpretation#
The Taylor-FO score \(|w \cdot g|\) has an intuitive interpretation. Consider the function output \(y = w \cdot x\):
- If \(|w|\) is large but \(|g| = |\partial L / \partial w|\) is small, then the weight contributes significantly to the output but changing it does not affect the loss much — the loss is insensitive to this weight. It could still be important.
- If \(|w|\) is small but \(|g|\) is large, then the weight contributes little now but the loss is very sensitive to it — it is in the process of being optimized and may become important.
- If both \(|w|\) and \(|g|\) are large, the weight is clearly important.
- If both are small, the weight is clearly unimportant.
The product captures both magnitude and sensitivity, providing a richer importance measure than either alone.
Movement Pruning (Sanh et al., 2020)#
Movement pruning was introduced for fine-tuning pretrained models (e.g., BERT) and is based on a philosophically different idea: importance is determined not by the current magnitude but by how weights move during training.
Motivation#
When fine-tuning a pretrained model, the initial weight magnitudes reflect the pretraining task, not the target task. Magnitude pruning would preserve weights that were important for the original task, which may not be the ones important for the fine-tuning task. Movement pruning instead looks at which weights are actively being used by the optimizer.
Score Definition#
Each weight \(w_i\) is assigned a score \(S_i\) that accumulates information about how the weight moves during fine-tuning:
$$S_i^{(t+1)} = S_i^{(t)} + \alpha \cdot w_i^{(t)} \cdot \frac{\partial L^{(t)}}{\partial w_i}$$Here \(\alpha\) is a scaling factor. The score increases when the weight and its gradient have the same sign (meaning the gradient is pushing the weight toward zero, suggesting it is unimportant) and decreases when they have opposite signs (the gradient is pushing the weight away from zero, suggesting it is important).
Wait — let us be more careful. In gradient descent, the update rule is:
$$w_i^{(t+1)} = w_i^{(t)} - \eta \frac{\partial L}{\partial w_i}$$The weight moves away from zero (increases in magnitude) when:
$$\text{sign}(w_i) = -\text{sign}\left(\frac{\partial L}{\partial w_i}\right)$$In this case, \(w_i \cdot \frac{\partial L}{\partial w_i} < 0\), so the movement score decreases. Weights moving away from zero get lower (more negative) scores, making them less likely to be pruned.
Conversely, weights moving toward zero have \(w_i \cdot \frac{\partial L}{\partial w_i} > 0\), the score increases, and they become more likely to be pruned.
This is the correct interpretation: weights moving away from zero are kept; weights moving toward zero are pruned.
Soft vs Hard Movement Pruning#
Hard movement pruning applies a binary mask based on the top-\(k\) scores:
$$m_i = \begin{cases} 1 & \text{if } S_i \text{ is in the top-}(1-s) \text{ fraction} \\\\ 0 & \text{otherwise} \end{cases}$$Soft movement pruning uses a smooth threshold with a straight-through estimator:
$$m_i = \sigma\left(\frac{S_i - \tau}{\beta}\right)$$where \(\sigma\) is the sigmoid function, \(\tau\) is a learned threshold, and \(\beta\) is a temperature. This allows gradients to flow through the mask during training.
Soft movement pruning generally outperforms hard movement pruning, especially at high sparsity levels, because the smooth mask allows for more nuanced importance estimates during training.
Pruning Schedule and Strategy#
The when and how much of pruning is just as important as the what. This section covers the major strategies for scheduling pruning operations.
One-Shot Pruning#
The simplest approach: prune all weights at once to the target sparsity.
Accuracy
|
1 |* * * * * * *
| \
| \
| * * * * * * * (after fine-tuning)
|
+----+----------+----------+--> Time
Train Prune Fine-tuneAdvantages: Simple, fast — only one prune-and-retrain cycle.
Disadvantages: The sudden removal of many weights causes a large, immediate accuracy drop. Fine-tuning may not fully recover this loss, especially at high sparsity.
Iterative Pruning#
Iterative pruning removes a small fraction of weights at each step, fine-tuning between steps:
Accuracy
|
1 |* * * * * * * * * * *
| \ / \ / \ / \ /
| * * * *
|
+---+--+--+--+--+--+--+--+--> Time
p1 ft p2 ft p3 ft p4 ft
p = prune step, ft = fine-tune stepEach pruning step removes only a small fraction, and fine-tuning restores accuracy before the next pruning step. This is much gentler than one-shot pruning and typically achieves better final accuracy at the same sparsity.
Cubic Sparsity Schedule (Zhu & Gupta, 2017)#
Rather than pruning a fixed percentage at each step, the cubic sparsity schedule gradually ramps up the sparsity according to a cubic polynomial:
$$s_t = s_f + (s_i - s_f)\left(1 - \frac{t - t_0}{n \Delta t}\right)^3$$where:
- \(s_t\): sparsity at step \(t\)
- \(s_i\): initial sparsity (usually 0)
- \(s_f\): final (target) sparsity
- \(t_0\): the step at which pruning begins
- \(\Delta t\): the interval between pruning operations
- \(n\): the number of pruning steps (so pruning ends at step \(t_0 + n\Delta t\))
Understanding the Cubic Schedule#
Let us define the normalized time variable \(\tau = \frac{t - t_0}{n \Delta t} \in [0, 1]\).
With \(s_i = 0\):
$$s(\tau) = s_f(1 - (1 - \tau)^3)$$Let us compute \(s\) at several points:
| \(\tau\) (progress) | \((1-\tau)^3\) | \(s(\tau) / s_f\) | Description |
|---|---|---|---|
| 0.0 | 1.000 | 0.000 | Start: no pruning |
| 0.1 | 0.729 | 0.271 | 27.1% of target sparsity reached |
| 0.2 | 0.512 | 0.488 | 48.8% |
| 0.3 | 0.343 | 0.657 | 65.7% |
| 0.5 | 0.125 | 0.875 | 87.5% |
| 0.7 | 0.027 | 0.973 | 97.3% |
| 1.0 | 0.000 | 1.000 | End: full target sparsity |
The schedule is aggressive at the start and gentle at the end. Most pruning happens in the first half of the schedule. This is desirable because:
- Early in the schedule, many clearly unimportant weights exist and can be safely removed.
- Late in the schedule, the remaining weights are more important, so we prune slowly and give the network more time to adapt.
Sparsity (s/s_f)
1.0 | * * * * * *
| * *
| *
| *
0.5 | *
| *
| *
| *
| *
| *
0.0 | *
+---+---+---+---+---+---+---+---+---+---+--> tau
0 0.1 0.2 0.3 0.4 0.5 0.6 0.7 0.8 0.9 1.0Numerical Example#
Suppose we want to prune a ResNet-50 from \(s_i = 0\) to \(s_f = 0.90\) (90% sparsity), starting at epoch 10 (\(t_0 = 10\)), pruning every 2 epochs (\(\Delta t = 2\)), for 20 pruning steps (\(n = 20\)). Pruning ends at epoch \(10 + 20 \times 2 = 50\).
Sparsity at epoch 20 (\(t = 20\), \(\tau = (20-10)/(20 \times 2) = 0.25\)):
$$s_{20} = 0.90 \times (1 - (1 - 0.25)^3) = 0.90 \times (1 - 0.422) = 0.90 \times 0.578 = 0.520$$At epoch 20, the network is already at 52.0% sparsity (more than halfway to the target).
Sparsity at epoch 40 (\(\tau = 0.75\)):
$$s_{40} = 0.90 \times (1 - 0.25^3) = 0.90 \times (1 - 0.016) = 0.90 \times 0.984 = 0.886$$At epoch 40, the network is at 88.6% sparsity — nearly at the target, with 10 more epochs of gentle pruning remaining.
Pruning at Initialization (Before Training)#
A provocative question: can we prune the network before training, saving the cost of training the full dense network? Several methods attempt this.
SNIP (Single-shot Network Pruning, Lee et al. 2019)#
SNIP introduces a binary mask variable \(c_j \in {0, 1}\) for each weight, so that the effective weight is \(c_j \cdot w_j\). The importance of each weight is measured by the sensitivity of the loss to the mask variable:
$$g_j = \frac{\partial L(c \odot \theta; x, y)}{\partial c_j} \Bigg|_{c = \mathbf{1}}$$By the chain rule:
$$g_j = \frac{\partial L}{\partial (c_j w_j)} \cdot w_j = \frac{\partial L}{\partial w_j'} \cdot w_j$$where \(w_j’ = c_j w_j\). The normalized importance score is:
$$s_j = \frac{|g_j|}{\sum_k |g_k|}$$The top \((1-s)\) fraction of weights (by score) are kept, and the rest are pruned.
Algorithm: SNIP
---------------
Input: Initialized network f(x; theta_0), dataset D,
target sparsity s
1. Sample a single mini-batch (x, y) from D
2. Forward pass with all masks c = 1:
L = Loss(f(x; 1 * theta_0), y)
3. Backward pass: compute g_j = dL/dc_j for all j
4. Compute normalized scores: s_j = |g_j| / sum(|g_k|)
5. Create mask: m_j = 1 if s_j >= Percentile(s, sparsity)
6. Train the pruned network f(x; m * theta_0) normallySNIP is remarkable in its simplicity: it requires only a single forward-backward pass on a single mini-batch to determine the pruning mask, before any training has occurred.
GraSP (Gradient Signal Preservation, Wang et al. 2020)#
GraSP argues that pruning should preserve the ability of gradients to flow through the network. Specifically, it aims to maximize the gradient flow after pruning.
The gradient flow is measured by the gradient norm: \(|g|^2 = g^T g\). GraSP approximates how pruning weight \(j\) affects the gradient norm by considering the change in \(g^T H g\):
$$S_j = -\frac{\partial}{\partial c_j}(g^T H g)\Bigg|_{c=\mathbf{1}}$$The Hessian-gradient product \(Hg\) can be computed efficiently using a single additional forward-backward pass (the “Pearlmutter trick”). The score \(S_j\) measures how much removing connection \(j\) would reduce the gradient flow. Weights with large negative \(S_j\) (meaning their removal would greatly reduce gradient flow) are kept.
In practice, this simplifies to:
$$S_j = -(Hg)_j \cdot w_j$$where \((Hg)_j\) is the \(j\)-th element of the Hessian-gradient product. Weights with large positive \(S_j\) are pruned (their removal increases gradient flow), and weights with large negative \(S_j\) are kept.
SynFlow (Iterative Synaptic Flow Pruning, Tanaka et al. 2020)#
SynFlow addresses a critical failure mode of pruning-at-initialization methods: layer collapse. Layer collapse occurs when all weights in a layer are pruned, disconnecting the network completely. Once a layer is fully pruned, no gradient can flow through the network, and accuracy drops to random chance.
SynFlow avoids layer collapse through a data-free, iterative pruning criterion. The key idea is to measure the total “flow” of signals through each synaptic path in the network.
For a network with \(L\) layers with weight matrices \(W_1, W_2, \ldots, W_L\), define the synaptic flow score:
$$R = \mathbf{1}^T \left(\prod_{l=1}^{L} |W_l|\right) \mathbf{1}$$This is the sum of all products of absolute weight values along every path from input to output. The score for individual weight \(\theta_j\) in layer \(l\) is:
$$\boxed{R_j = \frac{\partial R}{\partial \theta_j} \odot \theta_j}$$Since \(R\) is a product of absolute values, \(R_j\) is always non-negative, and it is zero only if the weight lies on no active path. This ensures that SynFlow never causes layer collapse: if a layer has only one remaining nonzero weight, that weight’s SynFlow score will be proportional to the product of all weights along its path, which is generally nonzero.
The iterative procedure is critical. SynFlow does not prune to the target sparsity in one shot. Instead, it iteratively prunes a fraction \(p\) of weights per iteration:
Algorithm: SynFlow (Iterative)
------------------------------
Input: Network with weights theta, target sparsity s,
number of iterations T
1. Compute per-iteration pruning fraction: rho = 1 - (1-s)^(1/T)
2. For t = 1, ..., T:
3. Compute R = 1^T * (prod_l |W_l|) * 1
4. For each weight theta_j:
5. R_j = (dR/d|theta_j|) * |theta_j|
6. Prune rho fraction of weights with smallest R_j
7. Return final maskNote that SynFlow is entirely data-free — it does not use any training data. The scores depend only on the network weights and architecture.
Comparison of Pruning-at-Initialization Methods#
| Method | Data Required | Iterations | Avoids Layer Collapse | CIFAR-10 (90% sparsity, ResNet-20) |
|---|---|---|---|---|
| SNIP | 1 mini-batch | 1 | No | ~91.5% |
| GraSP | 1-2 mini-batches | 1 | No | ~91.2% |
| SynFlow | None | Multiple | Yes | ~91.0% |
| Random | None | 1 | No | ~89.5% |
| Magnitude (after training) | Full training | 1 | No | ~92.5% |
All three methods significantly outperform random pruning and approach the accuracy of magnitude pruning (which requires full training). SNIP is the simplest and often the most accurate for moderate sparsity; GraSP is better at extreme sparsity; SynFlow is the safest (no layer collapse) and requires no data.
Pruning Masks and Sparse Representations#
After pruning, we need to efficiently represent the sparse weight matrices. Naively storing the full matrix with zeros wastes memory. Several sparse formats exist, each with different trade-offs.
Binary Mask Representation#
The simplest representation stores the original dense matrix alongside a binary mask:
$$W_{\text{pruned}} = W \odot M, \quad M \in \{0, 1\}^{m \times n}$$This is conceptually simple and easy to implement but provides limited compression: we still store the full matrix plus a mask. The mask itself can be compressed since it is binary (1 bit per element instead of 32 bits for a float), but we still store zeros in \(W\).
CSR (Compressed Sparse Row) Format#
The Compressed Sparse Row (CSR) format is one of the most common sparse matrix representations. It stores only the nonzero elements along with their positions.
A CSR representation consists of three arrays:
- values: the nonzero elements, read row by row
- col_indices: the column index of each nonzero element
- row_ptr: for each row \(i\), row_ptr[i] gives the index into
valueswhere row \(i\) starts
Example: Consider the pruned matrix from our earlier example:
$$W_{\text{pruned}} = \begin{bmatrix} 0.52 & 0 & 0.81 \\\\ 0 & 0.95 & 0 \\\\ 0 & -0.68 & 0 \end{bmatrix}$$Original matrix (3x3):
col 0 col 1 col 2
row 0 [ 0.52, 0, 0.81]
row 1 [ 0, 0.95, 0 ]
row 2 [ 0, -0.68, 0 ]
CSR representation:
values: [0.52, 0.81, 0.95, -0.68]
col_indices: [0, 2, 1, 1 ]
row_ptr: [0, 2, 3, 4 ]
^ ^ ^ ^
| | | |
| | | +-- row 2 ends (4 elements total)
| | +-- row 2 starts at index 3
| +-- row 1 starts at index 2
+-- row 0 starts at index 0
Row 0 elements: values[0:2] = [0.52, 0.81]
at columns col_indices[0:2] = [0, 2]
Row 1 elements: values[2:3] = [0.95]
at columns col_indices[2:3] = [1]
Row 2 elements: values[3:4] = [-0.68]
at columns col_indices[3:4] = [1]CSC (Compressed Sparse Column) Format#
CSC is the column-oriented counterpart of CSR. It stores nonzero elements column by column:
- values: nonzero elements, read column by column
- row_indices: the row index of each nonzero element
- col_ptr: for each column \(j\), col_ptr[j] gives the index into
valueswhere column \(j\) starts
CSC is preferred when column access patterns dominate (e.g., for matrix-vector multiplication \(Ax\) where \(A\) is accessed column-wise).
COO (Coordinate) Format#
The COO format stores each nonzero element as a (row, column, value) triple:
For the same matrix:
row: [0, 0, 1, 2 ]
col: [0, 2, 1, 1 ]
values: [0.52, 0.81, 0.95, -0.68]COO is simple and flexible but less memory-efficient than CSR/CSC for large matrices (it stores two indices per nonzero element instead of one index plus a pointer array).
Block Sparse Formats#
In block sparse formats, the sparsity structure is defined at the level of blocks (e.g., \(4 \times 4\) or \(8 \times 8\) submatrices) rather than individual elements. A block is either entirely zero or entirely nonzero.
Dense matrix (8x8): Block sparse (2x2 blocks):
[x x 0 0 x x 0 0] [X X . . X X . .]
[x x 0 0 x x 0 0] [X X . . X X . .]
[0 0 x x 0 0 0 0] [. . X X . . . .]
[0 0 x x 0 0 0 0] [. . X X . . . .]
[0 0 0 0 x x x x] [. . . . X X X X]
[0 0 0 0 x x x x] [. . . . X X X X]
[x x x x 0 0 0 0] [X X X X . . . .]
[x x x x 0 0 0 0] [X X X X . . . .]
'X' = nonzero block, '.' = zero blockBlock sparse formats are important for hardware efficiency. Modern GPUs (e.g., NVIDIA A100 with 2:4 structured sparsity) operate on blocks of data, and unstructured sparsity does not map well to their compute units. Block sparsity allows for real hardware speedups.
Storage Savings at Different Sparsity Levels#
For a matrix with \(n\) elements stored as FP32 (4 bytes each):
| Sparsity | nnz | Dense (bytes) | CSR (bytes) | Compression Ratio |
|---|---|---|---|---|
| 0% | \(n\) | \(4n\) | \(4n + 4n + 4(r+1)\) | 0.5x (larger!) |
| 50% | \(0.5n\) | \(4n\) | \(4(0.5n) + 4(0.5n) + 4(r+1)\) | ~1x |
| 80% | \(0.2n\) | \(4n\) | \(4(0.2n) + 4(0.2n) + 4(r+1)\) | ~2.5x |
| 90% | \(0.1n\) | \(4n\) | \(4(0.1n) + 4(0.1n) + 4(r+1)\) | ~5x |
| 95% | \(0.05n\) | \(4n\) | \(4(0.05n) + 4(0.05n) + 4(r+1)\) | ~10x |
| 99% | \(0.01n\) | \(4n\) | \(4(0.01n) + 4(0.01n) + 4(r+1)\) | ~50x |
Note: CSR storage = (values: 4 bytes/nnz) + (col_indices: 4 bytes/nnz) + (row_ptr: 4 bytes/(rows+1)). For large matrices, the row_ptr overhead is negligible.
The crossover point where CSR becomes beneficial is around 50% sparsity. Below 50%, the overhead of storing indices makes CSR larger than dense storage. This is why pruning to at least 50% sparsity (and preferably 80%+) is needed for memory benefits.
Regrowth and Dynamic Sparse Training#
All methods discussed so far assume a fixed sparsity pattern: once a weight is pruned, it stays pruned. Dynamic sparse training challenges this assumption by allowing pruned weights to return (regrow) while other weights are pruned, maintaining a constant sparsity level throughout training.
Sparse-to-Sparse Training#
The key idea is to train a sparse network from the start, never materializing the full dense network:
Traditional Pruning: Dynamic Sparse Training:
Dense --> Sparse Sparse --> Sparse --> Sparse --> ...
(train) (prune+ft) (train) (regrow (regrow
+ prune) + prune)
Memory: O(n) Memory: O(k) where k << nThis is significant for memory: we never need to store a dense \(n\)-parameter model, only the sparse \(k\)-parameter model.
SET (Sparse Evolutionary Training, Mocanu et al. 2018)#
SET was one of the first dynamic sparse training methods. At each regrowth step:
- Prune: Remove a fraction of weights with the smallest magnitudes.
- Regrow: Add the same number of new connections at random positions.
Algorithm: SET
--------------
Input: Initial sparse network (random topology),
prune/regrow fraction f, dataset D
1. Initialize random sparse topology with sparsity s
2. For each epoch:
3. Train the sparse network on D
4. If regrowth step:
5. Let k = f * (number of nonzero weights)
6. Remove k weights with smallest |w_i|
7. Add k weights at random zero positions
8. (initialize new weights to 0 or small random)SET demonstrates that the topology of the sparse network can be optimized during training, not just the weight values. The network “evolves” its connectivity structure over time.
RigL (Rigged Lottery, Evci et al. 2020)#
RigL improves upon SET by using gradient information instead of random selection to decide which connections to regrow. The key insight is: the gradient of the loss with respect to a zero (pruned) weight tells us how much the loss would decrease if that connection were active.
Gradient-Based Regrowth#
For a pruned weight \(w_j = 0\), the gradient \(\frac{\partial L}{\partial w_j}\) is still well-defined (it is the gradient of the loss with respect to the weight, as if it were active). Connections with the largest gradient magnitude are the ones that would be most useful if activated.
Algorithm: RigL
---------------
Input: Initial sparse network, sparsity s,
prune fraction alpha(t), dataset D
1. Initialize sparse network with Erdos-Renyi topology
2. For each training step t:
3. Forward pass: y = f(x; W_sparse)
4. Backward pass: compute gradients for ALL weights
(including zero/pruned weights)
5. Update active weights: w_i <- w_i - eta * g_i
6. If regrowth step (every Delta_T steps):
7. // PRUNE: remove lowest-magnitude active weights
8. k = alpha(t) * nnz(W)
9. Drop k active weights with smallest |w_i|
10. // REGROW: activate highest-gradient zero weights
11. Activate k zero weights with largest |g_j|
12. Initialize new weights to 0Key detail in step 12: newly regrown weights are initialized to zero. This might seem counterproductive, but the gradient will immediately push them to useful values in the next training step.
Key detail in step 4: gradients are computed for all weights, including pruned ones. In a dense layer \(y = Wx\), the gradient \(\partial L / \partial W_{ij} = (\partial L / \partial y_i) \cdot x_j\) can be computed regardless of whether \(W_{ij}\) is zero. This costs the same as a dense backward pass for that layer, which is the main overhead of RigL compared to purely sparse training.
Why Gradient-Based Regrowth Outperforms Random#
Consider a network with 1000 pruned connections. In SET, we randomly select connections to regrow — each has an equal probability of being useful. In RigL, we regrow the connections whose activation would most reduce the loss. This is a dramatically better strategy, especially as training progresses and the remaining improvements become more specific.
Empirically, RigL matches or exceeds the accuracy of dense training at 80-90% sparsity on ImageNet with ResNet-50, while SET falls short:
| Method | Sparsity | Top-1 Accuracy (ImageNet, ResNet-50) |
|---|---|---|
| Dense baseline | 0% | 76.8% |
| Static sparse (magnitude) | 80% | 74.6% |
| SET | 80% | 72.9% |
| RigL | 80% | 74.6% |
| RigL | 90% | 73.2% |
| RigL (ERK distribution) | 90% | 73.0% |
Top-KAST and Other Methods#
Top-KAST (Jayakumar et al., 2020) takes a different approach: at each forward pass, it selects the top-\(k\) weights by magnitude and only uses those for computation. The backward pass computes gradients for a slightly larger set (top-\(k’\) with \(k’ > k\)) to allow exploration.
Other notable dynamic sparse training methods include:
- MEST (Mixture of Experts Sparse Training): combines structured and unstructured sparsity
- OptG (Optimal Gradient-based regrowth): analyzes the optimal frequency and fraction for prune-regrow cycles
- AC/DC (Alternating Compressed/DeCompressed training): alternates between dense and sparse phases during training
Comparison of Dynamic Sparse Training Methods#
| Method | Regrowth Criterion | Pruning Criterion | Extra Cost vs Static Sparse | Key Advantage |
|---|---|---|---|---|
| SET | Random | Magnitude | None | Simplicity |
| RigL | Gradient magnitude | Magnitude | Dense backward pass | Best accuracy |
| Top-KAST | Top-k by magnitude | Implicit (not top-k) | Slightly larger backward | No explicit regrow step |
| MEST | Mixed | Mixed | Moderate | Structured sparsity |
Measuring and Evaluating Pruning#
Pruning a model is only useful if the pruned model is actually better in some practical sense. This section discusses how to measure and evaluate pruning quality.
Key Metrics#
Sparsity (\(s\)): The fraction of zero weights.
$$s = 1 - \frac{\text{nnz}(W)}{|W|}$$FLOPs reduction: The theoretical reduction in floating-point operations. For a sparse linear layer with weight matrix \(W \in \mathbb{R}^{m \times n}\):
$$\text{FLOPs}_{\text{dense}} = 2mn, \quad \text{FLOPs}_{\text{sparse}} = 2 \cdot \text{nnz}(W)$$At sparsity \(s\): \(\text{FLOPs}{\text{sparse}} = (1-s) \cdot \text{FLOPs}{\text{dense}}\)
Memory savings: Depends on the sparse format used (see previous section).
Actual speedup: The wall-clock time reduction when running inference.
The Gap Between Theoretical and Actual Speedup#
One of the most important practical considerations in pruning is the speedup gap: the difference between the theoretical FLOPs reduction and the actual wall-clock speedup.
Theoretical speedup at 90% sparsity: 10x
Actual speedup (unstructured): 1.0-2.0x (!)
Actual speedup (structured): 3.0-5.0x
Actual speedup (2:4 on A100): ~2.0xWhy the gap? Several reasons:
Memory bandwidth: Many operations are memory-bound, not compute-bound. Sparse formats require extra memory accesses for indices, which can offset computational savings.
Irregular access patterns: Unstructured sparsity creates irregular memory access patterns that defeat hardware prefetchers and cache hierarchies.
Software overhead: Sparse matrix libraries have overhead for managing the sparse data structure, and most deep learning frameworks are heavily optimized for dense operations.
Parallelism loss: Dense matrix multiplication maps perfectly to GPU’s parallel architecture (SIMD/SIMT). Sparse operations have irregular parallelism.
Hardware support: Most current hardware is designed for dense computation. Only specific hardware (e.g., NVIDIA A100 with 2:4 sparsity, Cerebras CS-2) has native sparse support.
Wall-Clock Time vs FLOPs#
This gap means that FLOPs is a poor proxy for actual runtime in the context of sparsity. A 10x reduction in FLOPs might translate to only a 1.5x speedup on a GPU. Practitioners should always measure wall-clock time on the target hardware, not just count FLOPs.
The situation is better for structured pruning (removing entire filters, attention heads, etc.), where the sparsity pattern is regular and maps well to hardware. This is why structured pruning is often preferred in practice despite slightly worse accuracy-sparsity trade-offs.
Accuracy vs Sparsity Pareto Curves#
The standard way to evaluate a pruning method is to plot accuracy against sparsity across a range of sparsity levels. A method is superior if its curve dominates another (higher accuracy at every sparsity level).
Accuracy (%)
96 |*--* (OBS)
| *--*
94 | +--+ *--* (OBD)
| +--+ *--*
92 | o--o--o +--+ *--* (Magnitude)
| o--o +--+ *
90 | o--o +--+ *
| o +--+
88 | o +
| o
86 |
+---+---+---+---+---+---+---+---+---+---+--> Sparsity
50% 60% 70% 75% 80% 85% 90% 95%
* = OBS + = OBD o = MagnitudeA better pruning criterion (e.g., OBS vs magnitude) shifts the curve to the right, achieving the same accuracy at higher sparsity. Alternatively, it achieves higher accuracy at the same sparsity level.
Per-Layer Sparsity Distribution#
Global pruning naturally assigns different sparsity levels to different layers. Analyzing this distribution provides insight into the network’s structure:
Layer-wise sparsity in a pruned ResNet-50 (90% overall):
conv1 (first layer): |#### | ~20% sparse
layer1.0.conv1: |########## | ~50% sparse
layer1.0.conv2: |############ | ~60% sparse
layer2.0.conv1: |################ | ~80% sparse
layer2.0.conv2: |################# | ~85% sparse
layer3.0.conv1: |##################| ~95% sparse
layer3.0.conv2: |##################| ~97% sparse
layer4.0.conv1: |##################| ~98% sparse
fc (last layer): |########## | ~50% sparse
Pattern: middle/late layers are most prunable;
first and last layers are most sensitive.This pattern is highly consistent across architectures: early layers (which extract basic features like edges and textures) and the final classification layer are relatively sensitive to pruning, while the middle layers (which extract high-level features) are highly redundant.
Summary#
Comparison of All Pruning Criteria#
| Criterion | Information Used | Computational Cost | Requires Training | Handles Interactions | Typical Accuracy |
|---|---|---|---|---|---|
| Magnitude (L1) | \(|w_i|\) | \(O(n)\) | Yes | No | Good |
| OBD | \(w_i, h_{ii}\) | \(O(n)\) | Yes | Partial (diagonal) | Better |
| OBS | \(w_i, H^{-1}\) | \(O(n^2)\) - \(O(n^3)\) | Yes | Yes (full Hessian) | Best |
| Fisher | \(w_i, F_{ii}\) | \(O(n \cdot B)\) | Yes | Partial | Better |
| Taylor-FO | \(w_i, g_i\) | \(O(n)\) | Yes | No | Good |
| Movement | \(\sum w_i g_i\) | \(O(n \cdot T)\) | During fine-tune | Temporal | Best for fine-tuning |
| SNIP | \(\partial L/\partial c_j\) | \(O(n)\) | No (1 batch) | No | Moderate |
| GraSP | \(Hg \cdot w\) | \(O(n)\) | No (1-2 batches) | Partial | Moderate |
| SynFlow | Path products | \(O(n \cdot T)\) | No (data-free) | Layer-aware | Moderate |
where \(n\) = number of parameters, \(B\) = batch size, \(T\) = training iterations.
Key Takeaways#
Neural networks are vastly over-parameterized. Typically 80-95% of weights can be removed with less than 1% accuracy loss.
The Lottery Ticket Hypothesis provides deep theoretical insight: dense networks contain sparse “winning tickets” that can match full accuracy when trained from their original initialization.
Magnitude pruning is a strong baseline. Despite its simplicity, it is competitive with more sophisticated methods in many settings.
Second-order methods (OBD, OBS) are theoretically superior but computationally expensive. They are most useful when pruning to extreme sparsity or when each pruned weight must be carefully chosen.
The pruning schedule matters enormously. Iterative pruning with a cubic schedule significantly outperforms one-shot pruning at the same final sparsity.
Pruning at initialization is possible (SNIP, GraSP, SynFlow) and saves the cost of training a full dense network, though with some accuracy penalty.
Dynamic sparse training (SET, RigL) eliminates the need for a dense training phase entirely, achieving competitive accuracy while maintaining a sparse network throughout training.
Theoretical speedup does not equal actual speedup. Unstructured sparsity maps poorly to current hardware. Structured pruning or specialized hardware (e.g., NVIDIA 2:4 sparsity) is needed for real wall-clock improvements.
Global pruning generally outperforms local pruning because it allows non-uniform sparsity distribution across layers, allocating more capacity to sensitive layers.
Pruning is complementary to other compression techniques. The best results come from combining pruning with quantization, distillation, and architecture search.
What Comes Next#
This post covered the fundamentals of pruning: the criteria for deciding which weights to remove. But we have not yet addressed a critical distinction: unstructured vs structured pruning. Unstructured pruning removes individual weights anywhere in the network, while structured pruning removes entire neurons, filters, or attention heads. This distinction has profound implications for hardware efficiency and practical deployment.
In the next post, we will explore structured pruning in detail — how to prune entire channels and filters, the group sparsity framework, and why structured pruning is often preferred in practice despite its less favorable accuracy-sparsity trade-off.
References#
- LeCun, Y., Denker, J.S., & Solla, S.A. (1990). Optimal Brain Damage. NeurIPS.
- Hassibi, B., & Stork, D.G. (1993). Second Order Derivatives for Network Pruning: Optimal Brain Surgeon. NeurIPS.
- Frankle, J., & Carlin, M. (2019). The Lottery Ticket Hypothesis: Finding Sparse, Trainable Neural Networks. ICLR.
- Frankle, J., Dziugaite, G.K., Roy, D.M., & Carlin, M. (2020). Linear Mode Connectivity and the Lottery Ticket Hypothesis. ICML.
- Ramanujan, V., et al. (2020). What’s Hidden in a Randomly Weighted Neural Network? CVPR.
- Malach, E., et al. (2020). Proving the Lottery Ticket Hypothesis: Pruning is All You Need. ICML.
- Zhu, M., & Gupta, S. (2017). To Prune, or Not to Prune: Exploring the Efficacy of Pruning for Model Compression. NeurIPS Workshop.
- Lee, N., Ajanthan, T., & Torr, P.H.S. (2019). SNIP: Single-shot Network Pruning based on Connection Sensitivity. ICLR.
- Wang, C., Zhang, G., & Grosse, R. (2020). Picking Winning Tickets Before Training by Preserving Gradient Flow. ICLR.
- Tanaka, H., et al. (2020). Pruning Neural Networks without Any Data by Iteratively Conserving Synaptic Flow. NeurIPS.
- Sanh, V., Wolf, T., & Rush, A.M. (2020). Movement Pruning: Adaptive Sparsity during Fine-Tuning. NeurIPS.
- Mocanu, D.C., et al. (2018). Scalable Training of Artificial Neural Networks with Adaptive Sparse Connectivity Inspired by Network Science. Nature Communications.
- Evci, U., et al. (2020). Rigging the Lottery: Making All Tickets Winners. ICML.
- Frantar, E., et al. (2022). GPTQ: Accurate Post-Training Quantization for Generative Pre-Trained Transformers. ICLR.
- Jayakumar, S., et al. (2020). Top-KAST: Top-K Always Sparse Training. NeurIPS.