Skip to content

Commit

Permalink
Debug annotates
Browse files Browse the repository at this point in the history
  • Loading branch information
renatahodovan committed Dec 11, 2023
1 parent 4314fed commit f561895
Show file tree
Hide file tree
Showing 2 changed files with 68 additions and 28 deletions.
8 changes: 7 additions & 1 deletion grammarinator/runtime/rule.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,7 +311,7 @@ def __repr__(self):
return f'{self.__class__.__name__}({", ".join(parts)})'

def _dbg_(self):
return '{name}\n{children}'.format(name=self.name, children=indent('\n'.join(child._dbg_() for child in self.children), '| '))
return '{name}\n{children}'.format(name=self.name or self.__class__.__name__, children=indent('\n'.join(child._dbg_() for child in self.children), '| '))


class UnparserRule(ParentRule):
Expand Down Expand Up @@ -412,6 +412,9 @@ def __repr__(self):
def __deepcopy__(self, memo):
return UnparserRuleQuantifier(idx=deepcopy(self.idx, memo), start=deepcopy(self.start, memo), stop=deepcopy(self.stop, memo), children=[deepcopy(child, memo) for child in self.children])

def _dbg_(self):
return '{name}:[{idx}]\n{children}'.format(idx=self.idx, name=self.__class__.__name__, children=indent('\n'.join(child._dbg_() for child in self.children), '| '))


class UnparserRuleQuantified(ParentRule):
"""
Expand Down Expand Up @@ -455,3 +458,6 @@ def __deepcopy__(self, memo):
return UnparserRuleAlternative(alt_idx=deepcopy(self.alt_idx, memo),
idx=deepcopy(self.idx, memo),
children=[deepcopy(child, memo) for child in self.children])

def _dbg_(self):
return '{name}:[{alt_idx}/{idx}]\n{children}'.format(name=self.__class__.__name__, alt_idx=self.alt_idx, idx=self.idx, children=indent('\n'.join(child._dbg_() for child in self.children), '| '))
88 changes: 61 additions & 27 deletions grammarinator/tool/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from os.path import abspath, dirname
from shutil import rmtree

from ..runtime import CooldownModel, DefaultModel, ParentRule, RuleSize, UnlexerRule, UnparserRule
from ..runtime import CooldownModel, DefaultModel, ParentRule, RuleSize, UnlexerRule, UnparserRule, UnparserRuleAlternative, UnparserRuleQuantifier

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -293,7 +293,15 @@ def mutate(self):
"""
root, annot = self._select_individual()

