PyTorch implementation for the training procedure described in Driver Gaze Estimation in the Real World: Overcoming the Eyeglass Challenge.
Parts of the CycleGAN code have been adapted from the PyTorch-CycleGAN respository.
- Clone this repository
- Install Pipenv:
pip3 install pipenv
- Install all requirements and dependencies in a new virtual environment using Pipenv:
cd GPCycleGAN
pipenv install
- Get link for desired PyTorch and Torchvision wheel from here and install it in the Pipenv virtual environment as follows:
pipenv install https://download.pytorch.org/whl/cu100/torch-1.2.0-cp36-cp36m-manylinux1_x86_64.whl
pipenv install https://download.pytorch.org/whl/cu100/torchvision-0.3.0-cp36-cp36m-linux_x86_64.whl
- Download the complete IR dataset for driver gaze classification using this link.
- Unzip the file.
- Prepare the train, val and test splits as follows:
python prepare_gaze_data.py --dataset-dir=/path/to/lisat_gaze_data
The prescribed three-step training procedure for the classification network can be carried out as follows:
pipenv shell # activate virtual environment
python gazenet.py --dataset-root-path=/path/to/lisat_gaze_data/all_data/ --version=1_1 --snapshot=./weights/squeezenet1_1_imagenet.pth --random-transforms
python gpcyclegan.py --dataset-root-path=/path/to/lisat_gaze_data/ --version=1_1 --snapshot-dir=/path/to/trained/gaze-classifier/directory/ --random-transforms
python create_fake_images.py --dataset-root-path=/path/to/lisat_gaze_data/all_data/ --version=1_1 --snapshot-dir=/path/to/trained/gpcyclegan/directory/
cp /path/to/lisat_gaze_data/all_data/mean_std.mat /path/to/fake_data/mean_std.mat # copy over dataset mean/std information to fake data folder
python gazenet-ft.py --dataset-root-path=/path/to/fake_data/ --version=1_1 --snapshot-dir=/path/to/trained/gaze-classifier/directory/ --random-transforms
exit # exit virtual environment
Inference can be carried out using this script as follows:
pipenv shell # activate virtual environment
python infer.py --dataset-root-path=/path/to/lisat_gaze_data/all_data/ --split=test --version=1_1 --snapshot-dir=/path/to/trained/models/directory/
exit # exit virtual environment
You can download our pre-trained (GPCycleGAN + gaze classifier) weights using this link.
Config files, logs, results and snapshots from running the above scripts will be stored in the GPCycleGAN/experiments
folder by default.