Skip to content

Commit

Permalink
refactorize ExplorationTaskGroup. provide abs method make_task. provi…
Browse files Browse the repository at this point in the history
…de ExplorationTaskGroupData as a data holder
  • Loading branch information
Han Wang committed Feb 2, 2024
1 parent ebe7a17 commit cc8a7a2
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 86 deletions.
3 changes: 3 additions & 0 deletions dpgen2/exploration/task/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,8 @@
)
from .task import (
ExplorationTask,
)
from .task_group import (
ExplorationTaskGroup,
ExplorationTaskGroupData,
)
2 changes: 2 additions & 0 deletions dpgen2/exploration/task/conf_sampling_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@

from .task import (
ExplorationTask,
)
from .task_group import (
ExplorationTaskGroup,
)

Expand Down
5 changes: 2 additions & 3 deletions dpgen2/exploration/task/customized_lmp_template_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
)
from .task import (
ExplorationTask,
ExplorationTaskGroup,
)


Expand Down Expand Up @@ -147,7 +146,7 @@ def set_lmp(

def make_task(
self,
) -> ExplorationTaskGroup:
) -> "CustomizedLmpTemplateTaskGroup":
if not self.conf_set:
raise RuntimeError("confs are not set")
if not self.lmp_set:
Expand All @@ -166,7 +165,7 @@ def make_task(
def _make_customized_task_group(
self,
conf,
) -> ExplorationTaskGroup:
) -> "CustomizedLmpTemplateTaskGroup":
with tempfile.TemporaryDirectory() as tmpdir:
with set_directory(Path(tmpdir)):
Path(self.input_lmp_conf_name).write_text(conf)
Expand Down
3 changes: 1 addition & 2 deletions dpgen2/exploration/task/lmp_template_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from .task import (
ExplorationTask,
ExplorationTaskGroup,
)


Expand Down Expand Up @@ -60,7 +59,7 @@ def set_lmp(

def make_task(
self,
) -> ExplorationTaskGroup:
) -> "LmpTemplateTaskGroup":
if not self.conf_set:
raise RuntimeError("confs are not set")
if not self.lmp_set:
Expand Down
3 changes: 1 addition & 2 deletions dpgen2/exploration/task/npt_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
)
from .task import (
ExplorationTask,
ExplorationTaskGroup,
)


Expand Down Expand Up @@ -76,7 +75,7 @@ def set_md(

def make_task(
self,
) -> ExplorationTaskGroup:
) -> "NPTTaskGroup":
"""
Make the LAMMPS task group.
Expand Down
5 changes: 4 additions & 1 deletion dpgen2/exploration/task/stage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,10 @@

from .task import (
ExplorationTask,
)
from .task_group import (
ExplorationTaskGroup,
ExplorationTaskGroupData,
)


Expand Down Expand Up @@ -65,7 +68,7 @@ def make_task(
"""

lmp_task_grp = ExplorationTaskGroup()
lmp_task_grp = ExplorationTaskGroupData()
for ii in self.explor_groups:
# lmp_task_grp.add_group(ii.make_task())
lmp_task_grp += ii.make_task()
Expand Down
76 changes: 0 additions & 76 deletions dpgen2/exploration/task/task.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import os
from abc import (
ABC,
abstractmethod,
)
from collections.abc import (
Sequence,
)
Expand Down Expand Up @@ -60,51 +56,6 @@ def files(self) -> Dict:
return self._files


class ExplorationTaskGroup(Sequence):
"""A group of exploration tasks. Implemented as a `list` of `ExplorationTask`."""

def __init__(self):
super().__init__()
self.clear()

def __getitem__(self, ii: int) -> ExplorationTask:
"""Get the `ii`th task"""
return self.task_list[ii]

def __len__(self) -> int:
"""Get the number of tasks in the group"""
return len(self.task_list)

def clear(self) -> None:
self._task_list = []

@property
def task_list(self) -> List[ExplorationTask]:
"""Get the `list` of `ExplorationTask`"""
return self._task_list

def add_task(self, task: ExplorationTask):
"""Add one task to the group."""
self.task_list.append(task)
return self

def add_group(
self,
group: "ExplorationTaskGroup",
):
"""Add another group to the group."""
# see https://www.python.org/dev/peps/pep-0484/#forward-references for forward references
self._task_list = self._task_list + group._task_list
return self

def __add__(
self,
group: "ExplorationTaskGroup",
):
"""Add another group to the group."""
return self.add_group(group)


