-
Notifications
You must be signed in to change notification settings - Fork 5
Model Card
Itomigna2 edited this page Mar 12, 2024
·
11 revisions
This is model card of this implementation.
I will imitate some reference paper's model card on this page. Maybe it will not be perfect, but It will be helpful to understand the implementation comprehensively.
It will contains model architecture detail, used techniques, system configuration, loss function, etc.
Sample model card of PaLM 2
(WIP)
Model Summary | |
---|---|
Summary | Muesli is model-based RL algorithm. It processes RGB image to probability distribution corresponding to action space defined by environment. It learns from episode data came from agent-environment interaction to maximize the cumulative rewards. |
Input | RGB image tensor; [channel = 3, height = 72 pixel, width = 96 pixel] |
Output | Probability distribution vector; [action_space] |
Model architecture | |
---|---|
Representation Network | It is the observation encoder based on CNN. (Main role: image -> hidden state) |
Dynamics Network | It infers future time step's hidden states conditioned on the selected actions(in the episode data) with LSTM. (Main role: hidden state -> hidden states in the future) |
Policy Network | It infers probability distribution related to action from hidden state. (Main role: hidden state -> distribution(softmaxed)) |
Value Network | It infers probability distribution related to scalar value from hidden state. (Main role: hidden state -> distribution(softmaxed)) |
Reward Network | It infers probability distribution related to scalar reward from hidden state. (Main role: hidden state -> distribution(softmaxed)) |
Loss function | |
---|---|
PG loss | Auxiliary policy-gradient loss. (eq.10 in muesli paper)(first_term in the code) |
Model loss | The policy component of the model loss. (eq.13 in muesli paper)(extended to the first time step, start from k=0)(L_m in the code) |
Value model loss | Cross-entropy loss. (Same as described in MuZero paper's supplementary materials) (L_v in the code) |
Reward model loss | Cross-entropy loss. (Same as described in MuZero paper's supplementary materials) (L_r in the code) |
Replay Buffer | |
---|---|
Sampling method | Uniform |
Replay proportion in a batch | Off-policy data 75% + on-policy data 25% |
Capacity | Not yet implemented |
Evaluation | |
---|---|
LunarLander-v2 | Just checking cumulative returns while gathering data from agent-environment interacting |
TBD | Need to be improved |
Main Hyperparameters | |
---|---|
Mini-batch size | 128 |
Mini-batch size | 128 |
Optimization | |
---|---|
Optimizer | Adam |
Gradient clipping | [-1,1] |
learning rate schedule | decay to zero |
Techniques | |
---|---|
Categorical reparametrization | Used for value model and reward model. (distribution <-> scalar) |
Retrace | Not yet (Target value has to be estimated by Retrace estimator but not yet implemented) |
Advantage normalization | yes |
Target Network | yes moving average update |
Pop-Art | Not yet |
Stacking observations | 4 |
Min-max normalization | used for embedding before model |
Implementation Frameworks | |
---|---|
Hardware | Intel Xeon, NVIDIA H100 GPU |
Software | PyTorch, Python, MS nni, ... |
Model Usage & Limitations | |
---|---|
TBD | |
TBD |