Skip to content

Commit

Permalink
Test template based approach
Browse files Browse the repository at this point in the history
import key functions and classes in artacs.__init__
  • Loading branch information
agricolab committed Sep 24, 2018
1 parent b05c740 commit 3ff58d7
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 12 deletions.
2 changes: 2 additions & 0 deletions artacs/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .kernel import create_kernel, apply_kernel, CombKernel
from .template import StepwiseRemover
5 changes: 3 additions & 2 deletions artacs/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def filter_1d(indata, fs:int, freq:int, kernel:ndarray):
return filtered

# %%
def filter_2d(indata:ndarray, fs:int, freq:int, kernel:ndarray):
def apply_kernel(indata:ndarray, fs:int, freq:int, kernel:ndarray):
''' filter a two-dimensional dataset with a predefined kernel
args
Expand Down Expand Up @@ -277,6 +277,7 @@ def filter_2d(indata:ndarray, fs:int, freq:int, kernel:ndarray):
filtered[idx,:] = filter_1d(chandata, fs, freq, kernel)

return filtered

#%%
class CombKernel():

Expand All @@ -298,7 +299,7 @@ def _update_kernel(self):
self._right_mode)

def __call__(self, indata:np.array) -> ndarray:
return filter_2d(indata=indata, freq=self._freq, fs=self._fs,
return apply_kernel(indata=indata, freq=self._freq, fs=self._fs,
kernel=self._kernel)

def __repr__(self):
Expand Down
31 changes: 23 additions & 8 deletions artacs/template.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,14 @@
"""
import numpy as np
import artacs.tools as tools
import logging
logger = logging.Logger(__name__)
# %%
class StepwiseRemover():

def __init__(self, fs=1000, freq=None, period_steps=2,
epsilon=0.01, max_iterations=10):
epsilon=0.01, max_iterations=10, verbose=True):
self.verbose = verbose
self.true_fs = fs
self.freq = freq
if freq is not None:
Expand All @@ -35,15 +38,15 @@ def inbound_resample(self, indata):
'resample so that (artifact_period* artifact_frequency) is an integer'
if self.resample_flag:
period = int(np.ceil(self.true_period))
fs = int(period * self.freq)
fs = int(np.ceil(period * self.freq))
data = tools.resample_by_fs(indata,
up=fs,
down=self.true_fs,
axis=0)
self.sample_count = indata.shape[0]
else:
data = indata
fs = self.fs
fs = self.true_fs
period = int(self.true_period)
return data, period, fs

Expand All @@ -62,21 +65,33 @@ def prepare_data(self, indata):
seeds = self.calc_seeds(period)
return data, period, fs, seeds

def __call__(self, indata):
return self.process(indata)

def process(self, indata):
'process all channels of a dataset'
if self.true_period is None:
print('Invalid period length, skipping artifact removal')
return indata

num_channels, num_samples = indata.shape
if len(indata.shape) == 1:
num_channels, num_samples = 1, indata.shape[0]
indata = np.atleast_2d(indata)
elif len(indata.shape) == 2:
num_channels, num_samples = indata.shape
else:
raise ValueError('Unspecified dimensionality of the dataset')
outdata = np.empty((indata.shape))
outdata.fill(np.nan)
print('[',end='')
if self.verbose:
print('[',end='')
for chan_idx, chan_data in enumerate(indata):
outdata[chan_idx,:] = self.process_channel(chan_data)
print('.',end='')
print(']',end='\n')
return outdata
if self.verbose:
print('.',end='')
if self.verbose:
print(']',end='\n')
return np.squeeze(outdata)

def process_channel(self, indata):
'process a single channels of data'
Expand Down
30 changes: 28 additions & 2 deletions test/test.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@
'''
from artacs.kernel import create_kernel, _estimate_prms_from_kernel, filter_1d, filter_2d
from artacs.kernel import create_kernel, _estimate_prms_from_kernel, filter_1d, apply_kernel
from artacs.kernel import CombKernel
from artacs.template import StepwiseRemover
import numpy as np
#%%
def test_kernel():
Expand Down Expand Up @@ -93,6 +94,8 @@ def test_kernel():
assert np.all(np.isclose(filtered[period*2:], 0, atol=1e-04))

# test multi-channel filter
# we only check for everything after the first period to account for the
# settle-in duration of the one-step comb filter
fs = 1000
freq = 10
period = int(np.ceil(fs/freq))
Expand All @@ -102,7 +105,7 @@ def test_kernel():
data = np.vstack((data, data))
kernel = create_kernel(freq, fs, 1,
left_mode='uniform', right_mode ='none')
filtered = filter_2d(data, fs, freq, kernel)
filtered = apply_kernel(data, fs, freq, kernel)
for chan in filtered:
assert np.all(np.isclose(chan[period:], 0, 1e-10))

Expand All @@ -118,6 +121,29 @@ def test_kernel():
assert np.all(np.isclose(chan[period:], 0, 1e-10))


s = StepwiseRemover(fs=fs, freq=freq)
duration_in_s = 2
t = np.linspace(1/fs, duration_in_s, num=fs*duration_in_s)
data = np.sin(2*np.pi*freq*t)
data = np.vstack((data, data))
filtered = s(data)
for chan in filtered:
assert np.all(np.isclose(chan[:], 0, 1e-10))

fs = 5000
duration_in_s = 2
for freq in range(10, 21, 1):
s = StepwiseRemover(fs=fs, freq=freq, period_steps=10)
t = np.linspace(1/fs, duration_in_s, num=fs*duration_in_s)
nse = np.random.randn(t.shape[0])
data = 1000*np.sin(2*np.pi*freq*t) + nse
filtered = s(data)
print('{:3.0f} '.format(freq),end='')
r2 = np.square(np.corrcoef(filtered, nse)[0,1])
print(np.max(np.abs(filtered)), r2)
assert np.all(r2>0.7071)


print('Test successful')

# %%
Expand Down

0 comments on commit 3ff58d7

Please sign in to comment.