Skip to content

Probing

Activation probing for interpretability analysis.

Optional Dependency

The probing module requires pip install rotalabs-probe[gpu] which includes PyTorch, Transformers, and scikit-learn.

Available Components

Component Description
ActivationHook Extract activations from model layers
LinearProbe Train linear classifiers on activations
SteeringVector Compute and apply steering vectors

Module

Activation probing for sandbagging detection.

This module provides tools for analyzing model activations to detect sandbagging behavior at the representation level.

Key components: - ActivationHook: Capture hidden states during forward pass - SteeringVector: Represent behavioral directions in activation space - extract_caa_vector: Extract vectors using Contrastive Activation Addition - LinearProbe: Train classifiers on activation patterns

ActivationHook

Hook for capturing activations from specific model layers.

Works with HuggingFace transformers models (GPT-2, Mistral, Llama, etc).

Example

hook = ActivationHook(model, layer_indices=[10, 15, 20]) with hook: ... outputs = model(**inputs) act = hook.cache.get("layer_15") # (batch, seq, hidden)

Source code in src/rotalabs_probe/probing/hooks.py
class ActivationHook:
    """Hook for capturing activations from specific model layers.

    Works with HuggingFace transformers models (GPT-2, Mistral, Llama, etc).

    Example:
        >>> hook = ActivationHook(model, layer_indices=[10, 15, 20])
        >>> with hook:
        ...     outputs = model(**inputs)
        >>> act = hook.cache.get("layer_15")  # (batch, seq, hidden)
    """

    def __init__(
        self,
        model: nn.Module,
        layer_indices: List[int],
        component: str = "residual",
        token_position: str = "all",
    ):
        """Initialize activation hook.

        Args:
            model: HuggingFace model to hook
            layer_indices: Which layers to capture
            component: What to capture - "residual", "attn", or "mlp"
            token_position: "all", "last", or "first"
        """
        self.model = model
        self.layer_indices = layer_indices
        self.component = component
        self.token_position = token_position
        self.cache = ActivationCache()
        self._handles: List[Any] = []

    def _get_layers(self) -> nn.ModuleList:
        """Get the transformer layers from the model."""
        # XXX: this is ugly but HF doesn't have a consistent API for this
        if hasattr(self.model, "model"):
            inner = self.model.model
            if hasattr(inner, "layers"):
                return inner.layers  # Llama, Mistral
            elif hasattr(inner, "decoder"):
                return inner.decoder.layers
        if hasattr(self.model, "transformer"):
            if hasattr(self.model.transformer, "h"):
                return self.model.transformer.h  # GPT-2
        if hasattr(self.model, "gpt_neox"):
            return self.model.gpt_neox.layers

        # TODO: add support for more architectures as needed
        raise ValueError("Could not find transformer layers in model architecture")

    def _make_hook(self, layer_idx: int):
        """Create a hook function for a specific layer."""
        def hook_fn(module, input, output):
            # Handle different output formats
            if isinstance(output, tuple):
                hidden_states = output[0]
            else:
                hidden_states = output

            # Store based on token position
            if self.token_position == "last":
                self.cache.store(f"layer_{layer_idx}", hidden_states[:, -1:, :])
            elif self.token_position == "first":
                self.cache.store(f"layer_{layer_idx}", hidden_states[:, :1, :])
            else:  # all
                self.cache.store(f"layer_{layer_idx}", hidden_states)

        return hook_fn

    def __enter__(self):
        """Register hooks on specified layers."""
        self.cache.clear()
        layers = self._get_layers()

        for idx in self.layer_indices:
            if idx >= len(layers):
                raise ValueError(f"Layer {idx} out of range (model has {len(layers)} layers)")

            layer = layers[idx]
            handle = layer.register_forward_hook(self._make_hook(idx))
            self._handles.append(handle)

        return self

    def __exit__(self, exc_type, exc_val, exc_tb):
        """Remove hooks."""
        for handle in self._handles:
            handle.remove()
        self._handles.clear()
        return False

__init__(model: nn.Module, layer_indices: List[int], component: str = 'residual', token_position: str = 'all')

Initialize activation hook.

Parameters:

Name Type Description Default
model Module

HuggingFace model to hook

required
layer_indices List[int]

Which layers to capture

required
component str

What to capture - "residual", "attn", or "mlp"

'residual'
token_position str

"all", "last", or "first"

