-
Notifications
You must be signed in to change notification settings - Fork 387
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feature(zjow): add middleware for ape-x structure pipeline (#696)
* Add priority collected in collector; Add Periodical model exchanger middleware * polish code
- Loading branch information
Showing
11 changed files
with
318 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from typing import TYPE_CHECKING, Callable | ||
from ding.framework import task | ||
if TYPE_CHECKING: | ||
from ding.framework import OnlineRLContext | ||
|
||
|
||
def priority_calculator(priority_calculation_fn: Callable) -> Callable: | ||
""" | ||
Overview: | ||
The middleware that calculates the priority of the collected data. | ||
Arguments: | ||
- priority_calculation_fn (:obj:`Callable`): The function that calculates the priority of the collected data. | ||
""" | ||
|
||
if task.router.is_active and not task.has_role(task.role.COLLECTOR): | ||
return task.void() | ||
|
||
def _priority_calculator(ctx: "OnlineRLContext") -> None: | ||
|
||
priority = priority_calculation_fn(ctx.trajectories) | ||
for i in range(len(priority)): | ||
ctx.trajectories[i]['priority'] = priority[i] | ||
|
||
return _priority_calculator |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
#unittest for priority_calculator | ||
|
||
import unittest | ||
import pytest | ||
import numpy as np | ||
from unittest.mock import Mock, patch | ||
from ding.framework import OnlineRLContext, OfflineRLContext | ||
from ding.framework import task, Parallel | ||
from ding.framework.middleware.functional import priority_calculator | ||
|
||
|
||
class MockPolicy(Mock): | ||
|
||
def priority_fun(self, data): | ||
return np.random.rand(len(data)) | ||
|
||
|
||
@pytest.mark.unittest | ||
def test_priority_calculator(): | ||
policy = MockPolicy() | ||
ctx = OnlineRLContext() | ||
ctx.trajectories = [ | ||
{ | ||
'obs': np.random.rand(2, 2), | ||
'next_obs': np.random.rand(2, 2), | ||
'reward': np.random.rand(1), | ||
'info': {} | ||
} for _ in range(10) | ||
] | ||
priority_calculator_middleware = priority_calculator(priority_calculation_fn=policy.priority_fun) | ||
priority_calculator_middleware(ctx) | ||
assert len(ctx.trajectories) == 10 | ||
assert all([isinstance(traj['priority'], float) for traj in ctx.trajectories]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters