diff --git a/paddlex/inference/models/common/transformers/transformers/__init__.py b/paddlex/inference/models/common/transformers/transformers/__init__.py index 4badddda5f..29cfe096eb 100644 --- a/paddlex/inference/models/common/transformers/transformers/__init__.py +++ b/paddlex/inference/models/common/transformers/transformers/__init__.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved. +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -13,4 +13,5 @@ # limitations under the License. from .configuration_utils import PretrainedConfig +from .hf_state_dict_utils import BatchNormHFStateDictMixin from .model_utils import PretrainedModel diff --git a/paddlex/inference/models/common/transformers/transformers/hf_state_dict_utils.py b/paddlex/inference/models/common/transformers/transformers/hf_state_dict_utils.py new file mode 100644 index 0000000000..bec2216a3f --- /dev/null +++ b/paddlex/inference/models/common/transformers/transformers/hf_state_dict_utils.py @@ -0,0 +1,67 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +class BatchNormHFStateDictMixin: + def _get_forward_key_rules(self): + return [ + ("_mean", "_mean", "running_mean"), + ("_variance", "_variance", "running_var"), + ] + + def _get_reverse_key_rules(self): + return [ + ("running_mean", "running_mean", "_mean"), + ("running_var", "running_var", "_variance"), + ] + + def get_hf_state_dict(self, *args, **kwargs): + + try: + super().get_hf_state_dict(*args, **kwargs) + except NotImplementedError: + pass + + model_state_dict = self.state_dict(*args, **kwargs) + hf_state_dict = {} + rules = self._get_forward_key_rules() + + for old_key, value in model_state_dict.items(): + new_key = old_key + for match_key, old_sub, new_sub in rules: + if match_key in old_key: + new_key = old_key.replace(old_sub, new_sub) + break + hf_state_dict[new_key] = value + return hf_state_dict + + def set_hf_state_dict(self, state_dict, *args, **kwargs): + + try: + super().set_hf_state_dict(state_dict, *args, **kwargs) + except NotImplementedError: + pass + + key_mapping = {} + rules = self._get_reverse_key_rules() + + for old_key in list(state_dict.keys()): + for match_key, old_sub, new_sub in rules: + if match_key in old_key: + key_mapping[old_key] = old_key.replace(old_sub, new_sub) + break + + for old_key, new_key in key_mapping.items(): + state_dict[new_key] = state_dict.pop(old_key) + return self.set_state_dict(state_dict, *args, **kwargs) diff --git a/paddlex/inference/models/common/transformers/transformers/model_utils.py b/paddlex/inference/models/common/transformers/transformers/model_utils.py index e55877e296..d01650dcc0 100644 --- a/paddlex/inference/models/common/transformers/transformers/model_utils.py +++ b/paddlex/inference/models/common/transformers/transformers/model_utils.py @@ -203,7 +203,11 @@ def _transpose_hf_weight(key, weight): else: weight = tp_fn(py_safe_slice_) else: - weight = py_safe_slice_[:] + # HACK + if len(py_safe_slice_.get_shape()) == 0: + logging.debug("Ignore empty shape this moment") + else: + weight = py_safe_slice_[:] if not return_numpy and device == "expected": weight = weight._copy_to( @@ -1841,13 +1845,12 @@ def from_pretrained( ): raise NotImplementedError else: - try: - transpose_weight_keys = model.get_transpose_weight_keys() - except NotImplementedError: - if convert_from_hf: - raise ValueError("`convert_from_hf=True` is not supported") - else: - transpose_weight_keys = None + transpose_weight_keys = None + if convert_from_hf: + try: + transpose_weight_keys = model.get_transpose_weight_keys() + except NotImplementedError: + pass state_dict = load_state_dict( resolved_archive_file, convert_from_hf=convert_from_hf, diff --git a/paddlex/inference/models/image_unwarping/modeling/__init__.py b/paddlex/inference/models/image_unwarping/modeling/__init__.py new file mode 100644 index 0000000000..6f006b11e6 --- /dev/null +++ b/paddlex/inference/models/image_unwarping/modeling/__init__.py @@ -0,0 +1,15 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .uvdoc import UVDocNet diff --git a/paddlex/inference/models/image_unwarping/modeling/uvdoc.py b/paddlex/inference/models/image_unwarping/modeling/uvdoc.py new file mode 100644 index 0000000000..83f095f729 --- /dev/null +++ b/paddlex/inference/models/image_unwarping/modeling/uvdoc.py @@ -0,0 +1,388 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle +import paddle.nn as nn +import paddle.nn.functional as F + +from ...common.transformers.transformers import ( + BatchNormHFStateDictMixin, + PretrainedConfig, + PretrainedModel, +) + + +def conv3x3(in_channels, out_channels, kernel_size, stride=1): + return nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=kernel_size // 2, + ) + + +def dilated_conv_bn_act(in_channels, out_channels, act_fn, BatchNorm, dilation): + model = nn.Sequential( + nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + bias_attr=False, + kernel_size=3, + stride=1, + padding=dilation, + dilation=dilation, + ), + BatchNorm(out_channels), + act_fn, + ) + return model + + +def dilated_conv(in_channels, out_channels, kernel_size, dilation, stride=1): + model = nn.Sequential( + nn.Conv2D( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=dilation * (kernel_size // 2), + dilation=dilation, + ) + ) + return model + + +class ResidualBlockWithDilation(nn.Layer): + + def __init__( + self, + in_channels, + out_channels, + BatchNorm, + kernel_size, + stride=1, + downsample=None, + is_activation=True, + is_top=False, + ): + super(ResidualBlockWithDilation, self).__init__() + self.stride = stride + self.downsample = downsample + self.is_activation = is_activation + self.is_top = is_top + if self.stride != 1 or self.is_top: + self.conv1 = conv3x3(in_channels, out_channels, kernel_size, self.stride) + self.conv2 = conv3x3(out_channels, out_channels, kernel_size) + else: + self.conv1 = dilated_conv( + in_channels, out_channels, kernel_size, dilation=3 + ) + self.conv2 = dilated_conv( + out_channels, out_channels, kernel_size, dilation=3 + ) + self.bn1 = BatchNorm(out_channels) + self.relu = nn.ReLU() + self.bn2 = BatchNorm(out_channels) + + def forward(self, x): + residual = x + if self.downsample is not None: + residual = self.downsample(x) + out1 = self.relu(self.bn1(self.conv1(x))) + out2 = self.bn2(self.conv2(out1)) + out2 += residual + out = self.relu(out2) + return out + + +class ResnetStraight(nn.Layer): + + def __init__( + self, + num_filter, + map_num, + BatchNorm, + block_nums=[3, 4, 6, 3], + block=ResidualBlockWithDilation, + kernel_size=5, + stride=[1, 1, 2, 2], + ): + super(ResnetStraight, self).__init__() + self.in_channels = num_filter * map_num[0] + self.stride = stride + self.relu = nn.ReLU() + self.block_nums = block_nums + self.kernel_size = kernel_size + self.layer1 = self.blocklayer( + block, + num_filter * map_num[0], + self.block_nums[0], + BatchNorm, + kernel_size=self.kernel_size, + stride=self.stride[0], + ) + self.layer2 = self.blocklayer( + block, + num_filter * map_num[1], + self.block_nums[1], + BatchNorm, + kernel_size=self.kernel_size, + stride=self.stride[1], + ) + self.layer3 = self.blocklayer( + block, + num_filter * map_num[2], + self.block_nums[2], + BatchNorm, + kernel_size=self.kernel_size, + stride=self.stride[2], + ) + + def blocklayer( + self, block, out_channels, block_nums, BatchNorm, kernel_size, stride=1 + ): + downsample = None + if stride != 1 or self.in_channels != out_channels: + downsample = nn.Sequential( + conv3x3( + self.in_channels, + out_channels, + kernel_size=kernel_size, + stride=stride, + ), + BatchNorm(out_channels), + ) + layers = [] + layers.append( + block( + self.in_channels, + out_channels, + BatchNorm, + kernel_size, + stride, + downsample, + is_top=True, + ) + ) + self.in_channels = out_channels + for i in range(1, block_nums): + layers.append( + block( + out_channels, + out_channels, + BatchNorm, + kernel_size, + is_activation=True, + is_top=False, + ) + ) + return nn.Sequential(*layers) + + def forward(self, x): + out1 = self.layer1(x) + out2 = self.layer2(out1) + out3 = self.layer3(out2) + return out3 + + +class UVDocNet(BatchNormHFStateDictMixin, PretrainedModel): + config_class = PretrainedConfig + + def __init__(self, config: PretrainedConfig): + super(UVDocNet, self).__init__(config) + + self.num_filter = 32 + self.in_channels = 3 + self.kernel_size = 5 + self.stride = [1, 2, 2, 2] + BatchNorm = nn.BatchNorm2D + act_fn = nn.ReLU() + map_num = [1, 2, 4, 8, 16] + + self.resnet_head = nn.Sequential( + nn.Conv2D( + in_channels=self.in_channels, + out_channels=self.num_filter * map_num[0], + bias_attr=False, + kernel_size=self.kernel_size, + stride=2, + padding=self.kernel_size // 2, + ), + BatchNorm(self.num_filter * map_num[0]), + act_fn, + nn.Conv2D( + in_channels=self.num_filter * map_num[0], + out_channels=self.num_filter * map_num[0], + bias_attr=False, + kernel_size=self.kernel_size, + stride=2, + padding=self.kernel_size // 2, + ), + BatchNorm(self.num_filter * map_num[0]), + act_fn, + ) + + self.resnet_down = ResnetStraight( + self.num_filter, + map_num, + BatchNorm, + block_nums=[3, 4, 6, 3], + block=ResidualBlockWithDilation, + kernel_size=self.kernel_size, + stride=self.stride, + ) + + map_num_i = 2 + self.bridge_1 = nn.Sequential( + dilated_conv_bn_act( + self.num_filter * map_num[map_num_i], + self.num_filter * map_num[map_num_i], + act_fn, + BatchNorm, + dilation=1, + ) + ) + self.bridge_2 = nn.Sequential( + dilated_conv_bn_act( + self.num_filter * map_num[map_num_i], + self.num_filter * map_num[map_num_i], + act_fn, + BatchNorm, + dilation=2, + ) + ) + self.bridge_3 = nn.Sequential( + dilated_conv_bn_act( + self.num_filter * map_num[map_num_i], + self.num_filter * map_num[map_num_i], + act_fn, + BatchNorm, + dilation=5, + ) + ) + self.bridge_4 = nn.Sequential( + *[ + dilated_conv_bn_act( + self.num_filter * map_num[map_num_i], + self.num_filter * map_num[map_num_i], + act_fn, + BatchNorm, + dilation=d, + ) + for d in [8, 3, 2] + ] + ) + self.bridge_5 = nn.Sequential( + *[ + dilated_conv_bn_act( + self.num_filter * map_num[map_num_i], + self.num_filter * map_num[map_num_i], + act_fn, + BatchNorm, + dilation=d, + ) + for d in [12, 7, 4] + ] + ) + self.bridge_6 = nn.Sequential( + *[ + dilated_conv_bn_act( + self.num_filter * map_num[map_num_i], + self.num_filter * map_num[map_num_i], + act_fn, + BatchNorm, + dilation=d, + ) + for d in [18, 12, 6] + ] + ) + + self.bridge_concat = nn.Sequential( + nn.Conv2D( + in_channels=self.num_filter * map_num[map_num_i] * 6, + out_channels=self.num_filter * map_num[2], + bias_attr=False, + kernel_size=1, + stride=1, + padding=0, + ), + BatchNorm(self.num_filter * map_num[2]), + act_fn, + ) + + self.out_point_positions2D = nn.Sequential( + nn.Conv2D( + in_channels=self.num_filter * map_num[2], + out_channels=self.num_filter * map_num[0], + bias_attr=False, + kernel_size=self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + padding_mode="reflect", + ), + BatchNorm(self.num_filter * map_num[0]), + nn.PReLU(), + nn.Conv2D( + in_channels=self.num_filter * map_num[0], + out_channels=2, + kernel_size=self.kernel_size, + stride=1, + padding=self.kernel_size // 2, + padding_mode="reflect", + ), + ) + + def forward(self, x): + x = paddle.to_tensor(x[0]) + + image = x + h_ori, w_ori = x.shape[2:] + x = F.upsample(x, size=(712, 488), mode="bilinear", align_corners=True) + resnet_head = self.resnet_head(x) + resnet_down = self.resnet_down(resnet_head) + + bridge_1 = self.bridge_1(resnet_down) + bridge_2 = self.bridge_2(resnet_down) + bridge_3 = self.bridge_3(resnet_down) + bridge_4 = self.bridge_4(resnet_down) + bridge_5 = self.bridge_5(resnet_down) + bridge_6 = self.bridge_6(resnet_down) + + bridge_concat = paddle.concat( + x=[bridge_1, bridge_2, bridge_3, bridge_4, bridge_5, bridge_6], axis=1 + ) + bridge = self.bridge_concat(bridge_concat) + out_point_positions2D = self.out_point_positions2D(bridge) + + bm_up = F.upsample( + out_point_positions2D, + size=(h_ori, w_ori), + mode="bilinear", + align_corners=True, + ) + bm = bm_up.transpose([0, 2, 3, 1]) + out = F.grid_sample(image, bm, align_corners=True) + + return [out.cpu().numpy()] + + def _get_forward_key_rules(self): + default_rules = super()._get_forward_key_rules() + custom_rules = [("out_point_positions2D.2._weight", "_weight", "weight")] + return default_rules + custom_rules + + def _get_reverse_key_rules(self): + default_rules = super()._get_reverse_key_rules() + custom_rules = [("out_point_positions2D.2.weight", "weight", "_weight")] + return default_rules + custom_rules diff --git a/paddlex/inference/models/image_unwarping/predictor.py b/paddlex/inference/models/image_unwarping/predictor.py index fea09dc8a6..adede04eaa 100644 --- a/paddlex/inference/models/image_unwarping/predictor.py +++ b/paddlex/inference/models/image_unwarping/predictor.py @@ -17,6 +17,7 @@ import numpy as np from ....modules.image_unwarping.model_list import MODELS +from ....utils.device import TemporaryDeviceChanger from ...common.batch_sampler import ImageBatchSampler from ...common.reader import ReadImage from ..base import BasePredictor @@ -38,6 +39,7 @@ def __init__(self, *args: List, **kwargs: Dict) -> None: **kwargs: Arbitrary keyword arguments passed to the superclass. """ super().__init__(*args, **kwargs) + self.device = kwargs.get("device", None) self.preprocessors, self.infer, self.postprocessors = self._build() def _build_batch_sampler(self) -> ImageBatchSampler: @@ -66,8 +68,16 @@ def _build(self) -> Tuple: preprocessors["Normalize"] = Normalize(mean=0.0, std=1.0, scale=1.0 / 255) preprocessors["ToCHW"] = ToCHWImage() preprocessors["ToBatch"] = ToBatch() + if self._use_static_model: + infer = self.create_static_infer() + else: + from .modeling import UVDocNet - infer = self.create_static_infer() + with TemporaryDeviceChanger(self.device): + infer = UVDocNet.from_pretrained( + self.model_dir, use_safetensors=True, convert_from_hf=True + ) + infer.eval() postprocessors = {"DocTrPostProcess": DocTrPostProcess()} return preprocessors, infer, postprocessors @@ -86,7 +96,8 @@ def process(self, batch_data: List[Union[str, np.ndarray]]) -> Dict[str, Any]: batch_imgs = self.preprocessors["Normalize"](imgs=batch_raw_imgs) batch_imgs = self.preprocessors["ToCHW"](imgs=batch_imgs) x = self.preprocessors["ToBatch"](imgs=batch_imgs) - batch_preds = self.infer(x=x) + with TemporaryDeviceChanger(self.device): + batch_preds = self.infer(x=x) batch_warp_preds = self.postprocessors["DocTrPostProcess"](batch_preds) return {