Skip to content

Commit 6e139b6

Browse files
committed
fix(nyz): fix offline mem data fetcher unittest bug
1 parent cacab2e commit 6e139b6

File tree

4 files changed

+11
-11
lines changed

4 files changed

+11
-11
lines changed

ding/entry/tests/test_serial_entry.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -658,7 +658,7 @@ def test_discrete_dt():
658658
from ding.model.template.dt import DecisionTransformer
659659
from ding.policy import DTPolicy
660660
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \
661-
offline_data_fetcher_from_mem_c, offline_logger, termination_checker
661+
OfflineMemoryDataFetcher, offline_logger, termination_checker
662662
config = compile_config(config[0], create_cfg=config[1], auto=True)
663663
with task.start(async_mode=False, ctx=OfflineRLContext()):
664664
evaluator_env = BaseEnvManagerV2(
@@ -675,7 +675,7 @@ def test_discrete_dt():
675675

676676
task.use(termination_checker(max_train_iter=1))
677677
task.use(interaction_evaluator(config, policy.eval_mode, evaluator_env))
678-
task.use(offline_data_fetcher_from_mem_c(config, dataset))
678+
task.use(OfflineMemoryDataFetcher(config, dataset))
679679
task.use(trainer(config, policy.learn_mode))
680680
task.use(CkptSaver(policy, config.exp_name, train_freq=100))
681681
task.use(offline_logger(config.exp_name))

ding/framework/middleware/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@
44
from .ckpt_handler import CkptSaver
55
from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
66
from .barrier import Barrier, BarrierRuntime
7-
from .data_fetcher import offline_data_fetcher_from_mem_c
7+
from .data_fetcher import OfflineMemoryDataFetcher

ding/framework/middleware/data_fetcher.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -2,24 +2,23 @@
22
from threading import Thread, Event
33
from queue import Queue
44
import time
5+
import numpy as np
56
import torch
6-
import torch.distributed as dist
77
from easydict import EasyDict
88
from ding.framework import task
99
from ding.data import Dataset, DataLoader
10-
from ding.utils import get_rank
11-
import numpy as np
10+
from ding.utils import get_rank, get_world_size
1211

1312
if TYPE_CHECKING:
1413
from ding.framework import OfflineRLContext
1514

1615

17-
class offline_data_fetcher_from_mem_c:
16+
class OfflineMemoryDataFetcher:
1817

1918
def __new__(cls, *args, **kwargs):
2019
if task.router.is_active and not task.has_role(task.role.FETCHER):
2120
return task.void()
22-
return super(offline_data_fetcher_from_mem_c, cls).__new__(cls)
21+
return super(OfflineMemoryDataFetcher, cls).__new__(cls)
2322

2423
def __init__(self, cfg: EasyDict, dataset: Dataset):
2524
device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
@@ -30,7 +29,7 @@ def producer(queue, dataset, batch_size, device, event):
3029
torch.set_num_threads(4)
3130
if device != 'cpu':
3231
nonlocal stream
33-
sbatch_size = batch_size * dist.get_world_size()
32+
sbatch_size = batch_size * get_world_size()
3433
rank = get_rank()
3534
idx_list = np.random.permutation(len(dataset))
3635
temp_idx_list = []

dizoo/atari/entry/atari_dt_main.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
from ding.config import compile_config
99
from ding.framework import task, ding_init
1010
from ding.framework.context import OfflineRLContext
11-
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_logger, termination_checker, offline_data_fetcher_from_mem_c, offline_data_fetcher
11+
from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_logger, termination_checker, \
12+
OfflineMemoryDataFetcher
1213
from ding.utils import set_pkg_seed, DDPContext, to_ddp_config
1314
from dizoo.atari.envs import AtariEnv
1415
from dizoo.atari.config.serial.pong.pong_dt_config import main_config, create_config
@@ -43,7 +44,7 @@ def main():
4344
policy = DTPolicy(cfg.policy, model=model)
4445

4546
task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
46-
task.use(offline_data_fetcher_from_mem_c(cfg, dataset))
47+
task.use(OfflineMemoryDataFetcher(cfg, dataset))
4748
task.use(trainer(cfg, policy.learn_mode))
4849
task.use(termination_checker(max_train_iter=3e4))
4950
task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))

0 commit comments

Comments
 (0)