mengting
commited on
Commit
·
5c17f58
1
Parent(s):
d859277
initial
Browse files- .gitattributes +1 -0
- pre_trained/unet_denoise/checkpoint-70000/config.json +68 -0
- pre_trained/unet_denoise/checkpoint-70000/diffusion_pytorch_model.safetensors +3 -0
- pre_trained/unet_id/checkpoint-70000/config.json +68 -0
- pre_trained/unet_id/checkpoint-70000/diffusion_pytorch_model.safetensors +3 -0
- utils/checkpoints/net_seg_res18.pth +3 -0
- utils/checkpoints/third_party/BFM_model_front.mat +3 -0
- utils/checkpoints/third_party/d3dfr_res50_nofc.pth +3 -0
- utils/third_party/__pycache__/model_resnet_d3dfr.cpython-39.pyc +0 -0
- utils/third_party/d3dfr/__pycache__/bfm.cpython-39.pyc +0 -0
- utils/third_party/d3dfr/bfm.py +473 -0
- utils/third_party/d3dfr_res50_nofc.pth +3 -0
- utils/third_party/insightface_backbone_conv.py +237 -0
- utils/third_party/model_parsing.py +381 -0
- utils/third_party/model_resnet_d3dfr.py +554 -0
- utils/third_party_files/79999_iter.pth +3 -0
- utils/third_party_files/BFM_model_front.mat +3 -0
- utils/third_party_files/d3dfr_res50_nofc.pth +3 -0
- utils/third_party_files/insightface_glint360k.pth +3 -0
- utils/third_party_files/models/antelopev2/1k3d68.onnx +3 -0
- utils/third_party_files/models/antelopev2/2d106det.onnx +3 -0
- utils/third_party_files/models/antelopev2/antelopev2.zip +3 -0
- utils/third_party_files/models/antelopev2/genderage.onnx +3 -0
- utils/third_party_files/models/antelopev2/glintr100.onnx +3 -0
- utils/third_party_files/models/antelopev2/scrfd_10g_bnkps.onnx +3 -0
- utils/third_party_files/resnet18-5c106cde.pth +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.mat filter=lfs diff=lfs merge=lfs -text
|
pre_trained/unet_denoise/checkpoint-70000/config.json
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "UNetDenoise2DConditionModel",
|
| 3 |
+
"_diffusers_version": "0.25.1",
|
| 4 |
+
"_name_or_path": "/scratch/project_462000772/wmengting/diffusion_models/stable-diffusion-v1-5",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"addition_embed_type": null,
|
| 7 |
+
"addition_embed_type_num_heads": 64,
|
| 8 |
+
"addition_time_embed_dim": null,
|
| 9 |
+
"attention_head_dim": 8,
|
| 10 |
+
"attention_type": "default",
|
| 11 |
+
"block_out_channels": [
|
| 12 |
+
320,
|
| 13 |
+
640,
|
| 14 |
+
1280,
|
| 15 |
+
1280
|
| 16 |
+
],
|
| 17 |
+
"center_input_sample": false,
|
| 18 |
+
"class_embed_type": null,
|
| 19 |
+
"class_embeddings_concat": false,
|
| 20 |
+
"conv_in_kernel": 3,
|
| 21 |
+
"conv_out_kernel": 3,
|
| 22 |
+
"cross_attention_dim": 768,
|
| 23 |
+
"cross_attention_norm": null,
|
| 24 |
+
"down_block_types": [
|
| 25 |
+
"CrossAttnDownBlock2D",
|
| 26 |
+
"CrossAttnDownBlock2D",
|
| 27 |
+
"CrossAttnDownBlock2D",
|
| 28 |
+
"DownBlock2D"
|
| 29 |
+
],
|
| 30 |
+
"downsample_padding": 1,
|
| 31 |
+
"dropout": 0.0,
|
| 32 |
+
"dual_cross_attention": false,
|
| 33 |
+
"encoder_hid_dim": null,
|
| 34 |
+
"encoder_hid_dim_type": null,
|
| 35 |
+
"flip_sin_to_cos": true,
|
| 36 |
+
"freq_shift": 0,
|
| 37 |
+
"in_channels": 12,
|
| 38 |
+
"layers_per_block": 2,
|
| 39 |
+
"mid_block_only_cross_attention": null,
|
| 40 |
+
"mid_block_scale_factor": 1,
|
| 41 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
| 42 |
+
"norm_eps": 1e-05,
|
| 43 |
+
"norm_num_groups": 32,
|
| 44 |
+
"num_attention_heads": null,
|
| 45 |
+
"num_class_embeds": null,
|
| 46 |
+
"only_cross_attention": false,
|
| 47 |
+
"out_channels": 4,
|
| 48 |
+
"projection_class_embeddings_input_dim": null,
|
| 49 |
+
"resnet_out_scale_factor": 1.0,
|
| 50 |
+
"resnet_skip_time_act": false,
|
| 51 |
+
"resnet_time_scale_shift": "default",
|
| 52 |
+
"reverse_transformer_layers_per_block": null,
|
| 53 |
+
"sample_size": 64,
|
| 54 |
+
"time_cond_proj_dim": null,
|
| 55 |
+
"time_embedding_act_fn": null,
|
| 56 |
+
"time_embedding_dim": null,
|
| 57 |
+
"time_embedding_type": "positional",
|
| 58 |
+
"timestep_post_act": null,
|
| 59 |
+
"transformer_layers_per_block": 1,
|
| 60 |
+
"up_block_types": [
|
| 61 |
+
"UpBlock2D",
|
| 62 |
+
"CrossAttnUpBlock2D",
|
| 63 |
+
"CrossAttnUpBlock2D",
|
| 64 |
+
"CrossAttnUpBlock2D"
|
| 65 |
+
],
|
| 66 |
+
"upcast_attention": false,
|
| 67 |
+
"use_linear_projection": false
|
| 68 |
+
}
|
pre_trained/unet_denoise/checkpoint-70000/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a5eb8f49aba84c10f607630f28785b55037cde350e5ec90563d1ddd200dd64c8
|
| 3 |
+
size 3438592680
|
pre_trained/unet_id/checkpoint-70000/config.json
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"_class_name": "UNetID2DConditionModel",
|
| 3 |
+
"_diffusers_version": "0.25.1",
|
| 4 |
+
"_name_or_path": "/scratch/project_462000772/wmengting/diffusion_models/stable-diffusion-v1-5",
|
| 5 |
+
"act_fn": "silu",
|
| 6 |
+
"addition_embed_type": null,
|
| 7 |
+
"addition_embed_type_num_heads": 64,
|
| 8 |
+
"addition_time_embed_dim": null,
|
| 9 |
+
"attention_head_dim": 8,
|
| 10 |
+
"attention_type": "default",
|
| 11 |
+
"block_out_channels": [
|
| 12 |
+
320,
|
| 13 |
+
640,
|
| 14 |
+
1280,
|
| 15 |
+
1280
|
| 16 |
+
],
|
| 17 |
+
"center_input_sample": false,
|
| 18 |
+
"class_embed_type": null,
|
| 19 |
+
"class_embeddings_concat": false,
|
| 20 |
+
"conv_in_kernel": 3,
|
| 21 |
+
"conv_out_kernel": 3,
|
| 22 |
+
"cross_attention_dim": 768,
|
| 23 |
+
"cross_attention_norm": null,
|
| 24 |
+
"down_block_types": [
|
| 25 |
+
"CrossAttnDownBlock2D",
|
| 26 |
+
"CrossAttnDownBlock2D",
|
| 27 |
+
"CrossAttnDownBlock2D",
|
| 28 |
+
"DownBlock2D"
|
| 29 |
+
],
|
| 30 |
+
"downsample_padding": 1,
|
| 31 |
+
"dropout": 0.0,
|
| 32 |
+
"dual_cross_attention": false,
|
| 33 |
+
"encoder_hid_dim": null,
|
| 34 |
+
"encoder_hid_dim_type": null,
|
| 35 |
+
"flip_sin_to_cos": true,
|
| 36 |
+
"freq_shift": 0,
|
| 37 |
+
"in_channels": 4,
|
| 38 |
+
"layers_per_block": 2,
|
| 39 |
+
"mid_block_only_cross_attention": null,
|
| 40 |
+
"mid_block_scale_factor": 1,
|
| 41 |
+
"mid_block_type": "UNetMidBlock2DCrossAttn",
|
| 42 |
+
"norm_eps": 1e-05,
|
| 43 |
+
"norm_num_groups": 32,
|
| 44 |
+
"num_attention_heads": null,
|
| 45 |
+
"num_class_embeds": null,
|
| 46 |
+
"only_cross_attention": false,
|
| 47 |
+
"out_channels": 4,
|
| 48 |
+
"projection_class_embeddings_input_dim": null,
|
| 49 |
+
"resnet_out_scale_factor": 1.0,
|
| 50 |
+
"resnet_skip_time_act": false,
|
| 51 |
+
"resnet_time_scale_shift": "default",
|
| 52 |
+
"reverse_transformer_layers_per_block": null,
|
| 53 |
+
"sample_size": 64,
|
| 54 |
+
"time_cond_proj_dim": null,
|
| 55 |
+
"time_embedding_act_fn": null,
|
| 56 |
+
"time_embedding_dim": null,
|
| 57 |
+
"time_embedding_type": "positional",
|
| 58 |
+
"timestep_post_act": null,
|
| 59 |
+
"transformer_layers_per_block": 1,
|
| 60 |
+
"up_block_types": [
|
| 61 |
+
"UpBlock2D",
|
| 62 |
+
"CrossAttnUpBlock2D",
|
| 63 |
+
"CrossAttnUpBlock2D",
|
| 64 |
+
"CrossAttnUpBlock2D"
|
| 65 |
+
],
|
| 66 |
+
"upcast_attention": false,
|
| 67 |
+
"use_linear_projection": false
|
| 68 |
+
}
|
pre_trained/unet_id/checkpoint-70000/diffusion_pytorch_model.safetensors
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f1e441ce5eed311efe690e31674d7ba31283bb5d4d6ee6c522f6001c6ce42faa
|
| 3 |
+
size 3438167536
|
utils/checkpoints/net_seg_res18.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:041ab78a4f8f756cd7e93df0d2840d03162e46b9c463c144f7fdf3ee3e6c4233
|
| 3 |
+
size 57429148
|
utils/checkpoints/third_party/BFM_model_front.mat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a9f127eb18c3d022acbdbfcf1b6b353d01a6e01785d675a67cc31a3826a5be0f
|
| 3 |
+
size 127170280
|
utils/checkpoints/third_party/d3dfr_res50_nofc.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:52c54b90304a06c16b6813910c26faff1a907d4f8bd69a71ad4ecff43b41a090
|
| 3 |
+
size 96449126
|
utils/third_party/__pycache__/model_resnet_d3dfr.cpython-39.pyc
ADDED
|
Binary file (14.3 kB). View file
|
|
|
utils/third_party/d3dfr/__pycache__/bfm.cpython-39.pyc
ADDED
|
Binary file (12.6 kB). View file
|
|
|
utils/third_party/d3dfr/bfm.py
ADDED
|
@@ -0,0 +1,473 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
# import torch.nn as nn
|
| 4 |
+
from scipy.io import loadmat
|
| 5 |
+
import numpy as np
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
# CURRENT_PATH = os.path.dirname(os.path.realpath(__file__))
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def perspective_projection(focal, center):
|
| 12 |
+
# return p.T (N, 3) @ (3, 3)
|
| 13 |
+
return np.array([
|
| 14 |
+
focal, 0, center,
|
| 15 |
+
0, focal, center,
|
| 16 |
+
0, 0, 1
|
| 17 |
+
]).reshape([3, 3]).astype(np.float32).transpose()
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SH:
|
| 21 |
+
def __init__(self):
|
| 22 |
+
self.a = [np.pi, 2 * np.pi / np.sqrt(3.), 2 * np.pi / np.sqrt(8.)]
|
| 23 |
+
self.c = [1/np.sqrt(4 * np.pi), np.sqrt(3.) / np.sqrt(4 * np.pi), 3 * np.sqrt(5.) / np.sqrt(12 * np.pi)]
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class BFM(torch.nn.Module):
|
| 27 |
+
# BFM 3D face model
|
| 28 |
+
def __init__(self,
|
| 29 |
+
recenter=True,
|
| 30 |
+
camera_distance=10.,
|
| 31 |
+
init_lit=np.array([0.8, 0, 0, 0, 0, 0, 0, 0, 0]),
|
| 32 |
+
focal=1015.,
|
| 33 |
+
image_size=224,
|
| 34 |
+
bfm_model_path='pretrained/BFM_model_front.mat'
|
| 35 |
+
):
|
| 36 |
+
super().__init__()
|
| 37 |
+
model = loadmat(bfm_model_path)
|
| 38 |
+
# self.bfm_uv = loadmat(os.path.join(CURRENT_PATH, 'BFM/BFM_UV.mat'))
|
| 39 |
+
# print(model.keys())
|
| 40 |
+
# mean face shape. [3*N,1]
|
| 41 |
+
# self.meanshape = torch.from_numpy(model['meanshape'])
|
| 42 |
+
self.register_buffer('meanshape', torch.from_numpy(model['meanshape']).float())
|
| 43 |
+
|
| 44 |
+
if recenter:
|
| 45 |
+
meanshape = self.meanshape.view(-1, 3)
|
| 46 |
+
meanshape = meanshape - torch.mean(meanshape, dim=0, keepdim=True)
|
| 47 |
+
self.meanshape = meanshape.view(-1, 1)
|
| 48 |
+
|
| 49 |
+
# identity basis. [3*N,80]
|
| 50 |
+
# self.idBase = torch.from_numpy(model['idBase'])
|
| 51 |
+
self.register_buffer('idBase', torch.from_numpy(model['idBase']).float())
|
| 52 |
+
# self.idBase = nn.Parameter(torch.from_numpy(model['idBase']).float())
|
| 53 |
+
# self.exBase = torch.from_numpy(model['exBase'].astype(
|
| 54 |
+
# np.float32)) # expression basis. [3*N,64]
|
| 55 |
+
self.register_buffer('exBase', torch.from_numpy(model['exBase']).float())
|
| 56 |
+
# self.exBase = nn.Parameter(torch.from_numpy(model['exBase']).float())
|
| 57 |
+
# mean face texture. [3*N,1] (0-255)
|
| 58 |
+
# self.meantex = torch.from_numpy(model['meantex'])
|
| 59 |
+
self.register_buffer('meantex', torch.from_numpy(model['meantex']).float())
|
| 60 |
+
# texture basis. [3*N,80]
|
| 61 |
+
# self.texBase = torch.from_numpy(model['texBase'])
|
| 62 |
+
self.register_buffer('texBase', torch.from_numpy(model['texBase']).float())
|
| 63 |
+
# self.texBase = nn.Parameter(torch.from_numpy(model['texBase']).float())
|
| 64 |
+
|
| 65 |
+
# triangle indices for each vertex that lies in. starts from 0. [N,8]
|
| 66 |
+
self.register_buffer('point_buf', torch.from_numpy(model['point_buf']).long()-1)
|
| 67 |
+
# self.point_buf = model['point_buf'].astype(np.int32)
|
| 68 |
+
# vertex indices in each triangle. starts from 0. [F,3]
|
| 69 |
+
self.register_buffer('face_buf', torch.from_numpy(model['tri']).long()-1)
|
| 70 |
+
# self.tri = model['tri'].astype(np.int32)
|
| 71 |
+
# vertex indices of 68 facial landmarks. starts from 0. [68]
|
| 72 |
+
self.register_buffer('keypoints', torch.from_numpy(model['keypoints']).long().view(68)-1)
|
| 73 |
+
# self.keypoints = model['keypoints'].astype(np.int32)[0]
|
| 74 |
+
# print(self.keypoints)
|
| 75 |
+
# print('keypoints', self.keypoints)
|
| 76 |
+
|
| 77 |
+
# vertex indices for small face region to compute photometric error. starts from 0.
|
| 78 |
+
# self.front_mask = np.squeeze(model['frontmask2_idx']).astype(np.int64) - 1
|
| 79 |
+
self.register_buffer('front_mask', torch.from_numpy(np.squeeze(model['frontmask2_idx'])).long()-1)
|
| 80 |
+
# vertex indices for each face from small face region. starts from 0. [f,3]
|
| 81 |
+
# self.front_face_buf = model['tri_mask2'].astype(np.int64) - 1
|
| 82 |
+
self.register_buffer('front_face_buf', torch.from_numpy(np.squeeze(model['tri_mask2'])).long() - 1)
|
| 83 |
+
# vertex indices for pre-defined skin region to compute reflectance loss
|
| 84 |
+
# self.skin_mask = np.squeeze(model['skinmask'])
|
| 85 |
+
self.register_buffer('skin_mask', torch.from_numpy(np.squeeze(model['skinmask'])))
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# keypoints_222 = []
|
| 89 |
+
# with open(os.path.join(CURRENT_PATH, 'BFM/D3DFR_222.txt'), 'r') as f:
|
| 90 |
+
# for line in f.readlines():
|
| 91 |
+
# idx = int(line.strip())
|
| 92 |
+
# keypoints_222.append(max(idx, 0))
|
| 93 |
+
# self.register_buffer('keypoints_222', torch.from_numpy(np.array(keypoints_222)).long())
|
| 94 |
+
|
| 95 |
+
# (1) right eye outer corner, (2) right eye inner corner, (3) left eye inner corner, (4) left eye outer corner,
|
| 96 |
+
# (5) nose bottom, (6) right mouth corner, (7) left mouth corner
|
| 97 |
+
self.register_buffer('keypoints_7', self.keypoints[[36, 39, 42, 45, 33, 48, 54]])
|
| 98 |
+
|
| 99 |
+
# self.persc_proj = torch.from_numpy(perspective_projection(focal, center)).float()
|
| 100 |
+
self.register_buffer('persc_proj', torch.from_numpy(perspective_projection(focal, image_size/2)))
|
| 101 |
+
self.camera_distance = camera_distance
|
| 102 |
+
self.image_size = image_size
|
| 103 |
+
self.SH = SH()
|
| 104 |
+
# self.init_lit = init_lit.reshape([1, 1, -1]).astype(np.float32)
|
| 105 |
+
self.register_buffer('init_lit', torch.from_numpy(init_lit.reshape([1, 1, -1]).astype(np.float32)))
|
| 106 |
+
|
| 107 |
+
# (1) right eye outer corner, (2) right eye inner corner, (3) left eye inner corner, (4) left eye outer corner,
|
| 108 |
+
# (5) nose bottom, (6) right mouth corner, (7) left mouth corner
|
| 109 |
+
# print(self.keypoints[[36, 39, 42, 45, 33, 48, 54]])
|
| 110 |
+
|
| 111 |
+
# Lm3D = loadmat(os.path.join(CURRENT_PATH, 'BFM/similarity_Lm3D_all.mat'))
|
| 112 |
+
# Lm3D = Lm3D['lm']
|
| 113 |
+
# # print(Lm3D)
|
| 114 |
+
#
|
| 115 |
+
# # calculate 5 facial landmarks using 68 landmarks
|
| 116 |
+
# lm_idx = np.array([31, 37, 40, 43, 46, 49, 55]) - 1
|
| 117 |
+
# Lm3D = np.stack([Lm3D[lm_idx[0], :], np.mean(Lm3D[lm_idx[[1, 2]], :], 0), np.mean(
|
| 118 |
+
# Lm3D[lm_idx[[3, 4]], :], 0), Lm3D[lm_idx[5], :], Lm3D[lm_idx[6], :]], axis=0)
|
| 119 |
+
# Lm3D = Lm3D[[1, 2, 0, 3, 4], :]
|
| 120 |
+
# self.Lm3D = Lm3D
|
| 121 |
+
# print(Lm3D.shape)
|
| 122 |
+
|
| 123 |
+
def split_coeff(self, coeff):
|
| 124 |
+
# input: coeff with shape [1,258]
|
| 125 |
+
id_coeff = coeff[:, 0:80] # identity(shape) coeff of dim 80
|
| 126 |
+
ex_coeff = coeff[:, 80:144] # expression coeff of dim 64
|
| 127 |
+
tex_coeff = coeff[:, 144:224] # texture(albedo) coeff of dim 80
|
| 128 |
+
gamma = coeff[:, 227:254] # lighting coeff for 3 channel SH function of dim 27
|
| 129 |
+
angles = coeff[:, 224:227] # ruler angles(x,y,z) for rotation of dim 3
|
| 130 |
+
translation = coeff[:, 254:257] # translation coeff of dim 3
|
| 131 |
+
|
| 132 |
+
return id_coeff, ex_coeff, tex_coeff, gamma, angles, translation
|
| 133 |
+
|
| 134 |
+
def split_coeff_orderly(self, coeff):
|
| 135 |
+
# input: coeff with shape [1,258]
|
| 136 |
+
id_coeff = coeff[:, 0:80] # identity(shape) coeff of dim 80
|
| 137 |
+
ex_coeff = coeff[:, 80:144] # expression coeff of dim 64
|
| 138 |
+
tex_coeff = coeff[:, 144:224] # texture(albedo) coeff of dim 80
|
| 139 |
+
angles = coeff[:, 224:227] # ruler angles(x,y,z) for rotation of dim 3
|
| 140 |
+
gamma = coeff[:, 227:254] # lighting coeff for 3 channel SH function of dim 27
|
| 141 |
+
translation = coeff[:, 254:257] # translation coeff of dim 3
|
| 142 |
+
|
| 143 |
+
return id_coeff, ex_coeff, tex_coeff, angles, gamma, translation
|
| 144 |
+
|
| 145 |
+
def compute_exp_deform(self, exp_coeff):
|
| 146 |
+
exp_part = torch.einsum('ij,aj->ai', self.exBase, exp_coeff)
|
| 147 |
+
return exp_part
|
| 148 |
+
|
| 149 |
+
def compute_id_deform(self, id_coeff):
|
| 150 |
+
id_part = torch.einsum('ij,aj->ai', self.idBase, id_coeff)
|
| 151 |
+
return id_part
|
| 152 |
+
|
| 153 |
+
def compute_shape_from_coeff(self, coeff):
|
| 154 |
+
id_coeff = coeff[:, 0:80]
|
| 155 |
+
ex_coeff = coeff[:, 80:144]
|
| 156 |
+
batch_size = coeff.shape[0]
|
| 157 |
+
id_part = torch.einsum('ij,aj->ai', self.idBase, id_coeff) #B, n
|
| 158 |
+
exp_part = torch.einsum('ij,aj->ai', self.exBase, ex_coeff) #B, n
|
| 159 |
+
face_shape = id_part + exp_part + self.meanshape.view(1, -1)
|
| 160 |
+
return face_shape.view(batch_size, -1, 3)
|
| 161 |
+
|
| 162 |
+
def compute_shape(self, id_coeff, exp_coeff):
|
| 163 |
+
"""
|
| 164 |
+
Return:
|
| 165 |
+
face_shape -- torch.tensor, size (B, N, 3)
|
| 166 |
+
Parameters:
|
| 167 |
+
id_coeff -- torch.tensor, size (B, 80), identity coeffs
|
| 168 |
+
id_relative_scale -- torch.tensor, size (B, 1), identity coeffs
|
| 169 |
+
exp_coeff -- torch.tensor, size (B, 64), expression coeffs
|
| 170 |
+
"""
|
| 171 |
+
batch_size = id_coeff.shape[0]
|
| 172 |
+
id_part = torch.einsum('ij,aj->ai', self.idBase, id_coeff) #B, n
|
| 173 |
+
exp_part = torch.einsum('ij,aj->ai', self.exBase, exp_coeff) #B, n
|
| 174 |
+
face_shape = id_part + exp_part + self.meanshape.view(1, -1)
|
| 175 |
+
return face_shape.view(batch_size, -1, 3)
|
| 176 |
+
|
| 177 |
+
def compute_texture(self, tex_coeff, normalize=True):
|
| 178 |
+
"""
|
| 179 |
+
Return:
|
| 180 |
+
face_texture -- torch.tensor, size (B, N, 3), in RGB order, range (0, 1.)
|
| 181 |
+
Parameters:
|
| 182 |
+
tex_coeff -- torch.tensor, size (B, 80)
|
| 183 |
+
"""
|
| 184 |
+
batch_size = tex_coeff.shape[0]
|
| 185 |
+
face_texture = torch.einsum('ij,aj->ai', self.texBase, tex_coeff) + self.meantex
|
| 186 |
+
if normalize:
|
| 187 |
+
face_texture = face_texture / 255.
|
| 188 |
+
return face_texture.view(batch_size, -1, 3)
|
| 189 |
+
|
| 190 |
+
def compute_norm(self, face_shape):
|
| 191 |
+
"""
|
| 192 |
+
Return:
|
| 193 |
+
vertex_norm -- torch.tensor, size (B, N, 3)
|
| 194 |
+
Parameters:
|
| 195 |
+
face_shape -- torch.tensor, size (B, N, 3)
|
| 196 |
+
"""
|
| 197 |
+
|
| 198 |
+
v1 = face_shape[:, self.face_buf[:, 0]]
|
| 199 |
+
v2 = face_shape[:, self.face_buf[:, 1]]
|
| 200 |
+
v3 = face_shape[:, self.face_buf[:, 2]]
|
| 201 |
+
e1 = v1 - v2
|
| 202 |
+
e2 = v2 - v3
|
| 203 |
+
face_norm = torch.cross(e1, e2, dim=-1)
|
| 204 |
+
face_norm = F.normalize(face_norm, dim=-1, p=2)
|
| 205 |
+
face_norm = torch.cat([face_norm, torch.zeros(face_norm.shape[0], 1, 3).to(self.meanshape)], dim=1)
|
| 206 |
+
|
| 207 |
+
vertex_norm = torch.sum(face_norm[:, self.point_buf], dim=2)
|
| 208 |
+
vertex_norm = F.normalize(vertex_norm, dim=-1, p=2)
|
| 209 |
+
return vertex_norm
|
| 210 |
+
|
| 211 |
+
def compute_color(self, face_texture, face_norm, gamma):
|
| 212 |
+
"""
|
| 213 |
+
Return:
|
| 214 |
+
face_color -- torch.tensor, size (B, N, 3), range (0, 1.)
|
| 215 |
+
Parameters:
|
| 216 |
+
face_texture -- torch.tensor, size (B, N, 3), from texture model, range (0, 1.)
|
| 217 |
+
face_norm -- torch.tensor, size (B, N, 3), rotated face normal
|
| 218 |
+
gamma -- torch.tensor, size (B, 27), SH coeffs
|
| 219 |
+
"""
|
| 220 |
+
batch_size = gamma.shape[0]
|
| 221 |
+
v_num = face_texture.shape[1]
|
| 222 |
+
a, c = self.SH.a, self.SH.c
|
| 223 |
+
gamma = gamma.reshape([batch_size, 3, 9])
|
| 224 |
+
gamma = gamma + self.init_lit
|
| 225 |
+
gamma = gamma.permute(0, 2, 1)
|
| 226 |
+
Y = torch.cat([
|
| 227 |
+
a[0] * c[0] * torch.ones_like(face_norm[..., :1]).to(self.meanshape),
|
| 228 |
+
-a[1] * c[1] * face_norm[..., 1:2],
|
| 229 |
+
a[1] * c[1] * face_norm[..., 2:],
|
| 230 |
+
-a[1] * c[1] * face_norm[..., :1],
|
| 231 |
+
a[2] * c[2] * face_norm[..., :1] * face_norm[..., 1:2],
|
| 232 |
+
-a[2] * c[2] * face_norm[..., 1:2] * face_norm[..., 2:],
|
| 233 |
+
0.5 * a[2] * c[2] / np.sqrt(3.) * (3 * face_norm[..., 2:] ** 2 - 1),
|
| 234 |
+
-a[2] * c[2] * face_norm[..., :1] * face_norm[..., 2:],
|
| 235 |
+
0.5 * a[2] * c[2] * (face_norm[..., :1] ** 2 - face_norm[..., 1:2] ** 2)
|
| 236 |
+
], dim=-1)
|
| 237 |
+
r = Y @ gamma[..., :1]
|
| 238 |
+
g = Y @ gamma[..., 1:2]
|
| 239 |
+
b = Y @ gamma[..., 2:]
|
| 240 |
+
face_color = torch.cat([r, g, b], dim=-1) * face_texture
|
| 241 |
+
return face_color
|
| 242 |
+
|
| 243 |
+
def compute_rotation(self, angles):
|
| 244 |
+
"""
|
| 245 |
+
Return:
|
| 246 |
+
rot -- torch.tensor, size (B, 3, 3) pts @ trans_mat
|
| 247 |
+
Parameters:
|
| 248 |
+
angles -- torch.tensor, size (B, 3), radian
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
batch_size = angles.shape[0]
|
| 252 |
+
ones = torch.ones([batch_size, 1]).to(self.meanshape)
|
| 253 |
+
zeros = torch.zeros([batch_size, 1]).to(self.meanshape)
|
| 254 |
+
x, y, z = angles[:, :1], angles[:, 1:2], angles[:, 2:],
|
| 255 |
+
|
| 256 |
+
rot_x = torch.cat([
|
| 257 |
+
ones, zeros, zeros,
|
| 258 |
+
zeros, torch.cos(x), -torch.sin(x),
|
| 259 |
+
zeros, torch.sin(x), torch.cos(x)
|
| 260 |
+
], dim=1).reshape([batch_size, 3, 3])
|
| 261 |
+
|
| 262 |
+
rot_y = torch.cat([
|
| 263 |
+
torch.cos(y), zeros, torch.sin(y),
|
| 264 |
+
zeros, ones, zeros,
|
| 265 |
+
-torch.sin(y), zeros, torch.cos(y)
|
| 266 |
+
], dim=1).reshape([batch_size, 3, 3])
|
| 267 |
+
|
| 268 |
+
rot_z = torch.cat([
|
| 269 |
+
torch.cos(z), -torch.sin(z), zeros,
|
| 270 |
+
torch.sin(z), torch.cos(z), zeros,
|
| 271 |
+
zeros, zeros, ones
|
| 272 |
+
], dim=1).reshape([batch_size, 3, 3])
|
| 273 |
+
|
| 274 |
+
rot = rot_z @ rot_y @ rot_x
|
| 275 |
+
return rot.permute(0, 2, 1)
|
| 276 |
+
|
| 277 |
+
def to_camera(self, face_shape):
|
| 278 |
+
face_shape[..., -1] = self.camera_distance - face_shape[..., -1]
|
| 279 |
+
return face_shape
|
| 280 |
+
|
| 281 |
+
def to_image(self, face_shape):
|
| 282 |
+
"""
|
| 283 |
+
Return:
|
| 284 |
+
face_proj -- torch.tensor, size (B, N, 2), y direction is opposite to v direction
|
| 285 |
+
Parameters:
|
| 286 |
+
face_shape -- torch.tensor, size (B, N, 3)
|
| 287 |
+
"""
|
| 288 |
+
# to image_plane
|
| 289 |
+
face_proj = face_shape @ self.persc_proj
|
| 290 |
+
|
| 291 |
+
# print(face_proj.shape)
|
| 292 |
+
face_proj = face_proj[..., :2] / face_proj[..., 2:]
|
| 293 |
+
|
| 294 |
+
return face_proj
|
| 295 |
+
|
| 296 |
+
def rotate(self, face_shape, rot):
|
| 297 |
+
"""
|
| 298 |
+
Return:
|
| 299 |
+
face_shape -- torch.tensor, size (B, N, 3) pts @ rot + trans
|
| 300 |
+
Parameters:
|
| 301 |
+
face_shape -- torch.tensor, size (B, N, 3)
|
| 302 |
+
rot -- torch.tensor, size (B, 3, 3)
|
| 303 |
+
|
| 304 |
+
"""
|
| 305 |
+
return face_shape @ rot
|
| 306 |
+
|
| 307 |
+
def get_landmarks7(self, face_proj):
|
| 308 |
+
"""
|
| 309 |
+
Return:
|
| 310 |
+
face_lms -- torch.tensor, size (B, 68, 2)
|
| 311 |
+
Parameters:
|
| 312 |
+
face_proj -- torch.tensor, size (B, N, 2)
|
| 313 |
+
"""
|
| 314 |
+
return face_proj[:, self.keypoints_7, :]
|
| 315 |
+
|
| 316 |
+
def get_landmarks68(self, face_proj):
|
| 317 |
+
"""
|
| 318 |
+
Return:
|
| 319 |
+
face_lms -- torch.tensor, size (B, 68, 2)
|
| 320 |
+
Parameters:
|
| 321 |
+
face_proj -- torch.tensor, size (B, N, 2)
|
| 322 |
+
"""
|
| 323 |
+
return face_proj[:, self.keypoints, :]
|
| 324 |
+
|
| 325 |
+
def get_landmarks222(self, face_proj):
|
| 326 |
+
"""
|
| 327 |
+
Return:
|
| 328 |
+
face_lms -- torch.tensor, size (B, 68, 2)
|
| 329 |
+
Parameters:
|
| 330 |
+
face_proj -- torch.tensor, size (B, N, 2)
|
| 331 |
+
"""
|
| 332 |
+
return face_proj[:, self.keypoints_222, :]
|
| 333 |
+
|
| 334 |
+
def compute_for_render(self, coeffs):
|
| 335 |
+
"""
|
| 336 |
+
Return:
|
| 337 |
+
face_vertex -- torch.tensor, size (B, N, 3), in camera coordinate
|
| 338 |
+
face_color -- torch.tensor, size (B, N, 3), in RGB order
|
| 339 |
+
landmark -- torch.tensor, size (B, 68, 2), y direction is opposite to v direction
|
| 340 |
+
Parameters:
|
| 341 |
+
coeffs -- torch.tensor, size (B, 258)
|
| 342 |
+
"""
|
| 343 |
+
id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeffs)
|
| 344 |
+
# id_relative_scale = id_relative_scale.clamp(0.9,1.1)
|
| 345 |
+
face_shape = self.compute_shape(id_coeff, ex_coeff)
|
| 346 |
+
# face_shape_noexp = self.compute_shape(id_coeff, torch.zeros_like(ex_coeff))
|
| 347 |
+
# print(face_shape.size())
|
| 348 |
+
rotation = self.compute_rotation(angles)
|
| 349 |
+
# print('rotation')
|
| 350 |
+
|
| 351 |
+
face_shape_rotated = self.rotate(face_shape, rotation)
|
| 352 |
+
face_shape_transformed = face_shape_rotated + translation.unsqueeze(1)
|
| 353 |
+
face_vertex = self.to_camera(face_shape_transformed)
|
| 354 |
+
face_proj = self.to_image(face_vertex)
|
| 355 |
+
|
| 356 |
+
# face_shape_transformed_noexp = self.transform(face_shape_noexp, rotation, translation, scale_xyz)
|
| 357 |
+
# face_vertex_noexp = self.to_camera(face_shape_transformed_noexp)
|
| 358 |
+
|
| 359 |
+
landmark68 = self.get_landmarks68(face_proj)
|
| 360 |
+
# landmark_face = face_proj[:,self.front_mask[::32], :]
|
| 361 |
+
landmark68[:, :, 1] = self.image_size - 1 - landmark68[:, :, 1]
|
| 362 |
+
|
| 363 |
+
face_texture = self.compute_texture(tex_coeff)
|
| 364 |
+
face_norm_roted = self.compute_norm(face_shape_rotated)
|
| 365 |
+
# face_norm_roted = face_norm @ rotation
|
| 366 |
+
face_color = self.compute_color(face_texture, face_norm_roted, gamma)
|
| 367 |
+
|
| 368 |
+
# face_norm_noexp = self.compute_norm(face_shape_noexp)
|
| 369 |
+
# face_norm_noexp_roted = face_norm_noexp @ rotation
|
| 370 |
+
# face_color_noexp = self.compute_color(face_texture, face_norm_noexp_roted, gamma)
|
| 371 |
+
|
| 372 |
+
return face_shape, face_vertex, face_color, face_texture, landmark68
|
| 373 |
+
|
| 374 |
+
def get_lm68(self, coeffs):
|
| 375 |
+
id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeffs)
|
| 376 |
+
ex_coeff = torch.zeros_like(ex_coeff)
|
| 377 |
+
# id_relative_scale = id_relative_scale.clamp(0.9,1.1)
|
| 378 |
+
face_shape = self.compute_shape(id_coeff, ex_coeff)
|
| 379 |
+
# face_shape_noexp = self.compute_shape(id_coeff, torch.zeros_like(ex_coeff))
|
| 380 |
+
# print(face_shape.size())
|
| 381 |
+
rotation = self.compute_rotation(angles)
|
| 382 |
+
# print('rotation')
|
| 383 |
+
|
| 384 |
+
face_shape_rotated = self.rotate(face_shape, rotation)
|
| 385 |
+
face_shape_transformed = face_shape_rotated + translation.unsqueeze(1)
|
| 386 |
+
face_vertex = self.to_camera(face_shape_transformed)
|
| 387 |
+
face_proj = self.to_image(face_vertex)
|
| 388 |
+
|
| 389 |
+
landmark68 = self.get_landmarks68(face_proj)
|
| 390 |
+
# landmark_face = face_proj[:,self.front_mask[::32], :]
|
| 391 |
+
landmark68[:, :, 1] = self.image_size - 1 - landmark68[:, :, 1]
|
| 392 |
+
return landmark68, ex_coeff
|
| 393 |
+
|
| 394 |
+
def get_coeffs(self, coeffs):
|
| 395 |
+
id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeffs)
|
| 396 |
+
return id_coeff, ex_coeff, tex_coeff, gamma, angles, translation
|
| 397 |
+
|
| 398 |
+
def get_vertex(self, coeffs):
|
| 399 |
+
id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeffs)
|
| 400 |
+
# id_relative_scale = id_relative_scale.clamp(0.9,1.1)
|
| 401 |
+
face_shape = self.compute_shape(id_coeff, ex_coeff)
|
| 402 |
+
# face_shape_noexp = self.compute_shape(id_coeff, torch.zeros_like(ex_coeff))
|
| 403 |
+
# print(face_shape.size())
|
| 404 |
+
rotation = self.compute_rotation(angles)
|
| 405 |
+
# print('rotation')
|
| 406 |
+
|
| 407 |
+
face_shape_rotated = self.rotate(face_shape, rotation)
|
| 408 |
+
face_shape_transformed = face_shape_rotated + translation.unsqueeze(1)
|
| 409 |
+
face_vertex = self.to_camera(face_shape_transformed)
|
| 410 |
+
face_proj = self.to_image(face_vertex)
|
| 411 |
+
|
| 412 |
+
return face_proj
|
| 413 |
+
|
| 414 |
+
|
| 415 |
+
def forward(self, coeffs):
|
| 416 |
+
face_shape, face_vertex, face_color, face_texture, landmark68 = self.compute_for_render(coeffs)
|
| 417 |
+
return face_shape, face_vertex, face_color, face_texture, landmark68
|
| 418 |
+
|
| 419 |
+
|
| 420 |
+
def save_obj(self, coeff, obj_name):
|
| 421 |
+
# The image size is 224 * 224
|
| 422 |
+
# face reconstruction with coeff and BFM model
|
| 423 |
+
id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeff)
|
| 424 |
+
|
| 425 |
+
# compute face shape
|
| 426 |
+
face_shape = self.compute_shape(id_coeff, ex_coeff).cpu().detach().numpy()[0]
|
| 427 |
+
face_tri = self.face_buf.cpu().numpy()
|
| 428 |
+
|
| 429 |
+
with open(obj_name, 'w') as fobj:
|
| 430 |
+
for i in range(face_shape.shape[0]):
|
| 431 |
+
fobj.write(
|
| 432 |
+
'v ' + str(face_shape[i][0]) + ' ' + str(face_shape[i][1]) + ' ' + str(face_shape[i][2]) + '\n')
|
| 433 |
+
|
| 434 |
+
# start from 1
|
| 435 |
+
for i in range(face_tri.shape[0]):
|
| 436 |
+
fobj.write('f ' + str(face_tri[i][0] + 1) + ' ' + str(face_tri[i][1] + 1) + ' ' + str(
|
| 437 |
+
face_tri[i][2] + 1) + '\n')
|
| 438 |
+
|
| 439 |
+
# lm7 = face_shape[[2215, 5828, 10455, 14066, 8204, 5522, 10795], :]
|
| 440 |
+
# with open(obj_name[:-3]+'txt', 'w') as f:
|
| 441 |
+
# for point in lm7:
|
| 442 |
+
# f.write('{} {} {}\n'.format(point[0], point[1], point[2]))
|
| 443 |
+
|
| 444 |
+
def save_neutral_obj(self, coeff, obj_name):
|
| 445 |
+
# The image size is 224 * 224
|
| 446 |
+
# face reconstruction with coeff and BFM model
|
| 447 |
+
id_coeff, ex_coeff, tex_coeff, gamma, angles, translation = self.split_coeff(coeff)
|
| 448 |
+
|
| 449 |
+
# compute face shape
|
| 450 |
+
face_shape = self.compute_shape(id_coeff, ex_coeff*0).cpu().numpy()[0]
|
| 451 |
+
face_tri = self.face_buf.cpu().numpy()
|
| 452 |
+
|
| 453 |
+
with open(obj_name, 'w') as fobj:
|
| 454 |
+
for i in range(face_shape.shape[0]):
|
| 455 |
+
fobj.write(
|
| 456 |
+
'v ' + str(face_shape[i][0]) + ' ' + str(face_shape[i][1]) + ' ' + str(face_shape[i][2]) + '\n')
|
| 457 |
+
|
| 458 |
+
# start from 1
|
| 459 |
+
for i in range(face_tri.shape[0]):
|
| 460 |
+
fobj.write('f ' + str(face_tri[i][0] + 1) + ' ' + str(face_tri[i][1] + 1) + ' ' + str(
|
| 461 |
+
face_tri[i][2] + 1) + '\n')
|
| 462 |
+
|
| 463 |
+
# lm7 = face_shape[[2215, 5828, 10455, 14066, 8204, 5522, 10795], :]
|
| 464 |
+
# with open(obj_name[:-3]+'txt', 'w') as f:
|
| 465 |
+
# for point in lm7:
|
| 466 |
+
# f.write('{} {} {}\n'.format(point[0], point[1], point[2]))
|
| 467 |
+
|
| 468 |
+
# def clip(self, g_ratio=0.1, t_ratio=0.1):
|
| 469 |
+
# self.idBase.data = torch.minimum(torch.maximum(self.idBase_org * (1 - g_ratio), self.idBase.data), self.idBase_org * (1 + g_ratio))
|
| 470 |
+
# self.exBase.data = self.exBase_org #torch.minimum(torch.maximum(self.exBase_org * (1 - 0.001), self.exBase.data), self.exBase_org * (1 + 0.001))
|
| 471 |
+
# self.texBase.data = torch.minimum(torch.maximum(self.texBase_org * (1 - t_ratio), self.texBase.data), self.texBase_org * (1 + t_ratio))
|
| 472 |
+
|
| 473 |
+
|
utils/third_party/d3dfr_res50_nofc.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:62e7e6a6bc4e16fb567643182ccabf55f8222746269cce3392109a6c592babc1
|
| 3 |
+
size 288887131
|
utils/third_party/insightface_backbone_conv.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200', 'getarcface']
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
| 9 |
+
"""3x3 convolution with padding"""
|
| 10 |
+
return nn.Conv2d(in_planes,
|
| 11 |
+
out_planes,
|
| 12 |
+
kernel_size=3,
|
| 13 |
+
stride=stride,
|
| 14 |
+
padding=dilation,
|
| 15 |
+
groups=groups,
|
| 16 |
+
bias=False,
|
| 17 |
+
dilation=dilation)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
| 21 |
+
"""1x1 convolution"""
|
| 22 |
+
return nn.Conv2d(in_planes,
|
| 23 |
+
out_planes,
|
| 24 |
+
kernel_size=1,
|
| 25 |
+
stride=stride,
|
| 26 |
+
bias=False)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class IBasicBlock(nn.Module):
|
| 30 |
+
expansion = 1
|
| 31 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
| 32 |
+
groups=1, base_width=64, dilation=1):
|
| 33 |
+
super(IBasicBlock, self).__init__()
|
| 34 |
+
if groups != 1 or base_width != 64:
|
| 35 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 36 |
+
if dilation > 1:
|
| 37 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 38 |
+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
|
| 39 |
+
self.conv1 = conv3x3(inplanes, planes)
|
| 40 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
|
| 41 |
+
self.prelu = nn.PReLU(planes)
|
| 42 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
| 43 |
+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
|
| 44 |
+
self.downsample = downsample
|
| 45 |
+
self.stride = stride
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
identity = x
|
| 49 |
+
out = self.bn1(x)
|
| 50 |
+
out = self.conv1(out)
|
| 51 |
+
out = self.bn2(out)
|
| 52 |
+
out = self.prelu(out)
|
| 53 |
+
out = self.conv2(out)
|
| 54 |
+
out = self.bn3(out)
|
| 55 |
+
if self.downsample is not None:
|
| 56 |
+
identity = self.downsample(x)
|
| 57 |
+
out += identity
|
| 58 |
+
return out
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class IResNet(nn.Module):
|
| 62 |
+
fc_scale = 7 * 7
|
| 63 |
+
def __init__(self,
|
| 64 |
+
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
| 65 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
| 66 |
+
super(IResNet, self).__init__()
|
| 67 |
+
self.fp16 = fp16
|
| 68 |
+
self.inplanes = 64
|
| 69 |
+
self.dilation = 1
|
| 70 |
+
if replace_stride_with_dilation is None:
|
| 71 |
+
replace_stride_with_dilation = [False, False, False]
|
| 72 |
+
if len(replace_stride_with_dilation) != 3:
|
| 73 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 74 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 75 |
+
self.groups = groups
|
| 76 |
+
self.base_width = width_per_group
|
| 77 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
| 78 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
| 79 |
+
self.prelu = nn.PReLU(self.inplanes)
|
| 80 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
| 81 |
+
self.layer2 = self._make_layer(block,
|
| 82 |
+
128,
|
| 83 |
+
layers[1],
|
| 84 |
+
stride=2,
|
| 85 |
+
dilate=replace_stride_with_dilation[0])
|
| 86 |
+
self.layer3 = self._make_layer(block,
|
| 87 |
+
256,
|
| 88 |
+
layers[2],
|
| 89 |
+
stride=2,
|
| 90 |
+
dilate=replace_stride_with_dilation[1])
|
| 91 |
+
self.layer4 = self._make_layer(block,
|
| 92 |
+
512,
|
| 93 |
+
layers[3],
|
| 94 |
+
stride=2,
|
| 95 |
+
dilate=replace_stride_with_dilation[2])
|
| 96 |
+
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
|
| 97 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
| 98 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
| 99 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
| 100 |
+
nn.init.constant_(self.features.weight, 1.0)
|
| 101 |
+
self.features.weight.requires_grad = False
|
| 102 |
+
|
| 103 |
+
for m in self.modules():
|
| 104 |
+
if isinstance(m, nn.Conv2d):
|
| 105 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
| 106 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 107 |
+
nn.init.constant_(m.weight, 1)
|
| 108 |
+
nn.init.constant_(m.bias, 0)
|
| 109 |
+
|
| 110 |
+
if zero_init_residual:
|
| 111 |
+
for m in self.modules():
|
| 112 |
+
if isinstance(m, IBasicBlock):
|
| 113 |
+
nn.init.constant_(m.bn2.weight, 0)
|
| 114 |
+
|
| 115 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
| 116 |
+
downsample = None
|
| 117 |
+
previous_dilation = self.dilation
|
| 118 |
+
if dilate:
|
| 119 |
+
self.dilation *= stride
|
| 120 |
+
stride = 1
|
| 121 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 122 |
+
downsample = nn.Sequential(
|
| 123 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 124 |
+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
|
| 125 |
+
)
|
| 126 |
+
layers = []
|
| 127 |
+
layers.append(
|
| 128 |
+
block(self.inplanes, planes, stride, downsample, self.groups,
|
| 129 |
+
self.base_width, previous_dilation))
|
| 130 |
+
self.inplanes = planes * block.expansion
|
| 131 |
+
for _ in range(1, blocks):
|
| 132 |
+
layers.append(
|
| 133 |
+
block(self.inplanes,
|
| 134 |
+
planes,
|
| 135 |
+
groups=self.groups,
|
| 136 |
+
base_width=self.base_width,
|
| 137 |
+
dilation=self.dilation))
|
| 138 |
+
|
| 139 |
+
return nn.Sequential(*layers)
|
| 140 |
+
|
| 141 |
+
def forward(self, x, return_id512=False):
|
| 142 |
+
|
| 143 |
+
bz = x.shape[0]
|
| 144 |
+
# with torch.cuda.amp.autocast(self.fp16):
|
| 145 |
+
x = self.conv1(x)
|
| 146 |
+
x = self.bn1(x)
|
| 147 |
+
x = self.prelu(x)
|
| 148 |
+
x = self.layer1(x)
|
| 149 |
+
x = self.layer2(x)
|
| 150 |
+
x = self.layer3(x)
|
| 151 |
+
x = self.layer4(x)
|
| 152 |
+
if not return_id512:
|
| 153 |
+
return x.view(bz,512,-1).permute(0,2,1).contiguous()
|
| 154 |
+
else:
|
| 155 |
+
x = self.bn2(x)
|
| 156 |
+
x = torch.flatten(x, 1)
|
| 157 |
+
# x = self.dropout(x)
|
| 158 |
+
# x = self.fc(x.float() if self.fp16 else x)
|
| 159 |
+
x = self.fc(x)
|
| 160 |
+
x = self.features(x)
|
| 161 |
+
return x
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
| 166 |
+
model = IResNet(block, layers, **kwargs)
|
| 167 |
+
if pretrained:
|
| 168 |
+
raise ValueError()
|
| 169 |
+
return model
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def iresnet18(pretrained=False, progress=True, **kwargs):
|
| 173 |
+
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
|
| 174 |
+
progress, **kwargs)
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def iresnet34(pretrained=False, progress=True, **kwargs):
|
| 178 |
+
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
|
| 179 |
+
progress, **kwargs)
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
def iresnet50(pretrained=False, progress=True, **kwargs):
|
| 183 |
+
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
|
| 184 |
+
progress, **kwargs)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def iresnet100(pretrained=False, progress=True, **kwargs):
|
| 188 |
+
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
|
| 189 |
+
progress, **kwargs)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def iresnet200(pretrained=False, progress=True, **kwargs):
|
| 193 |
+
return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
|
| 194 |
+
progress, **kwargs)
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
def getarcface(pretrained=None):
|
| 198 |
+
model = iresnet100().eval()
|
| 199 |
+
for param in model.parameters():
|
| 200 |
+
param.requires_grad=False
|
| 201 |
+
|
| 202 |
+
if pretrained is not None and os.path.exists(pretrained):
|
| 203 |
+
info = model.load_state_dict(torch.load(pretrained))
|
| 204 |
+
print(info)
|
| 205 |
+
return model
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if __name__=='__main__':
|
| 209 |
+
ckpt = 'pretrained/insightface_glint360k.pth'
|
| 210 |
+
arcface = iresnet100().eval()
|
| 211 |
+
info = arcface.load_state_dict(torch.load(ckpt))
|
| 212 |
+
print(info)
|
| 213 |
+
|
| 214 |
+
id = arcface(torch.randn(1,3,128,128))
|
| 215 |
+
print(id.shape)
|
| 216 |
+
|
| 217 |
+
# import cv2
|
| 218 |
+
# import numpy as np
|
| 219 |
+
# im1_crop256 = cv2.imread('happy.jpg')
|
| 220 |
+
# im2_crop256 = cv2.imread('angry.jpg')
|
| 221 |
+
|
| 222 |
+
# im1_crop112 = cv2.resize(im1_crop256, (128,128))[0:112,8:120,:]
|
| 223 |
+
# im2_crop112 = cv2.resize(im2_crop256, (128,128))[0:112,8:120,:]
|
| 224 |
+
|
| 225 |
+
# cv2.imwrite('1_112.jpg', im1_crop112)
|
| 226 |
+
# cv2.imwrite('2_112.jpg', im2_crop112)
|
| 227 |
+
|
| 228 |
+
# # [-1,1] rgb
|
| 229 |
+
# im1_crop112_tensor = torch.from_numpy(im1_crop112[:,:,[2,1,0]].transpose(2, 0, 1).astype(np.float32)).unsqueeze(0)/127.5-1
|
| 230 |
+
# im2_crop112_tensor = torch.from_numpy(im2_crop112[:,:,[2,1,0]].transpose(2, 0, 1).astype(np.float32)).unsqueeze(0)/127.5-1
|
| 231 |
+
|
| 232 |
+
# im1_id = arcface(im1_crop112_tensor)
|
| 233 |
+
# im2_id = arcface(im2_crop112_tensor)
|
| 234 |
+
|
| 235 |
+
# loss_cos = torch.mean(1-torch.cosine_similarity(im1_id, im2_id, dim=1))
|
| 236 |
+
|
| 237 |
+
# print(loss_cos)
|
utils/third_party/model_parsing.py
ADDED
|
@@ -0,0 +1,381 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/python
|
| 2 |
+
# -*- encoding: utf-8 -*-
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn as nn
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
import torchvision
|
| 9 |
+
import torch.utils.model_zoo as modelzoo
|
| 10 |
+
|
| 11 |
+
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 15 |
+
"""3x3 convolution with padding"""
|
| 16 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 17 |
+
padding=1, bias=False)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class BasicBlock(nn.Module):
|
| 21 |
+
def __init__(self, in_chan, out_chan, stride=1):
|
| 22 |
+
super(BasicBlock, self).__init__()
|
| 23 |
+
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
| 24 |
+
self.bn1 = nn.BatchNorm2d(out_chan)
|
| 25 |
+
self.conv2 = conv3x3(out_chan, out_chan)
|
| 26 |
+
self.bn2 = nn.BatchNorm2d(out_chan)
|
| 27 |
+
self.relu = nn.ReLU(inplace=True)
|
| 28 |
+
self.downsample = None
|
| 29 |
+
if in_chan != out_chan or stride != 1:
|
| 30 |
+
self.downsample = nn.Sequential(
|
| 31 |
+
nn.Conv2d(in_chan, out_chan,
|
| 32 |
+
kernel_size=1, stride=stride, bias=False),
|
| 33 |
+
nn.BatchNorm2d(out_chan),
|
| 34 |
+
)
|
| 35 |
+
|
| 36 |
+
def forward(self, x):
|
| 37 |
+
residual = self.conv1(x)
|
| 38 |
+
residual = F.relu(self.bn1(residual))
|
| 39 |
+
residual = self.conv2(residual)
|
| 40 |
+
residual = self.bn2(residual)
|
| 41 |
+
|
| 42 |
+
shortcut = x
|
| 43 |
+
if self.downsample is not None:
|
| 44 |
+
shortcut = self.downsample(x)
|
| 45 |
+
|
| 46 |
+
out = shortcut + residual
|
| 47 |
+
out = self.relu(out)
|
| 48 |
+
return out
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
| 52 |
+
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
| 53 |
+
for i in range(bnum-1):
|
| 54 |
+
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
| 55 |
+
return nn.Sequential(*layers)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class Resnet18(nn.Module):
|
| 59 |
+
def __init__(self):
|
| 60 |
+
super(Resnet18, self).__init__()
|
| 61 |
+
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
| 62 |
+
bias=False)
|
| 63 |
+
self.bn1 = nn.BatchNorm2d(64)
|
| 64 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 65 |
+
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
| 66 |
+
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
| 67 |
+
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
| 68 |
+
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
| 69 |
+
self.init_weight()
|
| 70 |
+
|
| 71 |
+
def forward(self, x):
|
| 72 |
+
x = self.conv1(x)
|
| 73 |
+
x = F.relu(self.bn1(x))
|
| 74 |
+
x = self.maxpool(x)
|
| 75 |
+
|
| 76 |
+
x = self.layer1(x)
|
| 77 |
+
feat8 = self.layer2(x) # 1/8
|
| 78 |
+
feat16 = self.layer3(feat8) # 1/16
|
| 79 |
+
feat32 = self.layer4(feat16) # 1/32
|
| 80 |
+
return feat8, feat16, feat32
|
| 81 |
+
|
| 82 |
+
def init_weight(self):
|
| 83 |
+
if os.path.isfile('./checkpoints/third_party/resnet18-5c106cde.pth'):
|
| 84 |
+
state_dict = torch.load('./checkpoints/third_party/resnet18-5c106cde.pth')
|
| 85 |
+
else:
|
| 86 |
+
state_dict = modelzoo.load_url(resnet18_url)
|
| 87 |
+
self_state_dict = self.state_dict()
|
| 88 |
+
for k, v in state_dict.items():
|
| 89 |
+
if 'fc' in k: continue
|
| 90 |
+
self_state_dict.update({k: v})
|
| 91 |
+
self.load_state_dict(self_state_dict)
|
| 92 |
+
|
| 93 |
+
def get_params(self):
|
| 94 |
+
wd_params, nowd_params = [], []
|
| 95 |
+
for name, module in self.named_modules():
|
| 96 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 97 |
+
wd_params.append(module.weight)
|
| 98 |
+
if not module.bias is None:
|
| 99 |
+
nowd_params.append(module.bias)
|
| 100 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 101 |
+
nowd_params += list(module.parameters())
|
| 102 |
+
return wd_params, nowd_params
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
class ConvBNReLU(nn.Module):
|
| 107 |
+
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
| 108 |
+
super(ConvBNReLU, self).__init__()
|
| 109 |
+
self.conv = nn.Conv2d(in_chan,
|
| 110 |
+
out_chan,
|
| 111 |
+
kernel_size = ks,
|
| 112 |
+
stride = stride,
|
| 113 |
+
padding = padding,
|
| 114 |
+
bias = False)
|
| 115 |
+
self.bn = nn.BatchNorm2d(out_chan)
|
| 116 |
+
self.init_weight()
|
| 117 |
+
|
| 118 |
+
def forward(self, x):
|
| 119 |
+
x = self.conv(x)
|
| 120 |
+
x = F.relu(self.bn(x))
|
| 121 |
+
return x
|
| 122 |
+
|
| 123 |
+
def init_weight(self):
|
| 124 |
+
for ly in self.children():
|
| 125 |
+
if isinstance(ly, nn.Conv2d):
|
| 126 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 127 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 128 |
+
|
| 129 |
+
class BiSeNetOutput(nn.Module):
|
| 130 |
+
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
| 131 |
+
super(BiSeNetOutput, self).__init__()
|
| 132 |
+
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
| 133 |
+
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
| 134 |
+
self.init_weight()
|
| 135 |
+
|
| 136 |
+
def forward(self, x):
|
| 137 |
+
x = self.conv(x)
|
| 138 |
+
x = self.conv_out(x)
|
| 139 |
+
return x
|
| 140 |
+
|
| 141 |
+
def init_weight(self):
|
| 142 |
+
for ly in self.children():
|
| 143 |
+
if isinstance(ly, nn.Conv2d):
|
| 144 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 145 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 146 |
+
|
| 147 |
+
def get_params(self):
|
| 148 |
+
wd_params, nowd_params = [], []
|
| 149 |
+
for name, module in self.named_modules():
|
| 150 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
| 151 |
+
wd_params.append(module.weight)
|
| 152 |
+
if not module.bias is None:
|
| 153 |
+
nowd_params.append(module.bias)
|
| 154 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 155 |
+
nowd_params += list(module.parameters())
|
| 156 |
+
return wd_params, nowd_params
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
class AttentionRefinementModule(nn.Module):
|
| 160 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
| 161 |
+
super(AttentionRefinementModule, self).__init__()
|
| 162 |
+
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
| 163 |
+
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
| 164 |
+
self.bn_atten = nn.BatchNorm2d(out_chan)
|
| 165 |
+
self.sigmoid_atten = nn.Sigmoid()
|
| 166 |
+
self.init_weight()
|
| 167 |
+
|
| 168 |
+
def forward(self, x):
|
| 169 |
+
feat = self.conv(x)
|
| 170 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
| 171 |
+
atten = self.conv_atten(atten)
|
| 172 |
+
atten = self.bn_atten(atten)
|
| 173 |
+
atten = self.sigmoid_atten(atten)
|
| 174 |
+
out = torch.mul(feat, atten)
|
| 175 |
+
return out
|
| 176 |
+
|
| 177 |
+
def init_weight(self):
|
| 178 |
+
for ly in self.children():
|
| 179 |
+
if isinstance(ly, nn.Conv2d):
|
| 180 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 181 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
class ContextPath(nn.Module):
|
| 185 |
+
def __init__(self, *args, **kwargs):
|
| 186 |
+
super(ContextPath, self).__init__()
|
| 187 |
+
self.resnet = Resnet18()
|
| 188 |
+
self.arm16 = AttentionRefinementModule(256, 128)
|
| 189 |
+
self.arm32 = AttentionRefinementModule(512, 128)
|
| 190 |
+
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
| 191 |
+
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
| 192 |
+
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
| 193 |
+
|
| 194 |
+
self.init_weight()
|
| 195 |
+
|
| 196 |
+
def forward(self, x):
|
| 197 |
+
H0, W0 = x.size()[2:]
|
| 198 |
+
feat8, feat16, feat32 = self.resnet(x)
|
| 199 |
+
H8, W8 = feat8.size()[2:]
|
| 200 |
+
H16, W16 = feat16.size()[2:]
|
| 201 |
+
H32, W32 = feat32.size()[2:]
|
| 202 |
+
|
| 203 |
+
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
| 204 |
+
avg = self.conv_avg(avg)
|
| 205 |
+
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
| 206 |
+
|
| 207 |
+
feat32_arm = self.arm32(feat32)
|
| 208 |
+
feat32_sum = feat32_arm + avg_up
|
| 209 |
+
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
| 210 |
+
feat32_up = self.conv_head32(feat32_up)
|
| 211 |
+
|
| 212 |
+
feat16_arm = self.arm16(feat16)
|
| 213 |
+
feat16_sum = feat16_arm + feat32_up
|
| 214 |
+
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
| 215 |
+
feat16_up = self.conv_head16(feat16_up)
|
| 216 |
+
|
| 217 |
+
return feat8, feat16_up, feat32_up # x8, x8, x16
|
| 218 |
+
|
| 219 |
+
def init_weight(self):
|
| 220 |
+
for ly in self.children():
|
| 221 |
+
if isinstance(ly, nn.Conv2d):
|
| 222 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 223 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 224 |
+
|
| 225 |
+
def get_params(self):
|
| 226 |
+
wd_params, nowd_params = [], []
|
| 227 |
+
for name, module in self.named_modules():
|
| 228 |
+
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
| 229 |
+
wd_params.append(module.weight)
|
| 230 |
+
if not module.bias is None:
|
| 231 |
+
nowd_params.append(module.bias)
|
| 232 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 233 |
+
nowd_params += list(module.parameters())
|
| 234 |
+
return wd_params, nowd_params
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
### This is not used, since I replace this with the resnet feature with the same size
|
| 238 |
+
class SpatialPath(nn.Module):
|
| 239 |
+
def __init__(self, *args, **kwargs):
|
| 240 |
+
super(SpatialPath, self).__init__()
|
| 241 |
+
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
| 242 |
+
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
| 243 |
+
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
| 244 |
+
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
| 245 |
+
self.init_weight()
|
| 246 |
+
|
| 247 |
+
def forward(self, x):
|
| 248 |
+
feat = self.conv1(x)
|
| 249 |
+
feat = self.conv2(feat)
|
| 250 |
+
feat = self.conv3(feat)
|
| 251 |
+
feat = self.conv_out(feat)
|
| 252 |
+
return feat
|
| 253 |
+
|
| 254 |
+
def init_weight(self):
|
| 255 |
+
for ly in self.children():
|
| 256 |
+
if isinstance(ly, nn.Conv2d):
|
| 257 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 258 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 259 |
+
|
| 260 |
+
def get_params(self):
|
| 261 |
+
wd_params, nowd_params = [], []
|
| 262 |
+
for name, module in self.named_modules():
|
| 263 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
| 264 |
+
wd_params.append(module.weight)
|
| 265 |
+
if not module.bias is None:
|
| 266 |
+
nowd_params.append(module.bias)
|
| 267 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 268 |
+
nowd_params += list(module.parameters())
|
| 269 |
+
return wd_params, nowd_params
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class FeatureFusionModule(nn.Module):
|
| 273 |
+
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
| 274 |
+
super(FeatureFusionModule, self).__init__()
|
| 275 |
+
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
| 276 |
+
self.conv1 = nn.Conv2d(out_chan,
|
| 277 |
+
out_chan//4,
|
| 278 |
+
kernel_size = 1,
|
| 279 |
+
stride = 1,
|
| 280 |
+
padding = 0,
|
| 281 |
+
bias = False)
|
| 282 |
+
self.conv2 = nn.Conv2d(out_chan//4,
|
| 283 |
+
out_chan,
|
| 284 |
+
kernel_size = 1,
|
| 285 |
+
stride = 1,
|
| 286 |
+
padding = 0,
|
| 287 |
+
bias = False)
|
| 288 |
+
self.relu = nn.ReLU(inplace=True)
|
| 289 |
+
self.sigmoid = nn.Sigmoid()
|
| 290 |
+
self.init_weight()
|
| 291 |
+
|
| 292 |
+
def forward(self, fsp, fcp):
|
| 293 |
+
fcat = torch.cat([fsp, fcp], dim=1)
|
| 294 |
+
feat = self.convblk(fcat)
|
| 295 |
+
atten = F.avg_pool2d(feat, feat.size()[2:])
|
| 296 |
+
atten = self.conv1(atten)
|
| 297 |
+
atten = self.relu(atten)
|
| 298 |
+
atten = self.conv2(atten)
|
| 299 |
+
atten = self.sigmoid(atten)
|
| 300 |
+
feat_atten = torch.mul(feat, atten)
|
| 301 |
+
feat_out = feat_atten + feat
|
| 302 |
+
return feat_out
|
| 303 |
+
|
| 304 |
+
def init_weight(self):
|
| 305 |
+
for ly in self.children():
|
| 306 |
+
if isinstance(ly, nn.Conv2d):
|
| 307 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 308 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 309 |
+
|
| 310 |
+
def get_params(self):
|
| 311 |
+
wd_params, nowd_params = [], []
|
| 312 |
+
for name, module in self.named_modules():
|
| 313 |
+
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
| 314 |
+
wd_params.append(module.weight)
|
| 315 |
+
if not module.bias is None:
|
| 316 |
+
nowd_params.append(module.bias)
|
| 317 |
+
elif isinstance(module, nn.BatchNorm2d):
|
| 318 |
+
nowd_params += list(module.parameters())
|
| 319 |
+
return wd_params, nowd_params
|
| 320 |
+
|
| 321 |
+
|
| 322 |
+
class BiSeNet(nn.Module):
|
| 323 |
+
def __init__(self, n_classes, *args, **kwargs):
|
| 324 |
+
super(BiSeNet, self).__init__()
|
| 325 |
+
self.cp = ContextPath()
|
| 326 |
+
## here self.sp is deleted
|
| 327 |
+
self.ffm = FeatureFusionModule(256, 256)
|
| 328 |
+
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
| 329 |
+
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
| 330 |
+
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
| 331 |
+
self.init_weight()
|
| 332 |
+
|
| 333 |
+
def forward(self, x):
|
| 334 |
+
H, W = x.size()[2:]
|
| 335 |
+
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
| 336 |
+
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
| 337 |
+
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
| 338 |
+
|
| 339 |
+
feat_out = self.conv_out(feat_fuse)
|
| 340 |
+
feat_out16 = self.conv_out16(feat_cp8)
|
| 341 |
+
feat_out32 = self.conv_out32(feat_cp16)
|
| 342 |
+
|
| 343 |
+
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
| 344 |
+
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
| 345 |
+
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
| 346 |
+
return feat_out, feat_out16, feat_out32
|
| 347 |
+
|
| 348 |
+
def init_weight(self):
|
| 349 |
+
for ly in self.children():
|
| 350 |
+
if isinstance(ly, nn.Conv2d):
|
| 351 |
+
nn.init.kaiming_normal_(ly.weight, a=1)
|
| 352 |
+
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
| 353 |
+
|
| 354 |
+
def get_params(self):
|
| 355 |
+
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
| 356 |
+
for name, child in self.named_children():
|
| 357 |
+
child_wd_params, child_nowd_params = child.get_params()
|
| 358 |
+
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
| 359 |
+
lr_mul_wd_params += child_wd_params
|
| 360 |
+
lr_mul_nowd_params += child_nowd_params
|
| 361 |
+
else:
|
| 362 |
+
wd_params += child_wd_params
|
| 363 |
+
nowd_params += child_nowd_params
|
| 364 |
+
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def get_face_parsing(save_pth = 'third_party/pretrained/79999_iter.pth'):
|
| 368 |
+
net = BiSeNet(n_classes=19)
|
| 369 |
+
net.load_state_dict(torch.load(save_pth))
|
| 370 |
+
return net
|
| 371 |
+
|
| 372 |
+
|
| 373 |
+
if __name__ == "__main__":
|
| 374 |
+
net = BiSeNet(19)
|
| 375 |
+
net.cuda()
|
| 376 |
+
net.eval()
|
| 377 |
+
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
| 378 |
+
out, out16, out32 = net(in_ten)
|
| 379 |
+
print(out.shape)
|
| 380 |
+
|
| 381 |
+
net.get_params()
|
utils/third_party/model_resnet_d3dfr.py
ADDED
|
@@ -0,0 +1,554 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torch.nn import Linear
|
| 4 |
+
from torch.nn import Conv2d
|
| 5 |
+
from torch.nn import BatchNorm1d
|
| 6 |
+
from torch.nn import BatchNorm2d
|
| 7 |
+
from torch.nn import ReLU
|
| 8 |
+
from torch.nn import Dropout
|
| 9 |
+
try:
|
| 10 |
+
from torch.hub import load_state_dict_from_url
|
| 11 |
+
except ImportError:
|
| 12 |
+
from torch.utils.model_zoo import load_url as load_state_dict_from_url
|
| 13 |
+
from torch.nn import MaxPool2d
|
| 14 |
+
from torch.nn import Sequential
|
| 15 |
+
from torch.nn import Module
|
| 16 |
+
import torch
|
| 17 |
+
from torch import Tensor
|
| 18 |
+
from typing import Type, Any, Callable, Union, List, Optional
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
model_urls = {
|
| 22 |
+
'resnet50': 'https://download.pytorch.org/models/resnet50-0676ba61.pth',
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
def filter_state_dict(state_dict, remove_name='fc'):
|
| 26 |
+
new_state_dict = {}
|
| 27 |
+
for key in state_dict:
|
| 28 |
+
if remove_name in key:
|
| 29 |
+
continue
|
| 30 |
+
new_state_dict[key] = state_dict[key]
|
| 31 |
+
return new_state_dict
|
| 32 |
+
|
| 33 |
+
def conv3x3(in_planes, out_planes, stride=1):
|
| 34 |
+
""" 3x3 convolution with padding
|
| 35 |
+
"""
|
| 36 |
+
return Conv2d(in_planes,
|
| 37 |
+
out_planes,
|
| 38 |
+
kernel_size=3,
|
| 39 |
+
stride=stride,
|
| 40 |
+
padding=1,
|
| 41 |
+
bias=False)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def conv1x1(in_planes, out_planes, stride=1, bias=False):
|
| 45 |
+
""" 1x1 convolution
|
| 46 |
+
"""
|
| 47 |
+
return Conv2d(in_planes,
|
| 48 |
+
out_planes,
|
| 49 |
+
kernel_size=1,
|
| 50 |
+
stride=stride,
|
| 51 |
+
bias=bias)
|
| 52 |
+
|
| 53 |
+
def conv3x3_(in_planes: int, out_planes: int, stride: int = 1, groups: int = 1, dilation: int = 1) -> nn.Conv2d:
|
| 54 |
+
"""3x3 convolution with padding"""
|
| 55 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
| 56 |
+
padding=dilation, groups=groups, bias=False, dilation=dilation)
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def conv1x1_(in_planes: int, out_planes: int, stride: int = 1, bias: bool = False) -> nn.Conv2d:
|
| 60 |
+
"""1x1 convolution"""
|
| 61 |
+
return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=bias)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class Bottleneck(Module):
|
| 65 |
+
expansion = 4
|
| 66 |
+
|
| 67 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
| 68 |
+
super(Bottleneck, self).__init__()
|
| 69 |
+
self.conv1 = conv1x1(inplanes, planes)
|
| 70 |
+
self.bn1 = BatchNorm2d(planes)
|
| 71 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
| 72 |
+
self.bn2 = BatchNorm2d(planes)
|
| 73 |
+
self.conv3 = conv1x1(planes, planes * self.expansion)
|
| 74 |
+
self.bn3 = BatchNorm2d(planes * self.expansion)
|
| 75 |
+
self.relu = ReLU(inplace=True)
|
| 76 |
+
self.downsample = downsample
|
| 77 |
+
self.stride = stride
|
| 78 |
+
|
| 79 |
+
def forward(self, x):
|
| 80 |
+
identity = x
|
| 81 |
+
|
| 82 |
+
out = self.conv1(x)
|
| 83 |
+
out = self.bn1(out)
|
| 84 |
+
out = self.relu(out)
|
| 85 |
+
|
| 86 |
+
out = self.conv2(out)
|
| 87 |
+
out = self.bn2(out)
|
| 88 |
+
out = self.relu(out)
|
| 89 |
+
|
| 90 |
+
out = self.conv3(out)
|
| 91 |
+
out = self.bn3(out)
|
| 92 |
+
|
| 93 |
+
if self.downsample is not None:
|
| 94 |
+
identity = self.downsample(x)
|
| 95 |
+
|
| 96 |
+
out += identity
|
| 97 |
+
out = self.relu(out)
|
| 98 |
+
|
| 99 |
+
return out
|
| 100 |
+
|
| 101 |
+
class Bottleneck_(nn.Module):
|
| 102 |
+
# Bottleneck in torchvision places the stride for downsampling at 3x3 convolution(self.conv2)
|
| 103 |
+
# while original implementation places the stride at the first 1x1 convolution(self.conv1)
|
| 104 |
+
# according to "Deep residual learning for image recognition"https://arxiv.org/abs/1512.03385.
|
| 105 |
+
# This variant is also known as ResNet V1.5 and improves accuracy according to
|
| 106 |
+
# https://ngc.nvidia.com/catalog/model-scripts/nvidia:resnet_50_v1_5_for_pytorch.
|
| 107 |
+
|
| 108 |
+
expansion: int = 4
|
| 109 |
+
|
| 110 |
+
def __init__(
|
| 111 |
+
self,
|
| 112 |
+
inplanes: int,
|
| 113 |
+
planes: int,
|
| 114 |
+
stride: int = 1,
|
| 115 |
+
downsample: Optional[nn.Module] = None,
|
| 116 |
+
groups: int = 1,
|
| 117 |
+
base_width: int = 64,
|
| 118 |
+
dilation: int = 1,
|
| 119 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None
|
| 120 |
+
) -> None:
|
| 121 |
+
super(Bottleneck_, self).__init__()
|
| 122 |
+
if norm_layer is None:
|
| 123 |
+
norm_layer = nn.BatchNorm2d
|
| 124 |
+
width = int(planes * (base_width / 64.)) * groups
|
| 125 |
+
# Both self.conv2 and self.downsample layers downsample the input when stride != 1
|
| 126 |
+
self.conv1 = conv1x1_(inplanes, width)
|
| 127 |
+
self.bn1 = norm_layer(width)
|
| 128 |
+
self.conv2 = conv3x3_(width, width, stride, groups, dilation)
|
| 129 |
+
self.bn2 = norm_layer(width)
|
| 130 |
+
self.conv3 = conv1x1_(width, planes * self.expansion)
|
| 131 |
+
self.bn3 = norm_layer(planes * self.expansion)
|
| 132 |
+
self.relu = nn.ReLU(inplace=True)
|
| 133 |
+
self.downsample = downsample
|
| 134 |
+
self.stride = stride
|
| 135 |
+
|
| 136 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 137 |
+
identity = x
|
| 138 |
+
|
| 139 |
+
out = self.conv1(x)
|
| 140 |
+
out = self.bn1(out)
|
| 141 |
+
out = self.relu(out)
|
| 142 |
+
|
| 143 |
+
out = self.conv2(out)
|
| 144 |
+
out = self.bn2(out)
|
| 145 |
+
out = self.relu(out)
|
| 146 |
+
|
| 147 |
+
out = self.conv3(out)
|
| 148 |
+
out = self.bn3(out)
|
| 149 |
+
|
| 150 |
+
if self.downsample is not None:
|
| 151 |
+
identity = self.downsample(x)
|
| 152 |
+
|
| 153 |
+
out += identity
|
| 154 |
+
out = self.relu(out)
|
| 155 |
+
|
| 156 |
+
return out
|
| 157 |
+
|
| 158 |
+
class BasicBlock(nn.Module):
|
| 159 |
+
expansion: int = 1
|
| 160 |
+
|
| 161 |
+
def __init__(
|
| 162 |
+
self,
|
| 163 |
+
inplanes: int,
|
| 164 |
+
planes: int,
|
| 165 |
+
stride: int = 1,
|
| 166 |
+
downsample: Optional[nn.Module] = None,
|
| 167 |
+
groups: int = 1,
|
| 168 |
+
base_width: int = 64,
|
| 169 |
+
dilation: int = 1,
|
| 170 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None
|
| 171 |
+
) -> None:
|
| 172 |
+
super(BasicBlock, self).__init__()
|
| 173 |
+
if norm_layer is None:
|
| 174 |
+
norm_layer = nn.BatchNorm2d
|
| 175 |
+
if groups != 1 or base_width != 64:
|
| 176 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
| 177 |
+
if dilation > 1:
|
| 178 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
| 179 |
+
# Both self.conv1 and self.downsample layers downsample the input when stride != 1
|
| 180 |
+
self.conv1 = conv3x3(inplanes, planes, stride)
|
| 181 |
+
self.bn1 = norm_layer(planes)
|
| 182 |
+
self.relu = nn.ReLU(inplace=True)
|
| 183 |
+
self.conv2 = conv3x3(planes, planes)
|
| 184 |
+
self.bn2 = norm_layer(planes)
|
| 185 |
+
self.downsample = downsample
|
| 186 |
+
self.stride = stride
|
| 187 |
+
|
| 188 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 189 |
+
identity = x
|
| 190 |
+
|
| 191 |
+
out = self.conv1(x)
|
| 192 |
+
out = self.bn1(out)
|
| 193 |
+
out = self.relu(out)
|
| 194 |
+
|
| 195 |
+
out = self.conv2(out)
|
| 196 |
+
out = self.bn2(out)
|
| 197 |
+
|
| 198 |
+
if self.downsample is not None:
|
| 199 |
+
identity = self.downsample(x)
|
| 200 |
+
|
| 201 |
+
out += identity
|
| 202 |
+
out = self.relu(out)
|
| 203 |
+
|
| 204 |
+
return out
|
| 205 |
+
|
| 206 |
+
class ResNet(Module):
|
| 207 |
+
""" ResNet backbone
|
| 208 |
+
"""
|
| 209 |
+
def __init__(self, input_size, block, layers, zero_init_residual=True):
|
| 210 |
+
""" Args:
|
| 211 |
+
input_size: input_size of backbone
|
| 212 |
+
block: block function
|
| 213 |
+
layers: layers in each block
|
| 214 |
+
"""
|
| 215 |
+
super(ResNet, self).__init__()
|
| 216 |
+
assert input_size[0] in [112, 224], \
|
| 217 |
+
"input_size should be [112, 112] or [224, 224]"
|
| 218 |
+
self.inplanes = 64
|
| 219 |
+
self.conv1 = Conv2d(3, 64,
|
| 220 |
+
kernel_size=7,
|
| 221 |
+
stride=2,
|
| 222 |
+
padding=3,
|
| 223 |
+
bias=False)
|
| 224 |
+
self.bn1 = BatchNorm2d(64)
|
| 225 |
+
self.relu = ReLU(inplace=True)
|
| 226 |
+
self.maxpool = MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 227 |
+
self.layer1 = self._make_layer(block, 64, layers[0])
|
| 228 |
+
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
| 229 |
+
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
| 230 |
+
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
| 231 |
+
|
| 232 |
+
self.bn_o1 = BatchNorm2d(2048)
|
| 233 |
+
self.dropout = Dropout()
|
| 234 |
+
if input_size[0] == 112:
|
| 235 |
+
self.fc = Linear(2048 * 4 * 4, 512)
|
| 236 |
+
else:
|
| 237 |
+
self.fc = Linear(2048 * 7 * 7, 512)
|
| 238 |
+
self.bn_o2 = BatchNorm1d(512)
|
| 239 |
+
|
| 240 |
+
# initialize_weights(self.modules)
|
| 241 |
+
if zero_init_residual:
|
| 242 |
+
for m in self.modules():
|
| 243 |
+
if isinstance(m, Bottleneck):
|
| 244 |
+
nn.init.constant_(m.bn3.weight, 0)
|
| 245 |
+
|
| 246 |
+
def _make_layer(self, block, planes, blocks, stride=1):
|
| 247 |
+
downsample = None
|
| 248 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
| 249 |
+
downsample = Sequential(
|
| 250 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
| 251 |
+
BatchNorm2d(planes * block.expansion),
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
layers = []
|
| 255 |
+
layers.append(block(self.inplanes, planes, stride, downsample))
|
| 256 |
+
self.inplanes = planes * block.expansion
|
| 257 |
+
for _ in range(1, blocks):
|
| 258 |
+
layers.append(block(self.inplanes, planes))
|
| 259 |
+
|
| 260 |
+
return Sequential(*layers)
|
| 261 |
+
|
| 262 |
+
def forward(self, x):
|
| 263 |
+
x = self.conv1(x)
|
| 264 |
+
x = self.bn1(x)
|
| 265 |
+
x = self.relu(x)
|
| 266 |
+
x = self.maxpool(x)
|
| 267 |
+
|
| 268 |
+
x = self.layer1(x)
|
| 269 |
+
x = self.layer2(x)
|
| 270 |
+
x = self.layer3(x)
|
| 271 |
+
x = self.layer4(x)
|
| 272 |
+
|
| 273 |
+
x = self.bn_o1(x)
|
| 274 |
+
x = self.dropout(x)
|
| 275 |
+
x = x.view(x.size(0), -1)
|
| 276 |
+
x = self.fc(x)
|
| 277 |
+
x = self.bn_o2(x)
|
| 278 |
+
|
| 279 |
+
return x
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class resNet(nn.Module): # ori resnet
|
| 283 |
+
|
| 284 |
+
def __init__(
|
| 285 |
+
self,
|
| 286 |
+
block_: Type[Union[BasicBlock, Bottleneck_]],
|
| 287 |
+
layers: List[int],
|
| 288 |
+
num_classes: int = 1000,
|
| 289 |
+
zero_init_residual: bool = False,
|
| 290 |
+
use_last_fc: bool = False,
|
| 291 |
+
groups: int = 1,
|
| 292 |
+
width_per_group: int = 64,
|
| 293 |
+
replace_stride_with_dilation: Optional[List[bool]] = None,
|
| 294 |
+
norm_layer: Optional[Callable[..., nn.Module]] = None
|
| 295 |
+
) -> None:
|
| 296 |
+
super(resNet, self).__init__()
|
| 297 |
+
if norm_layer is None:
|
| 298 |
+
norm_layer = nn.BatchNorm2d
|
| 299 |
+
self._norm_layer = norm_layer
|
| 300 |
+
|
| 301 |
+
self.inplanes = 64
|
| 302 |
+
self.dilation = 1
|
| 303 |
+
if replace_stride_with_dilation is None:
|
| 304 |
+
# each element in the tuple indicates if we should replace
|
| 305 |
+
# the 2x2 stride with a dilated convolution instead
|
| 306 |
+
replace_stride_with_dilation = [False, False, False]
|
| 307 |
+
if len(replace_stride_with_dilation) != 3:
|
| 308 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
| 309 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
| 310 |
+
self.use_last_fc = use_last_fc
|
| 311 |
+
self.groups = groups
|
| 312 |
+
self.base_width = width_per_group
|
| 313 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=7, stride=2, padding=3,
|
| 314 |
+
bias=False)
|
| 315 |
+
self.bn1 = norm_layer(self.inplanes)
|
| 316 |
+
self.relu = nn.ReLU(inplace=True)
|
| 317 |
+
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
| 318 |
+
self.layer1 = self._make_layer(block_, 64, layers[0])
|
| 319 |
+
self.layer2 = self._make_layer(block_, 128, layers[1], stride=2,
|
| 320 |
+
dilate=replace_stride_with_dilation[0])
|
| 321 |
+
self.layer3 = self._make_layer(block_, 256, layers[2], stride=2,
|
| 322 |
+
dilate=replace_stride_with_dilation[1])
|
| 323 |
+
self.layer4 = self._make_layer(block_, 512, layers[3], stride=2,
|
| 324 |
+
dilate=replace_stride_with_dilation[2])
|
| 325 |
+
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
| 326 |
+
|
| 327 |
+
if self.use_last_fc:
|
| 328 |
+
self.fc = nn.Linear(512 * block_.expansion, num_classes)
|
| 329 |
+
|
| 330 |
+
for m in self.modules():
|
| 331 |
+
if isinstance(m, nn.Conv2d):
|
| 332 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
| 333 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
| 334 |
+
nn.init.constant_(m.weight, 1)
|
| 335 |
+
nn.init.constant_(m.bias, 0)
|
| 336 |
+
|
| 337 |
+
# Zero-initialize the last BN in each residual branch,
|
| 338 |
+
# so that the residual branch starts with zeros, and each residual block behaves like an identity.
|
| 339 |
+
# This improves the model by 0.2~0.3% according to https://arxiv.org/abs/1706.02677
|
| 340 |
+
if zero_init_residual:
|
| 341 |
+
for m in self.modules():
|
| 342 |
+
if isinstance(m, Bottleneck_):
|
| 343 |
+
nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]
|
| 344 |
+
elif isinstance(m, BasicBlock):
|
| 345 |
+
nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]
|
| 346 |
+
|
| 347 |
+
def _make_layer(self, block_: Type[Union[BasicBlock, Bottleneck_]], planes: int, blocks: int,
|
| 348 |
+
stride: int = 1, dilate: bool = False) -> nn.Sequential:
|
| 349 |
+
norm_layer = self._norm_layer
|
| 350 |
+
downsample = None
|
| 351 |
+
previous_dilation = self.dilation
|
| 352 |
+
if dilate:
|
| 353 |
+
self.dilation *= stride
|
| 354 |
+
stride = 1
|
| 355 |
+
if stride != 1 or self.inplanes != planes * block_.expansion:
|
| 356 |
+
downsample = nn.Sequential(
|
| 357 |
+
conv1x1(self.inplanes, planes * block_.expansion, stride),
|
| 358 |
+
norm_layer(planes * block_.expansion),
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
layers = []
|
| 362 |
+
layers.append(block_(self.inplanes, planes, stride, downsample, self.groups,
|
| 363 |
+
self.base_width, previous_dilation, norm_layer))
|
| 364 |
+
self.inplanes = planes * block_.expansion
|
| 365 |
+
for _ in range(1, blocks):
|
| 366 |
+
layers.append(block_(self.inplanes, planes, groups=self.groups,
|
| 367 |
+
base_width=self.base_width, dilation=self.dilation,
|
| 368 |
+
norm_layer=norm_layer))
|
| 369 |
+
|
| 370 |
+
return nn.Sequential(*layers)
|
| 371 |
+
|
| 372 |
+
def _forward_impl(self, x: Tensor) -> Tensor:
|
| 373 |
+
# See note [TorchScript super()]
|
| 374 |
+
x = self.conv1(x)
|
| 375 |
+
x = self.bn1(x)
|
| 376 |
+
x = self.relu(x)
|
| 377 |
+
x = self.maxpool(x)
|
| 378 |
+
|
| 379 |
+
x = self.layer1(x)
|
| 380 |
+
x = self.layer2(x)
|
| 381 |
+
x = self.layer3(x)
|
| 382 |
+
x = self.layer4(x)
|
| 383 |
+
|
| 384 |
+
x = self.avgpool(x)
|
| 385 |
+
if self.use_last_fc:
|
| 386 |
+
x = torch.flatten(x, 1)
|
| 387 |
+
x = self.fc(x)
|
| 388 |
+
return x
|
| 389 |
+
|
| 390 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 391 |
+
return self._forward_impl(x)
|
| 392 |
+
|
| 393 |
+
def ResNet_50(input_size, **kwargs):
|
| 394 |
+
""" Constructs a ResNet-50 model.
|
| 395 |
+
"""
|
| 396 |
+
model = ResNet(input_size, Bottleneck, [3, 4, 6, 3], **kwargs)
|
| 397 |
+
|
| 398 |
+
return model
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class ResNet50_nofc(Module):
|
| 402 |
+
""" ResNet backbone
|
| 403 |
+
"""
|
| 404 |
+
def __init__(self, input_size, output_dim, use_last_fc=False, init_path=None):
|
| 405 |
+
""" Args:
|
| 406 |
+
input_size: input_size of backbone
|
| 407 |
+
block: block function
|
| 408 |
+
layers: layers in each block
|
| 409 |
+
"""
|
| 410 |
+
super(ResNet50_nofc, self).__init__()
|
| 411 |
+
assert input_size[0] in [112, 224, 256], \
|
| 412 |
+
"input_size should be [112, 112] or [224, 224]"
|
| 413 |
+
func, last_dim = func_dict['resnet50']
|
| 414 |
+
self.use_last_fc=use_last_fc
|
| 415 |
+
backbone = func(use_last_fc=use_last_fc, num_classes=output_dim)
|
| 416 |
+
if init_path and os.path.isfile(init_path):
|
| 417 |
+
state_dict = filter_state_dict(torch.load(init_path, map_location='cpu'))
|
| 418 |
+
backbone.load_state_dict(state_dict)
|
| 419 |
+
print("Loading init recon %s from %s"%('resnet50', init_path))
|
| 420 |
+
self.backbone = backbone
|
| 421 |
+
if not use_last_fc:
|
| 422 |
+
self.fianl_layers = nn.ModuleList([
|
| 423 |
+
conv1x1(last_dim, 80, bias=True), # id
|
| 424 |
+
conv1x1(last_dim, 64, bias=True), # exp
|
| 425 |
+
conv1x1(last_dim, 80, bias=True), # tex
|
| 426 |
+
conv1x1(last_dim, 3, bias=True), # angle
|
| 427 |
+
conv1x1(last_dim, 27, bias=True), # gamma
|
| 428 |
+
conv1x1(last_dim, 2, bias=True), # tx, ty
|
| 429 |
+
conv1x1(last_dim, 1, bias=True), # tz
|
| 430 |
+
conv1x1(last_dim, 4, bias=True) # pupil
|
| 431 |
+
])
|
| 432 |
+
for m in self.fianl_layers:
|
| 433 |
+
nn.init.constant_(m.weight, 0.)
|
| 434 |
+
nn.init.constant_(m.bias, 0.)
|
| 435 |
+
|
| 436 |
+
|
| 437 |
+
def forward(self, x):
|
| 438 |
+
x = self.backbone(x)
|
| 439 |
+
if not self.use_last_fc:
|
| 440 |
+
output = []
|
| 441 |
+
for layer in self.fianl_layers:
|
| 442 |
+
output.append(layer(x))
|
| 443 |
+
x = torch.flatten(torch.cat(output, dim=1), 1)
|
| 444 |
+
return x
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def _resnet(
|
| 448 |
+
arch: str,
|
| 449 |
+
block: Type[Union[BasicBlock, Bottleneck_]],
|
| 450 |
+
layers: List[int],
|
| 451 |
+
pretrained: bool,
|
| 452 |
+
progress: bool,
|
| 453 |
+
**kwargs: Any
|
| 454 |
+
) -> ResNet:
|
| 455 |
+
model = resNet(block, layers, **kwargs)
|
| 456 |
+
if pretrained:
|
| 457 |
+
state_dict = load_state_dict_from_url(model_urls[arch],
|
| 458 |
+
progress=progress)
|
| 459 |
+
model.load_state_dict(state_dict)
|
| 460 |
+
return model
|
| 461 |
+
|
| 462 |
+
def resnet50(pretrained: bool = False, progress: bool = True, **kwargs: Any) -> resNet:
|
| 463 |
+
r"""ResNet-50 model from
|
| 464 |
+
`"Deep Residual Learning for Image Recognition" <https://arxiv.org/pdf/1512.03385.pdf>`_.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
pretrained (bool): If True, returns a model pre-trained on ImageNet
|
| 468 |
+
progress (bool): If True, displays a progress bar of the download to stderr
|
| 469 |
+
"""
|
| 470 |
+
return _resnet('resnet50', Bottleneck_, [3, 4, 6, 3], pretrained, progress,
|
| 471 |
+
**kwargs)
|
| 472 |
+
|
| 473 |
+
|
| 474 |
+
func_dict = {
|
| 475 |
+
'resnet50': (resnet50, 2048),
|
| 476 |
+
}
|
| 477 |
+
|
| 478 |
+
|
| 479 |
+
class Identity(nn.Module):
|
| 480 |
+
def __init__(self):
|
| 481 |
+
super(Identity, self).__init__()
|
| 482 |
+
|
| 483 |
+
def forward(self, x):
|
| 484 |
+
return x
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
def fuse(conv, bn):
|
| 488 |
+
w = conv.weight
|
| 489 |
+
mean = bn.running_mean
|
| 490 |
+
var_sqrt = torch.sqrt(bn.running_var + bn.eps)
|
| 491 |
+
|
| 492 |
+
beta = bn.weight
|
| 493 |
+
gamma = bn.bias
|
| 494 |
+
|
| 495 |
+
if conv.bias is not None:
|
| 496 |
+
b = conv.bias
|
| 497 |
+
else:
|
| 498 |
+
b = mean.new_zeros(mean.shape)
|
| 499 |
+
|
| 500 |
+
w = w * (beta / var_sqrt).reshape([conv.out_channels, 1, 1, 1])
|
| 501 |
+
b = (b - mean) / var_sqrt * beta + gamma
|
| 502 |
+
|
| 503 |
+
fused_conv = nn.Conv2d(
|
| 504 |
+
conv.in_channels,
|
| 505 |
+
conv.out_channels,
|
| 506 |
+
conv.kernel_size,
|
| 507 |
+
conv.stride,
|
| 508 |
+
conv.padding,
|
| 509 |
+
conv.dilation,
|
| 510 |
+
conv.groups,
|
| 511 |
+
bias=True,
|
| 512 |
+
padding_mode=conv.padding_mode
|
| 513 |
+
)
|
| 514 |
+
fused_conv.weight = nn.Parameter(w)
|
| 515 |
+
fused_conv.bias = nn.Parameter(b)
|
| 516 |
+
return fused_conv
|
| 517 |
+
|
| 518 |
+
|
| 519 |
+
def fuse_module(m):
|
| 520 |
+
children = list(m.named_children())
|
| 521 |
+
conv = None
|
| 522 |
+
conv_name = None
|
| 523 |
+
for name, child in children:
|
| 524 |
+
if isinstance(child, nn.BatchNorm2d) and conv:
|
| 525 |
+
bc = fuse(conv, child)
|
| 526 |
+
m._modules[conv_name] = bc
|
| 527 |
+
m._modules[name] = Identity()
|
| 528 |
+
conv = None
|
| 529 |
+
elif isinstance(child, nn.Conv2d):
|
| 530 |
+
conv = child
|
| 531 |
+
conv_name = name
|
| 532 |
+
else:
|
| 533 |
+
fuse_module(child)
|
| 534 |
+
|
| 535 |
+
|
| 536 |
+
def getd3dfr_res50(pretrained="./d3dfr_res50_nofc.pth"):
|
| 537 |
+
model = ResNet50_nofc([256, 256], 257+4, use_last_fc=False)
|
| 538 |
+
for param in model.parameters():
|
| 539 |
+
param.requires_grad=False
|
| 540 |
+
if pretrained is not None and os.path.exists(pretrained):
|
| 541 |
+
checkpoint_no_module = {}
|
| 542 |
+
checkpoint = torch.load(pretrained, map_location=lambda storage, loc: storage)
|
| 543 |
+
for k, v in checkpoint.items():
|
| 544 |
+
if k.startswith('module'):
|
| 545 |
+
k = k[7:]
|
| 546 |
+
checkpoint_no_module[k] = v
|
| 547 |
+
info = model.load_state_dict(checkpoint_no_module, strict=False)
|
| 548 |
+
|
| 549 |
+
print(pretrained, info)
|
| 550 |
+
model = model.eval()
|
| 551 |
+
fuse_module(model)
|
| 552 |
+
return model
|
| 553 |
+
if __name__ == '__main__':
|
| 554 |
+
model = getd3dfr_res50()
|
utils/third_party_files/79999_iter.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:468e13ca13a9b43cc0881a9f99083a430e9c0a38abd935431d1c28ee94b26567
|
| 3 |
+
size 53289463
|
utils/third_party_files/BFM_model_front.mat
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:a9f127eb18c3d022acbdbfcf1b6b353d01a6e01785d675a67cc31a3826a5be0f
|
| 3 |
+
size 127170280
|
utils/third_party_files/d3dfr_res50_nofc.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:52c54b90304a06c16b6813910c26faff1a907d4f8bd69a71ad4ecff43b41a090
|
| 3 |
+
size 96449126
|
utils/third_party_files/insightface_glint360k.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5f631718e783448b41631e15073bdc622eaeef56509bbad4e5085f23bd32db83
|
| 3 |
+
size 261223796
|
utils/third_party_files/models/antelopev2/1k3d68.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:df5c06b8a0c12e422b2ed8947b8869faa4105387f199c477af038aa01f9a45cc
|
| 3 |
+
size 143607619
|
utils/third_party_files/models/antelopev2/2d106det.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f001b856447c413801ef5c42091ed0cd516fcd21f2d6b79635b1e733a7109dbf
|
| 3 |
+
size 5030888
|
utils/third_party_files/models/antelopev2/antelopev2.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7353a5fdca5a90e11d2792e0236032b2fe42adc1ea23eaef5cf8c8b57e7e9393
|
| 3 |
+
size 360743453
|
utils/third_party_files/models/antelopev2/genderage.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4fde69b1c810857b88c64a335084f1c3fe8f01246c9a191b48c7bb756d6652fb
|
| 3 |
+
size 1322532
|
utils/third_party_files/models/antelopev2/glintr100.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:4ab1d6435d639628a6f3e5008dd4f929edf4c4124b1a7169e1048f9fef534cdf
|
| 3 |
+
size 260665334
|
utils/third_party_files/models/antelopev2/scrfd_10g_bnkps.onnx
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5838f7fe053675b1c7a08b633df49e7af5495cee0493c7dcf6697200b85b5b91
|
| 3 |
+
size 16923827
|
utils/third_party_files/resnet18-5c106cde.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:5c106cde386e87d4033832f2996f5493238eda96ccf559d1d62760c4de0613f8
|
| 3 |
+
size 46827520
|