Skip to content

Model Card

Itomigna2 edited this page Mar 12, 2024 · 11 revisions

Model card

This is model card of this implementation.

I will imitate some reference paper's model card on this page. Maybe it will not be perfect, but It will be helpful to understand the implementation comprehensively.

It will contains model architecture detail, used techniques, system configuration, loss function, etc.

Sample model card of PaLM 2

image

(WIP)

Model Summary
Summary Muesli is model-based RL algorithm. It processes RGB image to probability distribution corresponding to action space defined by environment. It learns from episode data came from agent-environment interaction to maximize the cumulative rewards.
Input RGB image tensor; [channel = 3, height = 72 pixel, width = 96 pixel]
Output Probability distribution vector; [action_space]
Model architecture
Representation Network It is the observation encoder based on CNN. (Main role: image -> hidden state)
Dynamics Network It infers future time step's hidden states conditioned on the selected actions(in the episode data) with LSTM. (Main role: hidden state -> hidden states in the future)
Policy Network It infers probability distribution related to action from hidden state. (Main role: hidden state -> distribution(softmaxed))
Value Network It infers probability distribution related to scalar value from hidden state. (Main role: hidden state -> distribution(softmaxed))
Reward Network It infers probability distribution related to scalar reward from hidden state. (Main role: hidden state -> distribution(softmaxed))
Loss function
PG loss Auxiliary policy-gradient loss. (eq.10 in muesli paper)(first_term in the code)
Model loss The policy component of the model loss. (eq.13 in muesli paper)(extended to the first time step, start from k=0)(L_m in the code)
Value model loss Cross-entropy loss. (Same as described in MuZero paper's supplementary materials) (L_v in the code)
Reward model loss Cross-entropy loss. (Same as described in MuZero paper's supplementary materials) (L_r in the code)
Replay Buffer
Sampling method Uniform
Replay proportion in a batch Off-policy data 75% + on-policy data 25%
Capacity Not yet implemented
Evaluation
LunarLander-v2 Just checking cumulative returns while gathering data from agent-environment interacting
TBD Need to be improved
Main Hyperparameters
Mini-batch size 128
Mini-batch size 128
Optimization
Optimizer Adam
Gradient clipping [-1,1]
learning rate schedule decay to zero
Techniques
Categorical reparametrization Used for value model and reward model. (distribution <-> scalar)
Retrace Not yet (Target value has to be estimated by Retrace estimator but not yet implemented)
Advantage normalization yes
Target Network yes moving average update
Pop-Art Not yet
Stacking observations 4
Min-max normalization used for embedding before model
Implementation Frameworks
Hardware Intel Xeon, NVIDIA H100 GPU
Software PyTorch, Python, MS nni, ...
Model Usage & Limitations
TBD
TBD
Clone this wiki locally