diff --git a/dpgen2/exploration/task/__init__.py b/dpgen2/exploration/task/__init__.py index 602618d7..19562255 100644 --- a/dpgen2/exploration/task/__init__.py +++ b/dpgen2/exploration/task/__init__.py @@ -25,6 +25,6 @@ ExplorationTask, ) from .task_group import ( + BaseExplorationTaskGroup, ExplorationTaskGroup, - ExplorationTaskGroupData, ) diff --git a/dpgen2/exploration/task/stage.py b/dpgen2/exploration/task/stage.py index dd281f26..a039f947 100644 --- a/dpgen2/exploration/task/stage.py +++ b/dpgen2/exploration/task/stage.py @@ -16,8 +16,8 @@ ExplorationTask, ) from .task_group import ( + BaseExplorationTaskGroup, ExplorationTaskGroup, - ExplorationTaskGroupData, ) @@ -68,7 +68,7 @@ def make_task( """ - lmp_task_grp = ExplorationTaskGroupData() + lmp_task_grp = BaseExplorationTaskGroup() for ii in self.explor_groups: # lmp_task_grp.add_group(ii.make_task()) lmp_task_grp += ii.make_task() diff --git a/dpgen2/exploration/task/task_group.py b/dpgen2/exploration/task/task_group.py index 964a9cd3..f9d320e7 100644 --- a/dpgen2/exploration/task/task_group.py +++ b/dpgen2/exploration/task/task_group.py @@ -16,7 +16,7 @@ ) -class ExplorationTaskGroup(Sequence): +class BaseExplorationTaskGroup(Sequence): """A group of exploration tasks. Implemented as a `list` of `ExplorationTask`.""" def __init__(self): @@ -60,21 +60,15 @@ def __add__( """Add another group to the group.""" return self.add_group(group) - @abstractmethod - def make_task(self) -> "ExplorationTaskGroup": - """Make the task group.""" - pass - - -class ExplorationTaskGroupData(ExplorationTaskGroup): - """Data-only exploration task group.""" +class ExplorationTaskGroup(ABC, BaseExplorationTaskGroup): def __init__(self): super().__init__() - def make_task(self): + @abstractmethod + def make_task(self) -> "ExplorationTaskGroup": """Make the task group.""" - raise NotImplementedError("This class is not supposed to supply make_task") + pass class FooTask(ExplorationTask): diff --git a/tests/test_prep_run_lmp.py b/tests/test_prep_run_lmp.py index 6960878f..1aca5000 100644 --- a/tests/test_prep_run_lmp.py +++ b/tests/test_prep_run_lmp.py @@ -69,8 +69,8 @@ train_task_pattern, ) from dpgen2.exploration.task import ( + BaseExplorationTaskGroup, ExplorationTask, - ExplorationTaskGroupData, ) from dpgen2.op.prep_lmp import ( PrepLmp, @@ -90,7 +90,7 @@ def make_task_group_list(ngrp, ntask_per_grp): - tgrp = ExplorationTaskGroupData() + tgrp = BaseExplorationTaskGroup() for ii in range(ngrp): for jj in range(ntask_per_grp): tt = ExplorationTask()