diff --git a/mne/io/egi/egimff.py b/mne/io/egi/egimff.py index 51efd433455..7d96c9470f9 100644 --- a/mne/io/egi/egimff.py +++ b/mne/io/egi/egimff.py @@ -5,6 +5,8 @@ """EGI NetStation Load Function.""" import datetime +import fnmatch +import itertools import math import os.path as op import re @@ -14,7 +16,7 @@ import numpy as np from ..._fiff.constants import FIFF -from ..._fiff.meas_info import _empty_info, _ensure_meas_date_none_or_dt, create_info +from ..._fiff.meas_info import create_info from ..._fiff.proj import setup_proj from ..._fiff.utils import _create_chs, _mult_cal_one from ...annotations import Annotations @@ -22,7 +24,6 @@ from ...evoked import EvokedArray from ...utils import _check_fname, _check_option, _soft_import, logger, verbose, warn from ..base import BaseRaw -from .events import _combine_triggers, _read_events, _triage_include_exclude from .general import ( _block_r, _extract, @@ -35,6 +36,134 @@ REFERENCE_NAMES = ("VREF", "Vertex Reference") +# TODO: Running list +# - [ ] Add support for reading in the PNS data +# - [ ] Add tutorial for reading calibration data +# - [ ] Add support for reading in the channel status (bad channels) +# - [ ] Replace _read_header with mffpy functions? + + +def _read_mff(input_fname): + """Read EGI MFF file.""" + mff_reader = _get_mff_reader(input_fname) + eeg = _get_eeg_data(mff_reader) + info = _get_info(mff_reader) + annotations = _get_annotations(mff_reader, info) + return eeg, info, annotations + + +def _get_mff_reader(input_fname): + mffpy = _import_mffpy() + mff_reader = mffpy.Reader(input_fname) + mff_reader.set_unit("EEG", "V") # XXX: set PNS unit + return mff_reader + + +def _get_montage(mff_reader): + mffpy = _import_mffpy() + xml_files = mff_reader.directory.files_by_type[".xml"] + sensor_fname = fnmatch.filter(xml_files, "sensorLayout") + assert len(sensor_fname) == 1 # XXX: remove + sensor_fname = sensor_fname[0] + with mff_reader.directory.filepointer(sensor_fname) as fp: + sensor_layout = mffpy.XML.from_file(fp).get_content()["sensors"] + n_eeg_channels = mff_reader.num_channels["EEG"] # XXX: PNS? + ch_pos = dict() + hsp = list() + for ch in sensor_layout.values(): + # XXX: the y coordinate seems to be inverted? Need to investigate + loc = np.array([ch["x"], -(ch["y"]), ch["z"]]) / 1000 + if ch["number"] <= n_eeg_channels: + assert ch["type"] in [0, 1] # XXX: remove + name = f"E{ch['number']}" if ch["name"] == "None" else ch["name"] + ch_pos[name] = loc + elif ch["type"] == 2: # type 2 seems to be headshape points or COM.. + if ch["name"] == "COM": + continue + hsp.append(loc) + # XXX: this is still wonky. MNE will complain that the head radius is unusually big + montage = make_dig_montage(ch_pos=ch_pos, coord_frame="head", hsp=hsp) + return montage + + +def _get_info(mff_reader): + montage = _get_montage(mff_reader) + ch_names = montage.ch_names + ch_types = ["eeg"] * len(ch_names) # XXX: refactor this when adding PNS support + meas_date = mff_reader.startdatetime.astimezone(datetime.timezone.utc) + sfreq = mff_reader.sampling_rates["EEG"] # XXX: check PNS sfreq? + info = create_info(ch_names=ch_names, sfreq=sfreq, ch_types=ch_types) + info.set_montage(montage) + info.set_meas_date(meas_date) + return info + + +def _get_eeg_data(mff_reader): + sfreq = mff_reader.sampling_rates["EEG"] # XXX: check PNS sfreq + n_channels = np.sum(list(mff_reader.num_channels.values())) + epochs = mff_reader.epochs + + data_blocks, start_secs, end_secs = [], [], [] + for epoch in epochs: + data_chunk, _ = mff_reader.get_physical_samples_from_epoch(epoch)["EEG"] # XXX + data_blocks.append(data_chunk) + start_secs.append(epoch.t0) + end_secs.append(epoch.t1) + + first_samp = int(start_secs[0] * sfreq) + last_samp = int(end_secs[-1] * sfreq) + interval = (1 / sfreq) * 1000 + all_samps = np.arange(first_samp, last_samp + 1, interval) + eeg = np.zeros((n_channels, len(all_samps)), dtype=np.float64) + for this_chunk, start, end in zip(data_blocks, start_secs, end_secs): + start = int(start * sfreq) + end = int(end * sfreq) + eeg[:, start:end] = this_chunk + return eeg + + +def _get_gap_annotations(mff_reader): + epochs = mff_reader.epochs + start_secs = [epoch.t0 for epoch in epochs] + end_secs = [epoch.t1 for epoch in epochs] + gap_durations = np.array(start_secs[1:]) - np.array(end_secs[:-1]) + descriptions = "BAD_ACQ_SKIP" * len(gap_durations) + gap_onsets = np.array(end_secs[:-1]) + gap_annots = Annotations(gap_onsets, gap_durations, descriptions) + return gap_annots + + +def _get_event_annotations(mff_reader, mne_info): + mffpy = _import_mffpy() + xml_files = mff_reader.directory.files_by_type[".xml"] + events_xmls = fnmatch.filter(xml_files, "Events*") + if not events_xmls: + raise RuntimeError("No events found in MFF file.") + mff_events = {} + for event_file in events_xmls: + with mff_reader.directory.filepointer(event_file) as fp: + categories = mffpy.XML.from_file(fp) + mff_events[event_file] = categories.get_content()["event"] + onsets = [] + descriptions = [] + mff_events = list(itertools.chain.from_iterable(mff_events.values())) + for event in mff_events: + onset_dt = event["beginTime"].astimezone(datetime.timezone.utc) + ts = (onset_dt - mne_info["meas_date"]).total_seconds() + onsets.append(ts) + # XXX: we could use event["duration"] but it always seems to be 1000ms? + descriptions.append(event["code"]) + durations = [0] * len(onsets) + event_annots = Annotations(onsets, durations, descriptions) + return event_annots + + +def _get_annotations(mff_reader, mne_info): + event_annots = _get_event_annotations(mff_reader, mne_info) + gap_annots = _get_gap_annotations(mff_reader) + return event_annots + gap_annots + + def _read_mff_header(filepath): """Read mff header.""" _soft_import("defusedxml", "reading EGI MFF data") @@ -380,14 +509,14 @@ class RawMff(BaseRaw): def __init__( self, input_fname, - eog=None, - misc=None, - include=None, - exclude=None, - preload=False, - channel_naming="E%d", + eog=None, # XXX: allow user to specify EOG channels? + misc=None, # XXX: allow user to specify misc channels? + include=None, # XXX: Now We dont create stim channels. Remove this? + exclude=None, # XXX: Ditto. But maybe we can exclude events from annots. + preload=False, # XXX: Make this work again + channel_naming="E%d", # XXX: Do we need to still support this? *, - events_as_annotations=True, + events_as_annotations=True, # XXX: This is now the only way. Remove? verbose=None, ): """Init the RawMff class.""" @@ -401,183 +530,19 @@ def __init__( ) ) logger.info(f"Reading EGI MFF Header from {input_fname}...") - egi_info = _read_header(input_fname) - if eog is None: - eog = [] - if misc is None: - misc = np.where(np.array(egi_info["chan_type"]) != "eeg")[0].tolist() - - logger.info(" Reading events ...") - egi_events, egi_info, mff_events = _read_events(input_fname, egi_info) - cals = _get_eeg_calibration_info(input_fname, egi_info) - logger.info(" Assembling measurement info ...") - event_codes = egi_info["event_codes"] - include = _triage_include_exclude(include, exclude, egi_events, egi_info) - if egi_info["n_events"] > 0 and not events_as_annotations: - logger.info(' Synthesizing trigger channel "STI 014" ...') - if all(ch.startswith("D") for ch in include): - # support the DIN format DIN1, DIN2, ..., DIN9, DI10, DI11, ... DI99, - # D100, D101, ..., D255 that we get when sending 0-255 triggers on a - # parallel port. - events_ids = list() - for ch in include: - while not ch[0].isnumeric(): - ch = ch[1:] - events_ids.append(int(ch)) - else: - events_ids = np.arange(len(include)) + 1 - egi_info["new_trigger"] = _combine_triggers( - egi_events[[c in include for c in event_codes]], remapping=events_ids - ) - self.event_id = dict( - zip([e for e in event_codes if e in include], events_ids) - ) - if egi_info["new_trigger"] is not None: - egi_events = np.vstack([egi_events, egi_info["new_trigger"]]) - else: - self.event_id = None - egi_info["new_trigger"] = None - assert egi_events.shape[1] == egi_info["last_samps"][-1] - - meas_dt_utc = egi_info["meas_dt_local"].astimezone(datetime.timezone.utc) - info = _empty_info(egi_info["sfreq"]) - info["meas_date"] = _ensure_meas_date_none_or_dt(meas_dt_utc) - info["utc_offset"] = egi_info["utc_offset"] - info["device_info"] = dict(type=egi_info["device"]) - - # read in the montage, if it exists - ch_names, mon = _read_locs(input_fname, egi_info, channel_naming) - # Second: Stim - ch_names.extend(list(egi_info["event_codes"])) - n_extra = len(event_codes) + len(misc) + len(eog) + len(egi_info["pns_names"]) - if egi_info["new_trigger"] is not None: - ch_names.append("STI 014") # channel for combined events - n_extra += 1 - - # Third: PNS - ch_names.extend(egi_info["pns_names"]) - - cals = np.concatenate([cals, np.ones(n_extra)]) - assert len(cals) == len(ch_names), (len(cals), len(ch_names)) - - # Actually create channels as EEG, then update stim and PNS - ch_coil = FIFF.FIFFV_COIL_EEG - ch_kind = FIFF.FIFFV_EEG_CH - chs = _create_chs(ch_names, cals, ch_coil, ch_kind, eog, (), (), misc) - - sti_ch_idx = [ - i - for i, name in enumerate(ch_names) - if name.startswith("STI") or name in event_codes - ] - for idx in sti_ch_idx: - chs[idx].update( - { - "unit_mul": FIFF.FIFF_UNITM_NONE, - "cal": cals[idx], - "kind": FIFF.FIFFV_STIM_CH, - "coil_type": FIFF.FIFFV_COIL_NONE, - "unit": FIFF.FIFF_UNIT_NONE, - } - ) - chs = _add_pns_channel_info(chs, egi_info, ch_names) - info["chs"] = chs - info._unlocked = False - info._update_redundant() - - if mon is not None: - info.set_montage(mon, on_missing="ignore") - - ref_idx = np.flatnonzero(np.isin(mon.ch_names, REFERENCE_NAMES)) - if len(ref_idx): - ref_idx = ref_idx.item() - ref_coords = info["chs"][int(ref_idx)]["loc"][:3] - for chan in info["chs"]: - if chan["kind"] == FIFF.FIFFV_EEG_CH: - chan["loc"][3:6] = ref_coords - - file_bin = op.join(input_fname, egi_info["eeg_fname"]) - egi_info["egi_events"] = egi_events - - # Check how many channels to read are from EEG - keys = ("eeg", "sti", "pns") - idx = dict() - idx["eeg"] = np.where([ch["kind"] == FIFF.FIFFV_EEG_CH for ch in chs])[0] - idx["sti"] = np.where([ch["kind"] == FIFF.FIFFV_STIM_CH for ch in chs])[0] - idx["pns"] = np.where( - [ - ch["kind"] in (FIFF.FIFFV_ECG_CH, FIFF.FIFFV_EMG_CH, FIFF.FIFFV_BIO_CH) - for ch in chs - ] - )[0] - # By construction this should always be true, but check anyway - if not np.array_equal( - np.concatenate([idx[key] for key in keys]), np.arange(len(chs)) - ): - raise ValueError( - "Currently interlacing EEG and PNS channels is not supported" - ) - egi_info["kind_bounds"] = [0] - for key in keys: - egi_info["kind_bounds"].append(len(idx[key])) - egi_info["kind_bounds"] = np.cumsum(egi_info["kind_bounds"]) - assert egi_info["kind_bounds"][0] == 0 - assert egi_info["kind_bounds"][-1] == info["nchan"] - first_samps = [0] - last_samps = [egi_info["last_samps"][-1] - 1] - - annot = dict(onset=list(), duration=list(), description=list()) - - if len(idx["pns"]): - # PNS Data is present and should be read: - egi_info["pns_filepath"] = op.join(input_fname, egi_info["pns_fname"]) - # Check for PNS bug immediately - pns_samples = np.sum(egi_info["pns_sample_blocks"]["samples_block"]) - eeg_samples = np.sum(egi_info["samples_block"]) - if pns_samples == eeg_samples - 1: - warn("This file has the EGI PSG sample bug") - annot["onset"].append(last_samps[-1] / egi_info["sfreq"]) - annot["duration"].append(1 / egi_info["sfreq"]) - annot["description"].append("BAD_EGI_PSG") - elif pns_samples != eeg_samples: - raise RuntimeError( - "PNS samples (%d) did not match EEG samples (%d)" - % (pns_samples, eeg_samples) - ) + eeg, info, annots = _read_mff(input_fname) super().__init__( info, - preload=preload, - orig_format="single", - filenames=[file_bin], - first_samps=first_samps, - last_samps=last_samps, - raw_extras=[egi_info], + preload=eeg, # XXX: Make eager/lazy loading work again + orig_format="single", # XXX: Check if this is still correct + filenames=[input_fname], # XXX: multiple files? I need an example + first_samps=(0,), # XXX: multiple files? + last_samps=None, # XXX: multiple files? + raw_extras=(None,), # XXX: do we still need this? verbose=verbose, ) - - # Annotate acquisition skips - for first, prev_last in zip( - egi_info["first_samps"][1:], egi_info["last_samps"][:-1] - ): - gap = first - prev_last - assert gap >= 0 - if gap: - annot["onset"].append((prev_last - 0.5) / egi_info["sfreq"]) - annot["duration"].append(gap / egi_info["sfreq"]) - annot["description"].append("BAD_ACQ_SKIP") - - # create events from annotations - if events_as_annotations: - for code, samples in mff_events.items(): - if code not in include: - continue - annot["onset"].extend(np.array(samples) / egi_info["sfreq"]) - annot["duration"].extend([0.0] * len(samples)) - annot["description"].extend([code] * len(samples)) - - if len(annot["onset"]): - self.set_annotations(Annotations(**annot)) + self.set_annotations(annots) def _read_segment_file(self, data, idx, fi, start, stop, cals, mult): """Read a chunk of data."""