Skip to content

Commit

Permalink
changed allow_ground_atoms to allow_ground_rules
Browse files Browse the repository at this point in the history
  • Loading branch information
dyumanaditya committed Aug 4, 2024
1 parent aae4943 commit 7aa3180
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 22 deletions.
14 changes: 7 additions & 7 deletions pyreason/pyreason.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def __init__(self):
self.__store_interpretation_changes = True
self.__parallel_computing = False
self.__update_mode = 'intersection'
self.__allow_ground_atoms = False
self.__allow_ground_rules = False

@property
def verbose(self) -> bool:
Expand Down Expand Up @@ -170,12 +170,12 @@ def update_mode(self) -> str:
return self.__update_mode

@property
def allow_ground_atoms(self) -> bool:
def allow_ground_rules(self) -> bool:
"""Returns whether rules can have ground atoms or not. Default is False
:return: bool
"""
return self.__allow_ground_atoms
return self.__allow_ground_rules

@verbose.setter
def verbose(self, value: bool) -> None:
Expand Down Expand Up @@ -364,8 +364,8 @@ def update_mode(self, value: str) -> None:
else:
self.__update_mode = value

@allow_ground_atoms.setter
def allow_ground_atoms(self, value: bool) -> None:
@allow_ground_rules.setter
def allow_ground_rules(self, value: bool) -> None:
"""Allow ground atoms to be used in rules when possible. Default is False
:param value: Whether to allow ground atoms or not
Expand All @@ -374,7 +374,7 @@ def allow_ground_atoms(self, value: bool) -> None:
if not isinstance(value, bool):
raise TypeError('value has to be a bool')
else:
self.__allow_ground_atoms = value
self.__allow_ground_rules = value


# VARIABLES
Expand Down Expand Up @@ -660,7 +660,7 @@ def _reason(timesteps, convergence_threshold, convergence_bound_threshold):
annotation_functions = tuple(__annotation_functions)

# Setup logical program
__program = Program(__graph, all_node_facts, all_edge_facts, __rules, __ipl, annotation_functions, settings.reverse_digraph, settings.atom_trace, settings.save_graph_attributes_to_trace, settings.canonical, settings.inconsistency_check, settings.store_interpretation_changes, settings.parallel_computing, settings.update_mode, settings.allow_ground_atoms)
__program = Program(__graph, all_node_facts, all_edge_facts, __rules, __ipl, annotation_functions, settings.reverse_digraph, settings.atom_trace, settings.save_graph_attributes_to_trace, settings.canonical, settings.inconsistency_check, settings.store_interpretation_changes, settings.parallel_computing, settings.update_mode, settings.allow_ground_rules)
__program.available_labels_node = __node_labels
__program.available_labels_edge = __edge_labels
__program.specific_node_labels = __specific_node_labels
Expand Down
22 changes: 11 additions & 11 deletions pyreason/scripts/interpretation/interpretation.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ class Interpretation:
specific_node_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(node_type))
specific_edge_labels = numba.typed.Dict.empty(key_type=label.label_type, value_type=numba.types.ListType(edge_type))

def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_atoms):
def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules):
self.graph = graph
self.ipl = ipl
self.annotation_functions = annotation_functions
Expand All @@ -66,7 +66,7 @@ def __init__(self, graph, ipl, annotation_functions, reverse_graph, atom_trace,
self.inconsistency_check = inconsistency_check
self.store_interpretation_changes = store_interpretation_changes
self.update_mode = update_mode
self.allow_ground_atoms = allow_ground_atoms
self.allow_ground_rules = allow_ground_rules

# For reasoning and reasoning again (contains previous time and previous fp operation cnt)
self.time = 0
Expand Down Expand Up @@ -205,7 +205,7 @@ def _init_facts(facts_node, facts_edge, facts_to_be_applied_node, facts_to_be_ap
return max_time

def _start_fp(self, rules, max_facts_time, verbose, again):
fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.canonical, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_atoms, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, verbose, again)
fp_cnt, t = self.reason(self.interpretations_node, self.interpretations_edge, self.tmax, self.prev_reasoning_data, rules, self.nodes, self.edges, self.neighbors, self.reverse_neighbors, self.rules_to_be_applied_node, self.rules_to_be_applied_edge, self.edges_to_be_added_node_rule, self.edges_to_be_added_edge_rule, self.rules_to_be_applied_node_trace, self.rules_to_be_applied_edge_trace, self.facts_to_be_applied_node, self.facts_to_be_applied_edge, self.facts_to_be_applied_node_trace, self.facts_to_be_applied_edge_trace, self.ipl, self.rule_trace_node, self.rule_trace_edge, self.rule_trace_node_atoms, self.rule_trace_edge_atoms, self.reverse_graph, self.atom_trace, self.save_graph_attributes_to_rule_trace, self.canonical, self.inconsistency_check, self.store_interpretation_changes, self.update_mode, self.allow_ground_rules, max_facts_time, self.annotation_functions, self._convergence_mode, self._convergence_delta, verbose, again)
self.time = t - 1
# If we need to reason again, store the next timestep to start from
self.prev_reasoning_data[0] = t
Expand All @@ -215,7 +215,7 @@ def _start_fp(self, rules, max_facts_time, verbose, again):

@staticmethod
@numba.njit(cache=True, parallel=False)
def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_atoms, max_facts_time, annotation_functions, convergence_mode, convergence_delta, verbose, again):
def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data, rules, nodes, edges, neighbors, reverse_neighbors, rules_to_be_applied_node, rules_to_be_applied_edge, edges_to_be_added_node_rule, edges_to_be_added_edge_rule, rules_to_be_applied_node_trace, rules_to_be_applied_edge_trace, facts_to_be_applied_node, facts_to_be_applied_edge, facts_to_be_applied_node_trace, facts_to_be_applied_edge_trace, ipl, rule_trace_node, rule_trace_edge, rule_trace_node_atoms, rule_trace_edge_atoms, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, update_mode, allow_ground_rules, max_facts_time, annotation_functions, convergence_mode, convergence_delta, verbose, again):
t = prev_reasoning_data[0]
fp_cnt = prev_reasoning_data[1]
max_rules_time = 0
Expand Down Expand Up @@ -520,7 +520,7 @@ def reason(interpretations_node, interpretations_edge, tmax, prev_reasoning_data
# Only go through if the rule can be applied within the given timesteps, or we're running until convergence
delta_t = rule.get_delta()
if t + delta_t <= tmax or tmax == -1 or again:
applicable_node_rules, applicable_edge_rules = _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_atoms)
applicable_node_rules, applicable_edge_rules = _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules)

# Loop through applicable rules and add them to the rules to be applied for later or next fp operation
for applicable_rule in applicable_node_rules:
Expand Down Expand Up @@ -731,7 +731,7 @@ def query(self, query, return_bool=True):


@numba.njit(cache=True)
def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_atoms):
def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges, neighbors, reverse_neighbors, atom_trace, allow_ground_rules):
# Extract rule params
rule_type = rule.get_type()
head_variables = rule.get_head_variables()
Expand Down Expand Up @@ -777,7 +777,7 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges,

# Get subset of nodes that can be used to ground the variable
# If we allow ground atoms, we can use the nodes directly
if allow_ground_atoms and clause_var_1 in nodes:
if allow_ground_rules and clause_var_1 in nodes:
grounding = numba.typed.List([clause_var_1])
else:
grounding = get_rule_node_clause_grounding(clause_var_1, groundings, nodes)
Expand All @@ -800,7 +800,7 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges,

# Get subset of edges that can be used to ground the variables
# If we allow ground atoms, we can use the nodes directly
if allow_ground_atoms and (clause_var_1, clause_var_2) in edges:
if allow_ground_rules and (clause_var_1, clause_var_2) in edges:
grounding = numba.typed.List([(clause_var_1, clause_var_2)])
else:
grounding = get_rule_edge_clause_grounding(clause_var_1, clause_var_2, groundings, groundings_edges, neighbors, reverse_neighbors, nodes)
Expand Down Expand Up @@ -862,7 +862,7 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges,
# If there is no grounding for head_var_1, we treat it as a ground atom and add it to the graph
head_var_1_in_nodes = head_var_1 in nodes
add_head_var_node_to_graph = False
if allow_ground_atoms and head_var_1_in_nodes:
if allow_ground_rules and head_var_1_in_nodes:
groundings[head_var_1] = numba.typed.List([head_var_1])
elif head_var_1 not in groundings:
if not head_var_1_in_nodes:
Expand Down Expand Up @@ -955,9 +955,9 @@ def _ground_rule(rule, interpretations_node, interpretations_edge, nodes, edges,
add_head_var_1_node_to_graph = False
add_head_var_2_node_to_graph = False
add_head_edge_to_graph = False
if allow_ground_atoms and head_var_1_in_nodes:
if allow_ground_rules and head_var_1_in_nodes:
groundings[head_var_1] = numba.typed.List([head_var_1])
if allow_ground_atoms and head_var_2_in_nodes:
if allow_ground_rules and head_var_2_in_nodes:
groundings[head_var_2] = numba.typed.List([head_var_2])

if head_var_1 not in groundings:
Expand Down
8 changes: 4 additions & 4 deletions pyreason/scripts/program/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ class Program:
specific_node_labels = []
specific_edge_labels = []

def __init__(self, graph, facts_node, facts_edge, rules, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, parallel_computing, update_mode, allow_ground_atoms):
def __init__(self, graph, facts_node, facts_edge, rules, ipl, annotation_functions, reverse_graph, atom_trace, save_graph_attributes_to_rule_trace, canonical, inconsistency_check, store_interpretation_changes, parallel_computing, update_mode, allow_ground_rules):
self._graph = graph
self._facts_node = facts_node
self._facts_edge = facts_edge
Expand All @@ -23,7 +23,7 @@ def __init__(self, graph, facts_node, facts_edge, rules, ipl, annotation_functio
self._store_interpretation_changes = store_interpretation_changes
self._parallel_computing = parallel_computing
self._update_mode = update_mode
self._allow_ground_atoms = allow_ground_atoms
self._allow_ground_rules = allow_ground_rules
self.interp = None

def reason(self, tmax, convergence_threshold, convergence_bound_threshold, verbose=True):
Expand All @@ -36,9 +36,9 @@ def reason(self, tmax, convergence_threshold, convergence_bound_threshold, verbo

# Instantiate correct interpretation class based on whether we parallelize the code or not. (We cannot parallelize with cache on)
if self._parallel_computing:
self.interp = InterpretationParallel(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode, self._allow_ground_atoms)
self.interp = InterpretationParallel(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode, self._allow_ground_rules)
else:
self.interp = Interpretation(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode, self._allow_ground_atoms)
self.interp = Interpretation(self._graph, self._ipl, self._annotation_functions, self._reverse_graph, self._atom_trace, self._save_graph_attributes_to_rule_trace, self._canonical, self._inconsistency_check, self._store_interpretation_changes, self._update_mode, self._allow_ground_rules)
self.interp.start_fp(self._tmax, self._facts_node, self._facts_edge, self._rules, verbose, convergence_threshold, convergence_bound_threshold)

return self.interp
Expand Down

0 comments on commit 7aa3180

Please sign in to comment.