mengting commited on
Commit
5c17f58
·
1 Parent(s): d859277
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mat filter=lfs diff=lfs merge=lfs -text
pre_trained/unet_denoise/checkpoint-70000/config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNetDenoise2DConditionModel",
3
+ "_diffusers_version": "0.25.1",
4
+ "_name_or_path": "/scratch/project_462000772/wmengting/diffusion_models/stable-diffusion-v1-5",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "attention_type": "default",
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "class_embed_type": null,
19
+ "class_embeddings_concat": false,
20
+ "conv_in_kernel": 3,
21
+ "conv_out_kernel": 3,
22
+ "cross_attention_dim": 768,
23
+ "cross_attention_norm": null,
24
+ "down_block_types": [
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "DownBlock2D"
29
+ ],
30
+ "downsample_padding": 1,
31
+ "dropout": 0.0,
32
+ "dual_cross_attention": false,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "in_channels": 12,
38
+ "layers_per_block": 2,
39
+ "mid_block_only_cross_attention": null,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "out_channels": 4,
48
+ "projection_class_embeddings_input_dim": null,
49
+ "resnet_out_scale_factor": 1.0,
50
+ "resnet_skip_time_act": false,
51
+ "resnet_time_scale_shift": "default",
52
+ "reverse_transformer_layers_per_block": null,
53
+ "sample_size": 64,
54
+ "time_cond_proj_dim": null,
55
+ "time_embedding_act_fn": null,
56
+ "time_embedding_dim": null,
57
+ "time_embedding_type": "positional",
58
+ "timestep_post_act": null,
59
+ "transformer_layers_per_block": 1,
60
+ "up_block_types": [
61
+ "UpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "CrossAttnUpBlock2D",
64
+ "CrossAttnUpBlock2D"
65
+ ],
66
+ "upcast_attention": false,
67
+ "use_linear_projection": false
68
+ }
pre_trained/unet_denoise/checkpoint-70000/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a5eb8f49aba84c10f607630f28785b55037cde350e5ec90563d1ddd200dd64c8
3
+ size 3438592680
pre_trained/unet_id/checkpoint-70000/config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "UNetID2DConditionModel",
3
+ "_diffusers_version": "0.25.1",
4
+ "_name_or_path": "/scratch/project_462000772/wmengting/diffusion_models/stable-diffusion-v1-5",
5
+ "act_fn": "silu",
6
+ "addition_embed_type": null,
7
+ "addition_embed_type_num_heads": 64,
8
+ "addition_time_embed_dim": null,
9
+ "attention_head_dim": 8,
10
+ "attention_type": "default",
11
+ "block_out_channels": [
12
+ 320,
13
+ 640,
14
+ 1280,
15
+ 1280
16
+ ],
17
+ "center_input_sample": false,
18
+ "class_embed_type": null,
19
+ "class_embeddings_concat": false,
20
+ "conv_in_kernel": 3,
21
+ "conv_out_kernel": 3,
22
+ "cross_attention_dim": 768,
23
+ "cross_attention_norm": null,
24
+ "down_block_types": [
25
+ "CrossAttnDownBlock2D",
26
+ "CrossAttnDownBlock2D",
27
+ "CrossAttnDownBlock2D",
28
+ "DownBlock2D"
29
+ ],
30
+ "downsample_padding": 1,
31
+ "dropout": 0.0,
32
+ "dual_cross_attention": false,
33
+ "encoder_hid_dim": null,
34
+ "encoder_hid_dim_type": null,
35
+ "flip_sin_to_cos": true,
36
+ "freq_shift": 0,
37
+ "in_channels": 4,
38
+ "layers_per_block": 2,
39
+ "mid_block_only_cross_attention": null,
40
+ "mid_block_scale_factor": 1,
41
+ "mid_block_type": "UNetMidBlock2DCrossAttn",
42
+ "norm_eps": 1e-05,
43
+ "norm_num_groups": 32,
44
+ "num_attention_heads": null,
45
+ "num_class_embeds": null,
46
+ "only_cross_attention": false,
47
+ "out_channels": 4,
48
+ "projection_class_embeddings_input_dim": null,
49
+ "resnet_out_scale_factor": 1.0,
50
+ "resnet_skip_time_act": false,
51
+ "resnet_time_scale_shift": "default",
52
+ "reverse_transformer_layers_per_block": null,
53
+ "sample_size": 64,
54
+ "time_cond_proj_dim": null,
55
+ "time_embedding_act_fn": null,
56
+ "time_embedding_dim": null,
57
+ "time_embedding_type": "positional",
58
+ "timestep_post_act": null,
59
+ "transformer_layers_per_block": 1,
60
+ "up_block_types": [
61
+ "UpBlock2D",
62
+ "CrossAttnUpBlock2D",
63
+ "CrossAttnUpBlock2D",
64
+ "CrossAttnUpBlock2D"
65
+ ],
66
+ "upcast_attention": false,
67
+ "use_linear_projection": false
68
+ }
pre_trained/unet_id/checkpoint-70000/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1e441ce5eed311efe690e31674d7ba31283bb5d4d6ee6c522f6001c6ce42faa
3
+ size 3438167536
utils/checkpoints/net_seg_res18.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:041ab78a4f8f756cd7e93df0d2840d03162e46b9c463c144f7fdf3ee3e6c4233
3
+ size 57429148
utils/checkpoints/third_party/BFM_model_front.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9f127eb18c3d022acbdbfcf1b6b353d01a6e01785d675a67cc31a3826a5be0f
3
+ size 127170280
utils/checkpoints/third_party/d3dfr_res50_nofc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52c54b90304a06c16b6813910c26faff1a907d4f8bd69a71ad4ecff43b41a090
3
+ size 96449126
utils/third_party/__pycache__/model_resnet_d3dfr.cpython-39.pyc ADDED
Binary file (14.3 kB). View file
 
utils/third_party/d3dfr/__pycache__/bfm.cpython-39.pyc ADDED
Binary file (12.6 kB). View file
 
