Skip to content

rashomon-gh/drrik

Repository files navigation

Drrik

Tests Documentation

Drrik (দৃক) is a Sanskrit (সংস্কৃত) word which stands for knowledge, eye, and direction. Drrik is a framework for extracting interpretable features from the MLP layers of transformer-based Large Language Models using Sparse Autoencoders, inspired by the Towards Monosemanticity paper from Anthropic.

Installation

Setup with UV (recommended)

uv sync
source .venv/bin/activate

Setup with pip

# Create virtual environment
python -m venv .venv
source .venv/bin/activate

# Install dependencies
pip install -e .

Quick Start

Option 1: Using the CLI (Recommended)

The CLI provides a simple way to run the full pipeline with a YAML configuration file.

# Generate an example config file
drrik init-config -o config.yml

# Edit config.yml to customize your settings

# Run the full pipeline (extract -> train -> visualize)
drrik run config.yml

# Or run individual steps
drrik extract -c config.yml
drrik train -c config.yml
drrik visualize -c config.yml

Option 2: Python API

from drrik import ActivationExtractor, SparseAutoencoder, FeatureVisualizer

# 1. Extract MLP activations
extractor = ActivationExtractor(
    model_name="google/gemma-2b",  # 2B parameters, fits on 8GB VRAM
    dataset_name="wikitext",
    mlp_layers=[0],
    num_samples=1000,
)
activations, metadata = extractor.extract()

# 2. Train Sparse Autoencoder
sae = SparseAutoencoder(
    activation_dim=activations.shape[-1],
    hidden_dim=activations.shape[-1] * 8,  # 8x expansion
    l1_coefficient=0.01,
)
sae.fit(activations, num_epochs=50)

# 3. Visualize features
visualizer = FeatureVisualizer(sae, activations, metadata)
visualizer.save_all(n_features=10)

CLI Commands

The drrik CLI provides several commands:

  • drrik init-config - Generate an example YAML configuration file
  • drrik extract - Extract MLP activations from a model
  • drrik train - Train a sparse autoencoder
  • drrik visualize - Generate feature visualizations
  • drrik run - Run the full pipeline

Each command supports additional options:

drrik extract --config config.yml --output-dir ./outputs --device cuda
drrik train --config config.yml --activations ./outputs/activations.pkl
drrik visualize --config config.yml --n-features 20 --no-wandb

Activation Steering

After training an SAE, use the learned feature directions to steer model generation:

from drrik import SAESteering, SparseAutoencoder

sae = SparseAutoencoder.load("sae_model.pt")
steering = SAESteering(sae, model_name="google/gemma-2b", layer=0)

# Generate with a single feature
result = steering.generate(
    "The weather is",
    feature_idx=42,
    strength=3.0,
    max_new_tokens=50,
)

# Compare baseline vs steered at multiple strengths
comparison = steering.compare_steering(
    "The weather is",
    feature_idx=42,
    strengths=[0.0, 1.0, 2.0, 5.0],
)

# Combine multiple features
result = steering.generate(
    "The weather is",
    feature_indices=[10, 42],
    strengths=[1.0, 2.0],
)

# Baseline (no steering)
baseline = steering.generate("The weather is", max_new_tokens=50)

Example Scripts

# Basic usage example
python examples/basic_usage.py

# Advanced pipeline with custom configs
python examples/advanced_pipeline.py

# Load saved activations
python examples/load_saved_activations.py

# Using wandb integration
python examples/with_wandb.py

# SAE-based activation steering
python examples/sae_steering.py

Configuration

YAML Configuration (CLI)

The CLI uses YAML configuration files for easy setup:

# Model configuration
model_name: "google/gemma-2b"
torch_dtype: "float16"
device_map: "cpu"

# Dataset configuration
dataset_name: "wikitext"
dataset_config: "wikitext-2-raw-v1"
split: "train"
num_samples: 1000
max_length: 512
extraction_batch_size: 8

# Activation extraction
mlp_layers: [0]

# Sparse Autoencoder configuration
activation_dim: 2048
hidden_dim: 16384  # 8x expansion
l1_coefficient: 0.01
learning_rate: 0.0001
num_epochs: 50
training_batch_size: 256
validation_split: 0.1
resample_dead_neurons: true

# Visualization
n_features_to_visualize: 10

# Hardware configuration
extraction_device: "cpu"
training_device: "mps"

# Wandb integration (optional)
wandb_enabled: false
wandb_project: "drrik-experiments"

# Output
output_dir: "./drrik_output"

Generate an example config with: drrik init-config -o config.yml

Python API Configuration

The framework uses Pydantic for configuration. Key configuration classes:

ActivationExtractorConfig

from drrik.config import ActivationExtractorConfig, ModelConfig, DatasetConfig

config = ActivationExtractorConfig(
    model=ModelConfig(
        model_name="google/gemma-2b",
        torch_dtype="float16",
    ),
    dataset=DatasetConfig(
        dataset_name="wikitext",
        num_samples=1000,
        max_length=512,
    ),
    mlp_layers=[0, 1, 2],
    batch_size=8,
)

SparseAutoencoderConfig

from drrik.config import SparseAutoencoderConfig

sae_config = SparseAutoencoderConfig(
    activation_dim=2048,
    hidden_dim=4096,  # 2x expansion
    l1_coefficient=0.01,
    learning_rate=1e-4,
    resample_dead_neurons=True,
)