class FooTask(ExplorationTask):
def __init__(
self,
Expand All @@ -118,30 +69,3 @@ def __init__(
conf_name: conf_cont,
inpu_name: inpu_cont,
}


class FooTaskGroup(ExplorationTaskGroup):
def __init__(self, numb_task):
super().__init__()
# TODO: confirm the following is correct
self.tlist = ExplorationTaskGroup()
for ii in range(numb_task):
self.tlist.add_task(
FooTask(
f"conf.{ii}",
f"this is conf.{ii}",
f"input.{ii}",
f"this is input.{ii}",
)
)

@property
def task_list(self):
return self.tlist


if __name__ == "__main__":
grp = FooTaskGroup(3)
for ii in grp:
fcs = ii.files()
print(fcs)
104 changes: 104 additions & 0 deletions dpgen2/exploration/task/task_group.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from abc import (
ABC,
abstractmethod,
)
from collections.abc import (
Sequence,
)
from typing import (
Dict,
List,
Tuple,
)

from .task import (
ExplorationTask,
)


class ExplorationTaskGroup(Sequence):
"""A group of exploration tasks. Implemented as a `list` of `ExplorationTask`."""

def __init__(self):
super().__init__()
self.clear()

def __getitem__(self, ii: int) -> ExplorationTask:
"""Get the `ii`th task"""
return self.task_list[ii]

def __len__(self) -> int:
"""Get the number of tasks in the group"""
return len(self.task_list)

def clear(self) -> None:
self._task_list = []

@property
def task_list(self) -> List[ExplorationTask]:
"""Get the `list` of `ExplorationTask`"""
return self._task_list

def add_task(self, task: ExplorationTask):
"""Add one task to the group."""
self.task_list.append(task)
return self

def add_group(
self,
group: "ExplorationTaskGroup",
):
"""Add another group to the group."""
# see https://www.python.org/dev/peps/pep-0484/#forward-references for forward references
self._task_list = self._task_list + group._task_list
return self

def __add__(
self,
group: "ExplorationTaskGroup",
):
"""Add another group to the group."""
return self.add_group(group)

@abstractmethod
def make_task(self) -> "ExplorationTaskGroup":
"""Make the task group."""
pass


class ExplorationTaskGroupData(ExplorationTaskGroup):
"""Data-only exploration task group."""

def __init__(self):
super().__init__()

def make_task(self):
"""Make the task group."""
raise NotImplementedError("This class is not supposed to supply make_task")


class FooTaskGroup(ExplorationTaskGroup):
def __init__(self, numb_task):
super().__init__()
# TODO: confirm the following is correct
self.tlist = ExplorationTaskGroup()
for ii in range(numb_task):
self.tlist.add_task(
FooTask(

Check failure on line 87 in dpgen2/exploration/task/task_group.py

View workflow job for this annotation

GitHub Actions / pyright

"FooTask" is not defined (reportUndefinedVariable)
f"conf.{ii}",
f"this is conf.{ii}",
f"input.{ii}",
f"this is input.{ii}",
)
)

@property
def task_list(self):
return self.tlist


if __name__ == "__main__":
grp = FooTaskGroup(3)
for ii in grp:
fcs = ii.files()
print(fcs)
9 changes: 9 additions & 0 deletions tests/mocked_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -789,6 +789,9 @@ def __init__(self):
)
self.add_task(tt)

def make_task(self):
raise NotImplementedError


class MockedExplorationTaskGroup1(ExplorationTaskGroup):
def __init__(self):
Expand All @@ -801,6 +804,9 @@ def __init__(self):
)
self.add_task(tt)

def make_task(self):
raise NotImplementedError


class MockedExplorationTaskGroup2(ExplorationTaskGroup):
def __init__(self):
Expand All @@ -813,6 +819,9 @@ def __init__(self):
)
self.add_task(tt)

def make_task(self):
raise NotImplementedError


class MockedStage(ExplorationStage):
def make_task(self):
Expand Down
4 changes: 2 additions & 2 deletions tests/test_prep_run_lmp.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
)
from dpgen2.exploration.task import (
ExplorationTask,
ExplorationTaskGroup,
ExplorationTaskGroupData,
)
from dpgen2.op.prep_lmp import (
PrepLmp,
Expand All @@ -90,7 +90,7 @@


def make_task_group_list(ngrp, ntask_per_grp):
tgrp = ExplorationTaskGroup()
tgrp = ExplorationTaskGroupData()
for ii in range(ngrp):
for jj in range(ntask_per_grp):
tt = ExplorationTask()
Expand Down

0 comments on commit cc8a7a2

Please sign in to comment.