-
Notifications
You must be signed in to change notification settings - Fork 5
Model Card
Itomigna2 edited this page Mar 12, 2024
·
11 revisions
This is model card of this implementation.
I will imitate some reference paper's model card on this page. Maybe it will not be perfect, but It will be helpful to understand the implementation comprehensively.
It will contains model architecture detail, used techniques, system configuration, loss function, etc.
Sample model card of PaLM 2
(WIP)
Model Summary | |
---|---|
Summary | Muesli is model-based RL algorithm. It processes RGB image to probability distribution corresponding to action space defined by environment. It learns from episode data came from agent-environment interaction to maximize the cumulative rewards. |
Input | RGB image tensor; [channel = 3, height = 72 pixel, width = 96 pixel] |
Output | Probability distribution vector; [action_space] |
Model architecture | |
---|---|
Agent Network | |
It is not official term, it means the networks to be unrolled and optimized. It includes below networks. | |
Representation Network | It is the observation encoder based on CNN. (Main role: image -> hidden state) |
Dynamics Network | It infers future time step's hidden states conditioned on the selected actions(in the episode data) with LSTM. (Main role: hidden state -> hidden states in the future) |
Policy Network | It infers probability distribution related to action from hidden state. (Main role: hidden state -> distribution(softmaxed)) |
Value Network | It infers probability distribution related to scalar value from hidden state. (Main role: hidden state -> distribution(softmaxed)) |
Reward Network | It infers probability distribution related to scalar reward from hidden state. (Main role: hidden state -> distribution(softmaxed)) |
Target Network | |
It has same elements as agent network. It is mixture of the updated agent network's parameters priorly. It updated by exponential moving average. It is used in actor's environment interacting and learner's inference except unrolling agent network. |
Loss function | |
---|---|
PG loss | Auxiliary policy-gradient loss. (eq.10 in muesli paper)(first_term in the code) |
Model loss | The policy component of the model loss. (eq.13 in muesli paper)(extended to the first time step, start from k=0)(L_m in the code) |
Value model loss | Cross-entropy loss. (Same as described in MuZero paper's supplementary materials) (L_v in the code) |
Reward model loss | Cross-entropy loss. (Same as described in MuZero paper's supplementary materials) (L_r in the code) |
Replay Buffer | |
---|---|
Sampling method | Uniform (randomly pick 1 sequence (5 transitions) per randomly selected episode)(it is empirically works but it is arbitrary and not same as paper. Maybe it will be changed after.) |
Replay proportion in a batch | Off-policy data 75% + on-policy data 25% |
Capacity | Not yet implemented |
Treating frame stacking | Add the start image (stacking_frame-1) times before the start of interacting. |
Treating unroll over episode length | Add zeros elements (unroll_step+1) times after the last of interacting. |
Storing methods are quite twisted and not verified. Maybe it needs to be checked and improved. |
Evaluation | |
---|---|
LunarLander-v2 | Just checking cumulative returns while gathering data from agent-environment interacting |
Need to be improved (Maybe it has to be averaged with more than 3 random seed with controlled randomness) |
Techniques | |
---|---|
Categorical reparametrization | Used for value model and reward model. (distribution <-> scalar) |
Retrace | Not yet. (Target value has to be estimated by Retrace estimator but not yet implemented) |
Advantage normalization | Yes. |
Target Network | Yes. Moving average update |
Mixed prior policy | Yes. Mixed with 0.03% uniform distribution. (Mixed with 0.3% behavior policy is not used and verified due to my lack of knowledge). The role of this is regulariser as described in the Ada paper p6. |
Pop-Art | Not yet. |
Stacking observations | Yes. |
Min-max normalization | Yes. [0,1]. It used for embedding before inference policy, value, and reward. |
Vectorized environment | Not yet. |
Distributed computing(actor-learner decomposition framework) | Not yet. |
Implementation Frameworks | |
---|---|
Hardware | Intel Xeon, NVIDIA H100 GPU |
Software | PyTorch, Python, MS nni, ... |
Main Hyperparameters | |
---|---|
Mini-batch size | 128 |
Learner iteration | 20 |
Discount rate | 0.997 |
Learning rate | 0.0003 |
Optimization | |
---|---|
Optimizer | Adam |
Gradient clipping | [-1,1] |
learning rate schedule | decay to zero |
Model Usage & Limitations | |
---|---|
TBD | |
TBD |