Skip to content

Commit 974dcd8

Browse files
authored
Merge pull request #10 from jejjohnson/main
Updates to scripts
2 parents e45e921 + 5731a05 commit 974dcd8

10 files changed

+676
-10091
lines changed

environment.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ dependencies:
2828
- pip
2929
- pip:
3030
- toolz
31+
- typer
3132
- einops
3233
- rastervision==0.21.3
3334
# formatting

helio_tools/_src/data/sdo/base.py

+114-71
Original file line numberDiff line numberDiff line change
@@ -4,30 +4,38 @@
44
55
Documentation for DRMS: https://docs.sunpy.org/projects/drms/en/latest/
66
"""
7-
7+
from typing import Optional
88
import argparse
99
import logging
1010
import multiprocessing
1111
import os
1212
from datetime import timedelta, datetime
1313
from urllib import request
1414

15-
15+
import tqdm
16+
import warnings
17+
import os
1618
import drms
1719
import numpy as np
1820
import pandas as pd
1921
from astropy.io import fits
2022
from sunpy.io._fits import header_to_fits
2123
from sunpy.util import MetaDict
24+
from helio_tools._src.utils.time import check_datetime_format
25+
import typer
26+
from loguru import logger
2227

2328
DEFAULT_WAVELENGTHS = [171, 193, 211, 304]
2429

2530

2631
class SDODownloader:
27-
def __init__(self, base_path: str = None,
28-
email: str = None,
29-
wavelengths: list[str | int | float] = DEFAULT_WAVELENGTHS,
30-
n_workers: int = 5) -> None:
32+
def __init__(
33+
self,
34+
base_path: str = None,
35+
email: str = None,
36+
wavelengths: list[str | int | float] = DEFAULT_WAVELENGTHS,
37+
n_workers: int = 5,
38+
) -> None:
3139
"""The SDO Downloader is an efficent way to download data from the SDO database.
3240
3341
Args:
@@ -45,33 +53,38 @@ def __init__(self, base_path: str = None,
4553
self.ds_path = base_path
4654
self.wavelengths = [str(wl) for wl in wavelengths]
4755
self.n_workers = n_workers
48-
[os.makedirs(os.path.join(base_path, wl), exist_ok=True)
49-
for wl in self.wavelengths + ['6173']]
56+
[
57+
os.makedirs(os.path.join(base_path, wl), exist_ok=True)
58+
for wl in self.wavelengths + ["6173"]
59+
]
5060

5161
self.drms_client = drms.Client(email=email)
5262

5363
def downloadDate(self, date: datetime):
54-
"""Download FITS data for a specific date.
55-
"""
64+
"""Download FITS data for a specific date."""
5665
id = date.isoformat()
57-
logging.info('Start download: %s' % id)
58-
time_param = '%sZ' % date.isoformat('_', timespec='seconds')
66+
logging.info("Start download: %s" % id)
67+
time_param = "%sZ" % date.isoformat("_", timespec="seconds")
5968

6069
# query Magnetogram Instrument
61-
ds_hmi = 'hmi.M_720s[%s]{magnetogram}' % time_param
70+
ds_hmi = "hmi.M_720s[%s]{magnetogram}" % time_param
6271
keys_hmi = self.drms_client.keys(ds_hmi)
6372
header_hmi, segment_hmi = self.drms_client.query(
64-
ds_hmi, key=','.join(keys_hmi), seg='magnetogram')
73+
ds_hmi, key=",".join(keys_hmi), seg="magnetogram"
74+
)
6575
if len(header_hmi) != 1 or np.any(header_hmi.QUALITY != 0):
6676
self.fetchDataFallback(date)
6777
return
6878

6979
# query EUV Instrument
70-
ds_euv = 'aia.lev1_euv_12s[%s][%s]{image}' % (
71-
time_param, ','.join(self.wavelengths))
80+
ds_euv = "aia.lev1_euv_12s[%s][%s]{image}" % (
81+
time_param,
82+
",".join(self.wavelengths),
83+
)
7284
keys_euv = self.drms_client.keys(ds_euv)
7385
header_euv, segment_euv = self.drms_client.query(
74-
ds_euv, key=','.join(keys_euv), seg='image')
86+
ds_euv, key=",".join(keys_euv), seg="image"
87+
)
7588
if len(header_euv) != len(self.wavelengths) or np.any(header_euv.QUALITY != 0):
7689
self.fetchDataFallback(date)
7790
return
@@ -84,91 +97,97 @@ def downloadDate(self, date: datetime):
8497

8598
with multiprocessing.Pool(self.n_workers) as p:
8699
p.map(self.download, queue)
87-
logging.info('Finished: %s' % id)
100+
logging.info("Finished: %s" % id)
88101

