Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
* Support roiextractors 0.5.11 [PR #1236](https://github.com/catalystneuro/neuroconv/pull/1236)
* Added stub_test option to TDTFiberPhotometryInterface [PR #1242](https://github.com/catalystneuro/neuroconv/pull/1242)
* Added ThorImagingInterface for Thor TIFF files with OME metadata [PR #1238](https://github.com/catalystneuro/neuroconv/pull/1238)
* For `PhySortingInterface`, automatically calculate `max_channel` for each unit and adds it to units table. [PR #961](https://github.com/catalystneuro/neuroconv/pull/961)

## Improvements
* Filter out warnings for missing timezone information in continuous integration [PR #1240](https://github.com/catalystneuro/neuroconv/pull/1240)
Expand Down
50 changes: 49 additions & 1 deletion src/neuroconv/datainterfaces/ecephys/phy/phydatainterface.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
from typing import Optional
from pathlib import Path
from typing import Literal, Optional

import numpy as np
from pydantic import DirectoryPath, validate_call
from pynwb.file import NWBFile

from ..basesortingextractorinterface import BaseSortingExtractorInterface
from ....utils import DeepDict


class PhySortingInterface(BaseSortingExtractorInterface):
Expand All @@ -24,6 +28,23 @@ def get_source_schema(cls) -> dict:
] = "Path to the output Phy folder (containing the params.py)."
return source_schema

def get_max_channel(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Docstring? It would be useful here to describe some of the Phy charateristics. Like, I read from the code below that the templates are stored unwithened, the axis operation would be clearer if you wrote the template shape somewhere (I think its shape is (num_templates, num_samples, num_channels), right? Plus, the definition of cluster id.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this could use a docstring

folder_path = Path(self.source_data["folder_path"])

templates = np.load(str(folder_path / "templates.npy"))
channel_map = np.load(str(folder_path / "channel_map.npy")).T
whitening_mat_inv = np.load(str(folder_path / "whitening_mat_inv.npy"))
templates_unwh = templates @ whitening_mat_inv

cluster_ids = self.sorting_extractor.get_property("original_cluster_id")
templates = templates_unwh[cluster_ids]

max_over_time = np.max(templates, axis=1)
idx_max_channel = np.argmax(max_over_time, axis=1)
max_channel = channel_map[idx_max_channel].ravel()

return max_channel
Copy link
Collaborator

@h-mayorquin h-mayorquin Aug 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, two other questions here:

  1. Aren't the templates scaled? I am thinking on the amplitudes.npy . This would mean the argmax operation might not be right.
  2. Also, is max_channel meanigful if spikes are negative? I think here it is consistent but aren't most people interested on the absolute value largest. On the templates data base we use "best channel" which did not have that load.

Now I am aware that you are using machinery that is already there in the add_units_table but I wanted to point this out.

https://phy.readthedocs.io/en/latest/sorting_user_guide/

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Scaled how? In a way that would affect which channel is selected?
  2. yeah, good point, the argmax(abs()) might be better

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I shared this with a friend that uses Phy and he flagged that. Reading the documentation I feel less certain.

We could do a roun-trip with spikeinterface artificial data and the export to phy functionality to see if your function gets it right.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is some code by Nick Stteinmetz that calculates templates max without using the amplitudes (but they don't do the de-whitening step that you do)

https://github.com/cortex-lab/spikes/blob/master/analysis/templatePositionsAmplitudes.m

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think he does. Look at winv

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But the block where the max channel is calculated take the temps as they come directly from the input:

https://github.com/cortex-lab/spikes/blob/fcea2b20e736b533e5baf612752b66121a691128/analysis/templatePositionsAmplitudes.m#L64-L71

Am I missing something?


@validate_call
def __init__(
self,
Expand All @@ -44,6 +65,33 @@ def __init__(
"""
super().__init__(folder_path=folder_path, exclude_cluster_groups=exclude_cluster_groups, verbose=verbose)

def add_to_nwbfile(
self,
nwbfile: NWBFile,
metadata: Optional[DeepDict] = None,
stub_test: bool = False,
write_ecephys_metadata: bool = False,
write_as: Literal["units", "processing"] = "units",
units_name: str = "units",
units_description: str = "Imported from Phy",
include_max_channel: bool = True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is nice that this inherits the doc from the parent:

image

But the new include_max_channel is not documented.

):
if include_max_channel and "max_channel" not in self.sorting_extractor.get_property_keys():
max_channels = self.get_max_channel()
self.sorting_extractor.set_property("max_channel", max_channels)

super().add_to_nwbfile(
nwbfile=nwbfile,
metadata=metadata,
stub_test=stub_test,
write_ecephys_metadata=write_ecephys_metadata,
write_as=write_as,
units_name=units_name,
units_description=units_description,
)

return nwbfile

def get_metadata(self):
metadata = super().get_metadata()
# See Kilosort save_to_phy() docstring for more info on these fields: https://github.com/MouseLand/Kilosort/blob/main/kilosort/io.py
Expand Down
17 changes: 17 additions & 0 deletions tests/test_on_data/ecephys/test_sorting_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
from datetime import datetime

import numpy as np
from numpy.testing import assert_array_equal
from pynwb import NWBHDF5IO
from spikeinterface.extractors.nwbextractors import read_nwbfile

from neuroconv.datainterfaces import (
BlackrockRecordingInterface,
Expand Down Expand Up @@ -198,6 +200,21 @@ class TestPhySortingInterface(SortingExtractorInterfaceTestMixin):
interface_kwargs = dict(folder_path=str(DATA_PATH / "phy" / "phy_example_0"))
save_directory = OUTPUT_PATH

def check_read_nwb(self, nwbfile_path: str):
# Test that the max channel is correctly extracted
super().check_read_nwb(nwbfile_path)

# check that the max channel is correctly extracted
max_channel = self.interface.get_max_channel()
assert_array_equal(max_channel, [1, 2, 5, 5, 6, 21, 13, 13, 21, 21, 22, 22, 24])

# check that max channel was properly added to sorting extractor
assert_array_equal(self.interface.sorting_extractor.get_property("max_channel"), max_channel)

# check that max channels were properly added to the NWB file
nwbfile = read_nwbfile(file_path=nwbfile_path, backend="hdf5")
assert_array_equal(nwbfile.units["max_channel"].data[:], max_channel)

def check_extracted_metadata(self, metadata: dict):
assert metadata["Ecephys"]["UnitProperties"] == [
dict(name="n_spikes", description="Number of spikes recorded from each unit."),
Expand Down
Loading