Skip to content

Commit f9b65a1

Browse files
committed
Fix typing problems found by MyPy
See #68
1 parent 55387fc commit f9b65a1

File tree

4 files changed

+36
-25
lines changed

4 files changed

+36
-25
lines changed

xcengine/cli.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
import os
99
import pathlib
1010
import subprocess
11-
import sys
1211
import tempfile
12+
from typing import TypedDict
1313

1414
import click
1515
import yaml
@@ -86,7 +86,7 @@ def make_script(
8686
output_dir=output_dir, clear_output=clear
8787
)
8888
if batch or server:
89-
args = ["python3", output_dir / "execute.py"]
89+
args: list[str | pathlib.Path] = ["python3", output_dir / "execute.py"]
9090
if batch:
9191
args.append("--batch")
9292
if server:
@@ -102,7 +102,8 @@ def image_cli():
102102

103103

104104
@image_cli.command(
105-
help="Build, and optionally run, a compute engine as a Docker image"
105+
help="Build a compute engine as a Docker image, optionally generating an "
106+
"Application Package"
106107
)
107108
@click.option(
108109
"-b",
@@ -144,7 +145,11 @@ def build(
144145
) -> None:
145146
if environment is None:
146147
LOGGER.info("No environment file specified on command line.")
147-
init_args = dict(notebook=notebook, environment=environment, tag=tag)
148+
class InitArgs(TypedDict):
149+
notebook: pathlib.Path
150+
environment: pathlib.Path
151+
tag: str
152+
init_args = InitArgs(notebook=notebook, environment=environment, tag=tag)
148153
if build_dir:
149154
image_builder = ImageBuilder(build_dir=build_dir, **init_args)
150155
os.makedirs(build_dir, exist_ok=True)
@@ -156,11 +161,9 @@ def build(
156161
)
157162
image = image_builder.build()
158163
if eoap:
159-
160164
class IndentDumper(yaml.Dumper):
161165
def increase_indent(self, flow=False, indentless=False):
162166
return super(IndentDumper, self).increase_indent(flow, False)
163-
164167
eoap.write_text(
165168
yaml.dump(
166169
image_builder.create_cwl(),
@@ -212,7 +215,7 @@ def increase_indent(self, flow=False, indentless=False):
212215
def run(
213216
ctx: click.Context,
214217
batch: bool,
215-
server: False,
218+
server: bool,
216219
port: int,
217220
from_saved: bool,
218221
keep: bool,

xcengine/core.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import os
88
import shutil
99
import signal
10-
import socket
1110
import sys
1211
import tarfile
1312
import subprocess
@@ -263,6 +262,7 @@ def export_conda_env() -> dict:
263262
)
264263
pip_inspect = PipInspector()
265264
if pip_map:
265+
assert pip_index is not None
266266
nonlocals = []
267267
for pkg in pip_map["pip"]:
268268
if pip_inspect.is_local(pkg):
@@ -340,7 +340,7 @@ def __init__(
340340
self,
341341
image: Image | str,
342342
output_dir: pathlib.Path | None,
343-
client: docker.DockerClient = None,
343+
client: docker.DockerClient | None = None,
344344
):
345345
self._client = client
346346
match image:
@@ -383,7 +383,7 @@ def run(
383383
)
384384
+ (["--from-saved"] if from_saved else [])
385385
)
386-
run_args = dict(
386+
run_args: dict[str, Any] = dict(
387387
image=self.image, command=command, remove=False, detach=True
388388
)
389389
if host_port is not None:
@@ -429,6 +429,7 @@ def _tar_strip(member, path):
429429
def extract_output_from_container(self, container: Container) -> None:
430430
# This assumes the image-defined CWD, so it won't work in EOAP mode,
431431
# but EOAP has its own protocol for data stage-in/out anyway.
432+
assert self.output_dir is not None
432433
bits, stat = container.get_archive("/home/mambauser/output")
433434
reader = io.BufferedReader(ChunkStream(bits))
434435
with tarfile.open(name=None, mode="r|", fileobj=reader) as tar_fh:
@@ -463,7 +464,7 @@ class PipInspector:
463464
local filesystem.
464465
"""
465466

466-
def __init__(self):
467+
def __init__(self) -> None:
467468
environment = os.environ.copy()
468469
for varname in "FORCE_COLOR", "CLICOLOR", "CLICOLOR_FORCE":
469470
environment.pop(varname, None)

xcengine/parameters.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import os
33
import pathlib
44
import typing
5-
from typing import Any
5+
from typing import Any, ClassVar
66

77
import pystac
88
import xarray as xr
99
import yaml
10-
from typing import ClassVar
1110

1211
LOGGER = logging.getLogger(__name__)
1312
logging.basicConfig(level=logging.INFO)
@@ -24,7 +23,7 @@ class NotebookParameters:
2423
def __init__(
2524
self,
2625
params: dict[str, tuple[type, Any]],
27-
config: dict[str, Any] = None,
26+
config: dict[str, Any] | None = None,
2827
):
2928
self.params = params
3029
self.config = {} if config is None else config
@@ -73,7 +72,7 @@ def extract_variables(
7372
cls, code: str, setup_code: str | None = None
7473
) -> dict[str, tuple[type, Any]]:
7574
if setup_code is None:
76-
locals_ = {}
75+
locals_: dict[str, object] = {}
7776
old_locals = {}
7877
else:
7978
exec(setup_code, globals(), locals_ := {})
@@ -135,13 +134,13 @@ def to_yaml(self) -> str:
135134

136135
def read_params_combined(
137136
self, cli_args: list[str] | None
138-
) -> dict[str, str]:
137+
) -> dict[str, Any]:
139138
params = self.read_params_from_env()
140139
if cli_args:
141140
params.update(self.read_params_from_cli(cli_args))
142141
return params
143142

144-
def read_params_from_env(self) -> dict[str, str]:
143+
def read_params_from_env(self) -> dict[str, Any]:
145144
values = {}
146145
for param_name, (type_, _) in self.params.items():
147146
env_var_name = "xce_" + param_name
@@ -154,7 +153,7 @@ def read_params_from_env(self) -> dict[str, str]:
154153
)
155154
return values
156155

157-
def read_params_from_cli(self, args: list[str]) -> dict[str, str]:
156+
def read_params_from_cli(self, args: list[str]) -> dict[str, Any]:
158157
values = {}
159158
for param_name, (type_, _) in self.params.items():
160159
arg_name = "--" + param_name.replace("_", "-")
@@ -216,7 +215,7 @@ def read_staged_in_dataset(
216215
),
217216
)
218217
)
219-
asset = next(a for a in item.assets.values() if "data" in a.roles)
218+
asset = next(a for a in item.assets.values() if "data" in (a.roles or []))
220219
return xr.open_dataset(stage_in_path / asset.href)
221220

222221
@staticmethod

xcengine/util.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,16 @@
11
# Copyright (c) 2024-2025 by Brockmann Consult GmbH
22
# Permissions are hereby granted under the terms of the MIT License:
33
# https://opensource.org/licenses/MIT.
4+
45
from collections import namedtuple
56
from datetime import datetime
67
import pathlib
78
import shutil
9+
from typing import NamedTuple, Any, Mapping
810

911
import pystac
1012
import xarray as xr
13+
from xarray import Dataset
1114

1215

1316
def clear_directory(directory: pathlib.Path) -> None:
@@ -19,7 +22,7 @@ def clear_directory(directory: pathlib.Path) -> None:
1922

2023

2124
def write_stac(
22-
datasets: dict[str, xr.Dataset], stac_root: pathlib.Path
25+
datasets: Mapping[str, xr.Dataset], stac_root: pathlib.Path
2326
) -> None:
2427
catalog_path = stac_root / "catalog.json"
2528
if catalog_path.exists():
@@ -57,9 +60,13 @@ def write_stac(
5760
media_type="application/x-netcdf" if output_format == "netcdf" else "application/vnd.zarr",
5861
title=ds.attrs.get("title", ds_name),
5962
)
60-
bb = namedtuple("Bounds", ["left", "bottom", "right", "top"])(
61-
0, -90, 360, 90
62-
) # TODO determine and set actual bounds here
63+
class Bounds(NamedTuple):
64+
left: float
65+
bottom: float
66+
right: float
67+
top: float
68+
# TODO determine and set actual bounds here
69+
bb = Bounds(0, -90, 360, 90)
6370
item = pystac.Item(
6471
id=ds_name,
6572
geometry={
@@ -85,8 +92,8 @@ def write_stac(
8592

8693

8794
def save_datasets(
88-
datasets, output_path: pathlib.Path, eoap_mode: bool
89-
) -> dict[str, xr.Dataset]:
95+
datasets: Mapping[str, Dataset], output_path: pathlib.Path, eoap_mode: bool
96+
) -> dict[str, pathlib.Path]:
9097
saved_datasets = {}
9198
# EOAP doesn't require an "output" subdirectory (output can go anywhere
9299
# in the CWD) but it's used by xcetool's built-in runner.
@@ -98,6 +105,7 @@ def save_datasets(
98105
suffix = "nc" if output_format == "netcdf" else "zarr"
99106
dataset_path = output_subpath / f"{ds_id}.{suffix}"
100107
saved_datasets[ds_id] = dataset_path
108+
101109
if output_format == "netcdf":
102110
ds.to_netcdf(dataset_path)
103111
else:

0 commit comments

Comments
 (0)