Skip to content

INT8 GEMM

W8A16 (INT8 weights, FP16 activations) matrix multiplication kernels.

Overview

INT8 GEMM enables memory-efficient inference by storing weights in 8-bit format while keeping activations in FP16. This provides:

  • 2x memory reduction for weight storage
  • Faster inference due to reduced memory bandwidth requirements
  • Minimal accuracy loss with per-channel quantization

Computation scheme:

output = (activation_fp16 @ weight_int8.dequantize()) + bias

The dequantization happens in registers during the matmul, so the memory traffic reduction directly translates to speedup.

Performance Characteristics

Configuration FP16 GEMM INT8 GEMM Speedup Memory Saved
4096x4096 156 μs 48 μs 3.3x 16 MB
8192x8192 620 μs 189 μs 3.3x 64 MB
4096x11008 418 μs 128 μs 3.3x 43 MB

Usage Examples

High-Level: QuantizedLinear

The easiest way to use INT8 inference:

import torch
from rotalabs_accel import QuantizedLinear

# Convert existing pretrained layer
linear = torch.nn.Linear(4096, 4096)
linear.load_state_dict(pretrained_weights)

# Quantize to INT8
qlinear = QuantizedLinear.from_linear(linear)
qlinear = qlinear.cuda()

# Use like normal
x = torch.randn(2, 512, 4096, device="cuda", dtype=torch.float16)
y = qlinear(x)  # Output is FP16

Low-Level: Int8Linear

For more control over quantization:

from rotalabs_accel import Int8Linear

# Create layer
linear = Int8Linear(
    in_features=4096,
    out_features=4096,
    bias=False,
)

# Quantize weights manually
linear.quantize_weights(pretrained_weight_fp16)

# Forward pass
y = linear(x)

Functional API

For custom implementations:

from rotalabs_accel import int8_gemm, quantize_weight_per_channel

# Quantize weights once
weight_int8, scales = quantize_weight_per_channel(weight_fp16)

# Use in forward pass
output = int8_gemm(x, weight_int8, scales)

Quantizing an Entire Model

from rotalabs_accel import QuantizedLinear

def quantize_model(model):
    """Replace all Linear layers with QuantizedLinear."""
    for name, module in model.named_children():
        if isinstance(module, torch.nn.Linear):
            quantized = QuantizedLinear.from_linear(module)
            setattr(model, name, quantized)
        else:
            quantize_model(module)
    return model

# Quantize LLM
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
model = quantize_model(model)
model = model.cuda()

# Memory usage: 14GB -> 7GB (approximately)

API Reference

Functions

int8_gemm

int8_gemm(x: Tensor, weight_int8: Tensor, scale: Tensor, bias: Optional[Tensor] = None, weight_transposed: Optional[Tensor] = None, scale_fp16: Optional[Tensor] = None, use_cublas: Optional[bool] = None) -> torch.Tensor

W8A16 GEMM: FP16 activations x INT8 weights.

Computes: y = x @ (weight_int8 * scale) + bias

Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

Parameters:

Name Type Description Default
x Tensor

FP16 activation tensor of shape (..., K).

required
weight_int8 Tensor

INT8 weight tensor of shape (N, K).

required
scale Tensor

FP32 scale tensor of shape (N,) for per-output-channel dequant.

required
bias Optional[Tensor]

Optional FP16 bias of shape (N,).

None
weight_transposed Optional[Tensor]

Optional pre-transposed weight (K, N).

None
scale_fp16 Optional[Tensor]

Optional pre-converted FP16 scale (unused in Triton path).

None
use_cublas Optional[bool]

Unused, kept for API compatibility.

None

Returns:

Type Description
Tensor

FP16 output tensor of shape (..., N).

Example

x = torch.randn(2, 8, 64) weight_int8 = torch.randint(-128, 127, (128, 64), dtype=torch.int8) scale = torch.ones(128) y = int8_gemm(x, weight_int8, scale)

