From d462233f80c11ef0614c0eb818c2dc7d45c88492 Mon Sep 17 00:00:00 2001 From: samurain Date: Sun, 11 May 2014 12:21:48 +0900 Subject: [PATCH] ipf running, ipf stops for max iter and threshold --- src/representations.py | 46 +++++++++++++++++++++++++++++++++++------- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/src/representations.py b/src/representations.py index d369328..aab1eaa 100644 --- a/src/representations.py +++ b/src/representations.py @@ -4,6 +4,10 @@ import Data # from sets import Set,ImmutableSet +MAX_IPF_ITER = 100 +IPF_CONV_THRESH = 0.001 + + def calculate_entropy_of_ndarray(probability_matrix): calc_cell_entropies = np.vectorize(lambda x: -x*np.log2(x) if x > 0 else 0.) return np.sum(calc_cell_entropies(probability_matrix)) @@ -123,14 +127,42 @@ def __init__(self, component_list, dataset): # print projection_list # initialize q: - q = np.zeros(self.var_cards) + q = np.zeros(var_cards) q[:] = 1./q.size # initialize with equal probs - # while error > threshold - for p in projection_list: - # get q's projection, q_proj - # update elements of q as below - # new_q = old_q * (p/q_proj) - # TODO: IPF goes here. + assert len(component_list) == len(projection_list) + froe2_norm = np.sum(q**2) + cont = True + itr = 1 + while (cont and itr < MAX_IPF_ITER): + for i,k in enumerate(component_list): + var_names = [var.name for var in k.var_list] + q_proj = project_q(dataset.variable_names,var_names,q) + q = q * (projection_list[i]/q_proj) + new_froe2_norm = np.sum(q**2) + cont = abs(new_froe2_norm - froe2_norm) > IPF_CONV_THRESH + froe2_norm = new_froe2_norm + itr += 1 + + +def project_q(all_variable_names,variable_list,q): + if all(isinstance(variable,int) for variable in variable_list): + pointer_list = variable_list + elif all(isinstance(variable,str) for variable in variable_list): + # TODO convert to int + pointer_list = [all_variable_names.index(var) + for var in variable_list] + else: + raise Exception("invalid variable_list parmeter in \ + extract_component.") + + unwanted_variables = [v for v in range(len(all_variable_names)) + if not v in pointer_list] + + # Keep dims so math works out more easily later and to track + # which variables are aggregated. + return np.sum(a=q, + axis=tuple(unwanted_variables), + keepdims=True) if __name__ == "__main__":