A Modular and Extensible Framework for 3D Talking Head Generation Research
β Star us on GitHub if this project helps your research!
This repository provides a foundational framework for any AI model training project. It serves as a base for accumulating and reusing essential model code, enabling rapid development of custom modules and avoiding reinventing the wheel.
The framework adopts a decoupled trainer architecture that automatically manages the entire pipelineβfrom data loading to model evaluationβwith a robust configuration management system.
By embracing structured programming, complex code is divided into independent modules, greatly improving code standardization, maintainability, and readability.
Key Features:
- π§ Modular Architecture: Decoupled components for easy extension and customization
- π¨ DiffPoseTalk Model: Implements diffusion-based talking head generation with style encoding
- π Unified Training Framework: Trainer-based system with full pipeline automation
- βοΈ Flexible Configuration: YACS-based hierarchical configuration management
- π Experiment Tracking: Built-in TensorBoard and WandB support
- π Production Ready: Comprehensive logging, checkpointing, and evaluation tools
Note
This project currently implements state-of-the-art (SOTA) methods for 3D talking head generation, specifically the DiffPoseTalk model. We are actively developing our own research methods to further advance the field.
Note
This project is modified from Dassl, making it more user-friendly and structured. It includes additional modules tailored for 3D Talking Head research, such as datasets for 3D Talking Head studies and FLAME-based rendering components.
- Add DistributedDataParallel (DDP) Support
3DTalkingHeadCodeBase/base/base_trainer.py
Lines 185 to 204 in 8424b81
- Develop support for audio-visual dataset collection
- Design and implement an audio-visual data collection workflow
- Provide tools for data annotation and preprocessing
- Integrate with the existing data management and training pipeline
- Implement Mesh Rendering using
pytorch3d.renderer - Develop a FLAME texture rendering pipeline
3DTalkingHeadCodeBase/
βββ base/ # Core base classes
β βββ base_config.py # Configuration base class
β βββ base_dataset.py # Dataset base class
β βββ base_datamanager.py # Data manager base class
β βββ base_model.py # Model base class
β βββ base_trainer.py # Trainer base class
β βββ base_evaluator.py # Evaluator base class
βββ config/ # Configuration files
β βββ difftalk_trainer_config.yaml # DiffPoseTalk trainer config
β βββ style_trainer_config.yaml # Style encoder trainer config
βββ dataset/ # Dataset implementations
β βββ HDTF_TFHP.py # HDTF-TFHP dataset
βββ models/ # Model implementations
β βββ diffposetalk.py # DiffPoseTalk model
β βββ avatar/ # Avatar related modules
β β βββ flame.py # FLAME head model
β β βββ lbs.py # Linear blend skinning
β βββ lib/ # Model components
β βββ base_models.py # Transformer, Attention, etc.
β βββ common.py # Common utilities
β βββ quantizer.py # Vector quantization
β βββ audio/ # Audio feature extractors
β βββ head/ # Head model components
β βββ network/ # Network architectures
βββ trainers/ # Training logic
β βββ diffposetalk_trainer.py # DiffPoseTalk trainer
βββ evaluator/ # Evaluators
β βββ TalkerEvaluator.py # Talking head evaluator
βββ utils/ # Utility functions
β βββ optim/ # Optimizers and schedulers
β βββ tools.py # General utilities
β βββ meters.py # Metric tracking
β βββ registry.py # Component registration
β βββ loss.py # Loss functions
β βββ media.py # Media utilities
β βββ renderer.py # Rendering utilities
βββ scripts/ # Shell scripts
β βββ style_train.sh # Style encoder training script
β βββ talker_train.sh # Talker training script
βββ data/ # Data directory
β βββ HDTF_TFHP/ # HDTF-TFHP dataset files
βββ output/ # Training outputs
β βββ HDTF_TFHP/ # Output for HDTF-TFHP experiments
βββ pretrained/ # Pretrained models
βββ train.py # Main training entry point
βββ environment.yml # Conda environment file
βββ requirements.txt # Python dependenciesTrainer
βββ config
β βββ check_cfg
β βββ system_init
βββ data
β βββ build_data_loader
β βββ DataManager
β β βββ DatasetBase
β β βββ DatasetWrapper
β β βββ show_dataset_summary
β β βββ data_analysis
βββ model
β βββ build_model
β βββ get_model_names
β βββ register_model
β βββ set_model_mode
βββ writer
β βββ init_writer
β βββ write_scalar
β βββ close_writer
βββ train
β βββ parse_batch_train
β βββ before_train
β βββ train_epoch
β β βββ before_epoch
β β βββ run_epoch
β β βββ after_epoch
β βββ train_iter
β β βββ before_iter
β β βββ run_iter
β β βββ after_iter
β βββ forward_backward
β βββ after_train
βββ optim
β βββ build_optimizer
β βββ build_lr_scheduler
β βββ model_backward_and_update
β β βββ model_zero_grad
β β βββ model_backward
β β βββ model_update
β βββ update_lr
β βββ get_current_lr
βββ test
β βββ test
β βββ parse_batch_test
βββ evaluator
β βββ build_evaluator
β βββ loss
β β βββ build_loss_metrics
β β βββ fetch_mask
β β βββ geometric_losses
β β βββ simple_loss
β β βββ velocity_loss
β β βββ smooth_loss
β βββ FLAME
β β βββ get_coef_dict
β β βββ coef_dict_to_vertices
β β βββ save_coef_file
β βββ render
β β βββ setup_mesh_renderer
β β βββ render_and_save
β β βββ render_to_video
βββ save_load
β βββ save_model
β βββ save_checkpoint
β βββ load_model
β βββ load_checkpoint
β βββ load_pretrained_weights
β βββ resume_model_if_exist
β βββ resume_from_checkpoint
βββ tools
β βββ optimizer
β β βββ RAdam
β β βββ PlainRAdam
β β βββ AdamW
β βββ scheduler
β β βββ ConstantWarmupScheduler
β β βββ LinearWarmupScheduler
β β βββ GradualWarmupScheduler
β βββ loss
β β βββ calc_vq_loss
β β βββ calc_logit_loss
β β βββ nt_xent_loss
β βββ meida
β β βββ combine_video_and_audio
β β βββ combine_frames_and_audio
β β βββ convert_video
β β βββ reencode_audio
β β βββ extract_frames
β βββ render
β β βββ PyMeshRenderer # psbody mesh
β βββ count_num_param
β βββ others# Clone the repository
git clone https://github.com/LZHMS/3DTalkingHeadCodeBase.git
cd 3DTalkingHeadCodeBase
# Create conda environment
conda create -n talkinghead python=3.9
conda activate talkinghead
# Or use the provided environment file
conda env create -f environment.yml
conda activate talkinghead
# Install PyTorch (adjust for your CUDA version)
pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1
# Install dependencies
pip install -r requirements.txt# Train DiffPoseTalk with default configuration
python train.py --config-file config/difftalk_trainer_config.yaml
# Train Style Encoder
python train.py --config-file config/style_trainer_config.yaml
# Train with custom settings
python train.py \
--config-file config/difftalk_trainer_config.yaml \
--gpu 0,1 \
OPTIM.LR 0.0001# Train style encoder
bash scripts/style_train.sh
# Train talking head model
bash scripts/talker_train.shThe framework adopts a decoupled trainer-based architecture that separates concerns:
# Automatic pipeline management
trainer = build_trainer(config)
trainer.train() # Handles entire training loopTrainer responsibilities:
- β Data loading and preprocessing
- β Model initialization and checkpointing
- β Training loop with gradient updates
- β Validation and evaluation
- β Logging and visualization
- β Learning rate scheduling
The most fantactic component is the config system which can include all config parameters in the project.
Only one yaml file you can config your own project and fast set up the training pipline, just like the following overview config:
# Example configuration
ENV:
SEED: 2025
NAME: StyleEncoder_Trainer
DESCRIPTION: Train the style encoder of DiffPoseTalk.
OUTPUT_DIR: ./output
VERBOSE: True
USE_WANDB: False
WANDB:
KEY: <your wandb key>
ENTITY: 3DVZHao
PROJECT: 3DTalkingHead
NAME: TrainingStyleEncoder
NOTES: Training as the baseline model.
TAGS: Baseline
MODE: online
EXTRA:
STYLE_ENC_CKPT:
DATASET:
NAME: HDTF_TFHP
ROOT: ./data/
HDTF_TFHP:
COEF_STATS: stats_train.npz
TRAIN: train.txt
VAL: val.txt
TEST: test.txt
COEF_FPS: 25 # frames per second for coefficients (sequence fps)
MOTIONS: 100 # number of motions per sample
CROP: random # crop strategy
AUDIO_SR: 16000 # audio sampling rate
DATALOADER:
NUM_WORKERS: 4
TRAIN:
BATCH_SIZE: 32
TEST:
BATCH_SIZE: 64
MODEL:
NAME: StyleEncoder
HEAD:
ROT_REPR: 'aa'
NO_HEAD_POSE: False
BACKBONE:
NAME: TransformerEncoder
IN_DIM: 50
HIDDEN_SIZE: 128
NUM_HIDDEN_LAYERS: 4
NUM_ATTENTION_HEADS: 4
TAIL:
MLP_RATIO: 4
LOSS:
NAME: NTXentLoss
CONTRASTIVE:
TEMPRATURE: 0.1
TRAINER:
NAME: StyleEncoderTrainer
TRAIN:
USE_ITERS: True
MAX_ITERS: 200
PRINT_FREQ: 5
SAVE_FREQ: 20
EVALUATE: True
EVAL_FREQ: 20
OPTIM:
NAME: adam
LR: 0.0001
LR_SCHEDULER: cosine
LR_UPDATE_FREQ: 1
EVALUATE:
EVALUATOR: TDTalkerEvaluatorMore exciting things include extending your custom parameters to the ENV.EXTRA, which is an extendable configuration.
When you cannot find your parameters in the base/base_config.py file and do not want to add them as global configurations across all projects, you can use this method to create a custom yml configuration file.
Note that the STYLE_ENC_CKPT parameter does not appear in the base/base_config.py file.
ENV:
EXTRA:
STYLE_ENC_CKPT: All components in the CodeBase are set up using the registry system. By using the @TRAINER_REGISTRY.register() decorator, we can register all defined modules into a centralized pool. Through the configuration file, we can then select the corresponding module to compose the required project. This approach is highly convenient and reusable!
from base import TRAINER_REGISTRY
@TRAINER_REGISTRY.register()
class CustomTrainer(TrainerBase):
def __init__(self, config):
super().__init__(config)
# Custom initializationModels can be difined using the components from models/lib including the head, backbone and tail config. Some standard module can be reuseable in this way.
| Model | Type | Paper | Status |
|---|---|---|---|
| DiffPoseTalk | Diffusion + Style | [Sun et al., 2024] | β |
| Dataset | Description | Subjects | Status |
|---|---|---|---|
| HDTF-TFHP | High-definition talking face with 3D head pose | - | β |
Distributed training allows you to scale your training process across multiple GPUs or machines. This is particularly useful for large-scale models or datasets. The framework provides built-in support for distributed training using PyTorch's torch.distributed module.
python -m torch.distributed.launch \
--nproc_per_node=4 \
train.py --config-file config/difftalk_trainer_config.yamlExperiment tracking is essential for monitoring and analyzing your training process. The framework supports both TensorBoard for local visualization and WandB for cloud-based experiment tracking. These tools allow you to log metrics, visualize training progress, and compare different experiments.
# Automatic logging
self.write_scalar("train/loss", loss, step)Model checkpointing ensures that your training progress is saved periodically, allowing you to resume training from the last saved state in case of interruptions. The framework automatically saves the best model and supports resuming from checkpoints.
# Automatic best model saving
# Resume from checkpoint
trainer.resume_model_if_exist("./checkpoint_dir")from base import BaseModel, MODEL_REGISTRY
@MODEL_REGISTRY.register()
class YourModel(BaseModel):
def __init__(self, cfg):
super().__init__()
# Initialize your model
def forward(self, x):
# Forward pass
return outputfrom base import TrainerBase, TRAINER_REGISTRY
@TRAINER_REGISTRY.register()
class YourTrainer(TrainerBase):
def build_model(self):
# Build your model
pass
def forward_backward(self, batch):
# Training step logic
passfrom base import DatasetBase, DATASET_REGISTRY
@DATASET_REGISTRY.register()
class YourDataset(DatasetBase):
def __init__(self, cfg):
# Initialize dataset
passThis codebase follows a registry-based modular design where:
- All major components (models, trainers, datasets, evaluators) are registered
- Configuration is centralized and hierarchical
- Training pipeline is fully automated through trainer classes
- Easy to extend with new models and experiments
- Base Classes: All components inherit from base classes in
base/ - Registry Pattern: Use
@REGISTRY.register()for component discovery - Configuration-Driven: All hyperparameters managed through YACS config
- Decoupled Training: Trainer handles all training logic separately from model
If you find this codebase useful for your research, please consider citing:
@software{3DTalkingHeadCodeBase,
author = {Zhihao Li},
title = {3DTalkingHeadCodeBase: A Modular Framework for 3D Talking Head Generation},
year = {2025},
url = {https://github.com/LZHMS/3DTalkingHeadCodeBase},
version = {1.0.0}
}This project is licensed under the MIT License - see the LICENSE file for details.
- Dassl.pytorch for the foundational training framework architecture
- DiffPoseTalk for diffusion-based methods
- YACS for configuration management
- PyTorch team for the deep learning framework
- The talking head research community
For questions and feedback, please open an issue or contact the maintainers.