-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathpython_nms.py
More file actions
62 lines (56 loc) · 1.94 KB
/
python_nms.py
File metadata and controls
62 lines (56 loc) · 1.94 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
import torch
import numpy as np
def python_nms(boxes, scores, nms_thresh, max_count=-1):
""" Performs non-maximum suppression using numpy
Args:
boxes(Tensor): `xyxy` mode boxes, use absolute coordinates(not support relative coordinates),
shape is (n, 4)
scores(Tensor): scores, shape is (n, )
nms_thresh(float): thresh
max_count (int): if > 0, then only the top max_proposals are kept after non-maximum suppression
Returns:
indices kept.
"""
if boxes.numel() == 0:
return torch.empty((0,), dtype=torch.long)
# Use numpy to run nms. Running nms in PyTorch code on CPU is really slow.
origin_device = boxes.device
cpu_device = torch.device('cpu')
boxes = boxes.to(cpu_device).numpy()
scores = scores.to(cpu_device).numpy()
x1 = boxes[:, 0]
y1 = boxes[:, 1]
x2 = boxes[:, 2]
y2 = boxes[:, 3]
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
order = np.argsort(scores)[::-1]
num_detections = boxes.shape[0]
suppressed = np.zeros((num_detections,), dtype=np.bool)
for _i in range(num_detections):
i = order[_i]
if suppressed[i]:
continue
ix1 = x1[i]
iy1 = y1[i]
ix2 = x2[i]
iy2 = y2[i]
iarea = areas[i]
for _j in range(_i + 1, num_detections):
j = order[_j]
if suppressed[j]:
continue
xx1 = max(ix1, x1[j])
yy1 = max(iy1, y1[j])
xx2 = min(ix2, x2[j])
yy2 = min(iy2, y2[j])
w = max(0, xx2 - xx1 + 1)
h = max(0, yy2 - yy1 + 1)
inter = w * h
ovr = inter / (iarea + areas[j] - inter)
if ovr >= nms_thresh:
suppressed[j] = True
keep = np.nonzero(suppressed == 0)[0]
if max_count > 0:
keep = keep[:max_count]
keep = torch.from_numpy(keep).to(origin_device)
return keep