Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Function feature with modulepickle #9

Open
wants to merge 4 commits into
base: modulepickle
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 24 additions & 4 deletions pyzoo/zoo/util/modulepickle.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import hashlib
import sys
from logging import getLogger
import types

__all__ = ('extend', 'extend_ray', 'extend_cloudpickle')

Expand Down Expand Up @@ -74,11 +75,16 @@ def compress(packagename, path):
return tar.getvalue()


def import_compressed(name, package, class_name):
def import_compressed(name, package, class_name,is_anyfunc):
res = package.load(name)
if getattr(res, class_name, None):
class_type = getattr(res, class_name)
return class_type.__new__(class_type)
obj_type = getattr(res, class_name)

return obj_type.__new__(obj_type) if not is_anyfunc else types.FunctionType(getattr(obj_type, "__code__", ""),
getattr(obj_type, "__globals__", ""),
name=getattr(obj_type, "__name__", ""),
argdefs=getattr(obj_type, "__defaults__", ""),
closure=getattr(obj_type, "__closure__", ""))
else:
return res

Expand All @@ -101,6 +107,9 @@ def is_local(module):
if path is None:
return False

# if your zoo is not installed by whl,
# to debug codes you may exclude your az path from loacl path
# to avoid infinite resursion.
if path.startswith(python_lib_path):
return False

Expand Down Expand Up @@ -157,7 +166,18 @@ def reducer_override(self, obj):
else:
print("get local {} in save_module, path is {}".format(module.__name__, module.__file__))
package = self.compress_package(packagename(module), get_path(module))
args = (module.__name__, package, obj.__class__.__name__)

try:
# todo:Should check class type first
is_anyfunc=isinstance(obj, types.FunctionType)
except TypeError: # t is not a class (old Boost; see SF #502085)
is_anyfunc = False

if is_anyfunc:
args = (module.__name__, package, obj.__name__,is_anyfunc)
else:
args = (module.__name__, package, obj.__class__.__name__,is_anyfunc)

return import_compressed, args, obj.__dict__
return super().reducer_override(obj)

Expand Down