Skip to content

Commit

Permalink
change to BaseExplorationTaskGroup
Browse files Browse the repository at this point in the history
  • Loading branch information
Han Wang committed Feb 2, 2024
1 parent f281d1d commit 98c13bf
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 16 deletions.
2 changes: 1 addition & 1 deletion dpgen2/exploration/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,6 @@
ExplorationTask,
)
from .task_group import (
BaseExplorationTaskGroup,
ExplorationTaskGroup,
ExplorationTaskGroupData,
)
4 changes: 2 additions & 2 deletions dpgen2/exploration/task/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
ExplorationTask,
)
from .task_group import (
BaseExplorationTaskGroup,
ExplorationTaskGroup,
ExplorationTaskGroupData,
)


Expand Down Expand Up @@ -68,7 +68,7 @@ def make_task(
"""

lmp_task_grp = ExplorationTaskGroupData()
lmp_task_grp = BaseExplorationTaskGroup()
for ii in self.explor_groups:
# lmp_task_grp.add_group(ii.make_task())
lmp_task_grp += ii.make_task()
Expand Down
16 changes: 5 additions & 11 deletions dpgen2/exploration/task/task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
)


class ExplorationTaskGroup(Sequence):
class BaseExplorationTaskGroup(Sequence):
"""A group of exploration tasks. Implemented as a `list` of `ExplorationTask`."""

def __init__(self):
Expand Down Expand Up @@ -60,21 +60,15 @@ def __add__(
"""Add another group to the group."""
return self.add_group(group)

@abstractmethod
def make_task(self) -> "ExplorationTaskGroup":
"""Make the task group."""
pass


class ExplorationTaskGroupData(ExplorationTaskGroup):
"""Data-only exploration task group."""

class ExplorationTaskGroup(ABC, BaseExplorationTaskGroup):
def __init__(self):
super().__init__()

def make_task(self):
@abstractmethod
def make_task(self) -> "ExplorationTaskGroup":
"""Make the task group."""
raise NotImplementedError("This class is not supposed to supply make_task")
pass


class FooTask(ExplorationTask):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_prep_run_lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@
train_task_pattern,
)
from dpgen2.exploration.task import (
BaseExplorationTaskGroup,
ExplorationTask,
ExplorationTaskGroupData,
)
from dpgen2.op.prep_lmp import (
PrepLmp,
Expand All @@ -90,7 +90,7 @@


def make_task_group_list(ngrp, ntask_per_grp):
tgrp = ExplorationTaskGroupData()
tgrp = BaseExplorationTaskGroup()
for ii in range(ngrp):
for jj in range(ntask_per_grp):
tt = ExplorationTask()
Expand Down

0 comments on commit 98c13bf

Please sign in to comment.