Docs

Quickstart

Go from zero to your first mechreward-driven GRPO training run in under 10 minutes.

1 · Install

pip install mechreward
# Optional extras
pip install "mechreward[sae]"    # sae_lens integration
pip install "mechreward[trl]"    # GRPOTrainer hook
pip install "mechreward[all]"

2 · Load a validated SAE + feature pack

Every pack in the catalog has passed Stage Gate 1 (ρ ≥ 0.30 on held-out data). You can load by short name.

import mechreward as mr

sae = mr.load_sae(release="caiovicentino1/Qwen3.5-4B-SAE-L18-topk", sae_id="layer_18")

pack = mr.load_pack("qwen3.5-4b/reasoning_pack")
# 10 helpful + 10 harmful features, validated at Spearman ρ=0.540 on 100Q held-out GSM8K

3 · Build the reward

reward = mr.FeatureReward.from_pack(
    pack,
    sae=sae,
    aggregation="per_token",  # per-token dense, matches G3 recipe
)

composite = mr.CompositeReward(
    rewards=[
        reward,
        mr.OutcomeReward(verifier=mr.verifiers.gsm8k_exact_match),
    ],
    weights=[0.1, 1.0],  # λ_mech=0.1, validated at G2
)

4 · Plug into GRPO

from trl import GRPOConfig, GRPOTrainer

trainer = GRPOTrainer(
    model="Qwen/Qwen3.5-4B",
    args=GRPOConfig(
        output_dir="./out",
        num_generations=4,
        learning_rate=3e-6,   # CRITICAL: Qwen3.5-4B G3 showed 1e-6 stalls
    ),
    train_dataset=gsm8k_train,
    reward_funcs=composite,
)

trainer.train()

5 · Stage Gates — don't skip them

Before running a full G3 training, always verify the signal on your setup with G1 and G2. Skipping these caused us 15h of wasted compute on Qwen3.5-4B before we fixed the LR.

  • G1 (correlation pre-test): generate 100 baseline responses, compute Spearman ρ between R_mech and outcome. Need ρ ≥ 0.30 before proceeding.
  • G2 (tiny RL ablation): compare R0 (outcome only) vs R1 (outcome + SAE) vs R2 (outcome + raw direction) on 100 steps. Need R1 ≥ R0 + 2 pp.
  • G3 (full RL): target ≥80% of benchmark, hack rate < 30%, MMLU regression < 2 pp.

6 · Engineering notes that matter

  • Qwen3.5 / Qwen3.6 are multimodal — use AutoModelForImageTextToText, not CausalLM. Freeze the vision tower.
  • Prompt format must match the SAE feature-discovery distribution. Chat template vs raw prompt silently breaks mech signal.
  • model.disable_adapter() context manager replaces a separate ref_model (saves 8 GB).
  • bf16 log_softmax halves logits memory — needed for vocab=248 k on 4 B models.
  • fla + causal-conv1d are not optional on GDN models (10× slower without).

Links