Source code in src/rotalabs_accel/kernels/gemm.py
def int8_gemm(
    x: torch.Tensor,
    weight_int8: torch.Tensor,
    scale: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
    weight_transposed: Optional[torch.Tensor] = None,
    scale_fp16: Optional[torch.Tensor] = None,
    use_cublas: Optional[bool] = None,
) -> torch.Tensor:
    """
    W8A16 GEMM: FP16 activations x INT8 weights.

    Computes: y = x @ (weight_int8 * scale) + bias

    Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

    Args:
        x: FP16 activation tensor of shape (..., K).
        weight_int8: INT8 weight tensor of shape (N, K).
        scale: FP32 scale tensor of shape (N,) for per-output-channel dequant.
        bias: Optional FP16 bias of shape (N,).
        weight_transposed: Optional pre-transposed weight (K, N).
        scale_fp16: Optional pre-converted FP16 scale (unused in Triton path).
        use_cublas: Unused, kept for API compatibility.

    Returns:
        FP16 output tensor of shape (..., N).

    Example:
        >>> x = torch.randn(2, 8, 64)
        >>> weight_int8 = torch.randint(-128, 127, (128, 64), dtype=torch.int8)
        >>> scale = torch.ones(128)
        >>> y = int8_gemm(x, weight_int8, scale)
    """
    assert weight_int8.dtype == torch.int8, "Weight must be INT8"

    # Get dimensions
    original_shape = x.shape
    K = x.shape[-1]
    N = weight_int8.shape[0]

    assert weight_int8.shape == (N, K), f"Weight shape mismatch: {weight_int8.shape} vs ({N}, {K})"
    assert scale.shape == (N,), f"Scale shape mismatch: {scale.shape} vs ({N},)"

    # Use Triton kernel if available and on CUDA
    if HAS_TRITON and x.is_cuda and weight_int8.is_cuda and scale.is_cuda:
        return _int8_gemm_triton(x, weight_int8, scale, bias, weight_transposed)

    # Fallback to PyTorch
    return int8_gemm_torch(x, weight_int8, scale, bias)

int8_gemm_torch

int8_gemm_torch(x: Tensor, weight_int8: Tensor, scale: Tensor, bias: Optional[Tensor] = None) -> torch.Tensor

PyTorch reference implementation of W8A16 GEMM.

Works on any device (CPU or CUDA).

Source code in src/rotalabs_accel/kernels/gemm.py
def int8_gemm_torch(
    x: torch.Tensor,
    weight_int8: torch.Tensor,
    scale: torch.Tensor,
    bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
    """
    PyTorch reference implementation of W8A16 GEMM.

    Works on any device (CPU or CUDA).
    """
    # Dequantize weight to same dtype as input
    weight_dequant = dequantize(weight_int8, scale, dtype=x.dtype, dim=0)

    # Standard matmul
    y = torch.nn.functional.linear(x, weight_dequant, bias)

    return y

Classes

Int8Linear

Bases: Module

Linear layer using INT8 weights with optimized GEMM kernel.

Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

Parameters:

Name Type Description Default
in_features int

Input dimension (K).

required
out_features int

Output dimension (N).

required
bias bool

Whether to include bias.

False
Example

linear = Int8Linear(64, 128) linear.quantize_weights(torch.randn(128, 64)) y = linear(torch.randn(2, 8, 64))

Source code in src/rotalabs_accel/kernels/gemm.py
class Int8Linear(torch.nn.Module):
    """
    Linear layer using INT8 weights with optimized GEMM kernel.

    Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

    Args:
        in_features: Input dimension (K).
        out_features: Output dimension (N).
        bias: Whether to include bias.

    Example:
        >>> linear = Int8Linear(64, 128)
        >>> linear.quantize_weights(torch.randn(128, 64))
        >>> y = linear(torch.randn(2, 8, 64))
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = False,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # INT8 weights: (out_features, in_features)
        self.register_buffer(
            "weight_int8",
            torch.zeros(out_features, in_features, dtype=torch.int8)
        )
        self.register_buffer(
            "scale",
            torch.ones(out_features, dtype=torch.float32)
        )
        # Pre-transposed weight for Triton path
        self.register_buffer(
            "weight_transposed",
            torch.zeros(in_features, out_features, dtype=torch.int8)
        )

        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(out_features, dtype=torch.float16))
        else:
            self.register_parameter("bias", None)

        self._quantized = False

    def quantize_weights(self, weight: torch.Tensor) -> None:
        """Quantize and store FP16/FP32 weights as INT8."""
        assert weight.shape == (self.out_features, self.in_features)

        # Quantization happens on CPU
        weight_cpu = weight.cpu() if weight.is_cuda else weight
        weight_int8, scale = quantize_weight_per_channel(weight_cpu)

        # Move to same device as buffers and copy
        device = self.weight_int8.device
        self.weight_int8.copy_(weight_int8.to(device))
        self.scale.copy_(scale.to(device))
        self.weight_transposed.copy_(weight_int8.t().contiguous().to(device))
        self._quantized = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass using optimized INT8 GEMM kernel."""
        if not self._quantized:
            raise RuntimeError("Weights not quantized. Call quantize_weights() first.")

        return int8_gemm(
            x, self.weight_int8, self.scale, self.bias,
            weight_transposed=self.weight_transposed,
        )

    @classmethod
    def from_linear(cls, linear: torch.nn.Linear) -> "Int8Linear":
        """Convert nn.Linear to Int8Linear."""
        has_bias = linear.bias is not None
        int8_linear = cls(linear.in_features, linear.out_features, bias=has_bias)

        # Move to same device as input linear, then quantize
        device = linear.weight.device
        int8_linear = int8_linear.to(device)
        int8_linear.quantize_weights(linear.weight.data)

        # Copy bias if present
        if has_bias:
            int8_linear.bias.data.copy_(linear.bias.data.half())

        return int8_linear

    def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, out_features={self.out_features}, "
            f"bias={self.bias is not None}, quantized={self._quantized}"
        )
