Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion clpy/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '2.1.0rc1'
__version__ = '2.1.0.1'
2 changes: 2 additions & 0 deletions clpy/random/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,8 @@ def get_random_state():
seed = os.getenv('CLPY_SEED')
if seed is None:
seed = os.getenv('CHAINER_SEED')
if seed is not None:
seed = numpy.uint64(int(seed))
rs = RandomState(seed)
rs = _random_states.setdefault(dev.id, rs)
return rs
Expand Down
2 changes: 1 addition & 1 deletion docker/python2/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ RUN apt-get update -y && \
python-pip && \
rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*

RUN pip install cupy==2.1.0
RUN pip install cupy==2.1.0.1
2 changes: 1 addition & 1 deletion docker/python3/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,4 +6,4 @@ RUN apt-get update -y && \
python3-pip && \
rm -rf /var/lib/apt/lists/* /var/cache/apt/archives/*

RUN pip3 install cupy==2.1.0
RUN pip3 install cupy==2.1.0.1
54 changes: 40 additions & 14 deletions tests/clpy_tests/random_tests/test_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,15 +551,12 @@ def test_thread_safe(self):
class TestGetRandomState2(unittest.TestCase):

def setUp(self):
self.rs_tmp = generator.RandomState
generator.RandomState = mock.Mock()
self.rs_dict = generator._random_states
generator._random_states = {}
self.clpy_seed = os.getenv('CLPY_SEED')
self.chainer_seed = os.getenv('CHAINER_SEED')

def tearDown(self, *args):
generator.RandomState = self.rs_tmp
generator._random_states = self.rs_dict
if self.clpy_seed is None:
os.environ.pop('CLPY_SEED', None)
Expand All @@ -573,26 +570,55 @@ def tearDown(self, *args):
def test_get_random_state_no_clpy_no_chainer_seed(self):
os.environ.pop('CLPY_SEED', None)
os.environ.pop('CHAINER_SEED', None)
generator.get_random_state()
generator.RandomState.assert_called_with(None)
rvs0 = self._get_rvs_reset()
rvs1 = self._get_rvs_reset()

self._check_different(rvs0, rvs1)

def test_get_random_state_no_cupy_with_chainer_seed(self):
rvs0 = self._get_rvs(generator.RandomState(5))

def test_get_random_state_no_clpy_with_chainer_seed(self):
os.environ.pop('CLPY_SEED', None)
os.environ['CHAINER_SEED'] = '5'
generator.get_random_state()
generator.RandomState.assert_called_with('5')
rvs1 = self._get_rvs_reset()

self._check_same(rvs0, rvs1)

def test_get_random_state_with_cupy_no_chainer_seed(self):
rvs0 = self._get_rvs(generator.RandomState(6))

def test_get_random_state_with_clpy_no_chainer_seed(self):
os.environ['CLPY_SEED'] = '6'
os.environ.pop('CHAINER_SEED', None)
generator.get_random_state()
generator.RandomState.assert_called_with('6')
rvs1 = self._get_rvs_reset()

self._check_same(rvs0, rvs1)

def test_get_random_state_with_cupy_with_chainer_seed(self):
rvs0 = self._get_rvs(generator.RandomState(7))

def test_get_random_state_with_clpy_with_chainer_seed(self):
os.environ['CLPY_SEED'] = '7'
os.environ['CHAINER_SEED'] = '8'
generator.get_random_state()
generator.RandomState.assert_called_with('7')
rvs1 = self._get_rvs_reset()

self._check_same(rvs0, rvs1)

def _get_rvs(self, rs):
rvu = rs.rand(4)
rvn = rs.randn(4)
return rvu, rvn

def _get_rvs_reset(self):
generator.reset_states()
return self._get_rvs(generator.get_random_state())

def _check_same(self, rvs0, rvs1):
for rv0, rv1 in zip(rvs0, rvs1):
testing.assert_array_equal(rv0, rv1)

def _check_different(self, rvs0, rvs1):
for rv0, rv1 in zip(rvs0, rvs1):
for r0, r1 in zip(rv0, rv1):
self.assertNotEqual(r0, r1)


class TestCheckAndGetDtype(unittest.TestCase):
Expand Down