feat: JaxLoss bugfix, per-molecule JIT, optimizer routing, MM3 fixes#272
Conversation
There was a problem hiding this comment.
Pull request overview
This PR improves benchmark ergonomics by (1) exposing the existing ObjectiveFunction/JaxLoss L2 regularization knob via the benchmark CLI and (2) reducing log noise during QFUERZA/Seminario estimation by demoting two frequently-triggered warnings to debug.
Changes:
- Add
--regularization FLOATtoq2mm-benchmarkand inject it into optimizer configs. - Demote “No bonds/angles match …” and “Non-atomic-unit Hessian …” from
warningtodebug. - Switch Seminario “no match” logs to lazy
%sformatting.
Reviewed changes
Copilot reviewed 3 out of 3 changed files in this pull request and generated 2 comments.
| File | Description |
|---|---|
q2mm/models/seminario.py |
Demotes bond/angle “no match” warnings to debug (and uses lazy formatting). |
q2mm/io/gaussian.py |
Demotes the non-AU Hessian warning to debug while still dropping the Hessian. |
q2mm/diagnostics/cli.py |
Adds --regularization flag and injects it into optimizer configurations. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
c394ec3 to
7e3968a
Compare
7e3968a to
b8281e7
Compare
There was a problem hiding this comment.
Pull request overview
This PR aims to improve benchmark usability by (1) exposing L2 regularization control for optimization runs (especially for unbounded jaxopt:lbfgs) and (2) reducing console noise by demoting frequently-triggered warnings to debug logging.
Changes:
- Updates JaxOpt optimizer convergence messaging and adds “revert to initial params” behavior when the final score is worse than the initial score.
- Demotes several high-volume
logger.warning(...)messages tologger.debug(...). - Adjusts benchmark CLI behavior (optimizer presets,
--max-itersemantics, and detailed-report output handling) and introduces an auto-regularization default forjaxopt:lbfgs.
Reviewed changes
Copilot reviewed 6 out of 6 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| q2mm/optimizers/jaxopt_opt.py | Adds explicit divergence messaging and rollback-on-worse-score logic for JaxOpt runs. |
| q2mm/models/seminario.py | Demotes “no matching bonds/angles” warnings to debug to reduce noise. |
| q2mm/io/openmm.py | Demotes wildcard-type “skipping term” warnings to debug. |
| q2mm/io/gaussian.py | Demotes “non-atomic-unit Hessian” warning to debug. |
| q2mm/diagnostics/cli.py | Updates optimizer presets, changes --max-iter default behavior, and redirects detailed tables to a file when saving results. |
| q2mm/diagnostics/benchmark.py | Makes maxiter optional and injects default iteration limits per optimizer branch; auto-injects L2 regularization for jaxopt:lbfgs. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
2b00d27 to
baaeb4a
Compare
baaeb4a to
7a2194e
Compare
There was a problem hiding this comment.
Pull request overview
This PR aims to improve benchmark usability by (a) exposing L2 regularization to stabilize unbounded optimizers and (b) reducing console noise by demoting repetitive warnings.
Changes:
- Adds/expands regularization handling in the benchmark pipeline (including an auto-regularization default for
jaxopt:lbfgs). - Adjusts optimizer behavior (Optax can route gradients via
JaxLoss; JaxOpt detects NaN divergence and may revert to initial params). - Demotes several high-volume warnings to debug; refines benchmark CLI reporting/output behavior.
Reviewed changes
Copilot reviewed 8 out of 8 changed files in this pull request and generated 10 comments.
Show a summary per file
| File | Description |
|---|---|
q2mm/optimizers/optax.py |
Adds a JaxLoss-based gradient path for JaxEngine to avoid OOM from large Hessian-parameter Jacobians. |
q2mm/optimizers/jaxopt_opt.py |
Improves convergence messaging, explicitly detects NaN/Inf divergence, and reverts to initial parameters if worse/diverged. |
q2mm/models/seminario.py |
Demotes “no matching bonds/angles” messages from warning to debug. |
q2mm/io/openmm.py |
Demotes wildcard-type skip messages from warning to debug during XML export. |
q2mm/io/gaussian.py |
Demotes non-AU Hessian attachment warning to debug. |
q2mm/diagnostics/systems.py |
Adds metal vdW registry and injects missing metal vdW parameters for certain composed systems. |
q2mm/diagnostics/cli.py |
Updates optimizer labels, alters max-iter semantics, changes optax LR default, writes detailed tables to file, and adds matrix success/fail/skip counts. |
q2mm/diagnostics/benchmark.py |
Makes maxiter optional and auto-applies regularization for jaxopt:lbfgs when not provided. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
7a2194e to
506afd8
Compare
There was a problem hiding this comment.
Pull request overview
This PR aims to improve benchmark usability by (a) exposing/using regularization to prevent parameter drift in unbounded optimizers and (b) reducing console noise by demoting high-volume warnings to debug logging.
Changes:
- Refactors benchmark CLI optimizer selection to use unique kebab-case keys (plus improved matrix run output and result bookkeeping).
- Adds automatic L2 regularization for
jaxopt:lbfgsruns in the benchmark execution path and improves JaxOpt divergence handling. - Demotes several high-volume warnings to debug and updates benchmark scripts to newer optimizer identifiers.
Reviewed changes
Copilot reviewed 11 out of 11 changed files in this pull request and generated 6 comments.
Show a summary per file
| File | Description |
|---|---|
| scripts/run_rh_enamide_selected_matrix.sh | Updates some optimizer arguments to new CLI optimizer keys (but still has a few legacy names). |
| scripts/run_rh_benchmarks.sh | Updates benchmark script to use new optimizer keys (e.g., jaxopt-lbfgs, scipy-*). |
| scripts/bench_overnight.sh | Adds an overnight benchmark runner script. |
| q2mm/optimizers/optax.py | Routes gradients through JaxLoss when using JaxEngine to avoid GPU OOM from large Jacobian materialization. |
| q2mm/optimizers/jaxopt_opt.py | Adds explicit NaN/Inf divergence detection and reversion-to-initial behavior when divergence/worsening occurs. |
| q2mm/models/seminario.py | Demotes “no matching bonds/angles” messages from warning to debug. |
| q2mm/io/openmm.py | Demotes wildcard-type “Skipping bond/angle/torsion” messages from warning to debug. |
| q2mm/io/gaussian.py | Demotes non-AU Hessian warning to debug. |
| q2mm/diagnostics/systems.py | Adds optional metal vdW injection when composing Wahlers systems. |
| q2mm/diagnostics/cli.py | Introduces optimizer key/label separation, changes --optimizer filtering semantics, adjusts defaults/help, and redirects detailed tables to files. |
| q2mm/diagnostics/benchmark.py | Makes maxiter optional and adds auto-regularization for jaxopt:lbfgs; adjusts per-optimizer iteration fallbacks. |
Comments suppressed due to low confidence (3)
scripts/run_rh_enamide_selected_matrix.sh:110
- This script still uses the legacy optimizer name
L-BFGS-B, butq2mm.diagnostics.cli --optimizernow filters by optimizer key (e.g.scipy-lbfgsb). Update these invocations to the new keys so the script doesn’t fail with “no matching optimizers”.
run_combo 1 "JAX GPU harmonic L-BFGS-B" \
"${benchmark_cmd[@]}" "${common_args[@]}" \
--backend jax --form harmonic --optimizer L-BFGS-B
run_combo 5 "JAX GPU mm3 L-BFGS-B" \
"${benchmark_cmd[@]}" "${common_args[@]}" \
--backend jax --form mm3 --optimizer L-BFGS-B
scripts/run_rh_enamide_selected_matrix.sh:127
- Same issue as above:
--optimizer L-BFGS-Bno longer matches the new optimizer key scheme. Use the appropriate key (likelyscipy-lbfgsb) so this combo remains runnable.
run_combo 9 "JAX-MD GPU harmonic L-BFGS-B" \
"${benchmark_cmd[@]}" "${common_args[@]}" \
--backend jax-md --form harmonic --optimizer L-BFGS-B
q2mm/diagnostics/cli.py:151
--max-iteris now optional (None), but_run_matrix()is still annotated/documented as taking anintwith a default of 10_000. Sincemain()now passesNonethrough, update_run_matrix’s signature/type/docs to acceptint | Noneand reflect the per-optimizer defaults.
def _run_matrix(
backends: list[tuple[str, type, str]],
optimizers: list[tuple[str, str, dict]],
forms: list[tuple[str, str]],
output_dir: Path | None = None,
*,
leaderboard_only: bool = False,
data_dir: Path | None = None,
platform: str | None = None,
system_key: str = "ch3f",
max_iter: int = 10_000,
) -> list:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
5f7d2a2 to
df9d8fb
Compare
There was a problem hiding this comment.
Pull request overview
This PR improves benchmark usability by (1) making optimizer selection more consistent and less noisy, and (2) adding guardrails/documentation for GPU optimization workflows (including regularization and divergence handling).
Changes:
- Switch benchmark optimizer selection to stable kebab-case optimizer keys (and update scripts/docs accordingly), while also cleaning console output and adding per-run summary counts.
- Add JAX memory/performance improvements (Optax routes gradients through
JaxLoss) and improve JaxOpt divergence detection with revert-to-initial behavior. - Demote several high-volume warnings to debug and add new multi-system GPU benchmark documentation.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 9 comments.
Show a summary per file
| File | Description |
|---|---|
| scripts/run_rh_enamide_selected_matrix.sh | Updates optimizer CLI values to new key-based names. |
| scripts/run_rh_benchmarks.sh | Updates optimizer CLI values to new key-based names. |
| scripts/bench_overnight.sh | Adds an overnight multi-system benchmarking helper script. |
| q2mm/optimizers/optax.py | Uses JaxLoss gradient path for JaxEngine to avoid GPU OOM from Hessian-parameter Jacobian materialization. |
| q2mm/optimizers/jaxopt_opt.py | Detects NaN/Inf divergence, improves messages, and reverts to initial params when diverged/worse. |
| q2mm/models/seminario.py | Demotes “no bonds/angles match” warnings to debug. |
| q2mm/io/openmm.py | Demotes wildcard-type “skipping” warnings to debug. |
| q2mm/io/gaussian.py | Demotes non-AU Hessian warning to debug. |
| q2mm/diagnostics/systems.py | Adds a metal vdW registry and injects missing metal vdW parameters for selected systems. |
| q2mm/diagnostics/cli.py | Introduces optimizer keys/labels, key-based filtering, cleaner output handling, and summary reporting. |
| q2mm/diagnostics/benchmark.py | Supports per-optimizer default iteration limits and auto-regularizes jaxopt:lbfgs. |
| properdocs.yml | Adds nav entry for the new multi-system benchmark page. |
| docs/how-it-works/optimization-guide.md | Adds GPU optimizer recommendations and links to multi-system benchmarks. |
| docs/comparison.md | Adds published FF comparison section referencing multi-system benchmarks. |
| docs/benchmarks/small-molecules.md | Updates optimizer examples to new key-based interface. |
| docs/benchmarks/published-ff-validation.md | Adds optimization comparison table and links to multi-system benchmarks. |
| docs/benchmarks/multi-system.md | Adds a new multi-system GPU benchmark results page. |
| docs/benchmarks/index.md | Adds the multi-system benchmarks page to the index and guidance. |
| docs/benchmarks/gpu.md | Links to multi-system benchmark results and updates optimizer examples. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
df9d8fb to
879f01a
Compare
879f01a to
40f3391
Compare
There was a problem hiding this comment.
Pull request overview
This PR aims to improve the benchmark CLI/usability by adding regularization control (to prevent unbounded parameter drift in unbounded optimizers), reducing console noise from repetitive warnings, and updating benchmark scripts/docs (including new multi-system GPU benchmark documentation).
Changes:
- Refactors benchmark optimizer selection to use unique kebab-case optimizer keys, updates defaults/output handling, and adds per-optimizer iteration defaults plus auto-regularization for
jaxopt:lbfgs. - Improves GPU optimization robustness/perf by routing Optax gradients through
JaxLoss(JaxEngine) and adding explicit divergence detection + revert behavior toJaxOptOptimizer. - Demotes several noisy warnings to debug and adds/updates documentation + scripts for multi-system GPU benchmarks.
Reviewed changes
Copilot reviewed 19 out of 19 changed files in this pull request and generated 8 comments.
Show a summary per file
| File | Description |
|---|---|
| scripts/run_rh_enamide_selected_matrix.sh | Updates optimizer arguments to new key-based names. |
| scripts/run_rh_benchmarks.sh | Updates optimizer arguments to new key-based names. |
| scripts/bench_overnight.sh | Adds an overnight multi-system benchmark helper script. |
| q2mm/optimizers/optax.py | Uses JaxLoss path for gradients under JaxEngine to avoid GPU OOM. |
| q2mm/optimizers/jaxopt_opt.py | Detects NaN divergence and reverts to initial params when diverged/worse. |
| q2mm/models/seminario.py | Demotes “no matching bonds/angles” warnings to debug. |
| q2mm/io/openmm.py | Demotes wildcard-type “skipping” warnings to debug. |
| q2mm/io/gaussian.py | Demotes non-AU Hessian warning to debug. |
| q2mm/diagnostics/systems.py | Adds metal vdW injection support for Wahlers systems (Pd registry + loader param). |
| q2mm/diagnostics/cli.py | Switches to key-based optimizer selection; adjusts max-iter behavior; writes detailed tables to files; adds run summary/exit behavior. |
| q2mm/diagnostics/benchmark.py | Makes maxiter optional with per-optimizer defaults; auto-regularizes jaxopt:lbfgs. |
| properdocs.yml | Adds multi-system benchmark page to docs nav. |
| docs/how-it-works/optimization-guide.md | Adds GPU optimizer recommendations and links to multi-system results. |
| docs/comparison.md | Adds published FF comparison metrics section. |
| docs/benchmarks/small-molecules.md | Updates CLI examples to new optimizer keys. |
| docs/benchmarks/published-ff-validation.md | Adds optimization comparison table/section. |
| docs/benchmarks/multi-system.md | New multi-system GPU benchmark results page. |
| docs/benchmarks/index.md | Links multi-system benchmark page from benchmarks index. |
| docs/benchmarks/gpu.md | References multi-system results and updates optimizer names in repro commands. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
40f3391 to
ce36150
Compare
There was a problem hiding this comment.
Pull request overview
This PR stabilizes the benchmark/diagnostics workflow after overnight failures, improves GPU robustness and memory behavior (notably for Optax on JaxEngine), introduces unique optimizer keys for unambiguous CLI selection, and adds/updates benchmark documentation and published-FF evaluation fixtures for multiple transition-state systems.
Changes:
- Make the benchmark CLI/matrix runner more robust (per-optimizer defaults, exact optimizer-key matching, better logging/output, failure summaries, non-zero exit when all combos fail, best-effort JAX OOM recovery).
- Route OptaxOptimizer gradients through
JaxLosswhen running onJaxEngineto avoid GPU OOM from Hessian-parameter Jacobian materialization; add JaxOpt NaN/divergence detection with parameter revert. - Add multi-system benchmark docs + new system pages and fixtures, plus minor warning-noise reductions (demote to debug).
Reviewed changes
Copilot reviewed 28 out of 28 changed files in this pull request and generated 4 comments.
Show a summary per file
| File | Description |
|---|---|
| test/integration/test_optax.py | Updates expected jac_mode for Optax-on-JAX to the new JaxLoss gradient path label. |
| test/fixtures/published_ff/rh_conjugate_wahlers2022.json | Adds published-FF re-evaluation fixture for Rh 1,4-conjugate system. |
| test/fixtures/published_ff/pd_conjugate_wahlers2021.json | Adds published-FF re-evaluation fixture for Pd 1,4-conjugate system. |
| test/fixtures/published_ff/pd_allyl_wahlers2021.json | Adds published-FF re-evaluation fixture for Pd-allyl system. |
| scripts/run_rh_enamide_selected_matrix.sh | Updates benchmark script to use new unique optimizer keys. |
| scripts/run_rh_benchmarks.sh | Updates benchmark runner script to new optimizer keys and retains venv activation behavior. |
| scripts/bench_overnight.sh | Adds an overnight multi-system GPU benchmark loop script. |
| q2mm/optimizers/optax.py | Switches Optax gradient computation to JaxLoss when using JaxEngine to reduce GPU memory usage; adjusts jac_mode label. |
| q2mm/optimizers/jaxopt_opt.py | Adds explicit NaN/Inf divergence detection and revert-to-initial behavior for unstable runs. |
| q2mm/models/seminario.py | Demotes “no matches” warnings to debug to reduce noise in expected scenarios. |
| q2mm/io/openmm.py | Demotes wildcard-skip warnings to debug to reduce expected-output noise. |
| q2mm/io/gaussian.py | Demotes non-AU Hessian warning to debug to reduce expected-output noise. |
| q2mm/diagnostics/systems.py | Adds metal vdW injection hook for Wahlers loaders (e.g., Pd) to unblock Pd benchmark systems. |
| q2mm/diagnostics/cli.py | Introduces unique optimizer keys + exact matching, per-optimizer iteration defaults, improved matrix logging/output, and failure handling. |
| q2mm/diagnostics/benchmark.py | Changes maxiter handling to allow per-optimizer defaults and adds auto-regularization for unbounded jaxopt:lbfgs. |
| properdocs.yml | Adds new benchmark pages to the docs navigation. |
| docs/how-it-works/optimization-guide.md | Adds GPU optimizer recommendation guidance based on new multi-system results. |
| docs/comparison.md | Adds a published-FF comparison section linking to per-system benchmark pages. |
| docs/benchmarks/small-molecules.md | Updates CLI examples to use new optimizer keys. |
| docs/benchmarks/rh-enamide.md | Refactors Rh-enamide benchmark page to the new single-system GPU summary format + optimizer keys. |
| docs/benchmarks/rh-conjugate.md | Adds a new benchmark page for Rh 1,4-conjugate system. |
| docs/benchmarks/published-ff-validation.md | Adds a cross-system optimization-vs-published comparison section with links. |
| docs/benchmarks/pd-conjugate.md | Adds a new benchmark page for Pd 1,4-conjugate system. |
| docs/benchmarks/pd-allyl.md | Adds a new benchmark page for Pd-allyl system. |
| docs/benchmarks/optimizer-comparison.md | Adds a cross-system optimizer comparison page summarizing the multi-system GPU shootout. |
| docs/benchmarks/index.md | Updates the benchmarks index to reflect the expanded system pages and optimizer comparison. |
| docs/benchmarks/heck-relay.md | Adds a new benchmark page for Heck relay system. |
| docs/benchmarks/gpu.md | Adds a multi-system GPU results summary and updates CLI examples to new optimizer keys. |
Comments suppressed due to low confidence (1)
q2mm/diagnostics/cli.py:151
- _run_matrix() now documents max_iter as “int | None” (optimizer-specific defaults when None), but the function signature still defaults max_iter to 10_000. If any internal callers rely on the default, they’ll bypass the new per-optimizer defaults. Consider changing the parameter to
max_iter: int | None = Noneto match the CLI behavior and the docstring.
def _run_matrix(
backends: list[tuple[str, type, str]],
optimizers: list[tuple[str, str, dict]],
forms: list[tuple[str, str]],
output_dir: Path | None = None,
*,
leaderboard_only: bool = False,
data_dir: Path | None = None,
platform: str | None = None,
system_key: str = "ch3f",
max_iter: int = 10_000,
) -> list:
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
a5c599c to
460165a
Compare
121e1f3 to
3e7fbef
Compare
3e7fbef to
de33ccd
Compare
de33ccd to
7d0d8b0
Compare
Remove all benchmark artifacts (forcefields, results JSONs, scripts) from this repository. Canonical benchmark data now lives in the separate ericchansen/q2mm-data repository to keep this repo lean. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
Core implementation changes for the CLI regularization feature: - JaxLoss per-molecule JIT split: compile each molecule's loss function independently to prevent XLA compilation OOM on multi-molecule systems - SciPy L-BFGS-B routing through JaxLoss: auto-detect JaxEngine and use analytical gradients with configurable ratio_tol parameter - Optax/JaxOpt routing through JaxLoss Python dispatch - MM3 V2 torsion phase correction (γ=180° not 0°) - Bond-dipole electrostatics in JaxEngine (MM3 P3 column) - ForceField frozen parameter support and active_mask - JaxLoss geometry relaxation stabilization (gradient retreat on NaN) - ScipyOptimizer ratio_tol param (None bypasses check for TS systems) - CLI: add scipy-lbfgsb-jax key for direct JaxLoss analytical gradients Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
7d0d8b0 to
17a457d
Compare
The JaxLoss geometry relaxation used an artificial harmonic restraint (k=100 kcal/mol·Å²) that caused a systematic disagreement with ObjectiveFunction. JaxLoss/ObjectiveFunction ratios were 0.1–0.4 for all TS systems, meaning JaxLoss optimization produced 0% real ObjectiveFunction improvement. Root cause: the restraint comment claimed it "prevents divergence for TS systems with negative force constants" — but Q2MM inverts TS curvature (Limé & Norrby 2015) before Seminario projection, producing all-positive force constants. The MM FF has no negative FCs, so unconstrained minimization is safe. Changes: - Remove _GEOM_RESTRAINT_K and harmonic restraint from _relax_coords() - Remove dead invert_ts_curvature machinery on MM Hessian (was a no-op) - Remove MoleculeSpec.invert_ts_curvature field and ts_mol_indices param - Remove stale comments referencing restraint/negative FCs in scipy_opt - Re-enable ratio_tol=0.15 in CLI scipy-lbfgsb-jax config - Remove TestJaxLossTsCurvatureInversion (covered by unit tests) Result: rh-enamide ratio is now 1.047 (PASS), 28.7% real improvement. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
17a457d to
0c4d236
Compare
Summary
Per-molecule JaxLoss JIT compilation, SciPy optimizer routing through JaxLoss analytical gradients, MM3 correctness fixes, and a critical bugfix for JaxLoss geometry relaxation that was producing 0% real improvement.
Key changes
Critical bugfix: JaxLoss harmonic restraint removal
JaxLoss's
_relax_coords()had an artificial harmonic restraint (k=100 kcal/mol·Å²) that caused a systematic disagreement with ObjectiveFunction. JaxLoss/ObjectiveFunction ratios were 0.1–0.4 for all TS systems — meaning JaxLoss "optimization" produced 0% real ObjectiveFunction improvement.Root cause: The restraint comment claimed it "prevents divergence for TS systems with negative force constants" — but Q2MM inverts TS curvature (Limé & Norrby 2015) before Seminario projection, producing all-positive force constants. The MM FF has no negative FCs → unconstrained minimization is safe → the restraint was unnecessary.
After fix: Rh-enamide ratio is 1.047 (PASS ✓), with 28.7% real ObjectiveFunction improvement in 8 L-BFGS-B iterations.
JaxLoss & optimizer infrastructure
value_and_gradis JIT-compiled independently to prevent XLA OOM on multi-molecule systemsratio_tol=0.15): validates JaxLoss agrees with ObjectiveFunction within ±15%, falls back to FD if not(1e30, zeros)sentinel instead of propagating bad valuesMM3 correctness
CLI & infrastructure
scipy-lbfgsb,scipy-lbfgsb-jax,jaxopt-lbfgs, etc.)ericchansen/q2mm-dataValidation results (GPU, RTX 5090)
Systems with poor Seminario starting FFs (deeply negative R²) fail the ratio check because unconstrained geometry relaxation wanders to wrong minima. The ratio check correctly detects this. Future work: improve starting FFs or add adaptive restraint.
Tests
ruff check+ruff format --check)