Skip to content

feat: multi-token KL divergence for more robust quality measurement#209

Open
KewkLW wants to merge 2 commits intop-e-w:masterfrom
KewkLW:multi-token-kl
Open

feat: multi-token KL divergence for more robust quality measurement#209
KewkLW wants to merge 2 commits intop-e-w:masterfrom
KewkLW:multi-token-kl

Conversation

@KewkLW
Copy link

@KewkLW KewkLW commented Mar 4, 2026

Summary

Adds a --kl-tokens N option to generate N tokens when computing KL divergence instead of only the first token.

The first generated token is often a generic starter ("I", "The") where distributions barely differ between the original and abliterated model, making single-token KL a noisy signal for detecting model damage. By averaging KL across multiple token positions, the measurement captures actual divergence in generation behavior.

Changes:

  • config.py: New kl_tokens setting (default 1, preserving existing behavior)
  • model.py: get_logprobs() generates N tokens and reshapes to (prompts * N, vocab) so batchmean naturally averages across positions
  • evaluator.py: Scales kl_divergence_scale and kl_divergence_target by N since multi-token KL produces proportionally larger absolute values

Motivation

While abliterating Qwen3.5 models (0.8B through 9B), I found that single-token KL consistently underestimated model damage. Trials that looked fine by first-token KL (< 0.05) would produce noticeably degraded outputs on longer generations. Switching to --kl-tokens 3 gave a much cleaner separation between good and bad parameter combinations in the Pareto front.

With N=3 on Qwen3.5-0.8B over 100 trials, the optimizer avoided parameter regions that single-token KL allowed through, resulting in better quality at the same refusal suppression level.

Usage

# Default behavior (unchanged)
heretic --model Qwen/Qwen3.5-9B

# 3-token KL (recommended)
heretic --model Qwen/Qwen3.5-9B --kl-tokens 3

Design decisions

  • Default is 1 so this is fully backwards-compatible and opt-in
  • Threshold scaling (scale * N, target * N) keeps the optimizer balanced without needing to manually adjust kl_divergence_scale or kl_divergence_target
  • Reshape instead of loop: stacking all token logits and reshaping to (prompts * N, vocab) lets batchmean handle the averaging in a single KL call
  • Recommended N=3: good tradeoff between signal quality and speed. N=5 is slightly better, N>5 shows diminishing returns

Add --kl-tokens N option to generate N tokens when computing KL
divergence instead of only the first token. The first generated token
is often a generic starter ("I", "The") where distributions barely
differ between the original and abliterated model, making single-token
KL a noisy quality signal.

Multi-token KL stacks logprobs across all positions and reshapes to
(prompts * N, vocab) so batchmean reduction naturally averages across
positions. The kl_divergence_scale and kl_divergence_target thresholds
are scaled by N to maintain consistent optimization behavior.

Default is 1 (preserving existing behavior). Recommended: 3.
@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the model quality measurement by introducing multi-token KL divergence. Previously, KL divergence was calculated only on the first generated token, which often led to noisy and unreliable signals for detecting model damage. By allowing the calculation to span multiple tokens, the system can now capture more meaningful differences in generation behavior, leading to more accurate and robust evaluation of abliterated models. This improvement provides a clearer separation between good and bad parameter combinations during optimization, ultimately yielding higher quality models.

