Skip to content

Rotary Position Embeddings (RoPE)

Rotary Position Embeddings for encoding position information in attention layers.

Overview

RoPE encodes position information by rotating query and key vectors in 2D subspaces. It's used in LLaMA, Mistral, Qwen, and most modern LLMs.

Mathematical formula:

For a vector \(x\) at position \(m\), the rotated vector is:

\[ \text{RoPE}(x, m) = \begin{pmatrix} x_0 \\ x_1 \\ x_2 \\ x_3 \\ \vdots \end{pmatrix} \odot \begin{pmatrix} \cos(m\theta_0) \\ \cos(m\theta_0) \\ \cos(m\theta_1) \\ \cos(m\theta_1) \\ \vdots \end{pmatrix} + \begin{pmatrix} -x_1 \\ x_0 \\ -x_3 \\ x_2 \\ \vdots \end{pmatrix} \odot \begin{pmatrix} \sin(m\theta_0) \\ \sin(m\theta_0) \\ \sin(m\theta_1) \\ \sin(m\theta_1) \\ \vdots \end{pmatrix} \]

Where \(\theta_i = \frac{1}{\text{base}^{2i/d}}\) with typical base=10000.

Key Properties

  1. Relative position encoding: The dot product between rotated vectors depends only on their relative position
  2. Long-range decay: Attention naturally decays with distance due to the rotation frequencies
  3. No learned parameters: Position encodings are computed, not learned

Performance Characteristics

Configuration PyTorch Triton Speedup
head_dim=128, seq=2048 67 μs 23 μs 2.9x
head_dim=128, seq=8192 267 μs 92 μs 2.9x
head_dim=64, seq=2048 34 μs 12 μs 2.8x

Usage Examples

import torch
from rotalabs_accel import RotaryEmbedding

# Create RoPE module
rope = RotaryEmbedding(
    dim=128,           # Head dimension
    max_seq_len=8192,  # Maximum sequence length
    base=10000.0,      # Frequency base (standard is 10000)
)

# Query and Key tensors
# Shape: [batch, seq_len, num_heads, head_dim]
q = torch.randn(2, 512, 32, 128, device="cuda", dtype=torch.float16)
k = torch.randn(2, 512, 32, 128, device="cuda", dtype=torch.float16)

# Apply RoPE
q_rot, k_rot = rope(q, k, seq_len=512)

Functional API

from rotalabs_accel import build_rope_cache, apply_rope

# Build cache once (at model initialization)
cos, sin = build_rope_cache(
    seq_len=8192,
    head_dim=128,
    base=10000.0,
    device="cuda",
)

# Apply during forward pass
# Slice cache to actual sequence length
q_rot, k_rot = apply_rope(q, k, cos[:seq_len], sin[:seq_len])

With Grouped Query Attention (GQA)

RoPE works with different numbers of Q and K heads:

# LLaMA 3 style: 32 Q heads, 8 KV heads
q = torch.randn(2, 512, 32, 128, device="cuda")  # 32 heads
k = torch.randn(2, 512, 8, 128, device="cuda")   # 8 heads

# apply_rope handles broadcasting automatically
q_rot, k_rot = rope(q, k, seq_len=512)

Position Offset (for KV Cache)

During generation with KV cache, you need to offset positions:

# First token: position 0
q1, k1 = rope(q[:, :1], k[:, :1], seq_len=1)
cached_k = k1

# Next token: position 1
# Pass offset to start from correct position
q2, k2 = rope(q[:, :1], k[:, :1], seq_len=1, offset=1)
cached_k = torch.cat([cached_k, k2], dim=1)

API Reference

Functions

apply_rope

apply_rope(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor, use_triton: Optional[bool] = None) -> Tuple[torch.Tensor, torch.Tensor]

Apply Rotary Position Embeddings to query and key tensors.

Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

Parameters:

Name Type Description Default
q Tensor

Query tensor [batch, seq, heads, head_dim]

