diff --git a/coconut/compiler/util.py b/coconut/compiler/util.py index c1fd1cda..d146f7e9 100644 --- a/coconut/compiler/util.py +++ b/coconut/compiler/util.py @@ -462,6 +462,7 @@ class CombineToNode(Combine, pickleable_obj): """Modified Combine to work with the computation graph.""" __slots__ = () validation_dict = None + pickling_enabled = False def _combine(self, original, loc, tokens): """Implement the parse action for Combine.""" @@ -475,25 +476,23 @@ def postParse(self, original, loc, tokens): """Create a ComputationNode for Combine.""" return ComputationNode(self._combine, original, loc, tokens, ignore_no_tokens=True, ignore_one_token=True, trim_arity=False) - @classmethod - def reconstitute(self, identifier): - return identifier_to_parse_elem(identifier, self.validation_dict) - def __reduce__(self): - if self.validation_dict is None: - return super(CombineToNode, self).__reduce__() + if self.pickling_enabled: + return (identifier_to_parse_elem, (parse_elem_to_identifier(self, self.validation_dict),)) else: - return (self.reconstitute, (parse_elem_to_identifier(self, self.validation_dict),)) + return super(CombineToNode, self).__reduce__() @classmethod @contextmanager - def enable_pickling(validation_dict={}): + def enable_pickling(cls, validation_dict=None): """Context manager to enable pickling for CombineToNode.""" - old_validation_dict, CombineToNode.validation_dict = CombineToNode.validation_dict, validation_dict + old_validation_dict, cls.validation_dict = cls.validation_dict, validation_dict + old_pickling_enabled, cls.pickling_enabled = cls.pickling_enabled, True try: yield finally: - CombineToNode.validation_dict = old_validation_dict + cls.pickling_enabled = old_pickling_enabled + cls.validation_dict = old_validation_dict if USE_COMPUTATION_GRAPH: