Spanicin commited on
Commit
e30d08a
1 Parent(s): 1682f7d

Update videoretalking/third_part/GPEN/face_parse/face_parsing.py

Browse files
videoretalking/third_part/GPEN/face_parse/face_parsing.py CHANGED
@@ -1,148 +1,148 @@
1
- '''
2
- @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3
- @author: yangxy (yangtao9009@gmail.com)
4
- '''
5
- import os
6
- import cv2
7
- import torch
8
- import numpy as np
9
- from face_parse.parse_model import ParseNet
10
- import torch.nn.functional as F
11
-
12
- from face_parse.model import BiSeNet
13
- import torchvision.transforms as transforms
14
-
15
- class FaceParse(object):
16
- def __init__(self, base_dir='./', model='ParseNet-latest', device='cuda', mask_map = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0]):
17
- self.mfile = os.path.join(base_dir, model+'.pth')
18
- self.size = 512
19
- self.device = device
20
-
21
- '''
22
- 0: 'background' 1: 'skin' 2: 'nose'
23
- 3: 'eye_g' 4: 'l_eye' 5: 'r_eye'
24
- 6: 'l_brow' 7: 'r_brow' 8: 'l_ear'
25
- 9: 'r_ear' 10: 'mouth' 11: 'u_lip'
26
- 12: 'l_lip' 13: 'hair' 14: 'hat'
27
- 15: 'ear_r' 16: 'neck_l' 17: 'neck'
28
- 18: 'cloth'
29
- '''
30
- # self.MASK_COLORMAP = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
31
- #self.#MASK_COLORMAP = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [0, 0, 0], [0, 0, 0]]
32
- # self.MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
33
- self.MASK_COLORMAP = mask_map
34
-
35
- self.load_model()
36
-
37
- def load_model(self):
38
- self.faceparse = ParseNet(self.size, self.size, 32, 64, 19, norm_type='bn', relu_type='LeakyReLU', ch_range=[32, 256])
39
- self.faceparse.load_state_dict(torch.load(self.mfile, map_location=torch.device('cpu')))
40
- self.faceparse.to(self.device)
41
- self.faceparse.eval()
42
-
43
- def process(self, im, masks=[0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0]):
44
- im = cv2.resize(im, (self.size, self.size))
45
- imt = self.img2tensor(im)
46
- with torch.no_grad():
47
- pred_mask, sr_img_tensor = self.faceparse(imt) # (1, 19, 512, 512)
48
- mask = self.tenor2mask(pred_mask, masks)
49
-
50
- return mask
51
-
52
- def process_tensor(self, imt):
53
- imt = F.interpolate(imt.flip(1)*2-1, (self.size, self.size))
54
- pred_mask, sr_img_tensor = self.faceparse(imt)
55
-
56
- mask = pred_mask.argmax(dim=1)
57
- for idx, color in enumerate(self.MASK_COLORMAP):
58
- mask = torch.where(mask==idx, color, mask)
59
- #mask = mask.repeat(3, 1, 1).unsqueeze(0) #.cpu().float().numpy()
60
- mask = mask.unsqueeze(0)
61
-
62
- return mask
63
-
64
- def img2tensor(self, img):
65
- img = img[..., ::-1] # BGR to RGB
66
- img = img / 255. * 2 - 1
67
- img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device)
68
- return img_tensor.float()
69
-
70
- def tenor2mask(self, tensor, masks):
71
- if len(tensor.shape) < 4:
72
- tensor = tensor.unsqueeze(0)
73
- if tensor.shape[1] > 1:
74
- tensor = tensor.argmax(dim=1)
75
-
76
- tensor = tensor.squeeze(1).data.cpu().numpy() # (1, 512, 512)
77
- color_maps = []
78
- for t in tensor:
79
- #tmp_img = np.zeros(tensor.shape[1:] + (3,))
80
- tmp_img = np.zeros(tensor.shape[1:])
81
- for idx, color in enumerate(masks):
82
- tmp_img[t == idx] = color
83
- color_maps.append(tmp_img.astype(np.uint8))
84
- return color_maps
85
-
86
-
87
-
88
- class FaceParse_v2(object):
89
- def __init__(self, device='cuda', mask_map = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0]):
90
- self.mfile = '/apdcephfs/private_quincheng/Expression/face-parsing.PyTorch/res/cp/79999_iter.pth'
91
- self.size = 512
92
- self.device = device
93
-
94
- '''
95
- 0: 'background' 1: 'skin' 2: 'nose'
96
- 3: 'eye_g' 4: 'l_eye' 5: 'r_eye'
97
- 6: 'l_brow' 7: 'r_brow' 8: 'l_ear'
98
- 9: 'r_ear' 10: 'mouth' 11: 'u_lip'
99
- 12: 'l_lip' 13: 'hair' 14: 'hat'
100
- 15: 'ear_r' 16: 'neck_l' 17: 'neck'
101
- 18: 'cloth'
102
- '''
103
- # self.MASK_COLORMAP = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
104
- #self.#MASK_COLORMAP = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [0, 0, 0], [0, 0, 0]]
105
- # self.MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
106
- self.MASK_COLORMAP = mask_map
107
- self.load_model()
108
- self.to_tensor = transforms.Compose([
109
- transforms.ToTensor(),
110
- transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
111
- ])
112
-
113
- def load_model(self):
114
- self.faceparse = BiSeNet(n_classes=19)
115
- self.faceparse.load_state_dict(torch.load(self.mfile))
116
- self.faceparse.to(self.device)
117
- self.faceparse.eval()
118
-
119
- def process(self, im, masks=[0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0]):
120
- im = cv2.resize(im[...,::-1], (self.size, self.size))
121
- im = self.to_tensor(im)
122
- imt = torch.unsqueeze(im, 0).to(self.device)
123
- with torch.no_grad():
124
- pred_mask = self.faceparse(imt)[0]
125
- mask = self.tenor2mask(pred_mask, masks)
126
- return mask
127
-
128
- # def img2tensor(self, img):
129
- # img = img[..., ::-1] # BGR to RGB
130
- # img = img / 255. * 2 - 1
131
- # img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device)
132
- # return img_tensor.float()
133
-
134
- def tenor2mask(self, tensor, masks):
135
- if len(tensor.shape) < 4:
136
- tensor = tensor.unsqueeze(0)
137
- if tensor.shape[1] > 1:
138
- tensor = tensor.argmax(dim=1)
139
-
140
- tensor = tensor.squeeze(1).data.cpu().numpy()
141
- color_maps = []
142
- for t in tensor:
143
- #tmp_img = np.zeros(tensor.shape[1:] + (3,))
144
- tmp_img = np.zeros(tensor.shape[1:])
145
- for idx, color in enumerate(masks):
146
- tmp_img[t == idx] = color
147
- color_maps.append(tmp_img.astype(np.uint8))
148
  return color_maps
 
