-
Notifications
You must be signed in to change notification settings - Fork 0
/
levenshtein_distance_with_trie.py
127 lines (102 loc) · 3.85 KB
/
levenshtein_distance_with_trie.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/python
#By Steve Hanov, 2011. Released to the public domain
import time
import sys
from collections import defaultdict
# The Trie data structure keeps a set of words, organized with one node for
# each letter. Each node has a branch for each letter that may follow it in the
# set of words.
class LevenshteinTrie:
def __init__(self, food_words, distance_cache = None, **weight_dict):
alphabet = "abcdefghijklmnopqrstuvwxyz "
self.w = defaultdict(lambda:(1,1,1))
# self.w = dict( (x, (1, 1, 1)) for x in alphabet + alphabet.upper())
if weight_dict:
self.w.update(weight_dict)
if distance_cache is None:
self.distance_cache = {}
else:
self.distance_cache = distance_cache
self.trie = TrieNode()
self.wordcount = 0
for word in food_words:
self.wordcount += 1
self.trie.insert(word)
self.nodecount = 0
# The search function returns a list of all words that are less than the given
# maximum distance from the target word
def search(self, word, maxCost):
# build first row
currentRow = range(len(word) + 1)
results = []
# recursively search each branch of the trie
for letter in self.trie.children:
self.searchRecursive(self.trie.children[letter], letter, word, currentRow,
results, maxCost)
return results
# This recursive helper is used by the search function above. It assumes that
# the previousRow has been filled in already.
def searchRecursive(self, node, letter, word, previousRow, results, maxCost):
columns = len(word) + 1
currentRow = [previousRow[0] + 1]
# Build one row for the letter, with a column for each letter in the target
# word, plus one for the empty string at column 0
for column in xrange(1, columns):
# deletes = self.w[s[row - 1]][0] #
deletes = self.w[word[column - 1]][0]
# inserts = self.w[t[col - 1]][1]
inserts = self.w[letter][1]
subs = max(self.w[word[column - 1]][2], self.w[letter][2])
insertCost = currentRow[column - 1] + inserts
deleteCost = previousRow[column] + deletes
if word[column - 1] != letter:
replaceCost = previousRow[column - 1] + subs
else:
replaceCost = previousRow[column - 1]
currentRow.append(min(insertCost, deleteCost, replaceCost))
# if the last entry in the row indicates the optimal cost is less than the
# maximum cost, and there is a word in this trie node, then add it.
if currentRow[-1] <= maxCost and node.word != None:
results.append((node.word, currentRow[-1]))
# if any entries in the row are less than the maximum cost, then
# recursively search each branch of the trie
if min(currentRow) <= maxCost:
for letter in node.children:
self.searchRecursive(node.children[letter], letter, word, currentRow,
results, maxCost)
class TrieNode:
def __init__(self):
self.word = None
self.children = {}
def insert(self, word ):
node = self
for letter in word:
if letter not in node.children:
node.children[letter] = TrieNode()
node = node.children[letter]
node.word = word
def get_levenshtein_distance_object(food_words, setting = 'system1', ):
if setting == 'system1':
ld = LevenshteinTrie(food_words, a=(3, 3, 1),
e=(3, 3, 1),
i=(3, 3, 1),
o=(3, 3, 1),
u=(3, 3, 1),
s=(0, 0, 1))
return ld
elif setting == 'system2':
ld = LevenshteinTrie(food_words)
return ld
elif setting == 'system3':
ld = LevenshteinTrie(food_words, s = (0, 0, 1))
return ld
else:
print(setting)
raise ValueError
if __name__ == '__main__':
start = time.time()
lt = get_levenshtein_distance_object(food_words='cat dog gumbo gambol'.split(), setting='system2')
results = lt.search('gumbo', 5000)
end = time.time()
for result in results: print result
print "Search took %g s" % (end - start)