-
Notifications
You must be signed in to change notification settings - Fork 376
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor(lyd): refactor dt_policy in new pipeline and add img input support #693
Conversation
@@ -27,7 +27,7 @@ | |||
embed_dim=128, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove old decision transformer config
self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1) | ||
) | ||
|
||
self.max_env_score = -1.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove this
self.states = torch.zeros((self.eval_batch_size, self.max_eval_ep_len,) + tuple(self.state_dim), dtype=torch.float32, device=self.device) | ||
self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)] | ||
else: | ||
self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
remove rtg_scale
argument
return timesteps, states, actions, rtgs, traj_mask | ||
|
||
|
||
class FixedReplayBuffer(object): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we merge this class into the above class?
Refactor DT to new pipeline.
Add img input support for atari.