1
+ '''
2
+ @paper: GAN Prior Embedded Network for Blind Face Restoration in the Wild (CVPR2021)
3
+ @author: yangxy (yangtao9009@gmail.com)
4
+ '''
5
+ import os
6
+ import cv2
7
+ import torch
8
+ import numpy as np
9
+ from videoretalking.third_part.GPEN.face_parse.parse_model import ParseNet
10
+ import torch.nn.functional as F
11
+
12
+ from videoretalking.third_part.GPEN.face_parse.model import BiSeNet
13
+ import torchvision.transforms as transforms
14
+
15
+ class FaceParse(object):
16
+ def __init__(self, base_dir='./', model='ParseNet-latest', device='cuda', mask_map = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0]):
17
+ self.mfile = os.path.join(base_dir, model+'.pth')
18
+ self.size = 512
19
+ self.device = device
20
+
21
+ '''
22
+ 0: 'background' 1: 'skin' 2: 'nose'
23
+ 3: 'eye_g' 4: 'l_eye' 5: 'r_eye'
24
+ 6: 'l_brow' 7: 'r_brow' 8: 'l_ear'
25
+ 9: 'r_ear' 10: 'mouth' 11: 'u_lip'
26
+ 12: 'l_lip' 13: 'hair' 14: 'hat'
27
+ 15: 'ear_r' 16: 'neck_l' 17: 'neck'
28
+ 18: 'cloth'
29
+ '''
30
+ # self.MASK_COLORMAP = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
31
+ #self.#MASK_COLORMAP = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [0, 0, 0], [0, 0, 0]]
32
+ # self.MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
33
+ self.MASK_COLORMAP = mask_map
34
+
35
+ self.load_model()
36
+
37
+ def load_model(self):
38
+ self.faceparse = ParseNet(self.size, self.size, 32, 64, 19, norm_type='bn', relu_type='LeakyReLU', ch_range=[32, 256])
39
+ self.faceparse.load_state_dict(torch.load(self.mfile, map_location=torch.device('cpu')))
40
+ self.faceparse.to(self.device)
41
+ self.faceparse.eval()
42
+
43
+ def process(self, im, masks=[0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0]):
44
+ im = cv2.resize(im, (self.size, self.size))
45
+ imt = self.img2tensor(im)
46
+ with torch.no_grad():
47
+ pred_mask, sr_img_tensor = self.faceparse(imt) # (1, 19, 512, 512)
48
+ mask = self.tenor2mask(pred_mask, masks)
49
+
50
+ return mask
51
+
52
+ def process_tensor(self, imt):
53
+ imt = F.interpolate(imt.flip(1)*2-1, (self.size, self.size))
54
+ pred_mask, sr_img_tensor = self.faceparse(imt)
55
+
56
+ mask = pred_mask.argmax(dim=1)
57
+ for idx, color in enumerate(self.MASK_COLORMAP):
58
+ mask = torch.where(mask==idx, color, mask)
59
+ #mask = mask.repeat(3, 1, 1).unsqueeze(0) #.cpu().float().numpy()
60
+ mask = mask.unsqueeze(0)
61
+
62
+ return mask
63
+
64
+ def img2tensor(self, img):
65
+ img = img[..., ::-1] # BGR to RGB
66
+ img = img / 255. * 2 - 1
67
+ img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device)
68
+ return img_tensor.float()
69
+
70
+ def tenor2mask(self, tensor, masks):
71
+ if len(tensor.shape) < 4:
72
+ tensor = tensor.unsqueeze(0)
73
+ if tensor.shape[1] > 1:
74
+ tensor = tensor.argmax(dim=1)
75
+
76
+ tensor = tensor.squeeze(1).data.cpu().numpy() # (1, 512, 512)
77
+ color_maps = []
78
+ for t in tensor:
79
+ #tmp_img = np.zeros(tensor.shape[1:] + (3,))
80
+ tmp_img = np.zeros(tensor.shape[1:])
81
+ for idx, color in enumerate(masks):
82
+ tmp_img[t == idx] = color
83
+ color_maps.append(tmp_img.astype(np.uint8))
84
+ return color_maps
85
+
86
+
87
+
88
+ class FaceParse_v2(object):
89
+ def __init__(self, device='cuda', mask_map = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0]):
90
+ self.mfile = '/apdcephfs/private_quincheng/Expression/face-parsing.PyTorch/res/cp/79999_iter.pth'
91
+ self.size = 512
92
+ self.device = device
93
+
94
+ '''
95
+ 0: 'background' 1: 'skin' 2: 'nose'
96
+ 3: 'eye_g' 4: 'l_eye' 5: 'r_eye'
97
+ 6: 'l_brow' 7: 'r_brow' 8: 'l_ear'
98
+ 9: 'r_ear' 10: 'mouth' 11: 'u_lip'
99
+ 12: 'l_lip' 13: 'hair' 14: 'hat'
100
+ 15: 'ear_r' 16: 'neck_l' 17: 'neck'
101
+ 18: 'cloth'
102
+ '''
103
+ # self.MASK_COLORMAP = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]]
104
+ #self.#MASK_COLORMAP = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [255, 153, 51], [0, 204, 0]] = [[0, 0, 0], [204, 0, 0], [76, 153, 0], [204, 204, 0], [51, 51, 255], [204, 0, 204], [0, 255, 255], [255, 204, 204], [102, 51, 0], [255, 0, 0], [102, 204, 0], [255, 255, 0], [0, 0, 153], [0, 0, 204], [255, 51, 153], [0, 204, 204], [0, 51, 0], [0, 0, 0], [0, 0, 0]]
105
+ # self.MASK_COLORMAP = [0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 255, 0, 0, 0]
106
+ self.MASK_COLORMAP = mask_map
107
+ self.load_model()
108
+ self.to_tensor = transforms.Compose([
109
+ transforms.ToTensor(),
110
+ transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
111
+ ])
112
+
113
+ def load_model(self):
114
+ self.faceparse = BiSeNet(n_classes=19)
115
+ self.faceparse.load_state_dict(torch.load(self.mfile))
116
+ self.faceparse.to(self.device)
117
+ self.faceparse.eval()
118
+
119
+ def process(self, im, masks=[0, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0]):
120
+ im = cv2.resize(im[...,::-1], (self.size, self.size))
121
+ im = self.to_tensor(im)
122
+ imt = torch.unsqueeze(im, 0).to(self.device)
123
+ with torch.no_grad():
124
+ pred_mask = self.faceparse(imt)[0]
125
+ mask = self.tenor2mask(pred_mask, masks)
126
+ return mask
127
+
128
+ # def img2tensor(self, img):
129
+ # img = img[..., ::-1] # BGR to RGB
130
+ # img = img / 255. * 2 - 1
131
+ # img_tensor = torch.from_numpy(img.transpose(2, 0, 1)).unsqueeze(0).to(self.device)
132
+ # return img_tensor.float()
133
+
134
+ def tenor2mask(self, tensor, masks):
135
+ if len(tensor.shape) < 4:
136
+ tensor = tensor.unsqueeze(0)
137
+ if tensor.shape[1] > 1:
138
+ tensor = tensor.argmax(dim=1)
139
+
140
+ tensor = tensor.squeeze(1).data.cpu().numpy()
141
+ color_maps = []
142
+ for t in tensor:
143
+ #tmp_img = np.zeros(tensor.shape[1:] + (3,))
144
+ tmp_img = np.zeros(tensor.shape[1:])
145
+ for idx, color in enumerate(masks):
146
+ tmp_img[t == idx] = color
147
+ color_maps.append(tmp_img.astype(np.uint8))
148
  return color_maps