diff --git a/python/metatomic_torchsim/metatomic_torchsim/_model.py b/python/metatomic_torchsim/metatomic_torchsim/_model.py index 09e7026db..00a00aba1 100644 --- a/python/metatomic_torchsim/metatomic_torchsim/_model.py +++ b/python/metatomic_torchsim/metatomic_torchsim/_model.py @@ -24,6 +24,7 @@ System, load_atomistic_model, pick_device, + pick_output, ) @@ -75,6 +76,7 @@ def __init__( check_consistency: bool = False, compute_forces: bool = True, compute_stress: bool = True, + variants: Optional[Dict[str, Optional[str]]] = None, ) -> None: """Initialize the metatomic model wrapper. @@ -91,6 +93,9 @@ def __init__( Useful for debugging but hurts performance. :param compute_forces: Compute atomic forces via autograd. :param compute_stress: Compute stress tensors via the strain trick. + :param variants: Dictionary mapping output names to variant names. If not + provided, the default variant is used for all outputs. See + :py:func:`metatomic.torch.pick_output` for details on variant selection. """ super().__init__() @@ -149,6 +154,10 @@ def __init__( "Only models with energy outputs can be used with TorchSim." ) + # Resolve output variants + variants = variants or {} + self._energy_key = pick_output("energy", capabilities.outputs, variants.get("energy")) + self._model = model.to(device=self._device) self._compute_forces = compute_forces self._compute_stress = compute_stress @@ -158,7 +167,7 @@ def __init__( self._evaluation_options = ModelEvaluationOptions( length_unit="angstrom", outputs={ - "energy": ModelOutput(quantity="energy", unit="eV", per_atom=False) + self._energy_key: ModelOutput(quantity="energy", unit="eV", per_atom=False) }, ) @@ -243,8 +252,7 @@ def forward(self, state: "ts.SimState") -> Dict[str, torch.Tensor]: check_consistency=self._check_consistency, ) - energy_values = model_outputs["energy"].block().values - + energy_values = model_outputs[self._energy_key].block().values results: Dict[str, torch.Tensor] = {} results["energy"] = energy_values.detach().squeeze(-1) diff --git a/python/metatomic_torchsim/tests/test_model_loading.py b/python/metatomic_torchsim/tests/test_model_loading.py index 2dc5ef687..237fc03f9 100644 --- a/python/metatomic_torchsim/tests/test_model_loading.py +++ b/python/metatomic_torchsim/tests/test_model_loading.py @@ -58,3 +58,15 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: dummy_scripted = torch.jit.script(Dummy()) with pytest.raises(TypeError, match="must be 'AtomisticModel'"): MetatomicModel(model=dummy_scripted, device=DEVICE) + + + +def test_variants_parameter_accepted(lj_model): + """Variants parameter is accepted even for models without variants.""" + # The LJ test model has no variants, but the parameter should be accepted + model = MetatomicModel(model=lj_model, device=DEVICE, variants=None) + assert model._energy_key == "energy" + + # Explicit empty variants dict should also work + model = MetatomicModel(model=lj_model, device=DEVICE, variants={}) + assert model._energy_key == "energy" \ No newline at end of file