Skip to content

Commit 521284b

Browse files
committed
test(nyz): fix dreamer unittest bugs
1 parent ce5e50c commit 521284b

File tree

2 files changed

+5
-3
lines changed

2 files changed

+5
-3
lines changed

ding/world_model/tests/test_dreamer.py renamed to ding/world_model/tests/test_dreamerv3.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from ding.utils import deep_merge_dicts
88

99
# arguments
10-
state_size = [3, 64, 64]
10+
state_size = [[3, 64, 64]]
1111
action_size = [6, 1]
1212
args = list(product(*[state_size, action_size]))
1313

ding/world_model/tests/test_world_model.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
import os
23
import torch
34
from easydict import EasyDict
45
from ding.world_model.base_world_model import DreamWorldModel, DynaWorldModel
@@ -10,8 +11,8 @@ class TestDynaWorldModel:
1011

1112
@pytest.mark.parametrize('buffer_type', [NaiveReplayBuffer, EpisodeReplayBuffer])
1213
def test_fill_img_buffer(self, buffer_type):
13-
env_buffer = buffer_type(buffer_type.default_config(), None, 'exp_name', 'env_buffer_for_test')
14-
img_buffer = buffer_type(buffer_type.default_config(), None, 'exp_name', 'img_buffer_for_test')
14+
env_buffer = buffer_type(buffer_type.default_config(), None, 'dyna_exp_name', 'env_buffer_for_test')
15+
img_buffer = buffer_type(buffer_type.default_config(), None, 'dyna_exp_name', 'img_buffer_for_test')
1516
fake_config = EasyDict(
1617
train_freq=250, # w.r.t environment step
1718
eval_freq=250, # w.r.t environment step
@@ -74,6 +75,7 @@ def step(self, obs, action):
7475
)
7576

7677
super(FakeModel, fake_model).fill_img_buffer(policy, env_buffer, img_buffer, 0, 0)
78+
os.popen("rm -rf dyna_exp_name")
7779

7880

7981
@pytest.mark.unittest

0 commit comments

Comments
 (0)