-
Notifications
You must be signed in to change notification settings - Fork 376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(zp): add dreamerv3 algorithm #652
Conversation
Codecov Report
@@ Coverage Diff @@
## main #652 +/- ##
==========================================
+ Coverage 82.06% 82.47% +0.41%
==========================================
Files 586 586
Lines 47515 48047 +532
==========================================
+ Hits 38991 39626 +635
+ Misses 8524 8421 -103
Flags with carried forward coverage won't be shown. Click here to find out more.
|
ding/world_model/utils.py
Outdated
return -loss | ||
|
||
|
||
class ContDist: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
polish all the class name of distributions we implemented, unify their name format
ding/world_model/utils.py
Outdated
|
||
class SymlogDist(): | ||
|
||
def __init__(self, mode, dist='mse', agg='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add python typing lint
ding/world_model/utils.py
Outdated
|
||
def mode(self): | ||
_mode = torch.round(self._dist.mean) | ||
return _mode.detach() + self._dist.mean - self._dist.mean.detach() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why use detach here
ding/model/common/encoder.py
Outdated
class Conv2dSame(torch.nn.Conv2d): | ||
|
||
def calc_same_pad(self, i, k, s, d): | ||
return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add data type for i,k,s,d
Add function notation
ding/model/common/encoder.py
Outdated
class DreamerLayerNorm(nn.Module): | ||
|
||
def __init__(self, ch, eps=1e-03): | ||
super(DreamerLayerNorm, self).__init__() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add data type
import torch | ||
from torch import nn | ||
import torch.nn.functional as F | ||
from torch import distributions as torchd |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
please do not use torchd as short module name. It's not a standard coding format.
It's better to from torch.distributions import XXXXX.
Import only what you need.
ding/world_model/utils.py
Outdated
|
||
class TwoHotDistSymlog(): | ||
|
||
def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
device is set by default. It should not be explicitly set by some argument.
Description
cartpole_balance:
walker_walk:
Related Issue
#669
TODO
Check List