Skip to content

Commit 620d360

Browse files
committed
Develop pipeline config
1 parent d5c6e07 commit 620d360

File tree

9 files changed

+219
-45
lines changed

9 files changed

+219
-45
lines changed

data-pipeline/src/data_pipeline/config.py

+49-16
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import os
22
import attr
3+
from enum import Enum
34
from pathlib import Path
45

56
DATA_ENV = os.getenv("DATA_ENV", "local")
@@ -31,25 +32,57 @@ def make_local_folder(self):
3132
Path(self.root).mkdir(parents=True, exist_ok=True)
3233

3334

34-
# @attr.define
35-
# class GnomadV4
36-
# gnomad_v4_exome_variants_sites_ht_path: str = "external_datasets/mock_v4_release.ht"
35+
class ComputeEnvironment(Enum):
36+
local = "local"
37+
cicd = "cicd"
38+
dataproc = "dataproc"
3739

3840

39-
@attr.define
40-
class PipelineConfig:
41-
data_paths: DataPaths
42-
compute_env: str = "local"
43-
data_env: str = "tiny"
41+
class DataEnvironment(Enum):
42+
tiny = "tiny"
43+
full = "full"
4444

4545

46-
config = PipelineConfig(
47-
data_env="local",
48-
data_paths=DataPaths.create(os.path.join("data")),
49-
)
46+
def is_valid_fn(cls):
47+
def is_valid(instance, attribute, value):
48+
if not isinstance(value, cls):
49+
raise ValueError(f"Expected {cls} enum, got {type(value)}")
5050

51+
return is_valid
5152

52-
if DATA_ENV == "dataproc":
53-
config = PipelineConfig(
54-
data_paths=DataPaths.create(os.path.join("gs://gnomad-matt-data-pipeline")),
55-
)
53+
54+
@attr.define
55+
class PipelineConfig:
56+
name: str
57+
input_paths: DataPaths
58+
output_paths: DataPaths
59+
data_env: DataEnvironment = attr.field(validator=is_valid_fn(DataEnvironment))
60+
compute_env: ComputeEnvironment = attr.field(validator=is_valid_fn(ComputeEnvironment))
61+
62+
@classmethod
63+
def create(
64+
cls,
65+
name: str,
66+
input_root: str,
67+
output_root: str,
68+
data_env=DataEnvironment.tiny,
69+
compute_env=ComputeEnvironment.local,
70+
):
71+
input_paths = DataPaths.create(input_root)
72+
output_paths = DataPaths.create(output_root)
73+
return cls(name, input_paths, output_paths, data_env, compute_env)
74+
75+
76+
# config = PipelineConfig.create(
77+
# name=
78+
# input_root="data_in",
79+
# output_root="data_out",
80+
# compute_env=ComputeEnvironment.local,
81+
# data_env=DataEnvironment.tiny,
82+
# )
83+
84+
85+
# if DATA_ENV == "dataproc":
86+
# config = PipelineConfig(
87+
# output_path=DataPaths.create(os.path.join("gs://gnomad-matt-data-pipeline")),
88+
# )

data-pipeline/src/data_pipeline/helpers/logging.py

+3-5
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,14 @@
66

77

