Skip to content

Commit

Permalink
update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
agricolab committed Mar 20, 2022
1 parent f06d488 commit 9fbb4fa
Show file tree
Hide file tree
Showing 81 changed files with 500 additions and 551 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,6 @@ data/*
*.pyc
.spyproject/
*.egg-info
doc/build/*
.vscode
doc/source/_autosummary/*
272 changes: 159 additions & 113 deletions artacs/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
Conf Proc IEEE Eng Med Biol Soc 2015, 3436–3439.
https://doi.org/10.1109/EMBC.2015.7319131
[2]: Guggenberger, R., & Gharabaghi, A. (2021). Comb filters for the removal of transcranial current stimulation artifacts from single channel EEG recordings. Current Directions in Biomedical Engineering, 7(2), 383–386. https://doi.org/10.1515/cdbme-2021-2097
Application
-----------
Expand All @@ -51,9 +53,7 @@
Object-oriented implementation
------------------------------
There exists also an object-oriented implementation as follows::
kernel =
There exists also an object-oriented implementation (see :class:`~.CombKernel`)
"""
Expand All @@ -63,9 +63,10 @@
from warnings import warn
from scipy import signal
from artacs.tools import resample_by_fs, resample_by_count

# %%
def _estimate_prms_from_kernel(kernel:ndarray) -> Tuple[int, int]:
'''estimate period and width from kernel
def _estimate_prms_from_kernel(kernel: ndarray) -> Tuple[int, int]:
"""estimate period and width from kernel
args
----
Expand All @@ -79,64 +80,77 @@ def _estimate_prms_from_kernel(kernel:ndarray) -> Tuple[int, int]:
width:int
number of weight values in each direction
'''


period = np.unique(np.diff(np.where(kernel!=0)[0]))
"""

period = np.unique(np.diff(np.where(kernel != 0)[0]))
if len(period) != 1:
raise ValueError('Multiple or no periods recognized in kernel' +
'Was the kernel correctly constructed?')

raise ValueError(
"Multiple or no periods recognized in kernel"
+ "Was the kernel correctly constructed?"
)

period = int(period)
left_width = (kernel[:kernel.shape[0]//2]<0).sum()
right_width = (kernel[kernel.shape[0]//2:]<0).sum()
left_width = (kernel[: kernel.shape[0] // 2] < 0).sum()
right_width = (kernel[kernel.shape[0] // 2 :] < 0).sum()

if left_width == 0:
width = right_width
elif right_width == 0:
width = left_width
elif right_width == left_width:
width = left_width # both need to be equal, so it doesn't matter
width = left_width # both need to be equal, so it doesn't matter
else:
raise ValueError('Kernel presents with unclear direction.' +
'Was the kernel correctly constructed?')

raise ValueError(
"Kernel presents with unclear direction."
+ "Was the kernel correctly constructed?"
)

return period, width


# %%
def _weigh_exp(width:int) -> ndarray:
'create exponential weights'
weights = signal.exponential(((width-1)*2)+1)[:width]
weights /= (weights.sum())
def _weigh_exp(width: int) -> ndarray:
"create exponential weights"
weights = signal.exponential(((width - 1) * 2) + 1)[:width]
weights /= weights.sum()
return weights

def _weigh_linear(width:int) -> ndarray:
'create linear weights'
weights = np.linspace(1/width, 1, num=width)
weights /= (weights.sum())

def _weigh_linear(width: int) -> ndarray:
"create linear weights"
weights = np.linspace(1 / width, 1, num=width)
weights /= weights.sum()
return weights

def _weigh_gaussian(width:int, sigma:float=1) -> ndarray:
'create gaussian weights'
weights = signal.gaussian(width*2, sigma)[0:width]
weights /= ( weights.sum())

def _weigh_gaussian(width: int, sigma: float = 1) -> ndarray:
"create gaussian weights"
weights = signal.gaussian(width * 2, sigma)[0:width]
weights /= weights.sum()
return weights

def _weigh_uniform(width:int) -> ndarray:
'create uniform weights'

def _weigh_uniform(width: int) -> ndarray:
"create uniform weights"
weights = np.ones(width)
weights /= ( weights.sum())
weights /= weights.sum()
return weights

def _weigh_not(width:int) -> ndarray:
'create zero weights'

def _weigh_not(width: int) -> ndarray:
"create zero weights"
weights = np.zeros(width)
return weights

def create_kernel(freq:int, fs:int, width:int,
left_mode:str='uniform',
right_mode:str='uniform') -> ndarray:
'''create kernel from parameters

def create_kernel(
freq: int,
fs: int,
width: int,
left_mode: str = "uniform",
right_mode: str = "uniform",
) -> ndarray:
"""create kernel from parameters
args
----
Expand All @@ -160,39 +174,43 @@ def create_kernel(freq:int, fs:int, width:int,
:func:`~.filter_1d`
'''
in_period = fs/freq
period = int(np.ceil(in_period))
"""
in_period = fs / freq
period = int(np.ceil(in_period))
if in_period != period:
warn ('Only integer periods are natively supported.' +
'Will auto-resample to higher sampling rate')
fs = int(period * freq)

weighfoos = {'uniform':_weigh_uniform,
'uni':_weigh_uniform,
'none':_weigh_not,
'zero':_weigh_not,
'gauss':_weigh_gaussian,
'normal':_weigh_gaussian,
'linear':_weigh_linear,
'exp':_weigh_exp,
'exponential':_weigh_exp
}

warn(
"Only integer periods are natively supported."
+ "Will auto-resample to higher sampling rate"
)
fs = int(period * freq)

weighfoos = {
"uniform": _weigh_uniform,
"uni": _weigh_uniform,
"none": _weigh_not,
"zero": _weigh_not,
"gauss": _weigh_gaussian,
"normal": _weigh_gaussian,
"linear": _weigh_linear,
"exp": _weigh_exp,
"exponential": _weigh_exp,
}

left_weights = weighfoos[left_mode.lower()](width)
right_weights = weighfoos[right_mode.lower()](width)[::-1]
norm = left_weights.sum() + right_weights.sum()
weights = np.hstack((-left_weights/norm, 1.0, -right_weights/norm))
midpoint = period*width
kernel = np.zeros((midpoint*2)+1)
kernel[::period] = weights

weights = np.hstack((-left_weights / norm, 1.0, -right_weights / norm))

midpoint = period * width
kernel = np.zeros((midpoint * 2) + 1)
kernel[::period] = weights
return kernel


# %%
def filter_1d(indata, fs:int, freq:int, kernel:ndarray):
''' filter a one-dimensional dataset with a predefined kernel
def filter_1d(indata, fs: int, freq: int, kernel: ndarray):
""" filter a one-dimensional dataset with a predefined kernel
args
----
Expand All @@ -214,44 +232,46 @@ def filter_1d(indata, fs:int, freq:int, kernel:ndarray):
.. seealso::
:func:`~.filter_2d`
'''
"""
in_samples = indata.shape[0]
in_period = fs/freq
#if sampling rate of signal and artifact are not integer divisible,
in_period = fs / freq

# if sampling rate of signal and artifact are not integer divisible,
# we have to resample the data
resample_flag = ( in_period != int(np.ceil(in_period)) )
resample_flag = in_period != int(np.ceil(in_period))
if resample_flag:
old_fs = fs
period = int(np.ceil(in_period))
fs = int(period * freq)
data = resample_by_fs(indata, up=fs, down=old_fs)
fs = int(period * freq)
data = resample_by_fs(indata, up=fs, down=old_fs)
else:
data = indata
period = int(in_period)


# if the kernel period is not matching the artifact period,
period = int(in_period)

# if the kernel period is not matching the artifact period,
# filtering would be off
kperiod, kwidth = _estimate_prms_from_kernel(kernel)
kperiod, kwidth = _estimate_prms_from_kernel(kernel)
if kperiod != period:
raise ValueError('Kernel is not matching artifact frequency. ' +
'Was the kernel correctly constructed?')
raise ValueError(
"Kernel is not matching artifact frequency. "
+ "Was the kernel correctly constructed?"
)

#-------------------------------------------------------------------------
fdata = np.convolve(data, kernel[::-1], 'same')
#-------------------------------------------------------------------------
# -------------------------------------------------------------------------
fdata = np.convolve(data, kernel[::-1], "same")
# -------------------------------------------------------------------------
if resample_flag:
filtered = resample_by_count(fdata, in_samples)
filtered = np.asanyarray(filtered)
filtered = np.asanyarray(filtered)
else:
filtered = fdata
filtered = fdata

return filtered



# %%
def apply_kernel(indata:ndarray, fs:int, freq:int, kernel:ndarray):
''' filter a two-dimensional dataset with a predefined kernel
def apply_kernel(indata: ndarray, fs: int, freq: int, kernel: ndarray):
""" filter a two-dimensional dataset with a predefined kernel
args
----
Expand All @@ -273,28 +293,35 @@ def apply_kernel(indata:ndarray, fs:int, freq:int, kernel:ndarray):
.. seealso::
:func:`~.filter_1d`
'''
"""
filtered = np.zeros(indata.shape)
for idx, chandata in enumerate(indata):
filtered[idx,:] = filter_1d(chandata, fs, freq, kernel)
for idx, chandata in enumerate(indata):
filtered[idx, :] = filter_1d(chandata, fs, freq, kernel)

return filtered


#%%
class CombKernel():
'''Object-oriented comb kernel filter
class CombKernel:
"""Object-oriented comb kernel filter
Example to create and apply classical comb kernel::
kernel = CombKernel(freq=20, fs=1000, width=1,
left_mode='uniform', right_mode ='none')
kernel.apply(artifacted_signal)
'''
def __init__(self, freq:int, fs:int, width:int,
left_mode:str='uniform',
right_mode:str='uniform') -> None:

"""

def __init__(
self,
freq: int,
fs: int,
width: int,
left_mode: str = "uniform",
right_mode: str = "uniform",
) -> None:

self._freq = freq
self._fs = fs
self._width = width
Expand All @@ -303,19 +330,38 @@ def __init__(self, freq:int, fs:int, width:int,
self._update_kernel()

def _update_kernel(self):
self._kernel = create_kernel(self._freq, self._fs,
self._width,
self._left_mode,
self._right_mode)

def apply(self, indata:ndarray):
return apply_kernel(indata=indata, freq=self._freq, fs=self._fs,
kernel=self._kernel)
self._kernel = create_kernel(
self._freq,
self._fs,
self._width,
self._left_mode,
self._right_mode,
)

def apply(self, indata: ndarray):
""" apply the kernel to a two-dimensional signal
args
----
indata:ndarray
two-dimensional artifacted signal, dimensions are channel x samples
def __call__(self, indata:ndarray) -> ndarray:
returns
-------
filtered:ndarray
two-dimensional signal with artifact removed
"""
return apply_kernel(
indata=indata, freq=self._freq, fs=self._fs, kernel=self._kernel
)

def __call__(self, indata: ndarray) -> ndarray:
return self.apply(indata)

def __repr__(self):
return (f'KernelFilter({self._freq}, {self._fs}, {self._width}, ' +
f"'{self._left_mode}', '{self._right_mode}')")

return (
f"KernelFilter({self._freq}, {self._fs}, {self._width}, "
+ f"'{self._left_mode}', '{self._right_mode}')"
)

Loading

0 comments on commit 9fbb4fa

Please sign in to comment.