Skip to content

Structured Coding for 3D Talking Head Codebase: A Modular and Extensible Framework for 3D Talking Head Generation Research https://lzhms.github.io/projects/3DTalkingHeadCodeBase/

Notifications You must be signed in to change notification settings

LZHMS/3DTalkingHeadCodeBase

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

33 Commits
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

Structured Coding for 3D Talking Head Codebase

A Modular and Extensible Framework for 3D Talking Head Generation Research

⭐ Star us on GitHub if this project helps your research!

Python 3.9+ PyTorch License

🎯 Overview

This repository provides a foundational framework for any AI model training project. It serves as a base for accumulating and reusing essential model code, enabling rapid development of custom modules and avoiding reinventing the wheel.

The framework adopts a decoupled trainer architecture that automatically manages the entire pipelineβ€”from data loading to model evaluationβ€”with a robust configuration management system.

By embracing structured programming, complex code is divided into independent modules, greatly improving code standardization, maintainability, and readability.


Key Features:

  • πŸ”§ Modular Architecture: Decoupled components for easy extension and customization
  • 🎨 DiffPoseTalk Model: Implements diffusion-based talking head generation with style encoding
  • πŸ“Š Unified Training Framework: Trainer-based system with full pipeline automation
  • βš™οΈ Flexible Configuration: YACS-based hierarchical configuration management
  • πŸ“ˆ Experiment Tracking: Built-in TensorBoard and WandB support
  • πŸš€ Production Ready: Comprehensive logging, checkpointing, and evaluation tools

Note

This project currently implements state-of-the-art (SOTA) methods for 3D talking head generation, specifically the DiffPoseTalk model. We are actively developing our own research methods to further advance the field.

Note

This project is modified from Dassl, making it more user-friendly and structured. It includes additional modules tailored for 3D Talking Head research, such as datasets for 3D Talking Head studies and FLAME-based rendering components.

πŸ—’οΈ TODO Plan

  • Add DistributedDataParallel (DDP) Support
    def wrap_model_with_ddp(self, model, find_unused_parameters=False):
    """Wrap model with DistributedDataParallel.
    Args:
    model: The model to wrap
    find_unused_parameters: Whether to find unused parameters (useful for complex models)
    Returns:
    Wrapped model or original model if not distributed
    """
    if self.is_distributed:
    # Wrap with DDP
    model = DDP(
    model,
    device_ids=[self.local_rank],
    output_device=self.local_rank,
    find_unused_parameters=find_unused_parameters
    )
    logger.info(f"Model wrapped with DistributedDataParallel (local_rank={self.local_rank})")
    return model
  • Develop support for audio-visual dataset collection
    • Design and implement an audio-visual data collection workflow
    • Provide tools for data annotation and preprocessing
    • Integrate with the existing data management and training pipeline
  • Implement Mesh Rendering using pytorch3d.renderer
  • Develop a FLAME texture rendering pipeline

πŸ“ Project Structure

