Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 11 additions & 3 deletions python/metatomic_torchsim/metatomic_torchsim/_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
System,
load_atomistic_model,
pick_device,
pick_output,
)


Expand Down Expand Up @@ -75,6 +76,7 @@ def __init__(
check_consistency: bool = False,
compute_forces: bool = True,
compute_stress: bool = True,
variants: Optional[Dict[str, Optional[str]]] = None,
) -> None:
"""Initialize the metatomic model wrapper.

Expand All @@ -91,6 +93,9 @@ def __init__(
Useful for debugging but hurts performance.
:param compute_forces: Compute atomic forces via autograd.
:param compute_stress: Compute stress tensors via the strain trick.
:param variants: Dictionary mapping output names to variant names. If not
provided, the default variant is used for all outputs. See
:py:func:`metatomic.torch.pick_output` for details on variant selection.
"""
super().__init__()

Expand Down Expand Up @@ -149,6 +154,10 @@ def __init__(
"Only models with energy outputs can be used with TorchSim."
)

# Resolve output variants
variants = variants or {}
self._energy_key = pick_output("energy", capabilities.outputs, variants.get("energy"))

self._model = model.to(device=self._device)
self._compute_forces = compute_forces
self._compute_stress = compute_stress
Expand All @@ -158,7 +167,7 @@ def __init__(
self._evaluation_options = ModelEvaluationOptions(
length_unit="angstrom",
outputs={
"energy": ModelOutput(quantity="energy", unit="eV", per_atom=False)
self._energy_key: ModelOutput(quantity="energy", unit="eV", per_atom=False)
},
)

Expand Down Expand Up @@ -243,8 +252,7 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]:
check_consistency=self._check_consistency,
)

energy_values = model_outputs["energy"].block().values

energy_values = model_outputs[self._energy_key].block().values
results: Dict[str, torch.Tensor] = {}
results["energy"] = energy_values.detach().squeeze(-1)

Expand Down
12 changes: 12 additions & 0 deletions python/metatomic_torchsim/tests/test_model_loading.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
dummy_scripted = torch.jit.script(Dummy())
with pytest.raises(TypeError, match="must be 'AtomisticModel'"):
MetatomicModel(model=dummy_scripted, device=DEVICE)



def test_variants_parameter_accepted(lj_model):
"""Variants parameter is accepted even for models without variants."""
# The LJ test model has no variants, but the parameter should be accepted
model = MetatomicModel(model=lj_model, device=DEVICE, variants=None)
assert model._energy_key == "energy"

# Explicit empty variants dict should also work
model = MetatomicModel(model=lj_model, device=DEVICE, variants={})
assert model._energy_key == "energy"