Skip to content

Model Card

Itomigna2 edited this page Apr 2, 2024 · 11 revisions

Model card

This is model card of this implementation.

I will imitate some reference paper's model card on this page. Maybe it will not be perfect, but It will be helpful to understand the implementation comprehensively.

Sample model card of PaLM 2

image


Model Summary
Summary Muesli is model-based RL algorithm. It processes RGB image to probability distribution corresponding to action space defined by environment. It learns from episode data came from agent-environment interaction to maximize the cumulative rewards.
Input RGB image tensor; [channel = 3, height = 72 pixel, width = 96 pixel]
Output Probability distribution vector; [action_space]
Model architecture
Agent Network
It is not official term, it means the networks to be unrolled and optimized. It includes below networks. (It is also called 'learning network', 'online network')
Representation Network It is the observation encoder based on CNN. (Main role: image -> hidden state)
Dynamics Network It infers future time step's hidden states conditioned on the selected actions(in the episode data) with LSTM. (Main role: hidden state -> hidden states in the future)
Policy Network It infers probability distribution related to action from hidden state. (Main role: hidden state -> distribution(softmaxed))
Value Network It infers probability distribution related to scalar value from hidden state. (Main role: hidden state -> distribution(softmaxed))
Reward Network It infers probability distribution related to scalar reward from hidden state. (Main role: hidden state -> distribution(softmaxed))
Target Network
It has same elements as agent network. It is mixture of the updated agent network's parameters priorly. It is updated by exponential moving average. It is used in actor's environment interacting and learner's inference except unrolling agent network.
Loss function
PG loss Auxiliary policy-gradient loss. (eq.10 in muesli paper)(first_term in the code)
Model loss The policy component of the model loss. (eq.13 in muesli paper)(extended to the first time step, start from k=0)(L_m in the code)
Value model loss Cross-entropy loss. (Same as described in MuZero paper's supplementary materials) (L_v in the code)
Reward model loss Cross-entropy loss. (Same as described in MuZero paper's supplementary materials) (L_r in the code)
Replay Buffer
Sampling method Uniform (randomly pick 1 sequence (5 transitions) per randomly selected episode)(it is empirically works but it is arbitrary and not same as paper. Maybe it will be changed after.)
Replay proportion in a batch Off-policy data 75% + on-policy data 25%
Capacity Not yet implemented
Treating frame stacking Add the start image (stacking_frame-1) times before the start of interacting.
Treating unroll over episode length Add zeros elements (unroll_step+1) times after the last of interacting.
Storing methods are quite twisted and not verified. Maybe it needs to be checked and improved.
Evaluation
LunarLander-v2 Just checking cumulative returns while gathering data from agent-environment interacting
Need to be improved (Maybe it has to be averaged with more than 3 random seed with controlled randomness)
Techniques
Categorical reparametrization Yes. Used for value model and reward model. (distribution <-> scalar)
Advantage normalization Yes.
Target Network Yes. Moving average update
Mixed prior policy Yes. Mixed with 0.3% uniform distribution. (Mixed with 3% behavior policy is not used and verified due to my lack of knowledge). The role of this is regulariser as described in the Ada paper p6.
Stacking observations Yes.
Min-max normalization Yes. [0,1]. It is used to normalize embedding before p,v,r inference.
β-LOO action-dependent baselines None.
Retrace Not yet. (Target value has to be estimated by Retrace estimator but not yet implemented)
Vectorized environment Not yet.
Distributed computing(actor-learner decomposition framework) Not yet.
Pop-Art Not yet.
Main Hyperparameters
Following Table 5 in the Muesli paper
Batch size 128 sequences
Sequence length 5 frames
Model unroll length K 4
Replay proportion in a batch 75%
Replay buffer capacity (not implemented) ~ frames
Initial learning rate 3e-4
Final learning rate 0
AdamW weight decay 0
Discount 0.997
Target network update rate (α_target) 0.01
Value loss weight 0.25
Reward loss weight 1.0
RETRACE estimator samples (not yet)
KL(π_CMPO, π) estimator samples None (exact KL used)
Variance moving average decay (β_var) 0.99
Variance offset (ϵ_var) 1e-12
Added by this implementation
not all, focus on the more important
Learner iteration 20
resize_height 72
resize_width 96
hidden state resolution (of LSTM) 512
mlp hidden layer width 128
support size (of categorical reparametrization) 30
expriment_length(for LunarLander) 20000
epsilon for CEloss 1e-12 (it prevents NaN error due to log about zero)(not H-param optimized yet)
stacking frame 4
Optimization
Mini-batch Yes.
Optimizer AdamW (weight_decay=0)
Gradient clipping [-1,1]
learning rate schedule decay to zero
Implementation Frameworks
Hardware Intel Xeon, NVIDIA H100 GPU
Software PyTorch, Python, MS nni, ...
Computing resource usage
GPU memory This version use 1.6GB vram per 1 experiment
CPU core 1 core per 1 experiment
Main memory This version use approximately 10GB~100GB ram per 1 experiment (Maybe the reason is the absence of the replay buffer capacity, or some memory leakage in the code)
Model Usage & Limitations
TBD
TBD
Clone this wiki locally