'all'
Source code in src/rotalabs_probe/probing/hooks.py
def __init__(
    self,
    model: nn.Module,
    layer_indices: List[int],
    component: str = "residual",
    token_position: str = "all",
):
    """Initialize activation hook.

    Args:
        model: HuggingFace model to hook
        layer_indices: Which layers to capture
        component: What to capture - "residual", "attn", or "mlp"
        token_position: "all", "last", or "first"
    """
    self.model = model
    self.layer_indices = layer_indices
    self.component = component
    self.token_position = token_position
    self.cache = ActivationCache()
    self._handles: List[Any] = []

__enter__()

Register hooks on specified layers.

Source code in src/rotalabs_probe/probing/hooks.py
def __enter__(self):
    """Register hooks on specified layers."""
    self.cache.clear()
    layers = self._get_layers()

    for idx in self.layer_indices:
        if idx >= len(layers):
            raise ValueError(f"Layer {idx} out of range (model has {len(layers)} layers)")

        layer = layers[idx]
        handle = layer.register_forward_hook(self._make_hook(idx))
        self._handles.append(handle)

    return self

__exit__(exc_type, exc_val, exc_tb)

Remove hooks.

Source code in src/rotalabs_probe/probing/hooks.py
def __exit__(self, exc_type, exc_val, exc_tb):
    """Remove hooks."""
    for handle in self._handles:
        handle.remove()
    self._handles.clear()
    return False

LinearProbe

Linear probe for binary classification on activation patterns.

Uses logistic regression to learn a linear decision boundary in activation space for detecting sandbagging behavior.

Example

probe = LinearProbe() probe.fit(X_train, y_train) predictions = probe.predict(X_test) accuracy = probe.score(X_test, y_test)

Source code in src/rotalabs_probe/probing/probes.py
class LinearProbe:
    """Linear probe for binary classification on activation patterns.

    Uses logistic regression to learn a linear decision boundary
    in activation space for detecting sandbagging behavior.

    Example:
        >>> probe = LinearProbe()
        >>> probe.fit(X_train, y_train)
        >>> predictions = probe.predict(X_test)
        >>> accuracy = probe.score(X_test, y_test)
    """

    def __init__(
        self,
        C: float = 1.0,
        max_iter: int = 1000,
        random_state: int = 42,
    ):
        """Initialize linear probe.

        Args:
            C: Inverse regularization strength
            max_iter: Maximum iterations for optimization
            random_state: Random seed for reproducibility
        """
        if not SKLEARN_AVAILABLE:
            raise ImportError("sklearn required for LinearProbe. Install: pip install scikit-learn")

        self.C = C
        self.max_iter = max_iter
        self.random_state = random_state
        self._model: Optional[LogisticRegression] = None
        self._cv_scores: Optional[np.ndarray] = None

    def fit(
        self,
        X: np.ndarray,
        y: np.ndarray,
        cv_folds: int = 5,
    ) -> "LinearProbe":
        """Fit the probe to training data.

        Args:
            X: Activation vectors (n_samples, hidden_dim)
            y: Binary labels (0=genuine, 1=sandbagging)
            cv_folds: Number of cross-validation folds

        Returns:
            self
        """
        self._model = LogisticRegression(
            penalty="l2",
            C=self.C,
            solver="lbfgs",
            max_iter=self.max_iter,
            random_state=self.random_state,
        )

        # Cross-validation for accuracy estimate
        self._cv_scores = cross_val_score(
            self._model, X, y, cv=cv_folds, scoring="accuracy"
        )

        # Fit on full data
        self._model.fit(X, y)

        return self

    def predict(self, X: np.ndarray) -> np.ndarray:
        """Predict labels for new activations."""
        if self._model is None:
            raise RuntimeError("Probe not fitted. Call fit() first.")
        return self._model.predict(X)

    def predict_proba(self, X: np.ndarray) -> np.ndarray:
        """Get probability estimates for each class."""
        if self._model is None:
            raise RuntimeError("Probe not fitted. Call fit() first.")
        return self._model.predict_proba(X)

    def score(self, X: np.ndarray, y: np.ndarray) -> float:
        """Compute accuracy on test data."""
        if self._model is None:
            raise RuntimeError("Probe not fitted. Call fit() first.")
        return self._model.score(X, y)

    @property
    def cv_accuracy(self) -> float:
        """Mean cross-validation accuracy."""
        if self._cv_scores is None:
            raise RuntimeError("Probe not fitted. Call fit() first.")
        return self._cv_scores.mean()

    @property
    def cv_std(self) -> float:
        """Standard deviation of cross-validation accuracy."""
        if self._cv_scores is None:
            raise RuntimeError("Probe not fitted. Call fit() first.")
        return self._cv_scores.std()

    @property
    def coef(self) -> np.ndarray:
        """Coefficients of the linear classifier (the probe direction)."""
        if self._model is None:
            raise RuntimeError("Probe not fitted. Call fit() first.")
        return self._model.coef_[0]

    def save(self, path: Path) -> None:
        """Save probe to disk."""
        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)
        with open(path, "wb") as f:
            pickle.dump({
                "model": self._model,
                "cv_scores": self._cv_scores,
                "C": self.C,
                "max_iter": self.max_iter,
                "random_state": self.random_state,
            }, f)

    @classmethod
    def load(cls, path: Path) -> "LinearProbe":
        """Load probe from disk."""
        with open(path, "rb") as f:
            data = pickle.load(f)

        probe = cls(
            C=data["C"],
            max_iter=data["max_iter"],
            random_state=data["random_state"],
        )
        probe._model = data["model"]
        probe._cv_scores = data["cv_scores"]
        return probe