options = self._filter_nodes((node for nodes in annot.nodes_by_name.values() for node in nodes), root, annot)
# Filter items from the nodes of the selected tree that can be regenerated
# within the current maximum depth and token limit (except '<INVALID>' and
# immutable nodes).
options = [node for nodes in annot.rules_by_name.values() for node in nodes
if (node.parent is not None
and node.name != '<INVALID>'
and node.name not in self._generator_factory._immutable_rules
and annot.node_levels[node] + self._generator_factory._rule_sizes.get(node.name, RuleSize(0, 0)).depth < self._limit.depth
and annot.token_counts[root] - annot.token_counts[node] + self._generator_factory._rule_sizes.get(node.name, RuleSize(0, 0)).tokens < self._limit.tokens)]
if options:
mutated_node = random.choice(options)
reserve = RuleSize(depth=annot.node_levels[mutated_node],
Expand All @@ -318,24 +326,31 @@ def recombine(self):
:rtype: Rule
"""
recipient_root, recipient_annot = self._select_individual()
donor_root, donor_annot = self._select_individual()
_, donor_annot = self._select_individual()

common_types = sorted(set(recipient_annot.nodes_by_name.keys()).intersection(set(donor_annot.nodes_by_name.keys())))
recipient_options = self._filter_nodes((node for rule_name in common_types for node in recipient_annot.nodes_by_name[rule_name]), recipient_root, recipient_annot)
recipient_lookup = dict(recipient_annot.rules_by_name)
recipient_lookup.update(recipient_annot.quants_by_name)
recipient_lookup.update(recipient_annot.alts_by_name)

donor_lookup = dict(donor_annot.rules_by_name)
donor_lookup.update(donor_annot.quants_by_name)
donor_lookup.update(donor_annot.alts_by_name)
common_types = sorted((set(recipient_lookup.keys()) - {(x, ) for x in self._generator_factory._immutable_rules} - {('<INVALID>', )}) & set(donor_lookup.keys()))

recipient_options = [(rule_name, node) for rule_name in common_types for node in recipient_lookup[rule_name] if node.parent]
# Shuffle suitable nodes with sample.
for recipient_node in random.sample(recipient_options, k=len(recipient_options)):
donor_options = donor_annot.nodes_by_name[recipient_node.name]
for rule_name, recipient_node in random.sample(recipient_options, k=len(recipient_options)):
donor_options = donor_lookup[rule_name]
for donor_node in random.sample(donor_options, k=len(donor_options)):
# Make sure that the output tree won't exceed the depth limit.
if (recipient_annot.node_levels[recipient_node] + donor_annot.node_depths[donor_node] <= self._limit.depth
and recipient_annot.token_counts[recipient_root] - recipient_annot.token_counts[recipient_node] + donor_annot.token_counts[donor_node] < self._limit.tokens):
recipient_node = recipient_node.replace(donor_node)
return recipient_node.root
recipient_node.replace(donor_node)
return recipient_root

# If selection strategy fails, we practically cause the whole donor tree
# If selection strategy fails, we practically cause the whole recipient tree
# to be the result of recombination.
logger.debug('Could not find node pairs to recombine.')
return donor_root
return recipient_root

def _select_individual(self):
root, annot = self._population.select_individual()
Expand All @@ -348,29 +363,33 @@ def _add_individual(self, root, path):
# superfluous here, but we have no way of knowing that in advance
self._population.add_individual(root, Annotations(root), path)

# Filter items from ``nodes`` that can be regenerated within the current
# maximum depth and token limit (except '<INVALID>' and immutable nodes
# and nodes without name).
def _filter_nodes(self, nodes, root, annot):
return [node for node in nodes
if node.parent is not None
and node.name not in self._generator_factory._immutable_rules
and node.name not in [None, '<INVALID>']
and annot.node_levels[node] + self._generator_factory._rule_sizes.get(node.name, RuleSize(0, 0)).depth < self._limit.depth
and annot.token_counts[root] - annot.token_counts[node] + self._generator_factory._rule_sizes.get(node.name, RuleSize(0, 0)).tokens < self._limit.tokens]


class Annotations:

def __init__(self, root):
def _annotate(current, level):
nonlocal current_rule_name
self.node_levels[current] = level

if isinstance(current, (UnlexerRule, UnparserRule)):
if current.name:
if current.name not in self.nodes_by_name:
self.nodes_by_name[current.name] = []
self.nodes_by_name[current.name].append(current)
current_rule_name = (current.name,)
if current_rule_name not in self.rules_by_name:
self.rules_by_name[current_rule_name] = []
self.rules_by_name[current_rule_name].append(current)
else:
current_rule_name = None
elif current_rule_name:
if isinstance(current, UnparserRuleQuantifier):
node_name = current_rule_name + ('q', current.idx,)
if node_name not in self.quants_by_name:
self.quants_by_name[node_name] = []
self.quants_by_name[node_name].append(current)
elif isinstance(current, UnparserRuleAlternative):
node_name = current_rule_name + ('a', current.alt_idx,)
if node_name not in self.alts_by_name:
self.alts_by_name[node_name] = []
self.alts_by_name[node_name].append(current)

self.node_depths[current] = 0
self.token_counts[current] = 0
Expand All @@ -380,8 +399,23 @@ def _annotate(current, level):
self.node_depths[current] = max(self.node_depths[current], self.node_depths[child] + 1)
self.token_counts[current] += self.token_counts[child] if isinstance(child, ParentRule) else child.size.tokens + 1

self.nodes_by_name = {}
current_rule_name = None
self.rules_by_name = {}
self.alts_by_name = {}
self.quants_by_name = {}
self.node_levels = {}
self.node_depths = {}
self.token_counts = {}
_annotate(root, 0)

@property
def rules(self):
return [rule for rules in self.rules_by_name.values() for rule in rules]

@property
def alts(self):
return [alt for alts in self.alts_by_name.values() for alt in alts]

@property
def quants(self):
return [quant for quants in self.quants_by_name.values() for quant in quants]

0 comments on commit f561895

Please sign in to comment.