Skip to content

Normalization

RMSNorm (Root Mean Square Layer Normalization) kernels with optional residual fusion.

Overview

RMSNorm is a simpler alternative to LayerNorm used in modern LLMs like LLaMA, Mistral, and Qwen. Unlike LayerNorm, it doesn't center the input (no mean subtraction), which reduces computation.

Mathematical formula:

\[ \text{RMSNorm}(x) = \frac{x}{\sqrt{\frac{1}{n}\sum_{i=1}^{n} x_i^2 + \epsilon}} \cdot \gamma \]

Where:

  • \(x\) is the input tensor
  • \(n\) is the hidden dimension
  • \(\epsilon\) is a small constant for numerical stability (typically 1e-6)
  • \(\gamma\) is the learnable weight parameter

Performance Characteristics

RMSNorm is a memory-bound operation with low arithmetic intensity (~1-2 FLOPs/byte). The Triton kernel provides speedups by:

  1. Fusing operations: Combines variance computation, normalization, and scaling in a single kernel
  2. Reducing memory traffic: Reads input once, writes output once
  3. Using FP32 accumulation: Ensures numerical stability while keeping I/O in FP16
Configuration PyTorch Triton Speedup
hidden=4096, seq=2048 45 μs 12 μs 3.8x
hidden=8192, seq=2048 89 μs 24 μs 3.7x
hidden=4096, seq=8192 178 μs 47 μs 3.8x

Usage Examples

Basic RMSNorm

import torch
from rotalabs_accel import TritonRMSNorm, rmsnorm

# Module API (recommended for models)
norm = TritonRMSNorm(hidden_size=4096, eps=1e-6)
x = torch.randn(2, 512, 4096, device="cuda", dtype=torch.float16)
y = norm(x)

# Functional API (for custom implementations)
weight = torch.ones(4096, device="cuda", dtype=torch.float16)
y = rmsnorm(x, weight, eps=1e-6)

Fused Residual + RMSNorm

In transformer blocks, RMSNorm is typically applied after adding a residual:

from rotalabs_accel import rmsnorm_residual_fused

# Standard pattern (two operations):
# x = x + residual
# x = rmsnorm(x, weight)

# Fused version (one operation, ~2x faster):
x = rmsnorm_residual_fused(x, residual, weight, eps=1e-6)

The fused version eliminates an intermediate tensor allocation and memory round-trip.


API Reference

Functions

rmsnorm

rmsnorm(x: Tensor, weight: Tensor, eps: float = 1e-06) -> torch.Tensor

Apply RMS normalization to input tensor.

Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (..., hidden_dim).

required
weight Tensor

Learnable scale parameter of shape (hidden_dim,).

required
eps float

Small constant for numerical stability.

1e-06

Returns:

Type Description
Tensor

Normalized tensor of same shape as input.

Example

x = torch.randn(2, 8, 64) weight = torch.ones(64) y = rmsnorm(x, weight, eps=1e-6)

Source code in src/rotalabs_accel/kernels/normalization.py
def rmsnorm(
    x: torch.Tensor,
    weight: torch.Tensor,
    eps: float = 1e-6,
) -> torch.Tensor:
    """
    Apply RMS normalization to input tensor.

    Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

    Args:
        x: Input tensor of shape (..., hidden_dim).
        weight: Learnable scale parameter of shape (hidden_dim,).
        eps: Small constant for numerical stability.

    Returns:
        Normalized tensor of same shape as input.

    Example:
        >>> x = torch.randn(2, 8, 64)
        >>> weight = torch.ones(64)
        >>> y = rmsnorm(x, weight, eps=1e-6)
    """
    assert x.shape[-1] == weight.shape[0], f"Hidden dim mismatch: {x.shape[-1]} vs {weight.shape[0]}"

    # Use Triton kernel if available and on CUDA
    if HAS_TRITON and x.is_cuda and weight.is_cuda:
        return _rmsnorm_triton(x, weight, eps)

    # Fallback to PyTorch
    return rmsnorm_torch(x, weight, eps)

rmsnorm_torch

rmsnorm_torch(x: Tensor, weight: Tensor, eps: float = 1e-06) -> torch.Tensor

PyTorch reference implementation of RMSNorm.

Works on any device (CPU or CUDA).

Source code in src/rotalabs_accel/kernels/normalization.py
def rmsnorm_torch(
    x: torch.Tensor,
    weight: torch.Tensor,
    eps: float = 1e-6,
) -> torch.Tensor:
    """
    PyTorch reference implementation of RMSNorm.

    Works on any device (CPU or CUDA).
    """
    # Compute in FP32 for numerical stability
    x_fp32 = x.float()
    rms = torch.rsqrt(x_fp32.pow(2).mean(dim=-1, keepdim=True) + eps)
    return (x_fp32 * rms).to(x.dtype) * weight

