Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions neo/core/spiketrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def _check_time_in_range(value, t_start, t_stop, view=False):
def _check_waveform_dimensions(spiketrain):
'''
Verify that waveform is compliant with the waveform definition as
quantity array 3D (spike, channel_index, time)
quantity array 3D (time, spike, channel_index)
'''

if not spiketrain.size:
Expand All @@ -87,10 +87,10 @@ def _check_waveform_dimensions(spiketrain):
if (waveforms is None) or (not waveforms.size):
return

if waveforms.shape[0] != len(spiketrain):
if waveforms.shape[1] != len(spiketrain):
raise ValueError("Spiketrain length (%s) does not match to number of "
"waveforms present (%s)" % (len(spiketrain),
waveforms.shape[0]))
waveforms.shape[1]))


def _new_spiketrain(cls, signal, t_stop, units=None, dtype=None,
Expand Down Expand Up @@ -161,7 +161,7 @@ class SpikeTrain(BaseNeo, pq.Quantity):
:class:`SpikeTrain` began. This will be converted to the
same units as :attr:`times`.
Default: 0.0 seconds.
:waveforms: (quantity array 3D (spike, channel_index, time))
:waveforms: (quantity array 3D (time, spike, channel_index))
The waveforms of each spike.
:sampling_rate: (quantity scalar) Number of samples per unit time
for the waveforms.
Expand All @@ -184,7 +184,7 @@ class SpikeTrain(BaseNeo, pq.Quantity):
read-only.
(:attr:`t_stop` - :attr:`t_start`)
:spike_duration: (quantity scalar) Duration of a waveform, read-only.
(:attr:`waveform`.shape[2] * :attr:`sampling_period`)
(:attr:`waveform`.shape[0] * :attr:`sampling_period`)
:right_sweep: (quantity scalar) Time from the trigger times of the
spikes to the end of the waveforms, read-only.
(:attr:`left_sweep` + :attr:`spike_duration`)
Expand Down Expand Up @@ -219,9 +219,7 @@ def __new__(cls, times, t_stop, units=None, dtype=None, copy=True,
This is called whenever a new :class:`SpikeTrain` is created from the
constructor, but not when slicing.
'''
if len(times) != 0 and waveforms is not None and len(times) != \
waveforms.shape[0]:
# len(times)!=0 has been used to workaround a bug occuring during neo import
if len(times) != 0 and waveforms is not None and len(times) != waveforms.shape[1]:
raise ValueError(
"the number of waveforms should be equal to the number of spikes")

Expand Down Expand Up @@ -435,7 +433,7 @@ def sort(self):
# sort the waveforms by the times
sort_indices = np.argsort(self)
if self.waveforms is not None and self.waveforms.any():
self.waveforms = self.waveforms[sort_indices]
self.waveforms = self.waveforms[:, sort_indices, :]

# now sort the times
# We have sorted twice, but `self = self[sort_indices]` introduces
Expand Down Expand Up @@ -492,7 +490,7 @@ def __getitem__(self, i):
'''
obj = super(SpikeTrain, self).__getitem__(i)
if hasattr(obj, 'waveforms') and obj.waveforms is not None:
obj.waveforms = obj.waveforms.__getitem__(i)
obj.waveforms = obj.waveforms.__getitem__([slice(None), i, slice(None)])
return obj

def __setitem__(self, i, value):
Expand Down Expand Up @@ -570,7 +568,7 @@ def time_slice(self, t_start, t_stop):
new_st.t_start = max(_t_start, self.t_start)
new_st.t_stop = min(_t_stop, self.t_stop)
if self.waveforms is not None:
new_st.waveforms = self.waveforms[indices]
new_st.waveforms = self.waveforms[:, indices, :]

return new_st

Expand Down Expand Up @@ -627,8 +625,8 @@ def merge(self, other):
sampling_rate=self.sampling_rate,
left_sweep=self.left_sweep, **kwargs)
if all(wfs):
wfs_stack = np.vstack((self.waveforms, other.waveforms))
wfs_stack = wfs_stack[sorting]
wfs_stack = np.concatenate((self.waveforms, other.waveforms), axis=1)
wfs_stack = wfs_stack[:, sorting, :]
train.waveforms = wfs_stack
train.segment = self.segment
if train.segment is not None:
Expand Down Expand Up @@ -661,11 +659,11 @@ def spike_duration(self):
'''
Duration of a waveform.

(:attr:`waveform`.shape[2] * :attr:`sampling_period`)
(:attr:`waveform`.shape[0] * :attr:`sampling_period`)
'''
if self.waveforms is None or self.sampling_rate is None:
return None
return self.waveforms.shape[2] / self.sampling_rate
return self.waveforms.shape[0] / self.sampling_rate

@property
def sampling_period(self):
Expand Down
1 change: 1 addition & 0 deletions neo/io/brainwaresrcio.py
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,7 @@ def _combine_spiketrains(self, spiketrains):
# get the maximum time
t_stop = times[-1] * 2.

waveforms = np.moveaxis(waveforms, 2, 0)
waveforms = pq.Quantity(waveforms, units=pq.mV, copy=False)

train = SpikeTrain(times=times, copy=False,
Expand Down
Loading