Docs
Quickstart
Go from zero to your first mechreward-driven GRPO training run in under 10 minutes.
1 · Install
pip install mechreward # Optional extras pip install "mechreward[sae]" # sae_lens integration pip install "mechreward[trl]" # GRPOTrainer hook pip install "mechreward[all]"
2 · Load a validated SAE + feature pack
Every pack in the catalog has passed Stage Gate 1 (ρ ≥ 0.30 on held-out data). You can load by short name.
import mechreward as mr
sae = mr.load_sae(release="caiovicentino1/Qwen3.5-4B-SAE-L18-topk", sae_id="layer_18")
pack = mr.load_pack("qwen3.5-4b/reasoning_pack")
# 10 helpful + 10 harmful features, validated at Spearman ρ=0.540 on 100Q held-out GSM8K3 · Build the reward
reward = mr.FeatureReward.from_pack(
pack,
sae=sae,
aggregation="per_token", # per-token dense, matches G3 recipe
)
composite = mr.CompositeReward(
rewards=[
reward,
mr.OutcomeReward(verifier=mr.verifiers.gsm8k_exact_match),
],
weights=[0.1, 1.0], # λ_mech=0.1, validated at G2
)4 · Plug into GRPO
from trl import GRPOConfig, GRPOTrainer
trainer = GRPOTrainer(
model="Qwen/Qwen3.5-4B",
args=GRPOConfig(
output_dir="./out",
num_generations=4,
learning_rate=3e-6, # CRITICAL: Qwen3.5-4B G3 showed 1e-6 stalls
),
train_dataset=gsm8k_train,
reward_funcs=composite,
)
trainer.train()5 · Stage Gates — don't skip them
Before running a full G3 training, always verify the signal on your setup with G1 and G2. Skipping these caused us 15h of wasted compute on Qwen3.5-4B before we fixed the LR.
- G1 (correlation pre-test): generate 100 baseline responses, compute Spearman ρ between R_mech and outcome. Need ρ ≥ 0.30 before proceeding.
- G2 (tiny RL ablation): compare R0 (outcome only) vs R1 (outcome + SAE) vs R2 (outcome + raw direction) on 100 steps. Need R1 ≥ R0 + 2 pp.
- G3 (full RL): target ≥80% of benchmark, hack rate < 30%, MMLU regression < 2 pp.
6 · Engineering notes that matter
- Qwen3.5 / Qwen3.6 are multimodal — use
AutoModelForImageTextToText, not CausalLM. Freeze the vision tower. - Prompt format must match the SAE feature-discovery distribution. Chat template vs raw prompt silently breaks mech signal.
model.disable_adapter()context manager replaces a separate ref_model (saves 8 GB).- bf16
log_softmaxhalves logits memory — needed for vocab=248 k on 4 B models. - fla + causal-conv1d are not optional on GDN models (10× slower without).