yunyangx commited on
Commit
070c43b
1 Parent(s): 4e4a175

fix one typo of file name

Browse files
Files changed (1) hide show
  1. utils/tools_gradio.py +193 -0
utils/tools_gradio.py ADDED
@@ -0,0 +1,193 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ import torch
5
+ from PIL import Image
6
+
7
+
8
+ def fast_process(
9
+ annotations,
10
+ image,
11
+ device,
12
+ scale,
13
+ better_quality=False,
14
+ mask_random_color=True,
15
+ bbox=None,
16
+ points=None,
17
+ use_retina=True,
18
+ withContours=True,
19
+ ):
20
+ if isinstance(annotations[0], dict):
21
+ annotations = [annotation["segmentation"] for annotation in annotations]
22
+
23
+ original_h = image.height
24
+ original_w = image.width
25
+ if better_quality:
26
+ if isinstance(annotations[0], torch.Tensor):
27
+ annotations = np.array(annotations.cpu())
28
+ for i, mask in enumerate(annotations):
29
+ mask = cv2.morphologyEx(
30
+ mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8)
31
+ )
32
+ annotations[i] = cv2.morphologyEx(
33
+ mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8)
34
+ )
35
+ if device == "cpu":
36
+ annotations = np.array(annotations)
37
+ inner_mask = fast_show_mask(
38
+ annotations,
39
+ plt.gca(),
40
+ random_color=mask_random_color,
41
+ bbox=bbox,
42
+ retinamask=use_retina,
43
+ target_height=original_h,
44
+ target_width=original_w,
45
+ )
46
+ else:
47
+ if isinstance(annotations[0], np.ndarray):
48
+ annotations = np.array(annotations)
49
+ annotations = torch.from_numpy(annotations)
50
+ inner_mask = fast_show_mask_gpu(
51
+ annotations,
52
+ plt.gca(),
53
+ random_color=mask_random_color,
54
+ bbox=bbox,
55
+ retinamask=use_retina,
56
+ target_height=original_h,
57
+ target_width=original_w,
58
+ )
59
+ if isinstance(annotations, torch.Tensor):
60
+ annotations = annotations.cpu().numpy()
61
+
62
+ if withContours:
63
+ contour_all = []
64
+ temp = np.zeros((original_h, original_w, 1))
65
+ for i, mask in enumerate(annotations):
66
+ if type(mask) == dict:
67
+ mask = mask["segmentation"]
68
+ annotation = mask.astype(np.uint8)
69
+ if use_retina == False:
70
+ annotation = cv2.resize(
71
+ annotation,
72
+ (original_w, original_h),
73
+ interpolation=cv2.INTER_NEAREST,
74
+ )
75
+ contours, _ = cv2.findContours(
76
+ annotation, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
77
+ )
78
+ for contour in contours:
79
+ contour_all.append(contour)
80
+ cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2 // scale)
81
+ color = np.array([0 / 255, 0 / 255, 255 / 255, 0.9])
82
+ contour_mask = temp / 255 * color.reshape(1, 1, -1)
83
+
84
+ image = image.convert("RGBA")
85
+ overlay_inner = Image.fromarray((inner_mask * 255).astype(np.uint8), "RGBA")
86
+ image.paste(overlay_inner, (0, 0), overlay_inner)
87
+
88
+ if withContours:
89
+ overlay_contour = Image.fromarray((contour_mask * 255).astype(np.uint8), "RGBA")
90
+ image.paste(overlay_contour, (0, 0), overlay_contour)
91
+
92
+ return image
93
+
94
+
95
+ # CPU post process
96
+ def fast_show_mask(
97
+ annotation,
98
+ ax,
99
+ random_color=False,
100
+ bbox=None,
101
+ retinamask=True,
102
+ target_height=960,
103
+ target_width=960,
104
+ ):
105
+ mask_sum = annotation.shape[0]
106
+ height = annotation.shape[1]
107
+ weight = annotation.shape[2]
108
+ # annotation is sorted by area
109
+ areas = np.sum(annotation, axis=(1, 2))
110
+ sorted_indices = np.argsort(areas)[::1]
111
+ annotation = annotation[sorted_indices]
112
+
113
+ index = (annotation != 0).argmax(axis=0)
114
+ if random_color == True:
115
+ color = np.random.random((mask_sum, 1, 1, 3))
116
+ else:
117
+ color = np.ones((mask_sum, 1, 1, 3)) * np.array(
118
+ [30 / 255, 144 / 255, 255 / 255]
119
+ )
120
+ transparency = np.ones((mask_sum, 1, 1, 1)) * 0.6
121
+ visual = np.concatenate([color, transparency], axis=-1)
122
+ mask_image = np.expand_dims(annotation, -1) * visual
123
+
124
+ mask = np.zeros((height, weight, 4))
125
+
126
+ h_indices, w_indices = np.meshgrid(
127
+ np.arange(height), np.arange(weight), indexing="ij"
128
+ )
129
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
130
+
131
+ mask[h_indices, w_indices, :] = mask_image[indices]
132
+ if bbox is not None:
133
+ x1, y1, x2, y2 = bbox
134
+ ax.add_patch(
135
+ plt.Rectangle(
136
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
137
+ )
138
+ )
139
+
140
+ if retinamask == False:
141
+ mask = cv2.resize(
142
+ mask, (target_width, target_height), interpolation=cv2.INTER_NEAREST
143
+ )
144
+
145
+ return mask
146
+
147
+
148
+ def fast_show_mask_gpu(
149
+ annotation,
150
+ ax,
151
+ random_color=False,
152
+ bbox=None,
153
+ retinamask=True,
154
+ target_height=960,
155
+ target_width=960,
156
+ ):
157
+ device = annotation.device
158
+ mask_sum = annotation.shape[0]
159
+ height = annotation.shape[1]
160
+ weight = annotation.shape[2]
161
+ areas = torch.sum(annotation, dim=(1, 2))
162
+ sorted_indices = torch.argsort(areas, descending=False)
163
+ annotation = annotation[sorted_indices]
164
+ # find the first non-zero subscript for each position
165
+ index = (annotation != 0).to(torch.long).argmax(dim=0)
166
+ if random_color == True:
167
+ color = torch.rand((mask_sum, 1, 1, 3)).to(device)
168
+ else:
169
+ color = torch.ones((mask_sum, 1, 1, 3)).to(device) * torch.tensor(
170
+ [30 / 255, 144 / 255, 255 / 255]
171
+ ).to(device)
172
+ transparency = torch.ones((mask_sum, 1, 1, 1)).to(device) * 0.6
173
+ visual = torch.cat([color, transparency], dim=-1)
174
+ mask_image = torch.unsqueeze(annotation, -1) * visual
175
+ # index
176
+ mask = torch.zeros((height, weight, 4)).to(device)
177
+ h_indices, w_indices = torch.meshgrid(torch.arange(height), torch.arange(weight))
178
+ indices = (index[h_indices, w_indices], h_indices, w_indices, slice(None))
179
+ # make updates based on indices
180
+ mask[h_indices, w_indices, :] = mask_image[indices]
181
+ mask_cpu = mask.cpu().numpy()
182
+ if bbox is not None:
183
+ x1, y1, x2, y2 = bbox
184
+ ax.add_patch(
185
+ plt.Rectangle(
186
+ (x1, y1), x2 - x1, y2 - y1, fill=False, edgecolor="b", linewidth=1
187
+ )
188
+ )
189
+ if retinamask == False:
190
+ mask_cpu = cv2.resize(
191
+ mask_cpu, (target_width, target_height), interpolation=cv2.INTER_NEAREST
192
+ )
193
+ return mask_cpu