跳到主要内容

稀疏自编码器训练

提供使用SAELens训练和分析稀疏自编码器(SAEs)的指导,以将神经网络激活分解为可解释的特征。适用于发现可解释特征、分析叠加现象或研究语言模型中的单义表示时使用。

技能元数据

来源可选 — 使用 hermes skills install official/mlops/saelens 安装
路径optional-skills/mlops/saelens
版本1.0.0
作者Orchestra Research
许可证MIT
依赖项sae-lens>=6.0.0, transformer-lens>=2.0.0, torch>=2.0.0
平台linux, macos, windows
标签稀疏自编码器, SAE, 机制可解释性, 特征发现, 叠加

title: "SAELens:用于机制解释性的稀疏自编码器" description: "SAELens 是用于训练和分析稀疏自编码器(SAE)的主要库——这是一种将多义神经网络激活分解为稀疏、可解释特征的技术。" slug: "saelens" date: 2024-03-28T10:00:00+08:00 draft: false tags: ["机器学习", "解释性", "稀疏自编码器", "TransformerLens"] categories: ["技术", "AI"]

信息

以下是当此技能触发时,Hermes 加载的完整技能定义。这是当技能激活时,智能体看到的指令。

SAELens:用于机制解释性的稀疏自编码器

SAELens 是用于训练和分析稀疏自编码器(SAE)的主要库——这是一种将多义神经网络激活分解为稀疏、可解释特征的技术。基于 Anthropic 在单义性方面的开创性研究。

GitHubjbloomAus/SAELens(1,100+ 颗星)

问题:多义性与叠加

神经网络中的单个神经元是多义的——它们会在多个语义截然不同的上下文中被激活。发生这种情况是因为模型使用叠加来表示比它们实际拥有的神经元数量更多的特征,这使得解释变得困难。

SAE 通过以下方式解决此问题:将密集激活分解为稀疏的、单义的特征——通常对于任何给定输入,只有少量特征被激活,并且每个特征都对应一个可解释的概念。

何时使用 SAELens

当你需要以下情况时使用 SAELens:

  • 发现模型激活中的可解释特征
  • 理解模型学到了哪些概念
  • 研究叠加和特征几何
  • 执行基于特征的引导或消融
  • 分析与安全相关的特征(欺骗、偏见、有害内容)

在以下情况下考虑替代方案:

  • 你需要基本的激活分析 → 直接使用 TransformerLens
  • 你想进行因果干预实验 → 使用 pyveneTransformerLens
  • 你需要生产环境中的引导 → 考虑直接进行激活工程

安装

pip install sae-lens

要求:Python 3.10+,transformer-lens>=2.0.0

核心概念

SAE 学习什么

SAE 被训练来通过稀疏瓶颈重建模型激活:

输入激活 → 编码器 → 稀疏特征 → 解码器 → 重建激活
(d_model) ↓ (d_sae >> d_model) ↓ (d_model)
稀疏性 重建
惩罚 损失

损失函数MSE(原始, 重建) + L1_系数 × L1(特征)

关键验证(Anthropic 研究)

在“走向单义性”研究中,人类评估者发现 70% 的 SAE 特征确实是可解释的。发现的特征包括:

  • DNA 序列、法律语言、HTTP 请求
  • 希伯来文、营养声明、代码语法
  • 情感、命名实体、语法结构

工作流 1:加载和分析预训练的 SAE

分步指南

from transformer_lens import HookedTransformer
from sae_lens import SAE

# 1. 加载模型和预训练的 SAE
model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, cfg_dict, sparsity = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)

# 2. 获取模型激活
tokens = model.to_tokens("The capital of France is Paris")
_, cache = model.run_with_cache(tokens)
activations = cache["resid_pre", 8] # [batch, pos, d_model]

# 3. 编码为 SAE 特征
sae_features = sae.encode(activations) # [batch, pos, d_sae]
print(f"活跃特征数: {(sae_features > 0).sum()}")

