Normalization¶
RMSNorm (Root Mean Square Layer Normalization) kernels with optional residual fusion.
Overview¶
RMSNorm is a simpler alternative to LayerNorm used in modern LLMs like LLaMA, Mistral, and Qwen. Unlike LayerNorm, it doesn't center the input (no mean subtraction), which reduces computation.
Mathematical formula:
Where:
- \(x\) is the input tensor
- \(n\) is the hidden dimension
- \(\epsilon\) is a small constant for numerical stability (typically 1e-6)
- \(\gamma\) is the learnable weight parameter
Performance Characteristics¶
RMSNorm is a memory-bound operation with low arithmetic intensity (~1-2 FLOPs/byte). The Triton kernel provides speedups by:
- Fusing operations: Combines variance computation, normalization, and scaling in a single kernel
- Reducing memory traffic: Reads input once, writes output once
- Using FP32 accumulation: Ensures numerical stability while keeping I/O in FP16
| Configuration | PyTorch | Triton | Speedup |
|---|---|---|---|
| hidden=4096, seq=2048 | 45 μs | 12 μs | 3.8x |
| hidden=8192, seq=2048 | 89 μs | 24 μs | 3.7x |
| hidden=4096, seq=8192 | 178 μs | 47 μs | 3.8x |
Usage Examples¶
Basic RMSNorm¶
import torch
from rotalabs_accel import TritonRMSNorm, rmsnorm
# Module API (recommended for models)
norm = TritonRMSNorm(hidden_size=4096, eps=1e-6)
x = torch.randn(2, 512, 4096, device="cuda", dtype=torch.float16)
y = norm(x)
# Functional API (for custom implementations)
weight = torch.ones(4096, device="cuda", dtype=torch.float16)
y = rmsnorm(x, weight, eps=1e-6)
Fused Residual + RMSNorm¶
In transformer blocks, RMSNorm is typically applied after adding a residual:
from rotalabs_accel import rmsnorm_residual_fused
# Standard pattern (two operations):
# x = x + residual
# x = rmsnorm(x, weight)
# Fused version (one operation, ~2x faster):
x = rmsnorm_residual_fused(x, residual, weight, eps=1e-6)
The fused version eliminates an intermediate tensor allocation and memory round-trip.
API Reference¶
Functions¶
rmsnorm ¶
Apply RMS normalization to input tensor.
Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (..., hidden_dim). |
required |
weight
|
Tensor
|
Learnable scale parameter of shape (hidden_dim,). |
required |
eps
|
float
|
Small constant for numerical stability. |
1e-06
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Normalized tensor of same shape as input. |
Example
x = torch.randn(2, 8, 64) weight = torch.ones(64) y = rmsnorm(x, weight, eps=1e-6)
Source code in src/rotalabs_accel/kernels/normalization.py
rmsnorm_torch ¶
PyTorch reference implementation of RMSNorm.
Works on any device (CPU or CUDA).
Source code in src/rotalabs_accel/kernels/normalization.py
rmsnorm_residual_fused ¶
rmsnorm_residual_fused(x: Tensor, residual: Tensor, weight: Tensor, eps: float = 1e-06) -> torch.Tensor
Fused RMSNorm with residual addition.
Computes: y = RMSNorm(x + residual) * weight
Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (..., hidden_dim). |
required |
residual
|
Tensor
|
Residual tensor of same shape as x. |
required |
weight
|
Tensor
|
Learnable scale parameter of shape (hidden_dim,). |
required |
eps
|
float
|
Small constant for numerical stability. |
1e-06
|
Returns:
| Type | Description |
|---|---|
Tensor
|
Normalized tensor of same shape as input. |
Example
x = torch.randn(2, 8, 64) residual = torch.randn_like(x) weight = torch.ones(64) y = rmsnorm_residual_fused(x, residual, weight, eps=1e-6)
Source code in src/rotalabs_accel/kernels/normalization.py
Classes¶
TritonRMSNorm ¶
Bases: Module
RMSNorm layer with automatic Triton/PyTorch dispatch.
Drop-in replacement for torch.nn.RMSNorm with identical interface. Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_size
|
int
|
Size of the last dimension to normalize over. |
required |
eps
|
float
|
Small constant for numerical stability. |
1e-06
|
Example
norm = TritonRMSNorm(64) x = torch.randn(2, 8, 64) y = norm(x)
Source code in src/rotalabs_accel/kernels/normalization.py
__init__ ¶
Implementation Notes¶
Automatic Dispatch¶
The rmsnorm function automatically selects the best implementation:
def rmsnorm(x, weight, eps=1e-6):
if HAS_TRITON and x.is_cuda and weight.is_cuda:
return _rmsnorm_triton(x, weight, eps)
return rmsnorm_torch(x, weight, eps)
Numerical Stability¶
All implementations use FP32 accumulation for the variance computation, even when inputs are FP16/BF16. This prevents numerical issues with large hidden dimensions.
Block Size Selection¶
The Triton kernel automatically selects block sizes based on the hidden dimension:
References¶
- Root Mean Square Layer Normalization - Original RMSNorm paper
- LLaMA: Open and Efficient Foundation Language Models - Uses RMSNorm