88
def create_logger():
9-
config = {
10-
"handlers": [
9+
logger.configure(
10+
handlers=[
1111
{
1212
"sink": sys.stdout,
1313
"format": "<level>{time:YYYY-MM-DDTHH:mm}</level> | <level>{level: <8}</level> | <cyan>{name}</cyan>:<cyan>{function}</cyan>:<cyan>{line}</cyan> - <level>{message}</level>", # noqa
1414
},
1515
]
16-
}
17-
18-
logger.configure(**config)
16+
)
1917
logger.level("CONFIG", no=38, icon="🐍")
2018

2119
# clear log file after each run

data-pipeline/src/data_pipeline/helpers/write_schemas.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def describe_handler(text):
2727

2828

2929
for pipeline in pipelines:
30-
pipeline_name = pipeline.name
30+
pipeline_name = pipeline.config.name
3131
task_names = pipeline.get_all_task_names()
3232
out_dir = os.path.join(SCHEMA_PATH, pipeline_name)
3333

data-pipeline/src/data_pipeline/pipeline.py

+24-18
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@
66
import subprocess
77
import tempfile
88
import time
9-
from typing import List, Optional, Union
9+
from typing import Callable, List, Optional, Union
1010
import attr
1111
from collections import OrderedDict
1212

1313
import hail as hl
1414

15-
from data_pipeline.config import config
15+
from data_pipeline.config import PipelineConfig
1616

1717
logger = logging.getLogger("gnomad_data_pipeline")
1818
logger.setLevel(logging.INFO)
@@ -57,23 +57,24 @@ def modified_time(path):
5757
return file_system.modified_time(check_path)
5858

5959

60-
_pipeline_config = {}
60+
# _pipeline_config = {}
6161

62-
_pipeline_config["output_root"] = config.data_paths.root
62+
# _pipeline_config["output_root"] = config.output_paths.root
6363

6464

6565
@attr.define
6666
class DownloadTask:
67+
_config: PipelineConfig
6768
_name: str
6869
_url: str
6970
_output_path: str
7071

7172
@classmethod
72-
def create(cls, name, url, output_path):
73-
return cls(name, url, output_path)
73+
def create(cls, config: PipelineConfig, name: str, url: str, output_path: str):
74+
return cls(config, name, url, output_path)
7475

7576
def get_output_path(self):
76-
return _pipeline_config["output_root"] + self._output_path
77+
return self._config.output_paths.root + self._output_path
7778

7879
def should_run(self):
7980
output_path = self.get_output_path()
@@ -82,6 +83,9 @@ def should_run(self):
8283

8384
return (False, None)
8485

86+
def get_inputs(self):
87+
raise NotImplementedError("Method not valid for DownloadTask")
88+
8589
def run(self, force=False):
8690
output_path = self.get_output_path()
8791
should_run, reason = (True, "Forced") if force else self.should_run()
@@ -106,17 +110,19 @@ def run(self, force=False):
106110

107111
@attr.define
108112
class Task:
113+
_config: PipelineConfig
109114
_name: str
110-
_task_function: str
115+
_task_function: Callable
111116
_output_path: str
112117
_inputs: dict
113118
_params: dict
114119

115120
@classmethod
116121
def create(
117122
cls,
123+
config: PipelineConfig,
118124
name: str,
119-
task_function: str,
125+
task_function: Callable,
120126
output_path: str,
121127
inputs: Optional[dict] = None,
122128
params: Optional[dict] = None,
@@ -125,10 +131,10 @@ def create(
125131
inputs = {}
126132
if params is None:
127133
params = {}
128-
return cls(name, task_function, output_path, inputs, params)
134+
return cls(config, name, task_function, output_path, inputs, params)
129135

130136
def get_output_path(self):
131-
return _pipeline_config["output_root"] + self._output_path
137+
return self._config.output_paths.root + self._output_path
132138

133139
def get_inputs(self):
134140
paths = {}
@@ -138,7 +144,7 @@ def get_inputs(self):
138144
paths.update({k: v.get_output_path()})
139145
else:
140146
logger.info(v)
141-
paths.update({k: os.path.join(config.data_paths.root, v)})
147+
paths.update({k: os.path.join(self._config.output_paths.root, v)})
142148

143149
return paths
144150

@@ -173,14 +179,14 @@ def run(self, force=False):
173179

174180
@attr.define
175181
class Pipeline:
176-
name: str
182+
config: PipelineConfig
177183
_tasks: OrderedDict = OrderedDict()
178184
_outputs: dict = {}
179185

180186
def add_task(
181187
self,
182188
name: str,
183-
task_function: str,
189+
task_function: Callable,
184190
output_path: str,
185191
inputs: Optional[dict] = None,
186192
params: Optional[dict] = None,
@@ -189,12 +195,12 @@ def add_task(
189195
inputs = {}
190196
if params is None:
191197
params = {}
192-
task = Task.create(name, task_function, output_path, inputs, params)
198+
task = Task.create(self.config, name, task_function, output_path, inputs, params)
193199
self._tasks[name] = task
194200
return task
195201

196202
def add_download_task(self, name, *args, **kwargs) -> DownloadTask:
197-
task = DownloadTask.create(name, *args, **kwargs)
203+
task = DownloadTask.create(self.config, name, *args, **kwargs)
198204
self._tasks[name] = task
199205
return task
200206

@@ -232,8 +238,8 @@ def run_pipeline(pipeline):
232238
group.add_argument("--force-all", action="store_true")
233239
args = parser.parse_args()
234240

235-
if args.output_root:
236-
_pipeline_config["output_root"] = args.output_root.rstrip("/")
241+
# if args.output_root:
242+
# _pipeline_config["output_root"] = args.output_root.rstrip("/")
237243

238244
pipeline_args = {}
239245
if args.force_all:

data-pipeline/src/data_pipeline/pipelines/gnomad_v4_coverage.py

+4-1
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
from data_pipeline.pipeline import Pipeline, run_pipeline
2+
from data_pipeline.config import PipelineConfig
23

34
from data_pipeline.data_types.coverage import prepare_coverage
45

56

6-
pipeline = Pipeline(name="gnomad_v4_coverage")
7+
pipeline = Pipeline(
8+
config=PipelineConfig.create(name="gnomad_v4_variants", input_root="data_in", output_root="data_out")
9+
)
710

811
pipeline.add_task(
912
name="prepare_gnomad_v4_exome_coverage",

data-pipeline/src/data_pipeline/pipelines/gnomad_v4_variants.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,23 @@
1+
from data_pipeline.config import PipelineConfig
12
from data_pipeline.pipeline import Pipeline, run_pipeline
23

3-
from data_pipeline.datasets.gnomad_v4.gnomad_v4_variants import prepare_gnomad_v4_variants
4+
from data_pipeline.datasets.gnomad_v4.gnomad_v4_variants import (
5+
prepare_gnomad_v4_variants,
6+
)
47

5-
from data_pipeline.data_types.variant import annotate_variants, annotate_transcript_consequences
8+
from data_pipeline.data_types.variant import (
9+
annotate_variants,
10+
annotate_transcript_consequences,
11+
)
612

713
# from data_pipeline.pipelines.gnomad_v4_coverage import pipeline as coverage_pipeline
814

915
# from data_pipeline.pipelines.genes import pipeline as genes_pipeline
1016

1117

12-
pipeline = Pipeline(name="gnomad_v4_variants")
18+
pipeline = Pipeline(
19+
config=PipelineConfig.create(name="gnomad_v4_variants", input_root="data_in", output_root="data_out")
20+
)
1321

1422
pipeline.add_task(
1523
name="prepare_gnomad_v4_exome_variants",
@@ -18,7 +26,6 @@
1826
inputs={
1927
"input_path": "external_datasets/mock_v4_release.ht",
2028
},
21-
# params={"sequencing_type": "exome"},
2229
)
2330

2431
# pipeline.add_task(
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
# from loguru import logger
2+
import os
3+
import pytest
4+
import tempfile
5+
6+
from data_pipeline.config import ComputeEnvironment, DataEnvironment, DataPaths, PipelineConfig
7+
8+
# from data_pipeline.pipeline import Pipeline
9+
10+
11+
@pytest.fixture
12+
def input_tmp():
13+
with tempfile.TemporaryDirectory() as temp_dir:
14+
with open(os.path.join(temp_dir, "sample_tiny.txt"), "w") as f:
15+
f.write("tiny dataset")
16+
with open(os.path.join(temp_dir, "sample_full.txt"), "w") as f:
17+
f.write("full dataset")
18+
yield temp_dir
19+
20+
21+
@pytest.fixture
22+
def output_tmp():
23+
with tempfile.TemporaryDirectory() as temp_dir:
24+
yield temp_dir
25+
26+
27+
@pytest.mark.only
28+
def test_config_created(input_tmp, output_tmp):
29+
config = PipelineConfig.create(name="test", input_root=input_tmp, output_root=output_tmp)
30+
assert isinstance(config, PipelineConfig)
31+
assert isinstance(config.input_paths, DataPaths)
32+
assert isinstance(config.output_paths, DataPaths)
33+
assert isinstance(config.compute_env, ComputeEnvironment)
34+
assert isinstance(config.data_env, DataEnvironment)
35+
36+
37+
@pytest.mark.only
38+
def test_config_read_input_file(input_tmp, output_tmp):
39+
config = PipelineConfig.create(
40+
name="test",
41+
input_root=input_tmp,
42+
output_root=output_tmp,
43+
)
44+
sample = os.path.join(config.input_paths.root, "sample_tiny.txt")
45+
with open(sample, "r") as f:
46+
assert f.read() == "tiny dataset"
47+
48+
49+
# @pytest.mark.only
50+
# def test_pipeline_tasks(ht_1_fixture: TestHt, ht_2_fixture: TestHt):
51+
# def task_1_fn():
52+
# pass
53+
54+
# pipeline = Pipeline("p1")
55+
56+
# pipeline.add_task(
57+
# name="task_1_join_hts",
58+
# task_function=task_1_fn,
59+
# output_path="/gnomad_v4/gnomad_v4_exome_variants_base.ht",
60+
# inputs={
61+
# "input_ht_1": ht_1_fixture.path,
62+
# },
63+
# )

0 commit comments

Comments
 (0)