Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions tests/test_extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,9 @@ def test_store_model_extras_canonical_keys_not_stored(
"stress": torch.randn(state.n_systems, 3, 3),
}
)
assert not state._system_extras # noqa: SLF001
assert not state._atom_extras # noqa: SLF001
for key in ("energy", "forces", "stress"):
assert key not in state._system_extras # noqa: SLF001
assert key not in state._atom_extras # noqa: SLF001

def test_store_model_extras_per_system(self, si_double_sim_state: ts.SimState):
"""Tensors with leading dim == n_systems go into system_extras."""
Expand Down
14 changes: 7 additions & 7 deletions tests/test_nbody.py
Original file line number Diff line number Diff line change
Expand Up @@ -481,9 +481,9 @@ def test_build_triplets_device(device: str) -> None:

result = build_triplets(edge_index, n_atoms)

assert result["trip_in"].device == dev
assert result["trip_out"].device == dev
assert result["center_atom"].device == dev
assert result["trip_in"].device.type == dev.type
assert result["trip_out"].device.type == dev.type
assert result["center_atom"].device.type == dev.type


@pytest.mark.parametrize(
Expand All @@ -507,10 +507,10 @@ def test_build_quadruplets_device(device: str) -> None:
internal_cell_offsets,
)

assert result["quad_c_to_a_edge"].device == dev
assert result["quad_d_to_b_trip_idx"].device == dev
assert result["d_to_b_edge"].device == dev
assert result["c_to_a_edge"].device == dev
assert result["quad_c_to_a_edge"].device.type == dev.type
assert result["quad_d_to_b_trip_idx"].device.type == dev.type
assert result["d_to_b_edge"].device.type == dev.type
assert result["c_to_a_edge"].device.type == dev.type


def test_build_triplets_jit_script() -> None:
Expand Down
2 changes: 2 additions & 0 deletions torch_sim/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,6 +680,8 @@ def get_cart_deformed_cell(state: SimState, axis: int = 0, size: float = 1.0) ->
masses=state.masses,
pbc=state.pbc,
atomic_numbers=state.atomic_numbers,
_system_extras=state._system_extras,
_atom_extras=state._atom_extras,
)


Expand Down
30 changes: 22 additions & 8 deletions torch_sim/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,16 +95,22 @@ def state_to_atoms(

# Write system extras to atoms.info
# charge/spin stored as int scalars for FairChem compatibility
if system_extras_keys is not None:
for key in system_extras_keys:
val = state.system_extras[key][sys_idx].detach().cpu().numpy()
atoms.info[key] = val
_sys_keys = (
system_extras_keys
if system_extras_keys is not None
else list(state.system_extras)
)
for key in _sys_keys:
val = state.system_extras[key][sys_idx].detach().cpu().numpy()
atoms.info[key] = val

# Write atom extras to atoms.arrays
if atom_extras_keys is not None:
for key in atom_extras_keys:
val = state.atom_extras[key][mask].detach().cpu().numpy()
atoms.arrays[key] = val
_atom_keys = (
atom_extras_keys if atom_extras_keys is not None else list(state.atom_extras)
)
for key in _atom_keys:
val = state.atom_extras[key][mask].detach().cpu().numpy()
atoms.arrays[key] = val

atoms_list.append(atoms)

Expand Down Expand Up @@ -314,8 +320,16 @@ def atoms_to_state(
raise ValueError("All systems must have the same periodic boundary conditions")

_system_extras: dict[str, torch.Tensor] = {}

# charge and spin always default to 0 for backward compatibility
for key in ("charge", "spin"):
vals = np.array([float(at.info.get(key, 0.0)) for at in atoms_list])
_system_extras[key] = torch.tensor(vals, dtype=dtype, device=device)

if system_extras_keys:
for key in system_extras_keys:
if key in _system_extras:
continue
vals = [at.info.get(key) for at in atoms_list]
non_none_vals = [v for v in vals if v is not None]
if len(non_none_vals) == len(vals):
Expand Down
4 changes: 2 additions & 2 deletions torch_sim/models/fairchem.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def forward(self, state: SimState, **_kwargs: object) -> dict[str, torch.Tensor]
pbc=pbc_np if cell is not None else False,
)

charge = sim_state.charge
spin = sim_state.spin
charge = getattr(sim_state, "charge", None)
spin = getattr(sim_state, "spin", None)
atoms.info["charge"] = charge[idx].item() if charge is not None else 0.0
atoms.info["spin"] = spin[idx].item() if spin is not None else 0.0

Expand Down
4 changes: 2 additions & 2 deletions torch_sim/models/mace.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,8 @@ def forward( # noqa: C901
edge_index=edge_index,
unit_shifts=unit_shifts,
shifts=shifts,
total_charge=state.charge,
total_spin=state.spin,
total_charge=getattr(state, "charge", None),
total_spin=getattr(state, "spin", None),
)

# Get model output
Expand Down
20 changes: 17 additions & 3 deletions torch_sim/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -974,6 +974,11 @@ def get_attrs_for_scope(
for attr_name in attr_names:
yield attr_name, getattr(state, attr_name)

if scope == "per-system":
yield from state._system_extras.items() # noqa: SLF001
elif scope == "per-atom":
yield from state._atom_extras.items() # noqa: SLF001


def _filter_attrs_by_index(
state: SimState,
Expand Down Expand Up @@ -1029,11 +1034,15 @@ def _filter_attrs_by_index(
c.system_idx = new_system_idx[c.system_idx] # ty: ignore[invalid-assignment]

for name, val in get_attrs_for_scope(state, "per-atom"):
if name in state.atom_extras:
continue
filtered_attrs[name] = (
system_remap[val[atom_indices]] if name == "system_idx" else val[atom_indices]
)

for name, val in get_attrs_for_scope(state, "per-system"):
if name in state.system_extras:
continue
filtered_attrs[name] = (
val[system_indices] if isinstance(val, torch.Tensor) else val
)
Expand Down Expand Up @@ -1065,11 +1074,14 @@ def _split_state[T: SimState](state: T) -> list[T]:

split_per_atom = {}
for attr_name, attr_value in get_attrs_for_scope(state, "per-atom"):
if attr_name != "system_idx":
split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0)
if attr_name == "system_idx" or attr_name in state.atom_extras:
continue
split_per_atom[attr_name] = torch.split(attr_value, system_sizes, dim=0)

split_per_system = {}
for attr_name, attr_value in get_attrs_for_scope(state, "per-system"):
if attr_name in state.system_extras:
continue
if isinstance(attr_value, torch.Tensor):
split_per_system[attr_name] = torch.split(attr_value, 1, dim=0)
else: # Non-tensor attributes are replicated for each split
Expand Down Expand Up @@ -1277,13 +1289,15 @@ def concatenate_states[T: SimState]( # noqa: C901, PLR0915

# Collect per-atom properties
for prop, val in get_attrs_for_scope(state, "per-atom"):
if prop == "system_idx":
if prop == "system_idx" or prop in state.atom_extras:
# skip system_idx, it will be handled below
continue
per_atom_tensors[prop].append(val)

# Collect per-system properties
for prop, val in get_attrs_for_scope(state, "per-system"):
if prop in state.system_extras:
continue
per_system_tensors[prop].append(val)

# Collect extras
Expand Down
Loading