INT8 GEMM¶
W8A16 (INT8 weights, FP16 activations) matrix multiplication kernels.
Overview¶
INT8 GEMM enables memory-efficient inference by storing weights in 8-bit format while keeping activations in FP16. This provides:
- 2x memory reduction for weight storage
- Faster inference due to reduced memory bandwidth requirements
- Minimal accuracy loss with per-channel quantization
Computation scheme:
The dequantization happens in registers during the matmul, so the memory traffic reduction directly translates to speedup.
Performance Characteristics¶
| Configuration | FP16 GEMM | INT8 GEMM | Speedup | Memory Saved |
|---|---|---|---|---|
| 4096x4096 | 156 μs | 48 μs | 3.3x | 16 MB |
| 8192x8192 | 620 μs | 189 μs | 3.3x | 64 MB |
| 4096x11008 | 418 μs | 128 μs | 3.3x | 43 MB |
Usage Examples¶
High-Level: QuantizedLinear¶
The easiest way to use INT8 inference:
import torch
from rotalabs_accel import QuantizedLinear
# Convert existing pretrained layer
linear = torch.nn.Linear(4096, 4096)
linear.load_state_dict(pretrained_weights)
# Quantize to INT8
qlinear = QuantizedLinear.from_linear(linear)
qlinear = qlinear.cuda()
# Use like normal
x = torch.randn(2, 512, 4096, device="cuda", dtype=torch.float16)
y = qlinear(x) # Output is FP16
Low-Level: Int8Linear¶
For more control over quantization:
from rotalabs_accel import Int8Linear
# Create layer
linear = Int8Linear(
in_features=4096,
out_features=4096,
bias=False,
)
# Quantize weights manually
linear.quantize_weights(pretrained_weight_fp16)
# Forward pass
y = linear(x)
Functional API¶
For custom implementations:
from rotalabs_accel import int8_gemm, quantize_weight_per_channel
# Quantize weights once
weight_int8, scales = quantize_weight_per_channel(weight_fp16)
# Use in forward pass
output = int8_gemm(x, weight_int8, scales)
Quantizing an Entire Model¶
from rotalabs_accel import QuantizedLinear
def quantize_model(model):
"""Replace all Linear layers with QuantizedLinear."""
for name, module in model.named_children():
if isinstance(module, torch.nn.Linear):
quantized = QuantizedLinear.from_linear(module)
setattr(model, name, quantized)
else:
quantize_model(module)
return model
# Quantize LLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = quantize_model(model)
model = model.cuda()
# Memory usage: 14GB -> 7GB (approximately)
API Reference¶
Functions¶
int8_gemm ¶
int8_gemm(x: Tensor, weight_int8: Tensor, scale: Tensor, bias: Optional[Tensor] = None, weight_transposed: Optional[Tensor] = None, scale_fp16: Optional[Tensor] = None, use_cublas: Optional[bool] = None) -> torch.Tensor
W8A16 GEMM: FP16 activations x INT8 weights.
Computes: y = x @ (weight_int8 * scale) + bias
Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
FP16 activation tensor of shape (..., K). |
required |
weight_int8
|
Tensor
|
INT8 weight tensor of shape (N, K). |
required |
scale
|
Tensor
|
FP32 scale tensor of shape (N,) for per-output-channel dequant. |
required |
bias
|
Optional[Tensor]
|
Optional FP16 bias of shape (N,). |
None
|
weight_transposed
|
Optional[Tensor]
|
Optional pre-transposed weight (K, N). |
None
|
scale_fp16
|
Optional[Tensor]
|
Optional pre-converted FP16 scale (unused in Triton path). |
None
|
use_cublas
|
Optional[bool]
|
Unused, kept for API compatibility. |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
FP16 output tensor of shape (..., N). |
Example
x = torch.randn(2, 8, 64) weight_int8 = torch.randint(-128, 127, (128, 64), dtype=torch.int8) scale = torch.ones(128) y = int8_gemm(x, weight_int8, scale)
Source code in src/rotalabs_accel/kernels/gemm.py
int8_gemm_torch ¶
int8_gemm_torch(x: Tensor, weight_int8: Tensor, scale: Tensor, bias: Optional[Tensor] = None) -> torch.Tensor
PyTorch reference implementation of W8A16 GEMM.
Works on any device (CPU or CUDA).
Source code in src/rotalabs_accel/kernels/gemm.py
Classes¶
Int8Linear ¶
Bases: Module
Linear layer using INT8 weights with optimized GEMM kernel.
Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimension (K). |
required |
out_features
|
int
|
Output dimension (N). |
required |
bias
|
bool
|
Whether to include bias. |
False
|
Example
linear = Int8Linear(64, 128) linear.quantize_weights(torch.randn(128, 64)) y = linear(torch.randn(2, 8, 64))
Source code in src/rotalabs_accel/kernels/gemm.py
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 | |
__init__ ¶
Source code in src/rotalabs_accel/kernels/gemm.py
quantize_weights ¶
Quantize and store FP16/FP32 weights as INT8.
Source code in src/rotalabs_accel/kernels/gemm.py
forward ¶
Forward pass using optimized INT8 GEMM kernel.
Source code in src/rotalabs_accel/kernels/gemm.py
Quantization Details¶
Symmetric Quantization¶
We use symmetric quantization with per-channel scales:
scale[i] = max(|weight[i, :]|) / 127
weight_int8[i, :] = round(weight[i, :] / scale[i]).clamp(-128, 127)
This provides a good balance between accuracy and performance.
Per-Channel vs Per-Tensor¶
| Method | Accuracy | Memory | Speed |
|---|---|---|---|
| Per-tensor | Lower | Minimal overhead | Fastest |
| Per-channel | Higher | 1 scale per output | Slightly slower |
We default to per-channel (per-output-row) quantization for better accuracy.
Quantization Error¶
Typical quantization errors for random weights:
| Metric | Value |
|---|---|
| Max absolute error | ~0.01 |
| Mean absolute error | ~0.002 |
| SNR | ~45-50 dB |
For real model weights, errors may vary based on weight distribution.
Implementation Notes¶
Memory Layout¶
- Weights: INT8, shape
(out_features, in_features) - Scales: FP32, shape
(out_features,) - Activations: FP16, shape
(..., in_features) - Output: FP16, shape
(..., out_features)
Kernel Strategy¶
The Triton kernel:
- Loads weight tiles as INT8 (1 byte per element)
- Dequantizes in registers using the scale vector
- Performs FP16 matmul with FP32 accumulation
- Stores FP16 output
This minimizes memory traffic while maintaining numerical precision.
Fallback Behavior¶
On CPU or without Triton, the kernel falls back to:
def int8_gemm_torch(x, weight_int8, scale):
weight_fp = (weight_int8.float() * scale.unsqueeze(1)).to(x.dtype)
return x @ weight_fp.T
Comparison with Other Quantization Methods¶
| Method | Bits | Scheme | Accuracy | Speed |
|---|---|---|---|---|
| INT8 (this) | 8 | Symmetric | High | Fast |
| GPTQ | 4 | Asymmetric + groups | Medium-High | Moderate |
| AWQ | 4 | Activation-aware | High | Moderate |
| FP8 (Hopper) | 8 | Native hardware | Very High | Very Fast |
INT8 symmetric quantization is a good default choice that works on all GPUs and provides a solid accuracy/speed tradeoff.