__init__
__init__(in_features: int, out_features: int, bias: bool = False)
Source code in src/rotalabs_accel/kernels/gemm.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    bias: bool = False,
):
    super().__init__()
    self.in_features = in_features
    self.out_features = out_features

    # INT8 weights: (out_features, in_features)
    self.register_buffer(
        "weight_int8",
        torch.zeros(out_features, in_features, dtype=torch.int8)
    )
    self.register_buffer(
        "scale",
        torch.ones(out_features, dtype=torch.float32)
    )
    # Pre-transposed weight for Triton path
    self.register_buffer(
        "weight_transposed",
        torch.zeros(in_features, out_features, dtype=torch.int8)
    )

    if bias:
        self.bias = torch.nn.Parameter(torch.zeros(out_features, dtype=torch.float16))
    else:
        self.register_parameter("bias", None)

    self._quantized = False
quantize_weights
quantize_weights(weight: Tensor) -> None

Quantize and store FP16/FP32 weights as INT8.

Source code in src/rotalabs_accel/kernels/gemm.py
def quantize_weights(self, weight: torch.Tensor) -> None:
    """Quantize and store FP16/FP32 weights as INT8."""
    assert weight.shape == (self.out_features, self.in_features)

    # Quantization happens on CPU
    weight_cpu = weight.cpu() if weight.is_cuda else weight
    weight_int8, scale = quantize_weight_per_channel(weight_cpu)

    # Move to same device as buffers and copy
    device = self.weight_int8.device
    self.weight_int8.copy_(weight_int8.to(device))
    self.scale.copy_(scale.to(device))
    self.weight_transposed.copy_(weight_int8.t().contiguous().to(device))
    self._quantized = True
forward
forward(x: Tensor) -> torch.Tensor

Forward pass using optimized INT8 GEMM kernel.

Source code in src/rotalabs_accel/kernels/gemm.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass using optimized INT8 GEMM kernel."""
    if not self._quantized:
        raise RuntimeError("Weights not quantized. Call quantize_weights() first.")

    return int8_gemm(
        x, self.weight_int8, self.scale, self.bias,
        weight_transposed=self.weight_transposed,
    )

Quantization Details

Symmetric Quantization

We use symmetric quantization with per-channel scales:

scale[i] = max(|weight[i, :]|) / 127
weight_int8[i, :] = round(weight[i, :] / scale[i]).clamp(-128, 127)

This provides a good balance between accuracy and performance.

Per-Channel vs Per-Tensor

Method Accuracy Memory Speed
Per-tensor Lower Minimal overhead Fastest
Per-channel Higher 1 scale per output Slightly slower

We default to per-channel (per-output-row) quantization for better accuracy.

Quantization Error

Typical quantization errors for random weights:

Metric Value
Max absolute error ~0.01
Mean absolute error ~0.002
SNR ~45-50 dB

For real model weights, errors may vary based on weight distribution.

Implementation Notes

Memory Layout

  • Weights: INT8, shape (out_features, in_features)
  • Scales: FP32, shape (out_features,)
  • Activations: FP16, shape (..., in_features)
  • Output: FP16, shape (..., out_features)

Kernel Strategy

The Triton kernel:

  1. Loads weight tiles as INT8 (1 byte per element)
  2. Dequantizes in registers using the scale vector
  3. Performs FP16 matmul with FP32 accumulation
  4. Stores FP16 output

This minimizes memory traffic while maintaining numerical precision.

Fallback Behavior

On CPU or without Triton, the kernel falls back to:

def int8_gemm_torch(x, weight_int8, scale):
    weight_fp = (weight_int8.float() * scale.unsqueeze(1)).to(x.dtype)
    return x @ weight_fp.T

Comparison with Other Quantization Methods

Method Bits Scheme Accuracy Speed
INT8 (this) 8 Symmetric High Fast
GPTQ 4 Asymmetric + groups Medium-High Moderate
AWQ 4 Activation-aware High Moderate
FP8 (Hopper) 8 Native hardware Very High Very Fast

INT8 symmetric quantization is a good default choice that works on all GPUs and provides a solid accuracy/speed tradeoff.

References