Skip to content

Commit

Permalink
fixed circular import and empty ns hostname lookup
Browse files Browse the repository at this point in the history
  • Loading branch information
irmen committed Apr 16, 2020
1 parent b7f80f4 commit df77ec5
Show file tree
Hide file tree
Showing 13 changed files with 97 additions and 96 deletions.
2 changes: 1 addition & 1 deletion Pyro5/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
Pyro - Python Remote Objects. Copyright by Irmen de Jong ([email protected]).
"""

__version__ = "5.9"
__version__ = "5.9.1"
__author__ = "Irmen de Jong"


Expand Down
3 changes: 2 additions & 1 deletion Pyro5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,12 @@

from . import __version__
from .configure import global_config as config
from .core import URI, locate_ns, resolve, type_meta, current_context
from .core import URI, locate_ns, resolve, type_meta
from .client import Proxy, BatchProxy, SerializedBlob
from .server import Daemon, DaemonObject, callback, expose, behavior, oneway, serve
from .nameserver import start_ns, start_ns_loop
from .serializers import SerializerBase
from .callcontext import current_context


__all__ = ["config", "URI", "locate_ns", "resolve", "type_meta", "current_context",
Expand Down
47 changes: 47 additions & 0 deletions Pyro5/callcontext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import threading
from . import errors


# call context thread local
class _CallContext(threading.local):
def __init__(self):
# per-thread initialization
self.client = None
self.client_sock_addr = None
self.seq = 0
self.msg_flags = 0
self.serializer_id = 0
self.annotations = {}
self.response_annotations = {}
self.correlation_id = None

def to_global(self):
return dict(self.__dict__)

def from_global(self, values):
self.client = values["client"]
self.seq = values["seq"]
self.msg_flags = values["msg_flags"]
self.serializer_id = values["serializer_id"]
self.annotations = values["annotations"]
self.response_annotations = values["response_annotations"]
self.correlation_id = values["correlation_id"]
self.client_sock_addr = values["client_sock_addr"]

def track_resource(self, resource):
"""keep a weak reference to the resource to be tracked for this connection"""
if self.client:
self.client.tracked_resources.add(resource)
else:
raise errors.PyroError("cannot track resource on a connectionless call")

def untrack_resource(self, resource):
"""no longer track the resource for this connection"""
if self.client:
self.client.tracked_resources.discard(resource)
else:
raise errors.PyroError("cannot untrack resource on a connectionless call")


current_context = _CallContext()
"""the context object for the current call. (thread-local)"""
13 changes: 7 additions & 6 deletions Pyro5/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import serpent
import contextlib
from . import config, core, serializers, protocol, errors, socketutil
from .callcontext import current_context
try:
from greenlet import getcurrent as get_ident
except ImportError:
Expand Down Expand Up @@ -71,7 +72,7 @@ def __init__(self, uri, connected_socket=None):
# note: we're not clearing the client annotations dict here.
# that is because otherwise it will be wiped if a new proxy is needed to connect PYRONAME uris.
# clearing the response annotations is okay.
core.current_context.response_annotations = {}
current_context.response_annotations = {}
if connected_socket:
self.__pyroCreateConnection(False, connected_socket)

Expand Down Expand Up @@ -196,12 +197,12 @@ def __pyroSetTimeout(self, timeout):
def _pyroInvoke(self, methodname, vargs, kwargs, flags=0, objectId=None):
"""perform the remote method call communication"""
self.__check_owner()
core.current_context.response_annotations = {}
current_context.response_annotations = {}
if self._pyroConnection is None:
self.__pyroCreateConnection()
serializer = serializers.serializers[self._pyroSerializer or config.SERIALIZER]
objectId = objectId or self._pyroConnection.objectId
annotations = core.current_context.annotations
annotations = current_context.annotations
if vargs and isinstance(vargs[0], SerializedBlob):
# special serialization of a 'blob' that stays serialized
data, flags = self.__serializeBlobArgs(vargs, kwargs, annotations, flags, objectId, methodname, serializer)
Expand Down Expand Up @@ -229,7 +230,7 @@ def _pyroInvoke(self, methodname, vargs, kwargs, flags=0, objectId=None):
log.error(error)
raise errors.SerializeError(error)
if msg.annotations:
core.current_context.response_annotations = msg.annotations
current_context.response_annotations = msg.annotations
if self._pyroRawWireResponse:
return msg
data = serializer.loads(msg.data)
Expand Down Expand Up @@ -285,7 +286,7 @@ def connect_and_handshake(conn):
data = {"handshake": self._pyroHandshake, "object": uri.object}
data = serializer.dumps(data)
msg = protocol.SendingMessage(protocol.MSG_CONNECT, 0, self._pyroSeq, serializer.serializer_id,
data, annotations=core.current_context.annotations)
data, annotations=current_context.annotations)
if config.LOGWIRE:
protocol.log_wiredata(log, "proxy connect sending", msg)
conn.send(msg.data)
Expand Down Expand Up @@ -320,7 +321,7 @@ def connect_and_handshake(conn):
self._pyroValidateHandshake(handshake_response)
log.debug("connected to %s - %s - %s", self._pyroUri, conn.family(), "SSL" if sslContext else "unencrypted")
if msg.annotations:
core.current_context.response_annotations = msg.annotations
current_context.response_annotations = msg.annotations
else:
conn.close()
err = "cannot connect to %s: invalid msg type %d received" % (connect_location, msg.type)
Expand Down
5 changes: 0 additions & 5 deletions Pyro5/compatibility/Pyro4.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@
Pyro - Python Remote Objects. Copyright by Irmen de Jong ([email protected]).
"""

