Table of Contents
Overview#
Neural network pruning removes redundant parameters to produce smaller, faster models. The central question is not whether to prune, but how to prune – and the answer determines whether you get a real-world speedup or merely a theoretical one.
The Fundamental Tradeoff: Flexibility vs Hardware Efficiency#
Pruning methods span a spectrum defined by two opposing forces:
- Flexibility: the freedom to remove any individual weight regardless of its position in a tensor. More flexibility means higher sparsity at the same accuracy, because the optimizer can cherry-pick the least important parameters wherever they are.
- Hardware efficiency: the ability of actual processors to exploit the resulting sparsity. Modern GPUs, CPUs, and accelerators are optimized for dense, regular memory access patterns. The more structured the sparsity pattern, the easier it is for hardware to convert zero parameters into saved computation cycles.
This tradeoff is the single most important concept in pruning research:
Flexibility Hardware Efficiency
(Accuracy) (Real Speedup)
| |
| Unstructured N:M Block Channel Filter |
| (individual) (2:4) (4x1,2x4) Layer |
| =============================================> |
| |
| Fine-grained <-----------> Coarse-grained |Why Structure Matters for Real-World Speedup#
Consider a dense matrix multiply \(Y = XW\) where \(W \in \mathbb{R}^{m \times n}\). If we prune 90% of the weights in \(W\) at random positions:
- The number of nonzero multiply-accumulate operations drops to 10%.
- But \(X\) and \(W\) cannot be stored as contiguous dense blocks.
- The GPU still allocates the full \(m \times n\) tensor in memory.
- Sparse indexing introduces overhead for every nonzero access.
- The Tensor Cores cannot be used because the operands are not dense tiles.
The result: theoretical 10x FLOPs reduction, actual 1.0-1.2x speedup on a standard GPU.
Now consider pruning 50% of the filters (entire rows of \(W\)). The resulting weight matrix is \(W’ \in \mathbb{R}^{m/2 \times n}\). This is a perfectly dense matrix – just smaller. The matrix multiply becomes \(Y’ = XW’\), which runs at full Tensor Core efficiency on a smaller problem. The result: theoretical 2x FLOPs reduction, actual 1.8-1.95x speedup.
Taxonomy of Pruning Granularity#
We classify pruning granularity from finest to coarsest:
| Granularity | Unit Removed | Sparsity Pattern | Hardware Friendly | Typical Sparsity |
|---|---|---|---|---|
| Unstructured | Single weight | Irregular | No (needs special HW) | 90-99% |
| N:M Sparsity | N of M elements | Semi-structured | Yes (Ampere+) | 50% (2:4) |
| Block Sparse | k x k weight block | Regular blocks | Moderate | 50-90% |
| Channel | Input channel slice | Structured | Yes | 30-70% |
| Filter/Neuron | Output filter/neuron | Structured | Yes | 30-70% |
| Attention Head | Entire QKV head | Structured | Yes | 20-50% |
| Layer | Entire layer | Structured | Yes | 10-30% |
Unstructured (Fine-Grained) Pruning#
Definition and Formulation#
Unstructured pruning removes individual weights from a parameter tensor. Given a weight matrix \(W \in \mathbb{R}^{m \times n}\), we compute a binary mask \(M \in {0, 1}^{m \times n}\) and apply it element-wise:
$$W_{\text{pruned}} = W \odot M$$where \(\odot\) denotes the Hadamard (element-wise) product. The mask \(M\) is determined by some importance criterion. The simplest and most widely used is magnitude pruning:
$$M_{ij} = \begin{cases} 1 & \text{if } |W_{ij}| \geq \theta \\ 0 & \text{if } |W_{ij}| < \theta \end{cases}$$where \(\theta\) is a threshold chosen to achieve the desired sparsity level \(s\):
$$s = 1 - \frac{||M||_0}{m \cdot n}$$Here \(||M||_0\) counts the number of nonzero entries in \(M\).
Why Unstructured Pruning Achieves the Highest Sparsity#
The key insight is degrees of freedom. With \(m \times n\) independent binary decisions, the optimizer can pick the globally least important weights. For a given accuracy target, unstructured pruning always achieves equal or higher sparsity than any structured method, because structured methods impose additional constraints on which weights must be removed together.
Formally, let \(\mathcal{M}_u\) be the set of all possible unstructured masks and \(\mathcal{M}_s \subset \mathcal{M}_u\) be the set of structured masks. The optimal unstructured mask solves:
$$M^*_u = \arg\min_{M \in \mathcal{M}_u} \mathcal{L}(W \odot M) \quad \text{s.t.} \quad ||M||_0 \leq k$$Since \(\mathcal{M}_s \subset \mathcal{M}_u\), we have \(\mathcal{L}(W \odot M^_u) \leq \mathcal{L}(W \odot M^_s)\) for any sparsity budget \(k\). In practice, unstructured methods routinely reach 90-98% sparsity with less than 1% accuracy loss, while structured methods often struggle beyond 50-70%.
The Sparsity Illusion: 95% Sparse but No Speedup on GPU#
This is the most common trap in pruning research. A paper reports “95% sparsity with only 0.5% accuracy drop” and claims massive compression. But when deployed:
Dense Model:
W = [0.3 0.0 0.7 0.1] Memory: 4 x 4 x 4B = 64 bytes
[0.0 0.5 0.0 0.0] GEMM: Dense 4x4, fully pipelined
[0.2 0.0 0.0 0.8] Tensor Cores: YES
[0.0 0.0 0.6 0.0] Actual speed: baseline
After 75% Unstructured Pruning:
W = [0.3 0.0 0.7 0.0] Memory: still 64 bytes (or more with index)
[0.0 0.5 0.0 0.0] GEMM: Sparse, irregular access
[0.0 0.0 0.0 0.8] Tensor Cores: NO
[0.0 0.0 0.6 0.0] Actual speed: ~1.0x (no speedup!)The problem is architectural. Modern GPUs execute matrix multiplications by:
- Loading tiles (e.g., 16x16) from global memory into shared memory.
- Computing the tile product using Tensor Cores (dense fused multiply-add).
- Writing the output tile back.
Every step assumes dense, contiguous data. A sparse matrix with 95% zeros still occupies the same memory footprint unless converted to a sparse format. Even in sparse formats, the irregular access patterns prevent coalesced memory reads and Tensor Core utilization.
Irregular Memory Access Patterns#
Consider a simple sparse matrix-vector multiply \(y = Wx\) with \(W\) stored in Compressed Sparse Row (CSR) format:
Dense W: CSR Representation:
[0.3 0.0 0.7 0.0] values: [0.3, 0.7, 0.5, 0.8, 0.6]
[0.0 0.5 0.0 0.0] col_idx: [0, 2, 1, 3, 2 ]
[0.0 0.0 0.0 0.8] row_ptr: [0, 2, 3, 4, 5 ]
[0.0 0.0 0.6 0.0]
Memory access for row 0: x[0], x[2] (stride=2, non-contiguous)
Memory access for row 1: x[1] (single element)
Memory access for row 2: x[3] (single element)
Memory access for row 3: x[2] (single element)Each row accesses different, unpredictable locations in \(x\). This is the opposite of what GPUs need (coalesced, predictable access). Cache lines are loaded but only partially used, wasting memory bandwidth.
Sparse Matrix Storage Overhead Analysis#
Let us quantify the storage overhead. For a dense matrix \(W \in \mathbb{R}^{m \times n}\) with sparsity \(s\) (fraction of zeros):
Dense storage: \(m \times n \times b\) bytes, where \(b\) is bytes per element (4 for FP32, 2 for FP16).
CSR storage:
- Values array: \((1-s) \cdot m \cdot n \cdot b\) bytes
- Column indices: \((1-s) \cdot m \cdot n \cdot 4\) bytes (INT32)
- Row pointers: \((m+1) \cdot 4\) bytes
Total CSR: \((1-s) \cdot m \cdot n \cdot (b + 4) + (m+1) \cdot 4\)
The break-even point where CSR becomes smaller than dense occurs at:
$$(1-s) \cdot m \cdot n \cdot (b + 4) + (m+1) \cdot 4 < m \cdot n \cdot b$$Solving for \(s\) (ignoring the row pointer term for large matrices):
$$(1-s)(b+4) < b$$$$b + 4 - sb - 4s < b$$$$4 < s(b+4)$$$$s > \frac{4}{b+4}$$For FP32 (\(b=4\)): \(s > 0.5\) (50% sparsity needed just to break even on storage). For FP16 (\(b=2\)): \(s > 0.667\) (67% sparsity needed). For INT8 (\(b=1\)): \(s > 0.8\) (80% sparsity needed).
This shows that sparse formats become less attractive as element size decreases – precisely when quantization is also applied.
Numerical Example: Pruning a 4x4 Matrix#
Consider a fully connected layer with \(W \in \mathbb{R}^{4 \times 4}\):
$$W = \begin{bmatrix} 0.82 & -0.15 & 0.91 & 0.03 \\ -0.07 & 0.68 & -0.11 & 0.44 \\ 0.23 & -0.02 & -0.05 & 0.77 \\ -0.38 & 0.01 & 0.56 & -0.09 \end{bmatrix}$$Step 1: Compute magnitudes:
$$|W| = \begin{bmatrix} 0.82 & 0.15 & 0.91 & 0.03 \\ 0.07 & 0.68 & 0.11 & 0.44 \\ 0.23 & 0.02 & 0.05 & 0.77 \\ 0.38 & 0.01 & 0.56 & 0.09 \end{bmatrix}$$Step 2: Sort all 16 magnitudes: 0.01, 0.02, 0.03, 0.05, 0.07, 0.09, 0.11, 0.15, 0.23, 0.38, 0.44, 0.56, 0.68, 0.77, 0.82, 0.91
Step 3: For 50% sparsity, prune the 8 smallest. Threshold \(\theta = 0.15\) (the 8th value). Everything with magnitude \(< 0.15\) is pruned:
$$M = \begin{bmatrix} 1 & 1 & 1 & 0 \\ 0 & 1 & 0 & 1 \\ 1 & 0 & 0 & 1 \\ 1 & 0 & 1 & 0 \end{bmatrix}$$$$W_{\text{pruned}} = \begin{bmatrix} 0.82 & -0.15 & 0.91 & 0 \\ 0 & 0.68 & 0 & 0.44 \\ 0.23 & 0 & 0 & 0.77 \\ -0.38 & 0 & 0.56 & 0 \end{bmatrix}$$Notice the irregular pattern: nonzeros are scattered with no spatial regularity. This matrix cannot be represented as a smaller dense matrix.
When Unstructured Pruning Works: Specialized Hardware#
Unstructured pruning becomes practical on hardware designed for sparsity:
- Cerebras WSE-2/3: The wafer-scale engine has a dataflow architecture where each processing element can skip zero operands natively. Unstructured sparsity directly reduces compute.
- NVIDIA Sparse Tensor Cores (Ampere+): Support N:M structured sparsity (a constrained form), not fully unstructured.
- Graphcore IPU: Can exploit some levels of sparsity through its bulk synchronous parallel model.
- CPUs with branch-based kernels: For small models, CPU inference can use conditional branches to skip zero multiplications, though branch misprediction limits the benefit.
Structured Pruning – Detailed Taxonomy#
Filter/Kernel Pruning (Coarse-Grained)#
Filter pruning is the most widely used form of structured pruning for convolutional networks. It removes entire 3D filters from a convolutional layer, resulting in a smaller but fully dense layer.
Setup: Consider a convolutional layer with weight tensor \(W \in \mathbb{R}^{C_{out} \times C_{in} \times k_h \times k_w}\), where:
- \(C_{out}\): number of output channels (filters)
- \(C_{in}\): number of input channels
- \(k_h \times k_w\): kernel spatial dimensions
Each filter \(F_i \in \mathbb{R}^{C_{in} \times k_h \times k_w}\) for \(i = 1, \ldots, C_{out}\) produces one output feature map.
L1-Norm Filter Pruning (Li et al., 2017)#
The importance of filter \(i\) is measured by the sum of absolute values of all its parameters:
$$\text{score}(F_i) = ||F_i||_1 = \sum_{c=1}^{C_{in}} \sum_{k_1=1}^{k_h} \sum_{k_2=1}^{k_w} |F_i(c, k_1, k_2)|$$Derivation of why L1-norm is a reasonable proxy for importance:
The output of filter \(i\) at spatial location \((x, y)\) is:
$$Z_i(x, y) = \sum_{c=1}^{C_{in}} \sum_{k_1=1}^{k_h} \sum_{k_2=1}^{k_w} F_i(c, k_1, k_2) \cdot A(c, x+k_1, y+k_2)$$where \(A\) is the input activation tensor. If the input activations have roughly unit variance and zero mean (which BatchNorm ensures), then the expected magnitude of \(Z_i(x,y)\) scales with:
$$\mathbb{E}[|Z_i(x,y)|] \propto ||F_i||_1 \cdot \mathbb{E}[|A|]$$A filter with smaller L1-norm produces activations with smaller expected magnitude, contributing less to the network’s representational capacity. Removing it should therefore cause less damage to the output.
Algorithm:
- For each layer \(l\), compute \(\text{score}(F_i^{(l)})\) for all \(i\).
- Sort filters by score within each layer (or globally).
- Remove the bottom \(p%\) of filters per layer (or global threshold).
- Remove corresponding structures in subsequent layers.
- Fine-tune the pruned network.
Geometric Median Filter Pruning (He et al., 2019)#
Instead of pruning the smallest filters, this method prunes filters that are most replaceable – those closest to the geometric median of all filters. The geometric median minimizes the sum of distances:
$$F_{\text{gm}} = \arg\min_{F} \sum_{i=1}^{C_{out}} ||F - F_i||_2$$The most replaceable filter \(j\) is:
$$j = \arg\min_{i} \sum_{k \neq i} ||F_i - F_k||_2$$Intuition: If filter \(i\) is close to many other filters, its function can be approximated by a linear combination of the remaining filters. Removing it causes minimal information loss. This is particularly useful when many filters have similar L1-norms but encode redundant features.
Effect on Layer Dimensions#
Removing filter \(i\) from layer \(l\) has cascading effects:
BEFORE PRUNING:
Layer l: W^(l) in R^{C_out x C_in x k x k}
bias^(l) in R^{C_out}
BN^(l): gamma, beta, running_mean, running_var in R^{C_out}
Layer l+1: W^(l+1) in R^{C_out' x C_out x k' x k'}
AFTER REMOVING FILTER i FROM LAYER l:
Layer l: W^(l) in R^{(C_out-1) x C_in x k x k} [row i removed]
bias^(l) in R^{(C_out-1)} [element i removed]
BN^(l): all params in R^{(C_out-1)} [element i removed]
Layer l+1: W^(l+1) in R^{C_out' x (C_out-1) x k' x k'} [channel i removed]This is the fundamental property of structured pruning: removing a filter from layer \(l\) changes the shape of two layers (\(l\) and \(l+1\)), but both remain fully dense tensors.
ASCII Diagram: Before and After Filter Pruning#
BEFORE: Conv Layer l (C_out=4, C_in=3, k=3x3)
Filter 0: [3x3x3] ----+
Filter 1: [3x3x3] ----+----> Output: [4 x H' x W']
Filter 2: [3x3x3] ----+ (4 output channels)
Filter 3: [3x3x3] ----+
Next Layer l+1 expects 4 input channels:
W^(l+1) shape: [C_out' x 4 x k' x k']
---------- Prune Filter 1 and Filter 3 (50% pruning) ----------
AFTER: Conv Layer l (C_out=2, C_in=3, k=3x3)
Filter 0: [3x3x3] ----+
+----> Output: [2 x H' x W']
Filter 2: [3x3x3] ----+ (2 output channels)
Next Layer l+1 now expects 2 input channels:
W^(l+1) shape: [C_out' x 2 x k' x k']
Result: Layer l is 50% smaller, Layer l+1 input dimension halved
Both remain DENSE tensors -> full hardware utilizationFull Numerical Example with a Small Conv Layer#
Consider a tiny conv layer: \(C_{out}=3, C_{in}=2, k=2\times 2\).
Filter 0:
$$F_0 = \begin{bmatrix} \begin{bmatrix} 0.5 & 0.3 \\ 0.1 & 0.2 \end{bmatrix}, \begin{bmatrix} -0.4 & 0.6 \\ 0.7 & -0.1 \end{bmatrix} \end{bmatrix}$$Filter 1:
$$F_1 = \begin{bmatrix} \begin{bmatrix} 0.02 & -0.01 \\ 0.03 & -0.05 \end{bmatrix}, \begin{bmatrix} 0.04 & 0.01 \\ -0.02 & 0.06 \end{bmatrix} \end{bmatrix}$$Filter 2:
$$F_2 = \begin{bmatrix} \begin{bmatrix} 0.8 & -0.3 \\ 0.4 & 0.9 \end{bmatrix}, \begin{bmatrix} -0.7 & 0.5 \\ 0.2 & 0.6 \end{bmatrix} \end{bmatrix}$$Compute L1-norm scores:
$$\text{score}(F_0) = |0.5|+|0.3|+|0.1|+|0.2|+|-0.4|+|0.6|+|0.7|+|-0.1| = 2.9$$$$\text{score}(F_1) = |0.02|+|-0.01|+|0.03|+|-0.05|+|0.04|+|0.01|+|-0.02|+|0.06| = 0.24$$$$\text{score}(F_2) = |0.8|+|-0.3|+|0.4|+|0.9|+|-0.7|+|0.5|+|0.2|+|0.6| = 4.4$$Ranking: \(F_1 (0.24) < F_0 (2.9) < F_2 (4.4)\)
Pruning: Remove \(F_1\) (lowest L1-norm). The pruned layer has \(C_{out}=2\), \(C_{in}=2\), \(k=2\times 2\). This is a standard dense conv layer that any framework can execute efficiently.
Channel Pruning#
Channel pruning removes input channels rather than output filters. While filter pruning operates on the output dimension, channel pruning operates on the input dimension of a weight tensor.
Channel Pruning via LASSO Regression (He et al., 2017)#
The goal is to select a subset of input channels that best reconstruct the output feature maps. For a layer with input \(X \in \mathbb{R}^{N \times C_{in} \times H \times W}\) (batch of activations) and filters \(W \in \mathbb{R}^{C_{out} \times C_{in} \times k \times k}\):
The output is \(Y = \sum_{c=1}^{C_{in}} X_c * W_c\) where \(*\) denotes convolution and \(X_c, W_c\) are the \(c\)-th channel slices.
Channel pruning introduces a channel selection vector \(\beta \in {0,1}^{C_{in}}\):
$$Y \approx \sum_{c=1}^{C_{in}} \beta_c \cdot X_c * W'_c$$The optimization problem is:
$$\min_{\beta, W'} \left\| Y - \sum_{c=1}^{C_{in}} \beta_c \cdot X_c * W'_c \right\|_F^2 \quad \text{s.t.} \quad ||\beta||_0 \leq C_{in} \cdot (1-s)$$Since the \(\ell_0\) constraint is NP-hard, it is relaxed to an \(\ell_1\) penalty (LASSO):
$$\min_{\beta, W'} \left\| Y - \sum_{c=1}^{C_{in}} \beta_c \cdot X_c * W'_c \right\|_F^2 + \lambda ||\beta||_1$$Derivation of the LASSO solution (for fixed \(W’ = W\)):
Reformulate as a standard LASSO. Let \(Z_c = X_c * W_c \in \mathbb{R}^{N \times H’ \times W’}\) be the contribution of channel \(c\). Flatten \(Y\) and each \(Z_c\) into vectors, forming the matrix \(Z = [z_1, z_2, \ldots, z_{C_{in}}]\):
$$\min_{\beta} ||y - Z\beta||_2^2 + \lambda||\beta||_1$$Taking the subgradient and setting to zero:
$$-2Z^T(y - Z\beta) + \lambda \partial ||\beta||_1 = 0$$For each coordinate \(c\), the solution is the soft-thresholding operator:
$$\hat{\beta}_c = \text{sign}(r_c) \cdot \max(|r_c| - \lambda/2, 0)$$where \(r_c = Z_c^T (y - Z_{-c}\beta_{-c}) / ||Z_c||_2^2\) is the partial residual.
Channels with \(\hat{\beta}_c = 0\) are pruned.
ThiNet: Pruning Channels Based on Next Layer’s Statistics#
ThiNet (Luo et al., 2017) takes a different approach: instead of analyzing the current layer, it selects channels to prune based on how well the next layer’s output can be reconstructed.
For layer \(l+1\) with weights \(W^{(l+1)}\), the output at a single spatial position is:
$$y = \sum_{c=1}^{C} \sum_{i=1}^{k} \sum_{j=1}^{k} W^{(l+1)}_{:,c,i,j} \cdot x_{c,i,j}$$ThiNet uses a greedy algorithm to find the subset \(S \subset {1, \ldots, C}\) with \(|S| = C \cdot (1-s)\) that minimizes:
$$\min_S \sum_{\text{samples}} \left\| y - \sum_{c \in S} \sum_{i,j} W^{(l+1)}_{:,c,i,j} \cdot x_{c,i,j} \right\|^2$$Relationship Between Filter and Channel Pruning (Duality)#
Filter pruning on layer \(l\) removes rows from \(W^{(l)}\) and columns from \(W^{(l+1)}\). Channel pruning on layer \(l\) removes columns from \(W^{(l)}\) and rows from \(W^{(l-1)}\). They are dual operations:
Layer l-1 Layer l Layer l+1
[C_out^{l-1} x C_in^{l-1}] [C_out^l x C_in^l] [C_out^{l+1} x C_in^{l+1}]
Filter pruning on l: removes row of W^l -> removes col of W^{l+1}
Channel pruning on l: removes col of W^l <- removes row of W^{l-1}
Filter pruning on layer l
= Channel pruning on layer l+1
(in terms of effect on W^{l+1})Neuron/Head Pruning#
FC Layer Neuron Pruning#
For a fully connected layer \(y = Wx + b\) with \(W \in \mathbb{R}^{n \times m}\), pruning neuron \(i\) means removing row \(i\) of \(W\), element \(i\) of \(b\), and column \(i\) of the next layer’s weight matrix. This is mathematically identical to filter pruning but for FC layers.
Attention Head Pruning in Transformers (Michel et al., 2019)#
A multi-head attention layer computes:
$$\text{MultiHead}(Q, K, V) = \text{Concat}(\text{head}_1, \ldots, \text{head}_H) W^O$$where \(\text{head}_h = \text{Attention}(QW^Q_h, KW^K_h, VW^V_h)\).
Head importance score derivation:
Define a mask variable \(\xi_h \in {0, 1}\) for each head:
$$\text{MultiHead}(Q, K, V) = \sum_{h=1}^{H} \xi_h \cdot \text{head}_h \cdot W^O_h$$The importance of head \(h\) is the expected sensitivity of the loss to masking it:
$$I_h = \left| \mathbb{E}_{x \sim \mathcal{D}} \left[ \frac{\partial \mathcal{L}(x)}{\partial \xi_h} \right] \right|$$By the chain rule:
$$\frac{\partial \mathcal{L}}{\partial \xi_h} = \frac{\partial \mathcal{L}}{\partial \text{Attn}_h} \cdot \frac{\partial \text{Attn}_h}{\partial \xi_h} = \frac{\partial \mathcal{L}}{\partial \text{Attn}_h} \cdot \text{Attn}_h$$where \(\text{Attn}_h = \text{head}_h \cdot W^O_h\) is the contribution of head \(h\) to the output. Therefore:
$$I_h = \left| \mathbb{E}_{x} \left[ \frac{\partial \mathcal{L}}{\partial \text{Attn}_h} \cdot \text{Attn}_h \right] \right|$$In practice, the expectation is estimated over a validation set. Michel et al. found that in BERT-base (12 layers, 12 heads each = 144 heads), up to 40% of heads can be pruned with less than 1% accuracy drop on many NLP benchmarks.
ASCII Diagram: Transformer Block Before/After Head Pruning#
BEFORE: Multi-Head Attention (H=8 heads, d_model=512, d_k=64)
Input (512-dim)
|
+--[W_Q0, W_K0, W_V0]---> Head 0 (d_k=64)---+
+--[W_Q1, W_K1, W_V1]---> Head 1 (d_k=64)---+
+--[W_Q2, W_K2, W_V2]---> Head 2 (d_k=64)---+
+--[W_Q3, W_K3, W_V3]---> Head 3 (d_k=64)---+--Concat--> W_O --> 512-dim
+--[W_Q4, W_K4, W_V4]---> Head 4 (d_k=64)---+ (512x512)
+--[W_Q5, W_K5, W_V5]---> Head 5 (d_k=64)---+
+--[W_Q6, W_K6, W_V6]---> Head 6 (d_k=64)---+
+--[W_Q7, W_K7, W_V7]---> Head 7 (d_k=64)---+
Parameters: 3 * (512*64) * 8 + 512*512 = 786,432 + 262,144 = 1,048,576
---------- Prune heads {1, 3, 5, 7} (50% head pruning) ----------
AFTER: Multi-Head Attention (H=4 heads, d_model=512, d_k=64)
Input (512-dim)
|
+--[W_Q0, W_K0, W_V0]---> Head 0 (d_k=64)---+
+--[W_Q2, W_K2, W_V2]---> Head 2 (d_k=64)---+--Concat--> W_O --> 512-dim
+--[W_Q4, W_K4, W_V4]---> Head 4 (d_k=64)---+ (256x512)
+--[W_Q6, W_K6, W_V6]---> Head 6 (d_k=64)---+
Parameters: 3 * (512*64) * 4 + 256*512 = 393,216 + 131,072 = 524,288
Reduction: 50% of attention parametersLayer Pruning#
Layer pruning removes entire layers from deep networks. This is the coarsest form of structured pruning.
Layer Importance Estimation#
For a network \(f = f_L \circ f_{L-1} \circ \cdots \circ f_1\), the importance of layer \(l\) can be estimated as the increase in loss when the layer is bypassed:
$$I_l = \mathcal{L}(f \text{ without } f_l) - \mathcal{L}(f)$$For networks with residual connections, “without \(f_l\)” means replacing \(x_{l+1} = x_l + f_l(x_l)\) with \(x_{l+1} = x_l\) (identity shortcut).
ResNet Layer Removal Studies#
Veit et al. (2016) showed that individual residual blocks in ResNet can be removed at test time with surprisingly small accuracy drops. For ResNet-110 on CIFAR-10:
- Removing 1 block from the middle: ~0.2% accuracy drop
- Removing 5 blocks: ~1.5% accuracy drop
- Removing 10 blocks: ~4% accuracy drop
This works because residual connections ensure that \(x_{l+1} = x_l + f_l(x_l)\), and if \(f_l(x_l)\) is small (which BN regularization encourages), then skipping the block has minimal effect.
When Is Layer Pruning Viable#
Layer pruning requires skip connections (residual, dense, or highway connections). Without them, removing a layer completely disconnects the forward pass. This is why layer pruning is primarily studied in:
- ResNets and ResNeXt (residual connections)
- DenseNets (dense connections provide alternative paths)
- Transformers (residual connections around attention and FFN)
Block/Group Pruning#
Block pruning removes rectangular groups of weights, providing a middle ground between unstructured and fully structured pruning.
Common Block Patterns#
1x1 (unstructured): 4x1 (vector): 2x4 (block):
[x . . .] [x . . .] [x x x x]
[. . x .] [x . . .] [x x x x]
[. x . .] [x . . .] [. . . .]
[. . . x] [x . . .] [. . . .]
1x4 (row vector): 4x4 (tile):
[. . . .] [x x x x]
[x x x x] [x x x x]
[. . . .] [x x x x]
[. . . .] [x x x x]
Legend: x = nonzero, . = zero (pruned)Bank-Balanced Sparsity#
Bank-balanced sparsity (Cao et al., 2019) partitions the weight matrix into banks (groups of consecutive rows or columns) and enforces the same number of nonzeros per bank. This enables balanced workload distribution across parallel hardware units.
For a matrix \(W \in \mathbb{R}^{m \times n}\) with bank size \(B\):
- Partition rows into \(m/B\) banks, each with \(B\) rows.
- Within each bank, maintain exactly \(k\) nonzero columns (out of \(n\)).
- Every bank has identical compute workload: \(B \times k\) multiply-accumulates.
This ensures no hardware unit is idle, achieving near-theoretical speedup.
N:M Structured Sparsity (NVIDIA)#
Definition#
N:M sparsity requires exactly \(N\) nonzero values in every group of \(M\) consecutive elements along a specific dimension of the weight matrix. The most important instance is 2:4 sparsity: exactly 2 nonzero values per group of 4.
2:4 Sparsity on NVIDIA Ampere/Hopper#
NVIDIA introduced hardware support for 2:4 sparsity in the Ampere architecture (A100, 2020), continued in Hopper (H100, 2022) and Blackwell (B100/B200, 2024).
Key properties:
- Exactly 50% of weights are zero (2 out of every 4).
- The Sparse Tensor Core achieves 2x throughput compared to the Dense Tensor Core for the same matrix dimensions.
- The sparsity pattern is stored as a compact 2-bit index per group.
Hardware Sparse Tensor Core Operation#
The sparse matrix multiply works as follows:
Dense A (16x8, FP16) x Sparse B (8x16, 2:4 pattern) = C (16x16, FP32)
Sparse B storage:
Original B (8x16):
[0.5 0.0 0.3 0.0 | 0.0 0.7 0.0 0.1 | ...]
[0.0 0.2 0.0 0.8 | 0.4 0.0 0.0 0.6 | ...]
...
Compressed B (8x8) + metadata:
[0.5 0.3 0.7 0.1 ...] <- nonzero values only (half the columns)
[0.2 0.8 0.4 0.6 ...]
...
Metadata (2-bit indices per group of 4):
[00 10 | 01 11 | ...] <- positions: (0,2) and (1,3) in each group
[01 11 | 00 11 | ...]
Hardware operation:
1. Load dense tile of A (16x8)
2. Load compressed tile of B (8x8) + metadata
3. Use metadata to select which columns of A to multiply
4. Execute dense 16x8 x 8x8 multiply (half the original 16x8 x 8x16)
5. Accumulate into C (16x16)
Result: Same output as dense multiply, but 2x throughputThe key insight is that the hardware uses the metadata to dynamically gather the appropriate elements of \(A\), then performs a dense multiply on the compressed operands. This avoids the irregular access patterns of general sparse formats.
ASCII Diagram of 2:4 Pattern in a Weight Matrix#
Original Dense Weight Matrix W (8x8):
+------+------+------+------+------+------+------+------+
| 0.82 |-0.15 | 0.91 | 0.03 | 0.44 | 0.02 |-0.68 | 0.11 |
| 0.07 | 0.68 |-0.11 | 0.44 |-0.38 | 0.56 | 0.01 |-0.09 |
| 0.23 |-0.02 | 0.05 | 0.77 | 0.90 |-0.34 | 0.12 | 0.67 |
| 0.45 | 0.31 |-0.88 | 0.04 | 0.19 | 0.73 |-0.55 | 0.08 |
+------+------+------+------+------+------+------+------+
Apply 2:4 sparsity (keep 2 largest per group of 4):
Group boundaries: [----group 1----] [----group 2----]
Row 0: [0.82, -0.15, 0.91, 0.03] -> keep 0.82, 0.91 (idx 0,2)
[0.44, 0.02,-0.68, 0.11] -> keep 0.44,-0.68 (idx 0,2)
Row 1: [0.07, 0.68,-0.11, 0.44] -> keep 0.68, 0.44 (idx 1,3)
[-0.38, 0.56, 0.01,-0.09] -> keep -0.38, 0.56 (idx 0,1)
2:4 Sparse Matrix:
+------+------+------+------+------+------+------+------+
| 0.82 | 0 | 0.91 | 0 | 0.44 | 0 |-0.68 | 0 |
| 0 | 0.68 | 0 | 0.44 |-0.38 | 0.56 | 0 | 0 |
| 0.23 | 0 | 0 | 0.77 | 0.90 | 0 | 0 | 0.67 |
| 0.45 | 0 |-0.88 | 0 | 0 | 0.73 |-0.55 | 0 |
+------+------+------+------+------+------+------+------+
Compressed Storage (nonzeros only):
+------+------+------+------+
| 0.82 | 0.91 | 0.44 |-0.68 | Metadata: [00,10 | 00,10]
| 0.68 | 0.44 |-0.38 | 0.56 | Metadata: [01,11 | 00,01]
| 0.23 | 0.77 | 0.90 | 0.67 | Metadata: [00,11 | 00,11]
| 0.45 |-0.88 | 0.73 |-0.55 | Metadata: [00,10 | 01,10]
+------+------+------+------+
(50% memory for values + small metadata overhead)How 2:4 Is Enforced#
The simplest enforcement: for each group of 4 consecutive weights, keep the 2 with largest magnitude and zero the rest.
Algorithm:
for each row in W:
for g in range(0, n_cols, 4):
group = W[row, g:g+4]
magnitudes = abs(group)
# Find indices of 2 smallest
sorted_idx = argsort(magnitudes)
# Zero the 2 smallest
W[row, g + sorted_idx[0]] = 0
W[row, g + sorted_idx[1]] = 0Numerical example:
Group: \([0.45, 0.31, -0.88, 0.04]\)
Magnitudes: \([0.45, 0.31, 0.88, 0.04]\)
Sorted indices by magnitude: \([3, 1, 0, 2]\) (0.04, 0.31, 0.45, 0.88)
Zero indices 3 and 1: \([0.45, 0, -0.88, 0]\) – kept the two largest magnitudes.
Training with N:M Sparsity: SR-STE#
Straight-Through Estimator (STE) is the standard approach for training with discrete constraints. SR-STE (Sparse-Refined STE) by Zhou et al. (2021) refines this for N:M sparsity.
Standard STE for N:M sparsity:
Forward pass uses the sparse weights:
$$W_s = \text{TopN:M}(W) = W \odot M(W)$$where \(M(W)\) is the mask selecting the top-\(N\) magnitudes per group of \(M\).
Backward pass ignores the pruning (straight-through):
$$\frac{\partial \mathcal{L}}{\partial W} \approx \frac{\partial \mathcal{L}}{\partial W_s}$$Problem: STE allows pruned weights to grow large during training because they never participate in the forward pass but receive gradient updates. When the mask is recomputed, large formerly-pruned weights may suddenly appear, causing instability.
SR-STE solution: Decay the pruned weights toward zero:
$$W^{(t+1)} = W^{(t)} - \eta \left[ M^{(t)} \odot \frac{\partial \mathcal{L}}{\partial W_s^{(t)}} + \lambda (1 - M^{(t)}) \odot W^{(t)} \right]$$where:
- \(M^{(t)} \odot \frac{\partial \mathcal{L}}{\partial W_s^{(t)}}\): standard gradient update for non-pruned weights
- \(\lambda (1 - M^{(t)}) \odot W^{(t)}\): weight decay on pruned weights, pushing them toward zero
This ensures pruned weights stay small, making the mask more stable across training iterations.
Full training algorithm:
- Initialize dense weights \(W^{(0)}\).
- For each training step \(t\):
- a. Compute mask: \(M^{(t)} = \text{TopN:M}(|W^{(t)}|)\)
- b. Forward: \(W_s^{(t)} = W^{(t)} \odot M^{(t)}\), compute \(\mathcal{L}\)
- c. Backward: compute \(g = \partial \mathcal{L} / \partial W_s^{(t)}\)
- d. Update: \(W^{(t+1)} = W^{(t)} - \eta [M^{(t)} \odot g + \lambda(1-M^{(t)}) \odot W^{(t)}]\)
- Final model uses \(W_s = W \odot M\).
N:M Beyond 2:4#
| Pattern | Sparsity | Theoretical Speedup | HW Support (2025) | Typical Accuracy (ImageNet ResNet-50) |
|---|---|---|---|---|
| 1:4 | 75% | 4x | Research only | Top-1 drops ~2-3% |
| 2:4 | 50% | 2x | NVIDIA Ampere/Hopper/Blackwell | Top-1 drops < 0.5% |
| 2:8 | 75% | 4x | Research only | Top-1 drops ~1-2% |
| 4:8 | 50% | 2x | NVIDIA Hopper+ (planned) | Top-1 drops < 0.3% |
Mathematical Analysis: Why 2:4 Is the Sweet Spot#
The quality of an N:M pattern depends on the expected approximation error. For a group of \(M\) weights drawn i.i.d. from a symmetric distribution with variance \(\sigma^2\), the error from zeroing the \(M-N\) smallest is:
$$\mathbb{E}\left[\sum_{i \in \text{pruned}} W_i^2\right] = (M - N) \cdot \mathbb{E}[W_{(k)}^2]$$where \(W_{(k)}\) denotes the \(k\)-th order statistic.
For Gaussian weights, the expected squared magnitude of the \(k\)-th smallest out of \(M\) is:
$$\mathbb{E}[W_{(k)}^2] = \sigma^2 \left(1 - \frac{2}{\pi}\sin^2\left(\frac{k\pi}{M+1}\right) \cdot \frac{M+1}{M}\right)$$(This is an approximation; exact expressions involve incomplete beta functions.)
For 2:4: We remove the 2 smallest out of 4, each with expected squared magnitude roughly \(0.32\sigma^2\) and \(0.68\sigma^2\). Total error: \(\approx 1.0\sigma^2\) per group. Fraction of total energy pruned: \(1.0 / (4 \cdot 1.0) = 25%\).
For 1:4: We remove 3 out of 4, pruning about \(\approx 2.18\sigma^2\) per group. Fraction pruned: \(\approx 55%\). Much more destructive.
The 2:4 pattern hits the sweet spot: 50% sparsity (which the hardware can double the throughput for) with only ~25% of the weight energy removed (which fine-tuning easily recovers).
Pruning Criteria for Structured Pruning#
Norm-Based Criteria#
L1-Norm of Filters#
$$\text{score}_i = ||W_i||_1 = \sum_j |W_{i,j}|$$Derivation: The L1-norm is the tightest convex relaxation of the L0-norm (number of nonzeros). Minimizing \(||W||_1\) encourages sparsity. Conversely, filters with large L1-norm carry more “weight” in the computation. The expected output magnitude of filter \(i\) is proportional to its L1-norm when inputs have symmetric distributions.
Advantages: Simple, fast to compute, no data needed. Disadvantages: Does not account for correlations between filters, or the actual data distribution.
L2-Norm of Filters#
$$\text{score}_i = ||W_i||_2 = \sqrt{\sum_j W_{i,j}^2}$$The L2-norm measures the energy of the filter. It is related to the expected output variance:
$$\text{Var}(Z_i) = ||W_i||_2^2 \cdot \text{Var}(X) \quad \text{(for i.i.d. inputs)}$$Filters with small L2-norm produce low-variance activations, contributing less to downstream discrimination.
Batch-Norm Scaling Factor (Network Slimming, Liu et al., 2017)#
This elegant method repurposes the BatchNorm scaling factor \(\gamma\) as a built-in importance indicator.
Background: BatchNorm normalizes activations channel-wise:
$$\hat{z}_c = \frac{z_c - \mu_c}{\sqrt{\sigma_c^2 + \epsilon}}$$$$\tilde{z}_c = \gamma_c \hat{z}_c + \beta_c$$where \(\gamma_c\) and \(\beta_c\) are learned per-channel parameters. If \(\gamma_c \to 0\), then channel \(c\) is effectively zeroed out regardless of the filter weights.
Method: Add L1 regularization on \(\gamma\) during training:
$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}} + \lambda \sum_l \sum_c |\gamma_c^{(l)}|$$Full training procedure:
- Train the network with the modified loss \(\mathcal{L}_{\text{total}}\).
- The L1 penalty on \(\gamma\) drives unimportant channels’ \(\gamma_c\) toward zero.
- After training, rank all \(\gamma_c\) across the network.
- Set a global threshold \(\theta\) such that a fraction \(s\) of channels have \(|\gamma_c| < \theta\).
- Prune those channels (and corresponding filters, BN parameters in adjacent layers).
- Fine-tune the pruned network.
Derivation of the L1 proximal gradient step: Since \(|\gamma|\) is non-smooth at zero, standard SGD cannot be directly applied. Instead, use the proximal gradient:
$$\gamma_c^{(t+1)} = \text{prox}_{\eta\lambda|\cdot|}\left(\gamma_c^{(t)} - \eta \frac{\partial \mathcal{L}_{\text{task}}}{\partial \gamma_c}\right)$$where the proximal operator for L1 is the soft-thresholding function:
$$\text{prox}_{\eta\lambda|\cdot|}(v) = \text{sign}(v) \cdot \max(|v| - \eta\lambda, 0)$$In practice, PyTorch’s SGD with weight decay on \(\gamma\) approximates this (though not exactly, since weight decay is L2, not L1). Correct implementation requires a custom optimizer step.
Reconstruction-Based Criteria#
The idea is to prune structures that cause the minimal change in the layer’s output.
Formulation: Given input activations \(X \in \mathbb{R}^{N \times d_{in}}\) and current output \(Y = XW\), find a subset \(S\) of columns to keep (i.e., input features/channels) and new weights \(W’\) such that:
$$\min_{W', S} ||Y - X_S W'||_F^2 \quad \text{s.t.} \quad |S| = d_{in} - p$$where \(p\) is the number of channels/features to prune.
Derivation for the optimal W’ given S:
This is a standard least-squares problem. Partition \(X = [X_S, X_{\bar{S}}]\) and \(W = [W_S; W_{\bar{S}}]\). The original output is:
$$Y = X_S W_S + X_{\bar{S}} W_{\bar{S}}$$The reconstruction target is \(Y\) and the model is \(X_S W’\). Setting the derivative to zero:
$$\frac{\partial}{\partial W'} ||Y - X_S W'||_F^2 = -2 X_S^T (Y - X_S W') = 0$$$$W' = (X_S^T X_S)^{-1} X_S^T Y$$This is the ordinary least-squares solution. The reconstruction error for subset \(S\) is:
$$\text{err}(S) = ||Y - X_S (X_S^T X_S)^{-1} X_S^T Y||_F^2 = ||(I - P_S) Y||_F^2$$where \(P_S = X_S (X_S^T X_S)^{-1} X_S^T\) is the projection matrix onto the column space of \(X_S\). The optimal \(S\) minimizes this projection residual.
Gradient/Taylor-Based Criteria#
First-Order Taylor Expansion#
The importance of a pruning group \(g\) (filter, channel, head) can be estimated via a first-order Taylor expansion of the loss around the current parameters:
$$\mathcal{L}(W \text{ without group } g) \approx \mathcal{L}(W) - \sum_{i \in g} W_i \frac{\partial \mathcal{L}}{\partial W_i}$$Derivation: Let \(\delta W\) be the change in weights when group \(g\) is removed: \(\delta W_i = -W_i\) for \(i \in g\), \(\delta W_i = 0\) otherwise. By Taylor expansion:
$$\mathcal{L}(W + \delta W) \approx \mathcal{L}(W) + \sum_i \frac{\partial \mathcal{L}}{\partial W_i} \delta W_i + O(||\delta W||^2)$$$$= \mathcal{L}(W) - \sum_{i \in g} W_i \frac{\partial \mathcal{L}}{\partial W_i} + O(||\delta W||^2)$$The importance of group \(g\) is therefore:
$$I_g = \left| \sum_{i \in g} W_i \frac{\partial \mathcal{L}}{\partial W_i} \right|$$Activation-based variant (Molchanov et al., 2017): Instead of weight gradients, use activation gradients. For filter \(i\) producing activation \(a_i\):
$$I_i = \left| \sum_{\text{spatial}} a_i \cdot \frac{\partial \mathcal{L}}{\partial a_i} \right|$$This is mathematically equivalent (by chain rule) but numerically more stable and easier to compute in practice, since activation gradients are readily available during backpropagation.
Second-Order (Hessian) Criteria#
The second-order Taylor expansion gives a more accurate importance estimate:
$$\Delta \mathcal{L}_g \approx -\sum_{i \in g} W_i g_i + \frac{1}{2} \sum_{i,j \in g} W_i H_{ij} W_j$$where \(g_i = \partial \mathcal{L}/\partial W_i\) and \(H_{ij} = \partial^2 \mathcal{L}/\partial W_i \partial W_j\).
Computing the full Hessian \(H\) is \(O(n^2)\) in parameters, which is infeasible for large networks. Approximations include:
- Diagonal Hessian: \(H_{ij} \approx 0\) for \(i \neq j\), giving \(I_g = |\sum_{i \in g} W_i g_i - \frac{1}{2} H_{ii} W_i^2|\)
- Fisher Information Matrix: \(H \approx F = \mathbb{E}[gg^T]\), which can be estimated from gradient samples
- Hessian trace: \(\text{tr}(H_g) = \sum_{i \in g} H_{ii}\), estimated via Hutchinson’s trace estimator
Hutchinson’s trace estimator derivation:
For any square matrix \(A\), if \(v\) is a random vector with \(\mathbb{E}[vv^T] = I\) (e.g., Rademacher \(\pm 1\) entries):
$$\mathbb{E}[v^T A v] = \mathbb{E}[\text{tr}(v^T A v)] = \mathbb{E}[\text{tr}(A v v^T)] = \text{tr}(A \mathbb{E}[vv^T]) = \text{tr}(A)$$The Hessian-vector product \(Hv\) can be computed efficiently via automatic differentiation (one extra backward pass), so the trace can be estimated without materializing \(H\).
Accumulated Gradient Information#
In practice, importance scores computed from a single minibatch are noisy. The standard approach is to accumulate importance scores over multiple batches:
$$I_g = \frac{1}{B} \sum_{b=1}^{B} I_g^{(b)}$$Some methods use exponential moving averages for online estimation:
$$I_g^{(t)} = \alpha \cdot I_g^{(t-1)} + (1-\alpha) \cdot I_g^{\text{batch}(t)}$$Learning-Based Criteria#
Learnable Pruning Masks with Gumbel-Softmax#
Instead of using a heuristic criterion, learn the pruning mask end-to-end. The mask \(M \in {0,1}^G\) (one bit per group) is a discrete variable, which is non-differentiable. The Gumbel-Softmax trick provides a continuous relaxation.
Derivation: For a binary mask variable \(m_g\) (keep or prune group \(g\)):
$$m_g = \begin{cases} 1 & \text{with probability } \sigma(\alpha_g) \\ 0 & \text{with probability } 1 - \sigma(\alpha_g) \end{cases}$$where \(\sigma\) is the sigmoid function and \(\alpha_g\) is a learnable logit.
The Gumbel-Softmax relaxation replaces the discrete sample with:
$$\tilde{m}_g = \sigma\left(\frac{\alpha_g + \log u - \log(1-u)}{\tau}\right)$$where \(u \sim \text{Uniform}(0,1)\) and \(\tau\) is a temperature parameter. As \(\tau \to 0\), \(\tilde{m}_g \to m_g\) (discrete). During training, \(\tau\) is annealed from a high value (smooth, easy to optimize) to a low value (near-discrete).
The loss becomes:
$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}}(W \odot \tilde{M}) + \lambda \sum_g \sigma(\alpha_g)$$where the regularization term encourages sparsity by penalizing the probability of keeping each group.
AMC: AutoML for Model Compression#
AMC (He et al., 2018) uses reinforcement learning to find per-layer pruning ratios automatically.
State: For layer \(l\), the state vector includes:
- Layer type (conv, FC, etc.)
- Layer dimensions (\(C_{in}, C_{out}, k, H, W\))
- Current FLOPs and parameter count
- Remaining FLOPs budget
- Layer index
Action: Continuous action \(a_l \in [0, 1]\) specifying the pruning ratio for layer \(l\). (E.g., \(a_l = 0.3\) means prune 30% of filters in layer \(l\).)
Reward: After pruning all layers with the chosen ratios and brief fine-tuning:
$$R = -\text{Error}(f_{\text{pruned}}) \quad \text{s.t.} \quad \text{FLOPs}(f_{\text{pruned}}) \leq \text{FLOPs}_{\text{target}}$$If the constraint is violated, a large negative penalty is applied.
Policy: A DDPG (Deep Deterministic Policy Gradient) agent learns a policy \(\pi(a_l | s_l)\) that maps layer states to pruning ratios. The agent processes layers sequentially, observing the updated state after each pruning decision.
Key finding: AMC consistently outperforms hand-crafted uniform pruning ratios. The learned policies tend to prune more aggressively in redundant layers (early conv layers with many similar filters) and preserve layers that are bottlenecks.
Structured Pruning Algorithms (Deep Dive)#
DepGraph (Dependency Graph-Based Pruning, 2023)#
Modern architectures have complex topologies (residual connections, concatenation, split, group convolutions) that make structured pruning non-trivial. Pruning a filter in one layer may require simultaneously pruning corresponding structures in multiple other layers.
The problem: In a ResNet block:
x ---> Conv1 ---> BN1 ---> ReLU ---> Conv2 ---> BN2 ---> (+) ---> out
| ^
+----------------------------------------------------------+If we prune filter \(i\) from Conv1, we must also prune:
- BN1’s \(\gamma_i, \beta_i\), running mean/var index \(i\)
- Input channel \(i\) of Conv2
But if there is a residual connection, the output of Conv2 is added to \(x\). If Conv2’s output channels are pruned, the addition dimensions no longer match unless the same channels are pruned from \(x\) (which means pruning the same channels from the preceding layer).
DepGraph solution: Build a dependency graph where nodes are parameter groups and edges represent “must prune together” relationships.
Algorithm:
- Parse the computational graph of the network.
- For each layer, identify which dimension of its parameters corresponds to “output features” and “input features.”
- Create dependency edges:
- Conv output channels ↔ next layer’s input channels
- BN parameters ↔ corresponding conv output channels
- Residual addition: all inputs must have matching channel counts → coupled pruning groups
- Concatenation: each branch can be pruned independently
- Group all transitively connected parameters into “pruning groups.”
- Assign importance scores to each group.
- Prune the least important groups.
ASCII Diagram: Dependency Graph for ResNet Block#
Dependency Graph for ResNet Basic Block:
[Conv1 output ch]---dep---[BN1 channels]---dep---[Conv2 input ch]
|
dep
|
[Conv2 output ch]---dep---[BN2 channels]---dep---[Add input ch]
|
dep (residual)
|
[Previous block output ch]
|
dep
|
[Previous BN output ch]
...
Pruning Group Example (if we want to prune channel i):
{Conv1.weight[i,:,:,:], BN1.gamma[i], BN1.beta[i],
Conv2.weight[:,i,:,:]}
For residual-connected channels:
{Conv2.weight[j,:,:,:], BN2.gamma[j], BN2.beta[j],
Prev_Conv.weight[j,:,:,:], Prev_BN.gamma[j], Prev_BN.beta[j], ...}
^--- all layers in the residual chain must prune channel j togetherThis automatic dependency resolution is what makes DepGraph applicable to arbitrary architectures (EfficientNet, ConvNeXt, Vision Transformers, etc.) without manual per-architecture pruning code.
Group Sparsity Regularization#
Instead of pruning after training, we can encourage structured sparsity during training through group sparsity regularization.
Group LASSO (L2,1 Norm)#
Partition the weight matrix into groups \(g_1, g_2, \ldots, g_G\) (e.g., each group is one filter). The group LASSO penalty is:
$$\Omega(W) = \sum_{g=1}^{G} ||W_{g}||_2 = \sum_{g=1}^{G} \sqrt{\sum_{i \in g} W_i^2}$$The training loss becomes:
$$\mathcal{L}_{\text{total}} = \mathcal{L}_{\text{task}}(W) + \lambda \sum_{g=1}^{G} ||W_g||_2$$Why group LASSO induces group sparsity (derivation):
The subdifferential of \(||W_g||_2\) with respect to \(W_g\) is:
$$\partial ||W_g||_2 = \begin{cases} \frac{W_g}{||W_g||_2} & \text{if } W_g \neq 0 \\ \{v : ||v||_2 \leq 1\} & \text{if } W_g = 0 \end{cases}$$At the optimum, for group \(g\):
$$0 \in \frac{\partial \mathcal{L}_{\text{task}}}{\partial W_g} + \lambda \partial ||W_g||_2$$If \(||\frac{\partial \mathcal{L}_{\text{task}}}{\partial W_g}||_2 \leq \lambda\), then \(W_g = 0\) is optimal (the entire group is zeroed). This is the mechanism by which group LASSO drives entire groups to zero, unlike L2 regularization (weight decay) which shrinks all weights but never exactly zeros them.
Proximal Gradient Descent#
Since \(||W_g||_2\) is non-smooth at \(W_g = 0\), we use the proximal gradient method:
$$W^{(t+1)} = \text{prox}_{\eta\lambda\Omega}\left(W^{(t)} - \eta \nabla \mathcal{L}_{\text{task}}(W^{(t)})\right)$$Derivation of the proximal operator for group LASSO:
The proximal operator is defined as:
$$\text{prox}_{\eta\lambda||.||_2}(v) = \arg\min_u \frac{1}{2}||u - v||_2^2 + \eta\lambda ||u||_2$$This separates across groups. For a single group with parameter \(v_g\):
$$\text{prox}_{\eta\lambda||.||_2}(v_g) = \arg\min_{u_g} \frac{1}{2}||u_g - v_g||_2^2 + \eta\lambda ||u_g||_2$$Case 1: \(u_g = 0\). Objective = \(\frac{1}{2}||v_g||_2^2\).
Case 2: \(u_g \neq 0\). Take derivative and set to zero:
$$(u_g - v_g) + \eta\lambda \frac{u_g}{||u_g||_2} = 0$$$$u_g\left(1 + \frac{\eta\lambda}{||u_g||_2}\right) = v_g$$Since both \(u_g\) and \(v_g\) point in the same direction:
$$u_g = v_g \cdot \frac{||u_g||_2}{||u_g||_2 + \eta\lambda}$$Taking norms: \(||u_g|| = ||v_g|| \cdot \frac{||u_g||}{||u_g|| + \eta\lambda}\), so \(||u_g|| + \eta\lambda = ||v_g||\), giving \(||u_g|| = ||v_g|| - \eta\lambda\).
This is valid only when \(||v_g|| > \eta\lambda\). Otherwise, \(u_g = 0\).
Final proximal operator (group soft-thresholding):
$$\text{prox}_{\eta\lambda||\cdot||_2}(v_g) = \begin{cases} v_g \cdot \left(1 - \frac{\eta\lambda}{||v_g||_2}\right) & \text{if } ||v_g||_2 > \eta\lambda \\ 0 & \text{otherwise} \end{cases}$$This is the block soft-thresholding operator. When \(||v_g||_2 \leq \eta\lambda\), the entire group is set to zero in a single step.
Soft Pruning vs Hard Pruning#
Hard pruning: Once a structure is pruned, it is permanently removed from the architecture. The pruned network has fewer parameters and cannot recover the pruned capacity.
Soft pruning (He et al., 2018): Set pruned structures to zero but keep them in the architecture. During fine-tuning, pruned weights can be updated (potentially becoming nonzero again). The pruning mask is periodically recomputed.
Hard Pruning Cycle:
Train -> Score -> Prune (permanent) -> Fine-tune -> Done
|
Removed from architecture
Soft Pruning Cycle:
Train -> Score -> Mask (temporary) -> Fine-tune -> Re-score -> Re-mask -> ...
| |
Set to zero, but kept May unmask previously prunedComparison:
| Aspect | Hard Pruning | Soft Pruning |
|---|---|---|
| Architecture | Changes (smaller) | Unchanged (sparse) |
| Recovery | No regrowth possible | Regrowth possible |
| Final model | Truly smaller, dense | Needs mask enforcement |
| Accuracy | May lose info permanently | Generally higher accuracy |
| Compute during fine-tune | Less (smaller network) | More (full network) |
| Best for | Deployment | Finding optimal sparse structure |
Real-World Speedup Analysis#
Theoretical FLOPs Reduction vs Actual Wall-Clock Speedup#
The gap between theoretical and actual speedup is the most important practical consideration in pruning. We now analyze why this gap exists and how structured pruning closes it.
Why structured pruning gets real speedup:
Dense GEMM on smaller matrices: A pruned convolution with 50% of filters removed is simply a convolution with half the output channels. The hardware runs the same dense operation, just on a smaller tensor. All existing optimizations (tiling, vectorization, Tensor Core utilization) apply perfectly.
No sparse format overhead: No index arrays, no indirect memory access, no metadata. The weight tensor is a standard contiguous block.
Better memory bandwidth utilization: Smaller tensors mean less data transfer between DRAM, L2 cache, and compute units. For memory-bandwidth-bound operations (small batch sizes, depthwise convolutions), this is the dominant factor.
Speedup Measurements on Different Hardware#
| Method | Sparsity | Model | FLOPs Reduction | GPU (A100) | GPU (RTX 4090) | CPU (Intel Xeon) | Mobile (Snapdragon 8 Gen 2) |
|---|---|---|---|---|---|---|---|
| Unstructured magnitude | 90% | ResNet-50 | 10x | 1.0-1.2x | 1.0-1.1x | 1.5-2.0x | 1.0-1.3x |
| Unstructured magnitude | 95% | ResNet-50 | 20x | 1.1-1.3x | 1.0-1.2x | 1.8-2.5x | 1.1-1.4x |
| 2:4 N:M (NVIDIA ASP) | 50% | ResNet-50 | 2x | 1.8-2.0x | 1.7-1.9x | 1.0x (no HW) | 1.0x (no HW) |
| Filter pruning (50%) | 50% | ResNet-50 | ~2x | 1.7-1.9x | 1.7-1.9x | 1.6-1.8x | 1.5-1.7x |
| Filter pruning (70%) | 70% | ResNet-50 | ~3.3x | 2.5-3.0x | 2.5-2.9x | 2.2-2.8x | 2.0-2.5x |
| Channel pruning (50%) | 50% | MobileNetV2 | ~2x | 1.5-1.7x | 1.5-1.7x | 1.7-1.9x | 1.8-2.0x |
Key observations:
- Unstructured pruning shows almost no speedup on GPUs even at 90%+ sparsity.
- CPUs fare slightly better with unstructured pruning due to branch-based sparse kernels.
- Structured pruning achieves near-theoretical speedup across all platforms.
- 2:4 sparsity is excellent on NVIDIA GPUs but useless on other hardware.
- Mobile platforms benefit most from structured pruning (memory-bandwidth bound).
Roofline Model Analysis#
The roofline model relates computational performance to arithmetic intensity (FLOPs per byte of memory transferred).
Performance Roofline Model: Sparse vs Dense
(TFLOPS)
|
8 | xxxxxxxxxxxxxxxxxx Peak Compute (Dense)
| x
6 | x
| x oooooooooooooooo Peak Compute (2:4 Sparse)
4 | x o
| Dense: x o
3 | x o
| x o Structured pruning shifts
2 | ___x_____o operations LEFT (less data)
| / x o AND stays on dense roofline
1 | /x o
| /x o
| /x o Unstructured: below roofline due to
| / o irregular access (cache misses, no vectorization)
|o..........u.u..u...u..u Unstructured sparse
+---+----+----+----+----+---> Arithmetic Intensity
1 2 4 8 16 (FLOPs / Byte)
x = Dense/Structured (on the roofline)
o = 2:4 Sparse (on a lower but real roofline)
u = Unstructured (below any roofline due to overhead)
Key: Structured pruning reduces problem size while staying on the optimal
roofline. Unstructured pruning falls off the roofline entirely.Analysis: Dense and structured-pruned operations ride the roofline – they achieve peak performance for their arithmetic intensity. Reducing the tensor size (structured pruning) moves the operating point left on the x-axis (less data to transfer) but stays on the roofline. The actual throughput equals \(\min(\text{peak compute}, \text{bandwidth} \times \text{arithmetic intensity})\).
Unstructured sparse operations fall below the roofline because:
- Cache misses from irregular access reduce effective bandwidth.
- No SIMD/Tensor Core utilization reduces effective peak compute.
- Index overhead increases bytes transferred without adding useful FLOPs.
Pruning + Other Compression Techniques#
Pruning + Quantization: Compound Compression#
Pruning and quantization are complementary: pruning reduces the number of parameters, quantization reduces the bits per parameter. The compound compression ratio multiplies:
$$\text{Compression}_{\text{total}} = \text{Compression}_{\text{prune}} \times \text{Compression}_{\text{quant}}$$Example: 50% structured pruning (2x) + INT8 quantization (4x from FP32) = 8x total compression. Plus 50% pruning (2x) + INT4 quantization (8x) = 16x compression.
Order matters: There are two strategies:
Prune first, then quantize:
- Train dense model.
- Prune to target sparsity, fine-tune.
- Quantize the pruned model (PTQ or QAT).
This is simpler but may lose accuracy at the quantization step because the pruned model has fewer parameters to absorb quantization error.
Joint pruning and quantization:
- Train with both pruning masks and quantization-aware fake quantization.
- The model adapts to both constraints simultaneously.
- Generally achieves better accuracy but is more complex to implement.
N:M sparsity + INT8 quantization is the most deployment-friendly combination:
- 2:4 sparsity: 2x Tensor Core throughput
- INT8: 2x throughput over FP16 on Tensor Cores
- Combined: theoretically 4x throughput over dense FP16
- Memory: 50% weights x 50% bits = 25% of original FP16 storage
Pruning + Knowledge Distillation#
Knowledge distillation uses a large teacher model to guide the training of a smaller student model. When combined with pruning:
Standard distillation loss:
$$\mathcal{L}_{\text{KD}} = \alpha \cdot \mathcal{L}_{\text{CE}}(y, \hat{y}_S) + (1-\alpha) \cdot T^2 \cdot \text{KL}(\sigma(\hat{z}_T/T) || \sigma(\hat{z}_S/T))$$where \(\hat{z}_T, \hat{z}_S\) are teacher and student logits, \(T\) is temperature, and \(\sigma\) is softmax.
Feature-level distillation for structured pruning: When filter/channel pruning changes intermediate feature map dimensions, a linear projection aligns teacher and student feature maps:
$$\mathcal{L}_{\text{feat}} = \sum_l ||f_T^{(l)} - P_l \cdot f_S^{(l)}||_F^2$$where \(P_l\) is a learnable projection matrix that maps the student’s (smaller) feature maps to the teacher’s dimension. This is particularly important for structured pruning because entire feature channels are missing from the student.
Pruning + NAS#
Neural Architecture Search and pruning share a deep connection: pruning can be viewed as searching for optimal sub-architectures within a larger network.
Once-for-All (OFA) Networks (Cai et al., 2020):
- Train a single large network that supports elastic depth, width, and kernel size.
- At deployment time, extract a sub-network matching the target hardware constraints.
- The sub-network extraction is equivalent to structured pruning (removing layers, channels, reducing kernels).
- No fine-tuning needed because the large network was trained to support all sub-networks.
This unifies NAS and pruning into a single paradigm: train once, deploy many configurations.
Practical Implementation Guide#
PyTorch Pruning API (torch.nn.utils.prune)#
PyTorch provides built-in pruning utilities. Here is how the key functions work:
Unstructured pruning:
import torch.nn.utils.prune as prune
# L1 unstructured: prune 30% of weights by magnitude
prune.l1_unstructured(module, name='weight', amount=0.3)
# Creates: module.weight_mask (binary), module.weight_orig (original)
# module.weight is now a property: weight_orig * weight_mask
# Random unstructured
prune.random_unstructured(module, name='weight', amount=0.3)
# Global unstructured: prune 20% globally across multiple layers
parameters_to_prune = [
(model.conv1, 'weight'),
(model.conv2, 'weight'),
(model.fc1, 'weight'),
]
prune.global_unstructured(
parameters_to_prune,
pruning_method=prune.L1Unstructured,
amount=0.2,
)Structured pruning:
# Ln structured: prune 40% of filters by L2-norm (dim=0 = output channels)
prune.ln_structured(module, name='weight', amount=0.4, n=2, dim=0)
# This zeros out entire filters (output channels)
# Prune by L1-norm along input channel dimension
prune.ln_structured(module, name='weight', amount=0.3, n=1, dim=1)Making pruning permanent (removing the reparametrization):
prune.remove(module, 'weight')
# Now module.weight is a regular parameter with zeros baked in
# For structured pruning, you still need to manually resize the tensor
# and adjust adjacent layersImportant caveat: PyTorch’s built-in pruning API only applies masks; it does not physically remove structures. For actual speedup from structured pruning, you must manually reconstruct the network with smaller layers. Libraries like DepGraph, torch-pruning, and NNI handle this automatically.
NVIDIA ASP (Automatic SParsity) for 2:4#
NVIDIA’s Automatic SParsity library applies 2:4 sparsity to PyTorch models with minimal code:
from apex.contrib.sparsity import ASP
# Prepare model for sparse training
ASP.prune_trained_model(model, optimizer)
# This applies 2:4 masks to all supported layers (Linear, Conv2d)
# Training loop runs normally; masks are maintained
for epoch in range(num_epochs):
for batch in dataloader:
loss = model(batch)
loss.backward()
optimizer.step()
# ASP automatically re-applies 2:4 masks after each step
# Export sparse model for inference
# The 2:4 pattern is automatically detected by TensorRT for
# Sparse Tensor Core accelerationComparison of Pruning Tools#
| Tool | Framework | Structured | Unstructured | N:M | Auto Dependency | Key Feature |
|---|---|---|---|---|---|---|
| torch.nn.utils.prune | PyTorch | Mask only | Yes | No | No | Built-in, simple API |
| torch-pruning (DepGraph) | PyTorch | Yes (physical) | No | No | Yes | Handles any architecture |
| NVIDIA ASP | PyTorch | No | No | 2:4 | N/A | Sparse Tensor Core ready |
| NNI (Microsoft) | PyTorch/TF | Yes | Yes | No | Partial | Many algorithms built-in |
| Intel Neural Compressor | PyTorch/TF | Yes | Yes | No | No | CPU-optimized inference |
| TF Model Optimization | TensorFlow | Yes | Yes | No | No | TFLite integration |
| ONNX Runtime | ONNX | Partial | Yes | No | N/A | Cross-framework inference |
Choosing the Right Tool#
Decision criteria:
If targeting NVIDIA GPU inference with TensorRT: Use NVIDIA ASP for 2:4 sparsity. This is the path of least resistance for guaranteed 2x Tensor Core speedup.
If targeting mobile/edge (TFLite, Core ML): Use structured pruning via torch-pruning or TF Model Optimization. Physical tensor size reduction translates directly to latency reduction.
If targeting CPU inference: Structured pruning (filter/channel) with Intel Neural Compressor. CPU benefits from both smaller tensors and reduced memory bandwidth.
If targeting maximum compression (storage, not latency): Unstructured pruning at high sparsity + quantization. Store in sparse format. Accept no inference speedup on standard hardware.
If working with complex architectures: Use DepGraph (torch-pruning) for automatic dependency resolution.
Summary#
Structured vs Unstructured Decision Matrix#
Need real speedup Need max compression Have sparse HW
on standard HW? (storage only)? (Ampere+)?
| | |
YES YES YES
| | |
Structured Unstructured N:M (2:4)
Pruning Pruning Sparsity
| | |
Filter/Channel Magnitude NVIDIA ASP
+ DepGraph + Sparse Format + Fine-tune
| | |
Actual 1.5-3x Theoretical 10-20x Actual 2x
speedup (1.0x on GPU) on Tensor CoresComplete Comparison Table#
| Dimension | Unstructured | N:M (2:4) | Block Sparse | Channel/Filter | Layer |
|---|---|---|---|---|---|
| Granularity | Individual weight | 2 of 4 elements | k x k block | Entire channel/filter | Entire layer |
| Typical sparsity | 90-99% | 50% | 50-80% | 30-70% | 10-30% |
| Accuracy at target | Best | Very good | Good | Good | Fair |
| GPU speedup | ~1.0x | 2.0x (Ampere+) | 1.3-1.8x | Near-theoretical | Near-theoretical |
| CPU speedup | 1.5-2.5x | 1.0x (no HW) | 1.2-1.5x | Near-theoretical | Near-theoretical |
| Mobile speedup | ~1.0x | 1.0x (no HW) | ~1.0x | Near-theoretical | Near-theoretical |
| Implementation ease | Easy (masking) | Easy (ASP) | Moderate | Hard (dependency) | Easy |
| Format overhead | High (indices) | Low (2-bit meta) | Moderate | None (dense) | None (dense) |
| Framework support | Excellent | NVIDIA only | Limited | Good (with libraries) | Manual |
| Best use case | Specialized HW | NVIDIA GPU inference | Research | General deployment | Very deep nets |
Key Takeaways#
Unstructured pruning is a compression technique, not an acceleration technique (on commodity hardware). Use it when storage size matters more than inference speed, or when deploying on sparsity-aware hardware like Cerebras.
Structured pruning is the only way to get real speedup on GPUs, CPUs, and mobile devices without specialized hardware. It produces smaller dense tensors that exploit all existing hardware optimizations.
2:4 sparsity is the current best compromise for NVIDIA GPU deployment: 50% sparsity with hardware-guaranteed 2x Tensor Core throughput, minimal accuracy loss, and easy implementation via ASP.
Dependency-aware pruning (DepGraph) is essential for modern architectures. Manual structured pruning is error-prone and architecture-specific; automatic dependency resolution makes it applicable to any model.
Compound compression (pruning + quantization + distillation) yields the best real-world results. The compression ratios multiply, and knowledge distillation recovers accuracy lost to aggressive pruning.
The pruning criterion matters less than the pruning granularity for speedup. A simple L1-norm filter pruning with proper fine-tuning often matches sophisticated criteria in accuracy, and the speedup is determined by the structure, not the selection method.
Always measure wall-clock time, not FLOPs. A method claiming 10x FLOPs reduction with no wall-clock improvement is not useful for deployment.
Next post: We will explore advanced pruning methods including lottery ticket hypothesis, pruning at initialization, gradual magnitude pruning schedules, and iterative pruning strategies that push the boundaries of how much we can prune while maintaining accuracy.