Modules¶
Drop-in nn.Module replacements with optimized Triton kernels.
Overview¶
These modules provide the same interface as their PyTorch counterparts but use optimized Triton kernels when available. If Triton isn't installed or the input is on CPU, they automatically fall back to pure PyTorch.
# These work identically, but TritonRMSNorm is faster on GPU
norm_pytorch = torch.nn.RMSNorm(4096)
norm_triton = TritonRMSNorm(4096)
# Same API, same results, different speed
y1 = norm_pytorch(x)
y2 = norm_triton(x) # Up to 3.8x faster on GPU
Module Summary¶
| Module | Replaces | Speedup | Use Case |
|---|---|---|---|
TritonRMSNorm |
nn.RMSNorm |
3.8x | LLaMA, Mistral normalization |
SwiGLU |
Custom FFN | 2.9x | LLaMA, PaLM FFN layers |
RotaryEmbedding |
Manual RoPE | 2.9x | Position encoding |
Int8Linear |
nn.Linear |
3.3x | Memory-efficient inference |
QuantizedLinear |
nn.Linear |
3.3x | Easy model quantization |
TritonRMSNorm¶
RMS normalization layer, used in LLaMA, Mistral, Qwen.
from rotalabs_accel import TritonRMSNorm
norm = TritonRMSNorm(hidden_size=4096, eps=1e-6)
x = torch.randn(2, 512, 4096, device="cuda", dtype=torch.float16)
y = norm(x)
Bases: Module
RMSNorm layer with automatic Triton/PyTorch dispatch.
Drop-in replacement for torch.nn.RMSNorm with identical interface. Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_size
|
int
|
Size of the last dimension to normalize over. |
required |
eps
|
float
|
Small constant for numerical stability. |
1e-06
|
Example
norm = TritonRMSNorm(64) x = torch.randn(2, 8, 64) y = norm(x)
Source code in src/rotalabs_accel/kernels/normalization.py
__init__ ¶
SwiGLU¶
Complete SwiGLU FFN module with gate, up, and down projections.
from rotalabs_accel import SwiGLU
ffn = SwiGLU(
hidden_size=4096,
intermediate_size=11008,
bias=False,
)
x = torch.randn(2, 512, 4096, device="cuda", dtype=torch.float16)
y = ffn(x) # Shape: (2, 512, 4096)
Bases: Module
SwiGLU module with linear projections.
Implements the full SwiGLU FFN
y = (silu(x @ W_gate) * (x @ W_up)) @ W_down
Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
hidden_size
|
int
|
Input/output dimension. |
required |
intermediate_size
|
int
|
Intermediate dimension for the FFN. |
required |
bias
|
bool
|
Whether to use bias in linear layers. |
False
|
Example
swiglu = SwiGLU(hidden_size=64, intermediate_size=256) x = torch.randn(2, 8, 64) y = swiglu(x) # Shape: (2, 8, 64)
Source code in src/rotalabs_accel/kernels/activations.py
__init__ ¶
Source code in src/rotalabs_accel/kernels/activations.py
RotaryEmbedding¶
Rotary Position Embeddings with automatic cache management.
from rotalabs_accel import RotaryEmbedding
rope = RotaryEmbedding(
dim=128,
max_seq_len=8192,
base=10000.0,
)
# Apply to query and key
q = torch.randn(2, 512, 32, 128, device="cuda")
k = torch.randn(2, 512, 32, 128, device="cuda")
q_rot, k_rot = rope(q, k, seq_len=512)
Bases: Module
Rotary Position Embedding module.
Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Dimension of each attention head (head_dim). |
required |
max_seq_len
|
int
|
Maximum sequence length to cache. |
2048
|
base
|
float
|
Base for frequency computation (default: 10000). |
10000.0
|
Example
rope = RotaryEmbedding(dim=32, max_seq_len=128) q = torch.randn(2, 16, 4, 32) k = torch.randn(2, 16, 4, 32) q_rot, k_rot = rope(q, k)
Source code in src/rotalabs_accel/kernels/rope.py
250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 | |
__init__ ¶
Source code in src/rotalabs_accel/kernels/rope.py
forward ¶
forward(q: Tensor, k: Tensor, position_ids: Optional[Tensor] = None) -> Tuple[torch.Tensor, torch.Tensor]
Apply RoPE to query and key tensors.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q
|
Tensor
|
Query tensor [batch, seq, heads, head_dim] |
required |
k
|
Tensor
|
Key tensor [batch, seq, heads, head_dim] |
required |
position_ids
|
Optional[Tensor]
|
Optional position indices [batch, seq]. |
None
|
Returns:
| Type | Description |
|---|---|
Tuple[Tensor, Tensor]
|
Tuple of (q_rotated, k_rotated). |
Source code in src/rotalabs_accel/kernels/rope.py
Int8Linear¶
Linear layer with INT8 quantized weights.
from rotalabs_accel import Int8Linear
linear = Int8Linear(
in_features=4096,
out_features=4096,
bias=False,
)
linear.quantize_weights(pretrained_weight)
y = linear(x)
Bases: Module
Linear layer using INT8 weights with optimized GEMM kernel.
Uses Triton kernel on CUDA when available, otherwise falls back to PyTorch.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimension (K). |
required |
out_features
|
int
|
Output dimension (N). |
required |
bias
|
bool
|
Whether to include bias. |
False
|
Example
linear = Int8Linear(64, 128) linear.quantize_weights(torch.randn(128, 64)) y = linear(torch.randn(2, 8, 64))
Source code in src/rotalabs_accel/kernels/gemm.py
241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 | |
__init__ ¶
Source code in src/rotalabs_accel/kernels/gemm.py
quantize_weights ¶
Quantize and store FP16/FP32 weights as INT8.
Source code in src/rotalabs_accel/kernels/gemm.py
forward ¶
Forward pass using optimized INT8 GEMM kernel.
Source code in src/rotalabs_accel/kernels/gemm.py
QuantizedLinear¶
Higher-level quantized linear with easy conversion from nn.Linear.
from rotalabs_accel import QuantizedLinear
# Convert existing layer
linear = torch.nn.Linear(4096, 4096)
qlinear = QuantizedLinear.from_linear(linear)
# Use like normal
y = qlinear(x)
Bases: Module
Linear layer with INT8 quantized weights.
Stores weights in INT8 format and dequantizes on-the-fly during forward pass. This is a reference implementation - the actual kernel-level optimization happens in the Triton INT8 GEMM kernel.
For W8A16 inference: - Weights: INT8 (2x memory reduction) - Activations: FP16 - Compute: FP16 with FP32 accumulation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input dimension. |
required |
out_features
|
int
|
Output dimension. |
required |
bias
|
bool
|
Whether to include bias (default: False for LLM weights). |
False
|
Example
linear = QuantizedLinear(4096, 4096) linear.quantize_weights(pretrained_weight) y = linear(x) # x is FP16, y is FP16
Source code in src/rotalabs_accel/quantization/symmetric.py
210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 | |
__init__ ¶
Source code in src/rotalabs_accel/quantization/symmetric.py
quantize_weights ¶
Quantize and store weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
weight
|
Tensor
|
FP16 or FP32 weight tensor of shape (out_features, in_features). |
required |
Source code in src/rotalabs_accel/quantization/symmetric.py
forward ¶
Forward pass with on-the-fly dequantization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor of shape (..., in_features). |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Output tensor of shape (..., out_features). |
Source code in src/rotalabs_accel/quantization/symmetric.py
from_linear
classmethod
¶
Create QuantizedLinear from existing nn.Linear.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
linear
|
Linear
|
Pretrained linear layer. |
required |
Returns:
| Type | Description |
|---|---|
QuantizedLinear
|
QuantizedLinear with quantized weights. |
Source code in src/rotalabs_accel/quantization/symmetric.py
Using with Existing Models¶
Replace Layers in a Model¶
from rotalabs_accel import TritonRMSNorm, SwiGLU, QuantizedLinear
def optimize_model(model):
"""Replace layers with optimized versions."""
for name, module in model.named_children():
# Replace RMSNorm
if isinstance(module, torch.nn.RMSNorm):
setattr(model, name, TritonRMSNorm(module.weight.shape[0]))
# Quantize Linear
elif isinstance(module, torch.nn.Linear):
setattr(model, name, QuantizedLinear.from_linear(module))
# Recurse
else:
optimize_model(module)
return model
With Hugging Face Transformers¶
from transformers import AutoModelForCausalLM
from rotalabs_accel import TritonRMSNorm
model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-hf")
# Replace all RMSNorm layers
for layer in model.model.layers:
layer.input_layernorm = TritonRMSNorm(
layer.input_layernorm.weight.shape[0]
)
layer.post_attention_layernorm = TritonRMSNorm(
layer.post_attention_layernorm.weight.shape[0]
)