# the symbols that were available in Pyro4 as Pyro4.* :
# from Pyro4.core import URI, Proxy, Daemon, callback, batch, asyncproxy, oneway, expose, behavior, current_context
# from Pyro4.core import _locateNS as locateNS, _resolve as resolve
# from Pyro4.futures import Future

import sys
import ipaddress

Expand Down
56 changes: 4 additions & 52 deletions Pyro5/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,14 +9,13 @@
import contextlib
import ipaddress
import socket
import threading
import random
import serpent
from typing import Union, Optional
from . import config, errors, socketutil, serializers, nameserver


__all__ = ["URI", "DAEMON_NAME", "NAMESERVER_NAME", "current_context", "resolve", "locate_ns", "type_meta"]
__all__ = ["URI", "DAEMON_NAME", "NAMESERVER_NAME", "resolve", "locate_ns", "type_meta"]

log = logging.getLogger("Pyro5.core")

Expand Down Expand Up @@ -203,13 +202,11 @@ def resolve(uri: Union[str, URI], delay_time: float = 0.0) -> URI:
raise errors.PyroError("invalid uri protocol")


from . import client # circular import...


def locate_ns(host: Union[str, ipaddress.IPv4Address, ipaddress.IPv6Address] = "",
port: Optional[int] = None, broadcast: bool = True) -> client.Proxy:
port: Optional[int] = None, broadcast: bool = True) -> "client.Proxy":
"""Get a proxy for a name server somewhere in the network."""
if host == "":
from . import client
if not host:
# first try localhost if we have a good chance of finding it there
if config.NS_HOST in ("localhost", "::1") or config.NS_HOST.startswith("127."):
if ":" in config.NS_HOST: # ipv6
Expand Down Expand Up @@ -295,48 +292,3 @@ def type_meta(class_or_object, prefix="class:"):
if hasattr(class_or_object, "__class__"):
return type_meta(class_or_object.__class__)
return frozenset()


# call context thread local
class _CallContext(threading.local):
def __init__(self):
# per-thread initialization
self.client = None
self.client_sock_addr = None
self.seq = 0
self.msg_flags = 0
self.serializer_id = 0
self.annotations = {}
self.response_annotations = {}
self.correlation_id = None