cv_accuracy: float property

Mean cross-validation accuracy.

cv_std: float property

Standard deviation of cross-validation accuracy.

coef: np.ndarray property

Coefficients of the linear classifier (the probe direction).

__init__(C: float = 1.0, max_iter: int = 1000, random_state: int = 42)

Initialize linear probe.

Parameters:

Name Type Description Default
C float

Inverse regularization strength

1.0
max_iter int

Maximum iterations for optimization

1000
random_state int

Random seed for reproducibility

42
Source code in src/rotalabs_probe/probing/probes.py
def __init__(
    self,
    C: float = 1.0,
    max_iter: int = 1000,
    random_state: int = 42,
):
    """Initialize linear probe.

    Args:
        C: Inverse regularization strength
        max_iter: Maximum iterations for optimization
        random_state: Random seed for reproducibility
    """
    if not SKLEARN_AVAILABLE:
        raise ImportError("sklearn required for LinearProbe. Install: pip install scikit-learn")

    self.C = C
    self.max_iter = max_iter
    self.random_state = random_state
    self._model: Optional[LogisticRegression] = None
    self._cv_scores: Optional[np.ndarray] = None

fit(X: np.ndarray, y: np.ndarray, cv_folds: int = 5) -> LinearProbe

Fit the probe to training data.

Parameters:

Name Type Description Default
X ndarray

Activation vectors (n_samples, hidden_dim)

required
y ndarray

Binary labels (0=genuine, 1=sandbagging)

required
cv_folds int

Number of cross-validation folds

5

Returns:

Type Description
LinearProbe

self

Source code in src/rotalabs_probe/probing/probes.py
def fit(
    self,
    X: np.ndarray,
    y: np.ndarray,
    cv_folds: int = 5,
) -> "LinearProbe":
    """Fit the probe to training data.

    Args:
        X: Activation vectors (n_samples, hidden_dim)
        y: Binary labels (0=genuine, 1=sandbagging)
        cv_folds: Number of cross-validation folds

    Returns:
        self
    """
    self._model = LogisticRegression(
        penalty="l2",
        C=self.C,
        solver="lbfgs",
        max_iter=self.max_iter,
        random_state=self.random_state,
    )

    # Cross-validation for accuracy estimate
    self._cv_scores = cross_val_score(
        self._model, X, y, cv=cv_folds, scoring="accuracy"
    )

    # Fit on full data
    self._model.fit(X, y)

    return self

predict(X: np.ndarray) -> np.ndarray

Predict labels for new activations.

Source code in src/rotalabs_probe/probing/probes.py
def predict(self, X: np.ndarray) -> np.ndarray:
    """Predict labels for new activations."""
    if self._model is None:
        raise RuntimeError("Probe not fitted. Call fit() first.")
    return self._model.predict(X)

predict_proba(X: np.ndarray) -> np.ndarray

Get probability estimates for each class.

Source code in src/rotalabs_probe/probing/probes.py
def predict_proba(self, X: np.ndarray) -> np.ndarray:
    """Get probability estimates for each class."""
    if self._model is None:
        raise RuntimeError("Probe not fitted. Call fit() first.")
    return self._model.predict_proba(X)

score(X: np.ndarray, y: np.ndarray) -> float

Compute accuracy on test data.

