Skip to content

Commit 06c1860

Browse files
committed
HMM (without EM)
1 parent cc24864 commit 06c1860

File tree

6 files changed

+593
-24
lines changed

6 files changed

+593
-24
lines changed

pytrain/HMM/HMM.py

+75-8
Original file line numberDiff line numberDiff line change
@@ -5,20 +5,87 @@
55
66
#
77

8-
from numpy import *
8+
import numpy as np
99
from pytrain.lib import convert
1010

1111
class HMM:
1212

13-
def __init__(self, mat_data, label_data):
14-
self.mat_data = convert.list2npfloat(mat_data)
15-
self.label_data = convert.list2npfloat(label_data)
13+
def __init__(self, mat_data, label_data, hidden_state_labeled = True, hidden_state = -1):
14+
self.mat_data = mat_data
15+
self.label_data = label_data
16+
self.n = len(mat_data[0][0])
17+
if hidden_state_labeled :
18+
self.label_set = []
19+
for seq_label in label_data:
20+
self.label_set.extend(seq_label)
21+
self.label_set = list(set(self.label_set))
22+
self.label_idx = { x:i for i, x in enumerate(self.label_set)}
23+
self.m = len(self.label_set)
24+
self.make_freqtable()
25+
elif not hidden_state_labeled:
26+
self.m = hidden_state
27+
self.label_set = list(range(hidden_state))
28+
self.make_randomtable()
1629

17-
def fit(self, lr, epoch, stoc):
18-
19-
# TODO : Implement HMM code
30+
def make_randomtable(self):
31+
self.a = np.random.random([self.m, self.m])
32+
self.b = np.random.random([self.m, self.n])
33+
self.a = np.log(self.a / self.a.sum(axis=1).reshape((self.m,1)))
34+
self.b = np.log(self.b / self.b.sum(axis=1).reshape((self.m,1)))
35+
36+
def make_freqtable(self):
37+
self.a = np.zeros([self.m, self.m]) + 0.000001
38+
self.b = np.zeros([self.m, self.n]) + 0.000001
39+
for seq_idx, seq_label in enumerate(self.label_data):
40+
for i in range(1, len(seq_label)):
41+
now = seq_label[i]
42+
prev = seq_label[i-1]
43+
now_ob = self.mat_data[seq_idx][i]
44+
self.b[self.label_idx[now]] += now_ob
45+
self.a[self.label_idx[prev]][self.label_idx[now]] += 1
46+
self.b = np.log(self.b / self.b.sum(axis=1).reshape((self.m,1)))
47+
self.a = np.log(self.a / self.a.sum(axis=1).reshape((self.m,1)))
48+
49+
def viterbi(self, array_input):
50+
t = len(array_input)
51+
# self.prob :: index[0] is prob, index[1] is from idx
52+
self.prob = np.zeros([t, self.m, 2]) - 10000000
53+
first_ob_idx = np.nonzero(array_input[0])[0]
54+
first = self.b[:,first_ob_idx].sum(axis=1)
55+
first_prob = np.transpose(np.tile(first,(self.m,1)))
56+
first_prob[:,1:] = -1
57+
self.prob[0] = first_prob[:,:2]
58+
for i in range(1,t):
59+
now_ob_idx = np.nonzero(array_input[i])[0]
60+
for j in range(self.m):
61+
max_prob = self.prob[i][j][0]
62+
max_idx = self.prob[i][j][1]
63+
for k in range(self.m):
64+
now_prob = self.prob[i-1][k][0] + \
65+
self.a[k][j] + self.b[j,now_ob_idx].sum(axis=0)
66+
if max_prob < now_prob:
67+
max_prob = now_prob
68+
max_idx = k
69+
self.prob[i][j][0] = max_prob
70+
self.prob[i][j][1] = max_idx
71+
last_idx = -1
72+
last_max = -10000000
73+
for j in range(self.m):
74+
if self.prob[t-1][j][0] > last_max:
75+
last_idx = int(j)
76+
last_max = self.prob[t-1][j][0]
77+
trace = []
78+
for at in range(t-1,-1,-1):
79+
trace.append(int(last_idx))
80+
last_idx = self.prob[at][int(last_idx)][1]
81+
return trace[::-1]
82+
83+
def fit(self, toler, epoch):
84+
# TODO : Baum-welch EM algorithm implementation
2085

2186
pass
2287

2388
def predict(self, array_input):
24-
pass
89+
seq_of_label = self.viterbi(array_input)
90+
ret = [self.label_set[x] for x in seq_of_label]
91+
return ret

pytrain/lib/nlp.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,8 @@ def extract_vocabulary(self, documents):
6161
for doc in documents:
6262
if str(type(doc).__name__) == 'str':
6363
doc = self.split2words(doc)
64-
lowerdoc = [ x.lower() for x in doc ]
6564
ndoc = []
66-
for w in lowerdoc:
65+
for w in doc:
6766
if w not in self.stopwords:
6867
ndoc.append(w)
6968
vocabulary = vocabulary | set(ndoc)
@@ -86,3 +85,13 @@ def bag_of_word2vector(self, vocabulary, sentence):
8685
if word in vocabulary:
8786
voca_vector[vocabulary.index(word)] += 1
8887
return voca_vector
88+
89+
def set_of_wordseq2matrix(self, vocabulary, wordlist):
90+
word_mat = []
91+
for word in wordlist:
92+
word_vector = [0] * len(vocabulary)
93+
if word in vocabulary:
94+
word_vector[vocabulary.index(word)] = 1
95+
word_mat.append(word_vector)
96+
return word_mat
97+

0 commit comments

Comments
 (0)