Environment Variables

For API keys and optional settings, create a .env file (see .env.example):

# HuggingFace Hub token (for gated models)
HUGGINGFACE_HUB_TOKEN=your_token_here

# Weights & Biases API key (optional, for experiment tracking)
WANDB_API_KEY=your_wandb_key_here

# Wandb settings
WANDB_PROJECT=drrik-experiments
WANDB_ENTITY=your_username
WANDB_MODE=online  # or 'offline' to disable

Key Features

Supported Models

Any HuggingFace transformer model with MLP layers.

Important

For Apple Silicon users: Models like gemma-2b have internal weight matrices that exceed MPS kernel limits. Use device_map: "cpu" for activation extraction and training_device: "mps" for SAE training or, use a smaller batch size and hidden dimension (meaning a smaller expansion factor).

Supported Datasets

Any dataset from HuggingFace Datasets.

SAE Features

Following the Anthropic paper:

  • Overcomplete basis: Hidden dimension > activation dimension
  • L1 sparsity: Encourages sparse feature activations
  • Decoder normalization: Prevents scaling collapse
  • Pre-encoder bias: As used in the paper
  • Dead neuron resampling: Reinitializes inactive neurons during training using a sliding window of recent batches for robust detection

Activation Steering

Use trained SAE features to steer language model generation by adding scaled decoder weight vectors to MLP activations during inference:

  • Single-feature steering: Bias output toward one learned feature direction
  • Multi-feature steering: Combine multiple feature directions with individual strengths
  • Comparison tools: Compare baseline vs steered outputs across strength levels
  • Feature analysis: Find top-activating features for a given input

Wandb Integration

Optional wandb integration for experiment tracking:

from drrik import WandbConfig, get_settings

settings = get_settings()
wandb_config = WandbConfig(
    project="drrik-experiments",
    name="my-experiment",
    config={"model": "gemma-2b", "expansion": 8},
    enabled=settings.use_wandb,  # Auto-disables if no API key
)

# Use in training
sae.fit(activations, wandb_config=wandb_config, wandb_enabled=True)

# Use in visualization
visualizer = FeatureVisualizer(
    sae=sae,
    activations=activations,
    wandb_config=wandb_config,
    log_to_wandb=True,
)

The framework automatically logs:

  • Training metrics (loss, L0 norm, dead neurons)
  • Learning rate changes
  • Activation histograms
  • Feature visualizations

Visualization Outputs

The framework generates several visualizations:

  1. Feature Density Histogram - Distribution of feature firing rates
  2. Training Curves - Loss and L0 norm over training
  3. Top Features - Features ranked by density/activation
  4. Feature Dashboards - Comprehensive view per feature
  5. Activation Histograms - Distribution of activations per feature

All plots can be saved locally and optionally logged to wandb.

API Reference

The full API documentation with Google-style docstrings is available in the generated docs:

# Regenerate docs (requires pdoc)
.venv/bin/pdoc drrik -o docs --docformat google

# View locally
open docs/drrik.html

ActivationExtractor

extractor = ActivationExtractor(
    model_name="google/gemma-2b",
    dataset_name="wikitext",
    num_samples=1000,
    mlp_layers=[0],
    batch_size=8,
)

activations, metadata = extractor.extract()

SparseAutoencoder

sae = SparseAutoencoder(
    activation_dim: int,
    hidden_dim: int,
    l1_coefficient: float = 0.01,
)

sae.fit(
    activations: np.ndarray,
    batch_size: int = 256,
    num_epochs: int = 100,
    learning_rate: float = 1e-4,
    resample_dead_neurons: bool = True,
    resample_interval: int = 10000,
    dead_threshold: float = 1e-8,
    window_size: int = 100,
    wandb_config: Optional[WandbConfig] = None,
    wandb_enabled: bool = False,
)

features = sae.encode(activations)
reconstructed = sae.decode(features)

FeatureVisualizer

visualizer = FeatureVisualizer(
    sae: SparseAutoencoder,
    activations: np.ndarray,
    metadata: Optional[Dict] = None,
    output_dir: str = "./visualizations",
    wandb_config: Optional[WandbConfig] = None,
    log_to_wandb: bool = False,
)

visualizer.plot_feature_density()
visualizer.plot_top_features(n_features=10)
visualizer.create_feature_dashboard(feature_idx=0)
visualizer.save_all(n_features=10)

SAESteering

from drrik import SAESteering, SparseAutoencoder

sae = SparseAutoencoder.load("sae_model.pt")
steering = SAESteering(
    sae=sae,
    model_name="google/gemma-2b",
    layer=0,               # target MLP layer
    device_map="auto",
    # token is read from HUGGINGFACE_HUB_TOKEN in .env by default
)

# Steered generation
output = steering.generate("The sky is", feature_idx=42, strength=2.0)

# Find top features for an input
features = steering.find_steering_features("text", activations, top_k=20)

# Get raw steering direction vector
direction = steering.get_steering_direction(42, normalize=True)

# Save comparison results
steering.save_steering_analysis(results, "./output", prompt="The sky is")

Testing

Run the test suite:

# Run all tests
pytest

# Skip slow tests
pytest -m "not slow"

References

License

AGPL-3.0

Contributing

Contributions welcome! Please feel free to submit issues or pull requests.

About

A framework to extract activated features from the MLP layers of LLMs

Topics

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages