Skip to content

Activation Hooks

Extract activations from model layers during inference.

ActivationHook

Hook for capturing activations from specific model layers.

Works with HuggingFace transformers models (GPT-2, Mistral, Llama, etc).

Example

hook = ActivationHook(model, layer_indices=[10, 15, 20]) with hook: ... outputs = model(**inputs) act = hook.cache.get("layer_15") # (batch, seq, hidden)

Source code in src/rotalabs_probe/probing/hooks.py
class ActivationHook:
    """Hook for capturing activations from specific model layers.

    Works with HuggingFace transformers models (GPT-2, Mistral, Llama, etc).

    Example:
        >>> hook = ActivationHook(model, layer_indices=[10, 15, 20])
        >>> with hook:
        ...     outputs = model(**inputs)
        >>> act = hook.cache.get("layer_15")  # (batch, seq, hidden)
    """

    def __init__(
        self,
        model: nn.Module,
        layer_indices: List[int],
        component: str = "residual",
        token_position: str = "all",
    ):
        """Initialize activation hook.

        Args:
            model: HuggingFace model to hook
            layer_indices: Which layers to capture
            component: What to capture - "residual", "attn", or "mlp"
            token_position: "all", "last", or "first"
        """
        self.model = model
        self.layer_indices = layer_indices
        self.component = component
        self.token_position = token_position
        self.cache = ActivationCache()
        self._handles: List[Any] = []

    def _get_layers(self) -> nn.ModuleList:
        """Get the transformer layers from the model."""
        # XXX: this is ugly but HF doesn't have a consistent API for this
        if hasattr(self.model, "model"):
            inner = self.model.model
            if hasattr(inner, "layers"):
                return inner.layers  # Llama, Mistral
            elif hasattr(inner, "decoder"):
                return inner.decoder.layers
        if hasattr(self.model, "transformer"):
            if hasattr(self.model.transformer, "h"):
                return self.model.transformer.h  # GPT-2
        if hasattr(self.model, "gpt_neox"):
            return self.model.gpt_neox.layers

        # TODO: add support for more architectures as needed
        raise ValueError("Could not find transformer layers in model architecture")

    def _make_hook(self, layer_idx: int):
        """Create a hook function for a specific layer."""
        def hook_fn(module, input, output):
            # Handle different output formats
            if isinstance(output, tuple):
                hidden_states = output[0]
            else:
                hidden_states = output

            # Store based on token position
            if self.token_position == "last":
                self.cache.store(f"layer_{layer_idx}", hidden_states[:, -1:, :])
            elif self.token_position == "first":
                self.cache.store(f"layer_{layer_idx}", hidden_states[:, :1, :])
            else:  # all
                self.cache.store(f"layer_{layer_idx}", hidden_states)

        return hook_fn

    def __enter__(self):
        """Register hooks on specified layers."""
        self.cache.clear()
        layers = self._get_layers()

        for idx in self.layer_indices:
            if idx >= len(layers):
                raise ValueError(f"Layer {idx} out of range (model has {len(layers)} layers)")

            layer = layers[idx]
            handle = layer.register_forward_hook(self._make_hook(idx))
            self._handles.append(handle)

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Remove hooks."""
        for handle in self._handles:
            handle.remove()
        self._handles.clear()
        return False

__init__(model: nn.Module, layer_indices: List[int], component: str = 'residual', token_position: str = 'all')

Initialize activation hook.

Parameters:

Name Type Description Default
model Module

HuggingFace model to hook

required
layer_indices List[int]

Which layers to capture

required
component str

What to capture - "residual", "attn", or "mlp"

'residual'
token_position str

"all", "last", or "first"

'all'
Source code in src/rotalabs_probe/probing/hooks.py
def __init__(
    self,
    model: nn.Module,
    layer_indices: List[int],
    component: str = "residual",
    token_position: str = "all",
):
    """Initialize activation hook.

    Args:
        model: HuggingFace model to hook
        layer_indices: Which layers to capture
        component: What to capture - "residual", "attn", or "mlp"
        token_position: "all", "last", or "first"
    """
    self.model = model
    self.layer_indices = layer_indices
    self.component = component
    self.token_position = token_position
    self.cache = ActivationCache()
    self._handles: List[Any] = []

__enter__()

Register hooks on specified layers.

Source code in src/rotalabs_probe/probing/hooks.py
def __enter__(self):
    """Register hooks on specified layers."""
    self.cache.clear()
    layers = self._get_layers()

    for idx in self.layer_indices:
        if idx >= len(layers):
            raise ValueError(f"Layer {idx} out of range (model has {len(layers)} layers)")

        layer = layers[idx]
        handle = layer.register_forward_hook(self._make_hook(idx))
        self._handles.append(handle)

    return self

__exit__(exc_type, exc_val, exc_tb)

Remove hooks.

Source code in src/rotalabs_probe/probing/hooks.py
def __exit__(self, exc_type, exc_val, exc_tb):
    """Remove hooks."""
    for handle in self._handles:
        handle.remove()
    self._handles.clear()
    return False

extract_activations

Extract activations for multiple texts at specified layers.

Source code in src/rotalabs_probe/probing/extraction.py
def extract_activations(
    model,
    tokenizer,
    texts: List[str],
    layer_indices: List[int],
    token_position: Literal["last", "first", "mean"] = "last",
    show_progress: bool = True,
) -> Dict[int, torch.Tensor]:
    """Extract activations for multiple texts at specified layers."""
    # FIXME: this is slow for large datasets, could batch but hook handling is tricky
    device = next(model.parameters()).device
    model.eval()

    # Initialize storage
    layer_activations = {idx: [] for idx in layer_indices}

    iterator = tqdm(texts, desc="Extracting", disable=not show_progress)

    for text in iterator:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        hook = ActivationHook(model, layer_indices, component="residual", token_position="all")

        with hook:
            with torch.no_grad():
                model(**inputs)

        for layer_idx in layer_indices:
            activation = hook.cache.get(f"layer_{layer_idx}")
            if activation is None:
                raise RuntimeError(f"Failed to capture layer {layer_idx}")

            if token_position == "last":
                act = activation[0, -1, :]
            elif token_position == "first":
                act = activation[0, 0, :]
            else:
                act = activation[0].mean(dim=0)

            layer_activations[layer_idx].append(act.cpu())

    # Stack into tensors
    return {
        idx: torch.stack(acts)
        for idx, acts in layer_activations.items()
    }