Skip to content

Commit

Permalink
get_units bug
Browse files Browse the repository at this point in the history
  • Loading branch information
m-beau committed Jan 13, 2021
1 parent 1737893 commit 907d117
Show file tree
Hide file tree
Showing 18 changed files with 52 additions and 65 deletions.
6 changes: 3 additions & 3 deletions build/lib/npyx/spk_t.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def ids(dp, unit, sav=True, prnt=False, subset_selection='all', again=False):
except:pass
if type(unit) is int:
spike_clusters = np.load(Path(dp,"spike_clusters.npy"))
indices = np.nonzero(spike_clusters==unit)[0].squeeze()
indices = np.nonzero(spike_clusters==unit)[0]
if type(unit) not in [str, np.str_, int]:
print('WARNING unit {} type ({}) not handled!'.format(unit, type(unit)))
return
Expand Down Expand Up @@ -114,11 +114,11 @@ def trn(dp, unit, sav=True, prnt=False, subset_selection='all', again=False, enf
if ds_table.shape[0]>1: # If several datasets in prophyler
spike_clusters_samples = np.load(Path(dp, 'merged_clusters_spikes.npy'))
dataset_mask=(spike_clusters_samples[:, 0]==ds_i); unit_mask=(spike_clusters_samples[:, 1]==unt)
train = spike_clusters_samples[dataset_mask&unit_mask, 2].squeeze().astype(np.int64)
train = spike_clusters_samples[dataset_mask&unit_mask, 2].astype(np.int64)
else:
spike_clusters = np.load(Path(ds_table['dp'][0],"spike_clusters.npy"))
spike_samples = np.load(Path(ds_table['dp'][0],'spike_times.npy'))
train = spike_samples[spike_clusters==unt].squeeze()
train = spike_samples[spike_clusters==unt]
else:
try:unit=int(unit)
except:pass
Expand Down
1 change: 1 addition & 0 deletions build/lib/npyx/spk_wvf.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,6 +403,7 @@ def get_ids_subset(dp, unit, n_waveforms, batch_size_waveforms, subset_selection
else:
assert n_waveforms > 0
spike_ids = ids(dp, unit)
assert any(spike_ids)
if subset_selection == 'regular':
# Regular subselection.
if batch_size_waveforms is None or len(spike_ids) <= max(batch_size_waveforms, n_waveforms):
Expand Down
Binary file removed dist/npyx-1.2.tar.gz
Binary file not shown.
Binary file not shown.
Binary file added dist/npyx-1.21.tar.gz
Binary file not shown.
15 changes: 7 additions & 8 deletions npyx.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Metadata-Version: 2.1
Name: npyx
Version: 1.2
Version: 1.21
Summary: Python routines dealing with Neuropixels data.
Home-page: https://github.com/Npix-routines/NeuroPyxels
Author: Maxime Beau
Expand All @@ -19,7 +19,7 @@ Description: # routines
Python version has to be >=3.7 ot there will be imports issues!

Useful link to [create a python package from a git repository](https://towardsdatascience.com/build-your-first-open-source-python-project-53471c9942a7)
In particular, to upload

## Installation:
Using a conda environment is very much advised. Instructions here: [manage conda environments](https://docs.conda.io/projects/conda/en/latest/user-guide/tasks/manage-environments.html)
- from a local repository (recommended if plans to work on it/regularly pull upgrades)
Expand Down Expand Up @@ -75,22 +75,21 @@ Description: # routines
setup(name='npyx',
version='1.0',... # change to 1.1 or whatev
```
Then re-generate the distribution files for the new version using twine:
Then delete the old distribution files before re-generating them for the new version using twine:
```
rm -r ./dist
rm -r ./build
rm -r ./npyx.egg-info
python setup.py sdist bdist_wheel # this will generate version 1.1 wheel without overwriting version 1.0 wheel in ./dist
```
Before pushing them to PyPI (it will keep track of older versions!)
Before pushing them to PyPI (older versions are saved online!)
```
$ twine upload dist/*
Uploading distributions to https://upload.pypi.org/legacy/
Enter your username: your-username
Enter your password:
Uploading npyx-1.0-py3-none-any.whl
100%|█████████████████████████████████████████████████████████| 156k/156k [00:01<00:00, 154kB/s]
Uploading npyx-1.1-py3-none-any.whl
100%|████████████████████████████████████████████████████████| 156k/156k [00:01<00:00, 96.8kB/s]
Uploading npyx-1.0.tar.gz
100%|█████████████████████████████████████████████████████████| 149k/149k [00:00<00:00, 169kB/s]
Uploading npyx-1.1.tar.gz
100%|█████████████████████████████████████████████████████████| 150k/150k [00:01<00:00, 142kB/s]

Expand Down
Binary file modified npyx/__pycache__/circuitProphyler.cpython-37.pyc
Binary file not shown.
Binary file modified npyx/__pycache__/corr.cpython-37.pyc
Binary file not shown.
Binary file modified npyx/__pycache__/gl.cpython-37.pyc
Binary file not shown.
Binary file modified npyx/__pycache__/plot.cpython-37.pyc
Binary file not shown.
Binary file modified npyx/__pycache__/spk_wvf.cpython-37.pyc
Binary file not shown.
Binary file modified npyx/__pycache__/utils.cpython-37.pyc
Binary file not shown.
45 changes: 19 additions & 26 deletions npyx/circuitProphyler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ class Prophyler:
If you want to access unit u:
>>> pro.units[u]
>>> pro.units[u].trn(): equivalent to rtn.npyx.spk_t.trn(dp,u)
>>> pro.units[u].trn(): equivalent to npyx.spk_t.trn(dp,u)
The units can also be accessed through the 'unit' attributes of the graph nodes:
>>> pro.graph.nodes[u]]['unit'].trn() returns the same thing as ds.units[u].trn()
Expand All @@ -107,8 +107,7 @@ class Prophyler:


def __init__(self, datapaths, sync_idx3A=2):

# Handle datapaths format
# Handle datapaths format
typ_e=TypeError('''
Datapath should be either a string to a kilosort path:
'path/to/kilosort/output1'
Expand Down Expand Up @@ -176,12 +175,13 @@ def __init__(self, datapaths, sync_idx3A=2):
cl_grp.insert(0, 'dataset_i', ds_i)
qualities=qualities.append(cl_grp, ignore_index=True)
qualities.set_index('dataset_i', inplace=True)
qualities_dp=Path(self.dp_pro, 'merged_cluster_group.tsv')#Path(self.dp_pro, 'merged_cluster_group.tsv')
qualities_dp=Path(self.dp_pro, 'merged_cluster_group.tsv')
if op.exists(qualities_dp):
qualities_old=pd.read_csv(qualities_dp, sep=' ', index_col='dataset_i')
# only consider re-spike sorted if cluster indices have been changed, do not if only qualities were changed (spike times are unimpacted by that)
if not np.all(np.isin(qualities_old.loc[:, 'cluster_id'], qualities.loc[:, 'cluster_id'])):
re_spksorted=True
print('New spike-sorting detected.')
qualities.to_csv(qualities_dp, sep=' ')

# If several datasets are fed to the prophyler, align their spike times.
Expand All @@ -191,11 +191,8 @@ def __init__(self, datapaths, sync_idx3A=2):
if (not op.exists(Path(self.dp_pro, merge_fname+'.npy'))) or re_spksorted:
print(">>> Loading spike trains of {} datasets...".format(len(self.ds_table.index)))
spike_times, spike_clusters, sync_signals = [], [], []
fileCreateTimes, fileTimeSecs = [], [] # used to assess if files were recorded on the same NI card
for ds_i in self.ds_table.index:
ds=self.ds[ds_i]
fileCreateTimes.append(ds.meta['fileCreateTime'])
fileTimeSecs.append(ds.meta['fileTimeSecs'])
ons, offs = get_npix_sync(ds.dp, output_binary = False, sourcefile='ap', unit='samples')
spike_times.append(np.load(Path(ds.dp, 'spike_times.npy')).flatten())
spike_clusters.append(np.load(Path(ds.dp, 'spike_clusters.npy')).flatten())
Expand All @@ -208,10 +205,7 @@ def __init__(self, datapaths, sync_idx3A=2):
for i in range(len(spike_times)): NspikesTotal+=len(spike_times[i])

print(">>> Aligning spike trains of {} datasets...".format(len(self.ds_table.index)))
if (all(e==fileCreateTimes[0] for e in fileCreateTimes) and all(e==fileTimeSecs[0] for e in fileTimeSecs)):
print('>>> All fed datasets were recorded on the same NI card - no alignment necessary, keep spike times as they are!')
else:
spike_times = align_timeseries(spike_times, sync_signals, 30000)
spike_times = align_timeseries(spike_times, sync_signals, 30000)
merged_clusters_spikes=npa(zeros=(NspikesTotal, 3), dtype=np.uint64) # 1:dataset index, 2:unit index
cum_Nspikes=0
for ds_i in self.ds_table.index:
Expand Down Expand Up @@ -240,7 +234,6 @@ def __init__(self, datapaths, sync_idx3A=2):
self.units={}
for ds_i in self.ds_table.index:
ds=self.ds[ds_i]
assert any(ds.get_good_units()), f'Aborting - circuit prophyler can only work on a spike-sorted dataset, find good units in {ds.name} before calling it!'
ds.get_peak_positions()
for u, pos in ds.peak_positions.items():
self.peak_positions['{}_{}'.format(ds_i,u)]=pos+npa([100,0])*ds_i # Every dataset is offset by 100um on x
Expand Down Expand Up @@ -302,7 +295,7 @@ def connect_graph(self, corr_type='connections', metric='amp_z', cbin=0.5, cwin=
g=self.map_sfc_on_g(g, sfc, criteria)
self.make_directed_graph()
if plotsfcm:
rtn.npyx.plot.plot_sfcm(self.dp_pro, corr_type, metric,
npyx.plot.plot_sfcm(self.dp_pro, corr_type, metric,
cbin, cwin, p_th, n_consec_bins, fract_baseline, W_sd, test,
depth_ticks=True, regions={}, reg_colors={}, again=again, againCCG=againCCG, drop_seq=drop_seq)

Expand Down Expand Up @@ -651,7 +644,7 @@ def label_edges(self, prophylerGraph='undigraph', src_graph=None):
ea=self.get_edge_attributes(edge, prophylerGraph=prophylerGraph, src_graph=src_graph) # u1, u2, i unpacked

##TODO - plt ccg from shared directory
rtn.npyx.plot.plot_ccg(self.dp, [ea['uSrc'],ea['uTrg']], ea['criteria']['cbin'], ea['criteria']['cwin'])
npyx.plot.plot_ccg(self.dp, [ea['uSrc'],ea['uTrg']], ea['criteria']['cbin'], ea['criteria']['cwin'])

label=''
while label=='': # if enter is hit
Expand Down Expand Up @@ -1123,13 +1116,13 @@ def __init__(self, datapath, probe_name='prb1', dataset_index=0, dataset_name=No
raise "Local channel map comprises channels not found in expected channels given matafile probe type."

def get_units(self):
return rtn.npyx.gl.get_units(self.dp)
return npyx.gl.get_units(self.dp)

def get_good_units(self):
return rtn.npyx.gl.get_units(self.dp, quality='good')
return npyx.gl.get_units(self.dp, quality='good')

def get_peak_channels(self):
self.peak_channels = get_depthSort_peakChans(self.dp, quality='good')# {mainChans[i,0]:mainChans[i,1] for i in range(mainChans.shape[0])}
self.peak_channels = get_depthSort_peakChans(self.dp, use_template=True)# {mainChans[i,0]:mainChans[i,1] for i in range(mainChans.shape[0])}
return self.peak_channels

def get_peak_positions(self):
Expand Down Expand Up @@ -1188,36 +1181,36 @@ def get_peak_position(self):
self.peak_position_real=self.ds.peak_positions_real[self.idx==self.ds.peak_positions_real[:,0], 1:].flatten()

def trn(self, rec_section='all'):
return rtn.npyx.spk_t.trn(self.dp, self.idx, rec_section=rec_section)
return npyx.spk_t.trn(self.dp, self.idx, rec_section=rec_section)

def trnb(self, bin_size, rec_section='all'):
return rtn.npyx.spk_t.trnb(self.dp, self.idx, bin_size, rec_section=rec_section)
return npyx.spk_t.trnb(self.dp, self.idx, bin_size, rec_section=rec_section)

def ids(self):
return rtn.npyx.spk_t.ids(self.dp, self.idx)
return npyx.spk_t.ids(self.dp, self.idx)

def isi(self, rec_section='all'):
return rtn.npyx.spk_t.isi(self.dp, self.idx, rec_section=rec_section)
return npyx.spk_t.isi(self.dp, self.idx, rec_section=rec_section)

def acg(self, cbin, cwin, normalize='Hertz', rec_section='all'):
return rtn.npyx.corr.acg(self.dp, self.idx, bin_size=cbin, win_size=cwin, normalize=normalize, rec_section=rec_section)
return npyx.corr.acg(self.dp, self.idx, bin_size=cbin, win_size=cwin, normalize=normalize, rec_section=rec_section)

def ccg(self, U, cbin, cwin, fs=30000, normalize='Hertz', ret=True, sav=True, prnt=True, rec_section='all', again=False):
return rtn.npyx.corr.ccg(self.dp, [self.idx]+list(U), cbin, cwin, fs, normalize, ret, sav, prnt, rec_section, again)
return npyx.corr.ccg(self.dp, [self.idx]+list(U), cbin, cwin, fs, normalize, ret, sav, prnt, rec_section, again)

def wvf(self, n_waveforms=100, t_waveforms=82, wvf_subset_selection='regular', wvf_batch_size=10):
return rtn.npyx.spk_wvf.wvf(self.dp, self.idx, n_waveforms, t_waveforms, wvf_subset_selection, wvf_batch_size, True, True)
return npyx.spk_wvf.wvf(self.dp, self.idx, n_waveforms, t_waveforms, wvf_subset_selection, wvf_batch_size, True, True)

def plot_acg(self, cbin=0.2, cwin=80, normalize='Hertz', color=0, saveDir='~/Downloads', saveFig=True, prnt=False, show=True,
pdf=True, png=False, rec_section='all', labels=True, title=None, ref_per=True, saveData=False, ylim=0):

rtn.npyx.plot.plot_acg(self.dp, self.idx, cbin, cwin, normalize, color, saveDir, saveFig, prnt, show,
npyx.plot.plot_acg(self.dp, self.idx, cbin, cwin, normalize, color, saveDir, saveFig, prnt, show,
pdf, png, rec_section, labels, title, ref_per, saveData, ylim)

def plot_ccg(self, units, cbin=0.2, cwin=80, normalize='Hertz', saveDir='~/Downloads', saveFig=False, prnt=False, show=True,
pdf=False, png=False, rec_section='all', labels=True, std_lines=True, title=None, color=-1, CCG=None, saveData=False, ylim=0):

rtn.npyx.plot.plot_ccg(self.dp, [self.idx]+list(units), cbin, cwin, normalize, saveDir, saveFig, prnt, show,
npyx.plot.plot_ccg(self.dp, [self.idx]+list(units), cbin, cwin, normalize, saveDir, saveFig, prnt, show,
pdf, png, rec_section, labels, std_lines, title, color, CCG, saveData, ylim)

def connections(self):
Expand Down
4 changes: 2 additions & 2 deletions npyx/corr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ def ccg_sig_stack(dp, U_src, U_trg, cbin=0.5, cwin=100, name=None,
return sigstack, sigustack

def gen_sfc(dp, corr_type='connections', metric='amp_z', cbin=0.5, cwin=100, p_th=0.02, n_consec_bins=3, fract_baseline=4./5, W_sd=10, test='Poisson_Stark',
again=False, againCCG=False, drop_seq=['sign', 'time', 'max_amplitude'], units=None, name=None, cross_cont_proof=False):
again=False, againCCG=False, drop_seq=['sign', 'time', 'max_amplitude'], units=None, name=None, cross_cont_proof=False, use_template_for_peakchan=False):
'''
Function generating a functional correlation dataframe sfc (Nsig x 2+8 features) and matrix sfcm (Nunits x Nunits)
from a sorted Kilosort output at 'dp' containing 'N' good units
Expand Down Expand Up @@ -1440,7 +1440,7 @@ def gen_sfc(dp, corr_type='connections', metric='amp_z', cbin=0.5, cwin=100, p_t
if units is not None:
assert np.all(np.isin(units, get_units(dp))), 'Some of the provided units are not found in this dataset.'
assert name is not None, 'You MUST provide a custom name for the provided list of units to ensure that your results can be saved.'
peakChs = get_depthSort_peakChans(dp, units=units)
peakChs = get_depthSort_peakChans(dp, units=units, use_template=use_template_for_peakchan)
gu = peakChs[:,0]
else:
name='good-all_to_all'
Expand Down
20 changes: 3 additions & 17 deletions npyx/gl.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,30 +74,16 @@ def get_units(dp, quality='all', chan_range=None, again=False):
for ds_i in cl_grp.index.unique():
# np.all(cl_grp.loc[ds_i, 'group'][cl_grp.loc[ds_i, 'cluster_id']==u]==quality)
units += ['{}_{}'.format(ds_i, u) for u in cl_grp.loc[(cl_grp['group']==quality)&(cl_grp.index==ds_i), 'cluster_id']]

else:
try:
np.all(np.isnan(cl_grp['group'])) # Units have not been given a class yet
units=[]
except:
if quality=='all':
units = cl_grp.loc[:, 'cluster_id'].values.astype(np.int64)
if 'unsorted' not in cl_grp['group'].unique():
units1 = cl_grp.loc[:, 'cluster_id'].astype(np.int64)
units=np.unique(np.load(Path(dp,"spike_clusters.npy")))
unsort_u=units[~np.isin(units, units1)]
unsort_df=pd.DataFrame({'cluster_id':unsort_u, 'group':['unsorted']*len(unsort_u)})
cl_grp=cl_grp.append(unsort_df, ignore_index=True)
cl_grp.to_csv(Path(dp, 'cluster_group.tsv'), sep=' ', index=False)
else:
raise ValueError(f'you cannot try to load {quality} units before manually curating a dataset - run phy once and try again.')
units=cl_grp.loc[cl_grp['group']==quality,'cluster_id'].values if quality!='all' else cl_grp['cluster_id'].values

if chan_range is None:
return units

assert len(chan_range)==2, 'chan_range should be a list or array with 2 elements!'

peak_channels=get_depthSort_peakChans(dp, units=[], quality=quality)
# For regular datasets
peak_channels=get_depthSort_peakChans(dp, units=units, quality=quality)
chan_mask=(peak_channels[:,1]>=chan_range[0])&(peak_channels[:,1]<=chan_range[1])
units=peak_channels[chan_mask,0].flatten()

Expand Down
8 changes: 4 additions & 4 deletions npyx/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -1410,15 +1410,15 @@ def plot_acg(dp, unit, cbin=0.2, cwin=80, normalize='Hertz', color=0, saveDir='~

def plot_ccg(dp, units, cbin=0.2, cwin=80, normalize='Hertz', saveDir='~/Downloads', saveFig=False, prnt=False, show=True,
_format='pdf', subset_selection='all', labels=True, std_lines=True, title=None, color=-1, CCG=None, saveData=False,
ylim=[0,0], ccg_mn=None, ccg_std=None, again=False, trains=None, ccg_grid=False):
ylim=[0,0], ccg_mn=None, ccg_std=None, again=False, trains=None, ccg_grid=False, use_template=True):
assert type(units) in [list, np.ndarray]
units=list(units)
_, _idx=np.unique(units, return_index=True)
units=npa(units)[np.sort(_idx)].tolist()
assert normalize in ['Counts', 'Hertz', 'Pearson', 'zscore', 'mixte'],"WARNING ccg() 'normalize' argument should be a string in ['Counts', 'Hertz', 'Pearson', 'zscore', 'mixte']."#
if normalize=='mixte' and len(units)==2 and not ccg_grid: normalize='zscore'
saveDir=op.expanduser(saveDir)
bChs=get_depthSort_peakChans(dp, units=units)[:,1].flatten()
bChs=get_depthSort_peakChans(dp, units=units, use_template=use_template)[:,1].flatten()
ylim1, ylim2 = ylim[0], ylim[1]

if CCG is None:
Expand Down Expand Up @@ -1608,7 +1608,7 @@ def plot_sfcm(dp, corr_type='connections', metric='amp_z', cbin=0.5, cwin=100,
drop_seq=['sign', 'time', 'max_amplitude'], units=None, name=None,
text=False, markers=False, ticks=True, depth_ticks=False,
regions={}, reg_colors={}, vminmax=[-7,7], figsize=(7,7),
saveFig=False, saveDir=None, again=False, againCCG=False):
saveFig=False, saveDir=None, again=False, againCCG=False, use_template_for_peakchan=False):
'''
Visually represents the connectivity datafrane outputted by 'gen_sfc'.
Each line/row is a good unit.
Expand All @@ -1617,7 +1617,7 @@ def plot_sfcm(dp, corr_type='connections', metric='amp_z', cbin=0.5, cwin=100,
'''
sfc, sfcm, peakChs = gen_sfc(dp, corr_type, metric, cbin, cwin,
p_th, n_consec_bins, fract_baseline, W_sd, test,
again, againCCG, drop_seq, units, name)
again, againCCG, drop_seq, units, name, False, use_template_for_peakchan)
gu = peakChs[:,0]
ch = peakChs[:,1].astype(int)

Expand Down
Loading

0 comments on commit 907d117

Please sign in to comment.