required
k Tensor

Key tensor [batch, seq, heads, head_dim]

required
cos Tensor

Cosine cache for positions

required
sin Tensor

Sine cache for positions

required
use_triton Optional[bool]

Force Triton (True) or PyTorch (False). None = auto.

None

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple of (q_rotated, k_rotated) with same shapes as inputs.

Example

q = torch.randn(2, 16, 4, 32) k = torch.randn(2, 16, 4, 32) cos, sin = build_rope_cache(16, 32) q_rot, k_rot = apply_rope(q, k, cos, sin)

Source code in src/rotalabs_accel/kernels/rope.py
def apply_rope(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    use_triton: Optional[bool] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply Rotary Position Embeddings to query and key tensors.

    Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

    Args:
        q: Query tensor [batch, seq, heads, head_dim]
        k: Key tensor [batch, seq, heads, head_dim]
        cos: Cosine cache for positions
        sin: Sine cache for positions
        use_triton: Force Triton (True) or PyTorch (False). None = auto.

    Returns:
        Tuple of (q_rotated, k_rotated) with same shapes as inputs.

    Example:
        >>> q = torch.randn(2, 16, 4, 32)
        >>> k = torch.randn(2, 16, 4, 32)
        >>> cos, sin = build_rope_cache(16, 32)
        >>> q_rot, k_rot = apply_rope(q, k, cos, sin)
    """
    if use_triton is None:
        use_triton = HAS_TRITON and q.is_cuda

    if use_triton and HAS_TRITON:
        return _rope_triton(q, k, cos, sin)
    else:
        return rope_torch(q, k, cos, sin)

rope_torch

rope_torch(q: Tensor, k: Tensor, cos: Tensor, sin: Tensor) -> Tuple[torch.Tensor, torch.Tensor]

PyTorch reference implementation of RoPE.

Works on any device (CPU or CUDA).

Parameters:

Name Type Description Default
q Tensor

Query tensor [batch, seq, heads, head_dim]

required
k Tensor

Key tensor [batch, seq, heads, head_dim]

required
cos Tensor

Cosine cache [seq, head_dim/2] or broadcastable

required
sin Tensor

Sine cache [seq, head_dim/2] or broadcastable

required

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple of (q_rotated, k_rotated).

Source code in src/rotalabs_accel/kernels/rope.py
def rope_torch(
    q: torch.Tensor,
    k: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    PyTorch reference implementation of RoPE.

    Works on any device (CPU or CUDA).

    Args:
        q: Query tensor [batch, seq, heads, head_dim]
        k: Key tensor [batch, seq, heads, head_dim]
        cos: Cosine cache [seq, head_dim/2] or broadcastable
        sin: Sine cache [seq, head_dim/2] or broadcastable

    Returns:
        Tuple of (q_rotated, k_rotated).
    """
    # Reshape for rotation: split last dim into pairs
    q_reshape = q.view(*q.shape[:-1], -1, 2)  # [..., head_dim/2, 2]
    k_reshape = k.view(*k.shape[:-1], -1, 2)

    # Expand cos/sin if needed
    if cos.dim() == 2:
        cos = cos.unsqueeze(0).unsqueeze(2)  # [1, seq, 1, head_dim/2]
        sin = sin.unsqueeze(0).unsqueeze(2)

    # Apply rotation
    q_rot = torch.stack([
        q_reshape[..., 0] * cos - q_reshape[..., 1] * sin,
        q_reshape[..., 0] * sin + q_reshape[..., 1] * cos,
    ], dim=-1).flatten(-2)

    k_rot = torch.stack([
        k_reshape[..., 0] * cos - k_reshape[..., 1] * sin,
        k_reshape[..., 0] * sin + k_reshape[..., 1] * cos,
    ], dim=-1).flatten(-2)

    return q_rot, k_rot

build_rope_cache

build_rope_cache(seq_len: int, head_dim: int, base: float = 10000.0, device: Optional[device] = None, dtype: dtype = torch.float32) -> Tuple[torch.Tensor, torch.Tensor]

Build cosine and sine caches for RoPE.

Parameters:

Name Type Description Default
seq_len int

Maximum sequence length.

required
head_dim int

Dimension of each attention head.

required
base float

Base for the frequency computation (default: 10000).

10000.0
device Optional[device]

Device for the tensors.

None
dtype dtype

Data type for the tensors.

float32

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple of (cos_cache, sin_cache), each of shape [seq_len, head_dim/2].

Example

cos, sin = build_rope_cache(2048, 128, device='cuda') print(cos.shape) # torch.Size([2048, 64])

Source code in src/rotalabs_accel/kernels/rope.py
def build_rope_cache(
    seq_len: int,
    head_dim: int,
    base: float = 10000.0,
    device: Optional[torch.device] = None,
    dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Build cosine and sine caches for RoPE.

    Args:
        seq_len: Maximum sequence length.
        head_dim: Dimension of each attention head.
        base: Base for the frequency computation (default: 10000).
        device: Device for the tensors.
        dtype: Data type for the tensors.

    Returns:
        Tuple of (cos_cache, sin_cache), each of shape [seq_len, head_dim/2].

    Example:
        >>> cos, sin = build_rope_cache(2048, 128, device='cuda')
        >>> print(cos.shape)  # torch.Size([2048, 64])
    """
    assert head_dim % 2 == 0, f"Head dim must be even, got {head_dim}"

    # Compute frequencies: theta_i = base^(-2i/d) for i in [0, d/2)
    half_dim = head_dim // 2
    freq_seq = torch.arange(half_dim, device=device, dtype=dtype)
    inv_freq = 1.0 / (base ** (freq_seq / half_dim))

    # Compute positions
    positions = torch.arange(seq_len, device=device, dtype=dtype)

    # Outer product: [seq_len, half_dim]
    angles = torch.outer(positions, inv_freq)

    # Compute cos and sin
    cos_cache = torch.cos(angles)
    sin_cache = torch.sin(angles)

    return cos_cache, sin_cache

Classes

RotaryEmbedding

Bases: Module

Rotary Position Embedding module.

Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

Parameters:

Name Type Description Default
dim int

Dimension of each attention head (head_dim).

required
max_seq_len int

Maximum sequence length to cache.

2048
base float

Base for frequency computation (default: 10000).

10000.0
Example

rope = RotaryEmbedding(dim=32, max_seq_len=128) q = torch.randn(2, 16, 4, 32) k = torch.randn(2, 16, 4, 32) q_rot, k_rot = rope(q, k)

Source code in src/rotalabs_accel/kernels/rope.py
class RotaryEmbedding(torch.nn.Module):
    """
    Rotary Position Embedding module.

    Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

    Args:
        dim: Dimension of each attention head (head_dim).
        max_seq_len: Maximum sequence length to cache.
        base: Base for frequency computation (default: 10000).

    Example:
        >>> rope = RotaryEmbedding(dim=32, max_seq_len=128)
        >>> q = torch.randn(2, 16, 4, 32)
        >>> k = torch.randn(2, 16, 4, 32)
        >>> q_rot, k_rot = rope(q, k)
    """

    def __init__(
        self,
        dim: int,
        max_seq_len: int = 2048,
        base: float = 10000.0,
    ):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # Build initial cache
        cos, sin = build_rope_cache(max_seq_len, dim, base)
        self.register_buffer("cos_cache", cos, persistent=False)
        self.register_buffer("sin_cache", sin, persistent=False)

    def _extend_cache(self, seq_len: int) -> None:
        """Extend cache if sequence is longer than current cache."""
        if seq_len <= self.cos_cache.shape[0]:
            return

        new_len = max(seq_len, self.cos_cache.shape[0] * 2)
        cos, sin = build_rope_cache(
            new_len,
            self.dim,
            self.base,
            device=self.cos_cache.device,
            dtype=self.cos_cache.dtype,
        )
        self.cos_cache = cos
        self.sin_cache = sin

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        position_ids: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply RoPE to query and key tensors.

        Args:
            q: Query tensor [batch, seq, heads, head_dim]
            k: Key tensor [batch, seq, heads, head_dim]
            position_ids: Optional position indices [batch, seq].

        Returns:
            Tuple of (q_rotated, k_rotated).
        """
        seq_len = q.shape[1]
        self._extend_cache(seq_len)

        if position_ids is None:
            cos = self.cos_cache[:seq_len]
            sin = self.sin_cache[:seq_len]
        else:
            cos = self.cos_cache[position_ids]
            sin = self.sin_cache[position_ids]

        return apply_rope(q, k, cos, sin)

    def extra_repr(self) -> str:
        return f"dim={self.dim}, max_seq_len={self.max_seq_len}, base={self.base}"
__init__
__init__(dim: int, max_seq_len: int = 2048, base: float = 10000.0)
Source code in src/rotalabs_accel/kernels/rope.py
def __init__(
    self,
    dim: int,
    max_seq_len: int = 2048,
    base: float = 10000.0,
):
    super().__init__()
    self.dim = dim
    self.max_seq_len = max_seq_len
    self.base = base

    # Build initial cache
    cos, sin = build_rope_cache(max_seq_len, dim, base)
    self.register_buffer("cos_cache", cos, persistent=False)
    self.register_buffer("sin_cache", sin, persistent=False)
forward
forward(q: Tensor, k: Tensor, position_ids: Optional[Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]

Apply RoPE to query and key tensors.

Parameters:

Name Type Description Default
q Tensor

Query tensor [batch, seq, heads, head_dim]

required
k Tensor

Key tensor [batch, seq, heads, head_dim]

required
position_ids Optional[Tensor]

Optional position indices [batch, seq].

None

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple of (q_rotated, k_rotated).

Source code in src/rotalabs_accel/kernels/rope.py
def forward(
    self,
    q: torch.Tensor,
    k: torch.Tensor,
    position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply RoPE to query and key tensors.

    Args:
        q: Query tensor [batch, seq, heads, head_dim]
        k: Key tensor [batch, seq, heads, head_dim]
        position_ids: Optional position indices [batch, seq].

    Returns:
        Tuple of (q_rotated, k_rotated).
    """
    seq_len = q.shape[1]
    self._extend_cache(seq_len)

    if position_ids is None:
        cos = self.cos_cache[:seq_len]
        sin = self.sin_cache[:seq_len]
    else:
        cos = self.cos_cache[position_ids]
        sin = self.sin_cache[position_ids]

    return apply_rope(q, k, cos, sin)

Implementation Notes

Cache Precomputation

The cos/sin tables are computed once and reused:

def build_rope_cache(seq_len, head_dim, base=10000.0, device="cuda"):
    # Compute frequencies
    inv_freq = 1.0 / (base ** (torch.arange(0, head_dim, 2) / head_dim))

    # Compute position angles
    t = torch.arange(seq_len, device=device)
    freqs = torch.outer(t, inv_freq)

    # Cache cos and sin
    cos = torch.cos(freqs)
    sin = torch.sin(freqs)

    return cos, sin

Memory Layout

The rotation is applied to pairs of adjacent dimensions:

  • \((x_0, x_1)\) rotated by \(\theta_0\)
  • \((x_2, x_3)\) rotated by \(\theta_1\)
  • etc.

This "interleaved" layout matches LLaMA and most modern models. Some older models use "sequential" layout where first half and second half are paired.

Extended Context (YaRN, NTK)

For extended context lengths, you can modify the base frequency:

# NTK-aware scaling for 4x context extension
base = 10000 * 4.0

rope = RotaryEmbedding(dim=128, max_seq_len=32768, base=base)

References