# 4. 查找每个位置的前几个特征
for pos in range(tokens.shape[1]):
top_features = sae_features[0, pos].topk(5)
token = model.to_str_tokens(tokens[0, pos:pos+1])[0]
print(f"词元 '{token}': 特征 {top_features.indices.tolist()}")

# 5. 重建激活
reconstructed = sae.decode(sae_features)
reconstruction_error = (activations - reconstructed).norm()

可用的预训练 SAE

发布版本模型
gpt2-small-res-jbGPT-2 Small多个残差流
gemma-2b-resGemma 2B残差流
HuggingFace 上的各种模型搜索标签 saelens各种

检查清单

  • 使用 TransformerLens 加载模型
  • 为目标层加载匹配的 SAE
  • 将激活编码为稀疏特征
  • 识别每个词元的前几个激活特征
  • 验证重建质量

工作流 2:训练自定义 SAE

分步指南

from sae_lens import SAE, LanguageModelSAERunnerConfig, SAETrainingRunner

# 1. 配置训练
cfg = LanguageModelSAERunnerConfig(
# 模型
model_name="gpt2-small",
hook_name="blocks.8.hook_resid_pre",
hook_layer=8,
d_in=768, # 模型维度

# SAE 架构
architecture="standard", # 或 "gated", "topk"
d_sae=768 * 8, # 8 倍扩展因子
activation_fn="relu",

# 训练
lr=4e-4,
l1_coefficient=8e-5, # 稀疏性惩罚
l1_warm_up_steps=1000,
train_batch_size_tokens=4096,
training_tokens=100_000_000,

# 数据
dataset_path="monology/pile-uncopyrighted",
context_size=128,

# 日志记录
log_to_wandb=True,
wandb_project="sae-training",

# 检查点
checkpoint_path="checkpoints",
n_checkpoints=5,
)

# 2. 训练
trainer = SAETrainingRunner(cfg)
sae = trainer.run()

# 3. 评估
print(f"L0 (平均活跃特征数): {trainer.metrics['l0']}")
print(f"CE 损失恢复率: {trainer.metrics['ce_loss_score']}")

关键超参数

参数典型值影响
d_sae4-16× d_model更多特征,更高容量
l1_coefficient5e-5 到 1e-4值越高 = 稀疏性越强,精度越低
lr1e-4 到 1e-3标准优化器学习率
l1_warm_up_steps500-2000防止特征过早“死亡”

评估指标

指标目标值含义
L050-200每个词元的平均活跃特征数
CE 损失分数80-95%与原始模型相比恢复的交叉熵比例
死亡特征<5%从未激活的特征
解释方差>90%重建质量

检查清单

  • 选择目标层和钩子点
  • 设置扩展因子(d_sae = 4-16× d_model)
  • 为期望的稀疏性调整 L1 系数
  • 启用 L1 预热以防止特征“死亡”
  • 在训练期间监控指标(W&B)
  • 验证 L0 和 CE 损失恢复
  • 检查死亡特征比例

工作流 3:特征分析和引导

分析单个特征

from transformer_lens import HookedTransformer
from sae_lens import SAE
import torch

model = HookedTransformer.from_pretrained("gpt2-small", device="cuda")
sae, _, _ = SAE.from_pretrained(
release="gpt2-small-res-jb",
sae_id="blocks.8.hook_resid_pre",
device="cuda"
)

# 查找激活特定特征的内容
feature_idx = 1234
test_texts = [
"The scientist conducted an experiment",
"I love chocolate cake",
"The code compiles successfully",
"Paris is beautiful in spring",
]

for text in test_texts:
tokens = model.to_tokens(text)
_, cache = model.run_with_cache(tokens)
features = sae.encode(cache["resid_pre", 8])
activation = features[0, :, feature_idx].max().item()
print(f"{activation:.3f}: {text}")

特征引导

