feat: multi-token KL divergence for more robust quality measurement#209
feat: multi-token KL divergence for more robust quality measurement#209KewkLW wants to merge 2 commits intop-e-w:masterfrom
Conversation
Add --kl-tokens N option to generate N tokens when computing KL
divergence instead of only the first token. The first generated token
is often a generic starter ("I", "The") where distributions barely
differ between the original and abliterated model, making single-token
KL a noisy quality signal.
Multi-token KL stacks logprobs across all positions and reshapes to
(prompts * N, vocab) so batchmean reduction naturally averages across
positions. The kl_divergence_scale and kl_divergence_target thresholds
are scaled by N to maintain consistent optimization behavior.
Default is 1 (preserving existing behavior). Recommended: 3.
Summary of ChangesHello, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request enhances the model quality measurement by introducing multi-token KL divergence. Previously, KL divergence was calculated only on the first generated token, which often led to noisy and unreliable signals for detecting model damage. By allowing the calculation to span multiple tokens, the system can now capture more meaningful differences in generation behavior, leading to more accurate and robust evaluation of abliterated models. This improvement provides a clearer separation between good and bad parameter combinations during optimization, ultimately yielding higher quality models. Highlights
🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console. Changelog
Activity
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Code Review
This pull request introduces a multi-token KL divergence calculation to improve model quality measurement. The implementation is clever, particularly the reshaping in model.py to average KL divergence. My review has identified a few areas for improvement. A key concern is the scaling of KL divergence thresholds in evaluator.py, which appears to be inconsistent with the use of reduction="batchmean". I've also noted a violation of the repository style guide regarding configuration file updates, some duplicated code, and a minor comment formatting issue. Overall, a great feature addition with a few points to address.
src/heretic/evaluator.py
Outdated
| # Scale thresholds by kl_tokens since multi-token KL produces | ||
| # proportionally larger absolute values. | ||
| kl_tokens = self.settings.kl_tokens | ||
| kl_divergence_scale = self.settings.kl_divergence_scale * kl_tokens | ||
| kl_divergence_target = self.settings.kl_divergence_target * kl_tokens |
There was a problem hiding this comment.
The KL divergence thresholds are being scaled by kl_tokens. However, the implementation in model.py reshapes the logits so that F.kl_div with reduction="batchmean" computes the average KL divergence across all generated tokens. An average value should not be proportional to the number of tokens. This scaling appears to be incorrect and may lead to the optimizer behaving in unintended ways. The thresholds should likely not be scaled.
As noted in model.py, batchmean already handles the averaging, so the resulting kl_divergence value should be of a similar magnitude regardless of kl_tokens.
| # Scale thresholds by kl_tokens since multi-token KL produces | |
| # proportionally larger absolute values. | |
| kl_tokens = self.settings.kl_tokens | |
| kl_divergence_scale = self.settings.kl_divergence_scale * kl_tokens | |
| kl_divergence_target = self.settings.kl_divergence_target * kl_tokens | |
| kl_divergence_scale = self.settings.kl_divergence_scale | |
| kl_divergence_target = self.settings.kl_divergence_target |
| kl_tokens: int = Field( | ||
| default=1, | ||
| description=( | ||
| "Number of tokens to generate when computing KL divergence. " | ||
| "Higher values give a more robust quality signal at the cost of slower evaluation. " | ||
| "The KL divergence is averaged across all token positions. " | ||
| "Recommended: 1 (fastest, default), 3-5 (good tradeoff), >5 (diminishing returns)." | ||
| ), | ||
| ) |
There was a problem hiding this comment.
According to rule #8 of the repository style guide, when a new setting is added to config.py, it should also be added to config.default.toml. The new kl_tokens setting is missing from config.default.toml.
Please add the following to config.default.toml after winsorization_quantile:
# Number of tokens to generate when computing KL divergence.
# Higher values give a more robust quality signal at the cost of slower evaluation.
# The KL divergence is averaged across all token positions.
# Recommended: 1 (fastest, default), 3-5 (good tradeoff), >5 (diminishing returns).
kl_tokens = 1References
- When new settings are added in
config.py, they should also be added toconfig.default.toml, set to their default value and with their description as a comment. The order of settings inconfig.default.tomlshould match that inconfig.py. (link)
src/heretic/evaluator.py
Outdated
| kl_label = ( | ||
| f" * Obtaining {self.settings.kl_tokens}-token probability distributions..." | ||
| if self.settings.kl_tokens > 1 | ||
| else " * Obtaining first-token probability distributions..." | ||
| ) | ||
| print(kl_label) |
src/heretic/model.py
Outdated
| # Multi-token: stack all positions, reshape to (prompt * n_tokens, vocab) | ||
| # so KL div with batchmean naturally averages across all positions. |
There was a problem hiding this comment.
This comment does not fully adhere to the style guide (rule #4), which requires comments to use correct grammar, start with a capital letter, and end with a period. The second line is a continuation of the first and starts with a lowercase letter. Please rephrase for clarity and to follow the style guide.
| # Multi-token: stack all positions, reshape to (prompt * n_tokens, vocab) | |
| # so KL div with batchmean naturally averages across all positions. | |
| # Multi-token: stack all positions and reshape to (prompt * n_tokens, vocab). | |
| # This allows `F.kl_div` with `reduction="batchmean"` to average the KL divergence across all positions. |
References
- Comments should start with a capital letter and end with a period. They should use correct grammar and spelling. (link)
- Remove incorrect KL threshold scaling by kl_tokens. The batchmean reduction already averages across all token positions, so the thresholds do not need adjustment. - Extract duplicated kl_label logic into _kl_label() helper method. - Fix comment formatting to match repository style guide. - Add kl_tokens to config.default.toml.
|
Tried it. gpt-oss-20b, even with SOM (4 neurons), is impossible to decensor now :( Maybe it's because the "safety reasoning" in the beginning of the responses eliminated, however, and it can be a good thing actually. Hopefully, there is a way to make the refusals / KL count lower in the gpt-oss situation Also, the wait time indeed grew up noticeably. Need to test on smaller models to look at the tradeoffs |
|
Hey, interesting results. I wonder if the false refusal detection PR (#210) would help here. Without it the optimizer has no way to know it's over-abliterating - it just keeps pushing refusals down even if it's destroying the model to get there. With Worth a shot on gpt-oss-20b to see if the pareto front improves. |
|
So, I tested both PRs and found that on gpt-oss-20b it didn't do much. The false positive rate has been only < 4/100, with 1/100 present in all the trials. For gpt-oss in particular seems not that important. The pareto changes quite very small. I think, for this PR to have more usefulness, we need to invent better benchmark scaling or maybe where the (multitoken) penalty / generation token count is inversely proportional to the refusals count (?) |
|
Yeah, the multi-token KL values are fundamentally higher than single-token and the optimizer thresholds (kl_divergence_target=0.01, kl_divergence_scale=1.0) are calibrated for single-token. With multi-token, nearly every trial lands above target so the optimizer spends all its time minimizing KL instead of exploring low-refusal configs. We saw the same thing on Qwen3.5-9B - went from 3/100 refusals KL 0.0366 (kl_tokens=1) to 1/100 refusals KL 0.4770 (kl_tokens=3), with no trials in the sweet spot. We're testing a few approaches on Qwen3.5-9B right now:
Your idea about scaling penalty inversely with refusals is interesting too - low-refusal trials get more tokens for finer resolution, high-refusal ones get fewer since they're already suboptimal. That could work well combined with the position weighting. Will report back with results once the current Qwen3.5-9B run finishes. |
|
Unless I'm misunderstanding the implementation, this approach makes no sense. The KL divergence is only a meaningful metric if all previous tokens are identical. In that case, the KLD quantifies how different the predictions from two models are. But if the input sequences don't match, the KLD doesn't describe anything. Of course models will make wildly different predictions from different inputs. That's true even if the models are identical. So you can't just generate multiple tokens with both models, get the logprobs for each position, and then compute the pairwise (or averaged) KLD. If the first token generated by one model under greedy decoding differs from that generated by the other (which can happen even under minuscule model differences, if the top 2 tokens have very similar probabilities), the rest of the tokens are determined by those initial tokens, and divergences amplify chaotically. You can get arbitrarily high KLDs with arbitrarily small actual model differences that way. |
|
You're right, and I appreciate the clear explanation. The autoregressive divergence amplification is a fundamental flaw. Once the first token differs, the rest of the sequence is conditioned on different inputs and the KLD becomes meaningless noise. The fix would be teacher-forced KL. Feed the reference model's greedy decoded token sequence as input to the ablated model and compare logprobs at each position with identical conditioning. That way both models see the same prefix and the per-position KLD actually measures prediction divergence rather than compounding generation drift. Something like: # Generate reference sequence once
with torch.no_grad():
ref_output = ref_model.generate(**inputs, max_new_tokens=n_tokens)
# Get logprobs from both models on the SAME token sequence
ref_logits = ref_model(ref_output).logits
abl_logits = abl_model(ref_output).logits
# Now per-position KLD is meaningful
kl = F.kl_div(abl_logprobs, ref_probs, reduction="batchmean")That said, I realize this is a bigger change than what I originally proposed and the single-token approach already works well for the optimizer's needs. Happy to rework this if you think teacher-forced multi-token KL would be worth pursuing, or close this out if single-token is sufficient for the use cases you care about. |
Indeed, this would fix the issue. I'm still not entirely sure what the overall goal is though. I have seen no evidence suggesting that single-token KLD is insufficient. Token vocabularies typically consist of 200k+ tokens today, so even comparing a single token's logprobs effectively compares hundreds of thousands of values. This is why KLD is a much better measure of model similarity than perplexity. But I agree that such experiments are valuable, and I think this would make a great scorer plugin, one that I would happily include in Heretic by default (though probably not enable by default). Once #53 is merged, it will be possible to implement it as such. |
Summary
Adds a
--kl-tokens Noption to generate N tokens when computing KL divergence instead of only the first token.The first generated token is often a generic starter ("I", "The") where distributions barely differ between the original and abliterated model, making single-token KL a noisy signal for detecting model damage. By averaging KL across multiple token positions, the measurement captures actual divergence in generation behavior.
Changes:
config.py: Newkl_tokenssetting (default 1, preserving existing behavior)model.py:get_logprobs()generates N tokens and reshapes to(prompts * N, vocab)sobatchmeannaturally averages across positionsevaluator.py: Scaleskl_divergence_scaleandkl_divergence_targetby N since multi-token KL produces proportionally larger absolute valuesMotivation
While abliterating Qwen3.5 models (0.8B through 9B), I found that single-token KL consistently underestimated model damage. Trials that looked fine by first-token KL (< 0.05) would produce noticeably degraded outputs on longer generations. Switching to
--kl-tokens 3gave a much cleaner separation between good and bad parameter combinations in the Pareto front.With N=3 on Qwen3.5-0.8B over 100 trials, the optimizer avoided parameter regions that single-token KL allowed through, resulting in better quality at the same refusal suppression level.
Usage
Design decisions
scale * N,target * N) keeps the optimizer balanced without needing to manually adjustkl_divergence_scaleorkl_divergence_target(prompts * N, vocab)letsbatchmeanhandle the averaging in a single KL call