Skip to content

Commit

Permalink
ipf running, ipf stops for max iter and threshold
Browse files Browse the repository at this point in the history
  • Loading branch information
samurain committed May 11, 2014
1 parent f78b317 commit d462233
Showing 1 changed file with 39 additions and 7 deletions.
46 changes: 39 additions & 7 deletions src/representations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import Data
# from sets import Set,ImmutableSet

MAX_IPF_ITER = 100
IPF_CONV_THRESH = 0.001


def calculate_entropy_of_ndarray(probability_matrix):
calc_cell_entropies = np.vectorize(lambda x: -x*np.log2(x) if x > 0 else 0.)
return np.sum(calc_cell_entropies(probability_matrix))
Expand Down Expand Up @@ -123,14 +127,42 @@ def __init__(self, component_list, dataset):
# print projection_list

# initialize q:
q = np.zeros(self.var_cards)
q = np.zeros(var_cards)
q[:] = 1./q.size # initialize with equal probs
# while error > threshold
for p in projection_list:
# get q's projection, q_proj
# update elements of q as below
# new_q = old_q * (p/q_proj)
# TODO: IPF goes here.
assert len(component_list) == len(projection_list)
froe2_norm = np.sum(q**2)
cont = True
itr = 1
while (cont and itr < MAX_IPF_ITER):
for i,k in enumerate(component_list):
var_names = [var.name for var in k.var_list]
q_proj = project_q(dataset.variable_names,var_names,q)
q = q * (projection_list[i]/q_proj)
new_froe2_norm = np.sum(q**2)
cont = abs(new_froe2_norm - froe2_norm) > IPF_CONV_THRESH
froe2_norm = new_froe2_norm
itr += 1


def project_q(all_variable_names,variable_list,q):
if all(isinstance(variable,int) for variable in variable_list):
pointer_list = variable_list
elif all(isinstance(variable,str) for variable in variable_list):
# TODO convert to int
pointer_list = [all_variable_names.index(var)
for var in variable_list]
else:
raise Exception("invalid variable_list parmeter in \
extract_component.")

unwanted_variables = [v for v in range(len(all_variable_names))
if not v in pointer_list]

# Keep dims so math works out more easily later and to track
# which variables are aggregated.
return np.sum(a=q,
axis=tuple(unwanted_variables),
keepdims=True)


if __name__ == "__main__":
Expand Down

0 comments on commit d462233

Please sign in to comment.