Validate token-pruning ratio and unify keep-count resolution#342
Open
SuperMarioYL wants to merge 2 commits into
Open
Validate token-pruning ratio and unify keep-count resolution#342SuperMarioYL wants to merge 2 commits into
SuperMarioYL wants to merge 2 commits into
Conversation
Every token_compressor pruning strategy converts a drop ratio into an absolute keep count via round(num_vision_tokens * (1 - ratio)), each with its own ad-hoc 'retain at least one token' guard (missing entirely in hiprune) and none validating that ratio lies in [0, 1]. An out-of-range ratio therefore flows straight into the keep count: a value > 1 makes (1 - ratio) negative, yielding a negative num_to_keep that crashes torch.empty / torch.topk, while a value < 0 over-counts. The shipped config configs/qwen2_5_vl/pruning/vision_selector_r0.9.yaml even carries ratio: 9 (the filename says r0.9), so this is a live defect, not a hypothetical one. Add a shared resolve_num_tokens_to_keep(ratio, n) helper that validates the ratio, computes the keep count, and applies a single uniform retain-one rule, then route all eight strategies through it. Fix the config typo to 0.9 and add a CPU-only regression test for the helper's contract.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
Every token-compressor pruning strategy turns a drop
ratiointo an absolutekeep count via
int(round(num_vision_tokens * (1 - ratio))). Today each strategydoes this inline, with its own ad-hoc "retain at least one token" guard — and
none of them validate that
ratioactually lies in[0, 1].That gap is a real defect:
ratio > 1makes1 - rationegative, sonum_to_keepgoes negative andflows straight into
torch.empty(num_to_keep, ...)(divprune) /torch.topk(..., k=min(num_to_keep, n))(attention_based, vispruner) / thepivot math in dart — which raises an opaque CUDA/CPU error deep inside torch,
far from the actual cause.
ratio < 0over-counts and silently changes the kept-token set.hiprunedoesn't apply it at all, so abenign rounding-to-zero there drops the entire image.
This isn't hypothetical — the shipped config
configs/qwen2_5_vl/pruning/vision_selector_r0.9.yamlcarries
ratio: 9even though the filename saysr0.9, so running it todaycrashes instead of pruning to 10%.
Changes
resolve_num_tokens_to_keep(ratio, num_vision_tokens)in
algorithm/utils/utils.pythat:bool/NaN/ out-of-[0,1]ratio with a clear[TokenCompressor Error] 'ratio' must be in [0.0, 1.0], got <x>message;round(num_vision_tokens * (1 - ratio));ratio < 1.0" rule.basic,attention_based,dart,divprune,hiprune,idpruner,scope,vision_selector) through the helper, removingthe duplicated inline blocks. Behaviour is unchanged for valid ratios;
hiprunegains the retain-one guard for consistency.
ratio: 9→0.9typo in the vision_selector config so the exampleactually runs.
tests/test_token_pruning_ratio.py) that pinsthe contract — it stubs the
torchimport so it needs neither a GPU nor modelweights.
Testing
Formatting matches the repo pre-commit hooks:
black --line-length=99,isort --profile=black, andflake8 (E203,E704,W503,W504 ignored)are all clean.Notes
Out-of-range ratios were never a working path (they crashed or corrupted), so
raising a descriptive
ValueErrorinstead is strictly safer and not a behaviouralregression for any valid configuration.