3DTalkingHeadCodeBase/
β”œβ”€β”€ base/                      # Core base classes
β”‚   β”œβ”€β”€ base_config.py         # Configuration base class
β”‚   β”œβ”€β”€ base_dataset.py        # Dataset base class
β”‚   β”œβ”€β”€ base_datamanager.py    # Data manager base class
β”‚   β”œβ”€β”€ base_model.py          # Model base class
β”‚   β”œβ”€β”€ base_trainer.py        # Trainer base class
β”‚   └── base_evaluator.py      # Evaluator base class
β”œβ”€β”€ config/                     # Configuration files
β”‚   β”œβ”€β”€ difftalk_trainer_config.yaml  # DiffPoseTalk trainer config
β”‚   └── style_trainer_config.yaml     # Style encoder trainer config
β”œβ”€β”€ dataset/                   # Dataset implementations
β”‚   └── HDTF_TFHP.py           # HDTF-TFHP dataset
β”œβ”€β”€ models/                    # Model implementations
β”‚   β”œβ”€β”€ diffposetalk.py        # DiffPoseTalk model
β”‚   β”œβ”€β”€ avatar/                # Avatar related modules
β”‚   β”‚   β”œβ”€β”€ flame.py           # FLAME head model
β”‚   β”‚   └── lbs.py             # Linear blend skinning
β”‚   └── lib/                   # Model components
β”‚       β”œβ”€β”€ base_models.py     # Transformer, Attention, etc.
β”‚       β”œβ”€β”€ common.py          # Common utilities
β”‚       β”œβ”€β”€ quantizer.py       # Vector quantization
β”‚       β”œβ”€β”€ audio/             # Audio feature extractors
β”‚       β”œβ”€β”€ head/              # Head model components
β”‚       └── network/           # Network architectures
β”œβ”€β”€ trainers/                   # Training logic
β”‚   └── diffposetalk_trainer.py # DiffPoseTalk trainer
β”œβ”€β”€ evaluator/                # Evaluators
β”‚   └── TalkerEvaluator.py     # Talking head evaluator
β”œβ”€β”€ utils/                      # Utility functions
β”‚   β”œβ”€β”€ optim/                 # Optimizers and schedulers
β”‚   β”œβ”€β”€ tools.py               # General utilities
β”‚   β”œβ”€β”€ meters.py              # Metric tracking
β”‚   β”œβ”€β”€ registry.py            # Component registration
β”‚   β”œβ”€β”€ loss.py                # Loss functions
β”‚   β”œβ”€β”€ media.py               # Media utilities
β”‚   └── renderer.py            # Rendering utilities
β”œβ”€β”€ scripts/                    # Shell scripts
β”‚   β”œβ”€β”€ style_train.sh         # Style encoder training script
β”‚   └── talker_train.sh        # Talker training script
β”œβ”€β”€ data/                       # Data directory
β”‚   └── HDTF_TFHP/             # HDTF-TFHP dataset files
β”œβ”€β”€ output/                     # Training outputs
β”‚   └── HDTF_TFHP/             # Output for HDTF-TFHP experiments
β”œβ”€β”€ pretrained/                 # Pretrained models
β”œβ”€β”€ train.py                    # Main training entry point
β”œβ”€β”€ environment.yml            # Conda environment file
└── requirements.txt           # Python dependencies

πŸ“ Trainer Architecture