89102
def download(self, sample: tuple[dict, str, datetime]):
90103
header, segment, t = sample
91104
try:
92-
dir = os.path.join(self.ds_path, '%d' % header['WAVELNTH'])
93-
map_path = os.path.join(dir, '%s.fits' %
94-
t.isoformat('T', timespec='seconds'))
105+
dir = os.path.join(self.ds_path, "%d" % header["WAVELNTH"])
106+
map_path = os.path.join(
107+
dir, "%s.fits" % t.isoformat("T", timespec="seconds")
108+
)
95109
if os.path.exists(map_path):
96110
return map_path
97111
# load map
98-
url = 'http://jsoc.stanford.edu' + segment
112+
url = "http://jsoc.stanford.edu" + segment
99113
request.urlretrieve(url, filename=map_path)
100114

101-
header['DATE_OBS'] = header['DATE__OBS']
115+
header["DATE_OBS"] = header["DATE__OBS"]
102116
header = header_to_fits(MetaDict(header))
103-
with fits.open(map_path, 'update') as f:
117+
with fits.open(map_path, "update") as f:
104118
hdr = f[1].header
105119
for k, v in header.items():
106120
if pd.isna(v):
107121
continue
108122
hdr[k] = v
109-
f.verify('silentfix')
123+
f.verify("silentfix")
110124

111125
return map_path
112126
except Exception as ex:
113-
logging.info('Download failed: %s (requeue)' % header['DATE__OBS'])
127+
logging.info("Download failed: %s (requeue)" % header["DATE__OBS"])
114128
logging.info(ex)
115129
raise ex
116130

117131
def fetchDataFallback(self, date: datetime):
118132
id = date.isoformat()
119133

120-
logging.info('Fallback download: %s' % id)
134+
logging.info("Fallback download: %s" % id)
121135
# query Magnetogram
122136
t = date - timedelta(hours=24)
123-
ds_hmi = 'hmi.M_720s[%sZ/12h@720s]{magnetogram}' % t.replace(
124-
tzinfo=None).isoformat('_', timespec='seconds')
137+
ds_hmi = "hmi.M_720s[%sZ/12h@720s]{magnetogram}" % t.replace(
138+
tzinfo=None
139+
).isoformat("_", timespec="seconds")
125140
keys_hmi = self.drms_client.keys(ds_hmi)
126141
header_tmp, segment_tmp = self.drms_client.query(
127-
ds_hmi, key=','.join(keys_hmi), seg='magnetogram')
128-
assert len(header_tmp) != 0, 'No data found!'
129-
date_str = header_tmp['DATE__OBS'].replace(
130-
'MISSING', '').str.replace('60', '59') # fix date format
131-
date_diff = np.abs(pd.to_datetime(
132-
date_str).dt.tz_localize(None) - date)
142+
ds_hmi, key=",".join(keys_hmi), seg="magnetogram"
143+
)
144+
assert len(header_tmp) != 0, "No data found!"
145+
date_str = (
146+
header_tmp["DATE__OBS"].replace("MISSING", "").str.replace("60", "59")
147+
) # fix date format
148+
date_diff = np.abs(pd.to_datetime(date_str).dt.tz_localize(None) - date)
133149
# sort and filter
134-
header_tmp['date_diff'] = date_diff
135-
header_tmp.sort_values('date_diff')
136-
segment_tmp['date_diff'] = date_diff
137-
segment_tmp.sort_values('date_diff')
150+
header_tmp["date_diff"] = date_diff
151+
header_tmp.sort_values("date_diff")
152+
segment_tmp["date_diff"] = date_diff
153+
segment_tmp.sort_values("date_diff")
138154
cond_tmp = header_tmp.QUALITY == 0
139155
header_tmp = header_tmp[cond_tmp]
140156
segment_tmp = segment_tmp[cond_tmp]
141-
assert len(header_tmp) > 0, 'No valid quality flag found'
157+
assert len(header_tmp) > 0, "No valid quality flag found"
142158
# replace invalid
143-
header_hmi = header_tmp.iloc[0].drop('date_diff')
144-
segment_hmi = segment_tmp.iloc[0].drop('date_diff')
159+
header_hmi = header_tmp.iloc[0].drop("date_diff")
160+
segment_hmi = segment_tmp.iloc[0].drop("date_diff")
145161
############################################################
146162
# query EUV
147163
header_euv, segment_euv = [], []
148164
t = date - timedelta(hours=6)
149165
for wl in self.wavelengths:
150-
euv_ds = 'aia.lev1_euv_12s[%sZ/12h@12s][%s]{image}' % (
151-
t.replace(tzinfo=None).isoformat('_', timespec='seconds'), wl)
166+
euv_ds = "aia.lev1_euv_12s[%sZ/12h@12s][%s]{image}" % (
167+
t.replace(tzinfo=None).isoformat("_", timespec="seconds"),
168+
wl,
169+
)
152170
keys_euv = self.drms_client.keys(euv_ds)
153171
header_tmp, segment_tmp = self.drms_client.query(
154-
euv_ds, key=','.join(keys_euv), seg='image')
155-
assert len(header_tmp) != 0, 'No data found!'
156-
date_str = header_tmp['DATE__OBS'].replace(
157-
'MISSING', '').str.replace('60', '59') # fix date format
158-
date_diff = (pd.to_datetime(
159-
date_str).dt.tz_localize(None) - date).abs()
172+
euv_ds, key=",".join(keys_euv), seg="image"
173+
)
174+
assert len(header_tmp) != 0, "No data found!"
175+
date_str = (
176+
header_tmp["DATE__OBS"].replace("MISSING", "").str.replace("60", "59")
177+
) # fix date format
178+
date_diff = (pd.to_datetime(date_str).dt.tz_localize(None) - date).abs()
160179
# sort and filter
161-
header_tmp['date_diff'] = date_diff
162-
header_tmp.sort_values('date_diff')
163-
segment_tmp['date_diff'] = date_diff
164-
segment_tmp.sort_values('date_diff')
180+
header_tmp["date_diff"] = date_diff
181+
header_tmp.sort_values("date_diff")
182+
segment_tmp["date_diff"] = date_diff
183+
segment_tmp.sort_values("date_diff")
165184
cond_tmp = header_tmp.QUALITY == 0
166185
header_tmp = header_tmp[cond_tmp]
167186
segment_tmp = segment_tmp[cond_tmp]
168-
assert len(header_tmp) > 0, 'No valid quality flag found'
187+
assert len(header_tmp) > 0, "No valid quality flag found"
169188
# replace invalid
170-
header_euv.append(header_tmp.iloc[0].drop('date_diff'))
171-
segment_euv.append(segment_tmp.iloc[0].drop('date_diff'))
189+
header_euv.append(header_tmp.iloc[0].drop("date_diff"))
190+
segment_euv.append(segment_tmp.iloc[0].drop("date_diff"))
172191

173192
queue = []
174193
queue += [(header_hmi.to_dict(), segment_hmi.magnetogram, date)]
@@ -178,24 +197,48 @@ def fetchDataFallback(self, date: datetime):
178197
with multiprocessing.Pool(self.n_workers) as p:
179198
p.map(self.download, queue)
180199

181-
logging.info('Finished: %s' % id)
200+
logging.info("Finished: %s" % id)
201+
182202

203+
def download_sdo_data(
204+
start_date: str = "2022-3-1",
205+
end_date: str = "2023-3-2",
206+
email: Optional[str] = None,
207+
base_path: Optional[str] = None,
208+
n_workers: int = 8,
209+
):
210+
if base_path is None:
211+
base_path = os.path.join(os.path.expanduser("~"), "sdo-path")
183212

184-
def main():
185-
import os
186-
email = os.getenv('SDO_EMAIL')
187-
base_path = os.path.join(os.path.expanduser('~'), 'sdo-data')
213+
logger.info(f"BasePath: {base_path}")
188214

215+
# check datetime object
216+
start_date: datetime = check_datetime_format(start_date, sensor="sodo")
217+
end_date: datetime = check_datetime_format(end_date, sensor="sodo")
218+
219+
logger.info(f"Period: {start_date}-{end_date}")
220+
221+
if email is None:
222+
email = os.getenv("SDO_EMAIL")
223+
logger.info(f"Email: {email}")
189224
downloader_sdo = SDODownloader(
190-
base_path=base_path, email=email, n_workers=8)
225+
base_path=base_path, email=email, n_workers=n_workers
226+
)
227+
228+
dates = [
229+
start_date + i * timedelta(hours=12)
230+
for i in range((end_date - start_date) // timedelta(hours=12))
231+
]
232+
233+
pbar = tqdm.tqdm(dates)
234+
235+
with warnings.catch_warnings():
236+
warnings.simplefilter("ignore")
191237

192-
start_date = datetime(2022, 3, 1)
193-
end_date = datetime(2023, 3, 2)
194-
import tqdm
195-
for d in tqdm.tqdm([start_date + i * timedelta(hours=12) for i in
196-
range((end_date - start_date) // timedelta(hours=12))]):
197-
downloader_sdo.downloadDate(d)
238+
for idate in pbar:
239+
pbar.set_description(f"Date: {idate}")
240+
downloader_sdo.downloadDate(idate)
198241

199242

200-
if __name__ == '__main__':
201-
main()
243+
if __name__ == "__main__":
244+
typer.run(download_sdo_data)

0 commit comments

Comments
 (0)