Skip to content

Commit e4bb649

Browse files
authored
Code formatting and linting (#28)
* Add flake configuration * Add pre-commit hook configuration * Update codebase to satisfy flake * Update flake configuration * Update conda environment.yml
1 parent bfcf39a commit e4bb649

16 files changed

+87
-177
lines changed

.flake8

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
[flake8]
2+
max-line-length = 88
3+
ignore = E731,W503,E722,E203,F821,F722

.pre-commit-config.yaml

+16
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
repos:
2+
- repo: https://github.com/pre-commit/pre-commit-hooks
3+
rev: v3.1.0
4+
hooks:
5+
- id: check-yaml
6+
- id: end-of-file-fixer
7+
- id: trailing-whitespace
8+
- repo: https://github.com/ambv/black
9+
rev: stable
10+
hooks:
11+
- id: black
12+
language_version: python3.7
13+
- repo: https://gitlab.com/pycqa/flake8
14+
rev: 3.8.3
15+
hooks:
16+
- id: flake8

environment.yml

+33-2
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,11 @@ channels:
77
- defaults
88
dependencies:
99
- _libgcc_mutex=0.1
10+
- alabaster=0.7.12
1011
- appdirs=1.4.3
12+
- argh=0.26.2
1113
- attrs=19.3.0
14+
- babel=2.8.0
1215
- backcall=0.1.0
1316
- beautifulsoup4=4.9.0
1417
- black=19.10b0
@@ -22,9 +25,11 @@ dependencies:
2225
- certifi=2020.6.20
2326
- cffi=1.14.0
2427
- cfgrib=0.9.8.2
28+
- cfgv=3.1.0
2529
- chardet=3.0.4
2630
- click=7.1.2
2731
- cloudpickle=1.4.1
32+
- commonmark=0.9.1
2833
- cryptography=2.9.2
2934
- cudatoolkit=10.2.89
3035
- curl=7.69.1
@@ -38,18 +43,25 @@ dependencies:
3843
- dbus=1.13.14
3944
- decorator=4.4.2
4045
- defusedxml=0.6.0
46+
- distlib=0.3.0
4147
- distributed=2.17.0
48+
- docopt=0.6.2
49+
- docutils=0.16
4250
- eccodes=2.17.0
51+
- editdistance=0.5.3
4352
- entrypoints=0.3
4453
- expat=2.2.6
4554
- fancycompleter=0.8
4655
- fastai=1.0.61
4756
- fastprogress=0.2.2
57+
- filelock=3.0.12
4858
- fire=0.3.1
59+
- flake8=3.8.3
4960
- fontconfig=2.13.0
5061
- freetype=2.9.1
5162
- fribidi=1.0.9
5263
- fsspec=0.7.4
64+
- future=0.18.2
5365
- fzf=0.21.1
5466
- glib=2.63.1
5567
- gmp=6.1.2
@@ -64,7 +76,9 @@ dependencies:
6476
- hdf5=1.10.5
6577
- heapdict=1.0.1
6678
- icu=58.2
79+
- identify=1.4.20
6780
- idna=2.9
81+
- imagesize=1.2.0
6882
- importlib-metadata=1.6.0
6983
- importlib_metadata=1.6.0
7084
- intel-openmp=2020.1
@@ -99,6 +113,7 @@ dependencies:
99113
- libuuid=1.0.3
100114
- libxcb=1.13
101115
- libxml2=2.9.9
116+
- livereload=2.6.2
102117
- locket=0.2.0
103118
- markupsafe=1.1.1
104119
- matplotlib=3.1.3
@@ -116,6 +131,7 @@ dependencies:
116131
- nbformat=5.0.6
117132
- ncurses=6.2
118133
- ninja=1.9.0
134+
- nodeenv=1.4.0
119135
- notebook=6.0.3
120136
- numexpr=2.7.1
121137
- numpy=1.18.1
@@ -136,8 +152,11 @@ dependencies:
136152
- pickleshare=0.7.5
137153
- pillow=7.1.2
138154
- pip=20.0.2
155+
- pipreqs=0.4.10
139156
- pixman=0.38.0
140157
- plac=0.9.6
158+
- port-for=0.4
159+
- pre-commit=2.5.1
141160
- preshed=2.0.1
142161
- prometheus_client=0.7.1
143162
- prompt-toolkit=3.0.4
@@ -167,8 +186,11 @@ dependencies:
167186
- qtconsole=4.7.4
168187
- qtpy=1.9.0
169188
- readline=8.0
189+
- recommonmark=0.6.0
170190
- regex=2020.5.14
171191
- requests=2.23.0
192+
- rope=0.17.0
193+
- rstcheck=3.3.1
172194
- scipy=1.4.1
173195
- send2trash=1.5.0
174196
- setproctitle=1.1.10
@@ -180,6 +202,14 @@ dependencies:
180202
- sortedcontainers=2.1.0
181203
- soupsieve=2.0.1
182204
- spacy=2.1.8
205+
- sphinx=3.1.1
206+
- sphinx-autobuild=0.7.1
207+
- sphinxcontrib-applehelp=1.0.2
208+
- sphinxcontrib-devhelp=1.0.2
209+
- sphinxcontrib-htmlhelp=1.0.3
210+
- sphinxcontrib-jsmath=1.0.1
211+
- sphinxcontrib-qthelp=1.0.3
212+
- sphinxcontrib-serializinghtml=1.1.4
183213
- sqlite=3.31.1
184214
- srsly=0.1.0
185215
- tblib=1.6.0
@@ -197,6 +227,7 @@ dependencies:
197227
- typed-ast=1.4.1
198228
- typing_extensions=3.7.4.1
199229
- urllib3=1.25.8
230+
- virtualenv=20.0.20
200231
- wasabi=0.2.2
201232
- wcwidth=0.1.9
202233
- webencodings=0.5.1
@@ -207,6 +238,7 @@ dependencies:
207238
- xonsh=0.9.18
208239
- xz=5.2.5
209240
- yaml=0.1.7
241+
- yarg=0.1.9
210242
- zeromq=4.3.1
211243
- zict=2.0.0
212244
- zipp=3.1.0
@@ -230,7 +262,6 @@ dependencies:
230262
- ecmwf-api-client==1.5.4
231263
- ffsend==0.1.3
232264
- flask==1.1.2
233-
- future==0.18.2
234265
- gitdb==4.0.5
235266
- gitpython==3.1.3
236267
- google-auth==1.16.0
@@ -244,6 +275,7 @@ dependencies:
244275
- markdown==3.2.2
245276
- msal==1.3.0
246277
- netcdf4==1.5.3
278+
- nonechucks==0.4.0
247279
- nvidia-ml-py3==7.352.0
248280
- oauthlib==3.1.0
249281
- pathtools==0.1.2
@@ -270,4 +302,3 @@ dependencies:
270302
- watchdog==0.10.2
271303
- werkzeug==1.0.1
272304
prefix: /home/esowc/anaconda3/envs/wildfire-dl
273-

src/dataloader/base_loader.py

+3-20
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,13 @@
11
"""
22
Base Dataset class to work with fwi-forcings data.
33
"""
4-
import os
5-
from argparse import ArgumentParser
6-
from collections import OrderedDict
7-
import json
8-
from glob import glob
4+
from collections import defaultdict
95

10-
import xarray as xr
116
import numpy as np
127

13-
148
import torch
15-
import torch.nn as nn
16-
from torch.nn import Sequential, MaxPool2d, ReLU, BatchNorm2d, Conv2d
17-
import torch.nn.functional as F
18-
import torchvision.transforms as transforms
19-
from torch import optim
20-
from torch.utils.data import DataLoader
219
from torch.utils.data import Dataset
2210

23-
# Logging helpers
24-
from pytorch_lightning import _logger as log
25-
from pytorch_lightning.core import LightningModule
26-
import wandb
27-
2811

2912
class ModelDataset(Dataset):
3013
"""
@@ -89,7 +72,7 @@ def training_step(self, model, batch):
8972
"""
9073
# forward pass
9174
x, y_pre = batch
92-
y_hat_pre, aux_y_hat = model(x) if model.aux else model(x), None
75+
y_hat_pre = model(x)
9376
mask = model.data.mask.expand_as(y_pre[0][0])
9477
assert y_pre.shape == y_hat_pre.shape
9578
tensorboard_logs = defaultdict(dict)
@@ -118,7 +101,7 @@ def validation_step(self, model, batch):
118101
"""
119102
# forward pass
120103
x, y_pre = batch
121-
y_hat_pre, aux_y_hat = model(x) if model.aux else model(x), None
104+
y_hat_pre = model(x)
122105
mask = model.data.mask.expand_as(y_pre[0][0])
123106
assert y_pre.shape == y_hat_pre.shape
124107
tensorboard_logs = defaultdict(dict)
1.47 KB
Binary file not shown.

src/dataloader/fwi_forecast.py

+8-23
Original file line numberDiff line numberDiff line change
@@ -1,29 +1,12 @@
11
"""
22
The dataset class to be used with fwi-forcings and fwi-forecast data.
33
"""
4-
import os
5-
from argparse import ArgumentParser
6-
from collections import OrderedDict
7-
import json
84
from glob import glob
95

106
import xarray as xr
11-
import numpy as np
12-
137

148
import torch
15-
import torch.nn as nn
16-
from torch.nn import Sequential, MaxPool2d, ReLU, BatchNorm2d, Conv2d
17-
import torch.nn.functional as F
189
import torchvision.transforms as transforms
19-
from torch import optim
20-
from torch.utils.data import DataLoader
21-
from torch.utils.data import Dataset
22-
23-
# Logging helpers
24-
from pytorch_lightning import _logger as log
25-
from pytorch_lightning.core import LightningModule
26-
import wandb
2710

2811
from dataloader.base_loader import ModelDataset as BaseDataset
2912

@@ -69,9 +52,10 @@ def __init__(
6952
1 <= int(x.split("2019")[1].split("_1200_hr_")[0][:2]) <= 12
7053
and 1 <= int(x.split("2019")[1].split("_1200_hr_")[0][2:]) <= 31
7154
)
72-
assert not (
73-
sum([inp_invalid(x) for x in out_files])
74-
), "Invalid date format for input file(s). The dates should be formatted as YYMMDD."
55+
assert not (sum([inp_invalid(x) for x in inp_files])), (
56+
"Invalid date format for input file(s)."
57+
"The dates should be formatted as YYMMDD."
58+
)
7559
with xr.open_mfdataset(
7660
inp_files, preprocess=preprocess, engine="h5netcdf"
7761
) as ds:
@@ -85,9 +69,10 @@ def __init__(
8569
out_invalid = lambda x: not (
8670
1 <= int(x[-19:-17]) <= 12 and 1 <= int(x[-17:-15]) <= 31
8771
)
88-
assert not (
89-
sum([out_invalid(x) for x in out_files])
90-
), "Invalid date format for output file(s). The dates should be formatted as YYMMDD."
72+
assert not (sum([out_invalid(x) for x in out_files])), (
73+
"Invalid date format for output file(s)."
74+
"The dates should be formatted as YYMMDD."
75+
)
9176
with xr.open_mfdataset(
9277
out_files, preprocess=preprocess, engine="h5netcdf"
9378
) as ds:

src/dataloader/fwi_reanalysis.py

+9-23
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,15 @@
11
"""
22
The dataset class to be used with fwi-forcings and fwi-reanalysis data.
33
"""
4-
import os
5-
from argparse import ArgumentParser
6-
from collections import OrderedDict
7-
import json
84
from glob import glob
95
from collections import defaultdict
106
import pickle
117

128
import xarray as xr
139
import numpy as np
1410

15-
1611
import torch
17-
import torch.nn as nn
18-
from torch.nn import Sequential, MaxPool2d, ReLU, BatchNorm2d, Conv2d
19-
import torch.nn.functional as F
2012
import torchvision.transforms as transforms
21-
from torch import optim
22-
from torch.utils.data import DataLoader
23-
from torch.utils.data import Dataset
24-
25-
# Logging helpers
26-
from pytorch_lightning import _logger as log
27-
from pytorch_lightning.core import LightningModule
28-
import wandb
2913

3014
from dataloader.base_loader import ModelDataset as BaseDataset
3115

@@ -61,7 +45,7 @@ def __init__(
6145

6246
# Number of input and prediction days
6347
assert (
64-
not self.hparams.in_days > 0 and self.hparams.out_days > 0
48+
self.hparams.in_days > 0 and self.hparams.out_days > 0
6549
), "The number of input and output days must be > 0."
6650
self.n_input = self.hparams.in_days
6751
self.n_output = self.hparams.out_days
@@ -120,18 +104,20 @@ def __init__(
120104
out_invalid = lambda x: not (
121105
1 <= int(x[-22:-20]) <= 12 and 1 <= int(x[-20:-18]) <= 31
122106
)
123-
assert not (
124-
sum([out_invalid(x) for x in out_files])
125-
), "Invalid date format for output file(s). The dates should be formatted as YYMMDD."
107+
assert not (sum([out_invalid(x) for x in out_files])), (
108+
"Invalid date format for output file(s)."
109+
"The dates should be formatted as YYMMDD."
110+
)
126111
self.out_files = out_files
127112

128113
inp_invalid = lambda x: not (
129114
1 <= int(x.split("_20")[1][2:].split("_1200_hr_")[0][:2]) <= 12
130115
and 1 <= int(x.split("_20")[1][2:].split("_1200_hr_")[0][2:]) <= 31
131116
)
132-
assert not (
133-
sum([inp_invalid(x) for x in inp_files])
134-
), "Invalid date format for input file(s). The dates should be formatted as YYMMDD."
117+
assert not (sum([inp_invalid(x) for x in inp_files])), (
118+
"Invalid date format for input file(s)."
119+
"The dates should be formatted as YYMMDD."
120+
)
135121
self.inp_files = inp_files
136122

137123
# Consider only ground truth and discard forecast values

src/dataloader/mask.npy

800 KB
Binary file not shown.

src/dataloader/test_set.pkl

4.64 KB
Binary file not shown.

src/model/base_model.py

-14
Original file line numberDiff line numberDiff line change
@@ -1,31 +1,17 @@
11
"""
22
Base model implementing helper methods.
33
"""
4-
import os
5-
from argparse import ArgumentParser
6-
from collections import OrderedDict
7-
import json
8-
from glob import glob
9-
import types
104
import pickle
115
from collections import defaultdict
126

13-
import xarray as xr
14-
import numpy as np
15-
167

178
import torch
18-
import torch.nn as nn
19-
import torch.nn.functional as F
20-
import torchvision.transforms as transforms
219
from torch import optim
2210
from torch.utils.data import DataLoader
23-
from torch.utils.data import Dataset
2411

2512
# Logging helpers
2613
from pytorch_lightning import _logger as log
2714
from pytorch_lightning.core import LightningModule
28-
import wandb
2915

3016

3117
class BaseModel(LightningModule):

0 commit comments

Comments
 (0)