Trainer
β”œβ”€β”€ config
β”‚   β”œβ”€β”€ check_cfg
β”‚   └── system_init
β”œβ”€β”€ data
β”‚   β”œβ”€β”€ build_data_loader
β”‚   β”œβ”€β”€ DataManager
β”‚   β”‚   β”œβ”€β”€ DatasetBase
β”‚   β”‚   β”œβ”€β”€ DatasetWrapper
β”‚   β”‚   β”œβ”€β”€ show_dataset_summary
β”‚   β”‚   └── data_analysis
β”œβ”€β”€ model
β”‚   β”œβ”€β”€ build_model
β”‚   β”œβ”€β”€ get_model_names
β”‚   β”œβ”€β”€ register_model
β”‚   └── set_model_mode
β”œβ”€β”€ writer
β”‚   β”œβ”€β”€ init_writer
β”‚   β”œβ”€β”€ write_scalar
β”‚   └── close_writer
β”œβ”€β”€ train
β”‚   β”œβ”€β”€ parse_batch_train
β”‚   β”œβ”€β”€ before_train
β”‚   β”œβ”€β”€ train_epoch
β”‚   β”‚   β”œβ”€β”€ before_epoch
β”‚   β”‚   β”œβ”€β”€ run_epoch
β”‚   β”‚   └── after_epoch
β”‚   β”œβ”€β”€ train_iter
β”‚   β”‚   β”œβ”€β”€ before_iter
β”‚   β”‚   β”œβ”€β”€ run_iter
β”‚   β”‚   └── after_iter
β”‚   β”œβ”€β”€ forward_backward
β”‚   └── after_train
β”œβ”€β”€ optim
β”‚   β”œβ”€β”€ build_optimizer
β”‚   β”œβ”€β”€ build_lr_scheduler
β”‚   β”œβ”€β”€ model_backward_and_update
β”‚   β”‚   β”œβ”€β”€ model_zero_grad
β”‚   β”‚   β”œβ”€β”€ model_backward
β”‚   β”‚   └── model_update
β”‚   β”œβ”€β”€ update_lr
β”‚   └── get_current_lr
β”œβ”€β”€ test
β”‚   β”œβ”€β”€ test
β”‚   └── parse_batch_test
β”œβ”€β”€ evaluator
β”‚   β”œβ”€β”€ build_evaluator
β”‚   β”œβ”€β”€ loss
β”‚   β”‚   β”œβ”€β”€ build_loss_metrics
β”‚   β”‚   β”œβ”€β”€ fetch_mask
β”‚   β”‚   β”œβ”€β”€ geometric_losses
β”‚   β”‚   β”œβ”€β”€ simple_loss
β”‚   β”‚   β”œβ”€β”€ velocity_loss
β”‚   β”‚   └── smooth_loss
β”‚   β”œβ”€β”€ FLAME
β”‚   β”‚   β”œβ”€β”€ get_coef_dict
β”‚   β”‚   β”œβ”€β”€ coef_dict_to_vertices
β”‚   β”‚   └── save_coef_file
β”‚   β”œβ”€β”€ render
β”‚   β”‚   β”œβ”€β”€ setup_mesh_renderer
β”‚   β”‚   β”œβ”€β”€ render_and_save
β”‚   β”‚   └── render_to_video
β”œβ”€β”€ save_load
β”‚   β”œβ”€β”€ save_model
β”‚   β”œβ”€β”€ save_checkpoint
β”‚   β”œβ”€β”€ load_model
β”‚   β”œβ”€β”€ load_checkpoint
β”‚   β”œβ”€β”€ load_pretrained_weights
β”‚   β”œβ”€β”€ resume_model_if_exist
β”‚   └── resume_from_checkpoint
β”œβ”€β”€ tools
β”‚   β”œβ”€β”€ optimizer
β”‚   β”‚   β”œβ”€β”€ RAdam
β”‚   β”‚   β”œβ”€β”€ PlainRAdam
β”‚   β”‚   └── AdamW
β”‚   β”œβ”€β”€ scheduler
β”‚   β”‚   β”œβ”€β”€ ConstantWarmupScheduler
β”‚   β”‚   β”œβ”€β”€ LinearWarmupScheduler
β”‚   β”‚   └── GradualWarmupScheduler
β”‚   β”œβ”€β”€ loss
β”‚   β”‚   β”œβ”€β”€ calc_vq_loss
β”‚   β”‚   β”œβ”€β”€ calc_logit_loss
β”‚   β”‚   └── nt_xent_loss
β”‚   β”œβ”€β”€ meida
β”‚   β”‚   β”œβ”€β”€ combine_video_and_audio
β”‚   β”‚   β”œβ”€β”€ combine_frames_and_audio
β”‚   β”‚   β”œβ”€β”€ convert_video
β”‚   β”‚   β”œβ”€β”€ reencode_audio
β”‚   β”‚   └── extract_frames
β”‚   β”œβ”€β”€ render
β”‚   β”‚   └── PyMeshRenderer     # psbody mesh
β”‚   β”œβ”€β”€ count_num_param
β”‚   └── others

πŸš€ Quick Start

Installation

# Clone the repository
git clone https://github.com/LZHMS/3DTalkingHeadCodeBase.git
cd 3DTalkingHeadCodeBase

# Create conda environment
conda create -n talkinghead python=3.9
conda activate talkinghead

# Or use the provided environment file
conda env create -f environment.yml
conda activate talkinghead

# Install PyTorch (adjust for your CUDA version)
pip install torch==2.0.0 torchvision==0.15.1 torchaudio==2.0.1

# Install dependencies
pip install -r requirements.txt

Training

# Train DiffPoseTalk with default configuration
python train.py --config-file config/difftalk_trainer_config.yaml

# Train Style Encoder
python train.py --config-file config/style_trainer_config.yaml

# Train with custom settings
python train.py \
    --config-file config/difftalk_trainer_config.yaml \
    --gpu 0,1 \
    OPTIM.LR 0.0001

Using the Training Scripts

# Train style encoder
bash scripts/style_train.sh

# Train talking head model
bash scripts/talker_train.sh

πŸ—οΈ Architecture

Trainer-Based Training Paradigm

The framework adopts a decoupled trainer-based architecture that separates concerns:

# Automatic pipeline management
trainer = build_trainer(config)
trainer.train()  # Handles entire training loop

Trainer responsibilities:

  • βœ… Data loading and preprocessing
  • βœ… Model initialization and checkpointing
  • βœ… Training loop with gradient updates
  • βœ… Validation and evaluation
  • βœ… Logging and visualization
  • βœ… Learning rate scheduling

Configuration System

The most fantactic component is the config system which can include all config parameters in the project. Only one yaml file you can config your own project and fast set up the training pipline, just like the following overview config:

# Example configuration
ENV:
  SEED: 2025
  NAME: StyleEncoder_Trainer
  DESCRIPTION: Train the style encoder of DiffPoseTalk.
  OUTPUT_DIR: ./output
  VERBOSE: True
  USE_WANDB: False
  WANDB:
    KEY: <your wandb key>
    ENTITY: 3DVZHao
    PROJECT: 3DTalkingHead
    NAME: TrainingStyleEncoder
    NOTES: Training as the baseline model.
    TAGS: Baseline
    MODE: online
  EXTRA:
    STYLE_ENC_CKPT: 

DATASET:
  NAME: HDTF_TFHP
  ROOT: ./data/
  HDTF_TFHP:
    COEF_STATS: stats_train.npz
    TRAIN: train.txt
    VAL: val.txt
    TEST: test.txt
    COEF_FPS: 25      # frames per second for coefficients (sequence fps)
    MOTIONS: 100      # number of motions per sample
    CROP: random    # crop strategy
    AUDIO_SR: 16000   # audio sampling rate

DATALOADER:
  NUM_WORKERS: 4
  TRAIN:
    BATCH_SIZE: 32
  TEST:
    BATCH_SIZE: 64

MODEL:
  NAME: StyleEncoder
  HEAD:
    ROT_REPR: 'aa'
    NO_HEAD_POSE: False
  BACKBONE:
    NAME: TransformerEncoder
    IN_DIM: 50
    HIDDEN_SIZE: 128
    NUM_HIDDEN_LAYERS: 4
    NUM_ATTENTION_HEADS: 4
  TAIL:
    MLP_RATIO: 4

LOSS:
  NAME: NTXentLoss
  CONTRASTIVE:
    TEMPRATURE: 0.1

TRAINER:
  NAME: StyleEncoderTrainer

TRAIN:
  USE_ITERS: True
  MAX_ITERS: 200
  PRINT_FREQ: 5
  SAVE_FREQ: 20
  EVALUATE: True
  EVAL_FREQ: 20

OPTIM:
  NAME: adam
  LR: 0.0001
  LR_SCHEDULER: cosine
  LR_UPDATE_FREQ: 1

EVALUATE:
  EVALUATOR: TDTalkerEvaluator

More exciting things include extending your custom parameters to the ENV.EXTRA, which is an extendable configuration.
When you cannot find your parameters in the base/base_config.py file and do not want to add them as global configurations across all projects, you can use this method to create a custom yml configuration file.

Note that the STYLE_ENC_CKPT parameter does not appear in the base/base_config.py file.

ENV:
  EXTRA:
    STYLE_ENC_CKPT: 

Registry System

All components in the CodeBase are set up using the registry system. By using the @TRAINER_REGISTRY.register() decorator, we can register all defined modules into a centralized pool. Through the configuration file, we can then select the corresponding module to compose the required project. This approach is highly convenient and reusable!

