diff --git a/outrank/core_ranking.py b/outrank/core_ranking.py index 6d1e855..a406c26 100644 --- a/outrank/core_ranking.py +++ b/outrank/core_ranking.py @@ -116,12 +116,21 @@ def mixed_rank_graph( out_time_struct['encoding_columns'] = end_enc_timer - start_enc_timer combinations = get_combinations_from_columns(all_columns, args) - #combinations = prior_combinations_sample(combinations, args) - #random.shuffle(combinations) reference_model_features = {} if is_prior_heuristic(args): - reference_model_features = [(" AND ").join(tuple(sorted(item.split(",")))) for item in extract_features_from_reference_JSON(args.reference_model_JSON, full_feature_space = True)] + reference_model_features = [(" AND ").join(tuple(sorted(item.split(",")))) for item in extract_features_from_reference_JSON(args.reference_model_JSON, all_features=True)] + combinations = [comb for comb in combinations if comb[0] not in reference_model_features and comb[1] not in reference_model_features] + print(combinations) + print("\n\n") + + combinations = prior_combinations_sample(combinations, args) + print(GLOBAL_PRIOR_COMB_COUNTS) + print(combinations) + print("\n\n") + random.shuffle(combinations) + print(combinations) + print("\n\n") if args.heuristic == 'Constant': final_constant_imp = [] @@ -196,30 +205,20 @@ def compute_combined_features( model_combinations = [] full_combination_space = [] - if is_prior_heuristic(args): + + if args.reference_model_JSON != '': model_combinations = extract_features_from_reference_JSON(args.reference_model_JSON, combined_features_only = True) model_combinations = [tuple(sorted(combination.split(','))) for combination in model_combinations] - if args.interaction_order > 1: - full_combination_space = list( - itertools.combinations(all_columns, interaction_order), - ) - else: - if args.reference_model_JSON != '': - model_combinations = extract_features_from_reference_JSON(args.reference_model_JSON, combined_features_only = True) - model_combinations = [tuple(sorted(combination.split(','))) for combination in model_combinations] - full_combination_space = model_combinations - else: + full_combination_space = model_combinations + + if args.interaction_order > 1: full_combination_space = list( itertools.combinations(all_columns, interaction_order), ) + full_combination_space = prior_combinations_sample(full_combination_space, args) - if args.combination_number_upper_bound: - random.shuffle(full_combination_space) - full_combination_space = full_combination_space[ - : args.combination_number_upper_bound - ] - if is_prior_heuristic(args): - full_combination_space = full_combination_space + [tuple for tuple in model_combinations if tuple not in full_combination_space] + if is_prior_heuristic(args): + full_combination_space = full_combination_space + [tuple for tuple in model_combinations if tuple not in full_combination_space] com_counter = 0