Source code in src/rotalabs_probe/probing/probes.py
def score(self, X: np.ndarray, y: np.ndarray) -> float:
    """Compute accuracy on test data."""
    if self._model is None:
        raise RuntimeError("Probe not fitted. Call fit() first.")
    return self._model.score(X, y)

save(path: Path) -> None

Save probe to disk.

Source code in src/rotalabs_probe/probing/probes.py
def save(self, path: Path) -> None:
    """Save probe to disk."""
    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)
    with open(path, "wb") as f:
        pickle.dump({
            "model": self._model,
            "cv_scores": self._cv_scores,
            "C": self.C,
            "max_iter": self.max_iter,
            "random_state": self.random_state,
        }, f)

load(path: Path) -> LinearProbe classmethod

Load probe from disk.

Source code in src/rotalabs_probe/probing/probes.py
@classmethod
def load(cls, path: Path) -> "LinearProbe":
    """Load probe from disk."""
    with open(path, "rb") as f:
        data = pickle.load(f)

    probe = cls(
        C=data["C"],
        max_iter=data["max_iter"],
        random_state=data["random_state"],
    )
    probe._model = data["model"]
    probe._cv_scores = data["cv_scores"]
    return probe

SteeringVector dataclass

A vector in activation space representing a behavioral direction.

Created by computing mean(positive_activations) - mean(negative_activations) using Contrastive Activation Addition (CAA).

Attributes:

Name Type Description
behavior str

Name of the behavior (e.g., "sandbagging")

layer_index int

Which layer this vector was extracted from

vector Tensor

The actual steering vector tensor

model_name str

Model used for extraction

extraction_method str

Method used (typically "caa")

metadata Dict[str, Any]

Additional extraction details

Source code in src/rotalabs_probe/probing/vectors.py
@dataclass
class SteeringVector:
    """A vector in activation space representing a behavioral direction.

    Created by computing mean(positive_activations) - mean(negative_activations)
    using Contrastive Activation Addition (CAA).

    Attributes:
        behavior: Name of the behavior (e.g., "sandbagging")
        layer_index: Which layer this vector was extracted from
        vector: The actual steering vector tensor
        model_name: Model used for extraction
        extraction_method: Method used (typically "caa")
        metadata: Additional extraction details
    """

    behavior: str
    layer_index: int
    vector: torch.Tensor
    model_name: str = "unknown"
    extraction_method: str = "caa"
    metadata: Dict[str, Any] = field(default_factory=dict)

    @property
    def norm(self) -> float:
        """L2 norm of the steering vector."""
        return self.vector.norm().item()

    @property
    def dim(self) -> int:
        """Dimensionality of the vector."""
        return self.vector.shape[-1]

    def to(self, device: str) -> "SteeringVector":
        """Move vector to specified device."""
        return SteeringVector(
            behavior=self.behavior,
            layer_index=self.layer_index,
            vector=self.vector.to(device),
            model_name=self.model_name,
            extraction_method=self.extraction_method,
            metadata=self.metadata,
        )

    def normalize(self) -> "SteeringVector":
        """Return unit-normalized version of this vector."""
        return SteeringVector(
            behavior=self.behavior,
            layer_index=self.layer_index,
            vector=self.vector / self.norm,
            model_name=self.model_name,
            extraction_method=self.extraction_method,
            metadata={**self.metadata, "normalized": True},
        )

    def save(self, path: Path) -> None:
        """Save vector to disk.

        Creates:
            - {path}.pt: The vector tensor
            - {path}_meta.json: Metadata
        """
        import json

        path = Path(path)
        path.parent.mkdir(parents=True, exist_ok=True)

        # Save tensor
        torch.save(self.vector, f"{path}.pt")

        # Save metadata
        meta = {
            "behavior": self.behavior,
            "layer_index": self.layer_index,
            "model_name": self.model_name,
            "extraction_method": self.extraction_method,
            "norm": self.norm,
            "dim": self.dim,
            **self.metadata,
        }
        with open(f"{path}_meta.json", "w") as f:
            json.dump(meta, f, indent=2)

    @classmethod
    def load(cls, path: Path) -> "SteeringVector":
        """Load vector from disk."""
        import json

        path = Path(path)

        # Load tensor
        vector = torch.load(f"{path}.pt", weights_only=True)

        # Load metadata
        with open(f"{path}_meta.json") as f:
            meta = json.load(f)

        return cls(
            behavior=meta["behavior"],
            layer_index=meta["layer_index"],
            vector=vector,
            model_name=meta.get("model_name", "unknown"),
            extraction_method=meta.get("extraction_method", "caa"),
            metadata={k: v for k, v in meta.items()
                      if k not in ["behavior", "layer_index", "model_name",
                                   "extraction_method", "norm", "dim"]},
        )

    def cosine_similarity(self, other: "SteeringVector") -> float:
        """Compute cosine similarity with another vector."""
        return torch.nn.functional.cosine_similarity(
            self.vector.unsqueeze(0),
            other.vector.unsqueeze(0),
        ).item()

    def __repr__(self) -> str:
        return (
            f"SteeringVector(behavior='{self.behavior}', "
            f"layer={self.layer_index}, dim={self.dim}, norm={self.norm:.4f})"
        )

norm: float property

L2 norm of the steering vector.

dim: int property

Dimensionality of the vector.

to(device: str) -> SteeringVector

Move vector to specified device.

Source code in src/rotalabs_probe/probing/vectors.py
def to(self, device: str) -> "SteeringVector":
    """Move vector to specified device."""
    return SteeringVector(
        behavior=self.behavior,
        layer_index=self.layer_index,
        vector=self.vector.to(device),
        model_name=self.model_name,
        extraction_method=self.extraction_method,
        metadata=self.metadata,
    )

normalize() -> SteeringVector

Return unit-normalized version of this vector.

Source code in src/rotalabs_probe/probing/vectors.py
def normalize(self) -> "SteeringVector":
    """Return unit-normalized version of this vector."""
    return SteeringVector(
        behavior=self.behavior,
        layer_index=self.layer_index,
        vector=self.vector / self.norm,
        model_name=self.model_name,
        extraction_method=self.extraction_method,
        metadata={**self.metadata, "normalized": True},
    )

save(path: Path) -> None

Save vector to disk.

Creates
  • {path}.pt: The vector tensor
  • {path}_meta.json: Metadata
Source code in src/rotalabs_probe/probing/vectors.py
def save(self, path: Path) -> None:
    """Save vector to disk.

    Creates:
        - {path}.pt: The vector tensor
        - {path}_meta.json: Metadata
    """
    import json

    path = Path(path)
    path.parent.mkdir(parents=True, exist_ok=True)

    # Save tensor
    torch.save(self.vector, f"{path}.pt")

    # Save metadata
    meta = {
        "behavior": self.behavior,
        "layer_index": self.layer_index,
        "model_name": self.model_name,
        "extraction_method": self.extraction_method,
        "norm": self.norm,
        "dim": self.dim,
        **self.metadata,
    }
    with open(f"{path}_meta.json", "w") as f:
        json.dump(meta, f, indent=2)

load(path: Path) -> SteeringVector classmethod

Load vector from disk.

Source code in src/rotalabs_probe/probing/vectors.py
@classmethod
def load(cls, path: Path) -> "SteeringVector":
    """Load vector from disk."""
    import json

    path = Path(path)

    # Load tensor
    vector = torch.load(f"{path}.pt", weights_only=True)

    # Load metadata
    with open(f"{path}_meta.json") as f:
        meta = json.load(f)

    return cls(
        behavior=meta["behavior"],
        layer_index=meta["layer_index"],
        vector=vector,
        model_name=meta.get("model_name", "unknown"),
        extraction_method=meta.get("extraction_method", "caa"),
        metadata={k: v for k, v in meta.items()
                  if k not in ["behavior", "layer_index", "model_name",
                               "extraction_method", "norm", "dim"]},
    )

cosine_similarity(other: SteeringVector) -> float

Compute cosine similarity with another vector.

Source code in src/rotalabs_probe/probing/vectors.py
def cosine_similarity(self, other: "SteeringVector") -> float:
    """Compute cosine similarity with another vector."""
    return torch.nn.functional.cosine_similarity(
        self.vector.unsqueeze(0),
        other.vector.unsqueeze(0),
    ).item()

extract_activations(model, tokenizer, texts: List[str], layer_indices: List[int], token_position: Literal['last', 'first', 'mean'] = 'last', show_progress: bool = True) -> Dict[int, torch.Tensor]

Extract activations for multiple texts at specified layers.

