diff --git a/doc/tutorials/0-1_meg_preprocessing.py b/doc/tutorials/0-1_meg_preprocessing.py index cffbd077..b7379288 100644 --- a/doc/tutorials/0-1_meg_preprocessing.py +++ b/doc/tutorials/0-1_meg_preprocessing.py @@ -4,7 +4,7 @@ This tutorial walks through the full MEG processing pipeline step by step: -1. Preprocessing — Downsample, filter, detect bad segments/channels. +1. Preprocessing — Downsample, filter, detect bad segments/channels, ICA artefact rejection. 2. Surface Extraction — Extract skull/scalp surfaces from a structural MRI. 3. Coregistration — Align MEG sensor space to MRI space. 4. Forward Model — Compute the lead field matrix. @@ -161,11 +161,36 @@ # raw = preproc.detect_bad_channels(raw, picks="grad") #%% -# Save preprocessing QC plots (PSD, sum-of-squares time series, channel standard deviations). Check that bad segments and channels are being correctly identified. +# ICA artefact rejection +# ********************** +# +# Automatic ICA artefact rejection to remove ECG and EOG artefacts. +# Two approaches are available: +# +# **Option A: ECG/EOG correlation** — Uses MNE-Python's built-in +# correlation-based detection. +# +# .. code-block:: python +# +# raw, ica, ic_labels = preproc.ica_ecg_eog_correlation(raw, picks="meg") +# +# **Option B: MEGNet automatic labelling** — Uses ``mne-icalabel`` (included +# in the osl-dynamics conda environments) to classify components with a +# pre-trained deep learning model. Note, MEGNet was trained on ``'mag'`` +# sensor topographies. +# +# .. code-block:: python +# +# raw, ica, ic_labels = preproc.ica_label(raw, picks="mag", method="megnet") + +#%% +# Save preprocessing QC plots (PSD, sum-of-squares time series, channel +# standard deviations, ICA component topographies). Check that bad segments, +# channels, and ICA components are being correctly identified. # # .. code-block:: python # -# preproc.save_qc_plots(raw, plots_dir / id, show=True) +# preproc.save_qc_plots(raw, plots_dir / id, show=True, ica=ica, ic_labels=ic_labels) #%% # Save preprocessed data diff --git a/envs/bmrc.yml b/envs/bmrc.yml index d82ed818..25731dbd 100644 --- a/envs/bmrc.yml +++ b/envs/bmrc.yml @@ -8,15 +8,14 @@ dependencies: - aiohttp=3.13.3=py312h5d8c7f2_0 - aiosignal=1.4.0=pyhd8ed1ab_0 - alsa-lib=1.2.15.3=hb03c661_0 - - anyio=4.12.1=pyhcf101f3_0 + - anyio=4.13.0=pyhcf101f3_0 - aom=3.9.1=hac33072_0 - argon2-cffi=25.1.0=pyhd8ed1ab_0 - argon2-cffi-bindings=25.1.0=py312h4c3975b_2 - arrow=1.4.0=pyhcf101f3_0 - asttokens=3.0.1=pyhd8ed1ab_0 - - async-lru=2.2.0=pyhcf101f3_0 - - attr=2.5.2=h39aace5_0 - - attrs=25.4.0=pyhcf101f3_1 + - async-lru=2.3.0=pyhcf101f3_0 + - attrs=26.1.0=pyhcf101f3_0 - babel=2.18.0=pyhcf101f3_1 - backports.zstd=1.3.0=py312h90b7ffd_0 - beautifulsoup4=4.14.3=pyha770c72_0 @@ -38,12 +37,12 @@ dependencies: - certifi=2026.2.25=pyhd8ed1ab_0 - cffi=2.0.0=py312h460c074_1 - charls=2.4.3=hecca717_0 - - charset-normalizer=3.4.5=pyhd8ed1ab_0 + - charset-normalizer=3.4.6=pyhd8ed1ab_0 - comm=0.2.3=pyhe01879c_0 - contourpy=1.3.3=py312h0a2e395_4 - cpython=3.12.13=py312hd8ed1ab_0 - cycler=0.12.1=pyhcf101f3_2 - - cyclopts=4.7.0=pyhcf101f3_0 + - cyclopts=4.10.1=pyhcf101f3_0 - cyrus-sasl=2.1.28=hd9c7081_0 - dav1d=1.2.1=hd590300_0 - dbus=1.16.2=h24cb091_1 @@ -67,19 +66,19 @@ dependencies: - fonts-conda-forge=1=hc364b38_1 - fonttools=4.62.0=py312h8a5da7c_0 - fqdn=1.5.1=pyhd8ed1ab_1 - - freetype=2.14.2=ha770c72_0 + - freetype=2.14.3=ha770c72_0 - fribidi=1.0.16=hb03c661_0 - frozenlist=1.7.0=py312h447239a_0 - fslpy=3.27.0=pyhd8ed1ab_0 - gdk-pixbuf=2.44.4=h2b0a6b4_0 - giflib=5.2.2=hd590300_0 - - gl2ps=1.4.2=hae5d5c5_1 + - gl2ps=1.4.2=h36e74d4_2 - gmp=6.3.0=hac33072_2 - gnutls=3.8.11=h18acefa_1 - graphite2=1.3.14=hecca717_2 - h11=0.16.0=pyhcf101f3_1 - h2=4.3.0=pyhcf101f3_0 - - h5io=0.2.5=pyhecae5ae_0 + - h5io=0.2.5=pyhc455866_0 - h5py=3.13.0=nompi_py312hedeef09_100 - harfbuzz=12.2.0=h15599e2_0 - hdf4=4.2.15=h2a13503_7 @@ -92,9 +91,9 @@ dependencies: - idna=3.11=pyhd8ed1ab_0 - imagecodecs=2026.1.14=py312h40df4bb_1 - imageio=2.37.0=pyhfb79c49_0 - - importlib-metadata=8.7.0=pyhe01879c_1 + - importlib-metadata=8.8.0=pyhcf101f3_0 - importlib_resources=6.5.2=pyhd8ed1ab_0 - - intel-gmmlib=22.9.0=hb700be7_0 + - intel-gmmlib=22.10.0=hb700be7_0 - intel-media-driver=25.3.4=hecca717_0 - ipyevents=2.0.4=pyhbbac1ac_0 - ipykernel=7.2.0=pyha191276_1 @@ -107,7 +106,7 @@ dependencies: - joblib=1.5.3=pyhd8ed1ab_0 - json5=0.13.0=pyhd8ed1ab_0 - jsoncpp=1.9.6=hf42df4d_1 - - jsonpointer=3.0.0=pyhcf101f3_3 + - jsonpointer=3.1.1=pyhcf101f3_0 - jsonschema=4.26.0=pyhcf101f3_0 - jsonschema-specifications=2025.9.1=pyhcf101f3_0 - jsonschema-with-format-nongpl=4.26.0=hcf101f3_0 @@ -132,7 +131,7 @@ dependencies: - lazy-loader=0.5=pyhd8ed1ab_0 - lazy_loader=0.5=pyhd8ed1ab_0 - lcms2=2.18=h0c24ade_0 - - ld_impl_linux-64=2.45.1=default_hbd61a6d_101 + - ld_impl_linux-64=2.45.1=default_hbd61a6d_102 - lerc=4.1.0=hdb68285_0 - level-zero=1.28.2=hb700be7_0 - libabseil=20250512.1=cxx17_hba17884_0 @@ -144,7 +143,7 @@ dependencies: - libbrotlicommon=1.2.0=hb03c661_1 - libbrotlidec=1.2.0=hb03c661_1 - libbrotlienc=1.2.0=hb03c661_1 - - libcap=2.77=h3ff7636_0 + - libcap=2.77=hd0affe5_1 - libcblas=3.11.0=5_h0358290_openblas - libclang-cpp21.1=21.1.0=default_h99862b1_1 - libclang13=21.1.0=default_h746c552_1 @@ -158,8 +157,8 @@ dependencies: - libexpat=2.7.4=hecca717_0 - libffi=3.5.2=h3435931_0 - libflac=1.5.0=he200343_1 - - libfreetype=2.14.2=ha770c72_0 - - libfreetype6=2.14.2=h73754d4_0 + - libfreetype=2.14.3=ha770c72_0 + - libfreetype6=2.14.3=h73754d4_0 - libgcc=15.2.0=he0feb66_18 - libgcc-ng=15.2.0=h69a702a_18 - libgfortran=15.2.0=h69a702a_18 @@ -181,7 +180,7 @@ dependencies: - liblzma=5.8.2=hb03c661_0 - libmicrohttpd=1.0.2=hc2fc477_0 - libnetcdf=4.9.2=nompi_h00e09a9_116 - - libnghttp2=1.67.0=had1ee68_0 + - libnghttp2=1.68.1=h877daf1_0 - libnsl=2.0.1=hb9d3cd8_1 - libntlm=1.8=hb9d3cd8_0 - libogg=1.3.5=hd0c01bc_1 @@ -212,11 +211,11 @@ dependencies: - libssh2=1.11.1=hcf80075_0 - libstdcxx=15.2.0=h934c35e_18 - libstdcxx-ng=15.2.0=hdf11a46_18 - - libsystemd0=257.10=hd0affe5_4 + - libsystemd0=257.13=hd0affe5_0 - libtasn1=4.21.0=hb03c661_0 - libtheora=1.1.1=h4ab18f5_1006 - libtiff=4.7.1=h9d88235_1 - - libudev1=257.10=hd0affe5_4 + - libudev1=257.13=hd0affe5_0 - libunistring=0.9.10=h7f98852_0 - libunwind=1.8.3=h65a8314_0 - liburing=2.12=hb700be7_0 @@ -234,7 +233,7 @@ dependencies: - libxml2=2.13.9=h04c0eec_0 - libxslt=1.1.43=h7a3aeb2_0 - libzip=1.11.2=h6991a6a_0 - - libzlib=1.3.1=hb9d3cd8_2 + - libzlib=1.3.2=h25fd6f3_2 - libzopfli=1.0.3=h9c3ff4c_0 - llvmlite=0.46.0=py312h7424e68_0 - loguru=0.7.3=pyh707e725_0 @@ -270,7 +269,7 @@ dependencies: - notebook-shim=0.2.4=pyhd8ed1ab_1 - numba=0.64.0=py312hd1dde6f_0 - ocl-icd=2.3.3=hb9d3cd8_0 - - opencl-headers=2025.06.13=h5888daf_0 + - opencl-headers=2025.06.13=hecca717_0 - openh264=2.6.0=hc22cd8d_0 - openjpeg=2.5.4=h55fea9a_0 - openjph=0.26.3=h8d634f6_0 @@ -326,7 +325,7 @@ dependencies: - rav1e=0.7.1=h8fae777_3 - readline=8.3=h853b02a_0 - referencing=0.37.0=pyhcf101f3_0 - - requests=2.32.5=pyhcf101f3_1 + - requests=2.33.0=pyhcf101f3_0 - rfc3339-validator=0.1.4=pyhd8ed1ab_1 - rfc3986-validator=0.1.1=pyh9f0ad1d_0 - rfc3987-syntax=1.1.0=pyhe01879c_1 @@ -358,14 +357,14 @@ dependencies: - tinycss2=1.4.0=pyhd8ed1ab_0 - tk=8.6.13=noxft_h366c992_103 - tomli=2.4.0=pyhcf101f3_0 - - tornado=6.5.3=py312h4c3975b_0 + - tornado=6.5.5=py312h4c3975b_0 - tqdm=4.67.3=pyh8f84b5b_0 - traitlets=5.14.3=pyhd8ed1ab_1 - trame=3.12.0=pyhd8ed1ab_0 - - trame-client=3.11.3=pyhd8ed1ab_0 - - trame-common=1.1.2=pyhd8ed1ab_0 + - trame-client=3.11.4=pyhd8ed1ab_0 + - trame-common=1.1.3=pyhd8ed1ab_0 - trame-server=3.10.0=pyhd8ed1ab_0 - - trame-vtk=2.11.1=pyh932262d_0 + - trame-vtk=2.11.5=pyh3504b2d_0 - trame-vuetify=3.2.1=pyhd8ed1ab_0 - typing-extensions=4.15.0=h396c80c_0 - typing_extensions=4.15.0=pyhcf101f3_0 @@ -378,7 +377,7 @@ dependencies: - vtk=9.3.1=osmesa_py312hf4758c4_116 - vtk-base=9.3.1=osmesa_py312hc9bc066_116 - vtk-io-ffmpeg=9.3.1=osmesa_py312hf4758c4_116 - - wayland=1.24.0=hd6090a7_1 + - wayland=1.25.0=hd6090a7_0 - wayland-protocols=1.47=hd8ed1ab_0 - wcwidth=0.6.0=pyhd8ed1ab_0 - webcolors=25.10.0=pyhd8ed1ab_0 @@ -419,7 +418,7 @@ dependencies: - zeromq=4.3.5=h387f397_9 - zfp=1.0.1=h909a3a2_5 - zipp=3.23.0=pyhcf101f3_1 - - zlib=1.3.1=hb9d3cd8_2 + - zlib=1.3.2=h25fd6f3_2 - zlib-ng=2.3.3=hceb46e0_1 - zstd=1.5.7=hb78ec9c_6 - pip: @@ -435,6 +434,8 @@ dependencies: - libclang==18.1.1 - markdown==3.10.2 - ml-dtypes==0.5.4 + - mne-icalabel==0.8.1 + - mpmath==1.3.0 - namex==0.1.0 - numpy==2.1.3 - nvidia-cublas-cu12==12.5.3.2 @@ -449,15 +450,17 @@ dependencies: - nvidia-cusparse-cu12==12.5.1.3 - nvidia-nccl-cu12==2.23.4 - nvidia-nvjitlink-cu12==12.5.82 + - onnxruntime==1.24.4 - opt-einsum==3.4.0 - optree==0.19.0 - protobuf==5.29.6 + - sympy==1.14.0 - tensorboard==2.19.0 - tensorboard-data-server==0.7.2 - tensorflow==2.19.0 - tensorflow-probability==0.25.0 - termcolor==3.3.0 - tf-keras==2.19.0 - - werkzeug==3.1.6 + - werkzeug==3.1.7 - wrapt==2.1.2 - osl-dynamics diff --git a/envs/fsl.yml b/envs/fsl.yml index 6a306830..1bd228b6 100644 --- a/envs/fsl.yml +++ b/envs/fsl.yml @@ -31,5 +31,6 @@ dependencies: - trame-vtk - trame-vuetify - pip: + - mne-icalabel[onnx] - tensorflow==2.19 - tensorflow-probability[tf]==0.25 diff --git a/envs/hbaws.yml b/envs/hbaws.yml index 663933e8..2239a6cf 100644 --- a/envs/hbaws.yml +++ b/envs/hbaws.yml @@ -8,15 +8,14 @@ dependencies: - aiohttp=3.13.3=py312h5d8c7f2_0 - aiosignal=1.4.0=pyhd8ed1ab_0 - alsa-lib=1.2.15.3=hb03c661_0 - - anyio=4.12.1=pyhcf101f3_0 + - anyio=4.13.0=pyhcf101f3_0 - aom=3.9.1=hac33072_0 - argon2-cffi=25.1.0=pyhd8ed1ab_0 - argon2-cffi-bindings=25.1.0=py312h4c3975b_2 - arrow=1.4.0=pyhcf101f3_0 - asttokens=3.0.1=pyhd8ed1ab_0 - - async-lru=2.2.0=pyhcf101f3_0 - - attr=2.5.2=h39aace5_0 - - attrs=25.4.0=pyhcf101f3_1 + - async-lru=2.3.0=pyhcf101f3_0 + - attrs=26.1.0=pyhcf101f3_0 - babel=2.18.0=pyhcf101f3_1 - backports.zstd=1.3.0=py312h90b7ffd_0 - beautifulsoup4=4.14.3=pyha770c72_0 @@ -38,12 +37,12 @@ dependencies: - certifi=2026.2.25=pyhd8ed1ab_0 - cffi=2.0.0=py312h460c074_1 - charls=2.4.3=hecca717_0 - - charset-normalizer=3.4.5=pyhd8ed1ab_0 + - charset-normalizer=3.4.6=pyhd8ed1ab_0 - comm=0.2.3=pyhe01879c_0 - contourpy=1.3.3=py312h0a2e395_4 - cpython=3.12.13=py312hd8ed1ab_0 - cycler=0.12.1=pyhcf101f3_2 - - cyclopts=4.7.0=pyhcf101f3_0 + - cyclopts=4.10.1=pyhcf101f3_0 - cyrus-sasl=2.1.28=hd9c7081_0 - dav1d=1.2.1=hd590300_0 - dbus=1.16.2=h24cb091_1 @@ -67,19 +66,19 @@ dependencies: - fonts-conda-forge=1=hc364b38_1 - fonttools=4.62.0=py312h8a5da7c_0 - fqdn=1.5.1=pyhd8ed1ab_1 - - freetype=2.14.2=ha770c72_0 + - freetype=2.14.3=ha770c72_0 - fribidi=1.0.16=hb03c661_0 - frozenlist=1.7.0=py312h447239a_0 - fslpy=3.27.0=pyhd8ed1ab_0 - gdk-pixbuf=2.44.4=h2b0a6b4_0 - giflib=5.2.2=hd590300_0 - - gl2ps=1.4.2=hae5d5c5_1 + - gl2ps=1.4.2=h36e74d4_2 - gmp=6.3.0=hac33072_2 - gnutls=3.8.11=h18acefa_1 - graphite2=1.3.14=hecca717_2 - h11=0.16.0=pyhcf101f3_1 - h2=4.3.0=pyhcf101f3_0 - - h5io=0.2.5=pyhecae5ae_0 + - h5io=0.2.5=pyhc455866_0 - h5py=3.13.0=nompi_py312hedeef09_100 - harfbuzz=12.2.0=h15599e2_0 - hdf4=4.2.15=h2a13503_7 @@ -92,9 +91,9 @@ dependencies: - idna=3.11=pyhd8ed1ab_0 - imagecodecs=2026.1.14=py312h40df4bb_1 - imageio=2.37.0=pyhfb79c49_0 - - importlib-metadata=8.7.0=pyhe01879c_1 + - importlib-metadata=8.8.0=pyhcf101f3_0 - importlib_resources=6.5.2=pyhd8ed1ab_0 - - intel-gmmlib=22.9.0=hb700be7_0 + - intel-gmmlib=22.10.0=hb700be7_0 - intel-media-driver=25.3.4=hecca717_0 - ipyevents=2.0.4=pyhbbac1ac_0 - ipykernel=7.2.0=pyha191276_1 @@ -107,7 +106,7 @@ dependencies: - joblib=1.5.3=pyhd8ed1ab_0 - json5=0.13.0=pyhd8ed1ab_0 - jsoncpp=1.9.6=hf42df4d_1 - - jsonpointer=3.0.0=pyhcf101f3_3 + - jsonpointer=3.1.1=pyhcf101f3_0 - jsonschema=4.26.0=pyhcf101f3_0 - jsonschema-specifications=2025.9.1=pyhcf101f3_0 - jsonschema-with-format-nongpl=4.26.0=hcf101f3_0 @@ -132,7 +131,7 @@ dependencies: - lazy-loader=0.5=pyhd8ed1ab_0 - lazy_loader=0.5=pyhd8ed1ab_0 - lcms2=2.18=h0c24ade_0 - - ld_impl_linux-64=2.45.1=default_hbd61a6d_101 + - ld_impl_linux-64=2.45.1=default_hbd61a6d_102 - lerc=4.1.0=hdb68285_0 - level-zero=1.28.2=hb700be7_0 - libabseil=20250512.1=cxx17_hba17884_0 @@ -144,7 +143,7 @@ dependencies: - libbrotlicommon=1.2.0=hb03c661_1 - libbrotlidec=1.2.0=hb03c661_1 - libbrotlienc=1.2.0=hb03c661_1 - - libcap=2.77=h3ff7636_0 + - libcap=2.77=hd0affe5_1 - libcblas=3.11.0=5_h0358290_openblas - libclang-cpp21.1=21.1.0=default_h99862b1_1 - libclang13=21.1.0=default_h746c552_1 @@ -158,8 +157,8 @@ dependencies: - libexpat=2.7.4=hecca717_0 - libffi=3.5.2=h3435931_0 - libflac=1.5.0=he200343_1 - - libfreetype=2.14.2=ha770c72_0 - - libfreetype6=2.14.2=h73754d4_0 + - libfreetype=2.14.3=ha770c72_0 + - libfreetype6=2.14.3=h73754d4_0 - libgcc=15.2.0=he0feb66_18 - libgcc-ng=15.2.0=h69a702a_18 - libgfortran=15.2.0=h69a702a_18 @@ -181,7 +180,7 @@ dependencies: - liblzma=5.8.2=hb03c661_0 - libmicrohttpd=1.0.2=hc2fc477_0 - libnetcdf=4.9.2=nompi_h00e09a9_116 - - libnghttp2=1.67.0=had1ee68_0 + - libnghttp2=1.68.1=h877daf1_0 - libnsl=2.0.1=hb9d3cd8_1 - libntlm=1.8=hb9d3cd8_0 - libogg=1.3.5=hd0c01bc_1 @@ -212,11 +211,11 @@ dependencies: - libssh2=1.11.1=hcf80075_0 - libstdcxx=15.2.0=h934c35e_18 - libstdcxx-ng=15.2.0=hdf11a46_18 - - libsystemd0=259.4=h6569c3e_0 + - libsystemd0=260.1=h6569c3e_0 - libtasn1=4.21.0=hb03c661_0 - libtheora=1.1.1=h4ab18f5_1006 - libtiff=4.7.1=h9d88235_1 - - libudev1=259.4=h6569c3e_0 + - libudev1=260.1=h6569c3e_0 - libunistring=0.9.10=h7f98852_0 - libunwind=1.8.3=h65a8314_0 - liburing=2.12=hb700be7_0 @@ -234,7 +233,7 @@ dependencies: - libxml2=2.13.9=h04c0eec_0 - libxslt=1.1.43=h7a3aeb2_0 - libzip=1.11.2=h6991a6a_0 - - libzlib=1.3.1=hb9d3cd8_2 + - libzlib=1.3.2=h25fd6f3_2 - libzopfli=1.0.3=h9c3ff4c_0 - llvmlite=0.46.0=py312h7424e68_0 - loguru=0.7.3=pyh707e725_0 @@ -270,7 +269,7 @@ dependencies: - notebook-shim=0.2.4=pyhd8ed1ab_1 - numba=0.64.0=py312hd1dde6f_0 - ocl-icd=2.3.3=hb9d3cd8_0 - - opencl-headers=2025.06.13=h5888daf_0 + - opencl-headers=2025.06.13=hecca717_0 - openh264=2.6.0=hc22cd8d_0 - openjpeg=2.5.4=h55fea9a_0 - openjph=0.26.3=h8d634f6_0 @@ -326,7 +325,7 @@ dependencies: - rav1e=0.7.1=h8fae777_3 - readline=8.3=h853b02a_0 - referencing=0.37.0=pyhcf101f3_0 - - requests=2.32.5=pyhcf101f3_1 + - requests=2.33.0=pyhcf101f3_0 - rfc3339-validator=0.1.4=pyhd8ed1ab_1 - rfc3986-validator=0.1.1=pyh9f0ad1d_0 - rfc3987-syntax=1.1.0=pyhe01879c_1 @@ -358,14 +357,14 @@ dependencies: - tinycss2=1.4.0=pyhd8ed1ab_0 - tk=8.6.13=noxft_h366c992_103 - tomli=2.4.0=pyhcf101f3_0 - - tornado=6.5.4=py312h961da02_0 + - tornado=6.5.5=py312h4c3975b_0 - tqdm=4.67.3=pyh8f84b5b_0 - traitlets=5.14.3=pyhd8ed1ab_1 - trame=3.12.0=pyhd8ed1ab_0 - - trame-client=3.11.3=pyhd8ed1ab_0 - - trame-common=1.1.2=pyhd8ed1ab_0 + - trame-client=3.11.4=pyhd8ed1ab_0 + - trame-common=1.1.3=pyhd8ed1ab_0 - trame-server=3.10.0=pyhd8ed1ab_0 - - trame-vtk=2.11.1=pyh932262d_0 + - trame-vtk=2.11.5=pyh3504b2d_0 - trame-vuetify=3.2.1=pyhd8ed1ab_0 - typing-extensions=4.15.0=h396c80c_0 - typing_extensions=4.15.0=pyhcf101f3_0 @@ -378,7 +377,7 @@ dependencies: - vtk=9.3.1=osmesa_py312hf4758c4_116 - vtk-base=9.3.1=osmesa_py312hc9bc066_116 - vtk-io-ffmpeg=9.3.1=osmesa_py312hf4758c4_116 - - wayland=1.24.0=hd6090a7_1 + - wayland=1.25.0=hd6090a7_0 - wayland-protocols=1.47=hd8ed1ab_0 - wcwidth=0.6.0=pyhd8ed1ab_0 - webcolors=25.10.0=pyhd8ed1ab_0 @@ -419,7 +418,7 @@ dependencies: - zeromq=4.3.5=h387f397_9 - zfp=1.0.1=h909a3a2_5 - zipp=3.23.0=pyhcf101f3_1 - - zlib=1.3.1=hb9d3cd8_2 + - zlib=1.3.2=h25fd6f3_2 - zlib-ng=2.3.3=hceb46e0_1 - zstd=1.5.7=hb78ec9c_6 - pip: @@ -435,17 +434,21 @@ dependencies: - libclang==18.1.1 - markdown==3.10.2 - ml-dtypes==0.5.4 + - mne-icalabel==0.8.1 + - mpmath==1.3.0 - namex==0.1.0 - numpy==2.1.3 + - onnxruntime==1.24.4 - opt-einsum==3.4.0 - optree==0.19.0 - protobuf==5.29.6 + - sympy==1.14.0 - tensorboard==2.19.0 - tensorboard-data-server==0.7.2 - tensorflow==2.19.0 - tensorflow-probability==0.25.0 - termcolor==3.3.0 - tf-keras==2.19.0 - - werkzeug==3.1.6 + - werkzeug==3.1.7 - wrapt==2.1.2 - osl-dynamics diff --git a/envs/osld-tf-cuda.yml b/envs/osld-tf-cuda.yml index 6fe66b64..00082986 100644 --- a/envs/osld-tf-cuda.yml +++ b/envs/osld-tf-cuda.yml @@ -31,6 +31,7 @@ dependencies: - trame-vtk - trame-vuetify - pip: + - mne-icalabel[onnx] - tensorflow[and-cuda]==2.19 - tensorflow-probability[tf]==0.25 - osl-dynamics diff --git a/envs/osld-tf-macos.yml b/envs/osld-tf-macos.yml index cab7fd0c..b7443b83 100644 --- a/envs/osld-tf-macos.yml +++ b/envs/osld-tf-macos.yml @@ -31,6 +31,7 @@ dependencies: - trame-vtk - trame-vuetify - pip: + - mne-icalabel[onnx] - tensorflow==2.16.1 - tensorflow-probability==0.24 - tf-keras==2.16 diff --git a/envs/osld-tf.yml b/envs/osld-tf.yml index 399cd793..a4fa1249 100644 --- a/envs/osld-tf.yml +++ b/envs/osld-tf.yml @@ -31,6 +31,7 @@ dependencies: - trame-vtk - trame-vuetify - pip: + - mne-icalabel[onnx] - tensorflow==2.19 - tensorflow-probability[tf]==0.25 - osl-dynamics diff --git a/envs/osld.yml b/envs/osld.yml index 30e8d207..71896e3b 100644 --- a/envs/osld.yml +++ b/envs/osld.yml @@ -31,4 +31,5 @@ dependencies: - trame-vtk - trame-vuetify - pip: + - mne-icalabel[onnx] - osl-dynamics diff --git a/examples/meg_preproc/1_preproc.py b/examples/meg_preproc/1_preproc.py index 3a4a4de0..4c115c10 100755 --- a/examples/meg_preproc/1_preproc.py +++ b/examples/meg_preproc/1_preproc.py @@ -30,7 +30,7 @@ def process_session(id, info, logger, **kwargs): raw_file = input_dir / info["subject"] / "meg" / info["file"] raw = mne.io.read_raw_fif(raw_file, preload=True) - raw = raw.crop(tmax=30) + raw = raw.crop(tmax=60) logger.log("Filtering and downsampling...") raw = raw.resample(sfreq=250) @@ -51,8 +51,11 @@ def process_session(id, info, logger, **kwargs): raw = preproc.detect_bad_channels(raw, picks="mag") raw = preproc.detect_bad_channels(raw, picks="grad") + logger.log("Running ICA artefact rejection...") + raw, ica, ic_labels = preproc.ica_label(raw, picks="meg") + logger.log("Saving QC plots...") - preproc.save_qc_plots(raw, plots_dir / id) + preproc.save_qc_plots(raw, plots_dir / id, ica=ica, ic_labels=ic_labels) logger.log("Saving preprocessed data...") preproc_out_dir = output_dir / "preprocessed" diff --git a/examples/meg_preproc/README.md b/examples/meg_preproc/README.md index 31d1d8f6..e18e3d6e 100755 --- a/examples/meg_preproc/README.md +++ b/examples/meg_preproc/README.md @@ -26,7 +26,7 @@ Run the scripts **in order**. Each script processes all sessions in parallel. | Script | Step | Description | |--------|------|-------------| -| `1_preproc.py` | Preprocessing | Downsample (250 Hz), bandpass filter (1-45 Hz), notch filter (50/100 Hz), bad segment and bad channel detection | +| `1_preproc.py` | Preprocessing | Downsample (250 Hz), bandpass filter (1-45 Hz), notch filter (50/100 Hz), bad segment/channel detection, ICA artefact rejection (based on MEGNet) | | `2_surfaces.py` | Surface Extraction | Extract inner skull, outer skull and scalp surfaces from structural MRI using FSL BET | | `3_coreg.py` | Coregistration | Coregister MEG to MRI using Polhemus headshape points | | `4_source_recon_and_parc.py` | Forward Model, Source Reconstruction and Parcellation | Compute forward model (8 mm dipole grid), LCMV beamformer, parcellate voxel data, apply symmetric orthogonalisation | @@ -108,6 +108,7 @@ plots/ │ ├── 1_sum_square.png │ ├── 1_sum_square_exclude_bads.png │ ├── 1_channel_stds.png +│ ├── 1_ica_components.png │ ├── 3_coreg.png │ └── 4_psd_topo.png ├── sub-02_task-rest/ diff --git a/osl_dynamics/meeg/preproc.py b/osl_dynamics/meeg/preproc.py index 4808f265..db15ed2c 100644 --- a/osl_dynamics/meeg/preproc.py +++ b/osl_dynamics/meeg/preproc.py @@ -2,14 +2,20 @@ import json from pathlib import Path -from typing import List, Optional, Union +from typing import Any, List, Optional, Tuple, Union import mne import numpy as np import matplotlib.pyplot as plt +from mne.preprocessing import ICA +from mne_icalabel import label_components from scipy import stats from scipy.ndimage.filters import uniform_filter1d +# ------------------------------------------------------------------------- +# Artefact detection +# ------------------------------------------------------------------------- + def detect_bad_segments( raw: mne.io.Raw, @@ -348,6 +354,383 @@ def _gesd( return out_mask +# ------------------------------------------------------------------------- +# ICA artefact rejection +# ------------------------------------------------------------------------- + + +def ica_label( + raw: mne.io.Raw, + picks: str = "mag", + n_components: int = 20, + method: str = "megnet", + threshold: float = 0.5, + random_state: int = 42, +) -> Tuple[mne.io.Raw, Any, dict]: + """Automatic ICA artefact rejection using mne-icalabel. + + Fits ICA on a bandpass-filtered copy of the data, labels components + using a pre-trained classifier, and removes artefact components from + the original data. + + For MEG data, uses the MEGNet classifier. For EEG data, uses ICLabel. + + Parameters + ---------- + raw : mne.io.Raw + MNE Raw object. + picks : str, optional + Channel type to use for ICA. For Elekta MEG data, use ``"mag"`` + (MEGNet was trained on magnetometer topographies). + n_components : int, optional + Number of ICA components. Should not exceed the data rank. + For MaxFiltered Elekta data (rank ~60), 20 is a safe default. + method : str, optional + Labelling method: ``"megnet"`` for MEG or ``"iclabel"`` for EEG. + threshold : float, optional + Probability threshold (0–1) for excluding a component. Components + labelled as artefact with probability above this threshold are + removed. + random_state : int, optional + Random seed for ICA reproducibility. + + Returns + ------- + raw : mne.io.Raw + Cleaned MNE Raw object. + ica : mne.preprocessing.ICA + Fitted ICA object with ``exclude`` set. + ic_labels : dict + Dictionary with keys ``"labels"`` (list of str) and + ``"y_pred_proba"`` (array of float). + Notes + ----- + For EEG data, use ``picks="eeg"`` and ``method="iclabel"``:: + + raw, ica, ic_labels = preproc.ica_label( + raw, picks="eeg", method="iclabel", n_components=30, + ) + """ + print() + print("ICA artefact rejection") + print("----------------------") + print(f"Method: {method}") + print(f"Picks: {picks}") + print(f"Components: {n_components}") + print(f"Threshold: {threshold}") + + # Filter a copy for ICA fitting (classifiers expect 1-100 Hz) + print("Filtering data copy (1-100 Hz) for ICA fitting...") + raw_fit = raw.copy().filter(l_freq=1.0, h_freq=100.0, verbose=False) + + # ICLabel (EEG) requires average reference + if method == "iclabel": + raw_fit.set_eeg_reference("average", verbose=False) + + # Fit ICA + print("Fitting ICA...") + ica = ICA( + n_components=n_components, + method="infomax", + fit_params=dict(extended=True), + random_state=random_state, + verbose=False, + ) + ica.fit(raw_fit, picks=picks) + + # Label components + print("Labelling components...") + ic_labels = label_components(raw_fit, ica, method=method) + + labels = ic_labels["labels"] + probs = ic_labels["y_pred_proba"] + + # Identify artefact components (exclude everything except brain/other) + # Note: MEGNet returns "brain/other" as a single label, while ICLabel + # returns "brain" and "other" separately + keep_labels = ["brain", "other", "brain/other"] + exclude_idx = [] + for idx, (label, prob) in enumerate(zip(labels, probs)): + if label not in keep_labels: + if prob > threshold: + exclude_idx.append(idx) + print(f" ICA{idx:03d}: {label} ({prob:.2f}) -> excluded") + else: + print( + f" ICA{idx:03d}: {label} ({prob:.2f}) -> kept " + f"(below threshold)" + ) + + # Apply to original data + if len(exclude_idx) > 0: + print(f"Removing {len(exclude_idx)} artefact component(s)...") + ica.exclude = exclude_idx + ica.apply(raw) + else: + print("No artefact components found.") + + return raw, ica, ic_labels + + +def ica_ecg_eog_correlation( + raw: mne.io.Raw, + picks: str = "meg", + n_components: int = 40, + l_freq: float = 1.0, + h_freq: Optional[float] = None, + ecg_method: Optional[str] = "ctps", + ecg_threshold: Union[str, float] = "auto", + eog_measure: str = "correlation", + eog_threshold: float = 0.35, + random_state: int = 42, +) -> Tuple[mne.io.Raw, Any, dict]: + """ICA artefact rejection using ECG/EOG correlation. + + Fits ICA on a high-pass filtered copy of the data, identifies artefact + components by correlating with ECG and EOG signals, and removes them + from the original data. Follows the approach used in osl-ephys. + + Uses ``picks="meg"`` by default so both magnetometers and gradiometers + are denoised. Does not require mne-icalabel. + + Parameters + ---------- + raw : mne.io.Raw + MNE Raw object. + picks : str, optional + Channel type to use for ICA. ``"meg"`` fits on both mags and grads. + n_components : int, optional + Number of ICA components. Should not exceed the data rank. + For MaxFiltered Elekta data (rank ~60), 40 is a safe default. + l_freq : float, optional + High-pass filter frequency for the ICA fitting copy. + h_freq : float, optional + Low-pass filter frequency for the ICA fitting copy. + ecg_method : str, optional + Method for ECG detection: ``"ctps"`` (cross-trial phase + statistics) or ``"correlation"``. Set to ``None`` to skip + ECG detection. + ecg_threshold : str or float, optional + Threshold for ECG component detection. + eog_measure : str, optional + Measure for EOG detection: ``"correlation"`` or ``"zscore"``. + eog_threshold : float, optional + Threshold for EOG component detection. When + ``eog_measure="correlation"``, this is an absolute correlation + threshold (e.g. 0.35). When ``eog_measure="zscore"``, this is + a z-score threshold (e.g. 3.0). + random_state : int, optional + Random seed for ICA reproducibility. + + Returns + ------- + raw : mne.io.Raw + Cleaned MNE Raw object. + ica : mne.preprocessing.ICA + Fitted ICA object with ``exclude`` set. + ic_labels : dict + Dictionary with keys ``"labels"`` (list of str) and + ``"y_pred_proba"`` (array of float), compatible with + :func:`plot_ica_components`. + Notes + ----- + For EEG data, use ``picks="eeg"``. Note that synthetic ECG detection + only works with MEG magnetometers — for EEG data a dedicated ECG + channel must be present, otherwise set ``ecg_method=None``:: + + raw, ica, ic_labels = preproc.ica_ecg_eog_correlation( + raw, picks="eeg", n_components=30, ecg_method=None, + ) + """ + print() + print("ICA artefact rejection (ECG/EOG correlation)") + print("---------------------------------------------") + print(f"Picks: {picks}") + print(f"Components: {n_components}") + + # Filter a copy for ICA fitting + print(f"Filtering data copy ({l_freq}-{h_freq} Hz) for ICA fitting...") + raw_fit = raw.copy().filter(l_freq=l_freq, h_freq=h_freq, verbose=False) + + # Fit ICA + print("Fitting ICA...") + ica = ICA( + n_components=n_components, + method="fastica", + random_state=random_state, + verbose=False, + ) + ica.fit(raw_fit, picks=picks) + + # Detect ECG components + ecg_indices = [] + ecg_scores = np.zeros(ica.n_components_) + if ecg_method is not None: + print("Detecting ECG components...") + try: + ecg_indices, ecg_scores = ica.find_bads_ecg( + raw_fit, + method=ecg_method, + threshold=ecg_threshold, + verbose=False, + ) + for idx in ecg_indices: + print(f" ICA{idx:03d}: ecg (score={ecg_scores[idx]:.2f})") + if not ecg_indices: + print(" No ECG components found.") + except Exception as e: + print(f" ECG detection failed: {e}") + + # Detect EOG components + eog_indices = [] + eog_scores = np.zeros(ica.n_components_) + eog_chs = mne.pick_types(raw_fit.info, eog=True) + if len(eog_chs) == 0: + print("No EOG channel found, skipping EOG detection.") + else: + print("Detecting EOG components...") + try: + eog_indices, eog_scores = ica.find_bads_eog( + raw_fit, + measure=eog_measure, + threshold=eog_threshold, + verbose=False, + ) + # eog_scores can be a list of arrays if multiple EOG channels + if isinstance(eog_scores, list): + eog_scores = np.max(np.abs(eog_scores), axis=0) + for idx in eog_indices: + print(f" ICA{idx:03d}: eog (score={eog_scores[idx]:.2f})") + if not eog_indices: + print(" No EOG components found.") + except Exception as e: + print(f" EOG detection failed: {e}") + + # Combine and exclude + exclude_idx = sorted(set(ecg_indices + eog_indices)) + + # Capture pre-ICA PSD + psd_before_ica = raw.compute_psd(fmax=45) + + if len(exclude_idx) > 0: + print(f"Removing {len(exclude_idx)} artefact component(s)...") + ica.exclude = exclude_idx + ica.apply(raw) + else: + print("No artefact components found.") + + # Build ic_labels dict (same structure as ica_label output) + labels = [] + probs = [] + for idx in range(ica.n_components_): + if idx in ecg_indices and idx in eog_indices: + labels.append("ecg+eog") + probs.append(max(abs(ecg_scores[idx]), abs(eog_scores[idx]))) + elif idx in ecg_indices: + labels.append("ecg") + probs.append(abs(ecg_scores[idx])) + elif idx in eog_indices: + labels.append("eog") + probs.append(abs(eog_scores[idx])) + else: + labels.append("brain") + probs.append(0.0) + + ic_labels = {"labels": labels, "y_pred_proba": np.array(probs)} + + return raw, ica, ic_labels + + +def plot_ica_components( + ica: Any, + ic_labels: dict, +) -> Optional[plt.Figure]: + """Plot excluded ICA component topographies with labels. + + Creates a composite figure showing only the excluded ICA components + with their classification labels and probabilities. + + Parameters + ---------- + ica : mne.preprocessing.ICA + Fitted ICA object. + ic_labels : dict + Dictionary with keys ``"labels"`` and ``"y_pred_proba"``. + + Returns + ------- + fig : matplotlib.figure.Figure or None + The composite figure, or None if no components were excluded. + """ + exclude_idx = ica.exclude + if len(exclude_idx) == 0: + print("No ICA components excluded — nothing to plot.") + return None + + labels = ic_labels["labels"] + probs = ic_labels["y_pred_proba"] + n_excluded = len(exclude_idx) + + # Create composite figure + n_cols = min(n_excluded, 5) + n_rows = (n_excluded + n_cols - 1) // n_cols + fig, axes = plt.subplots( + n_rows, + n_cols, + figsize=(5 * n_cols, 5 * n_rows), + ) + if n_rows == 1 and n_cols == 1: + axes = np.array([[axes]]) + elif n_rows == 1: + axes = axes[np.newaxis, :] + elif n_cols == 1: + axes = axes[:, np.newaxis] + + for i, idx in enumerate(exclude_idx): + row, col = divmod(i, n_cols) + ax = axes[row, col] + + # Plot individual component topography + comp_figs = ica.plot_components(picks=[idx], show=False) + if isinstance(comp_figs, list): + comp_fig = comp_figs[0] + else: + comp_fig = comp_figs + comp_fig.canvas.draw() + buf = comp_fig.canvas.buffer_rgba() + img = np.asarray(buf) + plt.close(comp_fig) + + ax.imshow(img) + ax.set_axis_off() + + label = labels[idx] + prob = probs[idx] + ax.set_title( + f"ICA{idx:03d}: {label} ({prob:.2f})", + fontsize=16, + color="red", + fontweight="bold", + ) + + # Hide unused axes + for i in range(n_excluded, n_rows * n_cols): + row, col = divmod(i, n_cols) + axes[row, col].set_axis_off() + + fig.suptitle( + f"Excluded ICA Components ({n_excluded})", + fontsize=20, + ) + fig.tight_layout() + return fig + + +# ------------------------------------------------------------------------- +# Headshape decimation +# ------------------------------------------------------------------------- + + def decimate_headshape_points( raw: mne.io.Raw, decimate_amount: float = 0.01, @@ -628,19 +1011,30 @@ def _grid_average_decimate( return np.array([np.mean(voxel_dict[key], axis=0) for key in voxel_dict]) +# ------------------------------------------------------------------------- +# QC plots +# ------------------------------------------------------------------------- + + def save_qc_plots( raw: mne.io.Raw, output_dir: Union[str, Path], show: bool = False, + ica: Any = None, + ic_labels: Optional[dict] = None, ) -> None: """Save preprocessing QC plots and summary. Saves the following files to output_dir: - - 1_summary.json: preprocessing summary stats - - 1_psd.png: sensor-level PSD - - 1_sum_square.png: sum-square time series - - 1_sum_square_exclude_bads.png: sum-square excluding bad segments/channels - - 1_channel_stds.png: channel standard deviation distributions + + - ``1_summary.json``: preprocessing summary stats + - ``1_psd.png``: sensor-level PSD + - ``1_sum_square.png``: sum-square time series + - ``1_sum_square_exclude_bads.png``: sum-square excluding bad + segments/channels + - ``1_channel_stds.png``: channel standard deviation distributions + - ``1_ica_components.png``: ICA component topographies (if ``ica`` + and ``ic_labels`` are provided) Parameters ---------- @@ -650,6 +1044,11 @@ def save_qc_plots( Directory to save plots to. show : bool, optional Whether to display the plots interactively. Default is False. + ica : mne.preprocessing.ICA, optional + Fitted ICA object. If provided along with ``ic_labels``, saves + ICA component topography plot. + ic_labels : dict, optional + ICA label dictionary from ``ica_label``. """ output_dir = Path(output_dir) output_dir.mkdir(parents=True, exist_ok=True) @@ -666,6 +1065,13 @@ def save_qc_plots( "bad_channels": raw.info["bads"], "n_bad_channels": len(raw.info["bads"]), } + if ica is not None and ic_labels is not None: + summary["ica_n_components"] = ica.n_components_ + summary["ica_n_excluded"] = len(ica.exclude) + summary["ica_excluded_labels"] = [ + f"{ic_labels['labels'][i]} ({ic_labels['y_pred_proba'][i]:.2f})" + for i in ica.exclude + ] with open(output_dir / "1_summary.json", "w") as f: json.dump(summary, f, indent=2) @@ -695,6 +1101,16 @@ def save_qc_plots( if not show: plt.close("all") + # ICA component topographies (excluded components only) + if ica is not None and ic_labels is not None: + fig = plot_ica_components(ica, ic_labels) + if fig is not None: + fig.savefig( + output_dir / "1_ica_components.png", dpi=150, bbox_inches="tight" + ) + if not show: + plt.close("all") + def plot_sum_square_time_series( raw: mne.io.Raw, diff --git a/osl_dynamics/meeg/report.py b/osl_dynamics/meeg/report.py index ba4b3722..77c012b9 100644 --- a/osl_dynamics/meeg/report.py +++ b/osl_dynamics/meeg/report.py @@ -11,33 +11,67 @@ STEPS = { 1: { "name": "Preprocessing", - "files": [ - "1_psd.png", - "1_sum_square.png", - "1_sum_square_exclude_bads.png", - "1_channel_stds.png", + "subpanels": [ + { + "name": "PSD", + "files": ["1_psd.png"], + }, + { + "name": "Sum-Square", + "files": ["1_sum_square.png"], + }, + { + "name": "Sum-Square (excl. bads)", + "files": ["1_sum_square_exclude_bads.png"], + }, + { + "name": "Channel Stds", + "files": ["1_channel_stds.png"], + }, + { + "name": "ICA Components", + "files": ["1_ica_components.png"], + }, ], - "large": ["1_psd.png"], }, 2: { "name": "Surfaces", - "files": [ - "2_inskull.png", - "2_outskin.png", - "2_outskull.png", + "subpanels": [ + { + "name": "Inner Skull", + "files": ["2_inskull.png"], + }, + { + "name": "Outer Skull", + "files": ["2_outskull.png"], + }, + { + "name": "Outer Skin", + "files": ["2_outskin.png"], + }, + { + "name": "Outer Skin + Nose", + "files": ["2_outskin_plus_nose.png"], + }, ], }, 3: { "name": "Coregistration", - "files": ["3_coreg.png"], - "large": ["3_coreg.png"], + "subpanels": [ + { + "name": "Coregistration", + "files": ["3_coreg.png"], + }, + ], }, 4: { "name": "Source Recon & Parcellation", - "files": [ - "4_psd_topo.png", + "subpanels": [ + { + "name": "Parcellation PSD", + "files": ["4_psd_topo.png"], + }, ], - "large": ["4_psd_topo.png"], }, } @@ -83,8 +117,6 @@ background: #fff; padding: 20px; border-radius: 0 0 6px 6px; - max-height: calc(100vh - 160px); - overflow-y: auto; } .tab-content.active { display: block; @@ -93,7 +125,7 @@ display: flex; align-items: center; gap: 10px; - margin-bottom: 20px; + margin-bottom: 10px; padding: 10px; background: #f0f0f0; border-radius: 6px; @@ -123,23 +155,61 @@ color: #888; margin-left: auto; } +.subpanel-nav { + display: flex; + align-items: center; + gap: 8px; + margin-bottom: 16px; + padding: 8px 10px; + background: #e8f0fe; + border-radius: 6px; +} +.subpanel-nav button { + padding: 4px 12px; + border: 1px solid #b0c4de; + background: #fff; + cursor: pointer; + border-radius: 4px; + font-size: 16px; + line-height: 1; +} +.subpanel-nav button:hover { + background: #dce8f5; +} +.subpanel-nav .subpanel-label { + font-size: 14px; + font-weight: bold; + color: #333; +} +.subpanel-nav .subpanel-counter { + font-size: 12px; + color: #888; + margin-left: auto; +} +.subpanel-nav .hint { + font-size: 11px; + color: #aaa; +} .session-panel { display: none; } .session-panel.active { display: block; } -.session-panel img { +.subpanel { + display: none; +} +.subpanel.active { + display: block; +} +.subpanel img { max-width: 100%; - max-height: 250px; + max-height: calc(80vh - 280px); display: block; margin: 4px auto; border: 1px solid #eee; } -.session-panel img.large { - max-height: 400px; -} -.session-panel iframe { +.subpanel iframe { width: 100%; height: 350px; border: 1px solid #ddd; @@ -207,10 +277,15 @@ JS = """ var sessions = SESSION_LIST; var currentStep = 1; -var currentIdx = {}; // per-step session index - -// Initialise each step to session 0 -STEP_NUMS.forEach(function(s) { currentIdx[s] = 0; }); +var currentIdx = {}; // per-step session index +var currentSubpanel = {}; // per-step subpanel index +var subpanelCounts = SUBPANEL_COUNTS; // {step: count} + +// Initialise each step to session 0, subpanel 0 +STEP_NUMS.forEach(function(s) { + currentIdx[s] = 0; + currentSubpanel[s] = 0; +}); function switchTab(step) { currentStep = step; @@ -226,19 +301,52 @@ if (idx >= sessions.length) idx = sessions.length - 1; currentIdx[step] = idx; - // Hide all session panels for this step var panels = document.querySelectorAll('#tab-' + step + ' .session-panel'); panels.forEach(p => p.classList.remove('active')); - // Show the selected one var panel = document.getElementById('step-' + step + '-session-' + idx); if (panel) panel.classList.add('active'); - // Update input and counter var input = document.getElementById('input-' + step); var counter = document.getElementById('counter-' + step); input.value = sessions[idx]; counter.textContent = (idx + 1) + ' / ' + sessions.length; + + showSubpanel(step, currentSubpanel[step]); +} + +function showSubpanel(step, spIdx) { + var count = subpanelCounts[step] || 1; + if (spIdx < 0) spIdx = 0; + if (spIdx >= count) spIdx = count - 1; + currentSubpanel[step] = spIdx; + + // Hide all subpanels in the current session panel + var sessionIdx = currentIdx[step]; + var panel = document.getElementById('step-' + step + '-session-' + sessionIdx); + if (!panel) return; + var subs = panel.querySelectorAll('.subpanel'); + subs.forEach(s => s.classList.remove('active')); + var target = panel.querySelector('.subpanel[data-sp-idx="' + spIdx + '"]'); + if (target) target.classList.add('active'); + + // Update subpanel nav label and counter + var label = document.getElementById('sp-label-' + step); + var spCounter = document.getElementById('sp-counter-' + step); + if (label && target) { + label.textContent = target.getAttribute('data-sp-name') || ''; + } + if (spCounter) { + spCounter.textContent = (spIdx + 1) + ' / ' + count; + } +} + +function prevSubpanel(step) { + showSubpanel(step, currentSubpanel[step] - 1); +} + +function nextSubpanel(step) { + showSubpanel(step, currentSubpanel[step] + 1); } function prevSession(step) { @@ -258,7 +366,6 @@ return; } } - // Partial match: find first session containing the input for (var i = 0; i < sessions.length; i++) { if (sessions[i].toLowerCase().indexOf(val) !== -1) { showSession(step, i); @@ -268,7 +375,6 @@ } document.addEventListener('keydown', function(e) { - // Ignore if user is typing in the input if (document.activeElement.tagName === 'INPUT') { if (e.key === 'Enter') { e.preventDefault(); @@ -282,6 +388,12 @@ } else if (e.key === 'ArrowRight') { e.preventDefault(); nextSession(currentStep); + } else if (e.key === 'ArrowUp') { + e.preventDefault(); + prevSubpanel(currentStep); + } else if (e.key === 'ArrowDown') { + e.preventDefault(); + nextSubpanel(currentStep); } }); """ @@ -299,14 +411,26 @@ def _build_summary(session_dir: Path) -> str: pct = s["bad_percent"] n_bad_ch = s["n_bad_channels"] bad_chs = ", ".join(s["bad_channels"]) if s["bad_channels"] else "none" - return ( + html = ( f'