-
Notifications
You must be signed in to change notification settings - Fork 387
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feature(zjow): add middleware for ape-x structure pipeline #696
Conversation
@@ -54,7 +54,10 @@ def __init__( | |||
|
|||
def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData: | |||
if meta is None: | |||
meta = {'priority': self.max_priority} | |||
if 'priority' in data: | |||
meta = {'priority': data['priority'].item()} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do item operation in original priority producer
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
meta = {'priority': self.max_priority} | ||
if 'priority' in data: | ||
meta = {'priority': data['priority'].item()} | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add unittest for this new if-else branch
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
Parallel.runner(n_parallel_workers=2, startup_interval=0)(periodical_model_exchanger_main) | ||
|
||
|
||
if __name__ == "__main__": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
unittest doesn't need this entry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
|
||
priority = func_for_priority_calculation(ctx.trajectories) | ||
for i in range(len(priority)): | ||
ctx.trajectories[i]['priority'] = torch.tensor(priority[i], dtype=torch.float32).unsqueeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
don't transform it to tensor
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
|
||
def _priority_calculator(ctx: "OnlineRLContext") -> None: | ||
|
||
priority = func_for_priority_calculation(ctx.trajectories) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
priority_calculation_fn
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
from ding.framework import OnlineRLContext | ||
|
||
|
||
def priority_calculator(func_for_priority_calculation: Callable) -> Callable: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add unittest for this file
self._model_loader = model_loader | ||
self._event_name = event_name | ||
self._period = period | ||
self.mode = mode |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why not underline here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
|
||
def _cache_state_dict(self, msg: Dict[str, Any]): | ||
# msg: Dict {'id':id,'model':state_dict: Union[object, Storage]} | ||
print(f"node_id[{task.router.node_id}] get model msg") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use logging
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
fixed
Codecov Report
@@ Coverage Diff @@
## main #696 +/- ##
==========================================
- Coverage 82.06% 80.79% -1.27%
==========================================
Files 586 597 +11
Lines 47515 49631 +2116
==========================================
+ Hits 38991 40101 +1110
- Misses 8524 9530 +1006
Flags with carried forward coverage won't be shown. Click here to find out more.
|
Add middleware priority_calculator for calculate priority in collector.
Add middleware PeriodicalModelExchanger for better control model send and receive process.