def to_global(self):
return dict(self.__dict__)

def from_global(self, values):
self.client = values["client"]
self.seq = values["seq"]
self.msg_flags = values["msg_flags"]
self.serializer_id = values["serializer_id"]
self.annotations = values["annotations"]
self.response_annotations = values["response_annotations"]
self.correlation_id = values["correlation_id"]
self.client_sock_addr = values["client_sock_addr"]

def track_resource(self, resource):
"""keep a weak reference to the resource to be tracked for this connection"""
if self.client:
self.client.tracked_resources.add(resource)
else:
raise errors.PyroError("cannot track resource on a connectionless call")

def untrack_resource(self, resource):
"""no longer track the resource for this connection"""
if self.client:
self.client.tracked_resources.discard(resource)
else:
raise errors.PyroError("cannot untrack resource on a connectionless call")


current_context = _CallContext()
"""the context object for the current call. (thread-local)"""
2 changes: 1 addition & 1 deletion Pyro5/nsc.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def cmd_yplookup_any():
def main(args=None):
from argparse import ArgumentParser
parser = ArgumentParser(description="Pyro name server control utility.")
parser.add_argument("-n", "--host", dest="host", help="hostname of the NS")
parser.add_argument("-n", "--host", dest="host", help="hostname of the NS", default="")
parser.add_argument("-p", "--port", dest="port", type=int, help="port of the NS (or bc-port if host isn't specified)")
parser.add_argument("-u", "--unixsocket", help="Unix domain socket name of the NS")
parser.add_argument("-v", "--verbose", action="store_true", dest="verbose", help="verbose output")
Expand Down
8 changes: 3 additions & 5 deletions Pyro5/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
import zlib
import uuid
from . import config, errors
from .callcontext import current_context


log = logging.getLogger("Pyro5.protocol")
Expand Down Expand Up @@ -66,9 +67,6 @@
_empty_correlation_id = b"\0" * 16


from . import core # circular import...


class SendingMessage:
"""Wire protocol message that will be sent."""

