Skip to content

Commit f5b1333

Browse files
committed
update for direct pose manipulation
1 parent 2aefed5 commit f5b1333

12 files changed

+958
-59
lines changed

LICENSE

-21
This file was deleted.

README.md

+38-19
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ This is an unofficial implementation of the paper ["Surrogate Gradient Field for
44

55
![sgf_result](./docs/sgf_result.jpg)
66

7-
### (Jul. 16, 2021) Current issues in the result (TODO, working on)
8-
- the ID of face changes
9-
- how to fix? Will add more supervision (binary attributes)
7+
The author leveraged diverse labels (e.g., age, gender, smile, ...) using [MS Face API](https://azure.microsoft.com/en-us/services/cognitive-services/face/). In the experiment, I only used pose values in a soft manner (0.0 ~ 1.0) for my own research. Empirically, the result shows a smooth transition compared to the manipulation learned by hard labels. I believe adding more labels as the authors did in their work will make the transition more robust (e.g., id or characteristics of the input image is sustained while manipulating it).
108

119

1210
## Requirements
@@ -35,7 +33,7 @@ pip install numba
3533
## Run SGF
3634
To see the manipulation result:
3735
```
38-
python sgf.py --G_path 'path/to/generator.pkl' --SE_path 'path/to/se.pth' --AUX_path 'path/to/aux.pth' --save_result 1
36+
python sgf_pose.py --G_path 'path/to/generator.pkl' --SE_path 'path/to/se.pth' --AUX_path 'path/to/aux.pth' --save_result 1
3937
```
4038

4139
---
@@ -52,7 +50,7 @@ python generate.py --outdir=data/test/images --seeds=100500,101000 --resize 256
5250
```
5351

5452
### Step 2: Label images [`c`]
55-
- Label images using Azure Face API / open source Face landmark detection algorithm
53+
- Label images using Azure Face API / open source Face landmark detection algorithm to infer pose (yaw, roll, pitch)
5654
```
5755
python face_align.py --indir train
5856
python face_align.py --indir val
@@ -64,29 +62,42 @@ python face_align.py --indir test
6462
python face_align.py --indir test --plot 1
6563
```
6664

65+
- Next, infer the face pose values (e.g., yaw, roll, pitch)
66+
```
67+
python pose_estimation.py --image_dir data/train/
68+
python pose_estimation.py --image_dir data/val/
69+
python pose_estimation.py --image_dir data/test/
70+
```
71+
72+
- If you want to see the pose result
73+
```
74+
python pose_estimation.py --image_dir data/test/ --save_img 1
75+
```
6776

6877
### Step 3: Fine-tune Squeeze and Excitation Network using images [`x`] and labels [`c`]
6978
- Used is SE ResNet 50 pretrained on VGG Face2 dataset
7079
```
71-
python finetune.py --pretrained_path 'path/to/model.pkl'
72-
python finetune.py --mode test --model_path 'path/to/model.pth'
80+
python finetune_pose.py
81+
python finetune_pose.py --mode test --model_path path/to/model.pth
7382
```
7483

84+
7585
### Step 4: Train Auxiliary (FC-layer) Network [`mapping: (z, c) -> z`]
7686
- 6 FC layers for Z space, and 15 layers for W space
7787
- AdaIN is used to mix features (`z` and `c`) in the same way as StyleGAN v1
7888
- Refer to Appendix B in the paper
7989

8090
```
81-
python fc_layer.py --ckpt_dir 'path/to/save_dir'
82-
python fc_layer.py --mode test --ckpt_dir 'path/to/save_dir' --ckpt_fname 'filename.pth'
91+
python fc_layer_pose.py --ckpt_dir 'path/to/save_dir'
92+
python fc_layer_pose.py --mode test --ckpt_dir 'path/to/save_dir' --ckpt_fname 'filename.pth'
8393
```
8494

95+
8596
### Step 5: Calculate gradient in the surrogate gradient field and update [`z`]
8697
- Refer to Algo 1 in the original paper
8798
- Manipulate C to suit your purpose
8899
```
89-
python sgf.py --G_path 'path/to/generator.pkl' --SE_path 'path/to/se.pth' --AUX_path 'path/to/aux.pth' --save_result 1
100+
python sgf_pose.py --G_path 'path/to/generator.pkl' --SE_path 'path/to/se.pth' --AUX_path 'path/to/aux.pth' --save_result 1
90101
```
91102

92103

@@ -96,12 +107,20 @@ Many thanks to the first author of the original paper, [Minjun Li](https://minju
96107
## References
97108
- [Li, M., Jin, Y., & Zhu, H. (2021). Surrogate Gradient Field for Latent Space Manipulation. In Proceedings of the IEEE/CVF Conference on Computer Vision ](https://arxiv.org/abs/2104.09065)
98109

99-
Also, the implementaion is based on many works:
100-
- [Face Alignment](https://arxiv.org/abs/1703.07332)
101-
- [Official code](https://github.com/1adrianb/face-alignment)
102-
- [StyleGAN2](https://arxiv.org/abs/1912.04958)
103-
- [Official code](https://github.com/NVlabs/stylegan2-ada-pytorch)
104-
- [SENet](https://arxiv.org/abs/1709.01507?spm=a2c41.13233144.0.0) & [VGG Face2 dataset](https://arxiv.org/abs/1710.08092)
105-
- [Official code](https://github.com/ox-vgg/vgg_face2)
106-
- [Pytorch code](https://github.com/cydonia999/VGGFace2-pytorch)
107-
- [AdaIN](https://arxiv.org/abs/1703.06868)
110+
111+
## Credits
112+
113+
**StyleGAN2-ADA:**
114+
https://github.com/NVlabs/stylegan2-ada-pytorch
115+
Copyright (c) 2021, NVIDIA Corporation
116+
NVIDIA Source Code License https://github.com/NVlabs/stylegan2-ada-pytorch/blob/main/LICENSE.txt
117+
118+
**Face Alignment:**
119+
https://github.com/1adrianb/face-alignment
120+
Copyright (c) 2017, Adrian Bulat
121+
License (BSD 3-Clause) https://github.com/1adrianb/face-alignment/blob/master/LICENSE
122+
123+
**VGG Face2 Datset & Squeeze and Excitation Network:**
124+
https://github.com/cydonia999/VGGFace2-pytorch
125+
Copyright (c) 2018 cydonia
126+
License (MIT) https://github.com/cydonia999/VGGFace2-pytorch/blob/master/LICENSE

README_landmark.md

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
# To train on landmark labels
2+
3+
4+
### Step 1: Sample image generation using StyleGAN2 [`x`]
5+
- Generate 100K samples images using StyleGAN2 to train SENet
6+
```
7+
python generate.py --outdir=data/train/images --seeds=0,100000 --resize 256
8+
python generate.py --outdir=data/val/images --seeds=100000,100500 --resize 256
9+
python generate.py --outdir=data/test/images --seeds=100500,101000 --resize 256
10+
```
11+
12+
### Step 2: Label images [`c`]
13+
- Label images using Azure Face API / open source Face landmark detection algorithm
14+
```
15+
python face_align.py --indir train
16+
python face_align.py --indir val
17+
python face_align.py --indir test
18+
```
19+
20+
- If you want to see the landmark result
21+
```
22+
python face_align.py --indir test --plot 1
23+
```
24+
25+
26+
### Step 3: Fine-tune Squeeze and Excitation Network using images [`x`] and labels [`c`]
27+
- Used is SE ResNet 50 pretrained on VGG Face2 dataset
28+
```
29+
python finetune.py --pretrained_path 'path/to/model.pkl'
30+
python finetune.py --mode test --model_path 'path/to/model.pth'
31+
```
32+
33+
34+
### Step 4: Train Auxiliary (FC-layer) Network [`mapping: (z, c) -> z`]
35+
- 6 FC layers for Z space, and 15 layers for W space
36+
- AdaIN is used to mix features (`z` and `c`) in the same way as StyleGAN v1
37+
- Refer to Appendix B in the paper
38+
39+
```
40+
python fc_layer.py --ckpt_dir 'path/to/save_dir'
41+
python fc_layer.py --mode test --ckpt_dir 'path/to/save_dir' --ckpt_fname 'filename.pth'
42+
```
43+
44+
45+
### Step 5: Calculate gradient in the surrogate gradient field and update [`z`]
46+
- Refer to Algo 1 in the original paper
47+
- Manipulate C to suit your purpose
48+
```
49+
python sgf.py --G_path 'path/to/generator.pkl' --SE_path 'path/to/se.pth' --AUX_path 'path/to/aux.pth' --save_result 1
50+
```

datasets/sg2.py

+4-4
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
class StyleGAN2_Data(datasets.ImageFolder):
88

9-
def __init__(self, root='data', split='train', latent_dim=512):
9+
def __init__(self, root='data', split='train', lname='landmarks', latent_dim=512):
1010
super(StyleGAN2_Data, self).__init__(root)
1111

1212
assert os.path.exists(root), "root: {} not found.".format(root)
@@ -20,14 +20,14 @@ def __init__(self, root='data', split='train', latent_dim=512):
2020

2121
# self.labels = np.load(os.path.join(root, split, 'npy', 'landmarks.npy'))
2222
if split == 'train' or split == 'train_all':
23-
self.labels_original = np.load(os.path.join(root, split, 'npy', 'landmarks.npy'))
23+
self.labels_original = np.load(os.path.join(root, split, 'npy', f'{lname}.npy'))
2424
self.labels = self.scale_label(self.labels_original)
2525

2626
elif split == 'val' or 'test':
27-
self.labels_original = np.load(os.path.join(root, 'train_all', 'npy', 'landmarks.npy'))
27+
self.labels_original = np.load(os.path.join(root, 'train_all', 'npy', f'{lname}.npy'))
2828
self.scale_label(self.labels_original)
2929

30-
self.labels_original = np.load(os.path.join(root, split, 'npy', 'landmarks.npy'))
30+
self.labels_original = np.load(os.path.join(root, split, 'npy', f'{lname}.npy'))
3131
self.labels = self.scale_val_label(self.labels_original)
3232
else:
3333
raise ValueError(f"split was not set correctly split = ['train', 'val', 'test'] not {split}")

datasets/vggface2_sg2.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -8,32 +8,32 @@
88

99
class StyleGAN2_Data(datasets.ImageFolder):
1010

11-
def __init__(self, root='data/', split='train', transform=None, scale_size=-1):
11+
def __init__(self, root='data/', split='train', lname='landmarks', transform=None, scale_size=-1):
1212
super(StyleGAN2_Data, self).__init__(root)
1313

1414
assert os.path.exists(root), "root: {} not found.".format(root)
1515

1616
self.root = os.path.join(root, split)
1717
self.split = split
1818
self.transform = transform
19-
self.scaler = MinMaxScaler(feature_range = (-1, 1))
19+
self.scaler = MinMaxScaler(feature_range = (0, 1))
2020
self.scale_size = scale_size
2121

2222
if split == 'train':
23-
self.labels_original = np.load(os.path.join(root, 'train', 'npy', 'landmarks.npy'))
23+
self.labels_original = np.load(os.path.join(root, 'train', 'npy', f'{lname}.npy'))
2424
if scale_size > 0:
2525
self.labels = self.scale_label(self.labels_original / scale_size * INPUT_SIZE)
2626
else:
2727
self.labels = self.scale_label(self.labels_original)
2828

2929
elif split == 'val' or 'test':
30-
self.labels_original = np.load(os.path.join(root, 'train', 'npy', 'landmarks.npy'))
30+
self.labels_original = np.load(os.path.join(root, 'train', 'npy', f'{lname}.npy'))
3131
if scale_size > 0:
3232
self.scale_label(self.labels_original / scale_size * INPUT_SIZE)
3333
else:
3434
self.scale_label(self.labels_original)
3535

36-
self.labels_original = np.load(os.path.join(root, split, 'npy', 'landmarks.npy'))
36+
self.labels_original = np.load(os.path.join(root, split, 'npy', f'{lname}.npy'))
3737
if scale_size > 0:
3838
self.labels = self.scale_val_label(self.labels_original / scale_size * INPUT_SIZE)
3939
else:

docs/sgf_result.jpg

-30.7 KB
Loading

0 commit comments

Comments
 (0)