Skip to content

Model Card

Itomigna2 edited this page Mar 11, 2024 · 11 revisions

Model card

This is model card of this implementation.

I will imitate some reference paper's model card on this page. Maybe it will not be perfect, but It will be helpful to understand the implementation comprehensively.

It will contains model architecture detail, used techniques, system configuration, loss function, etc.

Sample model card of PaLM 2

image

WIP

Model Summary
Summary Muesli is model-based RL algorithm. It processes RGB image to probability distribution corresponding to action space defined by environment. It learns from episode data came from agent-environment interaction to maximize the cumulative rewards.
Input RGB image tensor; [channel = 3, height = 72 pixel, width = 96 pixel]
Output Probability distribution vector; [action_space]
Model architecture
Representation Network It is the observation encoder based on CNN. (Main role: image -> hidden states)
Dynamics Network It infers future time step's hidden states conditioned on the selected actions(in the episode data) with LSTM. (Main role: hidden state -> hidden states in the future)
Policy Network It infers probability distribution related to action from hidden state. (Main role: hidden state -> distribution(softmaxed))
Value Network It infers probability distribution related to scalar value from hidden state. (Main role: hidden state -> distribution(softmaxed))
Reward Network It infers probability distribution related to scalar reward from hidden state. (Main role: hidden state -> distribution(softmaxed))
Loss function
PG loss Auxiliary policy-gradient loss (eq.10 in muesli paper)
Model loss The policy component of the model loss (eq.13 in muesli paper)(extended to the first time step, start from k=0)
Value model loss Cross-entropy loss for value model which infer the value of corresponding state. Same as described in MuZero paper's supplementary materials
Reward loss Cross-entropy loss for reward model which infer the reward of corresponding state transition. Same as described in MuZero paper's supplementary materials
Evaluation
LunarLander-v2 Just checking cumulative returns while gathering data from agent-environment interacting
TBD TBD
Implementation Frameworks
Hardware Intel Xeon, NVIDIA H100 GPU
Software PyTorch, Python, MS nni, ...
Clone this wiki locally