Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions docs/src/engines/ase.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,34 @@ How to install the code
The code is available in the ``metatomic-torch`` package, in the
:py:class:`metatomic.torch.ase_calculator.MetatomicCalculator` class.

Supported model inputs
^^^^^^^^^^^^^^^^^^^^^^

The ASE calculator can provide per-atom inputs (e.g. ``"charges"``,
``"momenta"``, ``"velocities"``) as well as the following **system-level**
integer inputs used for model conditioning:

.. list-table::
:header-rows: 1
:widths: 2 3 5

* - Input name
- Default
- How to set
* - ``"charge"``
- ``0``
- ``atoms.info["charge"] = <int>``
* - ``"spin"``
- ``1``
- ``atoms.info["spin"] = <int>``

``"charge"`` is the total charge of the simulation cell in elementary
charges. ``"spin"`` is the spin multiplicity (2S+1) — a singlet is
``spin=1``, a doublet is ``spin=2``, a triplet is ``spin=3``, and so on.
Both values are read as integers from ``atoms.info`` and stored in the
system as the model's floating-point dtype (float32 or float64); the model
converts them back to integers internally for the embedding lookup.
Comment on lines +62 to +63
Copy link
Contributor

@frostedoyster frostedoyster Mar 23, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wouldn't necessarily say this in the documentation, simply because the implementation will potentially change (I think we're soon going to support integer TensorMaps, right @Luthaf?). I wouldn't say how the model is going to handle them either: the point of metatomic is to be totally generic, each model can handle these inputs how it wants (including ignoring them).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, integer TensorMap are already mostly supported at runtime and will soon be able to be serialized as well.


How to use the code
^^^^^^^^^^^^^^^^^^^

Expand Down
1 change: 1 addition & 0 deletions docs/src/engines/lammps.rst
Original file line number Diff line number Diff line change
Expand Up @@ -331,6 +331,7 @@ documentation.
non-conservative stress predictions. Overrides the value given to the
``variant`` keyword. Defaults to no variant.


Examples
--------

Expand Down
2 changes: 2 additions & 0 deletions metatomic-torch/src/misc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -427,6 +427,8 @@ inline std::unordered_set<std::string> KNOWN_INPUTS_OUTPUTS = {
"velocities",
"masses",
"charges",
"charge",
"spin",
};

