-
Notifications
You must be signed in to change notification settings - Fork 0
/
viterbi.py
177 lines (135 loc) · 4.72 KB
/
viterbi.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
import numpy as np
import sys
# Open the file in read mode
#file_path = 'input.txt'
file_path = sys.argv[1]
file = open(file_path, 'r')
# Read all lines from the file
lines = file.readlines()
# Extract the required lines
first_line = lines[0].strip()
map_size = [int(x) for x in first_line.split()]
raw_map = []
for i in range(1, map_size[0] + 1):
row = lines[i].strip().split()
raw_map.append(row)
observations = []
for i in range(map_size[0] + 2, len(lines) - 1):
observations.append(lines[i].strip())
error_rate = float(lines[-1].strip())
# Close the file
file.close()
# count how many '0' in map
def count_zero(map):
count = 0
for row in map:
for col in row:
if col == '0':
count += 1
return count
K = count_zero(raw_map)
init_prob = 1/K
# convert raw_map to numpy matrix, where '0' is init_prob, 'X' is 0.0
init_map = np.array(raw_map)
init_map = np.where(init_map == '0', init_prob, init_map)
init_map = np.where(init_map == 'X', 0.0, init_map)
init_map = init_map.astype(np.float64)
# Generate all possible permutations
permutations = []
for i in range(2):
for j in range(2):
for k in range(2):
for l in range(2):
permutations.append((i, j, k, l))
# observation
O = {i+1: permutation for i, permutation in enumerate(permutations)}
N = len(O)
S = {} # state space
# Iterate over the matrix and find the non-zero values
for i in range(len(init_map)):
for j in range(len(init_map[i])):
if init_map[i][j] != 0:
state = len(S) + 1
S[state] = (i+1, j+1)
Pi = [1/K] * K # initial state probability
Y = observations # a list of observations
T = len(Y) # number of observations
# convert Y to a list of tuples
Y = [tuple(int(i) for i in y) for y in Y]
# convert Y to a list of index based on O
Y = [list(O.keys())[list(O.values()).index(y)] for y in Y]
# given a state, return the key in S
def get_state_key(state):
for key, value in S.items():
if value == state:
return key
# define the probability of transition
def transition_probability(state_key):
state = S[state_key]
neighbors_dict = {}
neighbors = []
i, j = state
possible_neighbors = [ (i-1, j), (i+1, j), (i, j-1), (i, j+1)] # in order N, S, W, E
count_neighbors = 0
for neighbor in possible_neighbors:
if neighbor in S.values():
neighbors.append(neighbor)
count_neighbors += 1
for neighbor in neighbors:
key = get_state_key(neighbor)
neighbors_dict[key] = 1/count_neighbors
return neighbors_dict
# initial Tm a matrix of size K x K
Tm = np.zeros((K, K))
# both row and column are the possible states from S, the value is the probability of transition
for i in range(1, K+1):
neighbors_dict = transition_probability(i)
for j in range(1, K+1):
if j in neighbors_dict.keys():
Tm[i-1][j-1] = neighbors_dict[j] # the index of matrix starts from 0
# method of calculating the true observation
def true_observation(state_key):
state = S[state_key]
i, j = state
possible_neighbors = [ (i-1, j), (i+1, j), (i, j-1), (i, j+1)] # in order N, S, W, E
true_observation = [0, 0, 0, 0] # 1 means obsticle, 0 means no obsticle
for k, neighbor in enumerate(possible_neighbors):
if neighbor not in S.values():
true_observation[k] = 1
return true_observation
def count_different_items(list1, list2): # count means the difference
return sum(x != y for x, y in zip(list1, list2))
# initial Em a matrix of size K x N
Em = np.zeros((K, N))
# row is the possible states from S, column is the possible observations from O,
# the value is the probability of observation
for i in range(1, K+1):
#state = S[i] # (row, col)
true_obs = true_observation(i) # [N, S, W, E]
for j in range(1, N+1):
observation = O[j] #[N, S, W, E]
count = count_different_items(true_obs, observation)
Em[i-1][j-1] = (1-error_rate)**(4 - count) * error_rate**count
def viterbi_forward(O, S, pi, Y, Tm, Em):
K = len(S)
N = len(O)
T = len(Y)
trellis = np.zeros((K, T))
# Initialization
for i in range(K):
trellis[i, 0] = pi[i] * Em[i][Y[0] -1]
# Recursion
for j in range(1,T):
for i in range(K):
max_prob = max([trellis[k, j-1] * Tm[k, i] * Em[i][Y[j]-1] for k in range(K)])
trellis[i, j] = max_prob
return trellis
trellis = viterbi_forward(O, S, Pi, Y, Tm, Em)
result = [np.zeros(map_size) for i in range(T)]
transposed_trellis = np.transpose(trellis)
for t, prob in enumerate(transposed_trellis):
for index, p in enumerate(prob):
i,j = S[index+1]
result[t][i-1][j-1] = p
np.savez("output.npz", *result)
#print(result)