Source code in src/rotalabs_probe/probing/extraction.py
def extract_activations(
    model,
    tokenizer,
    texts: List[str],
    layer_indices: List[int],
    token_position: Literal["last", "first", "mean"] = "last",
    show_progress: bool = True,
) -> Dict[int, torch.Tensor]:
    """Extract activations for multiple texts at specified layers."""
    # FIXME: this is slow for large datasets, could batch but hook handling is tricky
    device = next(model.parameters()).device
    model.eval()

    # Initialize storage
    layer_activations = {idx: [] for idx in layer_indices}

    iterator = tqdm(texts, desc="Extracting", disable=not show_progress)

    for text in iterator:
        inputs = tokenizer(text, return_tensors="pt", truncation=True, max_length=512)
        inputs = {k: v.to(device) for k, v in inputs.items()}

        hook = ActivationHook(model, layer_indices, component="residual", token_position="all")

        with hook:
            with torch.no_grad():
                model(**inputs)

        for layer_idx in layer_indices:
            activation = hook.cache.get(f"layer_{layer_idx}")
            if activation is None:
                raise RuntimeError(f"Failed to capture layer {layer_idx}")

            if token_position == "last":
                act = activation[0, -1, :]
            elif token_position == "first":
                act = activation[0, 0, :]
            else:
                act = activation[0].mean(dim=0)

            layer_activations[layer_idx].append(act.cpu())

    # Stack into tensors
    return {
        idx: torch.stack(acts)
        for idx, acts in layer_activations.items()
    }

extract_caa_vector(model, tokenizer, contrast_pairs: List[Dict[str, str]], layer_idx: int, token_position: Literal['last', 'first', 'mean'] = 'last', behavior: str = 'sandbagging', show_progress: bool = True) -> SteeringVector

Extract steering vector using Contrastive Activation Addition.

The core idea: compute mean(positive_acts) - mean(negative_acts) to find the direction in activation space that corresponds to the target behavior.

Parameters:

Name Type Description Default
model

HuggingFace model

required
tokenizer

Corresponding tokenizer

required
contrast_pairs List[Dict[str, str]]

List of dicts with "positive" and "negative" keys

required
layer_idx int

Which layer to extract from

required
token_position Literal['last', 'first', 'mean']

Which token position to use

'last'
behavior str

Name of the behavior being extracted

'sandbagging'
show_progress bool

Show progress bar

True

Returns:

Type Description
SteeringVector

SteeringVector for the extracted direction

Source code in src/rotalabs_probe/probing/extraction.py
def extract_caa_vector(
    model,
    tokenizer,
    contrast_pairs: List[Dict[str, str]],
    layer_idx: int,
    token_position: Literal["last", "first", "mean"] = "last",
    behavior: str = "sandbagging",
    show_progress: bool = True,
) -> SteeringVector:
    """Extract steering vector using Contrastive Activation Addition.

    The core idea: compute mean(positive_acts) - mean(negative_acts)
    to find the direction in activation space that corresponds to
    the target behavior.

    Args:
        model: HuggingFace model
        tokenizer: Corresponding tokenizer
        contrast_pairs: List of dicts with "positive" and "negative" keys
        layer_idx: Which layer to extract from
        token_position: Which token position to use
        behavior: Name of the behavior being extracted
        show_progress: Show progress bar

    Returns:
        SteeringVector for the extracted direction
    """
    device = next(model.parameters()).device
    model.eval()

    positive_activations = []
    negative_activations = []

    iterator = tqdm(contrast_pairs, desc=f"Layer {layer_idx}", disable=not show_progress)

    for pair in iterator:
        pos_text = pair["positive"]
        neg_text = pair["negative"]

        # Extract positive activation
        pos_act = _get_activation(
            model, tokenizer, pos_text, layer_idx, token_position, device
        )
        positive_activations.append(pos_act)

        # Extract negative activation
        neg_act = _get_activation(
            model, tokenizer, neg_text, layer_idx, token_position, device
        )
        negative_activations.append(neg_act)

    # Compute mean activations
    pos_mean = torch.stack(positive_activations).mean(dim=0)
    neg_mean = torch.stack(negative_activations).mean(dim=0)

    # NOTE: this is the core of CAA - surprisingly simple but it works
    # see the original paper for theoretical justification
    steering_vector = pos_mean - neg_mean

    model_name = getattr(model.config, "_name_or_path", "unknown")

    return SteeringVector(
        behavior=behavior,
        layer_index=layer_idx,
        vector=steering_vector.cpu(),
        model_name=model_name,
        extraction_method="caa",
        metadata={
            "num_pairs": len(contrast_pairs),
            "token_position": token_position,
            "pos_mean_norm": pos_mean.norm().item(),
            "neg_mean_norm": neg_mean.norm().item(),
        },
    )