Quantization¶
INT8 symmetric quantization utilities for efficient inference.
Overview¶
This module provides utilities for quantizing neural network weights to INT8 format, enabling memory-efficient inference with minimal accuracy loss.
Quantization Scheme¶
We use symmetric quantization with the following formula:
To dequantize:
Benefits¶
| Aspect | FP16 | INT8 |
|---|---|---|
| Memory per weight | 2 bytes | 1 byte |
| Memory reduction | - | 50% |
| Accuracy | Baseline | ~99.5%+ of baseline |
Use Cases¶
- Large model inference: Fit bigger models in GPU memory
- Deployment: Reduce model size for edge devices
- Batched inference: Handle more concurrent requests
Quick Start¶
from rotalabs_accel import (
quantize_symmetric,
dequantize,
quantize_weight_per_channel,
calculate_quantization_error,
QuantizedLinear,
)
# Quantize a weight tensor
weight = torch.randn(4096, 4096, dtype=torch.float16)
weight_int8, scale = quantize_symmetric(weight)
# Check accuracy
errors = calculate_quantization_error(weight, weight_int8, scale)
print(f"SNR: {errors['snr_db']:.1f} dB")
# Use in a model
linear = torch.nn.Linear(4096, 4096)
qlinear = QuantizedLinear.from_linear(linear)
Quantization Granularity¶
Per-Tensor Quantization¶
One scale for the entire tensor. Fastest but lowest accuracy.
Per-Channel Quantization¶
One scale per output channel. Better accuracy, minimal overhead.
Recommendation: Use per-channel for best accuracy/speed tradeoff.
API Reference¶
Functions¶
quantize_symmetric ¶
quantize_symmetric(tensor: Tensor, bits: int = 8, dim: Optional[int] = None) -> tuple[torch.Tensor, torch.Tensor]
Symmetric quantization of tensor to INT8.
Computes per-tensor or per-channel quantization using symmetric scheme: - scale = max(|tensor|) / (2^(bits-1) - 1) - quantized = round(tensor / scale).clamp(-128, 127)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tensor
|
Tensor
|
Input tensor to quantize (typically FP16 or FP32 weights). |
required |
bits
|
int
|
Number of bits for quantization (default: 8). |
8
|
dim
|
Optional[int]
|
Dimension for per-channel quantization. If None, uses per-tensor. For weight matrices (out_features, in_features), use dim=0 for per-output-channel quantization. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Tuple of (quantized_int8, scale): |
Tensor
|
|
tuple[Tensor, Tensor]
|
|
Example
weight = torch.randn(4096, 4096, dtype=torch.float16) weight_int8, scale = quantize_symmetric(weight) weight_fp16 = dequantize(weight_int8, scale) error = (weight - weight_fp16).abs().max()
Source code in src/rotalabs_accel/quantization/symmetric.py
dequantize ¶
dequantize(quantized: Tensor, scale: Tensor, dtype: dtype = torch.float16, dim: Optional[int] = None) -> torch.Tensor
Dequantize INT8 tensor back to floating point.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
quantized
|
Tensor
|
INT8 quantized tensor. |
required |
scale
|
Tensor
|
Scale factor(s) from quantization. |
required |
dtype
|
dtype
|
Output dtype (default: float16). |
float16
|
dim
|
Optional[int]
|
Dimension along which scale was computed (for broadcasting). For weight matrices with per-output-channel quantization, use dim=0. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Dequantized tensor in specified dtype. |
Example
weight_int8, scale = quantize_symmetric(weight_fp16) weight_restored = dequantize(weight_int8, scale)
Source code in src/rotalabs_accel/quantization/symmetric.py
quantize_weight_per_channel ¶
Quantize weight matrix with per-output-channel scales.
For a weight matrix of shape (out_features, in_features), computes one scale per output channel (row).
This is the common scheme for W8A16 inference, providing good accuracy while allowing efficient dequantization during matmul.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
weight
|
Tensor
|
Weight tensor of shape (out_features, in_features). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Tuple of (weight_int8, scales): |
Tensor
|
|
tuple[Tensor, Tensor]
|
|
Example
W = torch.randn(4096, 4096, dtype=torch.float16) W_int8, scales = quantize_weight_per_channel(W)
Source code in src/rotalabs_accel/quantization/symmetric.py
calculate_quantization_error ¶
calculate_quantization_error(original: Tensor, quantized: Tensor, scale: Tensor, dim: Optional[int] = None) -> dict[str, float]
Calculate quantization error metrics.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
original
|
Tensor
|
Original FP tensor. |
required |
quantized
|
Tensor
|
INT8 quantized tensor. |
required |
scale
|
Tensor
|
Scale factor(s). |
required |
dim
|
Optional[int]
|
Dimension for scale broadcasting. |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dictionary with error metrics: |
dict[str, float]
|
|
dict[str, float]
|
|
dict[str, float]
|
|
dict[str, float]
|
|
Source code in src/rotalabs_accel/quantization/symmetric.py
Classes¶
QuantizedLinear ¶
Bases: Module
Linear layer with INT8 quantized weights.
Stores weights in INT8 format and dequantizes on-the-fly during forward pass. This is a reference implementation - the actual kernel-level optimization happens in the Triton INT8 GEMM kernel.
For W8A16 inference: - Weights: INT8 (2x memory reduction) - Activations: FP16 - Compute: FP16 with FP32 accumulation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimension. |
required |
out_features
|
int
|
Output dimension. |
required |
bias
|
bool
|
Whether to include bias (default: False for LLM weights). |
False
|
Example
linear = QuantizedLinear(4096, 4096) linear.quantize_weights(pretrained_weight) y = linear(x) # x is FP16, y is FP16
Source code in src/rotalabs_accel/quantization/symmetric.py
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 | |
__init__ ¶
Source code in src/rotalabs_accel/quantization/symmetric.py
quantize_weights ¶
Quantize and store weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
weight
|
Tensor
|
FP16 or FP32 weight tensor of shape (out_features, in_features). |
required |
Source code in src/rotalabs_accel/quantization/symmetric.py
forward ¶
Forward pass with on-the-fly dequantization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (..., in_features). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of shape (..., out_features). |
Source code in src/rotalabs_accel/quantization/symmetric.py
from_linear
classmethod
¶
Create QuantizedLinear from existing nn.Linear.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
linear
|
Linear
|
Pretrained linear layer. |
required |
Returns:
| Type | Description |
|---|---|
QuantizedLinear
|
QuantizedLinear with quantized weights. |
Source code in src/rotalabs_accel/quantization/symmetric.py
Best Practices¶
1. Quantize After Training¶
Quantize pretrained weights, not randomly initialized ones:
# Good: quantize pretrained weights
model = load_pretrained_model()
for module in model.modules():
if isinstance(module, nn.Linear):
qmodule = QuantizedLinear.from_linear(module)
# replace module with qmodule
2. Evaluate Before Deployment¶
Always check quantization accuracy on your specific model:
# Run validation before and after quantization
baseline_loss = evaluate(model)
quantize_model(model)
quantized_loss = evaluate(model)
print(f"Loss increase: {quantized_loss - baseline_loss:.4f}")
3. Keep Certain Layers in FP16¶
Some layers are more sensitive to quantization:
- First and last layers
- Layers with small weight magnitudes
- Attention output projections
# Skip quantizing sensitive layers
for name, module in model.named_modules():
if "lm_head" in name or "embed" in name:
continue # Keep in FP16
if isinstance(module, nn.Linear):
# Quantize
Error Metrics¶
The calculate_quantization_error function returns:
| Metric | Description | Typical Value |
|---|---|---|
max_abs_error |
Maximum absolute difference | < 0.02 |
mean_abs_error |
Mean absolute difference | < 0.005 |
relative_error_pct |
Max relative error for significant values | < 1% |
snr_db |
Signal-to-noise ratio | > 40 dB |
Values may vary based on weight distribution. Lower SNR indicates more quantization error.