Skip to content

Modules

Drop-in nn.Module replacements with optimized Triton kernels.

Overview

These modules provide the same interface as their PyTorch counterparts but use optimized Triton kernels when available. If Triton isn't installed or the input is on CPU, they automatically fall back to pure PyTorch.

# These work identically, but TritonRMSNorm is faster on GPU
norm_pytorch = torch.nn.RMSNorm(4096)
norm_triton = TritonRMSNorm(4096)

# Same API, same results, different speed
y1 = norm_pytorch(x)
y2 = norm_triton(x)  # Up to 3.8x faster on GPU

Module Summary

Module Replaces Speedup Use Case
TritonRMSNorm nn.RMSNorm 3.8x LLaMA, Mistral normalization
SwiGLU Custom FFN 2.9x LLaMA, PaLM FFN layers
RotaryEmbedding Manual RoPE 2.9x Position encoding
Int8Linear nn.Linear 3.3x Memory-efficient inference
QuantizedLinear nn.Linear 3.3x Easy model quantization

TritonRMSNorm

RMS normalization layer, used in LLaMA, Mistral, Qwen.

from rotalabs_accel import TritonRMSNorm

norm = TritonRMSNorm(hidden_size=4096, eps=1e-6)
x = torch.randn(2, 512, 4096, device="cuda", dtype=torch.float16)
y = norm(x)

Bases: Module

RMSNorm layer with automatic Triton/PyTorch dispatch.

Drop-in replacement for torch.nn.RMSNorm with identical interface. Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

Parameters:

Name Type Description Default
hidden_size int

Size of the last dimension to normalize over.

required
eps float

Small constant for numerical stability.

1e-06
Example

norm = TritonRMSNorm(64) x = torch.randn(2, 8, 64) y = norm(x)

Source code in src/rotalabs_accel/kernels/normalization.py
class TritonRMSNorm(torch.nn.Module):
    """
    RMSNorm layer with automatic Triton/PyTorch dispatch.

    Drop-in replacement for torch.nn.RMSNorm with identical interface.
    Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

    Args:
        hidden_size: Size of the last dimension to normalize over.
        eps: Small constant for numerical stability.

    Example:
        >>> norm = TritonRMSNorm(64)
        >>> x = torch.randn(2, 8, 64)
        >>> y = norm(x)
    """

    def __init__(self, hidden_size: int, eps: float = 1e-6):
        super().__init__()
        self.hidden_size = hidden_size
        self.eps = eps
        self.weight = torch.nn.Parameter(torch.ones(hidden_size))

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return rmsnorm(x, self.weight, self.eps)

    def extra_repr(self) -> str:
        return f"{self.hidden_size}, eps={self.eps}"

__init__

__init__(hidden_size: int, eps: float = 1e-06)
Source code in src/rotalabs_accel/kernels/normalization.py
def __init__(self, hidden_size: int, eps: float = 1e-6):
    super().__init__()
    self.hidden_size = hidden_size
    self.eps = eps
    self.weight = torch.nn.Parameter(torch.ones(hidden_size))

forward

forward(x: Tensor) -> torch.Tensor
Source code in src/rotalabs_accel/kernels/normalization.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return rmsnorm(x, self.weight, self.eps)

SwiGLU

Complete SwiGLU FFN module with gate, up, and down projections.

from rotalabs_accel import SwiGLU

ffn = SwiGLU(
    hidden_size=4096,
    intermediate_size=11008,
    bias=False,
)
x = torch.randn(2, 512, 4096, device="cuda", dtype=torch.float16)
y = ffn(x)  # Shape: (2, 512, 4096)

Bases: Module

SwiGLU module with linear projections.

Implements the full SwiGLU FFN

y = (silu(x @ W_gate) * (x @ W_up)) @ W_down

Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

Parameters:

Name Type Description Default
hidden_size int

Input/output dimension.

required
intermediate_size int

Intermediate dimension for the FFN.

required
bias bool

Whether to use bias in linear layers.

False
Example

swiglu = SwiGLU(hidden_size=64, intermediate_size=256) x = torch.randn(2, 8, 64) y = swiglu(x) # Shape: (2, 8, 64)