utils/third_party/d3dfr/bfm.py ADDED
@@ -0,0 +1,473 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ # import torch.nn as nn
4
+ from scipy.io import loadmat
5
+ import numpy as np
6
+ import torch.nn.functional as F
7
+
8
+ # CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
9
+
10
+
11
+ def perspective_projection(focal, center):
12
+ # return p.T (N, 3) @ (3, 3)
13
+ return np.array([
14
+ focal, 0, center,
15
+ 0, focal, center,
16
+ 0, 0, 1
17
+ ]).reshape([3, 3]).astype(np.float32).transpose()
18
+
19
+
20
+ class SH:
21
+ def __init__(self):
22
+ self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
23
+ self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
24
+
25
+
26
+ class BFM(torch.nn.Module):
27
+ # BFM 3D face model
28
+ def __init__(self,
29
+ recenter=True,
30
+ camera_distance=10.,
31
+ init_lit=np.array([0.8, 0, 0, 0, 0, 0, 0, 0, 0]),
32
+ focal=1015.,
33
+ image_size=224,
34
+ bfm_model_path='pretrained/BFM_model_front.mat'
35
+ ):
36
+ super().__init__()
37
+ model = loadmat(bfm_model_path)
38
+ # self.bfm_uv = loadmat(os.path.join(CURRENT_PATH, 'BFM/BFM_UV.mat'))
39
+ # print(model.keys())
40
+ # mean face shape. [3*N,1]
41
+ # self.meanshape = torch.from_numpy(model['meanshape'])
42
+ self.register_buffer('meanshape', torch.from_numpy(model['meanshape']).float())
43
+
44
+ if recenter:
45
+ meanshape = self.meanshape.view(-1, 3)
46
+ meanshape = meanshape - torch.mean(meanshape, dim=0, keepdim=True)
47
+ self.meanshape = meanshape.view(-1, 1)
48
+
49
+ # identity basis. [3*N,80]
50
+ # self.idBase = torch.from_numpy(model['idBase'])
51
+ self.register_buffer('idBase', torch.from_numpy(model['idBase']).float())
52
+ # self.idBase = nn.Parameter(torch.from_numpy(model['idBase']).float())
53
+ # self.exBase = torch.from_numpy(model['exBase'].astype(
54
+ # np.float32)) # expression basis. [3*N,64]
55
+ self.register_buffer('exBase', torch.from_numpy(model['exBase']).float())
56
+ # self.exBase = nn.Parameter(torch.from_numpy(model['exBase']).float())
57
+ # mean face texture. [3*N,1] (0-255)
58
+ # self.meantex = torch.from_numpy(model['meantex'])
59
+ self.register_buffer('meantex', torch.from_numpy(model['meantex']).float())
60
+ # texture basis. [3*N,80]
61
+ # self.texBase = torch.from_numpy(model['texBase'])
62
+ self.register_buffer('texBase', torch.from_numpy(model['texBase']).float())
63
+ # self.texBase = nn.Parameter(torch.from_numpy(model['texBase']).float())
64
+
65
+ # triangle indices for each vertex that lies in. starts from 0. [N,8]
66
+ self.register_buffer('point_buf', torch.from_numpy(model['point_buf']).long()-1)
67
+ # self.point_buf = model['point_buf'].astype(np.int32)
68
+ # vertex indices in each triangle. starts from 0. [F,3]
69
+ self.register_buffer('face_buf', torch.from_numpy(model['tri']).long()-1)
70
+ # self.tri = model['tri'].astype(np.int32)
71
+ # vertex indices of 68 facial landmarks. starts from 0. [68]
72
+ self.register_buffer('keypoints', torch.from_numpy(model['keypoints']).long().view(68)-1)
73
+ # self.keypoints = model['keypoints'].astype(np.int32)[0]
74
+ # print(self.keypoints)
75
+ # print('keypoints', self.keypoints)
76
+
77
+ # vertex indices for small face region to compute photometric error. starts from 0.
78
+ # self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
79
+ self.register_buffer('front_mask', torch.from_numpy(np.squeeze(model['frontmask2_idx'])).long()-1)
80
+ # vertex indices for each face from small face region. starts from 0. [f,3]
81
+ # self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
82
+ self.register_buffer('front_face_buf', torch.from_numpy(np.squeeze(model['tri_mask2'])).long() - 1)
83
+ # vertex indices for pre-defined skin region to compute reflectance loss
84
+ # self.skin_mask = np.squeeze(model['skinmask'])
85
+ self.register_buffer('skin_mask', torch.from_numpy(np.squeeze(model['skinmask'])))
86
+
87
+
88
+ # keypoints_222 = []
89
+ # with open(os.path.join(CURRENT_PATH, 'BFM/D3DFR_222.txt'), 'r') as f:
90
+ # for line in f.readlines():
91
+ # idx = int(line.strip())
92
+ # keypoints_222.append(max(idx, 0))
93
+ # self.register_buffer('keypoints_222', torch.from_numpy(np.array(keypoints_222)).long())
94
+
95
+ # (1) right eye outer corner, (2) right eye inner corner, (3) left eye inner corner, (4) left eye outer corner,
96
+ # (5) nose bottom, (6) right mouth corner, (7) left mouth corner
97
+ self.register_buffer('keypoints_7', self.keypoints[[36, 39, 42, 45, 33, 48, 54]])
98
+
99
+ # self.persc_proj = torch.from_numpy(perspective_projection(focal, center)).float()
100
+ self.register_buffer('persc_proj', torch.from_numpy(perspective_projection(focal, image_size/2)))
101
+ self.camera_distance = camera_distance
102
+ self.image_size = image_size
103
+ self.SH = SH()
104
+ # self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
105
+ self.register_buffer('init_lit', torch.from_numpy(init_lit.reshape([1, 1, -1]).astype(np.float32)))
106
+
107
+ # (1) right eye outer corner, (2) right eye inner corner, (3) left eye inner corner, (4) left eye outer corner,
108
+ # (5) nose bottom, (6) right mouth corner, (7) left mouth corner
109
+ # print(self.keypoints[[36, 39, 42, 45, 33, 48, 54]])
110
+
111
+ # Lm3D = loadmat(os.path.join(CURRENT_PATH, 'BFM/similarity_Lm3D_all.mat'))
112
+ # Lm3D = Lm3D['lm']
113
+ # # print(Lm3D)
114
+ #
115
+ # # calculate 5 facial landmarks using 68 landmarks
116
+ # lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
117
+ # Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(
118
+ # Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0)
119
+ # Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
120
+ # self.Lm3D = Lm3D
121
+ # print(Lm3D.shape)
122
+
123
+ def split_coeff(self, coeff):
124
+ # input: coeff with shape [1,258]
125
+ id_coeff = coeff[:, 0:80] # identity(shape) coeff of dim 80
126
+ ex_coeff = coeff[:, 80:144] # expression coeff of dim 64
127
+ tex_coeff = coeff[:, 144:224] # texture(albedo) coeff of dim 80
128
+ gamma = coeff[:, 227:254] # lighting coeff for 3 channel SH function of dim 27
129
+ angles = coeff[:, 224:227] # ruler angles(x,y,z) for rotation of dim 3
130
+ translation = coeff[:, 254:257] # translation coeff of dim 3
131
+
132
+ return id_coeff, ex_coeff, tex_coeff, gamma, angles, translation
133
+
134
+ def split_coeff_orderly(self, coeff):
135
+ # input: coeff with shape [1,258]
136
+ id_coeff = coeff[:, 0:80] # identity(shape) coeff of dim 80
137
+ ex_coeff = coeff[:, 80:144] # expression coeff of dim 64
138
+ tex_coeff = coeff[:, 144:224] # texture(albedo) coeff of dim 80
139
+ angles = coeff[:, 224:227] # ruler angles(x,y,z) for rotation of dim 3
140
+ gamma = coeff[:, 227:254] # lighting coeff for 3 channel SH function of dim 27
141
+ translation = coeff[:, 254:257] # translation coeff of dim 3
142
+
143
+ return id_coeff, ex_coeff, tex_coeff, angles, gamma, translation
144
+
145
+ def compute_exp_deform(self, exp_coeff):
146
+ exp_part = torch.einsum('ij,aj->ai', self.exBase, exp_coeff)
147
+ return exp_part
148
+
149
+ def compute_id_deform(self, id_coeff):
150
+ id_part = torch.einsum('ij,aj->ai', self.idBase, id_coeff)
151
+ return id_part
152
+
153
+ def compute_shape_from_coeff(self, coeff):
154
+ id_coeff = coeff[:, 0:80]
155
+ ex_coeff = coeff[:, 80:144]
156
+ batch_size = coeff.shape[0]
157
+ id_part = torch.einsum('ij,aj->ai', self.idBase, id_coeff) #B, n
158
+ exp_part = torch.einsum('ij,aj->ai', self.exBase, ex_coeff) #B, n
159
+ face_shape = id_part + exp_part + self.meanshape.view(1, -1)
160
+ return face_shape.view(batch_size, -1, 3)
161
+
162
+ def compute_shape(self, id_coeff, exp_coeff):
163
+ """
164
+ Return:
165
+ face_shape -- torch.tensor, size (B, N, 3)
166
+ Parameters:
167
+ id_coeff -- torch.tensor, size (B, 80), identity coeffs
168
+ id_relative_scale -- torch.tensor, size (B, 1), identity coeffs
169
+ exp_coeff -- torch.tensor, size (B, 64), expression coeffs
170
+ """
171
+ batch_size = id_coeff.shape[0]
172
+ id_part = torch.einsum('ij,aj->ai', self.idBase, id_coeff) #B, n
173
+ exp_part = torch.einsum('ij,aj->ai', self.exBase, exp_coeff) #B, n
174
+ face_shape = id_part + exp_part + self.meanshape.view(1, -1)
175
+ return face_shape.view(batch_size, -1, 3)
176
+
177
+ def compute_texture(self, tex_coeff, normalize=True):
178
+ """
179
+ Return:
180
+ face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
181
+ Parameters:
182
+ tex_coeff -- torch.tensor, size (B, 80)
183
+ """
184
+ batch_size = tex_coeff.shape[0]
185
+ face_texture = torch.einsum('ij,aj->ai', self.texBase, tex_coeff) + self.meantex
186
+ if normalize:
187
+ face_texture = face_texture / 255.
188
+ return face_texture.view(batch_size, -1, 3)
189
+
190
+ def compute_norm(self, face_shape):
191
+ """
192
+ Return:
193
+ vertex_norm -- torch.tensor, size (B, N, 3)
194
+ Parameters:
195
+ face_shape -- torch.tensor, size (B, N, 3)
196
+ """
197
+
198
+ v1 = face_shape[:, self.face_buf[:, 0]]
199
+ v2 = face_shape[:, self.face_buf[:, 1]]
200
+ v3 = face_shape[:, self.face_buf[:, 2]]
201
+ e1 = v1 - v2
202
+ e2 = v2 - v3
203
+ face_norm = torch.cross(e1, e2, dim=-1)
204
+ face_norm = F.normalize(face_norm, dim=-1, p=2)
205
+ face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.meanshape)], dim=1)
206
+
207
+ vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
208
+ vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
209
+ return vertex_norm
210
+
211
+ def compute_color(self, face_texture, face_norm, gamma):
212
+ """
213
+ Return:
214
+ face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
215
+ Parameters:
216
+ face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
217
+ face_norm -- torch.tensor, size (B, N, 3), rotated face normal
218
+ gamma -- torch.tensor, size (B, 27), SH coeffs
219
+ """
220
+ batch_size = gamma.shape[0]
221
+ v_num = face_texture.shape[1]
222
+ a, c = self.SH.a, self.SH.c
223
+ gamma = gamma.reshape([batch_size, 3, 9])
224
+ gamma = gamma + self.init_lit
225
+ gamma = gamma.permute(0, 2, 1)
226
+ Y = torch.cat([
227
+ a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.meanshape),
228
+ -a[1] * c[1] * face_norm[..., 1:2],
229
+ a[1] * c[1] * face_norm[..., 2:],
230
+ -a[1] * c[1] * face_norm[..., :1],
231
+ a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
232
+ -a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
233
+ 0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
234
+ -a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
235
+ 0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
236
+ ], dim=-1)
237
+ r = Y @ gamma[..., :1]
238
+ g = Y @ gamma[..., 1:2]
239
+ b = Y @ gamma[..., 2:]
240
+ face_color = torch.cat([r, g, b], dim=-1) * face_texture
241
+ return face_color
242
+
243
+ def compute_rotation(self, angles):
244
+ """
245
+ Return:
246
+ rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
247
+ Parameters:
248
+ angles -- torch.tensor, size (B, 3), radian
249
+ """
250
+
251
+ batch_size = angles.shape[0]
252
+ ones = torch.ones([batch_size, 1]).to(self.meanshape)
253
+ zeros = torch.zeros([batch_size, 1]).to(self.meanshape)
254
+ x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
255
+
256
+ rot_x = torch.cat([
257
+ ones, zeros, zeros,
258
+ zeros, torch.cos(x), -torch.sin(x),
259
+ zeros, torch.sin(x), torch.cos(x)
260
+ ], dim=1).reshape([batch_size, 3, 3])
261
+
262
+ rot_y = torch.cat([
263
+ torch.cos(y), zeros, torch.sin(y),
264
+ zeros, ones, zeros,
265
+ -torch.sin(y), zeros, torch.cos(y)
266
+ ], dim=1).reshape([batch_size, 3, 3])
267
+
268
+ rot_z = torch.cat([
269
+ torch.cos(z), -torch.sin(z), zeros,
270
+ torch.sin(z), torch.cos(z), zeros,
271
+ zeros, zeros, ones
272
+ ], dim=1).reshape([batch_size, 3, 3])
273
+
274
+ rot = rot_z @ rot_y @ rot_x
275
+ return rot.permute(0, 2, 1)
276
+
277
+ def to_camera(self, face_shape):
278
+ face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
279
+ return face_shape
280
+
281
+ def to_image(self, face_shape):
282
+ """
283
+ Return:
284
+ face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
285
+ Parameters:
286
+ face_shape -- torch.tensor, size (B, N, 3)
287
+ """
288
+ # to image_plane
289
+ face_proj = face_shape @ self.persc_proj
290
+
291
+ # print(face_proj.shape)
292
+ face_proj = face_proj[..., :2] / face_proj[..., 2:]
293
+
294
+ return face_proj
295
+
296
+ def rotate(self, face_shape, rot):
297
+ """
298
+ Return:
299
+ face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
300
+ Parameters:
301
+ face_shape -- torch.tensor, size (B, N, 3)
302
+ rot -- torch.tensor, size (B, 3, 3)
303
+
304
+ """
305
+ return face_shape @ rot
306
+
307
+ def get_landmarks7(self, face_proj):
308
+ """
309
+ Return:
310
+ face_lms -- torch.tensor, size (B, 68, 2)
311
+ Parameters:
312
+ face_proj -- torch.tensor, size (B, N, 2)
313
+ """
314
+ return face_proj[:, self.keypoints_7, :]
315
+
316
+ def get_landmarks68(self, face_proj):
317
+ """
318
+ Return:
319
+ face_lms -- torch.tensor, size (B, 68, 2)
320
+ Parameters:
321
+ face_proj -- torch.tensor, size (B, N, 2)
322
+ """
323
+ return face_proj[:, self.keypoints, :]
324
+
325
+ def get_landmarks222(self, face_proj):
326
+ """
327
+ Return:
328
+ face_lms -- torch.tensor, size (B, 68, 2)
329
+ Parameters:
330
+ face_proj -- torch.tensor, size (B, N, 2)
331
+ """
332
+ return face_proj[:, self.keypoints_222, :]
333
+
334
+ def compute_for_render(self, coeffs):
335
+ """
336
+ Return:
337
+ face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
338
+ face_color -- torch.tensor, size (B, N, 3), in RGB order
339
+ landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
340
+ Parameters:
341
+ coeffs -- torch.tensor, size (B, 258)
342
+ """
343
+ id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeffs)
344
+ # id_relative_scale = id_relative_scale.clamp(0.9,1.1)
345
+ face_shape = self.compute_shape(id_coeff, ex_coeff)
346
+ # face_shape_noexp = self.compute_shape(id_coeff, torch.zeros_like(ex_coeff))
347
+ # print(face_shape.size())
348
+ rotation = self.compute_rotation(angles)
349
+ # print('rotation')
350
+
351
+ face_shape_rotated = self.rotate(face_shape, rotation)
352
+ face_shape_transformed = face_shape_rotated + translation.unsqueeze(1)
353
+ face_vertex = self.to_camera(face_shape_transformed)
354
+ face_proj = self.to_image(face_vertex)
355
+
356
+ # face_shape_transformed_noexp = self.transform(face_shape_noexp, rotation, translation, scale_xyz)
357
+ # face_vertex_noexp = self.to_camera(face_shape_transformed_noexp)
358
+
359
+ landmark68 = self.get_landmarks68(face_proj)
360
+ # landmark_face = face_proj[:,self.front_mask[::32], :]
361
+ landmark68[:, :, 1] = self.image_size - 1 - landmark68[:, :, 1]
362
+
363
+ face_texture = self.compute_texture(tex_coeff)
364
+ face_norm_roted = self.compute_norm(face_shape_rotated)
365
+ # face_norm_roted = face_norm @ rotation
366
+ face_color = self.compute_color(face_texture, face_norm_roted, gamma)
367
+
368
+ # face_norm_noexp = self.compute_norm(face_shape_noexp)
369
+ # face_norm_noexp_roted = face_norm_noexp @ rotation
370
+ # face_color_noexp = self.compute_color(face_texture, face_norm_noexp_roted, gamma)
371
+
372
+ return face_shape, face_vertex, face_color, face_texture, landmark68
373
+
374
+ def get_lm68(self, coeffs):
375
+ id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeffs)
376
+ ex_coeff = torch.zeros_like(ex_coeff)
377
+ # id_relative_scale = id_relative_scale.clamp(0.9,1.1)
378
+ face_shape = self.compute_shape(id_coeff, ex_coeff)
379
+ # face_shape_noexp = self.compute_shape(id_coeff, torch.zeros_like(ex_coeff))
380
+ # print(face_shape.size())
381
+ rotation = self.compute_rotation(angles)
382
+ # print('rotation')
383
+
384
+ face_shape_rotated = self.rotate(face_shape, rotation)
385
+ face_shape_transformed = face_shape_rotated + translation.unsqueeze(1)
386
+ face_vertex = self.to_camera(face_shape_transformed)
387
+ face_proj = self.to_image(face_vertex)
388
+
389
+ landmark68 = self.get_landmarks68(face_proj)
390
+ # landmark_face = face_proj[:,self.front_mask[::32], :]
391
+ landmark68[:, :, 1] = self.image_size - 1 - landmark68[:, :, 1]
392
+ return landmark68, ex_coeff
393
+
394
+ def get_coeffs(self, coeffs):
395
+ id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeffs)
396
+ return id_coeff, ex_coeff, tex_coeff, gamma, angles, translation
397
+
398
+ def get_vertex(self, coeffs):
399
+ id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeffs)
400
+ # id_relative_scale = id_relative_scale.clamp(0.9,1.1)
401
+ face_shape = self.compute_shape(id_coeff, ex_coeff)
402
+ # face_shape_noexp = self.compute_shape(id_coeff, torch.zeros_like(ex_coeff))
403
+ # print(face_shape.size())
404
+ rotation = self.compute_rotation(angles)
405
+ # print('rotation')
406
+
407
+ face_shape_rotated = self.rotate(face_shape, rotation)
408
+ face_shape_transformed = face_shape_rotated + translation.unsqueeze(1)
409
+ face_vertex = self.to_camera(face_shape_transformed)
410
+ face_proj = self.to_image(face_vertex)
411
+
412
+ return face_proj
413
+
414
+
415
+ def forward(self, coeffs):
416
+ face_shape, face_vertex, face_color, face_texture, landmark68 = self.compute_for_render(coeffs)
417
+ return face_shape, face_vertex, face_color, face_texture, landmark68
418
+
419
+
420
+ def save_obj(self, coeff, obj_name):
421
+ # The image size is 224 * 224
422
+ # face reconstruction with coeff and BFM model
423
+ id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeff)
424
+
425
+ # compute face shape
426
+ face_shape = self.compute_shape(id_coeff, ex_coeff).cpu().detach().numpy()[0]
427
+ face_tri = self.face_buf.cpu().numpy()
428
+
429
+ with open(obj_name, 'w') as fobj:
430
+ for i in range(face_shape.shape[0]):
431
+ fobj.write(
432
+ 'v ' + str(face_shape[i][0]) + ' ' + str(face_shape[i][1]) + ' ' + str(face_shape[i][2]) + '\n')
433
+
434
+ # start from 1
435
+ for i in range(face_tri.shape[0]):
436
+ fobj.write('f ' + str(face_tri[i][0] + 1) + ' ' + str(face_tri[i][1] + 1) + ' ' + str(
437
+ face_tri[i][2] + 1) + '\n')
438
+
439
+ # lm7 = face_shape[[2215, 5828, 10455, 14066, 8204, 5522, 10795], :]
440
+ # with open(obj_name[:-3]+'txt', 'w') as f:
441
+ # for point in lm7:
442
+ # f.write('{} {} {}\n'.format(point[0], point[1], point[2]))
443
+
444
+ def save_neutral_obj(self, coeff, obj_name):
445
+ # The image size is 224 * 224
446
+ # face reconstruction with coeff and BFM model
447
+ id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeff)
448
+
449
+ # compute face shape
450
+ face_shape = self.compute_shape(id_coeff, ex_coeff*0).cpu().numpy()[0]
451
+ face_tri = self.face_buf.cpu().numpy()
452
+
453
+ with open(obj_name, 'w') as fobj:
454
+ for i in range(face_shape.shape[0]):
455
+ fobj.write(
456
+ 'v ' + str(face_shape[i][0]) + ' ' + str(face_shape[i][1]) + ' ' + str(face_shape[i][2]) + '\n')
457
+
458
+ # start from 1
459
+ for i in range(face_tri.shape[0]):
460
+ fobj.write('f ' + str(face_tri[i][0] + 1) + ' ' + str(face_tri[i][1] + 1) + ' ' + str(
461
+ face_tri[i][2] + 1) + '\n')
462
+
463
+ # lm7 = face_shape[[2215, 5828, 10455, 14066, 8204, 5522, 10795], :]
464
+ # with open(obj_name[:-3]+'txt', 'w') as f:
465
+ # for point in lm7:
466
+ # f.write('{} {} {}\n'.format(point[0], point[1], point[2]))
467
+
468
+ # def clip(self, g_ratio=0.1, t_ratio=0.1):
469
+ # self.idBase.data = torch.minimum(torch.maximum(self.idBase_org * (1 - g_ratio), self.idBase.data), self.idBase_org * (1 + g_ratio))
470
+ # self.exBase.data = self.exBase_org #torch.minimum(torch.maximum(self.exBase_org * (1 - 0.001), self.exBase.data), self.exBase_org * (1 + 0.001))
471
+ # self.texBase.data = torch.minimum(torch.maximum(self.texBase_org * (1 - t_ratio), self.texBase.data), self.texBase_org * (1 + t_ratio))
472
+
473
+
utils/third_party/d3dfr_res50_nofc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:62e7e6a6bc4e16fb567643182ccabf55f8222746269cce3392109a6c592babc1
3
+ size 288887131
utils/third_party/insightface_backbone_conv.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from torch import nn
4
+
5
+ __all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200', 'getarcface']
6
+
7
+
8
+ def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
9
+ """3x3 convolution with padding"""
10
+ return nn.Conv2d(in_planes,
11
+ out_planes,
12
+ kernel_size=3,
13
+ stride=stride,
14
+ padding=dilation,
15
+ groups=groups,
16
+ bias=False,
17
+ dilation=dilation)
18
+
19
+
20
+ def conv1x1(in_planes, out_planes, stride=1):
21
+ """1x1 convolution"""
22
+ return nn.Conv2d(in_planes,
23
+ out_planes,
24
+ kernel_size=1,
25
+ stride=stride,
26
+ bias=False)
27
+
28
+
29
+ class IBasicBlock(nn.Module):
30
+ expansion = 1
31
+ def __init__(self, inplanes, planes, stride=1, downsample=None,
32
+ groups=1, base_width=64, dilation=1):
33
+ super(IBasicBlock, self).__init__()
34
+ if groups != 1 or base_width != 64:
35
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
36
+ if dilation > 1:
37
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
38
+ self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
39
+ self.conv1 = conv3x3(inplanes, planes)
40
+ self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
41
+ self.prelu = nn.PReLU(planes)
42
+ self.conv2 = conv3x3(planes, planes, stride)
43
+ self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
44
+ self.downsample = downsample
45
+ self.stride = stride
46
+
47
+ def forward(self, x):
48
+ identity = x
49
+ out = self.bn1(x)
50
+ out = self.conv1(out)
51
+ out = self.bn2(out)
52
+ out = self.prelu(out)
53
+ out = self.conv2(out)
54
+ out = self.bn3(out)
55
+ if self.downsample is not None:
56
+ identity = self.downsample(x)
57
+ out += identity
58
+ return out
59
+
60
+
61
+ class IResNet(nn.Module):
62
+ fc_scale = 7 * 7
63
+ def __init__(self,
64
+ block, layers, dropout=0, num_features=512, zero_init_residual=False,
65
+ groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
66
+ super(IResNet, self).__init__()
67
+ self.fp16 = fp16
68
+ self.inplanes = 64
69
+ self.dilation = 1
70
+ if replace_stride_with_dilation is None:
71
+ replace_stride_with_dilation = [False, False, False]
72
+ if len(replace_stride_with_dilation) != 3:
73
+ raise ValueError("replace_stride_with_dilation should be None "
74
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
75
+ self.groups = groups
76
+ self.base_width = width_per_group
77
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
78
+ self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
79
+ self.prelu = nn.PReLU(self.inplanes)
80
+ self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
81
+ self.layer2 = self._make_layer(block,
82
+ 128,
83
+ layers[1],
84
+ stride=2,
85
+ dilate=replace_stride_with_dilation[0])
86
+ self.layer3 = self._make_layer(block,
87
+ 256,
88
+ layers[2],
89
+ stride=2,
90
+ dilate=replace_stride_with_dilation[1])
91
+ self.layer4 = self._make_layer(block,
92
+ 512,
93
+ layers[3],
94
+ stride=2,
95
+ dilate=replace_stride_with_dilation[2])
96
+ self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
97
+ self.dropout = nn.Dropout(p=dropout, inplace=True)
98
+ self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
99
+ self.features = nn.BatchNorm1d(num_features, eps=1e-05)
100
+ nn.init.constant_(self.features.weight, 1.0)
101
+ self.features.weight.requires_grad = False
102
+
103
+ for m in self.modules():
104
+ if isinstance(m, nn.Conv2d):
105
+ nn.init.normal_(m.weight, 0, 0.1)
106
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
107
+ nn.init.constant_(m.weight, 1)
108
+ nn.init.constant_(m.bias, 0)
109
+
110
+ if zero_init_residual:
111
+ for m in self.modules():
112
+ if isinstance(m, IBasicBlock):
113
+ nn.init.constant_(m.bn2.weight, 0)
114
+
115
+ def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
116
+ downsample = None
117
+ previous_dilation = self.dilation
118
+ if dilate:
119
+ self.dilation *= stride
120
+ stride = 1
121
+ if stride != 1 or self.inplanes != planes * block.expansion:
122
+ downsample = nn.Sequential(
123
+ conv1x1(self.inplanes, planes * block.expansion, stride),
124
+ nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
125
+ )
126
+ layers = []
127
+ layers.append(
128
+ block(self.inplanes, planes, stride, downsample, self.groups,
129
+ self.base_width, previous_dilation))
130
+ self.inplanes = planes * block.expansion
131
+ for _ in range(1, blocks):
132
+ layers.append(
133
+ block(self.inplanes,
134
+ planes,
135
+ groups=self.groups,
136
+ base_width=self.base_width,
137
+ dilation=self.dilation))
138
+
139
+ return nn.Sequential(*layers)
140
+
141
+ def forward(self, x, return_id512=False):
142
+
143
+ bz = x.shape[0]
144
+ # with torch.cuda.amp.autocast(self.fp16):
145
+ x = self.conv1(x)
146
+ x = self.bn1(x)
147
+ x = self.prelu(x)
148
+ x = self.layer1(x)
149
+ x = self.layer2(x)
150
+ x = self.layer3(x)
151
+ x = self.layer4(x)
152
+ if not return_id512:
153
+ return x.view(bz,512,-1).permute(0,2,1).contiguous()
154
+ else:
155
+ x = self.bn2(x)
156
+ x = torch.flatten(x, 1)
157
+ # x = self.dropout(x)
158
+ # x = self.fc(x.float() if self.fp16 else x)
159
+ x = self.fc(x)
160
+ x = self.features(x)
161
+ return x
162
+
163
+
164
+
165
+ def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
166
+ model = IResNet(block, layers, **kwargs)
167
+ if pretrained:
168
+ raise ValueError()
169
+ return model
170
+
171
+
172
+ def iresnet18(pretrained=False, progress=True, **kwargs):
173
+ return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
174
+ progress, **kwargs)
175
+
176
+
177
+ def iresnet34(pretrained=False, progress=True, **kwargs):
178
+ return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
179
+ progress, **kwargs)
180
+
181
+
182
+ def iresnet50(pretrained=False, progress=True, **kwargs):
183
+ return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
184
+ progress, **kwargs)
185
+
186
+
187
+ def iresnet100(pretrained=False, progress=True, **kwargs):
188
+ return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
189
+ progress, **kwargs)
190
+
191
+
192
+ def iresnet200(pretrained=False, progress=True, **kwargs):
193
+ return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
194
+ progress, **kwargs)
195
+
196
+
197
+ def getarcface(pretrained=None):
198
+ model = iresnet100().eval()
199
+ for param in model.parameters():
200
+ param.requires_grad=False
201
+
202
+ if pretrained is not None and os.path.exists(pretrained):
203
+ info = model.load_state_dict(torch.load(pretrained))
204
+ print(info)
205
+ return model
206
+
207
+
208
+ if __name__=='__main__':
209
+ ckpt = 'pretrained/insightface_glint360k.pth'
210
+ arcface = iresnet100().eval()
211
+ info = arcface.load_state_dict(torch.load(ckpt))
212
+ print(info)
213
+
214
+ id = arcface(torch.randn(1,3,128,128))
215
+ print(id.shape)
216
+
217
+ # import cv2
218
+ # import numpy as np
219
+ # im1_crop256 = cv2.imread('happy.jpg')
220
+ # im2_crop256 = cv2.imread('angry.jpg')
221
+
222
+ # im1_crop112 = cv2.resize(im1_crop256, (128,128))[0:112,8:120,:]
223
+ # im2_crop112 = cv2.resize(im2_crop256, (128,128))[0:112,8:120,:]
224
+
225
+ # cv2.imwrite('1_112.jpg', im1_crop112)
226
+ # cv2.imwrite('2_112.jpg', im2_crop112)
227
+
228
+ # # [-1,1] rgb
229
+ # im1_crop112_tensor = torch.from_numpy(im1_crop112[:,:,[2,1,0]].transpose(2, 0, 1).astype(np.float32)).unsqueeze(0)/127.5-1
230
+ # im2_crop112_tensor = torch.from_numpy(im2_crop112[:,:,[2,1,0]].transpose(2, 0, 1).astype(np.float32)).unsqueeze(0)/127.5-1
231
+
232
+ # im1_id = arcface(im1_crop112_tensor)
233
+ # im2_id = arcface(im2_crop112_tensor)
234
+
235
+ # loss_cos = torch.mean(1-torch.cosine_similarity(im1_id, im2_id, dim=1))
236
+
237
+ # print(loss_cos)
utils/third_party/model_parsing.py ADDED
@@ -0,0 +1,381 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python
2
+ # -*- encoding: utf-8 -*-
3
+
4
+ import os
5
+ import torch
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+ import torchvision
9
+ import torch.utils.model_zoo as modelzoo
10
+
11
+ resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
12
+
13
+
14
+ def conv3x3(in_planes, out_planes, stride=1):
15
+ """3x3 convolution with padding"""
16
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
17
+ padding=1, bias=False)
18
+
19
+
20
+ class BasicBlock(nn.Module):
21
+ def __init__(self, in_chan, out_chan, stride=1):
22
+ super(BasicBlock, self).__init__()
23
+ self.conv1 = conv3x3(in_chan, out_chan, stride)
24
+ self.bn1 = nn.BatchNorm2d(out_chan)
25
+ self.conv2 = conv3x3(out_chan, out_chan)
26
+ self.bn2 = nn.BatchNorm2d(out_chan)
27
+ self.relu = nn.ReLU(inplace=True)
28
+ self.downsample = None
29
+ if in_chan != out_chan or stride != 1:
30
+ self.downsample = nn.Sequential(
31
+ nn.Conv2d(in_chan, out_chan,
32
+ kernel_size=1, stride=stride, bias=False),
33
+ nn.BatchNorm2d(out_chan),
34
+ )
35
+
36
+ def forward(self, x):
37
+ residual = self.conv1(x)
38
+ residual = F.relu(self.bn1(residual))
39
+ residual = self.conv2(residual)
40
+ residual = self.bn2(residual)
41
+
42
+ shortcut = x
43
+ if self.downsample is not None:
44
+ shortcut = self.downsample(x)
45
+
46
+ out = shortcut + residual
47
+ out = self.relu(out)
48
+ return out
49
+
50
+
51
+ def create_layer_basic(in_chan, out_chan, bnum, stride=1):
52
+ layers = [BasicBlock(in_chan, out_chan, stride=stride)]
53
+ for i in range(bnum-1):
54
+ layers.append(BasicBlock(out_chan, out_chan, stride=1))
55
+ return nn.Sequential(*layers)
56
+
57
+
58
+ class Resnet18(nn.Module):
59
+ def __init__(self):
60
+ super(Resnet18, self).__init__()
61
+ self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
62
+ bias=False)
63
+ self.bn1 = nn.BatchNorm2d(64)
64
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
65
+ self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
66
+ self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
67
+ self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
68
+ self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
69
+ self.init_weight()
70
+
71
+ def forward(self, x):
72
+ x = self.conv1(x)
73
+ x = F.relu(self.bn1(x))
74
+ x = self.maxpool(x)
75
+
76
+ x = self.layer1(x)
77
+ feat8 = self.layer2(x) # 1/8
78
+ feat16 = self.layer3(feat8) # 1/16
79
+ feat32 = self.layer4(feat16) # 1/32
80
+ return feat8, feat16, feat32
81
+
82
+ def init_weight(self):
83
+ if os.path.isfile('./checkpoints/third_party/resnet18-5c106cde.pth'):
84
+ state_dict = torch.load('./checkpoints/third_party/resnet18-5c106cde.pth')
85
+ else:
86
+ state_dict = modelzoo.load_url(resnet18_url)
87
+ self_state_dict = self.state_dict()
88
+ for k, v in state_dict.items():
89
+ if 'fc' in k: continue
90
+ self_state_dict.update({k: v})
91
+ self.load_state_dict(self_state_dict)
92
+
93
+ def get_params(self):
94
+ wd_params, nowd_params = [], []
95
+ for name, module in self.named_modules():
96
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
97
+ wd_params.append(module.weight)
98
+ if not module.bias is None:
99
+ nowd_params.append(module.bias)
100
+ elif isinstance(module, nn.BatchNorm2d):
101
+ nowd_params += list(module.parameters())
102
+ return wd_params, nowd_params
103
+
104
+
105
+
106
+ class ConvBNReLU(nn.Module):
107
+ def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
108
+ super(ConvBNReLU, self).__init__()
109
+ self.conv = nn.Conv2d(in_chan,
110
+ out_chan,
111
+ kernel_size = ks,
112
+ stride = stride,
113
+ padding = padding,
114
+ bias = False)
115
+ self.bn = nn.BatchNorm2d(out_chan)
116
+ self.init_weight()
117
+
118
+ def forward(self, x):
119
+ x = self.conv(x)
120
+ x = F.relu(self.bn(x))
121
+ return x
122
+
123
+ def init_weight(self):
124
+ for ly in self.children():
125
+ if isinstance(ly, nn.Conv2d):
126
+ nn.init.kaiming_normal_(ly.weight, a=1)
127
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
128
+
129
+ class BiSeNetOutput(nn.Module):
130
+ def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
131
+ super(BiSeNetOutput, self).__init__()
132
+ self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
133
+ self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
134
+ self.init_weight()
135
+
136
+ def forward(self, x):
137
+ x = self.conv(x)
138
+ x = self.conv_out(x)
139
+ return x
140
+
141
+ def init_weight(self):
142
+ for ly in self.children():
143
+ if isinstance(ly, nn.Conv2d):
144
+ nn.init.kaiming_normal_(ly.weight, a=1)
145
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
146
+
147
+ def get_params(self):
148
+ wd_params, nowd_params = [], []
149
+ for name, module in self.named_modules():
150
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
151
+ wd_params.append(module.weight)
152
+ if not module.bias is None:
153
+ nowd_params.append(module.bias)
154
+ elif isinstance(module, nn.BatchNorm2d):
155
+ nowd_params += list(module.parameters())
156
+ return wd_params, nowd_params
157
+
158
+
159
+ class AttentionRefinementModule(nn.Module):
160
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
161
+ super(AttentionRefinementModule, self).__init__()
162
+ self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
163
+ self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
164
+ self.bn_atten = nn.BatchNorm2d(out_chan)
165
+ self.sigmoid_atten = nn.Sigmoid()
166
+ self.init_weight()
167
+
168
+ def forward(self, x):
169
+ feat = self.conv(x)
170
+ atten = F.avg_pool2d(feat, feat.size()[2:])
171
+ atten = self.conv_atten(atten)
172
+ atten = self.bn_atten(atten)
173
+ atten = self.sigmoid_atten(atten)
174
+ out = torch.mul(feat, atten)
175
+ return out
176
+
177
+ def init_weight(self):
178
+ for ly in self.children():
179
+ if isinstance(ly, nn.Conv2d):
180
+ nn.init.kaiming_normal_(ly.weight, a=1)
181
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
182
+
183
+
184
+ class ContextPath(nn.Module):
185
+ def __init__(self, *args, **kwargs):
186
+ super(ContextPath, self).__init__()
187
+ self.resnet = Resnet18()
188
+ self.arm16 = AttentionRefinementModule(256, 128)
189
+ self.arm32 = AttentionRefinementModule(512, 128)
190
+ self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
191
+ self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
192
+ self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
193
+
194
+ self.init_weight()
195
+
196
+ def forward(self, x):
197
+ H0, W0 = x.size()[2:]
198
+ feat8, feat16, feat32 = self.resnet(x)
199
+ H8, W8 = feat8.size()[2:]
200
+ H16, W16 = feat16.size()[2:]
201
+ H32, W32 = feat32.size()[2:]
202
+
203
+ avg = F.avg_pool2d(feat32, feat32.size()[2:])
204
+ avg = self.conv_avg(avg)
205
+ avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
206
+
207
+ feat32_arm = self.arm32(feat32)
208
+ feat32_sum = feat32_arm + avg_up
209
+ feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
210
+ feat32_up = self.conv_head32(feat32_up)
211
+
212
+ feat16_arm = self.arm16(feat16)
213
+ feat16_sum = feat16_arm + feat32_up
214
+ feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
215
+ feat16_up = self.conv_head16(feat16_up)
216
+
217
+ return feat8, feat16_up, feat32_up # x8, x8, x16
218
+
219
+ def init_weight(self):
220
+ for ly in self.children():
221
+ if isinstance(ly, nn.Conv2d):
222
+ nn.init.kaiming_normal_(ly.weight, a=1)
223
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
224
+
225
+ def get_params(self):
226
+ wd_params, nowd_params = [], []
227
+ for name, module in self.named_modules():
228
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
229
+ wd_params.append(module.weight)
230
+ if not module.bias is None:
231
+ nowd_params.append(module.bias)
232
+ elif isinstance(module, nn.BatchNorm2d):
233
+ nowd_params += list(module.parameters())
234
+ return wd_params, nowd_params
235
+
236
+
237
+ ### This is not used, since I replace this with the resnet feature with the same size
238
+ class SpatialPath(nn.Module):
239
+ def __init__(self, *args, **kwargs):
240
+ super(SpatialPath, self).__init__()
241
+ self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
242
+ self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
243
+ self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
244
+ self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
245
+ self.init_weight()
246
+
247
+ def forward(self, x):
248
+ feat = self.conv1(x)
249
+ feat = self.conv2(feat)
250
+ feat = self.conv3(feat)
251
+ feat = self.conv_out(feat)
252
+ return feat
253
+
254
+ def init_weight(self):
255
+ for ly in self.children():
256
+ if isinstance(ly, nn.Conv2d):
257
+ nn.init.kaiming_normal_(ly.weight, a=1)
258
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
259
+
260
+ def get_params(self):
261
+ wd_params, nowd_params = [], []
262
+ for name, module in self.named_modules():
263
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
264
+ wd_params.append(module.weight)
265
+ if not module.bias is None:
266
+ nowd_params.append(module.bias)
267
+ elif isinstance(module, nn.BatchNorm2d):
268
+ nowd_params += list(module.parameters())
269
+ return wd_params, nowd_params
270
+
271
+
272
+ class FeatureFusionModule(nn.Module):
273
+ def __init__(self, in_chan, out_chan, *args, **kwargs):
274
+ super(FeatureFusionModule, self).__init__()
275
+ self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
276
+ self.conv1 = nn.Conv2d(out_chan,
277
+ out_chan//4,
278
+ kernel_size = 1,
279
+ stride = 1,
280
+ padding = 0,
281
+ bias = False)
282
+ self.conv2 = nn.Conv2d(out_chan//4,
283
+ out_chan,
284
+ kernel_size = 1,
285
+ stride = 1,
286
+ padding = 0,
287
+ bias = False)
288
+ self.relu = nn.ReLU(inplace=True)
289
+ self.sigmoid = nn.Sigmoid()
290
+ self.init_weight()
291
+
292
+ def forward(self, fsp, fcp):
293
+ fcat = torch.cat([fsp, fcp], dim=1)
294
+ feat = self.convblk(fcat)
295
+ atten = F.avg_pool2d(feat, feat.size()[2:])
296
+ atten = self.conv1(atten)
297
+ atten = self.relu(atten)
298
+ atten = self.conv2(atten)
299
+ atten = self.sigmoid(atten)
300
+ feat_atten = torch.mul(feat, atten)
301
+ feat_out = feat_atten + feat
302
+ return feat_out
303
+
304
+ def init_weight(self):
305
+ for ly in self.children():
306
+ if isinstance(ly, nn.Conv2d):
307
+ nn.init.kaiming_normal_(ly.weight, a=1)
308
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
309
+
310
+ def get_params(self):
311
+ wd_params, nowd_params = [], []
312
+ for name, module in self.named_modules():
313
+ if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
314
+ wd_params.append(module.weight)
315
+ if not module.bias is None:
316
+ nowd_params.append(module.bias)
317
+ elif isinstance(module, nn.BatchNorm2d):
318
+ nowd_params += list(module.parameters())
319
+ return wd_params, nowd_params
320
+
321
+
322
+ class BiSeNet(nn.Module):
323
+ def __init__(self, n_classes, *args, **kwargs):
324
+ super(BiSeNet, self).__init__()
325
+ self.cp = ContextPath()
326
+ ## here self.sp is deleted
327
+ self.ffm = FeatureFusionModule(256, 256)
328
+ self.conv_out = BiSeNetOutput(256, 256, n_classes)
329
+ self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
330
+ self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
331
+ self.init_weight()
332
+
333
+ def forward(self, x):
334
+ H, W = x.size()[2:]
335
+ feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
336
+ feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
337
+ feat_fuse = self.ffm(feat_sp, feat_cp8)
338
+
339
+ feat_out = self.conv_out(feat_fuse)
340
+ feat_out16 = self.conv_out16(feat_cp8)
341
+ feat_out32 = self.conv_out32(feat_cp16)
342
+
343
+ feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
344
+ feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
345
+ feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
346
+ return feat_out, feat_out16, feat_out32
347
+
348
+ def init_weight(self):
349
+ for ly in self.children():
350
+ if isinstance(ly, nn.Conv2d):
351
+ nn.init.kaiming_normal_(ly.weight, a=1)
352
+ if not ly.bias is None: nn.init.constant_(ly.bias, 0)
353
+
354
+ def get_params(self):
355
+ wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
356
+ for name, child in self.named_children():
357
+ child_wd_params, child_nowd_params = child.get_params()
358
+ if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
359
+ lr_mul_wd_params += child_wd_params
360
+ lr_mul_nowd_params += child_nowd_params
361
+ else:
362
+ wd_params += child_wd_params
363
+ nowd_params += child_nowd_params
364
+ return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
365
+
366
+
367
+ def get_face_parsing(save_pth = 'third_party/pretrained/79999_iter.pth'):
368
+ net = BiSeNet(n_classes=19)
369
+ net.load_state_dict(torch.load(save_pth))
370
+ return net
371
+
372
+
373
+ if __name__ == "__main__":
374
+ net = BiSeNet(19)
375
+ net.cuda()
376
+ net.eval()
377
+ in_ten = torch.randn(16, 3, 640, 480).cuda()
378
+ out, out16, out32 = net(in_ten)
379
+ print(out.shape)
380
+
381
+ net.get_params()
utils/third_party/model_resnet_d3dfr.py ADDED
@@ -0,0 +1,554 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch.nn as nn
3
+ from torch.nn import Linear
4
+ from torch.nn import Conv2d
5
+ from torch.nn import BatchNorm1d
6
+ from torch.nn import BatchNorm2d
7
+ from torch.nn import ReLU
8
+ from torch.nn import Dropout
9
+ try:
10
+ from torch.hub import load_state_dict_from_url
11
+ except ImportError:
12
+ from torch.utils.model_zoo import load_url as load_state_dict_from_url
13
+ from torch.nn import MaxPool2d
14
+ from torch.nn import Sequential
15
+ from torch.nn import Module
16
+ import torch
17
+ from torch import Tensor
18
+ from typing import Type, Any, Callable, Union, List, Optional
19
+
20
+
21
+ model_urls = {
22
+ 'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
23
+ }
24
+
25
+ def filter_state_dict(state_dict, remove_name='fc'):
26
+ new_state_dict = {}
27
+ for key in state_dict:
28
+ if remove_name in key:
29
+ continue
30
+ new_state_dict[key] = state_dict[key]
31
+ return new_state_dict
32
+
33
+ def conv3x3(in_planes, out_planes, stride=1):
34
+ """ 3x3 convolution with padding
35
+ """
36
+ return Conv2d(in_planes,
37
+ out_planes,
38
+ kernel_size=3,
39
+ stride=stride,
40
+ padding=1,
41
+ bias=False)
42
+
43
+
44
+ def conv1x1(in_planes, out_planes, stride=1, bias=False):
45
+ """ 1x1 convolution
46
+ """
47
+ return Conv2d(in_planes,
48
+ out_planes,
49
+ kernel_size=1,
50
+ stride=stride,
51
+ bias=bias)
52
+
53
+ def conv3x3_(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
54
+ """3x3 convolution with padding"""
55
+ return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
56
+ padding=dilation, groups=groups, bias=False, dilation=dilation)
57
+
58
+
59
+ def conv1x1_(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d:
60
+ """1x1 convolution"""
61
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
62
+
63
+
64
+ class Bottleneck(Module):
65
+ expansion = 4
66
+
67
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
68
+ super(Bottleneck, self).__init__()
69
+ self.conv1 = conv1x1(inplanes, planes)
70
+ self.bn1 = BatchNorm2d(planes)
71
+ self.conv2 = conv3x3(planes, planes, stride)
72
+ self.bn2 = BatchNorm2d(planes)
73
+ self.conv3 = conv1x1(planes, planes * self.expansion)
74
+ self.bn3 = BatchNorm2d(planes * self.expansion)
75
+ self.relu = ReLU(inplace=True)
76
+ self.downsample = downsample
77
+ self.stride = stride
78
+
79
+ def forward(self, x):
80
+ identity = x
81
+
82
+ out = self.conv1(x)
83
+ out = self.bn1(out)
84
+ out = self.relu(out)
85
+
86
+ out = self.conv2(out)
87
+ out = self.bn2(out)
88
+ out = self.relu(out)
89
+
90
+ out = self.conv3(out)
91
+ out = self.bn3(out)
92
+
93
+ if self.downsample is not None:
94
+ identity = self.downsample(x)
95
+
96
+ out += identity
97
+ out = self.relu(out)
98
+
99
+ return out
100
+
101
+ class Bottleneck_(nn.Module):
102
+ # Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
103
+ # while original implementation places the stride at the first 1x1 convolution(self.conv1)
104
+ # according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
105
+ # This variant is also known as ResNet V1.5 and improves accuracy according to
106
+ # https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
107
+
108
+ expansion: int = 4
109
+
110
+ def __init__(
111
+ self,
112
+ inplanes: int,
113
+ planes: int,
114
+ stride: int = 1,
115
+ downsample: Optional[nn.Module] = None,
116
+ groups: int = 1,
117
+ base_width: int = 64,
118
+ dilation: int = 1,
119
+ norm_layer: Optional[Callable[..., nn.Module]] = None
120
+ ) -> None:
121
+ super(Bottleneck_, self).__init__()
122
+ if norm_layer is None:
123
+ norm_layer = nn.BatchNorm2d
124
+ width = int(planes * (base_width / 64.)) * groups
125
+ # Both self.conv2 and self.downsample layers downsample the input when stride != 1
126
+ self.conv1 = conv1x1_(inplanes, width)
127
+ self.bn1 = norm_layer(width)
128
+ self.conv2 = conv3x3_(width, width, stride, groups, dilation)
129
+ self.bn2 = norm_layer(width)
130
+ self.conv3 = conv1x1_(width, planes * self.expansion)
131
+ self.bn3 = norm_layer(planes * self.expansion)
132
+ self.relu = nn.ReLU(inplace=True)
133
+ self.downsample = downsample
134
+ self.stride = stride
135
+
136
+ def forward(self, x: Tensor) -> Tensor:
137
+ identity = x
138
+
139
+ out = self.conv1(x)
140
+ out = self.bn1(out)
141
+ out = self.relu(out)
142
+
143
+ out = self.conv2(out)
144
+ out = self.bn2(out)
145
+ out = self.relu(out)
146
+
147
+ out = self.conv3(out)
148
+ out = self.bn3(out)
149
+
150
+ if self.downsample is not None:
151
+ identity = self.downsample(x)
152
+
153
+ out += identity
154
+ out = self.relu(out)
155
+
156
+ return out
157
+
158
+ class BasicBlock(nn.Module):
159
+ expansion: int = 1
160
+
161
+ def __init__(
162
+ self,
163
+ inplanes: int,
164
+ planes: int,
165
+ stride: int = 1,
166
+ downsample: Optional[nn.Module] = None,
167
+ groups: int = 1,
168
+ base_width: int = 64,
169
+ dilation: int = 1,
170
+ norm_layer: Optional[Callable[..., nn.Module]] = None
171
+ ) -> None:
172
+ super(BasicBlock, self).__init__()
173
+ if norm_layer is None:
174
+ norm_layer = nn.BatchNorm2d
175
+ if groups != 1 or base_width != 64:
176
+ raise ValueError('BasicBlock only supports groups=1 and base_width=64')
177
+ if dilation > 1:
178
+ raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
179
+ # Both self.conv1 and self.downsample layers downsample the input when stride != 1
180
+ self.conv1 = conv3x3(inplanes, planes, stride)
181
+ self.bn1 = norm_layer(planes)
182
+ self.relu = nn.ReLU(inplace=True)
183
+ self.conv2 = conv3x3(planes, planes)
184
+ self.bn2 = norm_layer(planes)
185
+ self.downsample = downsample
186
+ self.stride = stride
187
+
188
+ def forward(self, x: Tensor) -> Tensor:
189
+ identity = x
190
+
191
+ out = self.conv1(x)
192
+ out = self.bn1(out)
193
+ out = self.relu(out)
194
+
195
+ out = self.conv2(out)
196
+ out = self.bn2(out)
197
+
198
+ if self.downsample is not None:
199
+ identity = self.downsample(x)
200
+
201
+ out += identity
202
+ out = self.relu(out)
203
+
204
+ return out
205
+
206
+ class ResNet(Module):
207
+ """ ResNet backbone
208
+ """
209
+ def __init__(self, input_size, block, layers, zero_init_residual=True):
210
+ """ Args:
211
+ input_size: input_size of backbone
212
+ block: block function
213
+ layers: layers in each block
214
+ """
215
+ super(ResNet, self).__init__()
216
+ assert input_size[0] in [112, 224], \
217
+ "input_size should be [112, 112] or [224, 224]"
218
+ self.inplanes = 64
219
+ self.conv1 = Conv2d(3, 64,
220
+ kernel_size=7,
221
+ stride=2,
222
+ padding=3,
223
+ bias=False)
224
+ self.bn1 = BatchNorm2d(64)
225
+ self.relu = ReLU(inplace=True)
226
+ self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
227
+ self.layer1 = self._make_layer(block, 64, layers[0])
228
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
229
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
230
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
231
+
232
+ self.bn_o1 = BatchNorm2d(2048)
233
+ self.dropout = Dropout()
234
+ if input_size[0] == 112:
235
+ self.fc = Linear(2048 * 4 * 4, 512)
236
+ else:
237
+ self.fc = Linear(2048 * 7 * 7, 512)
238
+ self.bn_o2 = BatchNorm1d(512)
239
+
240
+ # initialize_weights(self.modules)
241
+ if zero_init_residual:
242
+ for m in self.modules():
243
+ if isinstance(m, Bottleneck):
244
+ nn.init.constant_(m.bn3.weight, 0)
245
+
246
+ def _make_layer(self, block, planes, blocks, stride=1):
247
+ downsample = None
248
+ if stride != 1 or self.inplanes != planes * block.expansion:
249
+ downsample = Sequential(
250
+ conv1x1(self.inplanes, planes * block.expansion, stride),
251
+ BatchNorm2d(planes * block.expansion),
252
+ )
253
+
254
+ layers = []
255
+ layers.append(block(self.inplanes, planes, stride, downsample))
256
+ self.inplanes = planes * block.expansion
257
+ for _ in range(1, blocks):
258
+ layers.append(block(self.inplanes, planes))
259
+
260
+ return Sequential(*layers)
261
+
262
+ def forward(self, x):
263
+ x = self.conv1(x)
264
+ x = self.bn1(x)
265
+ x = self.relu(x)
266
+ x = self.maxpool(x)
267
+
268
+ x = self.layer1(x)
269
+ x = self.layer2(x)
270
+ x = self.layer3(x)
271
+ x = self.layer4(x)
272
+
273
+ x = self.bn_o1(x)
274
+ x = self.dropout(x)
275
+ x = x.view(x.size(0), -1)
276
+ x = self.fc(x)
277
+ x = self.bn_o2(x)
278
+
279
+ return x
280
+
281
+
282
+ class resNet(nn.Module): # ori resnet
283
+
284
+ def __init__(
285
+ self,
286
+ block_: Type[Union[BasicBlock, Bottleneck_]],
287
+ layers: List[int],
288
+ num_classes: int = 1000,
289
+ zero_init_residual: bool = False,
290
+ use_last_fc: bool = False,
291
+ groups: int = 1,
292
+ width_per_group: int = 64,
293
+ replace_stride_with_dilation: Optional[List[bool]] = None,
294
+ norm_layer: Optional[Callable[..., nn.Module]] = None
295
+ ) -> None:
296
+ super(resNet, self).__init__()
297
+ if norm_layer is None:
298
+ norm_layer = nn.BatchNorm2d
299
+ self._norm_layer = norm_layer
300
+
301
+ self.inplanes = 64
302
+ self.dilation = 1
303
+ if replace_stride_with_dilation is None:
304
+ # each element in the tuple indicates if we should replace
305
+ # the 2x2 stride with a dilated convolution instead
306
+ replace_stride_with_dilation = [False, False, False]
307
+ if len(replace_stride_with_dilation) != 3:
308
+ raise ValueError("replace_stride_with_dilation should be None "
309
+ "or a 3-element tuple, got {}".format(replace_stride_with_dilation))
310
+ self.use_last_fc = use_last_fc
311
+ self.groups = groups
312
+ self.base_width = width_per_group
313
+ self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
314
+ bias=False)
315
+ self.bn1 = norm_layer(self.inplanes)
316
+ self.relu = nn.ReLU(inplace=True)
317
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
318
+ self.layer1 = self._make_layer(block_, 64, layers[0])
319
+ self.layer2 = self._make_layer(block_, 128, layers[1], stride=2,
320
+ dilate=replace_stride_with_dilation[0])
321
+ self.layer3 = self._make_layer(block_, 256, layers[2], stride=2,
322
+ dilate=replace_stride_with_dilation[1])
323
+ self.layer4 = self._make_layer(block_, 512, layers[3], stride=2,
324
+ dilate=replace_stride_with_dilation[2])
325
+ self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
326
+
327
+ if self.use_last_fc:
328
+ self.fc = nn.Linear(512 * block_.expansion, num_classes)
329
+
330
+ for m in self.modules():
331
+ if isinstance(m, nn.Conv2d):
332
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
333
+ elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
334
+ nn.init.constant_(m.weight, 1)
335
+ nn.init.constant_(m.bias, 0)
336
+
337
+ # Zero-initialize the last BN in each residual branch,
338
+ # so that the residual branch starts with zeros, and each residual block behaves like an identity.
339
+ # This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
340
+ if zero_init_residual:
341
+ for m in self.modules():
342
+ if isinstance(m, Bottleneck_):
343
+ nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
344
+ elif isinstance(m, BasicBlock):
345
+ nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
346
+
347
+ def _make_layer(self, block_: Type[Union[BasicBlock, Bottleneck_]], planes: int, blocks: int,
348
+ stride: int = 1, dilate: bool = False) -> nn.Sequential:
349
+ norm_layer = self._norm_layer
350
+ downsample = None
351
+ previous_dilation = self.dilation
352
+ if dilate:
353
+ self.dilation *= stride
354
+ stride = 1
355
+ if stride != 1 or self.inplanes != planes * block_.expansion:
356
+ downsample = nn.Sequential(
357
+ conv1x1(self.inplanes, planes * block_.expansion, stride),
358
+ norm_layer(planes * block_.expansion),
359
+ )
360
+
361
+ layers = []
362
+ layers.append(block_(self.inplanes, planes, stride, downsample, self.groups,
363
+ self.base_width, previous_dilation, norm_layer))
364
+ self.inplanes = planes * block_.expansion
365
+ for _ in range(1, blocks):
366
+ layers.append(block_(self.inplanes, planes, groups=self.groups,
367
+ base_width=self.base_width, dilation=self.dilation,
368
+ norm_layer=norm_layer))
369
+
370
+ return nn.Sequential(*layers)
371
+
372
+ def _forward_impl(self, x: Tensor) -> Tensor:
373
+ # See note [TorchScript super()]
374
+ x = self.conv1(x)
375
+ x = self.bn1(x)
376
+ x = self.relu(x)
377
+ x = self.maxpool(x)
378
+
379
+ x = self.layer1(x)
380
+ x = self.layer2(x)
381
+ x = self.layer3(x)
382
+ x = self.layer4(x)
383
+
384
+ x = self.avgpool(x)
385
+ if self.use_last_fc:
386
+ x = torch.flatten(x, 1)
387
+ x = self.fc(x)
388
+ return x
389
+
390
+ def forward(self, x: Tensor) -> Tensor:
391
+ return self._forward_impl(x)
392
+
393
+ def ResNet_50(input_size, **kwargs):
394
+ """ Constructs a ResNet-50 model.
395
+ """
396
+ model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs)
397
+
398
+ return model
399
+
400
+
401
+ class ResNet50_nofc(Module):
402
+ """ ResNet backbone
403
+ """
404
+ def __init__(self, input_size, output_dim, use_last_fc=False, init_path=None):
405
+ """ Args:
406
+ input_size: input_size of backbone
407
+ block: block function
408
+ layers: layers in each block
409
+ """
410
+ super(ResNet50_nofc, self).__init__()
411
+ assert input_size[0] in [112, 224, 256], \
412
+ "input_size should be [112, 112] or [224, 224]"
413
+ func, last_dim = func_dict['resnet50']
414
+ self.use_last_fc=use_last_fc
415
+ backbone = func(use_last_fc=use_last_fc, num_classes=output_dim)
416
+ if init_path and os.path.isfile(init_path):
417
+ state_dict = filter_state_dict(torch.load(init_path, map_location='cpu'))
418
+ backbone.load_state_dict(state_dict)
419
+ print("Loading init recon %s from %s"%('resnet50', init_path))
420
+ self.backbone = backbone
421
+ if not use_last_fc:
422
+ self.fianl_layers = nn.ModuleList([
423
+ conv1x1(last_dim, 80, bias=True), # id
424
+ conv1x1(last_dim, 64, bias=True), # exp
425
+ conv1x1(last_dim, 80, bias=True), # tex
426
+ conv1x1(last_dim, 3, bias=True), # angle
427
+ conv1x1(last_dim, 27, bias=True), # gamma
428
+ conv1x1(last_dim, 2, bias=True), # tx, ty
429
+ conv1x1(last_dim, 1, bias=True), # tz
430
+ conv1x1(last_dim, 4, bias=True) # pupil
431
+ ])
432
+ for m in self.fianl_layers:
433
+ nn.init.constant_(m.weight, 0.)
434
+ nn.init.constant_(m.bias, 0.)
435
+
436
+
437
+ def forward(self, x):
438
+ x = self.backbone(x)
439
+ if not self.use_last_fc:
440
+ output = []
441
+ for layer in self.fianl_layers:
442
+ output.append(layer(x))
443
+ x = torch.flatten(torch.cat(output, dim=1), 1)
444
+ return x
445
+
446
+
447
+ def _resnet(
448
+ arch: str,
449
+ block: Type[Union[BasicBlock, Bottleneck_]],
450
+ layers: List[int],
451
+ pretrained: bool,
452
+ progress: bool,
453
+ **kwargs: Any
454
+ ) -> ResNet:
455
+ model = resNet(block, layers, **kwargs)
456
+ if pretrained:
457
+ state_dict = load_state_dict_from_url(model_urls[arch],
458
+ progress=progress)
459
+ model.load_state_dict(state_dict)
460
+ return model
461
+
462
+ def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> resNet:
463
+ r"""ResNet-50 model from
464
+ `"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
465
+
466
+ Args:
467
+ pretrained (bool): If True, returns a model pre-trained on ImageNet
468
+ progress (bool): If True, displays a progress bar of the download to stderr
469
+ """
470
+ return _resnet('resnet50', Bottleneck_, [3, 4, 6, 3], pretrained, progress,
471
+ **kwargs)
472
+
473
+
474
+ func_dict = {
475
+ 'resnet50': (resnet50, 2048),
476
+ }
477
+
478
+
479
+ class Identity(nn.Module):
480
+ def __init__(self):
481
+ super(Identity, self).__init__()
482
+
483
+ def forward(self, x):
484
+ return x
485
+
486
+
487
+ def fuse(conv, bn):
488
+ w = conv.weight
489
+ mean = bn.running_mean
490
+ var_sqrt = torch.sqrt(bn.running_var + bn.eps)
491
+
492
+ beta = bn.weight
493
+ gamma = bn.bias
494
+
495
+ if conv.bias is not None:
496
+ b = conv.bias
497
+ else:
498
+ b = mean.new_zeros(mean.shape)
499
+
500
+ w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
501
+ b = (b - mean) / var_sqrt * beta + gamma
502
+
503
+ fused_conv = nn.Conv2d(
504
+ conv.in_channels,
505
+ conv.out_channels,
506
+ conv.kernel_size,
507
+ conv.stride,
508
+ conv.padding,
509
+ conv.dilation,
510
+ conv.groups,
511
+ bias=True,
512
+ padding_mode=conv.padding_mode
513
+ )
514
+ fused_conv.weight = nn.Parameter(w)
515
+ fused_conv.bias = nn.Parameter(b)
516
+ return fused_conv
517
+
518
+
519
+ def fuse_module(m):
520
+ children = list(m.named_children())
521
+ conv = None
522
+ conv_name = None
523
+ for name, child in children:
524
+ if isinstance(child, nn.BatchNorm2d) and conv:
525
+ bc = fuse(conv, child)
526
+ m._modules[conv_name] = bc
527
+ m._modules[name] = Identity()
528
+ conv = None
529
+ elif isinstance(child, nn.Conv2d):
530
+ conv = child
531
+ conv_name = name
532
+ else:
533
+ fuse_module(child)
534
+
535
+
536
+ def getd3dfr_res50(pretrained="./d3dfr_res50_nofc.pth"):
537
+ model = ResNet50_nofc([256, 256], 257+4, use_last_fc=False)
538
+ for param in model.parameters():
539
+ param.requires_grad=False
540
+ if pretrained is not None and os.path.exists(pretrained):
541
+ checkpoint_no_module = {}
542
+ checkpoint = torch.load(pretrained, map_location=lambda storage, loc: storage)
543
+ for k, v in checkpoint.items():
544
+ if k.startswith('module'):
545
+ k = k[7:]
546
+ checkpoint_no_module[k] = v
547
+ info = model.load_state_dict(checkpoint_no_module, strict=False)
548
+
549
+ print(pretrained, info)
550
+ model = model.eval()
551
+ fuse_module(model)
552
+ return model
553
+ if __name__ == '__main__':
554
+ model = getd3dfr_res50()
utils/third_party_files/79999_iter.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
3
+ size 53289463
utils/third_party_files/BFM_model_front.mat ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a9f127eb18c3d022acbdbfcf1b6b353d01a6e01785d675a67cc31a3826a5be0f
3
+ size 127170280
utils/third_party_files/d3dfr_res50_nofc.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:52c54b90304a06c16b6813910c26faff1a907d4f8bd69a71ad4ecff43b41a090
3
+ size 96449126
utils/third_party_files/insightface_glint360k.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5f631718e783448b41631e15073bdc622eaeef56509bbad4e5085f23bd32db83
3
+ size 261223796
utils/third_party_files/models/antelopev2/1k3d68.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
3
+ size 143607619
utils/third_party_files/models/antelopev2/2d106det.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
3
+ size 5030888
utils/third_party_files/models/antelopev2/antelopev2.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7353a5fdca5a90e11d2792e0236032b2fe42adc1ea23eaef5cf8c8b57e7e9393
3
+ size 360743453
utils/third_party_files/models/antelopev2/genderage.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
3
+ size 1322532
utils/third_party_files/models/antelopev2/glintr100.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf
3
+ size 260665334
utils/third_party_files/models/antelopev2/scrfd_10g_bnkps.onnx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
3
+ size 16923827
utils/third_party_files/resnet18-5c106cde.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5c106cde386e87d4033832f2996f5493238eda96ccf559d1d62760c4de0613f8
3
+ size 46827520