diff --git a/.github/workflows/check_formatting.yml b/.github/workflows/check_formatting.yml new file mode 100644 index 0000000..3f106e8 --- /dev/null +++ b/.github/workflows/check_formatting.yml @@ -0,0 +1,44 @@ +name: Check formatting + +on: [push, pull_request] + +permissions: + contents: read + +jobs: + main: + runs-on: ubuntu-latest + steps: + + - uses: actions/checkout@v4 + + - uses: actions/setup-python@v5 + with: + python-version: "3.11" + + - name: Check EOF + uses: pre-commit/action@v3.0.0 + with: + extra_args: end-of-file-fixer + + - name: Check trailing whitespace + uses: pre-commit/action@v3.0.0 + with: + extra_args: trailing-whitespace + + - name: Black + uses: psf/black@stable + with: + options: "--check" + src: "./src" + jupyter: true + + - name: isort + uses: isort/isort-action@v1 + with: + configuration: --profile=black --check-only --diff + + - name: Ruff + uses: chartboost/ruff-action@v1 + with: + args: "check" diff --git a/.github/workflows/full_tests.yml b/.github/workflows/full_tests.yml index c74c8ae..4906926 100644 --- a/.github/workflows/full_tests.yml +++ b/.github/workflows/full_tests.yml @@ -21,7 +21,6 @@ jobs: run: | python -m pip install --upgrade pip pip install .[test] - pip install pytest - name: Pytest run: | pytest -v diff --git a/.gitignore b/.gitignore index 5be366e..fb19775 100644 --- a/.gitignore +++ b/.gitignore @@ -1,14 +1,793 @@ + +### Custom ### + +### C ### +# Prerequisites +*.d + +# Object files +*.o +*.ko +*.obj +*.elf + +# Linker output +*.ilk +*.map +*.exp + +# Precompiled Headers +*.gch +*.pch + +# Libraries +*.lib +*.a +*.la +*.lo + +# Shared objects (inc. Windows DLLs) +*.dll +*.so +*.so.* +*.dylib + +# Executables +*.exe +*.out +*.app +*.i*86 +*.x86_64 +*.hex + +# Debug files +*.dSYM/ +*.su +*.idb +*.pdb + +# Kernel Module Compile Results +*.mod* +*.cmd +.tmp_versions/ +modules.order +Module.symvers +Mkfile.old +dkms.conf + +### C++ ### +# Prerequisites + +# Compiled Object files +*.slo + +# Fortran module files +*.mod +*.smod + +# Compiled Static libraries +*.lai + +### CMake ### +CMakeLists.txt.user +CMakeCache.txt +CMakeFiles +CMakeScripts +Testing +Makefile +cmake_install.cmake +install_manifest.txt +compile_commands.json +CTestTestfile.cmake +_deps + +### CMake Patch ### +# External projects +*-prefix/ + +### Java ### +# Compiled class file +*.class + +# Log file +*.log + +# BlueJ files +*.ctxt + +# Mobile Tools for Java (J2ME) +.mtj.tmp/ + +# Package Files # +*.jar +*.war +*.nar +*.ear +*.zip +*.tar.gz +*.rar + +# virtual machine crash logs, see http://www.java.com/en/download/help/error_hotspot.xml +hs_err_pid* +replay_pid* + +### Julia ### +# Files generated by invoking Julia with --code-coverage +*.jl.cov +*.jl.*.cov + +# Files generated by invoking Julia with --track-allocation +*.jl.mem + +# System-specific files and directories generated by the BinaryProvider and BinDeps packages +# They contain absolute paths specific to the host computer, and so should not be committed +deps/deps.jl +deps/build.log +deps/downloads/ +deps/usr/ +deps/src/ + +# Build artifacts for creating documentation generated by the Documenter package +docs/build/ +docs/site/ + +# File generated by Pkg, the package manager, based on a corresponding Project.toml +# It records a fixed state of all packages used by the project. As such, it should not be +# committed for packages, but should be committed for applications that require a static +# environment. +Manifest.toml + +### JupyterNotebooks ### +# gitignore template for Jupyter Notebooks +# website: http://jupyter.org/ + +.ipynb_checkpoints +*/.ipynb_checkpoints/* + +# IPython +profile_default/ +ipython_config.py + +# Remove previous ipynb_checkpoints +# git rm -r .ipynb_checkpoints/ + +### LaTeX ### +## Core latex/pdflatex auxiliary files: +*.aux +*.lof +*.lot +*.fls +*.toc +*.fmt +*.fot +*.cb +*.cb2 +.*.lb + +## Intermediate documents: +*.dvi +*.xdv +*-converted-to.* +# these rules might exclude image files for figures etc. +# *.ps +# *.eps +# *.pdf + +## Generated if empty string is given at "Please type another file name for output:" +.pdf + +## Bibliography auxiliary files (bibtex/biblatex/biber): +*.bbl +*.bcf +*.blg +*-blx.aux +*-blx.bib +*.run.xml + +## Build tool auxiliary files: +*.fdb_latexmk +*.synctex +*.synctex(busy) +*.synctex.gz +*.synctex.gz(busy) +*.pdfsync + +## Build tool directories for auxiliary files +# latexrun +latex.out/ + +## Auxiliary and intermediate files from other packages: +# algorithms +*.alg +*.loa + +# achemso +acs-*.bib + +# amsthm +*.thm + +# beamer +*.nav +*.pre +*.snm +*.vrb + +# changes +*.soc + +# comment +*.cut + +# cprotect +*.cpt + +# elsarticle (documentclass of Elsevier journals) +*.spl + +# endnotes +*.ent + +# fixme +*.lox + +# feynmf/feynmp +*.mf +*.mp +*.t[1-9] +*.t[1-9][0-9] +*.tfm + +#(r)(e)ledmac/(r)(e)ledpar +*.end +*.?end +*.[1-9] +*.[1-9][0-9] +*.[1-9][0-9][0-9] +*.[1-9]R +*.[1-9][0-9]R +*.[1-9][0-9][0-9]R +*.eledsec[1-9] +*.eledsec[1-9]R +*.eledsec[1-9][0-9] +*.eledsec[1-9][0-9]R +*.eledsec[1-9][0-9][0-9] +*.eledsec[1-9][0-9][0-9]R + +# glossaries +*.acn +*.acr +*.glg +*.glo +*.gls +*.glsdefs +*.lzo +*.lzs +*.slg +*.sls + +# uncomment this for glossaries-extra (will ignore makeindex's style files!) +# *.ist + +# gnuplot +*.gnuplot +*.table + +# gnuplottex +*-gnuplottex-* + +# gregoriotex +*.gaux +*.glog +*.gtex + +# htlatex +*.4ct +*.4tc +*.idv +*.lg +*.trc +*.xref + +# hyperref +*.brf + +# knitr +*-concordance.tex +# TODO Uncomment the next line if you use knitr and want to ignore its generated tikz files +# *.tikz +*-tikzDictionary + +# listings +*.lol + +# luatexja-ruby +*.ltjruby + +# makeidx +*.idx +*.ilg +*.ind + +# minitoc +*.maf +*.mlf +*.mlt +*.mtc[0-9]* +*.slf[0-9]* +*.slt[0-9]* +*.stc[0-9]* + +# minted +_minted* +*.pyg + +# morewrites +*.mw + +# newpax +*.newpax + +# nomencl +*.nlg +*.nlo +*.nls + +# pax +*.pax + +# pdfpcnotes +*.pdfpc + +# sagetex +*.sagetex.sage +*.sagetex.py +*.sagetex.scmd + +# scrwfile +*.wrt + +# svg +svg-inkscape/ + +# sympy +*.sout +*.sympy +sympy-plots-for-*.tex/ + +# pdfcomment +*.upa +*.upb + +# pythontex +*.pytxcode +pythontex-files-*/ + +# tcolorbox +*.listing + +# thmtools +*.loe + +# TikZ & PGF +*.dpth +*.md5 +*.auxlock + +# titletoc +*.ptc + +# todonotes +*.tdo + +# vhistory +*.hst +*.ver + +# easy-todo +*.lod + +# xcolor +*.xcp + +# xmpincl +*.xmpi + +# xindy +*.xdy + +# xypic precompiled matrices and outlines +*.xyc +*.xyd + +# endfloat +*.ttt +*.fff + +# Latexian +TSWLatexianTemp* + +## Editors: +# WinEdt +*.bak +*.sav + +# Texpad +.texpadtmp + +# LyX +*.lyx~ + +# Kile +*.backup + +# gummi +.*.swp + +# KBibTeX +*~[0-9]* + +# TeXnicCenter +*.tps + +# auto folder when using emacs and auctex +./auto/* +*.el + +# expex forward references with \gathertags +*-tags.tex + +# standalone packages +*.sta + +# Makeindex log files +*.lpz + +# xwatermark package +*.xwm + +# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib +# option is specified. Footnotes are the stored in a file with suffix Notes.bib. +# Uncomment the next line to have this generated file ignored. +#*Notes.bib + +### LaTeX Patch ### +# LIPIcs / OASIcs +*.vtc + +# glossaries +*.glstex + +### Linux ### +*~ + +# temporary files which can be created if a process still has a handle open of a deleted file +.fuse_hidden* + +# KDE directory preferences +.directory + +# Linux trash folder which might appear on any partition or disk +.Trash-* + +# .nfs files are created when an open file is removed but is still being accessed +.nfs* + +### macOS ### +# General +.DS_Store +.AppleDouble +.LSOverride + +# Icon must end with two \r +Icon + + +# Thumbnails +._* + +# Files that might appear in the root of a volume +.DocumentRevisions-V100 +.fseventsd +.Spotlight-V100 +.TemporaryItems +.Trashes +.VolumeIcon.icns +.com.apple.timemachine.donotpresent + +# Directories potentially created on remote AFP share +.AppleDB +.AppleDesktop +Network Trash Folder +Temporary Items +.apdisk + +### macOS Patch ### +# iCloud generated files +*.icloud + +### MATLAB ### +# Windows default autosave extension +*.asv + +# OSX / *nix default autosave extension +*.m~ + +# Compiled MEX binaries (all platforms) +*.mex* + +# Packaged app and toolbox files +*.mlappinstall +*.mltbx + +# Generated helpsearch folders +helpsearch*/ + +# Simulink code generation folders +slprj/ +sccprj/ + +# Matlab code generation folders +codegen/ + +# Simulink autosave extension +*.autosave + +# Simulink cache files +*.slxc + +# Octave session info +octave-workspace + +### Node ### +# Logs +logs +npm-debug.log* +yarn-debug.log* +yarn-error.log* +lerna-debug.log* +.pnpm-debug.log* + +# Diagnostic reports (https://nodejs.org/api/report.html) +report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json + +# Runtime data +pids +*.pid +*.seed +*.pid.lock + +# Directory for instrumented libs generated by jscoverage/JSCover +lib-cov + +# Coverage directory used by tools like istanbul +coverage +*.lcov + +# nyc test coverage +.nyc_output + +# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files) +.grunt + +# Bower dependency directory (https://bower.io/) +bower_components + +# node-waf configuration +.lock-wscript + +# Compiled binary addons (https://nodejs.org/api/addons.html) +build/Release + +# Dependency directories +node_modules/ +jspm_packages/ + +# Snowpack dependency directory (https://snowpack.dev/) +web_modules/ + +# TypeScript cache +*.tsbuildinfo + +# Optional npm cache directory +.npm + +# Optional eslint cache +.eslintcache + +# Optional stylelint cache +.stylelintcache + +# Microbundle cache +.rpt2_cache/ +.rts2_cache_cjs/ +.rts2_cache_es/ +.rts2_cache_umd/ + +# Optional REPL history +.node_repl_history + +# Output of 'npm pack' +*.tgz + +# Yarn Integrity file +.yarn-integrity + +# dotenv environment variable files +.env +.env.development.local +.env.test.local +.env.production.local +.env.local + +# parcel-bundler cache (https://parceljs.org/) +.cache +.parcel-cache + +# Next.js build output +.next +out + +# Nuxt.js build / generate output +.nuxt +dist + +# Gatsby files +.cache/ +# Comment in the public line in if your project uses Gatsby and not Next.js +# https://nextjs.org/blog/next-9-1#public-directory-support +# public + +# vuepress build output +.vuepress/dist + +# vuepress v2.x temp and cache directory +.temp + +# Docusaurus cache and generated files +.docusaurus + +# Serverless directories +.serverless/ + +# FuseBox cache +.fusebox/ + +# DynamoDB Local files +.dynamodb/ + +# TernJS port file +.tern-port + +# Stores VSCode versions used for testing VSCode extensions +.vscode-test + +# yarn v2 +.yarn/cache +.yarn/unplugged +.yarn/build-state.yml +.yarn/install-state.gz +.pnp.* + +### Node Patch ### +# Serverless Webpack directories +.webpack/ + +# Optional stylelint cache + +# SvelteKit build / generate output +.svelte-kit + +### PyCharm ### +# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider +# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839 + +# User-specific stuff +.idea/**/workspace.xml +.idea/**/tasks.xml +.idea/**/usage.statistics.xml +.idea/**/dictionaries +.idea/**/shelf + +# AWS User-specific +.idea/**/aws.xml + +# Generated files +.idea/**/contentModel.xml + +# Sensitive or high-churn files +.idea/**/dataSources/ +.idea/**/dataSources.ids +.idea/**/dataSources.local.xml +.idea/**/sqlDataSources.xml +.idea/**/dynamic.xml +.idea/**/uiDesigner.xml +.idea/**/dbnavigator.xml + +# Gradle +.idea/**/gradle.xml +.idea/**/libraries + +# CMake +cmake-build-*/ + +# Mongo Explorer plugin +.idea/**/mongoSettings.xml + +# File-based project format +*.iws + +# IntelliJ +out/ + +# mpeltonen/sbt-idea plugin +.idea_modules/ + +# JIRA plugin +atlassian-ide-plugin.xml + +# Cursive Clojure plugin +.idea/replstate.xml + +# SonarLint plugin +.idea/sonarlint/ + +# Crashlytics plugin (for Android Studio and IntelliJ) +com_crashlytics_export_strings.xml +crashlytics.properties +crashlytics-build.properties +fabric.properties + +# Editor-based Rest Client +.idea/httpRequests + +# Android studio 3.1+ serialized cache file +.idea/caches/build_file_checksums.ser + +### PyCharm Patch ### +# Comment Reason: https://github.com/joeblau/gitignore.io/issues/186#issuecomment-215987721 + +# *.iml +# modules.xml +# .idea/misc.xml +# *.ipr + +# Sonarlint plugin +# https://plugins.jetbrains.com/plugin/7973-sonarlint +.idea/**/sonarlint/ + +# SonarQube Plugin +# https://plugins.jetbrains.com/plugin/7238-sonarqube-community-plugin +.idea/**/sonarIssues.xml + +# Markdown Navigator plugin +# https://plugins.jetbrains.com/plugin/7896-markdown-navigator-enhanced +.idea/**/markdown-navigator.xml +.idea/**/markdown-navigator-enh.xml +.idea/**/markdown-navigator/ + +# Cache file creation bug +# See https://youtrack.jetbrains.com/issue/JBR-2257 +.idea/$CACHE_FILE$ + +# CodeStream plugin +# https://plugins.jetbrains.com/plugin/12206-codestream +.idea/codestream.xml + +# Azure Toolkit for IntelliJ plugin +# https://plugins.jetbrains.com/plugin/8053-azure-toolkit-for-intellij +.idea/**/azureSettings.xml + +### Python ### # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] *$py.class # C extensions -*.so # Distribution / packaging .Python -env/ build/ develop-eggs/ dist/ @@ -20,9 +799,12 @@ lib64/ parts/ sdist/ var/ +wheels/ +share/python-wheels/ *.egg-info/ .installed.cfg *.egg +MANIFEST # PyInstaller # Usually these files are written by a python script from a template @@ -37,21 +819,25 @@ pip-delete-this-directory.txt # Unit test / coverage reports htmlcov/ .tox/ +.nox/ .coverage .coverage.* -.cache nosetests.xml coverage.xml -*,cover +*.cover +*.py,cover .hypothesis/ +.pytest_cache/ +cover/ # Translations *.mo *.pot # Django stuff: -*.log local_settings.py +db.sqlite3 +db.sqlite3-journal # Flask stuff: instance/ @@ -64,33 +850,171 @@ instance/ docs/_build/ # PyBuilder +.pybuilder/ target/ -# IPython Notebook -.ipynb_checkpoints +# Jupyter Notebook + +# IPython # pyenv -.python-version +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock -# celery beat schedule file +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff celerybeat-schedule +celerybeat.pid -# dotenv -.env +# SageMath parsed files +*.sage.py -# virtualenv +# Environments +.venv +env/ venv/ ENV/ +env.bak/ +venv.bak/ # Spyder project settings .spyderproject +.spyproject # Rope project settings .ropeproject +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +### Python Patch ### +# Poetry local configuration file - https://python-poetry.org/docs/configuration/#local-configuration +poetry.toml + +# ruff +.ruff_cache/ + +# LSP config files +pyrightconfig.json + +### Rust ### +# Generated by Cargo +# will have compiled files and executables +debug/ + +# Remove Cargo.lock from gitignore if creating an executable, leave it for libraries +# More information here https://doc.rust-lang.org/cargo/guide/cargo-toml-vs-cargo-lock.html +Cargo.lock + +# These are backup files generated by rustfmt +**/*.rs.bk + +# MSVC Windows builds of rustc generate these, which store debugging information + ### VisualStudioCode ### .vscode/* +!.vscode/settings.json +!.vscode/tasks.json +!.vscode/launch.json +!.vscode/extensions.json +!.vscode/*.code-snippets + +# Local History for Visual Studio Code +.history/ + +# Built Visual Studio Code Extensions +*.vsix ### VisualStudioCode Patch ### # Ignore all local history of files .history +.ionide + +### Windows ### +# Windows thumbnail cache files +Thumbs.db +Thumbs.db:encryptable +ehthumbs.db +ehthumbs_vista.db + +# Dump file +*.stackdump + +# Folder config file +[Dd]esktop.ini + +# Recycle Bin used on file shares +$RECYCLE.BIN/ + +# Windows Installer files +*.cab +*.msi +*.msix +*.msm +*.msp + +# Windows shortcuts +*.lnk + +### Xcode ### +## User settings +xcuserdata/ + +## Xcode 8 and earlier +*.xcscmblueprint +*.xccheckout + +### Xcode Patch ### +*.xcodeproj/* +!*.xcodeproj/project.pbxproj +!*.xcodeproj/xcshareddata/ +!*.xcodeproj/project.xcworkspace/ +!*.xcworkspace/contents.xcworkspacedata +/*.gcno +**/xcshareddata/WorkspaceSettings.xcsettings diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..7be81c7 --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,36 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.5.0 + hooks: + - id: fix-encoding-pragma + exclude: tests/test_data + - id: trailing-whitespace + exclude: tests/test_data + - id: end-of-file-fixer + exclude: tests/test_data + - id: check-docstring-first + - id: debug-statements + - id: check-toml + - id: check-yaml + exclude: tests/test_data + - id: requirements-txt-fixer + - id: detect-private-key + - id: check-merge-conflict + + - repo: https://github.com/psf/black + rev: 24.4.2 + hooks: + - id: black + exclude: tests/test_data + - id: black-jupyter + + - repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + args: ["--profile", "black"] + + - repo: https://github.com/astral-sh/ruff-pre-commit + rev: v0.4.4 + hooks: + - id: ruff diff --git a/README.md b/README.md index 5417d71..5946af0 100644 --- a/README.md +++ b/README.md @@ -1,47 +1,42 @@ # Expipe plugin CINPLA -Expipe plugin for CINPLA laboratory +Expipe plugin for the CINPLA laboratory. ## Installation -You can install the package with pip: +`expipe-plugin-cinpla` can be installed by running -```bash ->>> pip install expipe-plugin-cinpla -``` + $ pip install expipe-plugin-cinpla -or from sources: +It requires Python 3.10+ to run. + +If you want the latest features and can't wait for the next release, install from GitHub: + + $ pip install git+https://github.com/CINPLA/expipe-plugin-cinpla.git -```bash -git clone -cd expipe-plugin-cinpla -pip install -e . -``` ## Usage -The starting point is a valid `expipe` project. Refer to the [expipe docs]() to read more on how -to create one. +The starting point is a valid `expipe` project. Refer to the [expipe docs](https://expipe.readthedocs.io/en/latest/) to read more on how to create one. The recommended usage is via Jupyter Notebook / Lab, using the interactive widgets to Register, Process, Curate, and View your actions. To launch the interactive browser, you can run: + ```python from expipe_plugin_cinpla import display_browser project_path = "path-to-my-project" display_browser(project_path) - ``` ![alt text](docs/images/browser.png) - ## Updating old projects The current version uses Neurodata Without Borders as backend instead of Exdir. If you have an existing @@ -60,3 +55,40 @@ convert_old_project(old_project_path, new_project_path, probe_path) To check out other options, use `convert_old_project?` + +## How to contribute + +### Set up development environment + +First, we recommend to create a virtual environment and install `pip`; + +* Using [venv](https://packaging.python.org/en/latest/key_projects/#venv): + + $ python3.11 -m venv + $ source /bin/activate + $ python3 -m pip install --upgrade pip + +* Using [conda](https://conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html): + + $ conda create -n python=3.11 pip + $ conda activate + +Then install `expipe-plugin-cinpla` in editable mode from source: + + $ git clone https://github.com/CINPLA/expipe-plugin-cinpla.git + $ cd expipe_plugin_cinpla + $ python3 -m pip install -e ".[full]" + + +### pre-commit +We use [pre-commit](https://pre-commit.com/) to run Git hooks on every commit to identify simple issues such as trailing whitespace or not complying with the required formatting. Our pre-commit configuration is specified in the `.pre-commit-config.yml` file. + +To set up the Git hook scripts specified in `.pre-commit-config.yml`, run + + $ pre-commit install + +> **NOTE:** If `pre-commit` identifies formatting issues in the commited code, the pre-commit Git hooks will reformat the code. If code is reformatted, it will show up in your unstaged changes. Stage them and recommit to successfully commit your changes. + +It is also possible to run the pre-commit hooks without attempting a commit: + + $ pre-commit run --all-files diff --git a/docs/conf.py b/docs/conf.py index 7f625c8..70a3c88 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -22,7 +22,6 @@ # sys.path.insert(0, os.path.abspath('.')) import os -import re # import expipe_plugin_cinpla diff --git a/environment.yml b/environment.yml index e6c8fa7..39966db 100644 --- a/environment.yml +++ b/environment.yml @@ -3,7 +3,7 @@ name: cinpla channels: - defaults dependencies: - - python=3.10 + - python=3.11 - pip - pip: - expipe-plugin-cinpla diff --git a/notebooks/convert_project.py b/notebooks/convert_project.py index ec98fe8..4b7ed65 100644 --- a/notebooks/convert_project.py +++ b/notebooks/convert_project.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from pathlib import Path import expipe_plugin_cinpla diff --git a/notebooks/expipe-plugin-cinpla-demo.ipynb b/notebooks/expipe-plugin-cinpla-demo.ipynb index b37a790..177c469 100644 --- a/notebooks/expipe-plugin-cinpla-demo.ipynb +++ b/notebooks/expipe-plugin-cinpla-demo.ipynb @@ -50,7 +50,9 @@ "metadata": {}, "outputs": [], "source": [ - "expipe_plugin_cinpla.convert_old_project(old_project_path, new_project_path, probe_path=\"tetrode_32_openephys.json\", debug_n_actions=5, overwrite=True)" + "expipe_plugin_cinpla.convert_old_project(\n", + " old_project_path, new_project_path, probe_path=\"tetrode_32_openephys.json\", debug_n_actions=5, overwrite=True\n", + ")" ] }, { diff --git a/probes/tetrode_32_openephys.json b/probes/tetrode_32_openephys.json index fc1b2b2..37d723f 100644 --- a/probes/tetrode_32_openephys.json +++ b/probes/tetrode_32_openephys.json @@ -787,4 +787,4 @@ ] } ] -} \ No newline at end of file +} diff --git a/pyproject.toml b/pyproject.toml index 6673ad2..05bcbec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -2,13 +2,13 @@ name = "expipe_plugin_cinpla" version = "0.1.5" authors = [ - { name="Mikkel Lepperod", email="mikkel@simula.no" }, - { name="Alessio Buccino", email="alessiop.buccino@gmail.com" }, + { name = "Mikkel Lepperod", email = "mikkel@simula.no" }, + { name = "Alessio Buccino", email = "alessiop.buccino@gmail.com" }, ] description = "Expipe plugins for the CINPLA lab." readme = "README.md" -requires-python = ">=3.9" +requires-python = ">=3.10" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", @@ -16,29 +16,26 @@ classifiers = [ ] dependencies = [ - "expipe>=0.6.0", + "expipe>=0.6.0", "neuroconv>=0.4.6", "pyopenephys>=1.2.0", - "spikeinterface[full,widgets]>=0.100.0", + "spikeinterface[full,widgets]>=0.100.0,<0.101.0", + "scikit-learn<1.5.0", "pynwb>=2.5.0", - "neuroconv>=0.4.6", "ipywidgets>=8.1.1", "nwbwidgets>=0.11.3", - "tbb>=2021.11.0", + "tbb>=2021.11.0; platform_system != 'Darwin'", "pynapple>=0.5.1", - "spython>=0.3.13", ] [project.urls] homepage = "https://github.com/CINPLA/expipe-plugin-cinpla" repository = "https://github.com/CINPLA/expipe-plugin-cinpla" - [build-system] requires = ["setuptools>=62.0"] build-backend = "setuptools.build_meta" - [tool.setuptools] include-package-data = true @@ -48,30 +45,73 @@ include = ["expipe_plugin_cinpla*"] namespaces = false [project.optional-dependencies] - -dev = [ - "pytest", - "pytest-cov", - "pytest-dependency", - "black" +dev = ["pre-commit", "black[jupyter]", "isort", "ruff"] +test = ["pytest", "pytest-cov", "pytest-dependency", "mountainsort5"] +docs = ["sphinx-gallery", "sphinx_rtd_theme"] +full = [ + "expipe_plugin_cinpla[dev]", + "expipe_plugin_cinpla[test]", + "expipe_plugin_cinpla[docs]", ] -test = [ - "pytest", - "pytest-cov", - "pytest-dependency", - "mountainsort5" -] +[tool.coverage.run] +omit = ["tests/*"] -docs = [ - "sphinx-gallery", - "sphinx_rtd_theme", -] +[tool.black] +line-length = 120 -[tool.coverage.run] -omit = [ - "tests/*", +[tool.ruff] +# Exclude a variety of commonly ignored directories. +exclude = [ + ".bzr", + ".direnv", + ".eggs", + ".git", + ".hg", + ".mypy_cache", + ".nox", + ".pants.d", + ".pytype", + ".ruff_cache", + ".svn", + ".tox", + ".venv", + "__pypackages__", + "_build", + "buck-out", + "build", + "dist", + "node_modules", + "venv", ] -[tool.black] +# In addition to the standard set of exclusions, omit: +extend-exclude = ["tests/test_data"] + +# Same as Black. line-length = 120 +indent-width = 4 + +# Assume Python 3.11. +target-version = "py311" + +[tool.ruff.lint] +# Enable Pyflakes (`F`) and a subset of the pycodestyle (`E`) codes by default. +# Unlike Flake8, Ruff doesn't enable pycodestyle warnings (`W`) or +# McCabe complexity (`C901`) by default. +select = ["E4", "E7", "E9", "F"] +ignore = [] + +# Allow fix for all enabled rules (when `--fix`) is provided. +fixable = ["ALL"] +unfixable = [] + +# Allow unused variables when underscore-prefixed. +dummy-variable-rgx = "^(_+|(_+[a-zA-Z0-9_]*[a-zA-Z0-9]+?))$" + +[tool.ruff.lint.per-file-ignores] +"src/expipe_plugin_cinpla/cli/utils.py" = ["F403"] +"src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py" = ["F821"] +"src/expipe_plugin_cinpla/widgets/utils.py" = ["F841"] # TODO: fix warning +"tests/test_cli.py" = ["F841"] # TODO: fix warning +"tests/test_script.py" = ["F841"] # TODO: fix warning diff --git a/requirements.txt b/requirements.txt index bd0846c..35d5ce0 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,10 +1,10 @@ expipe>=0.5.1 -neuroconv>=0.4.6 ipywidgets>=8.1.1 jupyter_contrib_nbextensions>=0.7.0 jupyterlab==3.6.5 +neuroconv>=0.4.6 +neuroconv>=0.4.6 +pynwb>=2.6.0 pyopenephys>=1.2.0 -tbb>=2021.11.0 spikeinterface[full,widgets]>=0.100.0 -pynwb>=2.6.0 -neuroconv>=0.4.6 \ No newline at end of file +tbb>=2021.11.0 diff --git a/setup.py b/setup.py index d0453b9..7a58127 100644 --- a/setup.py +++ b/setup.py @@ -1,43 +1,6 @@ -# # -*- coding: utf-8 -*- -# from setuptools import setup, find_packages - -# long_description = open("README.md").read() - -# with open("requirements.txt", mode='r') as f: -# install_requires = f.read().split('\n') - -# install_requires = [e for e in install_requires if len(e) > 0] - -# d = {} -# exec(open("expipe_plugin_cinpla/version.py").read(), None, d) -# version = d['version'] -# pkg_name = "expipe-pligin-cinpla" - -# setup( -# name=pkg_name, -# packages=find_packages(), -# version=version, -# include_package_data=True, -# author="CINPLA", -# author_email="", -# maintainer="Mikkel Elle Lepperød, Alessio Buccino", -# maintainer_email="mikkel@simula.no", -# platforms=["Linux", "Windows"], -# description="Expipe plugins for the CINPLA lab", -# url="https://github.com/CINPLA/expipe-plugin-cinpla", -# long_description_content_type="text/markdown", -# install_requires=install_requires, -# long_description=long_description, -# classifiers=['Intended Audience :: Science/Research', -# 'License :: OSI Approved :: GNU General Public License v2 (GPLv2)', -# 'Natural Language :: English', -# 'Programming Language :: Python :: 3', -# 'Topic :: Scientific/Engineering'], -# python_requires='>=3.9', -# ) +# -*- coding: utf-8 -*- import setuptools - if __name__ == "__main__": setuptools.setup() diff --git a/src/expipe_plugin_cinpla/__init__.py b/src/expipe_plugin_cinpla/__init__.py index 0e0c0fb..63fbebe 100644 --- a/src/expipe_plugin_cinpla/__init__.py +++ b/src/expipe_plugin_cinpla/__init__.py @@ -1,7 +1,8 @@ -from .cli import CinplaPlugin -from .widgets import display_browser -from .scripts import convert_old_project - +# -*- coding: utf-8 -*- import importlib.metadata +from .cli import CinplaPlugin # noqa +from .scripts import convert_old_project # noqa +from .widgets import display_browser # noqa + __version__ = importlib.metadata.version("expipe_plugin_cinpla") diff --git a/src/expipe_plugin_cinpla/cli/__init__.py b/src/expipe_plugin_cinpla/cli/__init__.py index c526aa8..6fca81e 100644 --- a/src/expipe_plugin_cinpla/cli/__init__.py +++ b/src/expipe_plugin_cinpla/cli/__init__.py @@ -1 +1,2 @@ -from .main import CinplaPlugin +# -*- coding: utf-8 -*- +from .main import CinplaPlugin # noqa diff --git a/src/expipe_plugin_cinpla/cli/main.py b/src/expipe_plugin_cinpla/cli/main.py index cb699f7..cf2b0a3 100644 --- a/src/expipe_plugin_cinpla/cli/main.py +++ b/src/expipe_plugin_cinpla/cli/main.py @@ -1,9 +1,9 @@ +# -*- coding: utf-8 -*- import click - from expipe.cliutils.plugin import IPlugin -from .register import attach_to_register from .process import attach_to_process +from .register import attach_to_register class CinplaPlugin(IPlugin): diff --git a/src/expipe_plugin_cinpla/cli/process.py b/src/expipe_plugin_cinpla/cli/process.py index 97b3108..04cca4b 100644 --- a/src/expipe_plugin_cinpla/cli/process.py +++ b/src/expipe_plugin_cinpla/cli/process.py @@ -1,5 +1,7 @@ -import click +# -*- coding: utf-8 -*- from pathlib import Path + +import click import ruamel.yaml as yaml from expipe_plugin_cinpla.imports import project diff --git a/src/expipe_plugin_cinpla/cli/register.py b/src/expipe_plugin_cinpla/cli/register.py index 08d11b2..43dae56 100644 --- a/src/expipe_plugin_cinpla/cli/register.py +++ b/src/expipe_plugin_cinpla/cli/register.py @@ -1,14 +1,17 @@ -import click -from pathlib import Path +# -*- coding: utf-8 -*- from datetime import datetime +from pathlib import Path +import click import expipe +from expipe_plugin_cinpla.cli.utils import ( + validate_adjustment, + validate_angle, + validate_depth, + validate_position, +) from expipe_plugin_cinpla.scripts import register -from expipe_plugin_cinpla.cli.utils import validate_depth, validate_position, validate_angle, validate_adjustment - - -import spikeinterface.sorters as ss def attach_to_register(cli): diff --git a/src/expipe_plugin_cinpla/cli/utils.py b/src/expipe_plugin_cinpla/cli/utils.py index 08b2ec8..d982f3d 100644 --- a/src/expipe_plugin_cinpla/cli/utils.py +++ b/src/expipe_plugin_cinpla/cli/utils.py @@ -1,4 +1,7 @@ +# -*- coding: utf-8 -*- import collections +import copy + import click from expipe_plugin_cinpla.imports import * @@ -98,12 +101,12 @@ def optional_choice(ctx, param, value): assert isinstance(options, list) if value is None: if param.required: - raise ValueError('Missing option "{}"'.format(param.opts)) + raise ValueError(f'Missing option "{param.opts}"') return value if param.multiple: if len(value) == 0: if param.required: - raise ValueError('Missing option "{}"'.format(param.opts)) + raise ValueError(f'Missing option "{param.opts}"') return value if len(options) == 0: return value @@ -113,8 +116,8 @@ def optional_choice(ctx, param, value): value, ] for val in value: - if not val in options: - raise ValueError('Value "{}" not in "{}".'.format(val, options)) + if val not in options: + raise ValueError(f'Value "{val}" not in "{options}".') else: if param.multiple: return value diff --git a/src/expipe_plugin_cinpla/scripts/data_processing.py b/src/expipe_plugin_cinpla/data_loader.py similarity index 90% rename from src/expipe_plugin_cinpla/scripts/data_processing.py rename to src/expipe_plugin_cinpla/data_loader.py index 72a11f5..3b65564 100644 --- a/src/expipe_plugin_cinpla/scripts/data_processing.py +++ b/src/expipe_plugin_cinpla/data_loader.py @@ -1,13 +1,13 @@ +# -*- coding: utf-8 -*- """Utils for loading data from NWB files""" -import numpy as np -import quantities as pq import neo +import numpy as np +import quantities as pq import spikeinterface as si import spikeinterface.extractors as se -from pynwb import NWBHDF5IO -from .utils import _get_data_path +from .scripts.utils import _get_data_path def get_data_path(action): @@ -92,6 +92,8 @@ def load_leds(data_path): x1, y1, t1, x2, y2, t2, stop_time: tuple The x and y positions of the red and green LEDs, the timestamps and the stop time """ + from pynwb import NWBHDF5IO + io = NWBHDF5IO(str(data_path), "r") nwbfile = io.read() @@ -130,6 +132,12 @@ def load_lfp(data_path, channel_group=None, lim=None): LFP: neo.AnalogSignal The LFP signal """ + # from pynwb import NWBHDF5IO + + # get the session start time + # TODO: are io and nwbfile needed? + # io = NWBHDF5IO(str(data_path), "r") + # nwbfile = io.read() recording_lfp = se.read_nwb_recording( str(data_path), electrical_series_path="processing/ecephys/LFP/ElectricalSeriesLFP" ) @@ -190,10 +198,11 @@ def load_epochs(data_path, label_column=None): epochs: neo.Epoch The trials as NEO epochs """ + from pynwb import NWBHDF5IO + with NWBHDF5IO(str(data_path), "r") as io: nwbfile = io.read() trials = nwbfile.trials.to_dataframe() - start_times = trials["start_time"].values * pq.s stop_times = trials["stop_time"].values * pq.s durations = stop_times - start_times @@ -248,6 +257,12 @@ def load_spiketrains(data_path, channel_group=None, lim=None): spiketrains: list of NEO spike trains The spike trains """ + # from pynwb import NWBHDF5IO + + # get the session start time + # TODO: are io and nwbfile needed? + # io = NWBHDF5IO(str(data_path), "r") + # nwbfile = io.read() recording = se.read_nwb_recording(str(data_path), electrical_series_path="acquisition/ElectricalSeries") sorting = se.read_nwb_sorting(str(data_path), electrical_series_path="acquisition/ElectricalSeries") @@ -260,9 +275,11 @@ def load_spiketrains(data_path, channel_group=None, lim=None): unit_id for unit_index, unit_id in enumerate(sorting.unit_ids) if groups[unit_index] == channel_group ] sptr = [] - # build neo pbjects + # build neo objects for unit in unit_ids: - times = sorting.get_unit_spike_train(unit, return_times=True) * pq.s + spike_times = sorting.get_unit_spike_train(unit, return_times=True) + # subtract the session start time + spike_times = spike_times * pq.s if lim is None: times = recording.get_times() * pq.s t_start = times[0] @@ -270,12 +287,13 @@ def load_spiketrains(data_path, channel_group=None, lim=None): else: t_start = pq.Quantity(lim[0], "s") t_stop = pq.Quantity(lim[1], "s") - mask = (times >= t_start) & (times <= t_stop) - times = times[mask] + mask = (spike_times >= t_start) & (spike_times <= t_stop) + spike_times = spike_times[mask] st = neo.SpikeTrain( - times=times, t_start=t_start, t_stop=t_stop, sampling_rate=sorting.sampling_frequency * pq.Hz + times=spike_times, t_start=t_start, t_stop=t_stop, sampling_rate=sorting.sampling_frequency * pq.Hz ) + st.annotations.update({"name": unit}) for p in sorting.get_property_keys(): st.annotations.update({p: sorting.get_unit_property(unit, p)}) sptr.append(st) @@ -313,7 +331,7 @@ def load_unit_annotations(data_path, channel_group=None): ] for unit in unit_ids: - annotations = {} + annotations = {"name": unit} for p in sorting.get_property_keys(): annotations.update({p: sorting.get_unit_property(unit, p)}) units.append(annotations) diff --git a/src/expipe_plugin_cinpla/imports.py b/src/expipe_plugin_cinpla/imports.py index 1038c0e..3b4af5e 100644 --- a/src/expipe_plugin_cinpla/imports.py +++ b/src/expipe_plugin_cinpla/imports.py @@ -1,18 +1,7 @@ -# import click -# from expipe.cliutils.misc import lazy_import - -import expipe +# -*- coding: utf-8 -*- from pathlib import Path -# @lazy_import -# def expipe(): -# import expipe -# return expipe - -# @lazy_import -# def pathlib(): -# import pathlib -# return pathlib +import expipe local_root, _ = expipe.config._load_local_config(Path.cwd()) if local_root is not None: @@ -23,171 +12,3 @@ class P: config = {} project = P - - -# @lazy_import -# def pd(): -# import pandas as pd -# return pd - -# @lazy_import -# def dt(): -# import datetime as dt -# return dt - -# @lazy_import -# def yaml(): -# import yaml -# return yaml - -# @lazy_import -# def ipywidgets(): -# import ipywidgets -# return ipywidgets - -# @lazy_import -# def pyopenephys(): -# import pyopenephys -# return pyopenephys - -# # @lazy_import -# # def openephys_io(): -# # from expipe_io_neuro.openephys import openephys as openephys_io -# # return openephys_io - -# @lazy_import -# def pyintan(): -# import pyintan -# return pyintan - -# @lazy_import -# def pyxona(): -# import pyxona -# return pyxona - -# @lazy_import -# def platform(): -# import platform -# return platform - -# @lazy_import -# def csv(): -# import csv -# return csv - -# @lazy_import -# def json(): -# import json -# return json - -# # @lazy_import -# # def axona(): -# # from expipe_io_neuro import axona -# # return axona - -# @lazy_import -# def os(): -# import os -# return os - -# @lazy_import -# def shutil(): -# import shutil -# return shutil - -# @lazy_import -# def datetime(): -# import datetime -# return datetime - -# @lazy_import -# def subprocess(): -# import subprocess -# return subprocess - -# @lazy_import -# def tarfile(): -# import tarfile -# return tarfile - -# @lazy_import -# def paramiko(): -# import paramiko -# return paramiko - -# @lazy_import -# def getpass(): -# import getpass -# return getpass - -# @lazy_import -# def tqdm(): -# from tqdm import tqdm -# return tqdm - -# @lazy_import -# def scp(): -# import scp -# return scp - -# @lazy_import -# def neo(): -# import neo -# return neo - -# @lazy_import -# def exdir(): -# import exdir -# import exdir.plugins.quantities -# return exdir - -# @lazy_import -# def pq(): -# import quantities as pq -# return pq - -# @lazy_import -# def logging(): -# import logging -# return logging - -# @lazy_import -# def np(): -# import numpy as np -# return np - -# @lazy_import -# def copy(): -# import copy -# return copy - -# @lazy_import -# def scipy(): -# import scipy -# import scipy.io -# return scipy - -# @lazy_import -# def glob(): -# import glob -# return glob - -# @lazy_import -# def el(): -# import elephant as el -# return el - -# @lazy_import -# def sys(): -# import sys -# return sys - -# @lazy_import -# def pprint(): -# import pprint -# return pprint - -# @lazy_import -# def collections(): -# import collections -# return collections diff --git a/src/expipe_plugin_cinpla/nwbutils/cinplanwbconverter.py b/src/expipe_plugin_cinpla/nwbutils/cinplanwbconverter.py index 090bb2e..1eea5a6 100644 --- a/src/expipe_plugin_cinpla/nwbutils/cinplanwbconverter.py +++ b/src/expipe_plugin_cinpla/nwbutils/cinplanwbconverter.py @@ -1,6 +1,7 @@ -from probeinterface import ProbeGroup +# -*- coding: utf-8 -*- from neuroconv import NWBConverter from neuroconv.datainterfaces import OpenEphysRecordingInterface +from probeinterface import ProbeGroup from .interfaces.openephystrackinginterface import OpenEphysTrackingInterface diff --git a/src/expipe_plugin_cinpla/nwbutils/interfaces/__init__.py b/src/expipe_plugin_cinpla/nwbutils/interfaces/__init__.py index ab3bebe..4f94197 100644 --- a/src/expipe_plugin_cinpla/nwbutils/interfaces/__init__.py +++ b/src/expipe_plugin_cinpla/nwbutils/interfaces/__init__.py @@ -1 +1,2 @@ -from .openephystrackinginterface import OpenEphysTrackingInterface +# -*- coding: utf-8 -*- +from .openephystrackinginterface import OpenEphysTrackingInterface # noqa diff --git a/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py b/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py index 84b20f6..f11cb0a 100644 --- a/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py +++ b/src/expipe_plugin_cinpla/nwbutils/interfaces/openephystrackinginterface.py @@ -1,14 +1,11 @@ +# -*- coding: utf-8 -*- import warnings + import numpy as np import pyopenephys - -from pynwb.behavior import ( - Position, - SpatialSeries, -) - from neuroconv import BaseDataInterface from neuroconv.utils import FolderPathType +from pynwb.behavior import Position, SpatialSeries class OpenEphysTrackingInterface(BaseDataInterface): @@ -78,17 +75,18 @@ def add_to_nwbfile( rising = rising[:-1] if len(rising) == len(falling): - nwbfile.add_trial_column( - name="channel", - description="Open Ephys channel", - ) - nwbfile.add_trial_column( - name="processor", - description="Open Ephys processor that recorded the event", - ) + if nwbfile.trials is None: + nwbfile.add_trial_column( + name="channel", + description="Open Ephys channel", + ) + nwbfile.add_trial_column( + name="processor", + description="Open Ephys processor that recorded the event", + ) start_times = times[rising].rescale("s").magnitude stop_times = times[falling].rescale("s").magnitude - for start, stop in zip(start_times, stop_times): + for start, stop in zip(start_times, stop_times, strict=False): nwbfile.add_trial( start_time=start, stop_time=stop, diff --git a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py index c402c6a..eba833e 100644 --- a/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py +++ b/src/expipe_plugin_cinpla/nwbutils/nwbwidgetsunitviewer.py @@ -1,14 +1,10 @@ +# -*- coding: utf-8 -*- from functools import partial -import numpy as np -import ipywidgets as widgets -from ipywidgets import Layout, interactive_output -from pynwb.misc import Units -from pynwb.behavior import SpatialSeries +import ipywidgets as widgets import matplotlib.pyplot as plt - -from nwbwidgets.view import default_neurodata_vis_spec - +import numpy as np +from ipywidgets import Layout, interactive_output color_wheel = plt.rcParams["axes.prop_cycle"].by_key()["color"] @@ -16,7 +12,7 @@ class UnitWaveformsWidget(widgets.VBox): def __init__( self, - units: Units, + units: "pynwb.misc.Units", ): super().__init__() @@ -55,7 +51,7 @@ def on_unit_change(self, change): self.unit_group_text.value = f"Group: {unit_group}" -def show_unit_waveforms(units: Units, unit_index=None, ax=None): +def show_unit_waveforms(units: "pynwb.mis.Units", unit_index=None, ax=None): """ TODO: add docstring @@ -105,8 +101,8 @@ def show_unit_waveforms(units: Units, unit_index=None, ax=None): class UnitRateMapWidget(widgets.VBox): def __init__( self, - units: Units, - spatial_series: SpatialSeries = None, + units: "pynwb.mis.Units", + spatial_series: "SpatialSeries" = None, ): super().__init__() @@ -186,6 +182,8 @@ def on_unit_change(self, change): self.unit_group_text.value = f"Group: {unit_group}" def get_spatial_series(self): + from pynwb.behavior import SpatialSeries + spatial_series = dict() nwbfile = self.units.get_ancestor("NWBFile") for item in nwbfile.all_children(): @@ -265,6 +263,9 @@ def show_unit_rate_maps(self, unit_index=None, spatial_series_selector=None, num def get_custom_spec(): + from nwbwidgets.view import default_neurodata_vis_spec + from pynwb.misc import Units + custom_neurodata_vis_spec = default_neurodata_vis_spec.copy() # remove irrelevant widgets diff --git a/src/expipe_plugin_cinpla/scripts/__init__.py b/src/expipe_plugin_cinpla/scripts/__init__.py index 3ca085f..085d970 100644 --- a/src/expipe_plugin_cinpla/scripts/__init__.py +++ b/src/expipe_plugin_cinpla/scripts/__init__.py @@ -1 +1,2 @@ -from .convert_old_project import convert_old_project +# -*- coding: utf-8 -*- +from .convert_old_project import convert_old_project # noqa diff --git a/src/expipe_plugin_cinpla/scripts/convert_old_project.py b/src/expipe_plugin_cinpla/scripts/convert_old_project.py index edb2461..a0884d1 100644 --- a/src/expipe_plugin_cinpla/scripts/convert_old_project.py +++ b/src/expipe_plugin_cinpla/scripts/convert_old_project.py @@ -1,14 +1,14 @@ +# -*- coding: utf-8 -*- import shutil +import time from datetime import datetime, timedelta from pathlib import Path -import time import expipe -from .utils import _get_data_path -from .register import convert_to_nwb, register_entity -from .process import process_ecephys from .curation import SortingCurator +from .process import process_ecephys +from .register import convert_to_nwb, register_entity def convert_old_project( @@ -128,6 +128,20 @@ def convert_old_project( delimiter = "*" * len(process_msg) print(f"\n{delimiter}\n{process_msg}\n{delimiter}\n") old_action = old_actions[action_id] + + old_action_folder = old_project.path / "actions" / action_id + new_action_folder = new_project.path / "actions" / action_id + old_data_folder = old_action_folder / "data" + new_data_folder = new_action_folder / "data" + # main.exdir + old_exdir_folder = old_data_folder / "main.exdir" + + if exist_ok and not new_action_folder.is_dir(): + # Copy action that previously failed + print(f">>> Re-copying action {action_id} to new project\n") + shutil.copytree( + old_action_folder, new_action_folder, ignore=shutil.ignore_patterns("main.exdir", ".git") + ) new_action = new_project.actions[action_id] # replace file in attributes.yaml @@ -136,18 +150,12 @@ def convert_old_project( attributes_str = attributes_str.replace("main.exdir", "main.nwb") attributes_file.write_text(attributes_str) - old_data_folder = old_project.path / "actions" / action_id / "data" - new_data_folder = new_project.path / "actions" / action_id / "data" - - # main.exdir - old_exdir_folder = old_data_folder / "main.exdir" - # find open-ephys folder acquisition_folder = old_exdir_folder / "acquisition" openephys_folders = [p for p in acquisition_folder.iterdir() if p.is_dir()] if len(openephys_folders) != 1: print(f"Found {len(openephys_folders)} openephys folders in {acquisition_folder}!") - continue + raise ValueError("Expected to find exactly one openephys folder") openephys_path = openephys_folders[0] # here we assume the following action name: {entity_id}-{date}-{session} entity_id = action_id.split("-")[0] @@ -236,7 +244,7 @@ def convert_old_project( t_stop_all = time.perf_counter() print(f"\nTotal time: {t_stop_all - t_start_all:.2f} s") - done_msg = f"ALL DONE!" + done_msg = "ALL DONE!" delimeter = "*" * len(done_msg) print(f"\n{delimeter}\n{done_msg}\n{delimeter}\n") print(f"Successful: {len(actions_to_convert) - len(actions_failed)}\n") diff --git a/src/expipe_plugin_cinpla/scripts/curation.py b/src/expipe_plugin_cinpla/scripts/curation.py index f1cf693..d50bece 100644 --- a/src/expipe_plugin_cinpla/scripts/curation.py +++ b/src/expipe_plugin_cinpla/scripts/curation.py @@ -1,20 +1,16 @@ -import shutil +# -*- coding: utf-8 -*- import json -from pathlib import Path -import numpy as np -from pynwb import NWBHDF5IO -from pynwb.testing.mock.file import mock_NWBFile +import shutil import warnings -import spikeinterface.full as si -import spikeinterface.extractors as se -import spikeinterface.postprocessing as spost -import spikeinterface.qualitymetrics as sqm -import spikeinterface.curation as sc - -from spikeinterface.extractors.nwbextractors import _retrieve_unit_table_pynwb +import numpy as np +import spikeinterface as si -from .utils import _get_data_path, add_units_from_waveform_extractor, compute_and_set_unit_groups +from .utils import ( + _get_data_path, + add_units_from_waveform_extractor, + compute_and_set_unit_groups, +) warnings.filterwarnings("ignore", category=ResourceWarning) warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -75,6 +71,8 @@ def check_sortings_equal(self, sorting1, sorting2): return True def load_raw_sorting(self, sorter): + import spikeinterface.extractors as se + raw_units_path = f"processing/ecephys/RawUnits-{sorter}" try: sorting_raw = se.read_nwb_sorting( @@ -82,22 +80,28 @@ def load_raw_sorting(self, sorter): unit_table_path=raw_units_path, electrical_series_path="acquisition/ElectricalSeries", ) + return sorting_raw except Exception as e: print(f"Could not load raw sorting for {sorter}. Using None: {e}") - sorting_raw = None - return sorting_raw + return None def load_raw_units(self, sorter): + from pynwb import NWBHDF5IO + from spikeinterface.extractors.nwbextractors import _retrieve_unit_table_pynwb + raw_units_path = f"processing/ecephys/RawUnits-{sorter}" self.io = NWBHDF5IO(self.nwb_path_main, "r") nwbfile = self.io.read() try: units = _retrieve_unit_table_pynwb(nwbfile, raw_units_path) - except: - units = None - return units + return units + except Exception as e: + print(f"Could not load raw units for {sorter}. Using None: {e}") + return None def load_main_units(self): + from pynwb import NWBHDF5IO + self.io = NWBHDF5IO(self.nwb_path_main, "r") nwbfile = self.io.read() return nwbfile.units @@ -106,6 +110,8 @@ def construct_curated_units(self): if len(self.curated_we.unit_ids) == 0: print("No units left after curation.") return + from pynwb import NWBHDF5IO + self.io = NWBHDF5IO(self.nwb_path_main, "r") nwbfile = self.io.read() add_units_from_waveform_extractor( @@ -136,11 +142,18 @@ def apply_curation(self, sorter, curated_sorting): print(f"No curation was performed for {sorter}. Using raw sorting") self.curated_we = None else: + import spikeinterface.curation as sc + import spikeinterface.postprocessing as spost + import spikeinterface.qualitymetrics as sqm + recording = self.load_processed_recording(sorter) - # if not sort by group, extract dense and estimate group - if "group" not in curated_sorting.get_property_keys(): - compute_and_set_unit_groups(curated_sorting, recording) + # remove excess spikes + print("Removing excess spikes from curated sorting") + curated_sorting = sc.remove_excess_spikes(curated_sorting, recording=recording) + + # if "group" is not available or some missing groups, extract dense and estimate group + compute_and_set_unit_groups(curated_sorting, recording) print("Extracting waveforms on curated sorting") self.curated_we = si.extract_waveforms( @@ -148,7 +161,7 @@ def apply_curation(self, sorter, curated_sorting): curated_sorting, folder=None, mode="memory", - max_spikes_per_unit=100, + max_spikes_per_unit=None, sparse=True, method="by_property", by_property="group", @@ -171,6 +184,8 @@ def apply_curation(self, sorter, curated_sorting): print("Done applying curation") def load_from_phy(self, sorter): + import spikeinterface.extractors as se + phy_path = self.si_path / sorter / "phy" sorting_phy = se.read_phy(phy_path, exclude_cluster_groups=["noise"]) @@ -189,11 +204,13 @@ def get_sortingview_link(self, sorter): visualization_json = self.si_path / sorter / "sortingview_links.json" if not visualization_json.is_file(): return "Sorting view link not found." - with open(visualization_json, "r") as f: + with open(visualization_json) as f: sortingview_links = json.load(f) return sortingview_links["raw"] def apply_sortingview_curation(self, sorter, curated_link): + import spikeinterface.curation as sc + sorting_raw = self.load_raw_sorting(sorter) assert sorting_raw is not None, f"Could not load raw sorting for {sorter}." sorting_raw = sorting_raw.save(format="memory") @@ -208,7 +225,7 @@ def apply_sortingview_curation(self, sorter, curated_link): uri = curation_str[curation_str.find("sha1://") : -2] sorting_curated = sc.apply_sortingview_curation(sorting_raw, uri_or_json=uri) # exclude noise - good_units = sorting_curated.unit_ids[sorting_curated.get_property("noise") == False] + good_units = sorting_curated.unit_ids[sorting_curated.get_property("noise") == False] # noqa E712 # create single property for SUA and MUA sorting_curated = sorting_curated.select_units(good_units) self.apply_curation(sorter, sorting_curated) @@ -243,6 +260,7 @@ def save_to_nwb(self): if self.curated_we is None: print("No curation was performed.") return + from pynwb import NWBHDF5IO # trick to get rid of Units first with NWBHDF5IO(self.nwb_path_main, mode="r") as read_io: diff --git a/src/expipe_plugin_cinpla/scripts/process.py b/src/expipe_plugin_cinpla/scripts/process.py index e293397..2313668 100644 --- a/src/expipe_plugin_cinpla/scripts/process.py +++ b/src/expipe_plugin_cinpla/scripts/process.py @@ -1,14 +1,13 @@ -import shutil +# -*- coding: utf-8 -*- import contextlib -import time import json -import os +import shutil +import time + import numpy as np from expipe_plugin_cinpla.scripts import utils -from ..nwbutils.cinplanwbconverter import CinplaNWBConverter - def process_ecephys( project, @@ -33,18 +32,17 @@ def process_ecephys( verbose=True, ): import warnings + import spikeinterface as si + import spikeinterface.exporters as sexp import spikeinterface.extractors as se - import spikeinterface.preprocessing as spre - import spikeinterface.sorters as ss import spikeinterface.postprocessing as spost + import spikeinterface.preprocessing as spre import spikeinterface.qualitymetrics as sqm - import spikeinterface.exporters as sexp + import spikeinterface.sorters as ss import spikeinterface.widgets as sw - - from pynwb import NWBHDF5IO - from neuroconv.tools.spikeinterface import add_recording + from pynwb import NWBHDF5IO from .utils import add_units_from_waveform_extractor, compute_and_set_unit_groups @@ -226,15 +224,18 @@ def process_ecephys( **spikesorter_params, ) except Exception as e: - try: - shutil.rmtree(output_folder) - except: - if verbose: - print(f"\tCould not tmp processing folder: {output_folder}") + shutil.rmtree(output_folder) raise Exception(f"Spike sorting failed:\n\n{e}") if verbose: print(f"\tFound {len(sorting.get_unit_ids())} units!") + # remove units with less than n_components spikes + num_spikes = sorting.count_num_spikes_per_unit() + selected_units = sorting.unit_ids[np.array(list(num_spikes.values())) >= n_components] + n_too_few_spikes = int(len(sorting.unit_ids) - len(selected_units)) + print(f"\tRemoved {n_too_few_spikes} units with less than {n_components} spikes") + sorting = sorting.select_units(selected_units) + # extract waveforms if verbose: print("\nPostprocessing") @@ -252,6 +253,7 @@ def process_ecephys( ms_after=ms_after, sparsity_temp_folder=si_folder / "tmp", sparse=True, + max_spikes_per_unit=None, method="by_property", by_property="group", ) @@ -269,7 +271,7 @@ def process_ecephys( if verbose: print("\tExporting to phy") - phy_folder = output_base_folder / f"phy" + phy_folder = output_base_folder / "phy" if phy_folder.is_dir(): shutil.rmtree(phy_folder) sexp.export_to_phy( @@ -383,7 +385,7 @@ def process_ecephys( if not provenance_file.is_file(): (output_base_folder / "recording_cmr").mkdir(parents=True, exist_ok=True) recording_cmr.dump_to_json(output_base_folder / "recording_cmr" / "provenance.json") - with open(output_base_folder / "recording_cmr" / "provenance.json", "r") as f: + with open(output_base_folder / "recording_cmr" / "provenance.json") as f: provenance = json.load(f) provenance_str = json.dumps(provenance) provenance_str = provenance_str.replace("main_tmp.nwb", "main.nwb") @@ -393,9 +395,9 @@ def process_ecephys( shutil.rmtree(output_base_folder / "recording_cmr") try: nwb_path_tmp.unlink() - except: + except Exception as e: print(f"Could not remove: {nwb_path_tmp}") - raise Exception + raise e if verbose: print("\tSaved to NWB: ", nwb_path) diff --git a/src/expipe_plugin_cinpla/scripts/register.py b/src/expipe_plugin_cinpla/scripts/register.py index 3440992..e186ca7 100644 --- a/src/expipe_plugin_cinpla/scripts/register.py +++ b/src/expipe_plugin_cinpla/scripts/register.py @@ -1,21 +1,21 @@ +# -*- coding: utf-8 -*- import shutil -import warnings import time -import numpy as np -from pathlib import Path +import warnings from datetime import datetime -import pytz -import quantities as pq - -import pyopenephys -import probeinterface as pi +from pathlib import Path import expipe +import numpy as np +import probeinterface as pi +import pyopenephys +import pytz +import quantities as pq def convert_to_nwb(project, action, openephys_path, probe_path, entity_id, user, include_events, overwrite): - from .utils import _make_data_path from ..nwbutils.cinplanwbconverter import CinplaNWBConverter + from .utils import _make_data_path nwb_path = _make_data_path(action, overwrite) @@ -162,8 +162,9 @@ def register_openephys_recording( if delete_raw_data: try: shutil.rmtree(openephys_path) - except: + except Exception as e: print("Could not remove: ", openephys_path) + raise e ### Adjustment ### @@ -181,7 +182,11 @@ def register_openephys_recording( def register_adjustment(project, entity_id, date, adjustment, user, depth, yes): - from expipe_plugin_cinpla.scripts.utils import position_to_dict, get_depth_from_surgery, query_yes_no + from expipe_plugin_cinpla.scripts.utils import ( + get_depth_from_surgery, + position_to_dict, + query_yes_no, + ) user = user or project.config.get("username") if user is None: @@ -203,7 +208,7 @@ def register_adjustment(project, entity_id, date, adjustment, user, depth, yes): try: action = project.actions[action_id] init = False - except KeyError as e: + except KeyError: action = project.create_action(action_id) init = True @@ -213,7 +218,7 @@ def register_adjustment(project, entity_id, date, adjustment, user, depth, yes): if name.endswith("adjustment"): deltas.append(int(name.split("_")[0])) index = max(deltas) + 1 - prev_depth = action.modules["{:03d}_adjustment".format(max(deltas))].contents["depth"] + prev_depth = action.modules[f"{max(deltas):03d}_adjustment"].contents["depth"] if init: if len(depth) > 0: prev_depth = position_to_dict(depth) @@ -221,14 +226,14 @@ def register_adjustment(project, entity_id, date, adjustment, user, depth, yes): prev_depth = get_depth_from_surgery(project=project, entity_id=entity_id) index = 0 - name = "{:03d}_adjustment".format(index) + name = f"{index:03d}_adjustment" if not isinstance(prev_depth, dict): print("Unable to retrieve previous depth.") return adjustment_dict = {key: dict() for key in prev_depth} current = {key: dict() for key in prev_depth} for key, probe, val, unit in adjustment: - pos_key = "probe_{}".format(probe) + pos_key = f"probe_{probe}" adjustment_dict[key][pos_key] = pq.Quantity(val, unit) for key, val in prev_depth.items(): for pos_key in prev_depth[key]: @@ -243,13 +248,13 @@ def last_probe(x): correct = query_yes_no( "Correct adjustment?: \n" + " ".join( - "{} {} = {}\n".format(key, pos_key, val[pos_key]) + f"{key} {pos_key} = {val[pos_key]}\n" for key, val in adjustment_dict.items() for pos_key in sorted(val, key=lambda x: last_probe(x)) ) + "New depth: \n" + " ".join( - "{} {} = {}\n".format(key, pos_key, val[pos_key]) + f"{key} {pos_key} = {val[pos_key]}\n" for key, val in current.items() for pos_key in sorted(val, key=lambda x: last_probe(x)) ), @@ -262,13 +267,13 @@ def last_probe(x): print( "Registering adjustment: \n" + " ".join( - "{} {} = {}\n".format(key, pos_key, val[pos_key]) + f"{key} {pos_key} = {val[pos_key]}\n" for key, val in adjustment_dict.items() for pos_key in sorted(val, key=lambda x: last_probe(x)) ) + " New depth: \n" + " ".join( - "{} {} = {}\n".format(key, pos_key, val[pos_key]) + f"{key} {pos_key} = {val[pos_key]}\n" for key, val in current.items() for pos_key in sorted(val, key=lambda x: last_probe(x)) ) @@ -300,10 +305,7 @@ def register_annotation( templates, correct_depth_answer, ): - from expipe_plugin_cinpla.scripts.utils import ( - register_templates, - register_depth, - ) + from expipe_plugin_cinpla.scripts.utils import register_depth, register_templates user = user or project.config.get("username") action = project.actions[action_id] @@ -333,9 +335,7 @@ def register_annotation( print("Registering message", message) action.create_message(text=message, user=user, datetime=datetime.now()) if depth: - correct_depth = register_depth( - project=project, action=action, depth=depth, answer=correct_depth_answer, overwrite=True - ) + _ = register_depth(project=project, action=action, depth=depth, answer=correct_depth_answer, overwrite=True) ### Entity ### @@ -389,7 +389,7 @@ def register_entity( if isinstance(val, (str, float, int)): entity.modules["register"][key]["value"] = val elif isinstance(val, tuple): - if not None in val: + if None not in val: entity.modules["register"][key] = pq.Quantity(val[0], val[1]) elif isinstance(val, type(None)): pass @@ -456,15 +456,15 @@ def register_surgery( for key, probe, x, y, z, unit in position: action.modules[key] = {} - probe_key = "probe_{}".format(probe) + probe_key = f"probe_{probe}" action.modules[key][probe_key] = {} - print("Registering position " + "{} {}: x={}, y={}, z={} {}".format(key, probe, x, y, z, unit)) + print("Registering position " + f"{key} {probe}: x={x}, y={y}, z={z} {unit}") action.modules[key][probe_key]["position"] = pq.Quantity([x, y, z], unit) for key, probe, ang, unit in angle: - probe_key = "probe_{}".format(probe) + probe_key = f"probe_{probe}" if probe_key not in action.modules[key]: action.modules[key][probe_key] = {} - print("Registering angle " + "{} {}: angle={} {}".format(key, probe, ang, unit)) + print("Registering angle " + f"{key} {probe}: angle={ang} {unit}") action.modules[key][probe_key]["angle"] = pq.Quantity(ang, unit) diff --git a/src/expipe_plugin_cinpla/scripts/utils.py b/src/expipe_plugin_cinpla/scripts/utils.py index cc75495..b3c8a55 100644 --- a/src/expipe_plugin_cinpla/scripts/utils.py +++ b/src/expipe_plugin_cinpla/scripts/utils.py @@ -1,12 +1,12 @@ -import sys +# -*- coding: utf-8 -*- import shutil -from datetime import datetime +import sys +from datetime import datetime, timedelta from pathlib import Path, PureWindowsPath -import numpy as np - -import quantities as pq import expipe +import numpy as np +import quantities as pq nwb_main_groups = ["acquisition", "analysis", "processing", "epochs", "general"] tmp_phy_folders = [".klustakwik2", ".phy", ".spikedetect"] @@ -47,14 +47,14 @@ def query_yes_no(question, default="yes", answer=None): def deltadate(adjustdate, regdate): - delta = regdate - adjustdate if regdate > adjustdate else datetime.timedelta.max + delta = regdate - adjustdate if regdate > adjustdate else timedelta.max return delta def position_to_dict(depth): position = {d[0]: dict() for d in depth} for key, num, val, unit in depth: - probe_key = "probe_{}".format(num) + probe_key = f"probe_{num}" position[key][probe_key] = pq.Quantity(val, unit) return position @@ -85,7 +85,6 @@ def write_python(path, dict): def get_depth_from_surgery(project, entity_id): - index = 0 surgery = project.actions[entity_id + "-surgery-implantation"] position = {} for key, module in surgery.modules.items(): @@ -97,7 +96,7 @@ def get_depth_from_surgery(project, entity_id): for key, groups in position.items(): for group, pos in groups.items(): if not isinstance(pos, pq.Quantity): - raise ValueError("Depth of implant " + '"{} {} = {}"'.format(key, group, pos) + " not recognized") + raise ValueError("Depth of implant " + f'"{key} {group} = {pos}"' + " not recognized") position[key][group] = pos.astype(float)[2] # index 2 = z return position @@ -106,7 +105,7 @@ def get_depth_from_adjustment(project, action, entity_id): DTIME_FORMAT = expipe.core.datetime_format try: adjustments = project.actions[entity_id + "-adjustment"] - except KeyError as e: + except KeyError: return None, None adjusts = {} for adjust in adjustments.modules.values(): @@ -130,7 +129,7 @@ def register_depth(project, action, depth=None, answer=None, overwrite=False): adjustdate = None else: curr_depth, adjustdate = get_depth_from_adjustment(project, action, action.entities[0]) - print("Adjust date time: {}\n".format(adjustdate)) + print(f"Adjust date time: {adjustdate}\n") if curr_depth is None: print("Cannot find current depth from adjustments.") return False @@ -140,7 +139,7 @@ def last_num(x): print( "".join( - "Depth: {} {} = {}\n".format(key, probe_key, val[probe_key]) + f"Depth: {key} {probe_key} = {val[probe_key]}\n" for key, val in curr_depth.items() for probe_key in sorted(val, key=lambda x: last_num(x)) ) @@ -176,19 +175,18 @@ def _make_data_path(action, overwrite, suffix=".nwb"): def _get_data_path(action): - if "main" not in action.data: - return try: + if "main" not in action.data: + return data_path = action.data_path("main") - except: - data_path = Path("None") - pass - if not data_path.is_dir(): - action_path = action._backend.path - project_path = action_path.parent.parent - # data_path = action.data['main'] - data_path = project_path / str(Path(PureWindowsPath(action.data["main"]))) - return data_path + if not data_path.is_dir(): + action_path = action._backend.path + project_path = action_path.parent.parent + # data_path = action.data['main'] + data_path = project_path / str(Path(PureWindowsPath(action.data["main"]))) + return data_path + except Exception: + return None def register_templates(action, templates, overwrite=False): @@ -290,7 +288,23 @@ def generate_phy_restore_files(phy_folder): def compute_and_set_unit_groups(sorting, recording): import spikeinterface as si - we_mem = si.extract_waveforms(recording, sorting, folder=None, mode="memory", sparse=False) - extremum_channel_indices = si.get_template_extremum_channel(we_mem, outputs="index") - unit_groups = recording.get_channel_groups()[np.array(list(extremum_channel_indices.values()))] - sorting.set_property("group", unit_groups) + if len(np.unique(recording.get_channel_groups())) == 1: + sorting.set_property("group", np.zeros(len(sorting.unit_ids), dtype="int64")) + else: + if "group" not in sorting.get_property_keys(): + we_mem = si.extract_waveforms(recording, sorting, folder=None, mode="memory", sparse=False) + extremum_channel_indices = si.get_template_extremum_channel(we_mem, outputs="index") + unit_groups = recording.get_channel_groups()[np.array(list(extremum_channel_indices.values()))] + sorting.set_property("group", unit_groups) + else: + unit_groups = sorting.get_property("group") + # if there are units without group, we need to compute them + unit_ids_without_group = np.array(sorting.unit_ids)[np.where(unit_groups == "nan")[0]] + if len(unit_ids_without_group) > 0: + sorting_no_group = sorting.select_units(unit_ids=unit_ids_without_group) + we_mem = si.extract_waveforms(recording, sorting_no_group, folder=None, mode="memory", sparse=False) + extremum_channel_indices = si.get_template_extremum_channel(we_mem, outputs="index") + unit_groups[unit_ids_without_group] = recording.get_channel_groups()[ + np.array(list(extremum_channel_indices.values())) + ] + sorting.set_property("group", unit_groups) diff --git a/src/expipe_plugin_cinpla/tools/__init__.py b/src/expipe_plugin_cinpla/tools/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/expipe_plugin_cinpla/tools/data_processing.py b/src/expipe_plugin_cinpla/tools/data_processing.py new file mode 100644 index 0000000..eca4ee6 --- /dev/null +++ b/src/expipe_plugin_cinpla/tools/data_processing.py @@ -0,0 +1,508 @@ +# -*- coding: utf-8 -*- +# This is work in progress, +import pathlib +import warnings + +import expipe +import numpy as np +import spatial_maps as sp + +from expipe_plugin_cinpla.data_loader import ( + get_channel_groups, + get_duration, + load_epochs, + load_leds, + load_lfp, + load_spiketrains, +) + + +def view_active_channels(action, sorter): + path = action.data_path() + sorter_path = path / "spikeinterface" / sorter / "phy" + return np.load(sorter_path / "channel_map_si.npy") + + +def _cut_to_same_len(*args): + out = [] + lens = [] + for arg in args: + lens.append(len(arg)) + minlen = min(lens) + for arg in args: + out.append(arg[:minlen]) + return out + + +def velocity_filter(x, y, t, threshold): + """ + Removes values above threshold + Parameters + ---------- + x : quantities.Quantity array in m + 1d vector of x positions + y : quantities.Quantity array in m + 1d vector of y positions + t : quantities.Quantity array in s + 1d vector of times at x, y positions + threshold : float + """ + assert len(x) == len(y) == len(t), "x, y, t must have same length" + vel = np.gradient([x, y], axis=1) / np.gradient(t) + speed = np.linalg.norm(vel, axis=0) + speed_mask = speed < threshold + speed_mask = np.append(speed_mask, 0) + x = x[np.where(speed_mask)] + y = y[np.where(speed_mask)] + t = t[np.where(speed_mask)] + return x, y, t + + +def interp_filt_position(x, y, tm, fs=100, f_cut=10): + """ + rapid head movements will contribute to velocity artifacts, + these can be removed by low-pass filtering + see http://www.ncbi.nlm.nih.gov/pmc/articles/PMC1876586/ + code addapted from Espen Hagen + Parameters + ---------- + x : quantities.Quantity array in m + 1d vector of x positions + y : quantities.Quantity array in m + 1d vector of y positions + tm : quantities.Quantity array in s + 1d vector of times at x, y positions + fs : quantities scalar in Hz + return radians + Returns + ------- + out : angles, resized t + """ + import scipy.signal as ss + + assert len(x) == len(y) == len(tm), "x, y, t must have same length" + t = np.arange(tm.min(), tm.max() + 1.0 / fs, 1.0 / fs) + x = np.interp(t, tm, x) + y = np.interp(t, tm, y) + # rapid head movements will contribute to velocity artifacts, + # these can be removed by low-pass filteringpar + # see http://www.ncbi.nlm.nih.gov/pmc/articles/PMC1876586/ + # code addapted from Espen Hagen + b, a = ss.butter(N=1, Wn=f_cut * 2 / fs) + # zero phase shift filter + x = ss.filtfilt(b, a, x) + y = ss.filtfilt(b, a, y) + # we tolerate small interpolation errors + x[(x > -1e-3) & (x < 0.0)] = 0.0 + y[(y > -1e-3) & (y < 0.0)] = 0.0 + + return x, y, t + + +def rm_nans(*args): + """ + Removes nan from all corresponding arrays + Parameters + ---------- + args : arrays, lists or quantities which should have removed nans in + all the same indices + Returns + ------- + out : args with removed nans + """ + nan_indices = [] + for arg in args: + nan_indices.extend(np.where(np.isnan(arg))[0].tolist()) + nan_indices = np.unique(nan_indices) + out = [] + for arg in args: + out.append(np.delete(arg, nan_indices)) + return out + + +def filter_xy_zero(x, y, t): + (idxs,) = np.where((x == 0) & (y == 0)) + return [np.delete(a, idxs) for a in [x, y, t]] + + +def filter_xy_box_size(x, y, t, box_size): + (idxs,) = np.where((x > box_size[0]) | (x < 0) | (y > box_size[1]) | (y < 0)) + return [np.delete(a, idxs) for a in [x, y, t]] + + +def filter_t_zero_duration(x, y, t, duration): + (idxs,) = np.where((t < 0) | (t > duration)) + return [np.delete(a, idxs) for a in [x, y, t]] + + +def load_head_direction(data_path, sampling_rate, low_pass_frequency, box_size): + from head_direction.head import head_direction + + x1, y1, t1, x2, y2, t2, stop_time = load_leds(data_path) + + x1, y1, t1 = rm_nans(x1, y1, t1) + x2, y2, t2 = rm_nans(x2, y2, t2) + + x1, y1, t1 = filter_t_zero_duration(x1, y1, t1, stop_time) + x2, y2, t2 = filter_t_zero_duration(x2, y2, t2, stop_time) + + # OE saves 0.0 when signal is lost, these can be removed + x1, y1, t1 = filter_xy_zero(x1, y1, t1) + x2, y2, t2 = filter_xy_zero(x2, y2, t2) + + # x1, y1, t1 = filter_xy_box_size(x1, y1, t1, box_size) + # x2, y2, t2 = filter_xy_box_size(x2, y2, t2, box_size) + + x1, y1, t1 = interp_filt_position(x1, y1, t1, fs=sampling_rate, f_cut=low_pass_frequency) + x2, y2, t2 = interp_filt_position(x2, y2, t2, fs=sampling_rate, f_cut=low_pass_frequency) + + x1, y1, t1, x2, y2, t2 = _cut_to_same_len(x1, y1, t1, x2, y2, t2) + + check_valid_tracking(x1, y1, box_size) + check_valid_tracking(x2, y2, box_size) + + angles, times = head_direction(x1, y1, x2, y2, t1) + + return angles, times + + +def check_valid_tracking(x, y, box_size): + if np.isnan(x).any() and np.isnan(y).any(): + raise ValueError( + "nans found in position, " + "x nans = %i, y nans = %i" % (sum(np.isnan(x)), sum(np.isnan(y))) + ) + + if x.min() < 0 or x.max() > box_size[0] or y.min() < 0 or y.max() > box_size[1]: + warnings.warn( + "Invalid values found " + + f"outside box: min [x, y] = [{x.min()}, {y.min()}], " + + f"max [x, y] = [{x.max()}, {y.max()}]" + ) + + +def load_tracking(data_path, sampling_rate, low_pass_frequency, box_size, velocity_threshold=5): + x1, y1, t1, x2, y2, t2, stop_time = load_leds(data_path) + x1, y1, t1 = rm_nans(x1, y1, t1) + x2, y2, t2 = rm_nans(x2, y2, t2) + + x1, y1, t1 = filter_t_zero_duration(x1, y1, t1, stop_time) + x2, y2, t2 = filter_t_zero_duration(x2, y2, t2, stop_time) + + # select data with least nan + if len(x1) > len(x2): + x, y, t = x1, y1, t1 + else: + x, y, t = x2, y2, t2 + + # OE saves 0.0 when signal is lost, these can be removed + x, y, t = filter_xy_zero(x, y, t) + + # x, y, t = filter_xy_box_size(x, y, t, box_size) + + # remove velocity artifacts + x, y, t = velocity_filter(x, y, t, velocity_threshold) + + x, y, t = interp_filt_position(x, y, t, fs=sampling_rate, f_cut=low_pass_frequency) + + check_valid_tracking(x, y, box_size) + + vel = np.gradient([x, y], axis=1) / np.gradient(t) + speed = np.linalg.norm(vel, axis=0) + x, y, t, speed = np.array(x), np.array(y), np.array(t), np.array(speed) + return x, y, t, speed + + +def sort_by_cluster_id(spike_trains): + if len(spike_trains) == 0: + return spike_trains + if "name" not in spike_trains[0].annotations: + print("Unable to get cluster_id, save with phy to create") + sorted_sptrs = sorted(spike_trains, key=lambda x: str(x.annotations["name"])) + return sorted_sptrs + + +def get_unit_id(unit): + return str(int(unit.annotations["name"])) + + +class Template: + def __init__(self, sptr): + self.data = np.array(sptr.annotations["waveform_mean"]) + self.sampling_rate = float(sptr.sampling_rate) + + +class Data: + def __init__(self, project, stim_mask=False, baseline_duration=None, stim_channels=None, **kwargs): + self.project_path = project.path + self.params = kwargs + self.project = expipe.get_project(self.project_path) + self.actions = self.project.actions + self._spike_trains = {} + self._templates = {} + self._stim_times = {} + self._unit_names = {} + self._tracking = {} + self._head_direction = {} + self._lfp = {} + self._occupancy = {} + self._rate_maps = {} + self._tracking_split = {} + self._rate_maps_split = {} + self._prob_dist = {} + self._spatial_bins = None + self.stim_mask = stim_mask + self.baseline_duration = baseline_duration + self._channel_groups = {} + self.stim_channels = stim_channels + + def channel_groups(self, action_id): + if action_id not in self._channel_groups: + self._channel_groups[action_id] = get_channel_groups(self.data_path(action_id)) + return self._channel_groups[action_id] + + def data_path(self, action_id): + return pathlib.Path(self.project_path) / "actions" / action_id / "data" / "main.nwb" + + def get_lim(self, action_id): + stim_times = self.stim_times(action_id) + if stim_times is None: + if self.baseline_duration is None: + return [0, float(get_duration(self.data_path(action_id)).magnitude)] + else: + return [0, float(self.baseline_duration)] + stim_times = np.array(stim_times) + return [stim_times.min(), stim_times.max()] + + def duration(self, action_id): + return get_duration(self.data_path(action_id)) + + def tracking(self, action_id): + if action_id not in self._tracking: + x, y, t, speed = load_tracking( + self.data_path(action_id), + sampling_rate=self.params["position_sampling_rate"], + low_pass_frequency=self.params["position_low_pass_frequency"], + box_size=self.params["box_size"], + ) + if self.stim_mask: + t1, t2 = self.get_lim(action_id) + mask = (t >= t1) & (t <= t2) + x = x[mask] + y = y[mask] + t = t[mask] + speed = speed[mask] + self._tracking[action_id] = {"x": x, "y": y, "t": t, "v": speed} + return self._tracking[action_id] + + @property + def spatial_bins(self): + if self._spatial_bins is None: + box_size_, bin_size_ = sp.maps._adjust_bin_size( + box_size=self.params["box_size"], bin_size=self.params["bin_size"] + ) + xbins, ybins = sp.maps._make_bins(box_size_, bin_size_) + self._spatial_bins = (xbins, ybins) + self.box_size_, self.bin_size_ = box_size_, bin_size_ + return self._spatial_bins + + def occupancy(self, action_id): + if action_id not in self._occupancy: + xbins, ybins = self.spatial_bins + + occupancy_map = sp.maps._occupancy_map( + self.tracking(action_id)["x"], + self.tracking(action_id)["y"], + self.tracking(action_id)["t"], + xbins, + ybins, + ) + threshold = self.params.get("occupancy_threshold") + if threshold is not None: + occupancy_map[occupancy_map <= threshold] = 0 + self._occupancy[action_id] = occupancy_map + return self._occupancy[action_id] + + def prob_dist(self, action_id): + if action_id not in self._prob_dist: + xbins, ybins = xbins, ybins = self.spatial_bins + prob_dist = sp.stats.prob_dist( + self.tracking(action_id)["x"], self.tracking(action_id)["y"], bins=(xbins, ybins) + ) + self._prob_dist[action_id] = prob_dist + return self._prob_dist[action_id] + + def tracking_split(self, action_id): + if action_id not in self._tracking_split: + x, y, t, v = map(self.tracking(action_id).get, ["x", "y", "t", "v"]) + + t_split = t[-1] / 2 + mask_1 = t < t_split + mask_2 = t >= t_split + x1, y1, t1, v1 = x[mask_1], y[mask_1], t[mask_1], v[mask_1] + x2, y2, t2, v2 = x[mask_2], y[mask_2], t[mask_2], v[mask_2] + + self._tracking_split[action_id] = { + "x1": x1, + "y1": y1, + "t1": t1, + "v1": v1, + "x2": x2, + "y2": y2, + "t2": t2, + "v2": v2, + } + return self._tracking_split[action_id] + + def spike_train_split(self, action_id, channel_group, unit_name): + spikes = self.spike_train(action_id, channel_group, unit_name) + t_split = self.duration(action_id) / 2 + spikes_1 = spikes[spikes < t_split] + spikes_2 = spikes[spikes >= t_split] + return spikes_1, spikes_2, t_split + + def rate_map_split(self, action_id, channel_group, unit_name, smoothing): + make_rate_map = False + if action_id not in self._rate_maps_split: + self._rate_maps_split[action_id] = {} + if channel_group not in self._rate_maps_split[action_id]: + self._rate_maps_split[action_id][channel_group] = {} + if unit_name not in self._rate_maps_split[action_id][channel_group]: + self._rate_maps_split[action_id][channel_group][unit_name] = {} + if smoothing not in self._rate_maps_split[action_id][channel_group][unit_name]: + make_rate_map = True + + if make_rate_map: + xbins, ybins = self.spatial_bins + x, y, t = map(self.tracking(action_id).get, ["x", "y", "t"]) + spikes = self.spike_train(action_id, channel_group, unit_name) + t_split = t[-1] / 2 + mask_1 = t < t_split + mask_2 = t >= t_split + x_1, y_1, t_1 = x[mask_1], y[mask_1], t[mask_1] + x_2, y_2, t_2 = x[mask_2], y[mask_2], t[mask_2] + spikes_1 = spikes[spikes < t_split] + spikes_2 = spikes[spikes >= t_split] + occupancy_map_1 = sp.maps._occupancy_map(x_1, y_1, t_1, xbins, ybins) + occupancy_map_2 = sp.maps._occupancy_map(x_2, y_2, t_2, xbins, ybins) + + spike_map_1 = sp.maps._spike_map(x_1, y_1, t_1, spikes_1, xbins, ybins) + spike_map_2 = sp.maps._spike_map(x_2, y_2, t_2, spikes_2, xbins, ybins) + + smooth_spike_map_1 = sp.maps.smooth_map(spike_map_1, bin_size=self.bin_size_, smoothing=smoothing) + smooth_spike_map_2 = sp.maps.smooth_map(spike_map_2, bin_size=self.bin_size_, smoothing=smoothing) + smooth_occupancy_map_1 = sp.maps.smooth_map(occupancy_map_1, bin_size=self.bin_size_, smoothing=smoothing) + smooth_occupancy_map_2 = sp.maps.smooth_map(occupancy_map_2, bin_size=self.bin_size_, smoothing=smoothing) + + rate_map_1 = smooth_spike_map_1 / smooth_occupancy_map_1 + rate_map_2 = smooth_spike_map_2 / smooth_occupancy_map_2 + self._rate_maps_split[action_id][channel_group][unit_name][smoothing] = [rate_map_1, rate_map_2] + + return self._rate_maps_split[action_id][channel_group][unit_name][smoothing] + + def rate_map(self, action_id, channel_group, unit_name, smoothing): + make_rate_map = False + if action_id not in self._rate_maps: + self._rate_maps[action_id] = {} + if channel_group not in self._rate_maps[action_id]: + self._rate_maps[action_id][channel_group] = {} + if unit_name not in self._rate_maps[action_id][channel_group]: + self._rate_maps[action_id][channel_group][unit_name] = {} + if smoothing not in self._rate_maps[action_id][channel_group][unit_name]: + make_rate_map = True + + if make_rate_map: + xbins, ybins = self.spatial_bins + + spike_map = sp.maps._spike_map( + self.tracking(action_id)["x"], + self.tracking(action_id)["y"], + self.tracking(action_id)["t"], + self.spike_train(action_id, channel_group, unit_name), + xbins, + ybins, + ) + + smooth_spike_map = sp.maps.smooth_map(spike_map, bin_size=self.bin_size_, smoothing=smoothing) + smooth_occupancy_map = sp.maps.smooth_map( + self.occupancy(action_id), bin_size=self.bin_size_, smoothing=smoothing + ) + rate_map = smooth_spike_map / smooth_occupancy_map + self._rate_maps[action_id][channel_group][unit_name][smoothing] = rate_map + + return self._rate_maps[action_id][channel_group][unit_name][smoothing] + + def head_direction(self, action_id): + if action_id not in self._head_direction: + a, t = load_head_direction( + self.data_path(action_id), + sampling_rate=self.params["position_sampling_rate"], + low_pass_frequency=self.params["position_low_pass_frequency"], + box_size=self.params["box_size"], + ) + if self.stim_mask: + t1, t2 = self.get_lim(action_id) + mask = (t >= t1) & (t <= t2) + a = a[mask] + t = t[mask] + self._head_direction[action_id] = {"a": a, "t": t} + return self._head_direction[action_id] + + def lfp(self, action_id, channel_group, clean_memory=False): + lim = self.get_lim(action_id) if self.stim_mask else None + if clean_memory: + return load_lfp(self.data_path(action_id), channel_group, lim) + if action_id not in self._lfp: + self._lfp[action_id] = {} + if channel_group not in self._lfp[action_id]: + self._lfp[action_id][channel_group] = load_lfp(self.data_path(action_id), channel_group, lim) + return self._lfp[action_id][channel_group] + + def template(self, action_id, channel_group, unit_id): + self.spike_trains(action_id) + return Template(self._spike_trains[action_id][channel_group][unit_id]) + + def spike_train(self, action_id, channel_group, unit_id): + self.spike_trains(action_id) + return self._spike_trains[action_id][channel_group][unit_id] + + def spike_trains(self, action_id, channel_group=None): + if action_id not in self._spike_trains: + self._spike_trains[action_id] = {} + lim = self.get_lim(action_id) if self.stim_mask else None + + sts = load_spiketrains(self.data_path(action_id), lim=lim) + for st in sts: + group = st.annotations["group"] + if group not in self._spike_trains[action_id]: + self._spike_trains[action_id][group] = {} + self._spike_trains[action_id][group][int(get_unit_id(st))] = st + if channel_group is None: + return self._spike_trains[action_id] + else: + return self._spike_trains[action_id][channel_group] + + def unit_names(self, action_id, channel_group): + # TODO + # units = load_unit_annotations(self.data_path(action_id), channel_group=channel_group) + units = None + return [u["name"] for u in units] + + def stim_times(self, action_id): + if action_id not in self._stim_times: + try: + trials = load_epochs(self.data_path(action_id), label_column="channel") + if len(set(trials.labels)) > 1: + stim_times = trials.times[trials.labels == self.stim_channels[action_id]] + else: + stim_times = trials.times + stim_times = np.sort(np.abs(np.array(stim_times))) + # there are some 0 times and inf times, remove those + # stim_times = stim_times[stim_times >= 1e-20] + self._stim_times[action_id] = stim_times + except AttributeError as e: + if str(e) == "'NoneType' object has no attribute 'to_dataframe'": + self._stim_times[action_id] = None + else: + raise e + + return self._stim_times[action_id] diff --git a/src/expipe_plugin_cinpla/tools/registration.py b/src/expipe_plugin_cinpla/tools/registration.py new file mode 100644 index 0000000..80bd395 --- /dev/null +++ b/src/expipe_plugin_cinpla/tools/registration.py @@ -0,0 +1,17 @@ +# -*- coding: utf-8 -*- +import os +import pathlib +import shutil + + +def store_notebook(action, notebook_path): + notebook_path = pathlib.Path(notebook_path) + action.data["notebook"] = notebook_path.name + notebook_output_path = action.data_path("notebook") + shutil.copy(notebook_path, notebook_output_path) + # As HTML + os.system(f"jupyter nbconvert --to html {notebook_path}") + html_path = notebook_path.with_suffix(".html") + action.data["html"] = html_path.name + html_output_path = action.data_path("html") + shutil.copy(html_path, html_output_path) diff --git a/src/expipe_plugin_cinpla/tools/track_units_tools.py b/src/expipe_plugin_cinpla/tools/track_units_tools.py new file mode 100644 index 0000000..b512d4e --- /dev/null +++ b/src/expipe_plugin_cinpla/tools/track_units_tools.py @@ -0,0 +1,205 @@ +# -*- coding: utf-8 -*- +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +from matplotlib import gridspec +from scipy.optimize import linear_sum_assignment + + +def dissimilarity(template_0, template_1): + """ + Returns a value of dissimilarity of the mean between two or more + spike templates. + Parameters + ---------- + templates : list object (see Notes) + List containing the mean waveform over each electrode of spike sorted + spiketrains from at least one electrode. All elements in the list must + be of equal size, that is, the number of electrodes must be equal, and + the number of points on the waveform must be equal. + Returns + ------- + diss : numpy array-like + Returns a matrix containing the computed dissimilarity between the mean + of the spiketrain, for the same channel. + """ + max_val = np.max([np.max(np.abs(template_0)), np.max(np.abs(template_1))]) + + t_i_lin = template_0.ravel() + t_j_lin = template_1.ravel() + + return np.mean(np.abs(t_i_lin / max_val - t_j_lin / max_val)) + # return np.mean(np.abs(t_i_lin - t_j_lin)) + + +def dissimilarity_weighted(templates_0, templates_1): + """ + Returns a value of dissimilarity of the mean between two or more + spike templates. + Parameters + ---------- + templates : list object (see Notes) + List containing the mean waveform over each electrode of spike sorted + spiketrains from at least one electrode. All elements in the list must + be of equal size, that is, the number of electrodes must be equal, and + the number of points on the waveform must be equal. + Returns + ------- + diss : numpy array-like + Returns a matrix containing the computed dissimilarity between the mean + of the spiketrain, for the same channel. + """ + + max_val = np.max([np.max(np.abs(templates_0)), np.max(np.abs(templates_1))]) + + templates_0 /= max_val + templates_1 /= max_val + # root sum square, averaged over channels + weighted = np.sqrt( + np.sum([(templates_0[:, i] - templates_1[:, i]) ** 2 for i in range(templates_0.shape[1])], axis=0) + ).mean() + return weighted + + +def make_possible_match(dissimilarity_scores, max_dissimilarity): + """ + Given an agreement matrix and a max_dissimilarity threhold. + Return as a dict all possible match for each spiketrain in each side. + + Note : this is symmetric. + + + Parameters + ---------- + dissimilarity_scores: pd.DataFrame + + max_dissimilarity: float + + + Returns + ----------- + best_match_12: pd.Series + + best_match_21: pd.Series + + """ + unit1_ids = np.array(dissimilarity_scores.index) + unit2_ids = np.array(dissimilarity_scores.columns) + + # threhold the matrix + scores = dissimilarity_scores.values.copy() + scores[scores > max_dissimilarity] = np.inf + + possible_match_12 = {} + for i1, u1 in enumerate(unit1_ids): + inds_match = np.isfinite(scores[i1, :]) + possible_match_12[u1] = unit2_ids[inds_match] + + possible_match_21 = {} + for i2, u2 in enumerate(unit2_ids): + inds_match = np.isfinite(scores[:, i2]) + possible_match_21[u2] = unit1_ids[inds_match] + + return possible_match_12, possible_match_21 + + +def make_best_match(dissimilarity_scores, max_dissimilarity): + """ + Given an agreement matrix and a max_dissimilarity threhold. + return a dict a best match for each units independently of others. + + Note : this is symmetric. + + Parameters + ---------- + dissimilarity_scores: pd.DataFrame + + max_dissimilarity: float + + + Returns + ----------- + best_match_12: pd.Series + + best_match_21: pd.Series + + + """ + unit1_ids = np.array(dissimilarity_scores.index) + unit2_ids = np.array(dissimilarity_scores.columns) + + scores = dissimilarity_scores.values.copy() + + best_match_12 = pd.Series(index=unit1_ids, dtype="int64") + for i1, u1 in enumerate(unit1_ids): + ind_min = np.argmin(scores[i1, :]) + if scores[i1, ind_min] <= max_dissimilarity: + best_match_12[u1] = unit2_ids[ind_min] + else: + best_match_12[u1] = -1 + + best_match_21 = pd.Series(index=unit2_ids, dtype="int64") + for i2, u2 in enumerate(unit2_ids): + ind_min = np.argmin(scores[:, i2]) + if scores[ind_min, i2] <= max_dissimilarity: + best_match_21[u2] = unit1_ids[ind_min] + else: + best_match_21[u2] = -1 + + return best_match_12, best_match_21 + + +def make_hungarian_match(dissimilarity_scores, max_dissimilarity): + """ + Given an agreement matrix and a max_dissimilarity threhold. + return the "optimal" match with the "hungarian" algo. + This use internally the scipy.optimze.linear_sum_assignment implementation. + + Parameters + ---------- + dissimilarity_scores: pd.DataFrame + + max_dissimilarity: float + + + Returns + ----------- + hungarian_match_12: pd.Series + + hungarian_match_21: pd.Series + + """ + unit1_ids = np.array(dissimilarity_scores.index) + unit2_ids = np.array(dissimilarity_scores.columns) + + # threhold the matrix + scores = dissimilarity_scores.values.copy() + + [inds1, inds2] = linear_sum_assignment(scores) + + hungarian_match_12 = pd.Series(index=unit1_ids, dtype="int64") + hungarian_match_12[:] = -1 + hungarian_match_21 = pd.Series(index=unit2_ids, dtype="int64") + hungarian_match_21[:] = -1 + + for i1, i2 in zip(inds1, inds2, strict=False): + u1 = unit1_ids[i1] + u2 = unit2_ids[i2] + if dissimilarity_scores.at[u1, u2] < max_dissimilarity: + hungarian_match_12[u1] = u2 + hungarian_match_21[u2] = u1 + + return hungarian_match_12, hungarian_match_21 + + +def plot_template(template, fig, gs, axs=None, **kwargs): + nrc = template.shape[1] + if axs is None: + gs0 = gridspec.GridSpecFromSubplotSpec(1, nrc, subplot_spec=gs) + axs = [fig.add_subplot(gs0[0])] + axs.extend([fig.add_subplot(gs0[i], sharey=axs[0], sharex=axs[0]) for i in range(1, nrc)]) + for c in range(nrc): + axs[c].plot(template[:, c], **kwargs) + if c > 0: + plt.setp(axs[c].get_yticklabels(), visible=False) + return axs diff --git a/src/expipe_plugin_cinpla/tools/trackunitcomparison.py b/src/expipe_plugin_cinpla/tools/trackunitcomparison.py new file mode 100644 index 0000000..1997b0a --- /dev/null +++ b/src/expipe_plugin_cinpla/tools/trackunitcomparison.py @@ -0,0 +1,130 @@ +# -*- coding: utf-8 -*- +from pathlib import Path + +import numpy as np +import pandas as pd + +from expipe_plugin_cinpla.data_loader import ( + get_channel_groups, + get_data_path, + load_spiketrains, +) + +from .track_units_tools import ( + dissimilarity_weighted, + make_best_match, + make_hungarian_match, + make_possible_match, +) + + +class TrackingSession: + """ + Base class shared by SortingComparison and GroundTruthComparison + """ + + def __init__( + self, + action_id_0, + action_id_1, + actions, + channel_groups=None, + max_dissimilarity=10, + dissimilarity_function=None, + verbose=False, + ): + data_path_0 = get_data_path(actions[action_id_0]) + data_path_1 = get_data_path(actions[action_id_1]) + + self._actions = actions + self.action_id_0 = action_id_0 + self.action_id_1 = action_id_1 + self.channel_groups = channel_groups + self.action_ids = [action_id_0, action_id_1] + self.max_dissimilarity = max_dissimilarity + self.dissimilarity_function = dissimilarity_function + self._verbose = verbose + + if self.channel_groups is None: + self.channel_groups = get_channel_groups(data_path_0) + self.matches = {} + self.templates = {} + self.unit_ids = {} + for chan in self.channel_groups: + self.matches[chan] = dict() + self.templates[chan] = list() + self.unit_ids[chan] = list() + + self.units_0 = load_spiketrains(data_path_0) + self.units_1 = load_spiketrains(data_path_1) + for channel_group in self.channel_groups: + us_0 = [u for u in self.units_0 if u.annotations["group"] == channel_group] + us_1 = [u for u in self.units_1 if u.annotations["group"] == channel_group] + + self.unit_ids[channel_group] = [ + [int(u.annotations["name"]) for u in us_0], + [int(u.annotations["name"]) for u in us_1], + ] + self.templates[channel_group] = [ + [u.annotations["waveform_mean"] for u in us_0], + [u.annotations["waveform_mean"] for u in us_1], + ] + if len(us_0) > 0 and len(us_1) > 0: + self._do_dissimilarity(channel_group) + self._do_matching(channel_group) + elif self._verbose: + print(f"Found no units in {channel_group}") + + def save_dissimilarity_matrix(self, path=None): + path = path or Path.cwd() + for channel_group in self.channel_groups: + if "dissimilarity_scores" not in self.matches[channel_group]: + continue + filename = f"{self.action_id_0}_{self.action_id_1}_{channel_group}" + self.matches[channel_group]["dissimilarity_scores"].to_csv(path / (filename + ".csv")) + + @property + def session_0_name(self): + return self.name_list[0] + + @property + def session_1_name(self): + return self.name_list[1] + + def make_dissimilary_matrix(self, channel_group): + templates_0, templates_1 = self.templates[channel_group] + diss_matrix = np.zeros((len(templates_0), len(templates_1))) + + unit_ids_0, unit_ids_1 = self.unit_ids[channel_group] + + for i, w0 in enumerate(templates_0): + for j, w1 in enumerate(templates_1): + diss_matrix[i, j] = dissimilarity_weighted(w0, w1) + + diss_matrix = pd.DataFrame(diss_matrix, index=unit_ids_0, columns=unit_ids_1) + + return diss_matrix + + def _do_dissimilarity(self, channel_group): + if self._verbose: + print("Agreement scores...") + + # agreement matrix score for each pair + self.matches[channel_group]["dissimilarity_scores"] = self.make_dissimilary_matrix(channel_group) + + def _do_matching(self, channel_group): + # must be implemented in subclass + if self._verbose: + print("Matching...") + + ( + self.matches[channel_group]["possible_match_01"], + self.matches[channel_group]["possible_match_10"], + ) = make_possible_match(self.matches[channel_group]["dissimilarity_scores"], self.max_dissimilarity) + self.matches[channel_group]["best_match_01"], self.matches[channel_group]["best_match_10"] = make_best_match( + self.matches[channel_group]["dissimilarity_scores"], self.max_dissimilarity + ) + ( + self.matches[channel_group]["hungarian_match_01"], + self.matches[channel_group]["hungarian_match_10"], + ) = make_hungarian_match(self.matches[channel_group]["dissimilarity_scores"], self.max_dissimilarity) diff --git a/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py b/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py new file mode 100644 index 0000000..7ba2ec5 --- /dev/null +++ b/src/expipe_plugin_cinpla/tools/trackunitmulticomparison.py @@ -0,0 +1,293 @@ +# -*- coding: utf-8 -*- +import datetime +import uuid +from collections import defaultdict +from pathlib import Path + +import matplotlib.pylab as plt +import networkx as nx +import numpy as np +import yaml +from matplotlib import gridspec +from tqdm import tqdm + +from expipe_plugin_cinpla.data_loader import ( + get_channel_groups, + get_data_path, + load_spiketrains, +) + +from .track_units_tools import plot_template +from .trackunitcomparison import TrackingSession + + +class TrackMultipleSessions: + def __init__( + self, + actions, + action_list=None, + channel_groups=None, + max_dissimilarity=None, + max_timedelta=None, + verbose=False, + progress_bar=None, + data_path=None, + ): + self.data_path = Path.cwd() if data_path is None else Path(data_path) + self.data_path.mkdir(parents=True, exist_ok=True) + self.action_list = [a for a in actions] if action_list is None else action_list + self._actions = actions + self.channel_groups = channel_groups + self.max_dissimilarity = max_dissimilarity or np.inf + self.max_timedelta = max_timedelta or datetime.MAXYEAR + self._verbose = verbose + self._pbar = tqdm if progress_bar is None else progress_bar + self._templates = {} + if self.channel_groups is None: + dp = get_data_path(self._actions[self.action_list[0]]) + self.channel_groups = get_channel_groups(dp) + if len(self.channel_groups) == 0: + print("Unable to locate channel groups, please provide a working action_list") + + def do_matching(self): + # do pairwise matching + if self._verbose: + print("Multicomaprison step1: pairwise comparison") + + self.comparisons = [] + N = len(self.action_list) + pbar = self._pbar(total=int((N**2 - N) / 2)) + for i in range(N): + for j in range(i + 1, N): + if self._verbose: + print(" Comparing: ", self.action_list[i], " and ", self.action_list[j]) + comp = TrackingSession( + self.action_list[i], + self.action_list[j], + actions=self._actions, + max_dissimilarity=np.inf, + channel_groups=self.channel_groups, + verbose=self._verbose, + ) + # comp.save_dissimilarity_matrix() + self.comparisons.append(comp) + pbar.update(1) + pbar.close() + + def make_graphs_from_matches(self): + if self._verbose: + print("Multicomaprison step2: make graph") + + self.graphs = {} + + for ch in self.channel_groups: + if self._verbose: + print("Processing channel", ch) + self.graphs[ch] = nx.Graph() + + # nodes + for comp in self.comparisons: + # if same node is added twice it's only created once + for i, action_id in enumerate(comp.action_ids): + for u in comp.unit_ids[ch][i]: + node_name = action_id + "_" + str(int(u)) + self.graphs[ch].add_node(node_name, action_id=action_id, unit_id=int(u)) + + # edges + for comp in self.comparisons: + if "hungarian_match_01" not in comp.matches[ch]: + continue + for u1 in comp.unit_ids[ch][0]: + u2 = comp.matches[ch]["hungarian_match_01"][u1] + if u2 != -1: + score = comp.matches[ch]["dissimilarity_scores"].loc[u1, u2] + node1_name = comp.action_id_0 + "_" + str(int(u1)) + node2_name = comp.action_id_1 + "_" + str(int(u2)) + self.graphs[ch].add_edge(node1_name, node2_name, weight=float(score)) + + # the graph is symmetrical + self.graphs[ch] = self.graphs[ch].to_undirected() + + def compute_time_delta_edges(self): + """ + adds a timedelta to each of the edges + """ + for graph in self.graphs.values(): + for n0, n1 in graph.edges(): + action_id_0 = graph.nodes[n0]["action_id"] + action_id_1 = graph.nodes[n1]["action_id"] + time_delta = abs(self._actions[action_id_0].datetime - self._actions[action_id_1].datetime) + graph.add_edge(n0, n1, time_delta=time_delta) + + def compute_depth_delta_edges(self): + """ + adds a depthdelta to each of the edges + """ + for ch, graph in self.graphs.items(): + ch_num = int(ch[-1]) + for n0, n1 in graph.edges(): + action_id_0 = graph.nodes[n0]["action_id"] + action_id_1 = graph.nodes[n1]["action_id"] + loc_0 = self._actions[action_id_0].modules["channel_group_location"][ch_num] + loc_1 = self._actions[action_id_1].modules["channel_group_location"][ch_num] + assert loc_0 == loc_1 + depth_0 = self._actions[action_id_0].modules["depth"][loc_0]["probe_0"] + depth_1 = self._actions[action_id_0].modules["depth"][loc_1]["probe_0"] + depth_0 = float(depth_0.rescale("um")) + depth_1 = float(depth_1.rescale("um")) + depth_delta = abs(depth_0 - depth_1) + graph.add_edge(n0, n1, depth_delta=depth_delta) + + def remove_edges_above_threshold(self, key="weight", threshold=0.05): + """ + key: weight, depth_delta, time_delta + """ + for ch in self.graphs: + graph = self.graphs[ch] + edges_to_remove = [] + for sub_graph in nx.connected_components(graph): + for node_id in sub_graph: + for n1, n2, d in graph.edges(node_id, data=True): + if d[key] > threshold and n2 in sub_graph: # remove all edges from the subgraph + edge = set((n1, n2)) + if edge not in edges_to_remove: + edges_to_remove.append(edge) + for n1, n2 in edges_to_remove: + graph.remove_edge(n1, n2) + self.graphs[ch] = graph + + def remove_edges_with_duplicate_actions(self): + for graph in self.graphs.values(): + edges_to_remove = [] + for sub_graph in nx.connected_components(graph): + sub_graph_action_ids = {node: graph.nodes[node]["action_id"] for node in sub_graph} + action_ids = np.array(list(sub_graph_action_ids.values())) + node_ids = np.array(list(sub_graph_action_ids.keys())) + unique_action_ids, action_id_counts = np.unique(action_ids, return_counts=True) + if len(unique_action_ids) != len(action_ids): + duplicates = unique_action_ids[action_id_counts > 1] + + for duplicate in duplicates: + (idxs,) = np.where(action_ids == duplicate) + weights = {} + for node_id in node_ids[idxs]: + weights[node_id] = np.mean( + [ + d["weight"] + for n1, n2, d in graph.edges(node_id, data=True) + if n2 in sub_graph_action_ids + ] + ) + min_weight = np.min(list(weights.values())) + for node_id, weight in weights.items(): + if weight > min_weight: # remove all edges from the subgraph + for n1, n2 in graph.edges(node_id): + if n2 in sub_graph_action_ids: + edge = set((n1, n2)) + if edge not in edges_to_remove: + edges_to_remove.append(edge) + for n1, n2 in edges_to_remove: + graph.remove_edge(n1, n2) + + def save_graphs(self): + for ch, graph in self.graphs.items(): + with open(self.data_path / f"graph-group-{ch}.yaml", "w") as f: + yaml.dump(graph, f) + + def load_graphs(self): + self.graphs = {} + for path in self.data_path.iterdir(): + if path.name.startswith("graph-group") and path.suffix == ".yaml": + ch = path.stem.split("-")[-1] + with open(path) as f: + self.graphs[ch] = yaml.load(f, Loader=yaml.Loader) + + def identify_units(self): + if self._verbose: + print("Multicomaprison step3: extract agreement from graph") + self.identified_units = {} + for ch, graph in self.graphs.items(): + # extract agrrement from graph + self._new_units = {} + for node_set in nx.connected_components(graph): + unit_id = str(uuid.uuid4()) + edges = graph.edges(node_set, data=True) + + if len(edges) == 0: + average_dissimilarity = None + else: + average_dissimilarity = np.mean([d["weight"] for _, _, d in edges]) + + original_ids = defaultdict(list) + for node in node_set: + original_ids[graph.nodes[node]["action_id"]].append(graph.nodes[node]["unit_id"]) + + self._new_units[unit_id] = { + "average_dissimilarity": average_dissimilarity, + "original_unit_ids": original_ids, + } + + self.identified_units[ch] = self._new_units + + def load_template(self, action_id, channel_group, unit_id): + group_unit_hash = str(channel_group) + "_" + str(unit_id) + if action_id in self._templates: + return self._templates[action_id][group_unit_hash] + + action = self._actions[action_id] + + data_path = get_data_path(action) + + spike_trains = load_spiketrains(data_path) + + self._templates[action_id] = {} + for sptr in spike_trains: + group_unit_hash_ = sptr.annotations["group"] + "_" + str(int(sptr.annotations["name"])) + self._templates[action_id][group_unit_hash_] = sptr.annotations["waveform_mean"] + + return self._templates[action_id][group_unit_hash] + + def plot_matches(self, chan_group=None, figsize=(10, 3), step_color=True): + """ + + Parameters + ---------- + + + Returns + ------- + + """ + if chan_group is None: + ch_groups = self.identified_units.keys() + else: + ch_groups = [chan_group] + for ch_group in ch_groups: + identified_units = self.identified_units[ch_group] + units = [ + (unit["original_unit_ids"], unit["average_dissimilarity"]) + for unit in identified_units.values() + if len(unit["original_unit_ids"]) > 1 + ] + num_units = sum([len(u) for u in units]) + if num_units == 0: + print(f"Zero units found on channel group {ch_group}") + continue + fig = plt.figure(figsize=(figsize[0], figsize[1] * num_units)) + gs = gridspec.GridSpec(num_units, 1) + id_ax = 0 + for unit, avg_dsim in units: + axs = None + for action_id, unit_ids in unit.items(): + for unit_id in unit_ids: + label = f"{action_id} Unit {unit_id} {avg_dsim:.2f}" + template = self.load_template(action_id, ch_group, unit_id) + if template is None: + print(f'Unable to plot "{unit_id}" from action "{action_id}" ch group "{ch_group}"') + continue + # print(f'plotting {action_id}, {ch_group}, {unit_id}') + axs = plot_template(template, fig=fig, gs=gs[id_ax], axs=axs, label=label) + id_ax += 1 + plt.legend(loc="center left", bbox_to_anchor=(1, 0.5)) + fig.suptitle("Channel group " + str(ch_group)) + plt.tight_layout(rect=[0, 0.03, 1, 0.98]) diff --git a/src/expipe_plugin_cinpla/utils.py b/src/expipe_plugin_cinpla/utils.py index e346d67..ec9311d 100644 --- a/src/expipe_plugin_cinpla/utils.py +++ b/src/expipe_plugin_cinpla/utils.py @@ -1,3 +1,4 @@ +# -*- coding: utf-8 -*- from expipe.backends.filesystem import yaml_dump diff --git a/src/expipe_plugin_cinpla/widgets/__init__.py b/src/expipe_plugin_cinpla/widgets/__init__.py index f386856..fb60c64 100644 --- a/src/expipe_plugin_cinpla/widgets/__init__.py +++ b/src/expipe_plugin_cinpla/widgets/__init__.py @@ -1 +1,2 @@ -from .browser import display_browser +# -*- coding: utf-8 -*- +from .browser import display_browser # noqa diff --git a/src/expipe_plugin_cinpla/widgets/browser.py b/src/expipe_plugin_cinpla/widgets/browser.py index b70db22..f689321 100644 --- a/src/expipe_plugin_cinpla/widgets/browser.py +++ b/src/expipe_plugin_cinpla/widgets/browser.py @@ -1,16 +1,17 @@ -import IPython.display as ipd +# -*- coding: utf-8 -*- import expipe +import IPython.display as ipd +from .curation import CurationView +from .process import process_ecephys_view from .register import ( - register_openephys_view, register_adjustment_view, register_annotate_view, register_entity_view, - register_surgery_view, + register_openephys_view, register_perfuse_view, + register_surgery_view, ) -from .process import process_ecephys_view -from .curation import CurationView from .viewer import NwbViewer diff --git a/src/expipe_plugin_cinpla/widgets/curation.py b/src/expipe_plugin_cinpla/widgets/curation.py index 57a05a4..b6b4a43 100644 --- a/src/expipe_plugin_cinpla/widgets/curation.py +++ b/src/expipe_plugin_cinpla/widgets/curation.py @@ -1,15 +1,14 @@ -import ipywidgets -import pandas as pd +# -*- coding: utf-8 -*- from collections import OrderedDict -import expipe -import expipe.config +import ipywidgets +import pandas as pd from expipe_plugin_cinpla.scripts import curation from expipe_plugin_cinpla.scripts.utils import _get_data_path -from .utils import BaseViewWithLog, required_values_filled -from ..utils import dump_project_config +from ..utils import dump_project_config +from .utils import BaseViewWithLog, required_values_filled default_qms = [ dict(name="isi_violations_ratio", sign="<", threshold=0.5), @@ -47,7 +46,11 @@ class CurationView(BaseViewWithLog): def __init__(self, project): from nwbwidgets import nwb2widget from pynwb.misc import Units - from ..nwbutils.nwbwidgetsunitviewer import UnitWaveformsWidget, UnitRateMapWidget + + from ..nwbutils.nwbwidgetsunitviewer import ( + UnitRateMapWidget, + UnitWaveformsWidget, + ) custom_raw_unit_vis = { Units: OrderedDict({"Raw Waveforms": UnitWaveformsWidget, "Rate Maps": UnitRateMapWidget}) @@ -70,7 +73,7 @@ def __init__(self, project): si_path = data_path.parent / "spikeinterface" if si_path.is_dir(): actions_processed.append(action_name) - + actions_processed = sorted(actions_processed) actions_list = ipywidgets.Select( options=actions_processed, rows=10, description="Actions: ", disabled=False, layout={"width": "300px"} ) diff --git a/src/expipe_plugin_cinpla/widgets/process.py b/src/expipe_plugin_cinpla/widgets/process.py index 4d60634..6897fef 100644 --- a/src/expipe_plugin_cinpla/widgets/process.py +++ b/src/expipe_plugin_cinpla/widgets/process.py @@ -1,5 +1,8 @@ +# -*- coding: utf-8 -*- import ast + from expipe_plugin_cinpla.scripts import process + from .utils import BaseViewWithLog metric_names = [ @@ -25,8 +28,9 @@ def process_ecephys_view(project): import ipywidgets import spikeinterface.sorters as ss - from .utils import SearchSelectMultiple, required_values_filled, ParameterSelectList + from ..scripts.utils import _get_data_path + from .utils import ParameterSelectList, SearchSelectMultiple, required_values_filled all_actions = project.actions @@ -41,7 +45,7 @@ def process_ecephys_view(project): action_names.append(f"{action_name} -- (P)") else: action_names.append(f"{action_name} -- (U)") - + action_names = sorted(action_names) action_ids = SearchSelectMultiple(action_names, description="*Actions") overwrite = ipywidgets.Checkbox(description="Overwrite", value=True) diff --git a/src/expipe_plugin_cinpla/widgets/register.py b/src/expipe_plugin_cinpla/widgets/register.py index be0e8b1..7b7aa75 100644 --- a/src/expipe_plugin_cinpla/widgets/register.py +++ b/src/expipe_plugin_cinpla/widgets/register.py @@ -1,18 +1,17 @@ +# -*- coding: utf-8 -*- from pathlib import Path + from expipe_plugin_cinpla.scripts import register -from .utils import BaseViewWithLog + from ..utils import dump_project_config +from .utils import BaseViewWithLog ### Open Ephys recording ### def register_openephys_view(project): import ipywidgets - from .utils import ( - MultiInput, - required_values_filled, - none_if_empty, - split_tags, - ) + + from .utils import MultiInput, none_if_empty, required_values_filled, split_tags # left column layout_auto = ipywidgets.Layout(width="300px") @@ -31,7 +30,7 @@ def register_openephys_view(project): # buttons depth = MultiInput(["Key", "Probe", "Depth", "Unit"], "Add depth") register_depth = ipywidgets.Checkbox(description="Register depth", value=False) - include_events = ipywidgets.Checkbox(description="Include events", value=False) + include_events = ipywidgets.Checkbox(description="Include events", value=True) register_depth_from_adjustment = ipywidgets.Checkbox(description="Find adjustments", value=True) register_depth_from_adjustment.layout.visibility = "hidden" @@ -127,12 +126,8 @@ def on_register(change): ### Adjustment ### def register_adjustment_view(project): import ipywidgets - from .utils import ( - DateTimePicker, - MultiInput, - required_values_filled, - SearchSelect, - ) + + from .utils import DateTimePicker, MultiInput, SearchSelect, required_values_filled entity_id = SearchSelect(options=project.entities, description="*Entities") user = ipywidgets.Text(placeholder="*User", value=project.config.get("username")) @@ -140,9 +135,9 @@ def register_adjustment_view(project): adjustment = MultiInput(["*Key", "*Probe", "*Adjustment", "*Unit"], "Add adjustment") depth = MultiInput(["Key", "Probe", "Depth", "Unit"], "Add depth") depth_from_surgery = ipywidgets.Checkbox(description="Get depth from surgery", value=True) - register = ipywidgets.Button(description="Register") + register_button = ipywidgets.Button(description="Register") - fields = ipywidgets.VBox([user, date, adjustment, register]) + fields = ipywidgets.VBox([user, date, adjustment, register_button]) main_box = ipywidgets.VBox([depth_from_surgery, ipywidgets.HBox([fields, entity_id])]) def on_manual_depth(change): @@ -158,6 +153,9 @@ def on_manual_depth(change): depth_from_surgery.observe(on_manual_depth, names="value") + view = BaseViewWithLog(main_box=main_box, project=project) + + @view.output.capture() def on_register(change): if not required_values_filled(entity_id, user, adjustment): return @@ -171,18 +169,20 @@ def on_register(change): yes=True, ) - register.on_click(on_register) - return main_box + register_button.on_click(on_register) + + return view ### Annotation ### def register_annotate_view(project): import ipywidgets + from .utils import ( DateTimePicker, MultiInput, - required_values_filled, SearchSelectMultiple, + required_values_filled, split_tags, ) @@ -196,11 +196,14 @@ def register_annotate_view(project): message = ipywidgets.Text(placeholder="Message") tag = ipywidgets.Text(placeholder="Tags (; to separate)") templates = SearchSelectMultiple(project.templates, description="Templates") - register = ipywidgets.Button(description="Register") + register_button = ipywidgets.Button(description="Register") - fields = ipywidgets.VBox([user, date, location, message, action_type, tag, depth, entity_id, register]) + fields = ipywidgets.VBox([user, date, location, message, action_type, tag, depth, entity_id, register_button]) main_box = ipywidgets.VBox([ipywidgets.HBox([fields, action_id, templates])]) + view = BaseViewWithLog(main_box=main_box, project=project) + + @view.output.capture() def on_register(change): if not required_values_filled(action_id, user): return @@ -221,20 +224,20 @@ def on_register(change): correct_depth_answer=True, ) - register.on_click(on_register) - return main_box + register_button.on_click(on_register) + return view ### Entity ### def register_entity_view(project): import ipywidgets + from .utils import ( DatePicker, SearchSelectMultiple, - required_values_filled, none_if_empty, + required_values_filled, split_tags, - make_output_and_show, ) entity_id = ipywidgets.Text(placeholder="*Entity id") @@ -248,8 +251,8 @@ def register_entity_view(project): templates = SearchSelectMultiple(project.templates, description="Templates") overwrite = ipywidgets.Checkbox(description="Overwrite", value=False) - register = ipywidgets.Button(description="Register") - fields = ipywidgets.VBox([entity_id, user, species, sex, location, birthday, message, tag, register]) + register_button = ipywidgets.Button(description="Register") + fields = ipywidgets.VBox([entity_id, user, species, sex, location, birthday, message, tag, register_button]) main_box = ipywidgets.VBox([overwrite, ipywidgets.HBox([fields, templates])]) view = BaseViewWithLog(main_box=main_box, project=project) @@ -273,21 +276,22 @@ def on_register(change): templates=templates.value, ) - register.on_click(on_register) + register_button.on_click(on_register) return view ### Surgery ### def register_surgery_view(project): import ipywidgets + from .utils import ( DatePicker, MultiInput, + SearchSelect, SearchSelectMultiple, - required_values_filled, none_if_empty, + required_values_filled, split_tags, - SearchSelect, ) entity_id = SearchSelect(options=project.entities, description="*Entities") @@ -307,9 +311,9 @@ def register_surgery_view(project): angle = MultiInput(["*Key", "*Probe", "*Angle", "*Unit"], "Add angle") templates = SearchSelectMultiple(project.templates, description="Templates") overwrite = ipywidgets.Checkbox(description="Overwrite", value=False) - register = ipywidgets.Button(description="Register") + register_button = ipywidgets.Button(description="Register") - fields = ipywidgets.VBox([user, location, date, weight, position, angle, message, procedure, tag, register]) + fields = ipywidgets.VBox([user, location, date, weight, position, angle, message, procedure, tag, register_button]) main_box = ipywidgets.VBox([overwrite, ipywidgets.HBox([fields, ipywidgets.VBox([entity_id, templates])])]) view = BaseViewWithLog(main_box=main_box, project=project) @@ -336,21 +340,20 @@ def on_register(change): tags=tags, ) - register.on_click(on_register) + register_button.on_click(on_register) return view ### PERFUSION ### def register_perfuse_view(project): import ipywidgets + from .utils import ( DatePicker, - MultiInput, + SearchSelect, SearchSelectMultiple, - required_values_filled, none_if_empty, - split_tags, - SearchSelect, + required_values_filled, ) entity_id = SearchSelect(options=project.entities, description="*Entities") @@ -367,8 +370,8 @@ def register_perfuse_view(project): templates = SearchSelectMultiple(project.templates, description="Templates") overwrite = ipywidgets.Checkbox(description="Overwrite", value=False) - register = ipywidgets.Button(description="Register") - fields = ipywidgets.VBox([user, location, date, weight, message, register]) + register_button = ipywidgets.Button(description="Register") + fields = ipywidgets.VBox([user, location, date, weight, message, register_button]) main_box = ipywidgets.VBox([overwrite, ipywidgets.HBox([fields, entity_id, templates])]) view = BaseViewWithLog(main_box=main_box, project=project) @@ -389,5 +392,5 @@ def on_register(change): message=none_if_empty(message.value), ) - register.on_click(on_register) + register_button.on_click(on_register) return view diff --git a/src/expipe_plugin_cinpla/widgets/utils.py b/src/expipe_plugin_cinpla/widgets/utils.py index 35d1e0c..39f9ceb 100644 --- a/src/expipe_plugin_cinpla/widgets/utils.py +++ b/src/expipe_plugin_cinpla/widgets/utils.py @@ -1,9 +1,11 @@ -import ipywidgets -import numpy as np +# -*- coding: utf-8 -*- import datetime as dt -import expipe import warnings +import expipe +import ipywidgets +import numpy as np + warnings.filterwarnings("ignore", category=DeprecationWarning) @@ -217,7 +219,7 @@ def value(self): for ch in self.children: keys.append(ch.description) values.append(ch.value) - return dict(zip(keys, values)) + return dict(zip(keys, values, strict=False)) class DateTimePicker(ipywidgets.HBox): @@ -277,9 +279,10 @@ def __init__(self, filetype=None, initialdir=None, *args, **kwargs): @staticmethod def select_file(self): - from tkfilebrowser import askopenfilename from tkinter import Tk + from tkfilebrowser import askopenfilename + # Create Tk root root = Tk() # Hide the main window @@ -294,7 +297,7 @@ def select_file(self): name = ft[1:].capitalize() result = askopenfilename( defaultextension=ft, - filetypes=[("{} file".format(name), "*{}".format(ft)), ("All files", "*.*")], + filetypes=[(f"{name} file", f"*{ft}"), ("All files", "*.*")], initialdir=self.initialdir, ) self.file = result if len(result) > 0 else "" @@ -347,9 +350,10 @@ def on_text_change(change): @staticmethod def select_file(self): - from tkfilebrowser import askopenfilenames from tkinter import Tk + from tkfilebrowser import askopenfilenames + # Create Tk root root = Tk() # Hide the main window @@ -364,7 +368,7 @@ def select_file(self): name = ft[1:].capitalize() self.files = askopenfilenames( defaultextension=ft, - filetypes=[("{} file".format(name), "*{}".format(ft)), ("All files", "*.*")], + filetypes=[(f"{name} file", f"*{ft}"), ("All files", "*.*")], initialdir=self.initialdir, ) else: @@ -399,9 +403,10 @@ def __init__(self, initialdir=None, *args, **kwargs): @staticmethod def select_directories(self): - from tkfilebrowser import askopendirnames from tkinter import Tk + from tkfilebrowser import askopendirnames + # Create Tk root root = Tk() # Hide the main window diff --git a/src/expipe_plugin_cinpla/widgets/viewer.py b/src/expipe_plugin_cinpla/widgets/viewer.py index e9937ea..2cc710b 100644 --- a/src/expipe_plugin_cinpla/widgets/viewer.py +++ b/src/expipe_plugin_cinpla/widgets/viewer.py @@ -1,7 +1,6 @@ -import ipywidgets -from pynwb import NWBHDF5IO - +# -*- coding: utf-8 -*- import expipe +import ipywidgets from ..nwbutils.nwbwidgetsunitviewer import get_custom_spec from ..scripts.utils import _get_data_path @@ -32,11 +31,13 @@ def get_options(self): data_path = _get_data_path(action) if data_path is not None and data_path.name == "main.nwb": options.append(action_name) + options = sorted(options) return options def on_change(self, change): if change["type"] == "change" and change["name"] == "value": from nwbwidgets import nwb2widget + from pynwb import NWBHDF5IO action_id = change["new"] if action_id is None: diff --git a/tests/conftest.py b/tests/conftest.py index 8e45c3e..c8c4092 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,8 +1,10 @@ -import pytest -import expipe +# -*- coding: utf-8 -*- import shutil from pathlib import Path +import expipe +import pytest + from expipe_plugin_cinpla.utils import dump_project_config TEST_DATA_PATH = Path(__file__).parent / "test_data" diff --git a/tests/test_cli.py b/tests/test_cli.py index 50377ff..839d4d6 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -1,15 +1,15 @@ -import pytest +# -*- coding: utf-8 -*- import time + import click -from pathlib import Path -from click.testing import CliRunner -import quantities as pq import numpy as np +import pytest +import quantities as pq +import spikeinterface.extractors as se +from click.testing import CliRunner from expipe_plugin_cinpla.cli import CinplaPlugin -import spikeinterface.extractors as se - @click.group() @click.pass_context @@ -23,7 +23,6 @@ def cli(ctx): def run_command(command_list, inp=None): runner = CliRunner() command_list = [str(c) for c in command_list] - # print(" ".join(command_list)) result = runner.invoke(cli, command_list, input=inp) if result.exit_code != 0: print(result.output) diff --git a/tests/test_convert_old_project.py b/tests/test_convert_old_project.py index a9307d3..9e5ab99 100644 --- a/tests/test_convert_old_project.py +++ b/tests/test_convert_old_project.py @@ -1,9 +1,11 @@ +# -*- coding: utf-8 -*- from pathlib import Path -from pynwb import NWBHDF5IO import expipe -from expipe_plugin_cinpla.scripts.utils import _get_data_path +from pynwb import NWBHDF5IO + from expipe_plugin_cinpla import convert_old_project +from expipe_plugin_cinpla.scripts.utils import _get_data_path test_folder = Path(__file__).parent old_project_path = test_folder / "test_data" / "old_project" diff --git a/tests/test_script.py b/tests/test_script.py index e0d1b43..0ee6808 100644 --- a/tests/test_script.py +++ b/tests/test_script.py @@ -1,8 +1,14 @@ -import pytest +# -*- coding: utf-8 -*- from datetime import datetime -from expipe_plugin_cinpla.scripts.register import register_entity, register_openephys_recording -from expipe_plugin_cinpla.scripts.process import process_ecephys + +import pytest + from expipe_plugin_cinpla.scripts.curation import SortingCurator +from expipe_plugin_cinpla.scripts.process import process_ecephys +from expipe_plugin_cinpla.scripts.register import ( + register_entity, + register_openephys_recording, +) @pytest.mark.dependency()