Drrik (দৃক) is a Sanskrit (সংস্কৃত) word which stands for knowledge, eye, and direction. Drrik is a framework for extracting interpretable features from the MLP layers of transformer-based Large Language Models using Sparse Autoencoders, inspired by the Towards Monosemanticity paper from Anthropic.
uv sync
source .venv/bin/activate# Create virtual environment
python -m venv .venv
source .venv/bin/activate
# Install dependencies
pip install -e .The CLI provides a simple way to run the full pipeline with a YAML configuration file.
# Generate an example config file
drrik init-config -o config.yml
# Edit config.yml to customize your settings
# Run the full pipeline (extract -> train -> visualize)
drrik run config.yml
# Or run individual steps
drrik extract -c config.yml
drrik train -c config.yml
drrik visualize -c config.ymlfrom drrik import ActivationExtractor, SparseAutoencoder, FeatureVisualizer
# 1. Extract MLP activations
extractor = ActivationExtractor(
model_name="google/gemma-2b", # 2B parameters, fits on 8GB VRAM
dataset_name="wikitext",
mlp_layers=[0],
num_samples=1000,
)
activations, metadata = extractor.extract()
# 2. Train Sparse Autoencoder
sae = SparseAutoencoder(
activation_dim=activations.shape[-1],
hidden_dim=activations.shape[-1] * 8, # 8x expansion
l1_coefficient=0.01,
)
sae.fit(activations, num_epochs=50)
# 3. Visualize features
visualizer = FeatureVisualizer(sae, activations, metadata)
visualizer.save_all(n_features=10)The drrik CLI provides several commands:
drrik init-config- Generate an example YAML configuration filedrrik extract- Extract MLP activations from a modeldrrik train- Train a sparse autoencoderdrrik visualize- Generate feature visualizationsdrrik run- Run the full pipeline
Each command supports additional options:
drrik extract --config config.yml --output-dir ./outputs --device cuda
drrik train --config config.yml --activations ./outputs/activations.pkl
drrik visualize --config config.yml --n-features 20 --no-wandbAfter training an SAE, use the learned feature directions to steer model generation:
from drrik import SAESteering, SparseAutoencoder
sae = SparseAutoencoder.load("sae_model.pt")
steering = SAESteering(sae, model_name="google/gemma-2b", layer=0)
# Generate with a single feature
result = steering.generate(
"The weather is",
feature_idx=42,
strength=3.0,
max_new_tokens=50,
)
# Compare baseline vs steered at multiple strengths
comparison = steering.compare_steering(
"The weather is",
feature_idx=42,
strengths=[0.0, 1.0, 2.0, 5.0],
)
# Combine multiple features
result = steering.generate(
"The weather is",
feature_indices=[10, 42],
strengths=[1.0, 2.0],
)
# Baseline (no steering)
baseline = steering.generate("The weather is", max_new_tokens=50)# Basic usage example
python examples/basic_usage.py
# Advanced pipeline with custom configs
python examples/advanced_pipeline.py
# Load saved activations
python examples/load_saved_activations.py
# Using wandb integration
python examples/with_wandb.py
# SAE-based activation steering
python examples/sae_steering.pyThe CLI uses YAML configuration files for easy setup:
# Model configuration
model_name: "google/gemma-2b"
torch_dtype: "float16"
device_map: "cpu"
# Dataset configuration
dataset_name: "wikitext"
dataset_config: "wikitext-2-raw-v1"
split: "train"
num_samples: 1000
max_length: 512
extraction_batch_size: 8
# Activation extraction
mlp_layers: [0]
# Sparse Autoencoder configuration
activation_dim: 2048
hidden_dim: 16384 # 8x expansion
l1_coefficient: 0.01
learning_rate: 0.0001
num_epochs: 50
training_batch_size: 256
validation_split: 0.1
resample_dead_neurons: true
# Visualization
n_features_to_visualize: 10
# Hardware configuration
extraction_device: "cpu"
training_device: "mps"
# Wandb integration (optional)
wandb_enabled: false
wandb_project: "drrik-experiments"
# Output
output_dir: "./drrik_output"Generate an example config with: drrik init-config -o config.yml
The framework uses Pydantic for configuration. Key configuration classes:
from drrik.config import ActivationExtractorConfig, ModelConfig, DatasetConfig
config = ActivationExtractorConfig(
model=ModelConfig(
model_name="google/gemma-2b",
torch_dtype="float16",
),
dataset=DatasetConfig(
dataset_name="wikitext",
num_samples=1000,
max_length=512,
),
mlp_layers=[0, 1, 2],
batch_size=8,
)from drrik.config import SparseAutoencoderConfig
sae_config = SparseAutoencoderConfig(
activation_dim=2048,
hidden_dim=4096, # 2x expansion
l1_coefficient=0.01,
learning_rate=1e-4,
resample_dead_neurons=True,
)For API keys and optional settings, create a .env file (see .env.example):
# HuggingFace Hub token (for gated models)
HUGGINGFACE_HUB_TOKEN=your_token_here
# Weights & Biases API key (optional, for experiment tracking)
WANDB_API_KEY=your_wandb_key_here
# Wandb settings
WANDB_PROJECT=drrik-experiments
WANDB_ENTITY=your_username
WANDB_MODE=online # or 'offline' to disableAny HuggingFace transformer model with MLP layers.
Important
For Apple Silicon users: Models like gemma-2b have internal weight matrices that exceed MPS kernel limits. Use device_map: "cpu" for activation extraction and training_device: "mps" for SAE training or, use a smaller batch size and hidden dimension (meaning a smaller expansion factor).
Any dataset from HuggingFace Datasets.
Following the Anthropic paper:
- Overcomplete basis: Hidden dimension > activation dimension
- L1 sparsity: Encourages sparse feature activations
- Decoder normalization: Prevents scaling collapse
- Pre-encoder bias: As used in the paper
- Dead neuron resampling: Reinitializes inactive neurons during training using a sliding window of recent batches for robust detection
Use trained SAE features to steer language model generation by adding scaled decoder weight vectors to MLP activations during inference:
- Single-feature steering: Bias output toward one learned feature direction
- Multi-feature steering: Combine multiple feature directions with individual strengths
- Comparison tools: Compare baseline vs steered outputs across strength levels
- Feature analysis: Find top-activating features for a given input
Optional wandb integration for experiment tracking:
from drrik import WandbConfig, get_settings
settings = get_settings()
wandb_config = WandbConfig(
project="drrik-experiments",
name="my-experiment",
config={"model": "gemma-2b", "expansion": 8},
enabled=settings.use_wandb, # Auto-disables if no API key
)
# Use in training
sae.fit(activations, wandb_config=wandb_config, wandb_enabled=True)
# Use in visualization
visualizer = FeatureVisualizer(
sae=sae,
activations=activations,
wandb_config=wandb_config,
log_to_wandb=True,
)The framework automatically logs:
- Training metrics (loss, L0 norm, dead neurons)
- Learning rate changes
- Activation histograms
- Feature visualizations
The framework generates several visualizations:
- Feature Density Histogram - Distribution of feature firing rates
- Training Curves - Loss and L0 norm over training
- Top Features - Features ranked by density/activation
- Feature Dashboards - Comprehensive view per feature
- Activation Histograms - Distribution of activations per feature
All plots can be saved locally and optionally logged to wandb.
The full API documentation with Google-style docstrings is available in the generated docs:
# Regenerate docs (requires pdoc)
.venv/bin/pdoc drrik -o docs --docformat google
# View locally
open docs/drrik.htmlextractor = ActivationExtractor(
model_name="google/gemma-2b",
dataset_name="wikitext",
num_samples=1000,
mlp_layers=[0],
batch_size=8,
)
activations, metadata = extractor.extract()sae = SparseAutoencoder(
activation_dim: int,
hidden_dim: int,
l1_coefficient: float = 0.01,
)
sae.fit(
activations: np.ndarray,
batch_size: int = 256,
num_epochs: int = 100,
learning_rate: float = 1e-4,
resample_dead_neurons: bool = True,
resample_interval: int = 10000,
dead_threshold: float = 1e-8,
window_size: int = 100,
wandb_config: Optional[WandbConfig] = None,
wandb_enabled: bool = False,
)
features = sae.encode(activations)
reconstructed = sae.decode(features)visualizer = FeatureVisualizer(
sae: SparseAutoencoder,
activations: np.ndarray,
metadata: Optional[Dict] = None,
output_dir: str = "./visualizations",
wandb_config: Optional[WandbConfig] = None,
log_to_wandb: bool = False,
)
visualizer.plot_feature_density()
visualizer.plot_top_features(n_features=10)
visualizer.create_feature_dashboard(feature_idx=0)
visualizer.save_all(n_features=10)from drrik import SAESteering, SparseAutoencoder
sae = SparseAutoencoder.load("sae_model.pt")
steering = SAESteering(
sae=sae,
model_name="google/gemma-2b",
layer=0, # target MLP layer
device_map="auto",
# token is read from HUGGINGFACE_HUB_TOKEN in .env by default
)
# Steered generation
output = steering.generate("The sky is", feature_idx=42, strength=2.0)
# Find top features for an input
features = steering.find_steering_features("text", activations, top_k=20)
# Get raw steering direction vector
direction = steering.get_steering_direction(42, normalize=True)
# Save comparison results
steering.save_steering_analysis(results, "./output", prompt="The sky is")Run the test suite:
# Run all tests
pytest
# Skip slow tests
pytest -m "not slow"- Towards Monosemanticity: Decomposing Language Models With Dictionary Learning - Anthropic
- nnsight Library - For activation extraction
- Toy Models of Superposition - Anthropic
AGPL-3.0
Contributions welcome! Please feel free to submit issues or pull requests.