Source code in src/rotalabs_accel/kernels/activations.py
class SwiGLU(torch.nn.Module):
    """
    SwiGLU module with linear projections.

    Implements the full SwiGLU FFN:
        y = (silu(x @ W_gate) * (x @ W_up)) @ W_down

    Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

    Args:
        hidden_size: Input/output dimension.
        intermediate_size: Intermediate dimension for the FFN.
        bias: Whether to use bias in linear layers.

    Example:
        >>> swiglu = SwiGLU(hidden_size=64, intermediate_size=256)
        >>> x = torch.randn(2, 8, 64)
        >>> y = swiglu(x)  # Shape: (2, 8, 64)
    """

    def __init__(
        self,
        hidden_size: int,
        intermediate_size: int,
        bias: bool = False,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.intermediate_size = intermediate_size

        self.w_gate = torch.nn.Linear(hidden_size, intermediate_size, bias=bias)
        self.w_up = torch.nn.Linear(hidden_size, intermediate_size, bias=bias)
        self.w_down = torch.nn.Linear(intermediate_size, hidden_size, bias=bias)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        gate = self.w_gate(x)
        up = self.w_up(x)
        return self.w_down(swiglu_fused(gate, up))

    def extra_repr(self) -> str:
        return f"hidden_size={self.hidden_size}, intermediate_size={self.intermediate_size}"

__init__

__init__(hidden_size: int, intermediate_size: int, bias: bool = False)
Source code in src/rotalabs_accel/kernels/activations.py
def __init__(
    self,
    hidden_size: int,
    intermediate_size: int,
    bias: bool = False,
):
    super().__init__()
    self.hidden_size = hidden_size
    self.intermediate_size = intermediate_size

    self.w_gate = torch.nn.Linear(hidden_size, intermediate_size, bias=bias)
    self.w_up = torch.nn.Linear(hidden_size, intermediate_size, bias=bias)
    self.w_down = torch.nn.Linear(intermediate_size, hidden_size, bias=bias)

forward

forward(x: Tensor) -> torch.Tensor
Source code in src/rotalabs_accel/kernels/activations.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    gate = self.w_gate(x)
    up = self.w_up(x)
    return self.w_down(swiglu_fused(gate, up))

RotaryEmbedding

Rotary Position Embeddings with automatic cache management.

from rotalabs_accel import RotaryEmbedding

rope = RotaryEmbedding(
    dim=128,
    max_seq_len=8192,
    base=10000.0,
)

# Apply to query and key
q = torch.randn(2, 512, 32, 128, device="cuda")
k = torch.randn(2, 512, 32, 128, device="cuda")
q_rot, k_rot = rope(q, k, seq_len=512)

Bases: Module

Rotary Position Embedding module.

Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

Parameters:

Name Type Description Default
dim int

Dimension of each attention head (head_dim).

required
max_seq_len int

Maximum sequence length to cache.

2048
base float

Base for frequency computation (default: 10000).

10000.0
Example

rope = RotaryEmbedding(dim=32, max_seq_len=128) q = torch.randn(2, 16, 4, 32) k = torch.randn(2, 16, 4, 32) q_rot, k_rot = rope(q, k)

Source code in src/rotalabs_accel/kernels/rope.py
class RotaryEmbedding(torch.nn.Module):
    """
    Rotary Position Embedding module.

    Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

    Args:
        dim: Dimension of each attention head (head_dim).
        max_seq_len: Maximum sequence length to cache.
        base: Base for frequency computation (default: 10000).

    Example:
        >>> rope = RotaryEmbedding(dim=32, max_seq_len=128)
        >>> q = torch.randn(2, 16, 4, 32)
        >>> k = torch.randn(2, 16, 4, 32)
        >>> q_rot, k_rot = rope(q, k)
    """

    def __init__(
        self,
        dim: int,
        max_seq_len: int = 2048,
        base: float = 10000.0,
    ):
        super().__init__()
        self.dim = dim
        self.max_seq_len = max_seq_len
        self.base = base

        # Build initial cache
        cos, sin = build_rope_cache(max_seq_len, dim, base)
        self.register_buffer("cos_cache", cos, persistent=False)
        self.register_buffer("sin_cache", sin, persistent=False)

    def _extend_cache(self, seq_len: int) -> None:
        """Extend cache if sequence is longer than current cache."""
        if seq_len <= self.cos_cache.shape[0]:
            return

        new_len = max(seq_len, self.cos_cache.shape[0] * 2)
        cos, sin = build_rope_cache(
            new_len,
            self.dim,
            self.base,
            device=self.cos_cache.device,
            dtype=self.cos_cache.dtype,
        )
        self.cos_cache = cos
        self.sin_cache = sin

    def forward(
        self,
        q: torch.Tensor,
        k: torch.Tensor,
        position_ids: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Apply RoPE to query and key tensors.

        Args:
            q: Query tensor [batch, seq, heads, head_dim]
            k: Key tensor [batch, seq, heads, head_dim]
            position_ids: Optional position indices [batch, seq].

        Returns:
            Tuple of (q_rotated, k_rotated).
        """
        seq_len = q.shape[1]
        self._extend_cache(seq_len)

        if position_ids is None:
            cos = self.cos_cache[:seq_len]
            sin = self.sin_cache[:seq_len]
        else:
            cos = self.cos_cache[position_ids]
            sin = self.sin_cache[position_ids]

        return apply_rope(q, k, cos, sin)

    def extra_repr(self) -> str:
        return f"dim={self.dim}, max_seq_len={self.max_seq_len}, base={self.base}"

__init__

__init__(dim: int, max_seq_len: int = 2048, base: float = 10000.0)
Source code in src/rotalabs_accel/kernels/rope.py
def __init__(
    self,
    dim: int,
    max_seq_len: int = 2048,
    base: float = 10000.0,
):
    super().__init__()
    self.dim = dim
    self.max_seq_len = max_seq_len
    self.base = base

    # Build initial cache
    cos, sin = build_rope_cache(max_seq_len, dim, base)
    self.register_buffer("cos_cache", cos, persistent=False)
    self.register_buffer("sin_cache", sin, persistent=False)

forward

forward(q: Tensor, k: Tensor, position_ids: Optional[Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]

Apply RoPE to query and key tensors.

Parameters:

Name Type Description Default
q Tensor

Query tensor [batch, seq, heads, head_dim]

required
k Tensor

Key tensor [batch, seq, heads, head_dim]

required
position_ids Optional[Tensor]

Optional position indices [batch, seq].

None

Returns:

Type Description
Tuple[Tensor, Tensor]

Tuple of (q_rotated, k_rotated).

Source code in src/rotalabs_accel/kernels/rope.py
def forward(
    self,
    q: torch.Tensor,
    k: torch.Tensor,
    position_ids: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
    """
    Apply RoPE to query and key tensors.

    Args:
        q: Query tensor [batch, seq, heads, head_dim]
        k: Key tensor [batch, seq, heads, head_dim]
        position_ids: Optional position indices [batch, seq].

    Returns:
        Tuple of (q_rotated, k_rotated).
    """
    seq_len = q.shape[1]
    self._extend_cache(seq_len)

    if position_ids is None:
        cos = self.cos_cache[:seq_len]
        sin = self.sin_cache[:seq_len]
    else:
        cos = self.cos_cache[position_ids]
        sin = self.sin_cache[position_ids]

    return apply_rope(q, k, cos, sin)

Int8Linear

Linear layer with INT8 quantized weights.

from rotalabs_accel import Int8Linear

linear = Int8Linear(
    in_features=4096,
    out_features=4096,
    bias=False,
)
linear.quantize_weights(pretrained_weight)
y = linear(x)

Bases: Module

Linear layer using INT8 weights with optimized GEMM kernel.

Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

Parameters:

Name Type Description Default
in_features int

Input dimension (K).

required
out_features int

Output dimension (N).

required
bias bool

Whether to include bias.

False
Example

linear = Int8Linear(64, 128) linear.quantize_weights(torch.randn(128, 64)) y = linear(torch.randn(2, 8, 64))

Source code in src/rotalabs_accel/kernels/gemm.py
class Int8Linear(torch.nn.Module):
    """
    Linear layer using INT8 weights with optimized GEMM kernel.

    Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.

    Args:
        in_features: Input dimension (K).
        out_features: Output dimension (N).
        bias: Whether to include bias.

    Example:
        >>> linear = Int8Linear(64, 128)
        >>> linear.quantize_weights(torch.randn(128, 64))
        >>> y = linear(torch.randn(2, 8, 64))
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = False,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # INT8 weights: (out_features, in_features)
        self.register_buffer(
            "weight_int8",
            torch.zeros(out_features, in_features, dtype=torch.int8)
        )
        self.register_buffer(
            "scale",
            torch.ones(out_features, dtype=torch.float32)
        )
        # Pre-transposed weight for Triton path
        self.register_buffer(
            "weight_transposed",
            torch.zeros(in_features, out_features, dtype=torch.int8)
        )

        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(out_features, dtype=torch.float16))
        else:
            self.register_parameter("bias", None)

        self._quantized = False

    def quantize_weights(self, weight: torch.Tensor) -> None:
        """Quantize and store FP16/FP32 weights as INT8."""
        assert weight.shape == (self.out_features, self.in_features)

        # Quantization happens on CPU
        weight_cpu = weight.cpu() if weight.is_cuda else weight
        weight_int8, scale = quantize_weight_per_channel(weight_cpu)

        # Move to same device as buffers and copy
        device = self.weight_int8.device
        self.weight_int8.copy_(weight_int8.to(device))
        self.scale.copy_(scale.to(device))
        self.weight_transposed.copy_(weight_int8.t().contiguous().to(device))
        self._quantized = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Forward pass using optimized INT8 GEMM kernel."""
        if not self._quantized:
            raise RuntimeError("Weights not quantized. Call quantize_weights() first.")

        return int8_gemm(
            x, self.weight_int8, self.scale, self.bias,
            weight_transposed=self.weight_transposed,
        )

    @classmethod
    def from_linear(cls, linear: torch.nn.Linear) -> "Int8Linear":
        """Convert nn.Linear to Int8Linear."""
        has_bias = linear.bias is not None
        int8_linear = cls(linear.in_features, linear.out_features, bias=has_bias)

        # Move to same device as input linear, then quantize
        device = linear.weight.device
        int8_linear = int8_linear.to(device)
        int8_linear.quantize_weights(linear.weight.data)

        # Copy bias if present
        if has_bias:
            int8_linear.bias.data.copy_(linear.bias.data.half())

        return int8_linear

    def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, out_features={self.out_features}, "
            f"bias={self.bias is not None}, quantized={self._quantized}"
        )

__init__

__init__(in_features: int, out_features: int, bias: bool = False)
Source code in src/rotalabs_accel/kernels/gemm.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    bias: bool = False,
):
    super().__init__()
    self.in_features = in_features
    self.out_features = out_features

    # INT8 weights: (out_features, in_features)
    self.register_buffer(
        "weight_int8",
        torch.zeros(out_features, in_features, dtype=torch.int8)
    )
    self.register_buffer(
        "scale",
        torch.ones(out_features, dtype=torch.float32)
    )
    # Pre-transposed weight for Triton path
    self.register_buffer(
        "weight_transposed",
        torch.zeros(in_features, out_features, dtype=torch.int8)
    )

    if bias:
        self.bias = torch.nn.Parameter(torch.zeros(out_features, dtype=torch.float16))
    else:
        self.register_parameter("bias", None)

    self._quantized = False

quantize_weights

quantize_weights(weight: Tensor) -> None

Quantize and store FP16/FP32 weights as INT8.

Source code in src/rotalabs_accel/kernels/gemm.py
def quantize_weights(self, weight: torch.Tensor) -> None:
    """Quantize and store FP16/FP32 weights as INT8."""
    assert weight.shape == (self.out_features, self.in_features)

    # Quantization happens on CPU
    weight_cpu = weight.cpu() if weight.is_cuda else weight
    weight_int8, scale = quantize_weight_per_channel(weight_cpu)

    # Move to same device as buffers and copy
    device = self.weight_int8.device
    self.weight_int8.copy_(weight_int8.to(device))
    self.scale.copy_(scale.to(device))
    self.weight_transposed.copy_(weight_int8.t().contiguous().to(device))
    self._quantized = True

forward

forward(x: Tensor) -> torch.Tensor

Forward pass using optimized INT8 GEMM kernel.

Source code in src/rotalabs_accel/kernels/gemm.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """Forward pass using optimized INT8 GEMM kernel."""
    if not self._quantized:
        raise RuntimeError("Weights not quantized. Call quantize_weights() first.")

    return int8_gemm(
        x, self.weight_int8, self.scale, self.bias,
        weight_transposed=self.weight_transposed,
    )

QuantizedLinear

Higher-level quantized linear with easy conversion from nn.Linear.

from rotalabs_accel import QuantizedLinear

# Convert existing layer
linear = torch.nn.Linear(4096, 4096)
qlinear = QuantizedLinear.from_linear(linear)

# Use like normal
y = qlinear(x)

Bases: Module

Linear layer with INT8 quantized weights.

Stores weights in INT8 format and dequantizes on-the-fly during forward pass. This is a reference implementation - the actual kernel-level optimization happens in the Triton INT8 GEMM kernel.

For W8A16 inference: - Weights: INT8 (2x memory reduction) - Activations: FP16 - Compute: FP16 with FP32 accumulation

Parameters:

Name Type Description Default
in_features int

Input dimension.

required
out_features int

Output dimension.

required
bias bool

Whether to include bias (default: False for LLM weights).

False
Example

linear = QuantizedLinear(4096, 4096) linear.quantize_weights(pretrained_weight) y = linear(x) # x is FP16, y is FP16

Source code in src/rotalabs_accel/quantization/symmetric.py
class QuantizedLinear(torch.nn.Module):
    """
    Linear layer with INT8 quantized weights.

    Stores weights in INT8 format and dequantizes on-the-fly during forward pass.
    This is a reference implementation - the actual kernel-level optimization
    happens in the Triton INT8 GEMM kernel.

    For W8A16 inference:
    - Weights: INT8 (2x memory reduction)
    - Activations: FP16
    - Compute: FP16 with FP32 accumulation

    Args:
        in_features: Input dimension.
        out_features: Output dimension.
        bias: Whether to include bias (default: False for LLM weights).

    Example:
        >>> linear = QuantizedLinear(4096, 4096)
        >>> linear.quantize_weights(pretrained_weight)
        >>> y = linear(x)  # x is FP16, y is FP16
    """

    def __init__(
        self,
        in_features: int,
        out_features: int,
        bias: bool = False,
    ):
        super().__init__()
        self.in_features = in_features
        self.out_features = out_features

        # Register buffers for quantized weights
        self.register_buffer(
            "weight_int8",
            torch.zeros(out_features, in_features, dtype=torch.int8)
        )
        self.register_buffer(
            "weight_scale",
            torch.ones(out_features, dtype=torch.float32)
        )

        if bias:
            self.bias = torch.nn.Parameter(torch.zeros(out_features))
        else:
            self.register_parameter("bias", None)

        self._quantized = False

    def quantize_weights(self, weight: torch.Tensor) -> None:
        """
        Quantize and store weights.

        Args:
            weight: FP16 or FP32 weight tensor of shape (out_features, in_features).
        """
        assert weight.shape == (self.out_features, self.in_features)
        weight_int8, scale = quantize_weight_per_channel(weight)
        self.weight_int8.copy_(weight_int8)
        self.weight_scale.copy_(scale)
        self._quantized = True

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        Forward pass with on-the-fly dequantization.

        Args:
            x: Input tensor of shape (..., in_features).

        Returns:
            Output tensor of shape (..., out_features).
        """
        if not self._quantized:
            raise RuntimeError("Weights not quantized. Call quantize_weights() first.")

        # Dequantize weights
        weight = dequantize(self.weight_int8, self.weight_scale, dtype=x.dtype, dim=0)

        # Compute matmul
        y = torch.nn.functional.linear(x, weight, self.bias)

        return y

    @classmethod
    def from_linear(cls, linear: torch.nn.Linear) -> "QuantizedLinear":
        """
        Create QuantizedLinear from existing nn.Linear.

        Args:
            linear: Pretrained linear layer.

        Returns:
            QuantizedLinear with quantized weights.
        """
        has_bias = linear.bias is not None
        quantized = cls(linear.in_features, linear.out_features, bias=has_bias)
        quantized.quantize_weights(linear.weight.data)
        if has_bias:
            quantized.bias.data.copy_(linear.bias.data)
        return quantized

    def extra_repr(self) -> str:
        return (
            f"in_features={self.in_features}, out_features={self.out_features}, "
            f"bias={self.bias is not None}, quantized={self._quantized}"
        )

__init__

__init__(in_features: int, out_features: int, bias: bool = False)
Source code in src/rotalabs_accel/quantization/symmetric.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    bias: bool = False,
):
    super().__init__()
    self.in_features = in_features
    self.out_features = out_features

    # Register buffers for quantized weights
    self.register_buffer(
        "weight_int8",
        torch.zeros(out_features, in_features, dtype=torch.int8)
    )
    self.register_buffer(
        "weight_scale",
        torch.ones(out_features, dtype=torch.float32)
    )

    if bias:
        self.bias = torch.nn.Parameter(torch.zeros(out_features))
    else:
        self.register_parameter("bias", None)

    self._quantized = False

quantize_weights

quantize_weights(weight: Tensor) -> None

Quantize and store weights.

Parameters:

Name Type Description Default
weight Tensor

FP16 or FP32 weight tensor of shape (out_features, in_features).

required
Source code in src/rotalabs_accel/quantization/symmetric.py
def quantize_weights(self, weight: torch.Tensor) -> None:
    """
    Quantize and store weights.

    Args:
        weight: FP16 or FP32 weight tensor of shape (out_features, in_features).
    """
    assert weight.shape == (self.out_features, self.in_features)
    weight_int8, scale = quantize_weight_per_channel(weight)
    self.weight_int8.copy_(weight_int8)
    self.weight_scale.copy_(scale)
    self._quantized = True

forward

forward(x: Tensor) -> torch.Tensor

Forward pass with on-the-fly dequantization.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (..., in_features).

required

Returns:

Type Description
Tensor

Output tensor of shape (..., out_features).

Source code in src/rotalabs_accel/quantization/symmetric.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass with on-the-fly dequantization.

    Args:
        x: Input tensor of shape (..., in_features).

    Returns:
        Output tensor of shape (..., out_features).
    """
    if not self._quantized:
        raise RuntimeError("Weights not quantized. Call quantize_weights() first.")

    # Dequantize weights
    weight = dequantize(self.weight_int8, self.weight_scale, dtype=x.dtype, dim=0)

    # Compute matmul
    y = torch.nn.functional.linear(x, weight, self.bias)

    return y

from_linear classmethod

from_linear(linear: Linear) -> QuantizedLinear

Create QuantizedLinear from existing nn.Linear.

Parameters:

Name Type Description Default
linear Linear

Pretrained linear layer.

required

Returns:

Type Description
QuantizedLinear

QuantizedLinear with quantized weights.

Source code in src/rotalabs_accel/quantization/symmetric.py
@classmethod
def from_linear(cls, linear: torch.nn.Linear) -> "QuantizedLinear":
    """
    Create QuantizedLinear from existing nn.Linear.

    Args:
        linear: Pretrained linear layer.

    Returns:
        QuantizedLinear with quantized weights.
    """
    has_bias = linear.bias is not None
    quantized = cls(linear.in_features, linear.out_features, bias=has_bias)
    quantized.quantize_weights(linear.weight.data)
    if has_bias:
        quantized.bias.data.copy_(linear.bias.data)
    return quantized

Using with Existing Models

Replace Layers in a Model

from rotalabs_accel import TritonRMSNorm, SwiGLU, QuantizedLinear

def optimize_model(model):
    """Replace layers with optimized versions."""
    for name, module in model.named_children():
        # Replace RMSNorm
        if isinstance(module, torch.nn.RMSNorm):
            setattr(model, name, TritonRMSNorm(module.weight.shape[0]))

        # Quantize Linear
        elif isinstance(module, torch.nn.Linear):
            setattr(model, name, QuantizedLinear.from_linear(module))

        # Recurse
        else:
            optimize_model(module)

    return model

With Hugging Face Transformers

from transformers import AutoModelForCausalLM
from rotalabs_accel import TritonRMSNorm

model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")

# Replace all RMSNorm layers
for layer in model.model.layers:
    layer.input_layernorm = TritonRMSNorm(
        layer.input_layernorm.weight.shape[0]
    )
    layer.post_attention_layernorm = TritonRMSNorm(
        layer.post_attention_layernorm.weight.shape[0]
    )