Table of Contents
1. Introduction: Why Quantization Matters#
Modern deep neural networks demand enormous compute and memory. A single forward pass of a large language model can require hundreds of gigabytes of memory and trillions of floating-point operations. Quantization addresses this by representing weights and activations with fewer bits, yielding smaller models and faster inference.
There are two dominant paradigms:
| Paradigm | When Applied | Calibration Data | Accuracy |
|---|---|---|---|
| Post-Training Quantization (PTQ) | After training | Small calibration set | Good for >= 8-bit |
| Quantization-Aware Training (QAT) | During training | Full training set | Superior, especially < 8-bit |
PTQ is convenient but struggles at low bit-widths (4-bit, 2-bit, binary). QAT embeds quantization into the training loop so the network learns to compensate for quantization error, consistently delivering higher accuracy across all bit-widths.
This post provides a thorough treatment of QAT: the mathematics, the algorithms, the engineering, and the practical decision-making.
2. Quantization Fundamentals Recap#
2.1 Uniform Affine Quantization#
The standard uniform quantization maps a floating-point value \(x\) to an integer \(x_q\):
$$x_q = \text{clamp}\!\left(\left\lfloor \frac{x}{s} \right\rceil + z,\; q_{\min},\; q_{\max}\right)$$where \(s\) is the scale, \(z\) is the zero-point, and \(\lfloor \cdot \rceil\) denotes rounding to the nearest integer. For a \(b\)-bit unsigned quantization:
$$q_{\min} = 0, \quad q_{\max} = 2^b - 1$$For signed quantization:
$$q_{\min} = -2^{b-1}, \quad q_{\max} = 2^{b-1} - 1$$The dequantization step recovers an approximation:
$$\hat{x} = s \cdot (x_q - z)$$2.2 Symmetric vs. Asymmetric#
| Property | Symmetric | Asymmetric |
|---|---|---|
| Zero-point | \(z = 0\) | \(z \neq 0\) |
| Range | \([-\alpha, \alpha]\) | \([\beta_{\min}, \beta_{\max}]\) |
| Use case | Weights (often symmetric around 0) | Activations (e.g., after ReLU, non-negative) |
| Hardware | Simpler (no zero-point offset) | Slightly more complex |
For symmetric quantization, the scale is:
$$s = \frac{\alpha}{q_{\max}}$$where \(\alpha = \max(|x_{\min}|, |x_{\max}|)\).
2.3 Per-Tensor vs. Per-Channel#
Per-tensor quantization uses a single \((s, z)\) pair for the entire tensor. Per-channel quantization assigns a separate \((s_c, z_c)\) for each output channel of a convolution or linear layer. Per-channel is almost always preferred for weights because different channels can have vastly different dynamic ranges.
Per-Tensor: Per-Channel:
+---------------------------+ +---------------------------+
| s=0.02, z=0 | | ch0: s=0.01, z=0 |
| applies to ALL elements | | ch1: s=0.03, z=0 |
+---------------------------+ | ch2: s=0.005, z=0 |
| ... |
+---------------------------+3. The Straight-Through Estimator (STE)#
3.1 The Core Problem#
Quantization involves rounding, and rounding is a piecewise-constant function. Its true gradient is zero almost everywhere and undefined at integers:
$$\frac{\partial \lfloor x \rceil}{\partial x} = 0 \quad \text{a.e.}$$This means that if we naively insert quantization into the computation graph, gradient-based optimization halts entirely because no gradient signal flows through the quantization nodes.
3.2 Bengio’s Straight-Through Estimator#
The Straight-Through Estimator (STE), popularized by Bengio et al. (2013), resolves this by approximating the gradient of the rounding function as the identity:
$$\frac{\partial \lfloor x \rceil}{\partial x} \approx 1$$More precisely, let \(Q(x)\) be the full quantize-then-dequantize operation. In the forward pass, we compute:
$$\hat{x} = Q(x) = s \cdot \left(\text{clamp}\!\left(\left\lfloor \frac{x}{s} \right\rceil + z,\; q_{\min},\; q_{\max}\right) - z\right)$$In the backward pass, we pretend that \(Q\) is the identity within the clipping range:
$$\frac{\partial \mathcal{L}}{\partial x} \approx \frac{\partial \mathcal{L}}{\partial \hat{x}} \cdot \mathbf{1}_{x \in [x_{\min}, x_{\max}]}$$where \(\mathbf{1}{x \in [x{\min}, x_{\max}]}\) is the indicator function that passes gradients only when \(x\) is within the quantization range \([x_{\min}, x_{\max}] = [s(q_{\min} - z),; s(q_{\max} - z)]\).
3.3 STE as a Subgradient Method#
The STE can be interpreted through the lens of subgradient optimization. The rounding function \(r(x) = \lfloor x \rceil\) is the proximal operator of the indicator function for integers. The STE gradient \(\frac{\partial r}{\partial x} = 1\) corresponds to a subgradient of the piecewise-linear interpolation of the rounding function, which is precisely the identity function.
Formally, consider the “soft” relaxation:
$$r_{\text{soft}}(x) = x$$We have \(r_{\text{soft}}(x) = r(x)\) at every integer, and the gradient \(\nabla r_{\text{soft}} = 1\) everywhere. The STE simply uses this smooth surrogate’s gradient while evaluating the hard function in the forward pass.
3.4 STE with Clipping Gradient#
The complete STE with clipping can be written as a single expression using the indicator:
$$\frac{\partial Q(x)}{\partial x} = \begin{cases} 1 & \text{if } q_{\min} \leq \frac{x}{s} + z \leq q_{\max} \\ 0 & \text{otherwise} \end{cases}$$This zero-gradient outside the clipping range is critical. Without it, outlier weights or activations would never receive a gradient signal pushing them back into the representable range.
Forward Pass:
x ---> [ Round + Clamp ] ---> x_q ---> [ Dequantize ] ---> x_hat
(non-differentiable)
Backward Pass (STE):
dL/dx <--- [ Identity * Indicator ] <--- dL/dx_hat
(differentiable surrogate)3.5 Limitations of the STE#
The STE introduces a gradient mismatch: the forward function and the backward function are different. This has several consequences:
- Biased gradients: The expected gradient under STE does not equal the true gradient (which is zero). This bias can cause optimization to converge to suboptimal points.
- Accumulation of error: In very deep networks or at very low bit-widths, the accumulated gradient mismatch can destabilize training.
- Dead neurons: If a weight is pushed far outside the clipping range, it receives zero gradient and cannot recover.
Despite these limitations, the STE works remarkably well in practice and remains the foundation of nearly all QAT methods.
4. Fake Quantization Nodes#
4.1 Concept#
A fake quantization node (also called a simulated quantization node) is the operational core of QAT. It performs quantization and immediate dequantization in the forward pass, so the output remains in floating-point but carries quantization error:
$$\text{FakeQuant}(x) = s \cdot \left(\text{clamp}\!\left(\left\lfloor \frac{x}{s} \right\rceil + z,\; q_{\min},\; q_{\max}\right) - z\right)$$The key insight is that the tensor shapes and data types remain in floating-point throughout training, so standard GPU hardware and autograd frameworks work normally. The quantization noise is injected as a deterministic perturbation.
4.2 Placement in the Graph#
Fake quantization nodes are inserted at specific points:
Original Graph QAT Graph
+-----------+ +------------+
input ------->| Conv2d | input ----->| FakeQuant |
| | +------------+
| weights | |
+-----------+ +-----v------+
| | Conv2d |<-- FakeQuant(weights)
v +------------+
+-----------+ |
| BN | v
+-----------+ +------------+
| | BN |
v +------------+
+-----------+ |
| ReLU | v
+-----------+ +------------+
| ReLU |
+------------+
|
v
+------------+
| FakeQuant | (activation)
+------------+The typical placement rules are:
- Weight fake quantization: Applied to weights before each convolution or linear layer.
- Activation fake quantization: Applied after the activation function (e.g., ReLU), since the activation’s output range is what the next layer will see at inference.
- Input fake quantization: Applied to the model’s input to simulate input quantization.
4.3 Observer and Fake Quantization Interplay#
During QAT, each fake quantization node contains an observer that tracks running statistics to determine \(s\) and \(z\):
| Observer Type | Description |
|---|---|
| MinMax | Tracks global min/max over all batches |
| MovingAverage | Exponential moving average of min/max |
| Histogram | Builds histogram, minimizes KL divergence or MSE |
| Percentile | Uses p-th and (100-p)-th percentile to exclude outliers |
The observer updates its statistics during the forward pass, and the fake quantization node uses the computed \(s, z\) to perform the quantize-dequantize operation.
5. The QAT Training Pipeline#
5.1 Overall Workflow#
The standard QAT pipeline follows these steps:
Step 1: Train FP32 model to convergence (or load pretrained)
|
v
Step 2: Prepare QAT model
- Insert FakeQuant nodes for weights and activations
- Attach observers
- Optionally fold BatchNorm layers
|
v
Step 3: Calibrate observers (a few batches in eval mode)
- Observers collect activation statistics
- No weight updates
|
v
Step 4: Fine-tune with QAT (train mode)
- Observers may freeze or continue updating
- Typically 10-30% of original training epochs
- Lower learning rate (1/10 to 1/100 of original)
|
v
Step 5: Convert to quantized model
- Remove FakeQuant nodes
- Store integer weights with scales/zero-points
- Ready for integer-only inference5.2 Learning Rate Schedule#
QAT is essentially fine-tuning, so the learning rate should be significantly lower than the original training. Common practices:
- Start at 1% to 10% of the peak training learning rate.
- Use cosine annealing or step decay.
- Total QAT epochs: typically 5 to 30, depending on the model and target bit-width.
5.3 Observer Freezing#
A critical but often overlooked detail: observers should be frozen after a warm-up period. If observers keep updating throughout training, the quantization grid shifts every step, introducing noise that can destabilize convergence. The recommended practice is:
- Epoch 0 to N_obs: Observers active, collecting statistics.
- Epoch N_obs to end: Observers frozen, fake quantization uses fixed \(s, z\).
In PyTorch, this is controlled via torch.ao.quantization.disable_observer applied after the warm-up period.
6. Learned Step Size Quantization (LSQ)#
6.1 Motivation#
Standard QAT uses fixed or heuristically determined quantization parameters. LSQ (Esser et al., 2020) proposes making the step size (scale \(s\)) a learnable parameter optimized jointly with the network weights via gradient descent.
6.2 Formulation#
LSQ uses symmetric uniform quantization. For a weight or activation \(x\), the quantized-then-dequantized value is:
$$\hat{x} = s \cdot \text{clamp}\!\left(\left\lfloor \frac{x}{s} \right\rceil, -Q_N, Q_P\right)$$where \(Q_N = 2^{b-1}\) and \(Q_P = 2^{b-1} - 1\) for \(b\)-bit signed quantization.
6.3 Gradient of the Step Size#
Using the STE, the gradient of the loss \(\mathcal{L}\) with respect to the step size \(s\) is derived as follows. Let \(\bar{x} = x / s\) and \(\hat{q} = \text{clamp}(\lfloor \bar{x} \rceil, -Q_N, Q_P)\). Then \(\hat{x} = s \cdot \hat{q}\), and:
$$\frac{\partial \mathcal{L}}{\partial s} = \frac{\partial \mathcal{L}}{\partial \hat{x}} \cdot \frac{\partial \hat{x}}{\partial s}$$Applying the product rule and STE:
$$\frac{\partial \hat{x}}{\partial s} = \begin{cases} -x/s + \lfloor x/s \rceil & \text{if } -Q_N \leq \bar{x} \leq Q_P \\ -Q_N & \text{if } \bar{x} < -Q_N \\ Q_P & \text{if } \bar{x} > Q_P \end{cases}$$This can be simplified. When \(\bar{x}\) is within range, the gradient is approximately \(\hat{q} - \bar{x} + \hat{q} = \hat{q} - \bar{x}\)… but more precisely:
$$\frac{\partial \hat{x}}{\partial s} = \hat{q} + s \cdot \frac{\partial \hat{q}}{\partial s}$$Under the STE, \(\frac{\partial \hat{q}}{\partial s} \approx \frac{\partial \bar{x}}{\partial s} \cdot 1 = -x/s^2\) when in range, so:
$$\frac{\partial \hat{x}}{\partial s} = \hat{q} - \frac{x}{s} = \hat{q} - \bar{x}$$This is the quantization residual. When clipped:
$$\frac{\partial \hat{x}}{\partial s} = \begin{cases} \hat{q} - \bar{x} & \text{if } -Q_N \leq \bar{x} \leq Q_P \\ -Q_N & \text{if } \bar{x} < -Q_N \\ Q_P & \text{if } \bar{x} > Q_P \end{cases}$$6.4 Scale Gradient Scaling#
A crucial practical detail in LSQ is the gradient scale factor. The step size \(s\) is a single scalar, but it affects every element in the tensor. Without scaling, the gradient magnitude for \(s\) would be disproportionately large compared to individual weight gradients. LSQ proposes:
$$g_s = \frac{1}{\sqrt{N \cdot Q_P}}$$where \(N\) is the number of elements in the tensor. The step size update becomes:
$$s \leftarrow s - \eta \cdot g_s \cdot \frac{\partial \mathcal{L}}{\partial s}$$6.5 Initialization#
The initial step size is set based on the tensor’s initial statistics:
$$s_0 = \frac{2 \cdot \text{mean}(|x|)}{\sqrt{Q_P}}$$This heuristic ensures that the initial quantization grid covers the bulk of the value distribution without being dominated by outliers.
7. LSQ+ (Learned Step Size Quantization Plus)#
7.1 Extension to Asymmetric Quantization#
LSQ+ (Bhalgat et al., 2020) extends LSQ by also learning the zero-point offset \(\beta\) as a continuous parameter:
$$\hat{x} = s \cdot \text{clamp}\!\left(\left\lfloor \frac{x - \beta}{s} \right\rceil, q_{\min}, q_{\max}\right) + \beta$$7.2 Gradients#
The gradient with respect to \(\beta\) is:
$$\frac{\partial \hat{x}}{\partial \beta} = \begin{cases} 1 - 1 = 0 & \text{if } q_{\min} \leq \frac{x - \beta}{s} \leq q_{\max} \quad \text{(STE passes through)} \\ 1 & \text{if outside range} \end{cases}$$Wait – let us derive this more carefully. Writing \(\bar{x} = (x - \beta)/s\):
$$\hat{x} = s \cdot \hat{q} + \beta, \quad \hat{q} = \text{clamp}(\lfloor \bar{x} \rceil, q_{\min}, q_{\max})$$$$\frac{\partial \hat{x}}{\partial \beta} = s \cdot \frac{\partial \hat{q}}{\partial \beta} + 1$$Under STE, when in range: \(\frac{\partial \hat{q}}{\partial \beta} \approx \frac{\partial \bar{x}}{\partial \beta} = -1/s\), so:
$$\frac{\partial \hat{x}}{\partial \beta} = s \cdot (-1/s) + 1 = 0$$When clipped (\(\hat{q}\) saturates): \(\frac{\partial \hat{q}}{\partial \beta} = 0\), so:
$$\frac{\partial \hat{x}}{\partial \beta} = 1$$This means the offset \(\beta\) receives gradient only from values that are being clipped, naturally pushing the quantization window to cover the distribution better.
7.3 Practical Benefit#
LSQ+ is particularly beneficial for activations that are not centered around zero, such as outputs of layers without batch normalization or after certain non-linearities like Swish/GELU where outputs can be slightly negative.
8. PACT: Parameterized Clipping Activation#
8.1 Key Idea#
PACT (Choi et al., 2018) focuses specifically on activation quantization. The insight is that clipping activations to a learned upper bound \(\alpha\) before quantization significantly reduces quantization error.
For ReLU activations:
$$\text{PACT}(x) = 0.5 \cdot (|x| - |x - \alpha| + \alpha) = \begin{cases} 0 & \text{if } x \leq 0 \\ x & \text{if } 0 < x < \alpha \\ \alpha & \text{if } x \geq \alpha \end{cases}$$This is simply a clipped ReLU where the clipping threshold \(\alpha\) is learned.
8.2 Quantization#
After clipping, the activation is uniformly quantized to \(b\) bits:
$$\hat{x} = \frac{\alpha}{2^b - 1} \cdot \left\lfloor \frac{x \cdot (2^b - 1)}{\alpha} \right\rceil$$8.3 Gradient of the Clipping Parameter#
$$\frac{\partial \mathcal{L}}{\partial \alpha} = \sum_i \frac{\partial \mathcal{L}}{\partial \hat{x}_i} \cdot \frac{\partial \hat{x}_i}{\partial \alpha}$$For elements within the range \(0 < x_i < \alpha\), applying the STE through the quantization, the gradient with respect to \(\alpha\) involves the quantization residual (similar to LSQ). For elements clipped at \(\alpha\), the gradient is simply 1 (passed through from the clamp).
The practical gradient expression is:
$$\frac{\partial \hat{x}_i}{\partial \alpha} = \begin{cases} 0 & \text{if } x_i \leq 0 \\ x_i / \alpha \cdot (\text{quantization residual terms}) & \text{if } 0 < x_i < \alpha \\ 1 & \text{if } x_i \geq \alpha \end{cases}$$In many implementations, this is simplified by treating the quantization within the range as approximately preserving the ratio \(x_i / \alpha\), yielding a clean gradient.
8.4 PACT vs. LSQ#
| Aspect | PACT | LSQ |
|---|---|---|
| Learned parameter | Clipping bound \(\alpha\) | Step size \(s\) |
| Applies to | Activations primarily | Both weights and activations |
| Quantization grid | Derived from \(\alpha\) | Directly the step size |
| Flexibility | Moderate | Higher |
| Publication | ICLR 2018 | ICLR 2020 |
9. DoReFa-Net#
9.1 Overview#
DoReFa-Net (Zhou et al., 2016) quantizes weights, activations, and gradients during training, enabling low-bitwidth computation throughout the training process itself.
9.2 Weight Quantization#
Weights are first normalized to \([0, 1]\) using:
$$w_n = \frac{\tanh(w)}{2 \cdot \max(|\tanh(w)|)} + 0.5$$Then quantized to \(k\) bits:
$$w_q = \frac{1}{2^k - 1} \cdot \text{round}(w_n \cdot (2^k - 1))$$The final weight used is \(2 w_q - 1\) to map back to \([-1, 1]\).
9.3 Activation Quantization#
Activations are assumed to be in \([0, 1]\) (after a bounded activation like sigmoid or clipped ReLU):
$$a_q = \frac{1}{2^k - 1} \cdot \text{round}(a \cdot (2^k - 1))$$9.4 Gradient Quantization#
This is the unique contribution of DoReFa-Net. Gradients are quantized stochastically to \(k\) bits. For gradient \(g\), first normalize:
$$g_n = \frac{g - \min(g)}{\max(g) - \min(g)}$$Then apply stochastic quantization:
$$g_q = \frac{1}{2^k - 1} \cdot \left\lfloor g_n \cdot (2^k - 1) + \epsilon \right\rfloor$$where \(\epsilon \sim \text{Uniform}(0, 1)\). The stochastic rounding ensures that \(\mathbb{E}[g_q] = g_n\), providing an unbiased estimator.
9.5 Why Gradient Quantization Matters#
Gradient quantization reduces communication bandwidth in distributed training and memory consumption for gradient storage. However, it introduces additional variance, so more bits are typically needed for gradients (8-bit or 16-bit) compared to weights and activations.
10. Batch Normalization Folding in QAT#
10.1 The Problem#
Batch normalization (BN) applies an affine transformation after normalization:
$$y = \gamma \cdot \frac{x - \mu}{\sqrt{\sigma^2 + \epsilon}} + \beta_{\text{bn}}$$At inference time, this is typically folded into the preceding convolution/linear layer for efficiency. If we have a convolution \(y = Wx + b\), the folded weights and bias become:
$$W_{\text{fold}} = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} \cdot W$$$$b_{\text{fold}} = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} \cdot (b - \mu) + \beta_{\text{bn}}$$10.2 The QAT Complication#
If we fold BN before QAT, the folded weights are different from the training-time weights, and the quantization parameters computed during QAT would be wrong. Conversely, if we do not fold BN during QAT, the quantization nodes do not see the actual inference-time weights.
10.3 Simulated BN Folding#
The solution is simulated BN folding during QAT. During each forward pass:
- Compute the folded weights: \(W_{\text{fold}} = \frac{\gamma}{\sqrt{\sigma^2 + \epsilon}} W\)
- Apply fake quantization to \(W_{\text{fold}}\) (not the original \(W\)).
- Compute the convolution with \(\text{FakeQuant}(W_{\text{fold}})\).
- Add the folded bias.
QAT with BN Folding (Training):
W ----> [ BN Fold ] ----> W_fold ----> [ FakeQuant ] ----> W_fq
|
x ----> [ FakeQuant ] -----+ |
| |
v v
[ Conv2d(x, W_fq) + b_fold ]
|
v
[ ReLU ]
|
v
[ FakeQuant ]10.4 Running Statistics#
During QAT with simulated BN folding, the BN running mean and variance are still updated using the batch statistics. However, the folded weights for quantization use the running (exponential moving average) statistics, not the batch statistics. This avoids instability from batch-to-batch fluctuations.
After training, the final running statistics are used to compute the permanently folded weights for inference.
10.5 Numerical Stability#
When \(\sigma\) is very small, the folding factor \(\gamma / \sqrt{\sigma^2 + \epsilon}\) can be extremely large, amplifying weight magnitudes and potentially causing overflow in low-bitwidth quantization. Practical mitigations include:
- Using a larger \(\epsilon\) in BN (e.g., \(10^{-3}\) instead of \(10^{-5}\)).
- Clipping the folding factor.
- Monitoring the distribution of folded weights during training.
11. Knowledge Distillation Combined with QAT#
11.1 Motivation#
Knowledge distillation (KD) uses a high-capacity teacher model to guide the training of a smaller student model. When the student is a quantized model, KD helps recover accuracy lost to quantization.
11.2 Standard KD + QAT Loss#
The combined loss function is:
$$\mathcal{L} = (1 - \lambda) \cdot \mathcal{L}_{\text{CE}}(y, \hat{y}_S) + \lambda \cdot T^2 \cdot \text{KL}\!\left(\sigma\!\left(\frac{z_T}{T}\right) \| \sigma\!\left(\frac{z_S}{T}\right)\right)$$where:
- \(y\) is the ground-truth label
- \(\hat{y}_S\) is the student’s prediction
- \(z_T, z_S\) are teacher and student logits
- \(T\) is the temperature
- \(\lambda\) balances the two losses
- \(\sigma\) is softmax
- The \(T^2\) factor compensates for the reduced gradient magnitude at higher temperatures
11.3 Feature-Level Distillation#
Beyond logit-level KD, feature-level distillation can be applied:
$$\mathcal{L}_{\text{feat}} = \sum_{l \in \mathcal{S}} \left\| f_l^T - \phi(f_l^S) \right\|_2^2$$where \(f_l^T\) and \(f_l^S\) are intermediate features from the teacher and student at layer \(l\), and \(\phi\) is a learnable projection to match dimensions if needed.
11.4 Self-Distillation for QAT#
A common variant uses the same architecture as both teacher (FP32) and student (quantized). The FP32 pretrained model serves as the teacher, and its quantized copy is the student. This avoids the need to train a separate teacher.
+------------------+ +------------------+
| FP32 Teacher | | Quantized Student|
| (frozen) | | (training) |
| | | |
| Input --> Logits | | Input --> Logits |
+--------+---------+ +--------+---------+
| |
+----------+ +-----------+
| |
v v
[ KL Divergence ]
+
[ CE with labels ]
=
[ Total QAT Loss ]11.5 Practical Results#
KD + QAT consistently provides 0.5–2.0% accuracy improvement over QAT alone, with the benefit increasing at lower bit-widths.
| Method | W4A4 Top-1 (ResNet-50) | W2A2 Top-1 (ResNet-18) |
|---|---|---|
| QAT only | 75.1% | 58.4% |
| QAT + KD (logit) | 76.0% | 60.8% |
| QAT + KD (feature) | 76.3% | 61.5% |
(Illustrative numbers; exact values vary by implementation.)
12. Progressive Quantization#
12.1 Concept#
Rather than directly quantizing from 32-bit to the target bit-width, progressive quantization reduces the bit-width gradually over training:
$$32 \rightarrow 16 \rightarrow 8 \rightarrow 4 \rightarrow 2 \text{ bits}$$At each stage, the model adapts to the coarser quantization grid before moving to the next level.
12.2 Schedule#
A typical progressive schedule:
| Phase | Epochs | Weight Bits | Activation Bits |
|---|---|---|---|
| 1 | 0–10 | 8 | 8 |
| 2 | 10–25 | 4 | 8 |
| 3 | 25–40 | 4 | 4 |
| 4 | 40–60 | 2 | 4 |
12.3 Smooth Bit-Width Transition#
Some methods use a continuous relaxation of the bit-width. Instead of discrete jumps, the effective bit-width is annealed:
$$b(t) = b_{\text{start}} + (b_{\text{end}} - b_{\text{start}}) \cdot \frac{t}{T}$$where \(t\) is the current training step and \(T\) is the total training steps. The quantization step size is adjusted accordingly:
$$s(t) = \frac{\alpha}{2^{b(t)} - 1}$$At non-integer \(b(t)\), this is implemented by interpolating between the two nearest integer bit-width quantizations.
12.4 Benefits#
Progressive quantization is particularly effective for extremely low bit-widths (2-bit, ternary, binary) where direct quantization from FP32 causes too large a loss surface discontinuity for the optimizer to handle.
13. Mixed-Precision QAT#
13.1 Observation#
Not all layers are equally sensitive to quantization. Early layers (which extract low-level features) and the final classifier layer tend to be more sensitive, while middle layers are often robust to aggressive quantization.
13.2 Problem Formulation#
Mixed-precision quantization assigns different bit-widths \(b_l\) to each layer \(l\), solving:
$$\min_{\{b_l\}} \mathcal{L}(\{b_l\}) \quad \text{s.t.} \quad \sum_l \text{Cost}(b_l) \leq \text{Budget}$$where Cost can be model size, latency, or energy.
13.3 Search Methods#
| Method | Approach | Pros | Cons |
|---|---|---|---|
| HAQ (Wang et al.) | Reinforcement learning | Hardware-aware | Expensive search |
| DNAS | Differentiable NAS | End-to-end gradient | Memory intensive |
| HAWQ (Dong et al.) | Hessian-based sensitivity | Principled, fast | Approximation needed |
| Once-for-All | Supernet training | Amortized cost | Training complexity |
13.4 HAWQ: Hessian-Aware Quantization#
HAWQ uses the Hessian trace (or top eigenvalue) to measure layer sensitivity:
$$\Omega_l = \text{tr}(H_l) \approx \text{sensitivity of layer } l \text{ to quantization}$$Layers with larger Hessian trace are more sensitive and should receive more bits. The bit-width allocation is then a knapsack problem:
$$\min_{\{b_l\}} \sum_l \Omega_l \cdot \delta_l(b_l) \quad \text{s.t.} \quad \sum_l b_l \cdot n_l \leq B$$where \(\delta_l(b_l)\) is the perturbation from quantizing layer \(l\) to \(b_l\) bits and \(n_l\) is the number of parameters in layer \(l\).
13.5 Differentiable Mixed-Precision#
In differentiable approaches, each layer maintains a probability distribution over candidate bit-widths:
$$\hat{x}_l = \sum_{b \in \mathcal{B}} \frac{\exp(\alpha_l^b)}{\sum_{b'} \exp(\alpha_l^{b'})} \cdot Q_b(x_l)$$where \(\alpha_l^b\) are learnable architecture parameters. During training, all bit-width options are computed (or approximated via Gumbel-Softmax), and the architecture parameters converge to select the best bit-width per layer.
14. QLoRA: Quantized Low-Rank Adaptation#
14.1 Context#
QLoRA (Dettmers et al., 2023) enables fine-tuning of large language models (LLMs) on consumer hardware by combining 4-bit quantization of the base model with Low-Rank Adaptation (LoRA). It is not classical QAT but a closely related quantization-during-training technique.
14.2 Three Key Innovations#
Innovation 1: NormalFloat 4-bit (NF4)
NF4 is an information-theoretically optimal data type for normally distributed weights. The quantization levels are set at the quantiles of the standard normal distribution:
$$q_i = \Phi^{-1}\!\left(\frac{i + 0.5}{2^4}\right), \quad i = 0, 1, \ldots, 15$$where \(\Phi^{-1}\) is the inverse CDF (quantile function) of the standard normal. This ensures each quantization bin contains an equal probability mass, minimizing the expected quantization error for normally distributed data.
NF4 Quantization Levels (16 values for 4 bits):
-1.0 -0.69 -0.52 -0.39 -0.28 -0.18 -0.09 0.00
0.08 0.17 0.27 0.38 0.51 0.68 0.96 1.0
(approximately -- exact values from normal quantiles)Innovation 2: Double Quantization
The quantization constants (scales) themselves are quantized. For a block size of 64:
- First quantization: FP32 weights to NF4 (one FP32 scale per 64 weights = 32/64 = 0.5 bits overhead per weight).
- Second quantization: The FP32 scales are quantized to FP8 with a block size of 256 (one FP32 scale per 256 scales = 32/256 = 0.125 bits overhead per scale, which is 0.125/64 = ~0.002 bits per original weight).
Total bits per parameter after double quantization:
$$4 + \frac{32}{64} + \frac{8}{64} + \frac{32}{64 \times 256} \approx 4 + 0.5 + 0.125 + 0.002 = 4.627 \text{ bits}$$Compared to naive NF4 without double quantization:
$$4 + \frac{32}{64} = 4.5 \text{ bits}$$Wait – double quantization actually reduces overhead. Without double quantization, each block of 64 needs one FP32 scale = 0.5 bits overhead. With double quantization, the FP32 scale becomes FP8 = 8 bits, reducing overhead to 8/64 = 0.125 bits per weight, plus the second-level scale overhead of 32/(64*256) which is negligible. So:
$$\text{Without double quant: } 4 + 0.5 = 4.5 \text{ bits/param}$$$$\text{With double quant: } 4 + 0.125 + 0.002 = 4.127 \text{ bits/param}$$This saves approximately 0.37 bits per parameter, which for a 65B model translates to:
$$65 \times 10^9 \times 0.37 / 8 \approx 3.0 \text{ GB savings}$$Innovation 3: Paged Optimizers
QLoRA uses NVIDIA unified memory to page optimizer states between GPU and CPU memory, preventing out-of-memory errors during gradient checkpointing spikes.
14.3 Memory Calculation#
For a 65B parameter model:
| Component | Memory |
|---|---|
| Base model (NF4 + double quant) | \(65 \times 10^9 \times 4.127 / 8 \approx 33.5\) GB |
| LoRA adapters (FP16, rank 64) | ~0.8 GB (depending on which layers) |
| Optimizer states (AdamW, FP32 for LoRA) | ~2.4 GB |
| Activations + gradients | ~5–10 GB (with gradient checkpointing) |
| Total | ~42–47 GB |
This fits on a single 48 GB GPU (e.g., A6000), whereas full fine-tuning in FP16 would require:
$$65 \times 10^9 \times 2 \text{ (model)} + 65 \times 10^9 \times 2 \text{ (grad)} + 65 \times 10^9 \times 8 \text{ (Adam states)} = 780 \text{ GB}$$14.4 Training Dynamics#
During QLoRA fine-tuning, the base model weights remain frozen in NF4. Only the LoRA adapters (low-rank matrices \(A\) and \(B\)) are trained in FP16/BF16:
$$h = W_{\text{NF4}} x + s \cdot B A x$$where \(W_{\text{NF4}}\) is the quantized frozen weight, \(A \in \mathbb{R}^{r \times d}\), \(B \in \mathbb{R}^{d \times r}\), \(r \ll d\), and \(s\) is a scaling factor.
Gradients flow through \(W_{\text{NF4}}\) via dequantization (NF4 to BF16 on the fly) but do not update \(W_{\text{NF4}}\). Only \(A\) and \(B\) receive updates.
15. LLM-QAT: Quantization-Aware Training for Large Language Models#
15.1 Challenges of QAT at LLM Scale#
Applying classical QAT to LLMs (billions of parameters) presents unique challenges:
- Training cost: Full QAT requires backpropagation through the entire model with fake quantization nodes, which is expensive at scale.
- Data requirements: QAT typically needs the full training dataset, which for LLMs is often proprietary or enormous.
- Activation quantization: LLM activations exhibit extreme outlier distributions (especially in attention layers), making activation quantization difficult.
15.2 Data-Free Distillation#
LLM-QAT (Liu et al., 2023) addresses the data problem by generating training data from the FP model itself:
- Prompt the FP32 teacher model with random or seed tokens.
- Generate sequences via autoregressive sampling.
- Use these generated sequences as the training data for QAT.
This is effectively data-free distillation: the teacher provides both the data and the soft targets.
15.3 KV-Cache Quantization#
LLM-QAT specifically addresses key-value cache quantization, which is critical for inference efficiency in autoregressive generation:
$$\text{Attention}(Q, K, V) = \text{softmax}\!\left(\frac{Q K^T}{\sqrt{d_k}}\right) V$$During QAT, fake quantization is applied to the cached \(K\) and \(V\) matrices:
$$K_q = \text{FakeQuant}(K), \quad V_q = \text{FakeQuant}(V)$$This trains the model to be robust to quantized KV-cache at inference time.
15.4 Results#
LLM-QAT achieves W4A8-KV4 (4-bit weights, 8-bit activations, 4-bit KV-cache) with minimal perplexity degradation on LLaMA models, where PTQ methods suffer significant quality loss especially on the KV-cache quantization.
16. Binary and Ternary Networks#
16.1 Binary Neural Networks#
Binary networks represent weights (and optionally activations) using only \({-1, +1}\), replacing multiplications with XNOR operations and additions with popcount.
Binarization function:
$$w_b = \text{sign}(w) = \begin{cases} +1 & \text{if } w \geq 0 \\ -1 & \text{if } w < 0 \end{cases}$$Gradient via STE:
$$\frac{\partial \text{sign}(w)}{\partial w} \approx \mathbf{1}_{|w| \leq 1}$$16.2 XNOR-Net#
XNOR-Net (Rastegari et al., 2016) introduces a scaling factor to improve the approximation quality. For a convolution \(W * X\) where both \(W\) and \(X\) are binarized:
$$W * X \approx (\text{sign}(W) \circledast \text{sign}(X)) \odot \alpha \odot K$$where:
- \(\circledast\) is the binary convolution (XNOR + popcount)
- \(\alpha\) is a per-filter scaling factor
- \(K\) captures the mean absolute value of the input patches
Optimal \(\alpha\) derivation:
We want to minimize:
$$J(\alpha) = \|W - \alpha \cdot \text{sign}(W)\|^2$$Expanding:
$$J(\alpha) = \|W\|^2 - 2\alpha \cdot W^T \text{sign}(W) + \alpha^2 \|\text{sign}(W)\|^2$$Note that \(W^T \text{sign}(W) = \sum_i |w_i| = |W|_1\) and \(|\text{sign}(W)|^2 = n\) (number of elements). Taking the derivative:
$$\frac{\partial J}{\partial \alpha} = -2 \|W\|_1 + 2\alpha n = 0$$$$\alpha^* = \frac{\|W\|_1}{n} = \frac{1}{n}\sum_{i=1}^{n}|w_i|$$So the optimal scaling factor is simply the mean absolute value of the weights.
16.3 Computational Advantage of Binary Convolution#
A standard convolution with \(c\) input channels and \(k \times k\) kernel requires \(c \times k \times k\) multiply-accumulate (MAC) operations per output pixel. A binary convolution replaces this with:
- XNOR: \(c \times k \times k\) XNOR operations (1 clock cycle each on most hardware).
- Popcount: Count the number of 1s in the result.
- Scale: Multiply by \(\alpha\) (one real multiplication per output pixel).
On a 64-bit processor, 64 binary operations can be packed into a single XNOR instruction, giving a theoretical 64x speedup.
FP32 Convolution: Binary Convolution:
w1*x1 + w2*x2 + ... + wn*xn popcount(XNOR(W_packed, X_packed)) * alpha
n multiplications n/64 XNOR ops + 1 multiplication
n additions n/64 popcount ops + 1 scaling16.4 Ternary Weight Networks (TWN)#
TWN (Li et al., 2016) extends binary to ternary: weights take values in \({-1, 0, +1}\). The ternarization function with threshold \(\Delta\):
$$w_t = \begin{cases} +1 & \text{if } w > \Delta \\ 0 & \text{if } |w| \leq \Delta \\ -1 & \text{if } w < -\Delta \end{cases}$$Optimal threshold \(\Delta\):
TWN minimizes \(|W - \alpha \cdot W_t|^2\) where \(W_t\) is the ternary weight. The optimal threshold is derived as:
$$\Delta^* \approx 0.7 \cdot \mathbb{E}[|W|] = 0.7 \cdot \frac{\|W\|_1}{n}$$This approximation comes from assuming the weights follow a normal distribution and finding the threshold that minimizes the expected quantization error. The factor 0.7 arises from the solution to the optimization problem under the Gaussian assumption.
With threshold \(\Delta\) determined, the optimal scaling factor is:
$$\alpha^* = \frac{\sum_{i: |w_i| > \Delta} |w_i|}{|\{i : |w_i| > \Delta\}|}$$which is the mean absolute value of the non-zero (non-pruned) weights.
16.5 Trained Ternary Quantization (TTQ)#
TTQ (Zhu et al., 2017) learns asymmetric scaling factors \(\alpha_p\) (positive) and \(\alpha_n\) (negative):
$$w_t = \begin{cases} \alpha_p & \text{if } w > \Delta \\ 0 & \text{if } |w| \leq \Delta \\ -\alpha_n & \text{if } w < -\Delta \end{cases}$$Gradient derivations:
Using the STE for the ternarization and direct gradients for the scaling factors:
For \(\alpha_p\):
$$\frac{\partial \mathcal{L}}{\partial \alpha_p} = \sum_{i: w_i > \Delta} \frac{\partial \mathcal{L}}{\partial w_{t,i}} \cdot 1 = \sum_{i: w_i > \Delta} \frac{\partial \mathcal{L}}{\partial w_{t,i}}$$For \(\alpha_n\):
$$\frac{\partial \mathcal{L}}{\partial \alpha_n} = \sum_{i: w_i < -\Delta} \frac{\partial \mathcal{L}}{\partial w_{t,i}} \cdot (-1) = -\sum_{i: w_i < -\Delta} \frac{\partial \mathcal{L}}{\partial w_{t,i}}$$For the latent full-precision weights \(w\) (via STE):
$$\frac{\partial \mathcal{L}}{\partial w_i} = \frac{\partial \mathcal{L}}{\partial w_{t,i}} \cdot \begin{cases} 1 & \text{if } w_i > \Delta \\ 1 & \text{if } |w_i| \leq \Delta \\ 1 & \text{if } w_i < -\Delta \end{cases}$$The STE passes gradients through regardless of the ternarization, allowing the latent weights to be updated and potentially change their ternary assignment at the next forward pass.
16.6 Comparison of Binary/Ternary Methods#
| Method | Weight Values | Activation Bits | Scaling | ImageNet Top-1 (ResNet-18) |
|---|---|---|---|---|
| Full Precision | FP32 | FP32 | N/A | 69.6% |
| BWN | \({-\alpha, +\alpha}\) | FP32 | Per-filter | 60.8% |
| XNOR-Net | \({-1, +1}\) | 1-bit | Per-filter + input | 51.2% |
| TWN | \({-\alpha, 0, +\alpha}\) | FP32 | Per-layer | 61.8% |
| TTQ | \({-\alpha_n, 0, +\alpha_p}\) | FP32 | Per-layer, learned | 66.6% |
(Approximate reference values from the original papers.)
17. Practical Considerations#
17.1 Which Layers to Quantize#
Not all layers should be quantized equally:
+-----------------------------------------------+
| Layer Type | Recommendation |
|-------------------------|----------------------|
| First conv layer | 8-bit (sensitive) |
| Last FC / classifier | 8-bit (sensitive) |
| Middle conv layers | 4-bit (robust) |
| Depthwise separable | 8-bit (few params, |
| | high sensitivity) |
| Attention QKV | 8-bit (outlier-prone)|
| Embedding layers | 8-bit or higher |
+-----------------------------------------------+17.2 Handling Activation Outliers#
LLMs and Vision Transformers often exhibit activation outliers (values 10-100x larger than the median). Strategies:
- Per-token quantization: Separate scale per sequence position.
- SmoothQuant: Migrate quantization difficulty from activations to weights by channel-wise scaling.
- Clipped quantization: Learn clipping bounds (PACT/LSQ).
- Mixed precision: Keep outlier-prone layers in higher precision.
17.3 Calibration Dataset Size#
| Purpose | Recommended Size |
|---|---|
| PTQ calibration | 256–1024 samples |
| QAT observer warm-up | 1–5 epochs over full data |
| QAT fine-tuning | 10–30% of original training |
| QLoRA | Same as standard fine-tuning |
17.4 Common Pitfalls#
- Not freezing observers: Leads to oscillating quantization grids and training instability.
- Too high learning rate: QAT is fine-tuning; large LR causes the model to diverge.
- Ignoring BN folding: The quantized model will behave differently at inference if BN was not folded during QAT.
- Symmetric quantization for asymmetric distributions: ReLU outputs are non-negative; use asymmetric quantization for activations.
- Quantizing skip connections: Residual additions require careful attention to ensure both branches share compatible quantization parameters.
- Ignoring hardware constraints: A 3-bit quantization might be optimal in theory but unsupported by target hardware.
17.5 Debugging QAT#
A systematic debugging checklist:
- Verify FP32 accuracy first: The pretrained model should match expected baseline.
- Check observer statistics: Ensure min/max values are reasonable (no NaN, no extreme ranges).
- Monitor per-layer quantization error: Compute \(|W - Q(W)|_2 / |W|_2\) per layer.
- Inspect gradient norms: If gradients vanish or explode after inserting fake quantization, something is wrong.
- Compare FP32 forward vs. fake-quant forward: On the same input, the output difference indicates total quantization noise.
- Profile accuracy vs. epoch: Accuracy should recover and stabilize; if it diverges, reduce LR or increase bit-width.
18. Framework Comparison#
18.1 PyTorch (torch.ao.quantization)#
PyTorch offers a mature QAT pipeline via torch.ao.quantization:
import torch
from torch.ao.quantization import get_default_qat_qconfig, prepare_qat, convert
# Step 1: Define QAT config
model.qconfig = get_default_qat_qconfig('fbgemm') # or 'qnnpack'
# Step 2: Fuse modules (Conv+BN+ReLU)
model_fused = torch.ao.quantization.fuse_modules(
model, [['conv1', 'bn1', 'relu1']]
)
# Step 3: Prepare QAT (inserts fake quant nodes)
model_prepared = prepare_qat(model_fused.train())
# Step 4: Fine-tune
for epoch in range(num_epochs):
train_one_epoch(model_prepared, train_loader, optimizer)
if epoch == observer_freeze_epoch:
model_prepared.apply(torch.ao.quantization.disable_observer)
# Step 5: Convert to quantized model
model_quantized = convert(model_prepared.eval())Pros: Native integration, extensive operator support, easy debugging. Cons: Limited to specific backends (fbgemm for x86, qnnpack for ARM).
18.2 TensorFlow / TF Model Optimization Toolkit#
import tensorflow_model_optimization as tfmot
# Apply QAT to entire model
qat_model = tfmot.quantization.keras.quantize_model(model)
# Or selective quantization
def apply_quantization_to_dense(layer):
if isinstance(layer, tf.keras.layers.Dense):
return tfmot.quantization.keras.quantize_annotate_layer(layer)
return layer
annotated_model = tf.keras.models.clone_model(
model, clone_function=apply_quantization_to_dense
)
qat_model = tfmot.quantization.keras.quantize_apply(annotated_model)Pros: Good TFLite integration, well-documented. Cons: Less flexible custom quantization, Keras-centric.
18.3 NVIDIA TensorRT#
TensorRT is primarily an inference engine but supports QAT model import:
- Train with QAT in PyTorch (using TensorRT-compatible fake quantization nodes from
pytorch-quantizationlibrary). - Export to ONNX with Q/DQ (Quantize/Dequantize) nodes.
- Import into TensorRT, which recognizes Q/DQ patterns and fuses them into INT8 kernels.
PyTorch QAT Model
|
v
[ Export to ONNX with Q/DQ nodes ]
|
v
[ TensorRT Builder ]
|
v
[ Optimized INT8 Engine ]Pros: Best inference performance on NVIDIA GPUs, hardware-aware optimization. Cons: NVIDIA-only, limited to supported layer patterns.
18.4 Qualcomm AIMET#
AIMET (AI Model Efficiency Toolkit) provides advanced QAT features:
- Adaptive rounding (AdaRound): Learns whether to round up or down per weight element.
- Cross-layer equalization (CLE): Balances weight ranges across layers before quantization.
- Bias correction: Corrects bias shift introduced by quantization.
- Sequential MSE: Optimizes quantization parameters layer-by-layer to minimize reconstruction error.
from aimet_torch.quantsim import QuantizationSimModel
sim = QuantizationSimModel(model, dummy_input,
quant_scheme='tf_enhanced',
default_param_bw=8,
default_output_bw=8)
sim.compute_encodings(forward_pass_callback, forward_pass_callback_args)
# Fine-tune
for epoch in range(num_epochs):
train_one_epoch(sim.model, train_loader, optimizer)
sim.export('./output', 'quantized_model', dummy_input)Pros: Targets Qualcomm Snapdragon (widely deployed), advanced PTQ/QAT techniques. Cons: Qualcomm-focused, smaller community.
18.5 Summary Comparison#
| Feature | PyTorch | TensorFlow | TensorRT | AIMET |
|---|---|---|---|---|
| QAT support | Native | Via toolkit | Import only | Native |
| Custom quantizers | Easy | Moderate | Limited | Moderate |
| Target hardware | x86, ARM | Mobile (TFLite) | NVIDIA GPU | Snapdragon |
| Mixed-precision QAT | Manual | Limited | Automatic | Manual |
| BN folding | Built-in | Built-in | Automatic | Built-in |
| Community size | Largest | Large | Large | Small |
| LSQ / learnable params | Custom needed | Custom needed | N/A | Supported |
19. PTQ vs. QAT Decision Matrix#
Choosing between PTQ and QAT depends on multiple factors. Use the following decision matrix:
START
|
v
+-----------------+
| Target >= 8-bit |---YES---> Try PTQ first
+-----------------+ |
| v
NO +------------------+
| | PTQ accuracy OK? |--YES--> Use PTQ
v +------------------+
+-----------------+ |
| Have training | NO
| data + compute? | |
+-----------------+ v
| | Use QAT (fine-tune
YES NO from PTQ model)
| |
v v
Use QAT Try advanced PTQ
(GPTQ, AWQ, etc.)
|
v
+------------------+
| Accuracy OK? |--YES--> Use advanced PTQ
+------------------+
|
NO
v
Need QAT (or accept
accuracy trade-off)19.1 Detailed Comparison Table#
| Criterion | PTQ | QAT |
|---|---|---|
| Training data needed | Small calibration set (256-1024 samples) | Full training set |
| Compute cost | Minutes | Hours to days |
| Accuracy at 8-bit | Excellent (< 1% drop) | Near-zero drop |
| Accuracy at 4-bit (weights) | Good with advanced methods (GPTQ, AWQ) | Excellent |
| Accuracy at 4-bit (weights + activations) | Moderate to poor | Good |
| Accuracy at 2-bit | Poor | Moderate (with progressive/KD) |
| Accuracy at 1-bit (binary) | Not applicable | Possible with specialized methods |
| Implementation complexity | Low | Moderate to high |
| Hyperparameter tuning | Minimal | Significant (LR, epochs, observer schedule) |
| Model architecture changes | None | May need BN folding, skip connection handling |
| Reproducibility | High (deterministic) | Moderate (training variance) |
| Time-to-deployment | Fast | Slower |
| Best for | Production, 8-bit, quick deployment | Low-bitwidth, accuracy-critical, research |
19.2 Recommended Workflow#
- Always start with PTQ. If the accuracy meets requirements, stop.
- If PTQ fails: Try advanced PTQ (GPTQ for weights, SmoothQuant for activations).
- If advanced PTQ fails: Apply QAT, starting from the PTQ model as initialization.
- If QAT alone is insufficient: Add knowledge distillation and/or progressive quantization.
- For extreme compression (binary/ternary): Use specialized architectures (XNOR-Net, ReActNet) trained from scratch with QAT.
20. Emerging Directions#
20.1 Quantization for Diffusion Models#
Diffusion models pose unique challenges because the noise level changes at each denoising step. Time-step-aware quantization adapts the quantization parameters based on the current diffusion time step.
20.2 Quantization for Mixture-of-Experts (MoE)#
MoE models like Mixtral have sparse activation patterns. Quantizing inactive experts more aggressively (or offloading them in low precision) can dramatically reduce memory with minimal accuracy impact.
20.3 FP8 Training#
NVIDIA’s Hopper architecture natively supports FP8 (E4M3 and E5M2 formats). FP8 training can be viewed as a form of QAT where the “quantization” is to a low-precision floating-point format rather than integer. The STE-like gradient handling is built into the hardware.
20.4 Learnable Quantization Beyond Uniform#
Non-uniform quantization (e.g., log-scale, power-of-two, lookup-table-based) can better match the actual weight/activation distributions. Methods like EWGS (Extremely Low-bit Weights with Gradient Scaling) and APoT (Additive Powers-of-Two) explore this space.
21. Summary#
Quantization-Aware Training is the most powerful technique for producing high-accuracy quantized models, especially at low bit-widths. The key concepts are:
Straight-Through Estimator: Enables gradient flow through non-differentiable quantization by approximating the backward pass as the identity within the clipping range.
Fake Quantization Nodes: Simulate quantization during training while keeping computations in floating-point, allowing standard training infrastructure to be used.
Learnable Quantization Parameters: Methods like LSQ, LSQ+, and PACT make the quantization grid parameters (step size, clipping bounds, offsets) learnable, improving accuracy.
BN Folding: Must be simulated during QAT to ensure consistency between training and inference quantization.
Knowledge Distillation: Provides complementary accuracy improvements, especially at extreme bit-widths.
Binary/Ternary Networks: Push quantization to the extreme (1-2 bits), enabling dramatic speedups via XNOR/popcount operations at the cost of significant accuracy reduction.
QLoRA and LLM-QAT: Extend quantization-aware techniques to the LLM regime with innovations like NF4, double quantization, and data-free distillation.
Mixed-Precision: Allocates bits non-uniformly across layers based on sensitivity analysis, achieving better accuracy-efficiency trade-offs.
The field continues to evolve rapidly, driven by the relentless growth of model sizes and the demand for efficient deployment on diverse hardware platforms. Understanding QAT deeply is essential for any engineer working on deploying neural networks in resource-constrained environments.
References#
- Bengio, Y., Leonard, N., & Courville, A. (2013). Estimating or Propagating Gradients Through Stochastic Neurons for Conditional Computation. arXiv:1308.3432.
- Esser, S. K., et al. (2020). Learned Step Size Quantization (LSQ). ICLR 2020.
- Bhalgat, Y., et al. (2020). LSQ+: Improving Low-bit Quantization Through Learnable Offsets and Better Initialization. ECCV 2020.
- Choi, J., et al. (2018). PACT: Parameterized Clipping Activation for Quantized Neural Networks. ICLR 2018 Workshop.
- Zhou, S., et al. (2016). DoReFa-Net: Training Low Bitwidth Convolutional Neural Networks with Low Bitwidth Gradients. arXiv:1606.06160.
- Rastegari, M., et al. (2016). XNOR-Net: ImageNet Classification Using Binary Convolutional Neural Networks. ECCV 2016.
- Li, F., Zhang, B., & Liu, B. (2016). Ternary Weight Networks. arXiv:1605.04711.
- Zhu, C., et al. (2017). Trained Ternary Quantization (TTQ). ICLR 2017.
- Dettmers, T., et al. (2023). QLoRA: Efficient Finetuning of Quantized Language Models. NeurIPS 2023.
- Liu, Z., et al. (2023). LLM-QAT: Data-Free Quantization Aware Training for Large Language Models. arXiv:2305.17888.
- Wang, K., et al. (2019). HAQ: Hardware-Aware Automated Quantization. CVPR 2019.
- Dong, Z., et al. (2019). HAWQ: Hessian AWare Quantization of Neural Networks with Mixed-Precision. ICCV 2019.
- Jacob, B., et al. (2018). Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. CVPR 2018.