forked from deepmodeling/dpdata
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata_type.py
More file actions
161 lines (133 loc) · 4.69 KB
/
data_type.py
File metadata and controls
161 lines (133 loc) · 4.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
from __future__ import annotations
from enum import Enum, unique
from typing import TYPE_CHECKING
import numpy as np
from dpdata.plugin import Plugin
if TYPE_CHECKING:
from dpdata.system import System
@unique
class Axis(Enum):
"""Data axis."""
NFRAMES = "nframes"
NATOMS = "natoms"
NTYPES = "ntypes"
NBONDS = "nbonds"
class AnyInt(int):
"""AnyInt equals to any other integer."""
def __eq__(self, other):
return True
class DataError(Exception):
"""Data is not correct."""
class DataType:
"""DataType represents a type of data, like coordinates, energies, etc.
Parameters
----------
name : str
name of data
dtype : type or tuple[type]
data type, e.g. np.ndarray
shape : tuple[int], optional
shape of data. Used when data is list or np.ndarray. Use Axis to
represents numbers
required : bool, default=True
whether this data is required
deepmd_name : str, optional
DeePMD-kit data type name. When not given, it is the same as `name`.
"""
def __init__(
self,
name: str,
dtype: type,
shape: tuple[int | Axis, ...] | None = None,
required: bool = True,
deepmd_name: str | None = None,
) -> None:
self.name = name
self.dtype = dtype
self.shape = shape
self.required = required
self.deepmd_name = name if deepmd_name is None else deepmd_name
def real_shape(self, system: System) -> tuple[int]:
"""Returns expected real shape of a system."""
assert self.shape is not None
shape = []
for ii in self.shape:
if ii is Axis.NFRAMES:
shape.append(system.get_nframes())
elif ii is Axis.NTYPES:
shape.append(system.get_ntypes())
elif ii is Axis.NATOMS:
shape.append(system.get_natoms())
elif ii is Axis.NBONDS:
# BondOrderSystem
shape.append(system.get_nbonds()) # type: ignore
elif ii == -1:
shape.append(AnyInt(-1))
elif isinstance(ii, int):
shape.append(ii)
else:
raise RuntimeError("Shape is not an int!")
return tuple(shape)
def check(self, system: System):
"""Check if a system has correct data of this type.
Parameters
----------
system : System
checked system
Raises
------
DataError
type or shape of data is not correct
"""
# check if exists
if self.name in system.data:
data = system.data[self.name]
# check dtype
# allow list for empty np.ndarray
if isinstance(data, list) and not len(data):
pass
elif not isinstance(data, self.dtype):
raise DataError(
f"Type of {self.name} is {type(data).__name__}, but expected {self.dtype.__name__}"
)
# check shape
if self.shape is not None:
shape = self.real_shape(system)
# skip checking empty list of np.ndarray
if isinstance(data, np.ndarray):
if data.size and shape != data.shape:
raise DataError(
f"Shape of {self.name} is {data.shape}, but expected {shape}"
)
elif isinstance(data, list):
if len(shape) and shape[0] != len(data):
raise DataError(
"Length of %s is %d, but expected %d" # noqa: UP031
% (self.name, len(data), shape[0])
)
else:
raise RuntimeError("Unsupported type to check shape")
elif self.required:
raise DataError(f"{self.name} not found in data")
__system_data_type_plugin = Plugin()
__labeled_system_data_type_plugin = Plugin()
def register_data_type(data_type: DataType, labeled: bool):
"""Register a data type.
Parameters
----------
data_type : DataType
data type to be registered
labeled : bool
whether this data type is for LabeledSystem
"""
plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin
plugin.register(data_type.name)(data_type)
def get_data_types(labeled: bool):
"""Get all registered data types.
Parameters
----------
labeled : bool
whether this data type is for LabeledSystem
"""
plugin = __labeled_system_data_type_plugin if labeled else __system_data_type_plugin
return tuple(plugin.plugins.values())