Skip to content

Commit 6963c8e

Browse files
authored
[Feat]: Add UVDoc dynamic graph model (#4793)
* [Feat]: Add UVDoc dynamic graph model * Fix delete self.dtype * Fix delete out_point_positions3D * Fix add pdparam and safetensors PReLU weight/_weight jugement * Fix optimize code * Fix optimize code * Fix optimize code * Fix optimize code
1 parent c868624 commit 6963c8e

File tree

6 files changed

+496
-11
lines changed

6 files changed

+496
-11
lines changed

paddlex/inference/models/common/transformers/transformers/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
22
#
33
# Licensed under the Apache License, Version 2.0 (the "License");
44
# you may not use this file except in compliance with the License.
@@ -13,4 +13,5 @@
1313
# limitations under the License.
1414

1515
from .configuration_utils import PretrainedConfig
16+
from .hf_state_dict_utils import BatchNormHFStateDictMixin
1617
from .model_utils import PretrainedModel
Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
16+
class BatchNormHFStateDictMixin:
17+
def _get_forward_key_rules(self):
18+
return [
19+
("_mean", "_mean", "running_mean"),
20+
("_variance", "_variance", "running_var"),
21+
]
22+
23+
def _get_reverse_key_rules(self):
24+
return [
25+
("running_mean", "running_mean", "_mean"),
26+
("running_var", "running_var", "_variance"),
27+
]
28+
29+
def get_hf_state_dict(self, *args, **kwargs):
30+
31+
try:
32+
super().get_hf_state_dict(*args, **kwargs)
33+
except NotImplementedError:
34+
pass
35+
36+
model_state_dict = self.state_dict(*args, **kwargs)
37+
hf_state_dict = {}
38+
rules = self._get_forward_key_rules()
39+
40+
for old_key, value in model_state_dict.items():
41+
new_key = old_key
42+
for match_key, old_sub, new_sub in rules:
43+
if match_key in old_key:
44+
new_key = old_key.replace(old_sub, new_sub)
45+
break
46+
hf_state_dict[new_key] = value
47+
return hf_state_dict
48+
49+
def set_hf_state_dict(self, state_dict, *args, **kwargs):
50+
51+
try:
52+
super().set_hf_state_dict(state_dict, *args, **kwargs)
53+
except NotImplementedError:
54+
pass
55+
56+
key_mapping = {}
57+
rules = self._get_reverse_key_rules()
58+
59+
for old_key in list(state_dict.keys()):
60+
for match_key, old_sub, new_sub in rules:
61+
if match_key in old_key:
62+
key_mapping[old_key] = old_key.replace(old_sub, new_sub)
63+
break
64+
65+
for old_key, new_key in key_mapping.items():
66+
state_dict[new_key] = state_dict.pop(old_key)
67+
return self.set_state_dict(state_dict, *args, **kwargs)

paddlex/inference/models/common/transformers/transformers/model_utils.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -203,7 +203,11 @@ def _transpose_hf_weight(key, weight):
203203
else:
204204
weight = tp_fn(py_safe_slice_)
205205
else:
206-
weight = py_safe_slice_[:]
206+
# HACK
207+
if len(py_safe_slice_.get_shape()) == 0:
208+
logging.debug("Ignore empty shape this moment")
209+
else:
210+
weight = py_safe_slice_[:]
207211

208212
if not return_numpy and device == "expected":
209213
weight = weight._copy_to(
@@ -1843,13 +1847,12 @@ def from_pretrained(
18431847
):
18441848
raise NotImplementedError
18451849
else:
1846-
try:
1847-
transpose_weight_keys = model.get_transpose_weight_keys()
1848-
except NotImplementedError:
1849-
if convert_from_hf:
1850-
raise ValueError("`convert_from_hf=True` is not supported")
1851-
else:
1852-
transpose_weight_keys = None
1850+
transpose_weight_keys = None
1851+
if convert_from_hf:
1852+
try:
1853+
transpose_weight_keys = model.get_transpose_weight_keys()
1854+
except NotImplementedError:
1855+
pass
18531856
state_dict = load_state_dict(
18541857
resolved_archive_file,
18551858
convert_from_hf=convert_from_hf,
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from .uvdoc import UVDocNet

0 commit comments

Comments
 (0)