std::tuple<bool, std::string, std::string> details::validate_name_and_check_variant(
Expand Down
130 changes: 114 additions & 16 deletions python/metatomic_torch/metatomic/torch/ase_calculator.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,32 @@ def _get_charges(atoms: ase.Atoms) -> np.ndarray:
return atoms.get_initial_charges()


SYSTEM_QUANTITIES = {
"charge": {
"quantity": "charge",
"getter": lambda atoms: np.array([[atoms.info.get("charge", 0)]]),
"unit": "e",
"info_key": "charge",
"default": 0,
},
"spin": {
"quantity": "spin",
"getter": lambda atoms: np.array([[atoms.info.get("spin", 1)]]),
"unit": "",
"info_key": "spin",
"default": 1,
},
}
"""
Per-system scalar inputs provided by ASE via ``atoms.info``.

- ``"charge"``: total system charge in elementary charges, read from
``atoms.info["charge"]``, defaults to ``0``.
- ``"spin"``: spin multiplicity (2S+1), read from
``atoms.info["spin"]``, defaults to ``1``.
"""


ARRAY_QUANTITIES = {
"momenta": {
"quantity": "momentum",
Expand Down Expand Up @@ -284,11 +310,19 @@ def __init__(
outputs,
resolved_variants["non_conservative_forces"],
)
self._nc_stress_key = pick_output(
"non_conservative_stress",
outputs,
resolved_variants["non_conservative_stress"],
has_nc_stress = any(
key == "non_conservative_stress"
or key.startswith("non_conservative_stress/")
for key in outputs.keys()
)
if has_nc_stress:
self._nc_stress_key = pick_output(
"non_conservative_stress",
outputs,
resolved_variants["non_conservative_stress"],
)
else:
self._nc_stress_key = None
else:
self._nc_forces_key = "non_conservative_forces"
self._nc_stress_key = "non_conservative_stress"
Expand All @@ -308,6 +342,15 @@ def __init__(

self._model = model.to(device=self._device)

# Cache which atoms.info keys need change-detection so that check_state
# does only plain Python list iteration on every MD step, avoiding a
# TorchScript JIT dispatch per step to requested_inputs().
self._system_info_watch: List[Tuple[str, int]] = [
(infos["info_key"], infos["default"])
for name, infos in SYSTEM_QUANTITIES.items()
if name in self._model.requested_inputs()
]

self._calculate_uncertainty = (
self._energy_uq_key in self._model.capabilities().outputs
# we require per-atom uncertainties to capture local effects
Expand Down Expand Up @@ -422,6 +465,34 @@ def run_model(
check_consistency=self.parameters["check_consistency"],
)

def check_state(self, atoms: ase.Atoms, tol: float = 1e-15) -> List[str]:
"""Detect system changes, including ``atoms.info`` keys used as model inputs.

ASE's default :py:meth:`~ase.calculators.calculator.Calculator.check_state`
only tracks per-atom arrays (positions, numbers, …) and cell/pbc. Changes
to ``atoms.info["charge"]`` or ``atoms.info["spin"]`` are invisible to it,
causing stale cached results when the charge or spin is updated between calls.

This override appends the name of any ``atoms.info`` key that has changed
since the last calculation to the standard change list, which forces a
fresh calculation.
"""
changes = super().check_state(atoms, tol=tol)
if self.atoms is not None:
for key, default in self._system_info_watch:
old = self.atoms.info.get(key, default)
new = atoms.info.get(key, default)
try:
equal = old == new
# numpy arrays and similar objects return array-like booleans;
# treat anything that is not a plain bool as "changed" to be safe
if not isinstance(equal, bool) or not equal:
changes.append(key)
except Exception:
# comparison raised (e.g. mixed types); assume changed
changes.append(key)
return changes

def calculate(
self,
atoms: ase.Atoms,
Expand Down Expand Up @@ -484,7 +555,8 @@ def calculate(
if self.parameters["do_gradients_with_energy"]:
if calculate_energies or calculate_energy:
calculate_forces = True
calculate_stress = True
if atoms.pbc.all():
calculate_stress = True

with record_function("MetatomicCalculator::prepare_inputs"):
outputs = self._ase_properties_to_metatensor_outputs(
Expand Down Expand Up @@ -634,16 +706,19 @@ def calculate(
forces_values = forces_values.cpu().double()
self.results["forces"] = forces_values.numpy()

if calculate_stress:
if self.parameters["non_conservative"]:
if calculate_stress and atoms.pbc.all():
if self.parameters["non_conservative"] and self._nc_stress_key is not None:
stress_values = outputs[self._nc_stress_key].block().values.detach()
else:
elif not self.parameters["non_conservative"]:
stress_values = strain.grad / atoms.cell.volume
stress_values = stress_values.reshape(3, 3)
stress_values = stress_values.cpu().double()
self.results["stress"] = _full_3x3_to_voigt_6_stress(
stress_values.numpy()
)
else:
stress_values = None
if stress_values is not None:
stress_values = stress_values.reshape(3, 3)
stress_values = stress_values.cpu().double()
self.results["stress"] = _full_3x3_to_voigt_6_stress(
stress_values.numpy()
)

self.additional_outputs = {}
for name in self._additional_output_requests:
Expand Down Expand Up @@ -720,6 +795,11 @@ def compute_energy(
cell = cell @ strain
strains.append(strain)
system = System(types, positions, cell, pbc)
for name, option in self._model.requested_inputs().items():
input_tensormap = _get_ase_input(
atoms, name, option, dtype=self._dtype, device=self._device
)
system.add_data(name, input_tensormap)
systems.append(system)

# Compute the neighbors lists requested by the model
Expand Down Expand Up @@ -803,7 +883,7 @@ def compute_energy(
for f in results_as_numpy_arrays["forces"]
]

if all(atoms.pbc.all() for atoms in atoms_list):
if all(atoms.pbc.all() for atoms in atoms_list) and self._nc_stress_key is not None:
results_as_numpy_arrays["stress"] = [
s
for s in predictions[self._nc_stress_key]
Expand Down Expand Up @@ -870,7 +950,7 @@ def _ase_properties_to_metatensor_outputs(
per_atom=True,
)

if calculate_stress and self.parameters["non_conservative"]:
if calculate_stress and self.parameters["non_conservative"] and self._nc_stress_key is not None:
metatensor_outputs[self._nc_stress_key] = ModelOutput(
quantity="pressure",
unit="eV/Angstrom^3",
Expand All @@ -897,9 +977,27 @@ def _get_ase_input(
dtype: torch.dtype,
device: torch.device,
) -> "TensorMap":
if name in SYSTEM_QUANTITIES:
infos = SYSTEM_QUANTITIES[name]
# shape: (1, 1) — one system, one scalar property
values = torch.tensor(infos["getter"](atoms), dtype=dtype, device=device)
block = TensorBlock(
values,
samples=Labels(["system"], torch.tensor([[0]], device=device)),
components=[],
properties=Labels([infos["quantity"]], torch.tensor([[0]], device=device)),
)
tensor = TensorMap(Labels(["_"], torch.tensor([[0]], device=device)), [block])
tensor.set_info("quantity", infos["quantity"])
tensor.set_info("unit", infos["unit"])
return tensor

if name not in ARRAY_QUANTITIES:
raise ValueError(
f"The model requested '{name}', which is not available in `ase`."
f"The model requested '{name}', which is not available in `ase`. "
"System-level quantities like 'charge' or 'spin' can be "
"set via atoms.info['charge'] and atoms.info['spin'] "
"respectively."
)

infos = ARRAY_QUANTITIES[name]
Expand Down
Loading