rmsnorm_residual_fused

rmsnorm_residual_fused(x: Tensor, residual: Tensor, weight: Tensor, eps: float = 1e-06) -> torch.Tensor

Fused RMSNorm with residual addition.

Computes: y = RMSNorm(x + residual) * weight

Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (..., hidden_dim).

required
residual Tensor

Residual tensor of same shape as x.

required
weight Tensor

Learnable scale parameter of shape (hidden_dim,).

required
eps float

Small constant for numerical stability.

1e-06

Returns:

Type Description
Tensor

Normalized tensor of same shape as input.

Example

x = torch.randn(2, 8, 64) residual = torch.randn_like(x) weight = torch.ones(64) y = rmsnorm_residual_fused(x, residual, weight, eps=1e-6)

Source code in src/rotalabs_accel/kernels/normalization.py
def rmsnorm_residual_fused(
    x: torch.Tensor,
    residual: torch.Tensor,
    weight: torch.Tensor,
    eps: float = 1e-6,
) -> torch.Tensor:
    """
    Fused RMSNorm with residual addition.

    Computes: y = RMSNorm(x + residual) * weight

    Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

    Args:
        x: Input tensor of shape (..., hidden_dim).
        residual: Residual tensor of same shape as x.
        weight: Learnable scale parameter of shape (hidden_dim,).
        eps: Small constant for numerical stability.

    Returns:
        Normalized tensor of same shape as input.

    Example:
        >>> x = torch.randn(2, 8, 64)
        >>> residual = torch.randn_like(x)
        >>> weight = torch.ones(64)
        >>> y = rmsnorm_residual_fused(x, residual, weight, eps=1e-6)
    """
    assert x.shape == residual.shape, f"Shape mismatch: x={x.shape}, residual={residual.shape}"
    assert x.shape[-1] == weight.shape[0], f"Hidden dim mismatch: {x.shape[-1]} vs {weight.shape[0]}"

    # Use Triton kernel if available and on CUDA
    if HAS_TRITON and x.is_cuda and residual.is_cuda and weight.is_cuda:
        return _rmsnorm_residual_triton(x, residual, weight, eps)

    # Fallback to PyTorch
    return rmsnorm_residual_torch(x, residual, weight, eps)

Classes

TritonRMSNorm

Bases: Module

RMSNorm layer with automatic Triton/PyTorch dispatch.

Drop-in replacement for torch.nn.RMSNorm with identical interface. Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

Parameters:

Name Type Description Default
hidden_size int

Size of the last dimension to normalize over.

required
eps float

Small constant for numerical stability.

1e-06
Example

norm = TritonRMSNorm(64) x = torch.randn(2, 8, 64) y = norm(x)

Source code in src/rotalabs_accel/kernels/normalization.py
class TritonRMSNorm(torch.nn.Module):
    """
    RMSNorm layer with automatic Triton/PyTorch dispatch.

    Drop-in replacement for torch.nn.RMSNorm with identical interface.
    Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

    Args:
        hidden_size: Size of the last dimension to normalize over.
        eps: Small constant for numerical stability.

    Example:
        >>> norm = TritonRMSNorm(64)
        >>> x = torch.randn(2, 8, 64)
        >>> y = norm(x)
    """

    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(hidden_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return rmsnorm(x, self.weight, self.eps)

    def extra_repr(self) -> str:
        return f"{self.hidden_size}, eps={self.eps}"
__init__
__init__(hidden_size: int, eps: float = 1e-06)
Source code in src/rotalabs_accel/kernels/normalization.py
def __init__(self, hidden_size: int, eps: float = 1e-6):
    super().__init__()
    self.hidden_size = hidden_size
    self.eps = eps
    self.weight = torch.nn.Parameter(torch.ones(hidden_size))
forward
forward(x: Tensor) -> torch.Tensor
Source code in src/rotalabs_accel/kernels/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return rmsnorm(x, self.weight, self.eps)

Implementation Notes

Automatic Dispatch

The rmsnorm function automatically selects the best implementation:

def rmsnorm(x, weight, eps=1e-6):
    if HAS_TRITON and x.is_cuda and weight.is_cuda:
        return _rmsnorm_triton(x, weight, eps)
    return rmsnorm_torch(x, weight, eps)

Numerical Stability

All implementations use FP32 accumulation for the variance computation, even when inputs are FP16/BF16. This prevents numerical issues with large hidden dimensions.

Block Size Selection

The Triton kernel automatically selects block sizes based on the hidden dimension:

BLOCK_SIZE = min(triton.next_power_of_2(hidden_dim), 8192)

References