-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathtest_grid_sampler.py
More file actions
252 lines (202 loc) · 7.9 KB
/
test_grid_sampler.py
File metadata and controls
252 lines (202 loc) · 7.9 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
#!/usr/bin/env python3
"""
Grid Sampler CUDA 测试用例
"""
import numpy as np
import sys
import os
# 添加当前目录到Python路径
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
try:
import grid_sampler_cuda
except ImportError as e:
print(f"无法导入grid_sampler_cuda模块: {e}")
print("请先编译CUDA扩展")
sys.exit(1)
def test_basic_functionality():
"""测试基本功能"""
print("=== 测试基本功能 ===")
# 创建测试数据
input_tensor = grid_sampler_cuda.create_test_input(batch=1, channels=3, height=4, width=4)
grid_tensor = grid_sampler_cuda.create_test_grid(batch=1, height=2, width=2)
print(f"输入张量形状: {input_tensor.shape}")
print(f"网格张量形状: {grid_tensor.shape}")
print(f"输入数据:\n{input_tensor}")
print(f"网格数据:\n{grid_tensor}")
# 测试双线性插值
output = grid_sampler_cuda.grid_sampler(
input_tensor, grid_tensor,
mode="bilinear",
padding_mode="zeros",
align_corners=False
)
print(f"输出张量形状: {output.shape}")
print(f"输出数据:\n{output}")
print("✓ 基本功能测试通过\n")
def test_different_modes():
"""测试不同插值模式"""
print("=== 测试不同插值模式 ===")
# 创建简单的测试数据
input_data = np.array([[[[1, 2], [3, 4]]]], dtype=np.float32) # [1, 1, 2, 2]
grid_data = np.array([[[[-1, -1], [1, 1]]]], dtype=np.float32) # [1, 2, 2, 2]
print(f"输入数据:\n{input_data}")
print(f"网格数据:\n{grid_data}")
# 测试双线性插值
output_bilinear = grid_sampler_cuda.grid_sampler(
input_data, grid_data, mode="bilinear"
)
print(f"双线性插值输出:\n{output_bilinear}")
# 测试最近邻插值
output_nearest = grid_sampler_cuda.grid_sampler(
input_data, grid_data, mode="nearest"
)
print(f"最近邻插值输出:\n{output_nearest}")
print("✓ 不同模式测试通过\n")
def test_different_padding_modes():
"""测试不同填充模式"""
print("=== 测试不同填充模式 ===")
# 创建测试数据
input_data = np.array([[[[1, 2], [3, 4]]]], dtype=np.float32) # [1, 1, 2, 2]
# 创建超出边界的网格
grid_data = np.array([[[[-2, -2], [2, 2]]]], dtype=np.float32) # [1, 2, 2, 2]
print(f"输入数据:\n{input_data}")
print(f"网格数据 (超出边界):\n{grid_data}")
# 测试零填充
output_zeros = grid_sampler_cuda.grid_sampler(
input_data, grid_data, padding_mode="zeros"
)
print(f"零填充输出:\n{output_zeros}")
# 测试边界填充
output_border = grid_sampler_cuda.grid_sampler(
input_data, grid_data, padding_mode="border"
)
print(f"边界填充输出:\n{output_border}")
# 测试反射填充
output_reflection = grid_sampler_cuda.grid_sampler(
input_data, grid_data, padding_mode="reflection"
)
print(f"反射填充输出:\n{output_reflection}")
print("✓ 不同填充模式测试通过\n")
def test_align_corners():
"""测试align_corners参数"""
print("=== 测试align_corners参数 ===")
# 创建测试数据
input_data = np.array([[[[1, 2], [3, 4]]]], dtype=np.float32) # [1, 1, 2, 2]
grid_data = np.array([[[[-1, -1], [1, 1]]]], dtype=np.float32) # [1, 2, 2, 2]
print(f"输入数据:\n{input_data}")
print(f"网格数据:\n{grid_data}")
# 测试align_corners=False
output_align_false = grid_sampler_cuda.grid_sampler(
input_data, grid_data, align_corners=False
)
print(f"align_corners=False 输出:\n{output_align_false}")
# 测试align_corners=True
output_align_true = grid_sampler_cuda.grid_sampler(
input_data, grid_data, align_corners=True
)
print(f"align_corners=True 输出:\n{output_align_true}")
print("✓ align_corners测试通过\n")
def test_batch_processing():
"""测试批处理"""
print("=== 测试批处理 ===")
# 创建多批次数据
batch_size = 2
input_data = grid_sampler_cuda.create_test_input(
batch=batch_size, channels=2, height=3, width=3
)
grid_data = grid_sampler_cuda.create_test_grid(
batch=batch_size, height=2, width=2
)
print(f"批次大小: {batch_size}")
print(f"输入形状: {input_data.shape}")
print(f"网格形状: {grid_data.shape}")
output = grid_sampler_cuda.grid_sampler(input_data, grid_data)
print(f"输出形状: {output.shape}")
print("✓ 批处理测试通过\n")
def test_performance():
"""性能测试"""
print("=== 性能测试 ===")
import time
# 创建较大的测试数据
batch_size = 4
channels = 64
height = 128
width = 128
output_height = 64
output_width = 64
print(f"测试数据大小: [{batch_size}, {channels}, {height}, {width}]")
print(f"输出大小: [{batch_size}, {channels}, {output_height}, {output_width}]")
# 创建输入数据
input_data = np.random.randn(batch_size, channels, height, width).astype(np.float32)
# 创建网格数据
grid_data = np.random.randn(batch_size, output_height, output_width, 2).astype(np.float32)
# 将网格坐标限制在[-1, 1]范围内
grid_data = np.clip(grid_data, -1, 1)
# 预热
_ = grid_sampler_cuda.grid_sampler(input_data, grid_data)
# 性能测试
num_iterations = 10
start_time = time.time()
for _ in range(num_iterations):
output = grid_sampler_cuda.grid_sampler(input_data, grid_data)
end_time = time.time()
avg_time = (end_time - start_time) / num_iterations
print(f"平均执行时间: {avg_time:.4f} 秒")
print(f"输出形状: {output.shape}")
print("✓ 性能测试完成\n")
def test_error_handling():
"""测试错误处理"""
print("=== 测试错误处理 ===")
try:
# 测试错误的输入维度
input_data = np.random.randn(1, 3, 4, 4, 4).astype(np.float32) # 5D张量
grid_data = np.random.randn(1, 2, 2, 2).astype(np.float32)
grid_sampler_cuda.grid_sampler(input_data, grid_data)
print("❌ 应该抛出维度错误")
except Exception as e:
print(f"✓ 正确捕获维度错误: {e}")
try:
# 测试错误的网格维度
input_data = np.random.randn(1, 3, 4, 4).astype(np.float32)
grid_data = np.random.randn(1, 2, 2, 3).astype(np.float32) # 错误的最后一维
grid_sampler_cuda.grid_sampler(input_data, grid_data)
print("❌ 应该抛出网格维度错误")
except Exception as e:
print(f"✓ 正确捕获网格维度错误: {e}")
try:
# 测试批次维度不匹配
input_data = np.random.randn(2, 3, 4, 4).astype(np.float32)
grid_data = np.random.randn(1, 2, 2, 2).astype(np.float32) # 不同的批次大小
grid_sampler_cuda.grid_sampler(input_data, grid_data)
print("❌ 应该抛出批次维度错误")
except Exception as e:
print(f"✓ 正确捕获批次维度错误: {e}")
try:
# 测试无效的模式
input_data = np.random.randn(1, 3, 4, 4).astype(np.float32)
grid_data = np.random.randn(1, 2, 2, 2).astype(np.float32)
grid_sampler_cuda.grid_sampler(input_data, grid_data, mode="invalid_mode")
print("❌ 应该抛出模式错误")
except Exception as e:
print(f"✓ 正确捕获模式错误: {e}")
print("✓ 错误处理测试通过\n")
def main():
"""运行所有测试"""
print("开始运行Grid Sampler CUDA测试...\n")
try:
test_basic_functionality()
test_different_modes()
test_different_padding_modes()
test_align_corners()
test_batch_processing()
test_performance()
test_error_handling()
print("🎉 所有测试通过!")
except Exception as e:
print(f"❌ 测试失败: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
exit(main())