Podtekatel commited on
Commit
d44e389
1 Parent(s): bd5330b

Update images and search

Browse files

Update token recieving
Update token recieving
Update model to inference

app.py CHANGED
@@ -1,30 +1,50 @@
 
 
1
  import gradio as gr
 
 
2
  from huggingface_hub import hf_hub_url, cached_download
3
- import dlib
4
 
 
 
 
 
 
5
  def load_model():
6
  REPO_ID = "MalchuL/JJBAGAN"
7
  FILENAME = "198_jjba_8_k_2_099_ep.onnx"
8
 
9
  global model
 
10
 
11
- model = cached_download(
12
- hf_hub_url(REPO_ID, FILENAME)
13
  )
 
14
 
 
15
  return model
16
 
17
- def inference(img):
18
 
19
- return img
 
 
 
 
20
 
21
 
22
  title = "JJStyleTransfer"
23
  description = "Gradio Demo for JoJo Bizzare Adventures 5 season style transfer. To use it, simply upload your image, or click one of the examples to load them."
24
- article = "Github Repo Pytorch "
25
- examples = [['demo/karin.jpg'],
26
- ['demo/tucker.png'],
27
- ['demo/biden.jpg']]
 
 
 
 
 
28
 
29
  demo = gr.Interface(
30
  fn=inference,
 
1
+ import os
2
+
3
  import gradio as gr
4
+ import numpy as np
5
+ from PIL import Image
6
  from huggingface_hub import hf_hub_url, cached_download
 
7
 
8
+ from inference.face_detector import StatRetinaFaceDetector
9
+ from inference.model_pipeline import VSNetModelPipeline
10
+ from inference.onnx_model import ONNXModel
11
+
12
+ MODEL_IMG_SIZE = 256
13
  def load_model():
14
  REPO_ID = "MalchuL/JJBAGAN"
15
  FILENAME = "198_jjba_8_k_2_099_ep.onnx"
16
 
17
  global model
18
+ global pipeline
19
 
20
+ model_path = cached_download(
21
+ hf_hub_url(REPO_ID, FILENAME), use_auth_token=os.getenv('HF_TOKEN')
22
  )
23
+ model = ONNXModel(model_path)
24
 
25
+ pipeline = VSNetModelPipeline(model, StatRetinaFaceDetector(MODEL_IMG_SIZE), background_resize=1024, no_detected_resize=1024)
26
  return model
27
 
28
+ load_model()
29
 
30
+ def inference(img):
31
+ img = np.array(img)
32
+ out_img = pipeline(img)
33
+ out_img = Image.fromarray(out_img)
34
+ return out_img
35
 
36
 
37
  title = "JJStyleTransfer"
38
  description = "Gradio Demo for JoJo Bizzare Adventures 5 season style transfer. To use it, simply upload your image, or click one of the examples to load them."
39
+ article = "There is one of my successful experiments on style transfer. I used my own pipeline, generator model and private dataset to train this model<br>" \
40
+ "" \
41
+ "" \
42
+ "" \
43
+ "" \
44
+ "If you want use this app or integrate this model into your app please contact with me at email '[email protected]'"
45
+
46
+ imgs_folder = 'demo'
47
+ examples = [[os.path.join(imgs_folder, img_filename)] for img_filename in os.listdir(imgs_folder)]
48
 
49
  demo = gr.Interface(
50
  fn=inference,
demo/IMG1.jpg ADDED
demo/IMG2.jpg ADDED
demo/{tucker.png → IMG3.png} RENAMED
File without changes
demo/biden.jpg DELETED
Binary file (291 kB)
 
demo/karin.jpg DELETED
Binary file (29.1 kB)
 
inference/__init__.py ADDED
File without changes
inference/box_utils.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ def convert_to_square(bboxes):
5
+ """Convert bounding boxes to a square form.
6
+ Arguments:
7
+ bboxes: a float numpy array of shape [n, 4].
8
+ Returns:
9
+ a float numpy array of shape [4],
10
+ squared bounding boxes.
11
+ """
12
+
13
+ square_bboxes = np.zeros_like(bboxes)
14
+ x1, y1, x2, y2 = bboxes
15
+ h = y2 - y1 + 1.0
16
+ w = x2 - x1 + 1.0
17
+ max_side = np.maximum(h, w)
18
+ square_bboxes[0] = x1 + w * 0.5 - max_side * 0.5
19
+ square_bboxes[1] = y1 + h * 0.5 - max_side * 0.5
20
+ square_bboxes[2] = square_bboxes[0] + max_side - 1.0
21
+ square_bboxes[3] = square_bboxes[1] + max_side - 1.0
22
+ return square_bboxes
23
+
24
+
25
+ def scale_box(box, scale):
26
+ x1, y1, x2, y2 = box
27
+ center_x, center_y = (x1 + x2) / 2, (y1 + y2) / 2
28
+ w, h = x2 - x1, y2 - y1
29
+ new_w, new_h = w * scale, h * scale
30
+ y1, y2, x1, x2 = center_y - new_h / 2, center_y + new_h / 2, center_x - new_w / 2, center_x + new_w / 2,
31
+ return np.array((x1, y1, x2, y2))
inference/center_crop.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+
4
+ # From albumentations
5
+ def center_crop(img: np.ndarray, crop_height: int, crop_width: int):
6
+ height, width = img.shape[:2]
7
+ if height < crop_height or width < crop_width:
8
+ raise ValueError(
9
+ "Requested crop size ({crop_height}, {crop_width}) is "
10
+ "larger than the image size ({height}, {width})".format(
11
+ crop_height=crop_height, crop_width=crop_width, height=height, width=width
12
+ )
13
+ )
14
+ x1, y1, x2, y2 = get_center_crop_coords(height, width, crop_height, crop_width)
15
+ img = img[y1:y2, x1:x2]
16
+ return img
17
+
18
+
19
+ def get_center_crop_coords(height: int, width: int, crop_height: int, crop_width: int):
20
+ y1 = (height - crop_height) // 2
21
+ y2 = y1 + crop_height
22
+ x1 = (width - crop_width) // 2
23
+ x2 = x1 + crop_width
24
+ return x1, y1, x2, y2
inference/face_detector.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import ABC, abstractmethod
3
+ from typing import List
4
+
5
+ import cv2
6
+ import numpy as np
7
+ from retinaface import RetinaFace
8
+ from retinaface.model import retinaface_model
9
+
10
+ from .box_utils import convert_to_square
11
+
12
+
13
+ class FaceDetector(ABC):
14
+ def __init__(self, target_size):
15
+ self.target_size = target_size
16
+ @abstractmethod
17
+ def detect_crops(self, img, *args, **kwargs) -> List[np.ndarray]:
18
+ """
19
+ Img is a numpy ndarray in range [0..255], uint8 dtype, RGB type
20
+ Returns ndarray with [x1, y1, x2, y2] in row
21
+ """
22
+ pass
23
+
24
+ @abstractmethod
25
+ def postprocess_crops(self, crops, *args, **kwargs) -> List[np.ndarray]:
26
+ return crops
27
+
28
+ def sort_faces(self, crops):
29
+ sorted_faces = sorted(crops, key=lambda x: -(x[2] - x[0]) * (x[3] - x[1]))
30
+ sorted_faces = np.stack(sorted_faces, axis=0)
31
+ return sorted_faces
32
+
33
+ def fix_range_crops(self, img, crops):
34
+ H, W, _ = img.shape
35
+ final_crops = []
36
+ for crop in crops:
37
+ x1, y1, x2, y2 = crop
38
+ x1 = max(min(round(x1), W), 0)
39
+ y1 = max(min(round(y1), H), 0)
40
+ x2 = max(min(round(x2), W), 0)
41
+ y2 = max(min(round(y2), H), 0)
42
+ new_crop = [x1, y1, x2, y2]
43
+ final_crops.append(new_crop)
44
+ final_crops = np.array(final_crops, dtype=np.int)
45
+ return final_crops
46
+
47
+ def crop_faces(self, img, crops) -> List[np.ndarray]:
48
+ cropped_faces = []
49
+ for crop in crops:
50
+ x1, y1, x2, y2 = crop
51
+ face_crop = img[y1:y2, x1:x2, :]
52
+ cropped_faces.append(face_crop)
53
+ return cropped_faces
54
+
55
+ def unify_and_merge(self, cropped_images):
56
+ return cropped_images
57
+
58
+ def __call__(self, img):
59
+ return self.detect_faces(img)
60
+
61
+ def detect_faces(self, img):
62
+ crops = self.detect_crops(img)
63
+ if crops is None or len(crops) == 0:
64
+ return [], []
65
+ crops = self.sort_faces(crops)
66
+ updated_crops = self.postprocess_crops(crops)
67
+ updated_crops = self.fix_range_crops(img, updated_crops)
68
+ cropped_faces = self.crop_faces(img, updated_crops)
69
+ unified_faces = self.unify_and_merge(cropped_faces)
70
+ return unified_faces, updated_crops
71
+
72
+
73
+ class StatRetinaFaceDetector(FaceDetector):
74
+ def __init__(self, target_size=None):
75
+ super().__init__(target_size)
76
+ self.model = retinaface_model.build_model()
77
+ #self.relative_offsets = [0.3258, 0.5225, 0.3258, 0.1290]
78
+ self.relative_offsets = [0.3619, 0.5830, 0.3619, 0.1909]
79
+
80
+ def postprocess_crops(self, crops, *args, **kwargs) -> np.ndarray:
81
+ final_crops = []
82
+ x1_offset, y1_offset, x2_offset, y2_offset = self.relative_offsets
83
+ for crop in crops:
84
+ x1, y1, x2, y2 = crop
85
+ w, h = x2 - x1, y2 - y1
86
+ x1 -= w * x1_offset
87
+ y1 -= h * y1_offset
88
+ x2 += w * x2_offset
89
+ y2 += h * y2_offset
90
+ crop = np.array([x1, y1, x2, y2], dtype=crop.dtype)
91
+ crop = convert_to_square(crop)
92
+ final_crops.append(crop)
93
+ final_crops = np.stack(final_crops, axis=0)
94
+ return final_crops
95
+
96
+ def detect_crops(self, img, *args, **kwargs):
97
+ faces = RetinaFace.detect_faces(img, model=self.model)
98
+ crops = []
99
+ for naem, face in faces.items():
100
+ x1, y1, x2, y2 = face['facial_area']
101
+ crop = np.array([x1, y1, x2, y2])
102
+ crops.append(crop)
103
+ if len(crops) > 0:
104
+ crops = np.stack(crops, axis=0)
105
+ return crops
106
+
107
+ def unify_and_merge(self, cropped_images):
108
+ if self.target_size is None:
109
+ return cropped_images
110
+ else:
111
+ resized_images = []
112
+ for cropped_image in cropped_images:
113
+ resized_image = cv2.resize(cropped_image, (self.target_size, self.target_size),
114
+ interpolation=cv2.INTER_LINEAR)
115
+ resized_images.append(resized_image)
116
+
117
+ resized_images = np.stack(resized_images, axis=0)
118
+ return resized_images
119
+
inference/model_pipeline.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+
3
+ import cv2
4
+ import numpy as np
5
+
6
+ from .center_crop import center_crop
7
+ from .face_detector import FaceDetector
8
+
9
+
10
+ class VSNetModelPipeline:
11
+ def __init__(self, model, face_detector: FaceDetector, background_resize=720, no_detected_resize=256):
12
+ self.background_resize = background_resize
13
+ self.no_detected_resize = no_detected_resize
14
+ self.model = model
15
+ self.face_detector = face_detector
16
+ self.mask = self.create_circular_mask(face_detector.target_size, face_detector.target_size, power=1 / 4)
17
+
18
+ @staticmethod
19
+ def create_circular_mask(h, w, center=None, power=None):
20
+
21
+ if center is None: # use the middle of the image
22
+ center = (int(w / 2), int(h / 2))
23
+
24
+ Y, X = np.ogrid[:h, :w]
25
+ dist_from_center = np.sqrt((X - center[0]) ** 2 + (Y - center[1]) ** 2)
26
+ print(dist_from_center.max())
27
+ dist_from_center = np.clip(dist_from_center, a_min=0, a_max=max(h / 2, w / 2))
28
+ dist_from_center = 1 - dist_from_center / np.max(dist_from_center)
29
+ if power is not None:
30
+ dist_from_center = np.power(dist_from_center, power)
31
+ dist_from_center = np.stack([dist_from_center] * 3, axis=2)
32
+ # mask = dist_from_center <= radius
33
+ return dist_from_center
34
+
35
+ @staticmethod
36
+ def resize_size(image, size=720, always_apply=True):
37
+ h, w, c = np.shape(image)
38
+ if min(h, w) > size or always_apply:
39
+ if h < w:
40
+ h, w = int(size * h / w), size
41
+ else:
42
+ h, w = size, int(size * w / h)
43
+ image = cv2.resize(image, (w, h), interpolation=cv2.INTER_AREA)
44
+ return image
45
+
46
+ def normalize(self, img):
47
+ img = img.astype(np.float32) / 255 * 2 - 1
48
+ return img
49
+
50
+ def denormalize(self, img):
51
+ return (img + 1) / 2
52
+
53
+ def divide_crop(self, img, must_divided=32):
54
+ h, w, _ = img.shape
55
+ h = h // must_divided * must_divided
56
+ w = w // must_divided * must_divided
57
+
58
+ img = center_crop(img, h, w)
59
+ return img
60
+
61
+ def merge_crops(self, faces_imgs, crops, full_image):
62
+ for face, crop in zip(faces_imgs, crops):
63
+ x1, y1, x2, y2 = crop
64
+ W, H = x2 - x1, y2 - y1
65
+ result_face = cv2.resize(face, (W, H), interpolation=cv2.INTER_LINEAR)
66
+ face_mask = cv2.resize(self.mask, (W, H), interpolation=cv2.INTER_LINEAR)
67
+ input_face = full_image[y1: y2, x1: x2]
68
+ full_image[y1: y2, x1: x2] = (result_face * face_mask + input_face * (1 - face_mask)).astype(np.uint8)
69
+ return full_image
70
+
71
+ def __call__(self, img):
72
+ return self.process_image(img)
73
+
74
+ def process_image(self, img):
75
+ img = self.resize_size(img, size=self.background_resize)
76
+ img = self.divide_crop(img)
77
+
78
+ face_crops, coords = self.face_detector(img)
79
+
80
+ if len(face_crops) > 0:
81
+ start_time = time.time()
82
+ faces = self.normalize(face_crops)
83
+ faces = faces.transpose(0, 3, 1, 2)
84
+ out_faces = self.model(faces)
85
+ out_faces = self.denormalize(out_faces)
86
+ out_faces = out_faces.transpose(0, 2, 3, 1)
87
+ out_faces = np.clip(out_faces * 255, 0, 255).astype(np.uint8)
88
+ end_time = time.time()
89
+ print(f'Face FPS {1 / (end_time - start_time)}')
90
+ else:
91
+ out_faces = []
92
+ img = self.resize_size(img, size=self.no_detected_resize)
93
+ img = self.divide_crop(img)
94
+
95
+ start_time = time.time()
96
+ full_image = self.normalize(img)
97
+ full_image = np.expand_dims(full_image, 0).transpose(0, 3, 1, 2)
98
+ full_image = self.model(full_image)
99
+ full_image = self.denormalize(full_image)
100
+ full_image = full_image.transpose(0, 2, 3, 1)
101
+ full_image = np.clip(full_image * 255, 0, 255).astype(np.uint8)
102
+ end_time = time.time()
103
+ print(f'Background FPS {1 / (end_time - start_time)}')
104
+
105
+ result_image = self.merge_crops(out_faces, coords, full_image[0])
106
+ return result_image
inference/onnx_model.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import onnxruntime
3
+
4
+
5
+ class ONNXModel:
6
+ def __init__(self, onnx_mode_path):
7
+ self.path = onnx_mode_path
8
+ self.ort_session = onnxruntime.InferenceSession(str(self.path))
9
+ self.input_name = self.ort_session.get_inputs()[0].name
10
+
11
+ def __call__(self, img):
12
+ ort_inputs = {self.input_name: img.astype(dtype=np.float32)}
13
+ ort_outs = self.ort_session.run(None, ort_inputs)[0]
14
+ return ort_outs
packages.txt CHANGED
@@ -1,2 +1 @@
1
- python3-opencv
2
- dlib
 
1
+ python3-opencv
 
requirements.txt CHANGED
@@ -1,6 +1,5 @@
1
- joblib
2
  huggingface_hub
3
  onnxruntime
4
  numpy
5
  gradio
6
- dlib
 
 
1
  huggingface_hub
2
  onnxruntime
3
  numpy
4
  gradio
5
+ retina-face