Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 30 additions & 18 deletions fastdtw/fastdtw.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
pass


def fastdtw(x, y, radius=1, dist=None):
def fastdtw(x, y, radius=1, dist=None, dist_only=False):
''' return the approximate distance between 2 time series with O(N)
time and memory complexity

Expand All @@ -32,13 +32,17 @@ def fastdtw(x, y, radius=1, dist=None):
dist is an int of value p > 0, then the p-norm will be used. If
dist is a function then dist(x[i], y[j]) will be used. If dist is
None then abs(x[i] - y[j]) will be used.
dist_only : bool
If you are only interested in the distance metric for DTW, set the
value too True and it will not return a path

Returns
-------
distance : float
the approximate distance between the 2 time series
path : list
list of indexes for the inputs x and y
path : list or None
list of indexes for the inputs x and y
(If dist_only is set to True then the path will be None)

Examples
--------
Expand All @@ -50,7 +54,7 @@ def fastdtw(x, y, radius=1, dist=None):
(2.0, [(0, 0), (1, 0), (2, 1), (3, 2), (4, 2)])
'''
x, y, dist = __prep_inputs(x, y, dist)
return __fastdtw(x, y, radius, dist)
return __fastdtw(x, y, radius, dist, dist_only)


def __difference(a, b):
Expand All @@ -61,18 +65,18 @@ def __norm(p):
return lambda a, b: np.linalg.norm(np.atleast_1d(a) - np.atleast_1d(b), p)


def __fastdtw(x, y, radius, dist):
def __fastdtw(x, y, radius, dist, dist_only):
min_time_size = radius + 2

if len(x) < min_time_size or len(y) < min_time_size:
return dtw(x, y, dist=dist)
return dtw(x, y, dist=dist, dist_only=dist_only)

x_shrinked = __reduce_by_half(x)
y_shrinked = __reduce_by_half(y)
distance, path = \
__fastdtw(x_shrinked, y_shrinked, radius=radius, dist=dist)
window = __expand_window(path, len(x), len(y), radius)
return __dtw(x, y, window, dist=dist)
return __dtw(x, y, window, dist=dist, dist_only=dist_only)


def __prep_inputs(x, y, dist):
Expand All @@ -95,7 +99,7 @@ def __prep_inputs(x, y, dist):
return x, y, dist


def dtw(x, y, dist=None):
def dtw(x, y, dist=None, dist_only=False):
''' return the distance between 2 time series without approximation

Parameters
Expand All @@ -109,13 +113,17 @@ def dtw(x, y, dist=None):
dist is an int of value p > 0, then the p-norm will be used. If
dist is a function then dist(x[i], y[j]) will be used. If dist is
None then abs(x[i] - y[j]) will be used.
dist_only : bool
If you are only interested in the distance metric for DTW, set the
value too True and it will not return a path

Returns
-------
distance : float
the approximate distance between the 2 time series
path : list
list of indexes for the inputs x and y
path : list or None
list of indexes for the inputs x and y
(If dist_only is set to True then the path will be None)

Examples
--------
Expand All @@ -127,10 +135,10 @@ def dtw(x, y, dist=None):
(2.0, [(0, 0), (1, 0), (2, 1), (3, 2), (4, 2)])
'''
x, y, dist = __prep_inputs(x, y, dist)
return __dtw(x, y, None, dist)
return __dtw(x, y, None, dist, dist_only)


def __dtw(x, y, window, dist):
def __dtw(x, y, window, dist, dist_only):
len_x, len_y = len(x), len(y)
if window is None:
window = [(i, j) for i in range(len_x) for j in range(len_y)]
Expand All @@ -141,12 +149,16 @@ def __dtw(x, y, window, dist):
dt = dist(x[i-1], y[j-1])
D[i, j] = min((D[i-1, j][0]+dt, i-1, j), (D[i, j-1][0]+dt, i, j-1),
(D[i-1, j-1][0]+dt, i-1, j-1), key=lambda a: a[0])
path = []
i, j = len_x, len_y
while not (i == j == 0):
path.append((i-1, j-1))
i, j = D[i, j][1], D[i, j][2]
path.reverse()

path = None
if not dist_only:
path = []
i, j = len_x, len_y
while not (i == j == 0):
path.append((i-1, j-1))
i, j = D[i, j][1], D[i, j][2]
path.reverse()

return (D[len_x, len_y][0], path)


Expand Down
16 changes: 13 additions & 3 deletions tests/test_fastdtw.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import numpy as np

from fastdtw._fastdtw import fastdtw as fastdtw_c
from fastdtw._fastdtw import dtw as dtw_c
# from fastdtw._fastdtw import fastdtw as fastdtw_c
# from fastdtw._fastdtw import dtw as dtw_c
from fastdtw.fastdtw import fastdtw as fastdtw_p
from fastdtw.fastdtw import dtw as dtw_p

Expand All @@ -29,7 +29,7 @@ def test_1d_fastdtw(self):

def test_1d_dtw(self):
distance_c = dtw_c(self.x_1d, self.y_1d)[0]
distance_p = dtw_p(self.x_1d, self.y_1d)[0]
c
self.assertEqual(distance_c, 2)
self.assertEqual(distance_c, distance_p)

Expand All @@ -54,6 +54,16 @@ def test_default_dist(self):
self.assertEqual(d1, d2)
self.assertEqual(d1, d3)
self.assertEqual(d1, d4)

def test_fastdtw_dist_only(self):
d1, p1 = fastdtw_p([[1,2]], [[2,2],[1,1]], dist_only=True)
d2, p2 = dtw_p([[1,2]], [[2,2],[1,1]], dist_only=True)
# Check paths were not generated
self.assertEqual(p1, None)
self.assertEqual(p2, None)
# Check distances
self.assertEqual(d1, 2)
self.assertEqual(d2, 2)

if __name__ == '__main__':
unittest.main()