|
19 | 19 | from typing import cast
|
20 | 20 |
|
21 | 21 | from absl.testing import absltest
|
22 |
| -import gym |
23 | 22 | from iris import checkpoint_util
|
24 | 23 | from iris import coordinator
|
25 | 24 | from iris.algorithms import ars_algorithm
|
26 |
| -from iris.policies import nn_policy |
27 |
| -from iris.workers import rl_worker |
| 25 | +from iris.workers import simple_worker |
28 | 26 | import launchpad as lp
|
29 | 27 | from ml_collections import config_dict
|
30 | 28 | import numpy as np
|
|
35 | 33 | _TEST_CHECKPOINT = "./testdata/test_checkpoint.pkl"
|
36 | 34 |
|
37 | 35 |
|
38 |
| -class TestEnv(gym.Env): |
39 |
| - |
40 |
| - def __init__(self): |
41 |
| - self._ac_dim = 6 |
42 |
| - self._ob_dim = 14 |
43 |
| - self.action_space = gym.spaces.Box( |
44 |
| - -1 * np.ones(self._ac_dim), np.ones(self._ac_dim), dtype=np.float32 |
45 |
| - ) |
46 |
| - self.observation_space = gym.spaces.Box( |
47 |
| - -1 * np.ones(self._ob_dim), np.ones(self._ob_dim), dtype=np.float32 |
48 |
| - ) |
49 |
| - |
50 |
| - def step(self, action): |
51 |
| - del action |
52 |
| - return np.zeros(self._ob_dim), 1.0, False, {} |
53 |
| - |
54 |
| - def reset(self): |
55 |
| - return np.zeros(self._ob_dim) |
56 |
| - |
57 |
| - def render(self, mode: str = "rgb_array"): |
58 |
| - return np.zeros((16, 16)) |
59 |
| - |
60 |
| - |
61 | 36 | def make_bb_program(
|
62 | 37 | num_workers: int,
|
63 | 38 | num_eval_workers: int,
|
@@ -154,17 +129,14 @@ def setUp(self):
|
154 | 129 | eval_rate=1,
|
155 | 130 | num_iterations=400,
|
156 | 131 | num_evals_per_suggestion=1,
|
157 |
| - record_video_during_eval=True, |
158 | 132 | )
|
159 | 133 | ),
|
160 | 134 | worker=config_dict.ConfigDict(
|
161 | 135 | dict(
|
162 |
| - worker_class=rl_worker.RLWorker, |
| 136 | + worker_class=simple_worker.SimpleWorker, |
163 | 137 | worker_args=dict(
|
164 |
| - env=TestEnv, |
165 |
| - policy=nn_policy.FullyConnectedNeuralNetworkPolicy, |
166 |
| - policy_args=dict(hidden_layer_sizes=[64, 64]), |
167 |
| - rollout_length=20, |
| 138 | + blackbox_function=np.sum, |
| 139 | + initial_params=np.zeros(shape=(10,)), |
168 | 140 | ),
|
169 | 141 | )
|
170 | 142 | ),
|
|
0 commit comments