Highlights

  • Multi-token KL Divergence: Implemented a new --kl-tokens N option to calculate KL divergence over N generated tokens instead of just the first, providing a more robust quality signal for model damage detection.
  • Configurability: Introduced a kl_tokens setting in config.py with a default of 1 for backward compatibility, allowing users to specify the number of tokens for KL divergence calculation.
  • Dynamic Threshold Scaling: Adjusted kl_divergence_scale and kl_divergence_target in evaluator.py by multiplying them with kl_tokens to maintain balance for the optimizer, as multi-token KL produces proportionally larger absolute values.
  • Efficient Logprob Generation: Modified model.py to generate N tokens and efficiently reshape their logit distributions for batch processing, enabling batchmean to average KL divergence across all token positions.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • src/heretic/config.py
    • Added kl_tokens field to Settings with a default of 1, allowing configuration of the number of tokens for KL divergence calculation.
  • src/heretic/evaluator.py
    • Updated print statements to reflect the number of tokens being used for KL divergence calculation.
    • Modified get_score method to scale kl_divergence_scale and kl_divergence_target by kl_tokens to account for the increased magnitude of multi-token KL values.
  • src/heretic/model.py
    • Updated get_logprobs method to generate n_tokens instead of a single token.
    • Implemented logic to stack and reshape logits from multiple tokens into a (prompts * n_tokens, vocab) tensor, facilitating batch processing for KL divergence calculation.
    • Maintained backward compatibility for n_tokens = 1.
Activity
  • No human activity has been recorded on this pull request yet.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a multi-token KL divergence calculation to improve model quality measurement. The implementation is clever, particularly the reshaping in model.py to average KL divergence. My review has identified a few areas for improvement. A key concern is the scaling of KL divergence thresholds in evaluator.py, which appears to be inconsistent with the use of reduction="batchmean". I've also noted a violation of the repository style guide regarding configuration file updates, some duplicated code, and a minor comment formatting issue. Overall, a great feature addition with a few points to address.

Comment on lines +120 to +124
# Scale thresholds by kl_tokens since multi-token KL produces
# proportionally larger absolute values.
kl_tokens = self.settings.kl_tokens
kl_divergence_scale = self.settings.kl_divergence_scale * kl_tokens
kl_divergence_target = self.settings.kl_divergence_target * kl_tokens
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The KL divergence thresholds are being scaled by kl_tokens. However, the implementation in model.py reshapes the logits so that F.kl_div with reduction="batchmean" computes the average KL divergence across all generated tokens. An average value should not be proportional to the number of tokens. This scaling appears to be incorrect and may lead to the optimizer behaving in unintended ways. The thresholds should likely not be scaled.

As noted in model.py, batchmean already handles the averaging, so the resulting kl_divergence value should be of a similar magnitude regardless of kl_tokens.

Suggested change
# Scale thresholds by kl_tokens since multi-token KL produces
# proportionally larger absolute values.
kl_tokens = self.settings.kl_tokens
kl_divergence_scale = self.settings.kl_divergence_scale * kl_tokens
kl_divergence_target = self.settings.kl_divergence_target * kl_tokens
kl_divergence_scale = self.settings.kl_divergence_scale
kl_divergence_target = self.settings.kl_divergence_target

Comment on lines +218 to +226
kl_tokens: int = Field(
default=1,
description=(
"Number of tokens to generate when computing KL divergence. "
"Higher values give a more robust quality signal at the cost of slower evaluation. "
"The KL divergence is averaged across all token positions. "
"Recommended: 1 (fastest, default), 3-5 (good tradeoff), >5 (diminishing returns)."
),
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

According to rule #8 of the repository style guide, when a new setting is added to config.py, it should also be added to config.default.toml. The new kl_tokens setting is missing from config.default.toml.

Please add the following to config.default.toml after winsorization_quantile:

# Number of tokens to generate when computing KL divergence.
# Higher values give a more robust quality signal at the cost of slower evaluation.
# The KL divergence is averaged across all token positions.
# Recommended: 1 (fastest, default), 3-5 (good tradeoff), >5 (diminishing returns).
kl_tokens = 1
References
  1. When new settings are added in config.py, they should also be added to config.default.toml, set to their default value and with their description as a comment. The order of settings in config.default.toml should match that in config.py. (link)

Comment on lines +101 to +106
kl_label = (
f" * Obtaining {self.settings.kl_tokens}-token probability distributions..."
if self.settings.kl_tokens > 1
else " * Obtaining first-token probability distributions..."
)
print(kl_label)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic to generate the kl_label string is duplicated from the __init__ method (lines 31-36). To improve maintainability and avoid code duplication, consider refactoring this logic into a private helper method within the Evaluator class.

Comment on lines +675 to +676
# Multi-token: stack all positions, reshape to (prompt * n_tokens, vocab)
# so KL div with batchmean naturally averages across all positions.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This comment does not fully adhere to the style guide (rule #4), which requires comments to use correct grammar, start with a capital letter, and end with a period. The second line is a continuation of the first and starts with a lowercase letter. Please rephrase for clarity and to follow the style guide.

Suggested change
# Multi-token: stack all positions, reshape to (prompt * n_tokens, vocab)
# so KL div with batchmean naturally averages across all positions.
# Multi-token: stack all positions and reshape to (prompt * n_tokens, vocab).
# This allows `F.kl_div` with `reduction="batchmean"` to average the KL divergence across all positions.
References
  1. Comments should start with a capital letter and end with a period. They should use correct grammar and spelling. (link)

- Remove incorrect KL threshold scaling by kl_tokens. The batchmean
  reduction already averages across all token positions, so the
  thresholds do not need adjustment.
- Extract duplicated kl_label logic into _kl_label() helper method.
- Fix comment formatting to match repository style guide.
- Add kl_tokens to config.default.toml.
@kabachuha
Copy link

Tried it.

gpt-oss-20b, even with SOM (4 neurons), is impossible to decensor now :(

? Which trial do you want to use? (Use arrow keys)
   [Trial 168] Refusals:  3/100, KL divergence: 2.1387
   [Trial 166] Refusals:  5/100, KL divergence: 1.9938
   [Trial 171] Refusals: 13/100, KL divergence: 1.9518
   [Trial 160] Refusals: 15/100, KL divergence: 1.6948
   [Trial  78] Refusals: 17/100, KL divergence: 1.6125
   [Trial 192] Refusals: 20/100, KL divergence: 1.1870
 » [Trial 110] Refusals: 24/100, KL divergence: 1.0808
   [Trial 103] Refusals: 29/100, KL divergence: 0.9377
   [Trial 126] Refusals: 31/100, KL divergence: 0.9283
   [Trial 186] Refusals: 32/100, KL divergence: 0.7197
   [Trial 181] Refusals: 58/100, KL divergence: 0.5101
   [Trial 139] Refusals: 82/100, KL divergence: 0.3237
   [Trial 120] Refusals: 98/100, KL divergence: 0.1868
   Run additional trials
   Exit program

Maybe it's because the "safety reasoning" in the beginning of the responses eliminated, however, and it can be a good thing actually.

Hopefully, there is a way to make the refusals / KL count lower in the gpt-oss situation

Also, the wait time indeed grew up noticeably. Need to test on smaller models to look at the tradeoffs

@KewkLW
Copy link
Author

KewkLW commented Mar 4, 2026

Hey, interesting results. I wonder if the false refusal detection PR (#210) would help here. Without it the optimizer has no way to know it's over-abliterating - it just keeps pushing refusals down even if it's destroying the model to get there.

With --detect-false-refusals it penalizes parameter combos that cause the model to start refusing harmless prompts too, which should steer it away from those high-KL configs. The two PRs were designed to work together so you might get better results combining them.

Worth a shot on gpt-oss-20b to see if the pareto front improves.

@kabachuha
Copy link

kabachuha commented Mar 4, 2026

So, I tested both PRs and found that on gpt-oss-20b it didn't do much. The false positive rate has been only < 4/100, with 1/100 present in all the trials. For gpt-oss in particular seems not that important. The pareto changes quite very small.

I think, for this PR to have more usefulness, we need to invent better benchmark scaling or maybe where the (multitoken) penalty / generation token count is inversely proportional to the refusals count (?)

@KewkLW
Copy link
Author

KewkLW commented Mar 4, 2026

Yeah, the multi-token KL values are fundamentally higher than single-token and the optimizer thresholds (kl_divergence_target=0.01, kl_divergence_scale=1.0) are calibrated for single-token. With multi-token, nearly every trial lands above target so the optimizer spends all its time minimizing KL instead of exploring low-refusal configs. We saw the same thing on Qwen3.5-9B - went from 3/100 refusals KL 0.0366 (kl_tokens=1) to 1/100 refusals KL 0.4770 (kl_tokens=3), with no trials in the sweet spot.

We're testing a few approaches on Qwen3.5-9B right now:

  1. Two-phase optimization - run phase 1 with kl_tokens=1 for fast exploration to find the promising parameter region, then re-evaluate the top N trials with multi-token KL for finer resolution. Avoids the calibration problem since single-token thresholds stay valid during search.

  2. Position-weighted KL - when doing multi-token, weight earlier tokens more heavily (exponential decay like w_t = 0.5^t). The first token is most constrained by the prompt, later tokens compound small differences via autoregressive amplification. This should dampen the KL inflation without losing the multi-token signal.

  3. Baseline calibration - before optimization starts, measure the unmodified model's multi-token KL (should be ~0 but quantization noise gives a floor). Use that to auto-set the target and scale.

Your idea about scaling penalty inversely with refusals is interesting too - low-refusal trials get more tokens for finer resolution, high-refusal ones get fewer since they're already suboptimal. That could work well combined with the position weighting.

Will report back with results once the current Qwen3.5-9B run finishes.

@p-e-w
Copy link
Owner

p-e-w commented Mar 5, 2026

Unless I'm misunderstanding the implementation, this approach makes no sense.

The KL divergence is only a meaningful metric if all previous tokens are identical. In that case, the KLD quantifies how different the predictions from two models are. But if the input sequences don't match, the KLD doesn't describe anything. Of course models will make wildly different predictions from different inputs. That's true even if the models are identical.

So you can't just generate multiple tokens with both models, get the logprobs for each position, and then compute the pairwise (or averaged) KLD. If the first token generated by one model under greedy decoding differs from that generated by the other (which can happen even under minuscule model differences, if the top 2 tokens have very similar probabilities), the rest of the tokens are determined by those initial tokens, and divergences amplify chaotically. You can get arbitrarily high KLDs with arbitrarily small actual model differences that way.

@KewkLW
Copy link
Author

KewkLW commented Mar 6, 2026

You're right, and I appreciate the clear explanation. The autoregressive divergence amplification is a fundamental flaw. Once the first token differs, the rest of the sequence is conditioned on different inputs and the KLD becomes meaningless noise.

The fix would be teacher-forced KL. Feed the reference model's greedy decoded token sequence as input to the ablated model and compare logprobs at each position with identical conditioning. That way both models see the same prefix and the per-position KLD actually measures prediction divergence rather than compounding generation drift.

Something like:

# Generate reference sequence once
with torch.no_grad():
    ref_output = ref_model.generate(**inputs, max_new_tokens=n_tokens)

# Get logprobs from both models on the SAME token sequence
ref_logits = ref_model(ref_output).logits
abl_logits = abl_model(ref_output).logits

# Now per-position KLD is meaningful
kl = F.kl_div(abl_logprobs, ref_probs, reduction="batchmean")

That said, I realize this is a bigger change than what I originally proposed and the single-token approach already works well for the optimizer's needs. Happy to rework this if you think teacher-forced multi-token KL would be worth pursuing, or close this out if single-token is sufficient for the use cases you care about.

@p-e-w
Copy link
Owner

p-e-w commented Mar 6, 2026

The fix would be teacher-forced KL. Feed the reference model's greedy decoded token sequence as input to the ablated model and compare logprobs at each position with identical conditioning.

Indeed, this would fix the issue. I'm still not entirely sure what the overall goal is though. I have seen no evidence suggesting that single-token KLD is insufficient. Token vocabularies typically consist of 200k+ tokens today, so even comparing a single token's logprobs effectively compares hundreds of thousands of values. This is why KLD is a much better measure of model similarity than perplexity.

But I agree that such experiments are valuable, and I think this would make a great scorer plugin, one that I would happily include in Heretic by default (though probably not enable by default). Once #53 is merged, it will be possible to implement it as such.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants