Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
*.egg-info
doc/_build/
.cache
*.pyc
.coverage
.pytest_cache
32 changes: 28 additions & 4 deletions pyasdf/asdf_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,6 @@
import prov
import prov.model


# Minimum compatibility wrapper between Python 2 and 3.
try:
filter = itertools.ifilter
Expand Down Expand Up @@ -811,7 +810,13 @@ def _get_waveform(self, waveform_name, starttime=None, endtime=None):
data = self.__file["Waveforms"]["%s.%s" % (network, station)][
waveform_name]

tr = obspy.Trace(data=data[idx_start: idx_end])
if "mask" in data.attrs and data.attrs["mask"] != np.bool(False):
_data = np.ma.masked_values(data[idx_start: idx_end],
data.attrs["mask"])
else:
_data = data[idx_start: idx_end]

tr = obspy.Trace(data=_data)
tr.stats.starttime = data_starttime
tr.stats.sampling_rate = data.attrs["sampling_rate"]
tr.stats.network = network
Expand Down Expand Up @@ -1167,6 +1172,8 @@ def add_waveforms(self, waveform, tag, event_id=None, origin_id=None,

# Actually add the data.
for trace in waveform:
if isinstance(trace.data, np.ma.masked_array):
self.__set_masked_array_fill_value(trace)
# Complicated multi-step process but it enables one to use
# parallel I/O with the same functions.
info = self._add_trace_get_collective_information(
Expand All @@ -1179,6 +1186,16 @@ def add_waveforms(self, waveform, tag, event_id=None, origin_id=None,
self._add_trace_write_collective_information(info)
self._add_trace_write_independent_information(info, trace)

def __set_masked_array_fill_value(self, trace):
if trace.data.dtype.kind in ("i", "u"):
_info = np.iinfo
elif trace.data.dtype.kind == "f":
_info = np.finfo
else:
raise(NotImplementedError("fill value for dtype %s not defined"
% trace.data.dtype))
trace.data.set_fill_value(_info(trace.data.dtype).min)

def __parse_and_validate_tag(self, tag):
tag = tag.strip()
if tag.lower() == "stationxml":
Expand Down Expand Up @@ -1295,7 +1312,7 @@ def _add_trace_write_independent_information(self, info, trace):
:param trace:
:return:
"""
self._waveform_group[info["data_name"]][:] = trace.data
self._waveform_group[info["data_name"]][:] = np.ma.filled(trace.data)

def _add_trace_write_collective_information(self, info):
"""
Expand Down Expand Up @@ -1367,6 +1384,12 @@ def _add_trace_get_collective_information(
else:
fletcher32 = True

# Determine appropriate mask value.
if not isinstance(trace.data, np.ma.masked_array):
_mask = np.bool(False)
else:
_mask = trace.data.fill_value

info = {
"station_name": station_name,
"data_name": group_name,
Expand All @@ -1384,7 +1407,8 @@ def _add_trace_get_collective_information(
# Starttime is the epoch time in nanoseconds.
"starttime":
int(round(trace.stats.starttime.timestamp * 1.0E9)),
"sampling_rate": trace.stats.sampling_rate
"sampling_rate": trace.stats.sampling_rate,
"mask": _mask
}
}

Expand Down
57 changes: 57 additions & 0 deletions pyasdf/tests/test_asdf_data_set.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,60 @@ def test_data_set_creation(tmpdir):
assert cat_file == cat_asdf


def test_masked_data_creation(tmpdir):
asdf_filename = os.path.join(tmpdir.strpath, "test.h5")
data_path = os.path.join(data_dir, "small_sample_data_set")

data_set = ASDFDataSet(asdf_filename)

filename = os.path.join(data_path, "AE.113A..BHZ.mseed")

ts1 = obspy.UTCDateTime("2013-05-24T05:40:00")
te1 = obspy.UTCDateTime("2013-05-24T06:00:00")
ts2 = obspy.UTCDateTime("2013-05-24T06:10:00")
te2 = obspy.UTCDateTime("2013-05-24T06:50:00")

st_file_raw = obspy.read(filename)

st_file_masked = st_file_raw.copy().trim(starttime=ts1, endtime=te1)\
+ st_file_raw.copy().trim(starttime=ts2, endtime=te2)
st_file_masked.merge()

# This will cast dtype from int to float
st_file_masked_filtered = st_file_masked.copy()
st_file_masked_filtered = st_file_masked_filtered.split()
st_file_masked_filtered.filter("bandpass", freqmin=0.1, freqmax=10)
st_file_masked_filtered.merge()

data_set.add_waveforms(st_file_masked, tag="masked")
data_set.add_waveforms(st_file_masked_filtered, tag="masked_filtered")

st_asdf_masked = data_set.waveforms["AE.113A"]["masked"]
st_asdf_masked_filtered = data_set.waveforms["AE.113A"]["masked_filtered"]

trfm = st_file_masked[0]
trfmf = st_file_masked_filtered[0]
tram = st_asdf_masked[0]
tramf = st_asdf_masked_filtered[0]

for tr in (trfm, trfmf):
del(tr.stats.mseed)
del(tr.stats._format)
del(tr.stats.processing)

for tr in (tram, tramf):
del(tr.stats.asdf)
del(tr.stats._format)

assert trfm.stats == tram.stats
assert all(trfm.data.mask == tram.data.mask)
assert all(trfm.data[~trfm.data.mask] == tram.data[~tram.data.mask])

assert trfmf.stats == tramf.stats
assert all(trfmf.data.mask == tramf.data.mask)
assert all(trfmf.data[~trfmf.data.mask] == tramf.data[~tramf.data.mask])


def test_equality_checks(example_data_set):
"""
Tests the equality operations.
Expand Down Expand Up @@ -3020,20 +3074,23 @@ def test_get_waveform_attributes(example_data_set):
'event_ids': [
'smi:service.iris.edu/fdsnws/event/1/query?'
'eventid=4218658'],
'mask': np.bool(False),
'sampling_rate': 40.0,
'starttime': 1369374000000000000},
'AE.113A..BHN__2013-05-24T05:40:00__'
'2013-05-24T06:50:00__raw_recording': {
'event_ids': [
'smi:service.iris.edu/fdsnws/event/1/query?'
'eventid=4218658'],
'mask': np.bool(False),
'sampling_rate': 40.0,
'starttime': 1369374000000000000},
'AE.113A..BHZ__2013-05-24T05:40:00__'
'2013-05-24T06:50:00__raw_recording': {
'event_ids': [
'smi:service.iris.edu/fdsnws/event/1/query?'
'eventid=4218658'],
'mask': np.bool(False),
'sampling_rate': 40.0,
'starttime': 1369374000000000000}
}
Expand Down