Skip to content

Commit e5bd150

Browse files
xingyousongcopybara-github
authored andcommitted
Turn on Github CI for coordinator_test (a rough approximation to running whole package E2E)
PiperOrigin-RevId: 728934650
1 parent 47db8f5 commit e5bd150

File tree

3 files changed

+6
-34
lines changed

3 files changed

+6
-34
lines changed

.github/workflows/core_test.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,4 +35,4 @@ jobs:
3535
pip freeze
3636
- name: Test with pytest # TODO(team): Fix tensorflow version conflict.
3737
run: |
38-
# pytest -n auto iris
38+
pytest -n auto iris/coordinator_test.py

iris/coordinator_test.py

+4-32
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,10 @@
1919
from typing import cast
2020

2121
from absl.testing import absltest
22-
import gym
2322
from iris import checkpoint_util
2423
from iris import coordinator
2524
from iris.algorithms import ars_algorithm
26-
from iris.policies import nn_policy
27-
from iris.workers import rl_worker
25+
from iris.workers import simple_worker
2826
import launchpad as lp
2927
from ml_collections import config_dict
3028
import numpy as np
@@ -35,29 +33,6 @@
3533
_TEST_CHECKPOINT = "./testdata/test_checkpoint.pkl"
3634

3735

38-
class TestEnv(gym.Env):
39-
40-
def __init__(self):
41-
self._ac_dim = 6
42-
self._ob_dim = 14
43-
self.action_space = gym.spaces.Box(
44-
-1 * np.ones(self._ac_dim), np.ones(self._ac_dim), dtype=np.float32
45-
)
46-
self.observation_space = gym.spaces.Box(
47-
-1 * np.ones(self._ob_dim), np.ones(self._ob_dim), dtype=np.float32
48-
)
49-
50-
def step(self, action):
51-
del action
52-
return np.zeros(self._ob_dim), 1.0, False, {}
53-
54-
def reset(self):
55-
return np.zeros(self._ob_dim)
56-
57-
def render(self, mode: str = "rgb_array"):
58-
return np.zeros((16, 16))
59-
60-
6136
def make_bb_program(
6237
num_workers: int,
6338
num_eval_workers: int,
@@ -154,17 +129,14 @@ def setUp(self):
154129
eval_rate=1,
155130
num_iterations=400,
156131
num_evals_per_suggestion=1,
157-
record_video_during_eval=True,
158132
)
159133
),
160134
worker=config_dict.ConfigDict(
161135
dict(
162-
worker_class=rl_worker.RLWorker,
136+
worker_class=simple_worker.SimpleWorker,
163137
worker_args=dict(
164-
env=TestEnv,
165-
policy=nn_policy.FullyConnectedNeuralNetworkPolicy,
166-
policy_args=dict(hidden_layer_sizes=[64, 64]),
167-
rollout_length=20,
138+
blackbox_function=np.sum,
139+
initial_params=np.zeros(shape=(10,)),
168140
),
169141
)
170142
),

requirements.txt

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
typing # Version dependent on Python version.
22
absl-py>=1.0.0
3-
numpy>=1.21.5
3+
numpy>=1.21.5,<2.0.0
44

55
# Distributed systems libraries.
66
# NOTE: Requires tensorflow~=2.8.0 to avoid proto issues.

0 commit comments

Comments
 (0)