Skip to content

Commit

Permalink
Fix CombineToNode pickling
Browse files Browse the repository at this point in the history
  • Loading branch information
evhub committed Oct 19, 2024
1 parent c439630 commit c60d601
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions coconut/compiler/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,7 @@ class CombineToNode(Combine, pickleable_obj):
"""Modified Combine to work with the computation graph."""
__slots__ = ()
validation_dict = None
pickling_enabled = False

def _combine(self, original, loc, tokens):
"""Implement the parse action for Combine."""
Expand All @@ -475,25 +476,23 @@ def postParse(self, original, loc, tokens):
"""Create a ComputationNode for Combine."""
return ComputationNode(self._combine, original, loc, tokens, ignore_no_tokens=True, ignore_one_token=True, trim_arity=False)

@classmethod
def reconstitute(self, identifier):
return identifier_to_parse_elem(identifier, self.validation_dict)

def __reduce__(self):
if self.validation_dict is None:
return super(CombineToNode, self).__reduce__()
if self.pickling_enabled:
return (identifier_to_parse_elem, (parse_elem_to_identifier(self, self.validation_dict),))
else:
return (self.reconstitute, (parse_elem_to_identifier(self, self.validation_dict),))
return super(CombineToNode, self).__reduce__()

@classmethod
@contextmanager
def enable_pickling(validation_dict={}):
def enable_pickling(cls, validation_dict=None):
"""Context manager to enable pickling for CombineToNode."""
old_validation_dict, CombineToNode.validation_dict = CombineToNode.validation_dict, validation_dict
old_validation_dict, cls.validation_dict = cls.validation_dict, validation_dict
old_pickling_enabled, cls.pickling_enabled = cls.pickling_enabled, True
try:
yield
finally:
CombineToNode.validation_dict = old_validation_dict
cls.pickling_enabled = old_pickling_enabled
cls.validation_dict = old_validation_dict


if USE_COMPUTATION_GRAPH:
Expand Down

0 comments on commit c60d601

Please sign in to comment.