from base import TRAINER_REGISTRY

@TRAINER_REGISTRY.register()
class CustomTrainer(TrainerBase):
    def __init__(self, config):
        super().__init__(config)
        # Custom initialization

πŸ“Š Supported Models

Models can be difined using the components from models/lib including the head, backbone and tail config. Some standard module can be reuseable in this way.

Model Type Paper Status
DiffPoseTalk Diffusion + Style [Sun et al., 2024] βœ…

πŸ“ˆ Datasets

Dataset Description Subjects Status
HDTF-TFHP High-definition talking face with 3D head pose - βœ…

πŸ› οΈ Advanced Features

Distributed Training

Distributed training allows you to scale your training process across multiple GPUs or machines. This is particularly useful for large-scale models or datasets. The framework provides built-in support for distributed training using PyTorch's torch.distributed module.

python -m torch.distributed.launch \
    --nproc_per_node=4 \
    train.py --config-file config/difftalk_trainer_config.yaml

Experiment Tracking

Experiment tracking is essential for monitoring and analyzing your training process. The framework supports both TensorBoard for local visualization and WandB for cloud-based experiment tracking. These tools allow you to log metrics, visualize training progress, and compare different experiments.

# Automatic logging
self.write_scalar("train/loss", loss, step)

Model Checkpointing

Model checkpointing ensures that your training progress is saved periodically, allowing you to resume training from the last saved state in case of interruptions. The framework automatically saves the best model and supports resuming from checkpoints.

# Automatic best model saving
# Resume from checkpoint
trainer.resume_model_if_exist("./checkpoint_dir")

πŸ“ Adding New Components

Add a New Model

from base import BaseModel, MODEL_REGISTRY

@MODEL_REGISTRY.register()
class YourModel(BaseModel):
    def __init__(self, cfg):
        super().__init__()
        # Initialize your model
    
    def forward(self, x):
        # Forward pass
        return output

Add a New Trainer

from base import TrainerBase, TRAINER_REGISTRY

@TRAINER_REGISTRY.register()
class YourTrainer(TrainerBase):
    def build_model(self):
        # Build your model
        pass
    
    def forward_backward(self, batch):
        # Training step logic
        pass

Add a New Dataset

from base import DatasetBase, DATASET_REGISTRY

@DATASET_REGISTRY.register()
class YourDataset(DatasetBase):
    def __init__(self, cfg):
        # Initialize dataset
        pass

πŸ”§ Development Guide

Project Philosophy

This codebase follows a registry-based modular design where:

  • All major components (models, trainers, datasets, evaluators) are registered
  • Configuration is centralized and hierarchical
  • Training pipeline is fully automated through trainer classes
  • Easy to extend with new models and experiments

Key Design Patterns

  1. Base Classes: All components inherit from base classes in base/
  2. Registry Pattern: Use @REGISTRY.register() for component discovery
  3. Configuration-Driven: All hyperparameters managed through YACS config
  4. Decoupled Training: Trainer handles all training logic separately from model

πŸ“– Citation

If you find this codebase useful for your research, please consider citing:

@software{3DTalkingHeadCodeBase,
  author       = {Zhihao Li},
  title        = {3DTalkingHeadCodeBase: A Modular Framework for 3D Talking Head Generation},
  year         = {2025},
  url          = {https://github.com/LZHMS/3DTalkingHeadCodeBase},
  version      = {1.0.0}
}

πŸ“„ License

This project is licensed under the MIT License - see the LICENSE file for details.

πŸ™ Acknowledgments

  • Dassl.pytorch for the foundational training framework architecture
  • DiffPoseTalk for diffusion-based methods
  • YACS for configuration management
  • PyTorch team for the deep learning framework
  • The talking head research community

πŸ“§ Contact

For questions and feedback, please open an issue or contact the maintainers.

About

Structured Coding for 3D Talking Head Codebase: A Modular and Extensible Framework for 3D Talking Head Generation Research https://lzhms.github.io/projects/3DTalkingHeadCodeBase/

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published