def steer_with_feature(model, sae, prompt, feature_idx, strength=5.0):
"""将 SAE 特征方向添加到残差流中。"""
tokens = model.to_tokens(prompt)

# 从解码器获取特征方向
feature_direction = sae.W_dec[feature_idx] # [d_model]

def steering_hook(activation, hook):
# 在所有位置添加缩放后的特征方向
activation += strength * feature_direction
return activation

# 使用引导进行生成
output = model.generate(
tokens,
max_new_tokens=50,
fwd_hooks=[("blocks.8.hook_resid_pre", steering_hook)]
)
return model.to_string(output[0])

特征归因

# 哪些特征对特定输出影响最大?
tokens = model.to_tokens("The capital of France is")
_, cache = model.run_with_cache(tokens)

# 获取最后位置的特征
features = sae.encode(cache["resid_pre", 8])[0, -1] # [d_sae]

# 获取每个特征的 logit 归因
# 特征贡献 = 特征激活 × 解码器权重 × 解嵌入
W_dec = sae.W_dec # [d_sae, d_model]
W_U = model.W_U # [d_model, vocab]

# 对 "Paris" logit 的贡献
paris_token = model.to_single_token(" Paris")
feature_contributions = features * (W_dec @ W_U[:, paris_token])

top_features = feature_contributions.topk(10)
print("用于预测 'Paris' 的前几个特征:")
for idx, val in zip(top_features.indices, top_features.values):
print(f" 特征 {idx.item()}: {val.item():.3f}")

常见问题与解决方案

问题:死亡特征比例过高

# 错误做法:没有预热,特征过早“死亡”
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4,
l1_warm_up_steps=0, # 不好!
)

# 正确做法:预热 L1 惩罚
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=8e-5,
l1_warm_up_steps=1000, # 逐渐增加
use_ghost_grads=True, # 复活“死亡”特征
)

问题:重建质量差(CE 恢复率低)

# 减少稀疏性惩罚
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=5e-5, # 值越低 = 重建质量越好
d_sae=768 * 16, # 更高容量
)

问题:特征不可解释

# 增加稀疏性(更高的 L1)
cfg = LanguageModelSAERunnerConfig(
l1_coefficient=1e-4, # 值越高 = 稀疏性越强,越可解释
)
# 或者使用 TopK 架构
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn_kwargs={"k": 50}, # 恰好 50 个活跃特征
)

问题:训练期间内存错误

cfg = LanguageModelSAERunnerConfig(
train_batch_size_tokens=2048, # 减少批处理大小
store_batch_size_prompts=4, # 缓冲区中更少的提示
n_batches_in_buffer=8, # 更小的激活缓冲区
)

与 Neuronpedia 的集成

neuronpedia.org 浏览预训练的 SAE 特征:

# 特征通过 SAE ID 索引
# 示例:gpt2-small 第 8 层 特征 1234
# → neuronpedia.org/gpt2-small/8-res-jb/1234

关键类参考

用途
SAE稀疏自编码器模型
LanguageModelSAERunnerConfig训练配置
SAETrainingRunner训练循环管理器
ActivationsStore激活值收集与批处理
HookedSAETransformerTransformerLens + SAE 集成

参考文档

详细的 API 文档、教程和高级用法,请参阅 references/ 文件夹:

文件内容
references/README.md概述与快速入门指南
references/api.mdSAE、TrainingSAE、配置的完整 API 参考
references/tutorials.md训练、分析、引导的分步教程

外部资源

教程

论文

官方文档

SAE 架构

架构描述用例
标准ReLU + L1 惩罚项通用
门控学习到的门控机制更好的稀疏性控制
TopK恰好 K 个活跃特征一致的稀疏性
# TopK SAE (恰好 50 个特征活跃)
cfg = LanguageModelSAERunnerConfig(
architecture="topk",
activation_fn="topk",
activation_fn_kwargs={"k": 50},
)