Expand All @@ -86,9 +84,9 @@ def __init__(self, msgtype, flags, seq, serializer_id, payload, annotations=None
total_size = len(payload) + annotations_size
if total_size > config.MAX_MESSAGE_SIZE:
raise errors.ProtocolError("message too large ({:d}, max={:d})".format(total_size, config.MAX_MESSAGE_SIZE))
if core.current_context.correlation_id:
if current_context.correlation_id:
flags |= FLAGS_CORR_ID
self.corr_id = core.current_context.correlation_id.bytes
self.corr_id = current_context.correlation_id.bytes
else:
self.corr_id = _empty_correlation_id
header_data = struct.pack(_header_format, b"PYRO", PROTOCOL_VERSION, msgtype, serializer_id, flags, seq,
Expand Down
34 changes: 17 additions & 17 deletions Pyro5/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import ipaddress
from typing import Callable, Tuple, Union, Optional, Dict, Any, Sequence, Set
from . import config, core, errors, serializers, socketutil, protocol, client

from .callcontext import current_context

__all__ = ["Daemon", "DaemonObject", "callback", "expose", "behavior", "oneway", "serve"]

Expand Down Expand Up @@ -175,7 +175,7 @@ def get_next_stream_item(self, streamId):
client, timestamp, linger_timestamp, stream = self.daemon.streaming_responses[streamId]
if client is None:
# reset client connection association (can be None if proxy disconnected)
self.daemon.streaming_responses[streamId] = (core.current_context.client, timestamp, 0, stream)
self.daemon.streaming_responses[streamId] = (current_context.client, timestamp, 0, stream)
try:
return next(stream)
except Exception:
Expand Down Expand Up @@ -331,9 +331,9 @@ def _handshake(self, conn, denied_reason=None):
if config.LOGWIRE:
protocol.log_wiredata(log, "daemon handshake received", msg)
if msg.flags & protocol.FLAGS_CORR_ID:
core.current_context.correlation_id = uuid.UUID(bytes=msg.corr_id)
current_context.correlation_id = uuid.UUID(bytes=msg.corr_id)
else:
core.current_context.correlation_id = uuid.uuid4()
current_context.correlation_id = uuid.uuid4()
serializer_id = msg.serializer_id
serializer = serializers.serializers_by_id[serializer_id]
data = serializer.loads(msg.data)
Expand Down Expand Up @@ -397,9 +397,9 @@ def handleRequest(self, conn):
request_seq = msg.seq
request_serializer_id = msg.serializer_id
if msg.flags & protocol.FLAGS_CORR_ID:
core.current_context.correlation_id = uuid.UUID(bytes=msg.corr_id)
current_context.correlation_id = uuid.UUID(bytes=msg.corr_id)
else:
core.current_context.correlation_id = uuid.uuid4()
current_context.correlation_id = uuid.uuid4()
if config.LOGWIRE:
protocol.log_wiredata(log, "daemon wiredata received", msg)
if msg.type == protocol.MSG_PING:
Expand All @@ -416,16 +416,16 @@ def handleRequest(self, conn):
else:
# normal deserialization of remote call arguments
objId, method, vargs, kwargs = serializer.loadsCall(msg.data)
core.current_context.client = conn
current_context.client = conn
try:
# store, because on oneway calls, socket will be disconnected:
core.current_context.client_sock_addr = conn.sock.getpeername()
current_context.client_sock_addr = conn.sock.getpeername()
except socket.error:
core.current_context.client_sock_addr = None # sometimes getpeername() doesn't work...
core.current_context.seq = msg.seq
core.current_context.annotations = msg.annotations
core.current_context.msg_flags = msg.flags
core.current_context.serializer_id = msg.serializer_id
current_context.client_sock_addr = None # sometimes getpeername() doesn't work...
current_context.seq = msg.seq
current_context.annotations = msg.annotations
current_context.msg_flags = msg.flags
current_context.serializer_id = msg.serializer_id
del msg # invite GC to collect the object, don't wait for out-of-scope
obj = self.objectsById.get(objId)
if obj is not None:
Expand Down Expand Up @@ -489,7 +489,7 @@ def handleRequest(self, conn):
response_flags |= protocol.FLAGS_BATCH
msg = protocol.SendingMessage(protocol.MSG_RESULT, response_flags, request_seq, serializer.serializer_id, data,
annotations=self.__annotations())
core.current_context.response_annotations = {}
current_context.response_annotations = {}
if config.LOGWIRE:
protocol.log_wiredata(log, "daemon wiredata sending", msg)
conn.send(msg.data)
Expand Down Expand Up @@ -753,7 +753,7 @@ def combine(self, daemon):
self.transportServer.combine_loop(daemon.transportServer)

def __annotations(self):
annotations = core.current_context.response_annotations
annotations = current_context.response_annotations
annotations.update(self.annotations())
return annotations

Expand Down Expand Up @@ -965,15 +965,15 @@ class _OnewayCallThread(threading.Thread):
def __init__(self, pyro_method, vargs, kwargs, pyro_daemon, pyro_client_sock):
super(_OnewayCallThread, self).__init__(target=self._methodcall, name="oneway-call")
self.daemon = True
self.parent_context = core.current_context.to_global()
self.parent_context = current_context.to_global()
self.pyro_daemon = pyro_daemon
self.pyro_client_sock = pyro_client_sock
self.pyro_method = pyro_method
self.pyro_vargs = vargs
self.pyro_kwars = kwargs

def run(self):
core.current_context.from_global(self.parent_context)
current_context.from_global(self.parent_context)
super(_OnewayCallThread, self).run()

def _methodcall(self):
Expand Down
Loading

0 comments on commit df77ec5

Please sign in to comment.