diff --git a/docs/src/engines/ase.rst b/docs/src/engines/ase.rst index 2bcdd8400..cab3722f5 100644 --- a/docs/src/engines/ase.rst +++ b/docs/src/engines/ase.rst @@ -34,6 +34,34 @@ How to install the code The code is available in the ``metatomic-torch`` package, in the :py:class:`metatomic.torch.ase_calculator.MetatomicCalculator` class. +Supported model inputs +^^^^^^^^^^^^^^^^^^^^^^ + +The ASE calculator can provide per-atom inputs (e.g. ``"charges"``, +``"momenta"``, ``"velocities"``) as well as the following **system-level** +integer inputs used for model conditioning: + +.. list-table:: + :header-rows: 1 + :widths: 2 3 5 + + * - Input name + - Default + - How to set + * - ``"charge"`` + - ``0`` + - ``atoms.info["charge"] = `` + * - ``"spin"`` + - ``1`` + - ``atoms.info["spin"] = `` + +``"charge"`` is the total charge of the simulation cell in elementary +charges. ``"spin"`` is the spin multiplicity (2S+1) — a singlet is +``spin=1``, a doublet is ``spin=2``, a triplet is ``spin=3``, and so on. +Both values are read as integers from ``atoms.info`` and stored in the +system as the model's floating-point dtype (float32 or float64); the model +converts them back to integers internally for the embedding lookup. + How to use the code ^^^^^^^^^^^^^^^^^^^ diff --git a/docs/src/engines/lammps.rst b/docs/src/engines/lammps.rst index 661b3053b..b7fa84a93 100644 --- a/docs/src/engines/lammps.rst +++ b/docs/src/engines/lammps.rst @@ -331,6 +331,7 @@ documentation. non-conservative stress predictions. Overrides the value given to the ``variant`` keyword. Defaults to no variant. + Examples -------- diff --git a/metatomic-torch/src/misc.cpp b/metatomic-torch/src/misc.cpp index ab4a24982..7da67692a 100644 --- a/metatomic-torch/src/misc.cpp +++ b/metatomic-torch/src/misc.cpp @@ -427,6 +427,8 @@ inline std::unordered_set KNOWN_INPUTS_OUTPUTS = { "velocities", "masses", "charges", + "charge", + "spin", }; std::tuple details::validate_name_and_check_variant( diff --git a/python/metatomic_torch/metatomic/torch/ase_calculator.py b/python/metatomic_torch/metatomic/torch/ase_calculator.py index a4919d654..87d9cd2aa 100644 --- a/python/metatomic_torch/metatomic/torch/ase_calculator.py +++ b/python/metatomic_torch/metatomic/torch/ase_calculator.py @@ -58,6 +58,32 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray: return atoms.get_initial_charges() +SYSTEM_QUANTITIES = { + "charge": { + "quantity": "charge", + "getter": lambda atoms: np.array([[atoms.info.get("charge", 0)]]), + "unit": "e", + "info_key": "charge", + "default": 0, + }, + "spin": { + "quantity": "spin", + "getter": lambda atoms: np.array([[atoms.info.get("spin", 1)]]), + "unit": "", + "info_key": "spin", + "default": 1, + }, +} +""" +Per-system scalar inputs provided by ASE via ``atoms.info``. + +- ``"charge"``: total system charge in elementary charges, read from + ``atoms.info["charge"]``, defaults to ``0``. +- ``"spin"``: spin multiplicity (2S+1), read from + ``atoms.info["spin"]``, defaults to ``1``. +""" + + ARRAY_QUANTITIES = { "momenta": { "quantity": "momentum", @@ -284,11 +310,19 @@ def __init__( outputs, resolved_variants["non_conservative_forces"], ) - self._nc_stress_key = pick_output( - "non_conservative_stress", - outputs, - resolved_variants["non_conservative_stress"], + has_nc_stress = any( + key == "non_conservative_stress" + or key.startswith("non_conservative_stress/") + for key in outputs.keys() ) + if has_nc_stress: + self._nc_stress_key = pick_output( + "non_conservative_stress", + outputs, + resolved_variants["non_conservative_stress"], + ) + else: + self._nc_stress_key = None else: self._nc_forces_key = "non_conservative_forces" self._nc_stress_key = "non_conservative_stress" @@ -308,6 +342,15 @@ def __init__( self._model = model.to(device=self._device) + # Cache which atoms.info keys need change-detection so that check_state + # does only plain Python list iteration on every MD step, avoiding a + # TorchScript JIT dispatch per step to requested_inputs(). + self._system_info_watch: List[Tuple[str, int]] = [ + (infos["info_key"], infos["default"]) + for name, infos in SYSTEM_QUANTITIES.items() + if name in self._model.requested_inputs() + ] + self._calculate_uncertainty = ( self._energy_uq_key in self._model.capabilities().outputs # we require per-atom uncertainties to capture local effects @@ -422,6 +465,34 @@ def run_model( check_consistency=self.parameters["check_consistency"], ) + def check_state(self, atoms: ase.Atoms, tol: float = 1e-15) -> List[str]: + """Detect system changes, including ``atoms.info`` keys used as model inputs. + + ASE's default :py:meth:`~ase.calculators.calculator.Calculator.check_state` + only tracks per-atom arrays (positions, numbers, …) and cell/pbc. Changes + to ``atoms.info["charge"]`` or ``atoms.info["spin"]`` are invisible to it, + causing stale cached results when the charge or spin is updated between calls. + + This override appends the name of any ``atoms.info`` key that has changed + since the last calculation to the standard change list, which forces a + fresh calculation. + """ + changes = super().check_state(atoms, tol=tol) + if self.atoms is not None: + for key, default in self._system_info_watch: + old = self.atoms.info.get(key, default) + new = atoms.info.get(key, default) + try: + equal = old == new + # numpy arrays and similar objects return array-like booleans; + # treat anything that is not a plain bool as "changed" to be safe + if not isinstance(equal, bool) or not equal: + changes.append(key) + except Exception: + # comparison raised (e.g. mixed types); assume changed + changes.append(key) + return changes + def calculate( self, atoms: ase.Atoms, @@ -484,7 +555,8 @@ def calculate( if self.parameters["do_gradients_with_energy"]: if calculate_energies or calculate_energy: calculate_forces = True - calculate_stress = True + if atoms.pbc.all(): + calculate_stress = True with record_function("MetatomicCalculator::prepare_inputs"): outputs = self._ase_properties_to_metatensor_outputs( @@ -634,16 +706,19 @@ def calculate( forces_values = forces_values.cpu().double() self.results["forces"] = forces_values.numpy() - if calculate_stress: - if self.parameters["non_conservative"]: + if calculate_stress and atoms.pbc.all(): + if self.parameters["non_conservative"] and self._nc_stress_key is not None: stress_values = outputs[self._nc_stress_key].block().values.detach() - else: + elif not self.parameters["non_conservative"]: stress_values = strain.grad / atoms.cell.volume - stress_values = stress_values.reshape(3, 3) - stress_values = stress_values.cpu().double() - self.results["stress"] = _full_3x3_to_voigt_6_stress( - stress_values.numpy() - ) + else: + stress_values = None + if stress_values is not None: + stress_values = stress_values.reshape(3, 3) + stress_values = stress_values.cpu().double() + self.results["stress"] = _full_3x3_to_voigt_6_stress( + stress_values.numpy() + ) self.additional_outputs = {} for name in self._additional_output_requests: @@ -720,6 +795,11 @@ def compute_energy( cell = cell @ strain strains.append(strain) system = System(types, positions, cell, pbc) + for name, option in self._model.requested_inputs().items(): + input_tensormap = _get_ase_input( + atoms, name, option, dtype=self._dtype, device=self._device + ) + system.add_data(name, input_tensormap) systems.append(system) # Compute the neighbors lists requested by the model @@ -803,7 +883,7 @@ def compute_energy( for f in results_as_numpy_arrays["forces"] ] - if all(atoms.pbc.all() for atoms in atoms_list): + if all(atoms.pbc.all() for atoms in atoms_list) and self._nc_stress_key is not None: results_as_numpy_arrays["stress"] = [ s for s in predictions[self._nc_stress_key] @@ -870,7 +950,7 @@ def _ase_properties_to_metatensor_outputs( per_atom=True, ) - if calculate_stress and self.parameters["non_conservative"]: + if calculate_stress and self.parameters["non_conservative"] and self._nc_stress_key is not None: metatensor_outputs[self._nc_stress_key] = ModelOutput( quantity="pressure", unit="eV/Angstrom^3", @@ -897,9 +977,27 @@ def _get_ase_input( dtype: torch.dtype, device: torch.device, ) -> "TensorMap": + if name in SYSTEM_QUANTITIES: + infos = SYSTEM_QUANTITIES[name] + # shape: (1, 1) — one system, one scalar property + values = torch.tensor(infos["getter"](atoms), dtype=dtype, device=device) + block = TensorBlock( + values, + samples=Labels(["system"], torch.tensor([[0]], device=device)), + components=[], + properties=Labels([infos["quantity"]], torch.tensor([[0]], device=device)), + ) + tensor = TensorMap(Labels(["_"], torch.tensor([[0]], device=device)), [block]) + tensor.set_info("quantity", infos["quantity"]) + tensor.set_info("unit", infos["unit"]) + return tensor + if name not in ARRAY_QUANTITIES: raise ValueError( - f"The model requested '{name}', which is not available in `ase`." + f"The model requested '{name}', which is not available in `ase`. " + "System-level quantities like 'charge' or 'spin' can be " + "set via atoms.info['charge'] and atoms.info['spin'] " + "respectively." ) infos = ARRAY_QUANTITIES[name] diff --git a/python/metatomic_torch/tests/ase_calculator.py b/python/metatomic_torch/tests/ase_calculator.py index 0b84aa565..0b8e35ce0 100644 --- a/python/metatomic_torch/tests/ase_calculator.py +++ b/python/metatomic_torch/tests/ase_calculator.py @@ -951,6 +951,187 @@ def test_additional_input(atoms): assert np.allclose(values, expected) +def test_system_level_input(atoms): + """charge and spin are per-system integer inputs read from atoms.info.""" + inputs = { + "charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "spin": ModelOutput(quantity="spin", unit="", per_atom=False), + } + outputs = {("extra::" + n): inputs[n] for n in inputs} + capabilities = ModelCapabilities( + outputs=outputs, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + + model = AtomisticModel( + AdditionalInputModel(inputs).eval(), ModelMetadata(), capabilities + ) + atoms.info["charge"] = -2 + atoms.info["spin"] = 3 + calculator = MetatomicCalculator(model, check_consistency=False) + results = calculator.run_model(atoms, outputs) + + charge_tensor = results["extra::charge"] + assert charge_tensor[0].samples.names == ["system"] + assert charge_tensor[0].values.dtype == torch.float64 # matches model dtype + assert int(charge_tensor[0].values[0, 0]) == -2 + + spin_tensor = results["extra::spin"] + assert spin_tensor[0].samples.names == ["system"] + assert spin_tensor[0].values.dtype == torch.float64 # matches model dtype + assert int(spin_tensor[0].values[0, 0]) == 3 + + +def test_system_level_input_defaults(atoms): + """charge defaults to 0 and spin to 1 when not set in atoms.info.""" + inputs = { + "charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "spin": ModelOutput(quantity="spin", unit="", per_atom=False), + } + outputs = {("extra::" + n): inputs[n] for n in inputs} + capabilities = ModelCapabilities( + outputs=outputs, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + + model = AtomisticModel( + AdditionalInputModel(inputs).eval(), ModelMetadata(), capabilities + ) + # ensure the keys are absent + atoms.info.pop("charge", None) + atoms.info.pop("spin", None) + calculator = MetatomicCalculator(model, check_consistency=False) + results = calculator.run_model(atoms, outputs) + + assert int(results["extra::charge"][0].values[0, 0]) == 0 + assert int(results["extra::spin"][0].values[0, 0]) == 1 + + +class ChargeSpinEnergyModel(torch.nn.Module): + """Minimal energy model whose output depends on charge and spin. + + Returns energy = charge_value + 10 * spin_value so that different + charge/spin inputs always produce different energies. + """ + + def requested_inputs(self) -> Dict[str, ModelOutput]: + return { + "charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "spin": ModelOutput(quantity="spin", unit="", per_atom=False), + } + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + system = systems[0] + charge = float(system.get_data("charge").block(0).values[0, 0]) + spin = float(system.get_data("spin").block(0).values[0, 0]) + energy_value = charge + 10.0 * spin + block = TensorBlock( + values=torch.tensor([[energy_value]], dtype=torch.float64), + samples=Labels("system", torch.tensor([[0]])), + components=torch.jit.annotate(List[Labels], []), + properties=Labels("energy", torch.tensor([[0]])), + ) + return {"energy": TensorMap(Labels("_", torch.tensor([[0]])), [block])} + + +def test_system_level_input_changes_energy(atoms): + """Different charge/spin values must produce different energies.""" + capabilities = ModelCapabilities( + outputs={"energy": ModelOutput(per_atom=False)}, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + model = AtomisticModel( + ChargeSpinEnergyModel().eval(), ModelMetadata(), capabilities + ) + + # --- varying charge --- + atoms.info["spin"] = 1 + atoms.info["charge"] = 0 + calc = MetatomicCalculator(model, check_consistency=False) + atoms.calc = calc + e_neutral = atoms.get_potential_energy() + + atoms.info["charge"] = 2 + atoms.calc.reset() + e_charged = atoms.get_potential_energy() + + assert e_neutral != e_charged, "Different charges must give different energies" + + # --- varying spin --- + atoms.info["charge"] = 0 + atoms.info["spin"] = 1 + atoms.calc.reset() + e_singlet = atoms.get_potential_energy() + + atoms.info["spin"] = 3 + atoms.calc.reset() + e_triplet = atoms.get_potential_energy() + + assert e_singlet != e_triplet, "Different spins must give different energies" + + # --- cache invalidation: check_state detects atoms.info changes --- + atoms.info["charge"] = 0 + atoms.info["spin"] = 1 + atoms.calc.reset() + e_before = atoms.get_potential_energy() + + atoms.info["charge"] = 1 # change without explicit reset + e_after = atoms.get_potential_energy() + + assert e_before != e_after, ( + "check_state must invalidate cache when atoms.info['charge'] changes" + ) + + +def test_system_level_input_export_roundtrip(atoms, tmp_path): + """Export a charge/spin model to disk and reload via MetatomicCalculator. + + Covers the full pipeline: build → export → save(".pt") → load from file → + run with atoms.info["charge"]/["spin"]. This is the path exercised by + end-users who load a saved model, so it must work end-to-end. + """ + capabilities = ModelCapabilities( + outputs={"energy": ModelOutput(per_atom=False)}, + atomic_types=[28], + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ) + model = AtomisticModel( + ChargeSpinEnergyModel().eval(), ModelMetadata(), capabilities + ) + model_path = str(tmp_path / "charge_spin_model.pt") + model.save(model_path) + + atoms.info["charge"] = 0 + atoms.info["spin"] = 1 + calc = MetatomicCalculator(model_path, check_consistency=True) + atoms.calc = calc + e_neutral = atoms.get_potential_energy() + + atoms.info["charge"] = 2 + atoms.calc.reset() + e_charged = atoms.get_potential_energy() + + assert e_neutral != e_charged, ( + "Loaded model must produce charge-dependent energies" + ) + + @pytest.mark.parametrize("device,dtype", ALL_DEVICE_DTYPE) def test_mixed_pbc(model, device, dtype): """Test that the calculator works on a mixed-PBC system""" diff --git a/python/metatomic_torch/tests/symmetrized_ase_calculator.py b/python/metatomic_torch/tests/symmetrized_ase_calculator.py index 3e1e83a96..06d004305 100644 --- a/python/metatomic_torch/tests/symmetrized_ase_calculator.py +++ b/python/metatomic_torch/tests/symmetrized_ase_calculator.py @@ -485,6 +485,105 @@ def test_choose_quadrature_rules(): assert n_gamma == 2 * L + 1 +# -- Charge / spin conditioning tests ---------------------------------------- + + +class _ChargeSpinAnisoModel(torch.nn.Module): + """Minimal model whose energy = charge + 10*spin + P1(cos θ). + + The P1(cos θ) term is orientation-dependent and cancels exactly under O(3) + rotational averaging (Lebedev l_max >= 1). What remains is + ``charge + 10*spin``, which is rotation-invariant. + + This lets us verify two things in a single test: + - charge/spin values from ``atoms.info`` reach the model in every rotated + copy (the bug fixed in ``compute_energy``). + - the orientation-dependent part is correctly averaged away. + """ + + def requested_inputs(self) -> Dict[str, ModelOutput]: + return { + "charge": ModelOutput(quantity="charge", unit="e", per_atom=False), + "spin": ModelOutput(quantity="spin", unit="", per_atom=False), + } + + def forward( + self, + systems: List[System], + outputs: Dict[str, ModelOutput], + selected_atoms: Optional[Labels] = None, + ) -> Dict[str, TensorMap]: + energies: List[torch.Tensor] = [] + for system in systems: + charge = system.get_data("charge").block(0).values[0, 0] + spin = system.get_data("spin").block(0).values[0, 0] + + # Orientation-dependent P1 term (averages to zero under O(3)) + b = _body_axis_from_system(system).to(dtype=charge.dtype) + zhat = torch.tensor( + [0.0, 0.0, 1.0], dtype=charge.dtype, device=charge.device + ) + P1 = torch.dot(b, zhat) + + energies.append((charge + 10.0 * spin + P1).reshape(1, 1)) + + values = torch.cat(energies, dim=0) + block = TensorBlock( + values=values, + samples=Labels( + "system", + torch.arange( + len(systems), dtype=torch.int32 + ).reshape(-1, 1), + ), + components=torch.jit.annotate(List[Labels], []), + properties=Labels("energy", torch.tensor([[0]])), + ) + return {"energy": TensorMap(Labels("_", torch.tensor([[0]])), [block])} + + +def _charge_spin_calculator() -> mta.ase_calculator.MetatomicCalculator: + """Wrap _ChargeSpinAnisoModel in a MetatomicCalculator.""" + atomistic_model = mta.AtomisticModel( + _ChargeSpinAnisoModel().eval(), + mta.ModelMetadata(), + mta.ModelCapabilities( + outputs={"energy": mta.ModelOutput(per_atom=False)}, + atomic_types=list(range(1, 10)), + interaction_range=0.0, + supported_devices=["cpu"], + dtype="float64", + ), + ) + calc = mta.ase_calculator.MetatomicCalculator(atomistic_model) + return calc + + +@pytest.mark.parametrize("charge,spin", [(0.0, 1.0), (2.0, 1.0), (-1.0, 3.0)]) +def test_symmetrized_calculator_passes_charge_spin( + dimer: Atoms, charge: float, spin: float +) -> None: + """SymmetrizedCalculator must pass charge/spin to each rotated evaluation. + + The model returns ``charge + 10*spin + P1(cos θ)``. After O(3) averaging + the P1 term cancels, so the result must equal ``charge + 10*spin`` exactly. + If charge/spin were silently dropped, every evaluation would use the default + values (0 and 1) and the test would fail for non-default inputs. + """ + dimer.info["charge"] = charge + dimer.info["spin"] = spin + + base = _charge_spin_calculator() + calc = SymmetrizedCalculator(base, l_max=3, include_inversion=True) + dimer.calc = calc + energy = dimer.get_potential_energy() + + expected = charge + 10.0 * spin + assert np.isclose(energy, expected, atol=1e-8), ( + f"Expected energy={expected} for charge={charge}, spin={spin}, got {energy}" + ) + + def test_get_quadrature_properties(): """Check properties of the quadrature returned by _get_quadrature.""" from metatomic.torch.ase_calculator import _get_quadrature