Skip to content

Commit 9aa1377

Browse files
committed
fixes
1 parent a6bb296 commit 9aa1377

File tree

4 files changed

+58
-66
lines changed

4 files changed

+58
-66
lines changed

tests/codec_test.py

Lines changed: 23 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1,54 +1,34 @@
11
import os
2-
import base64
32
from pathlib import Path
4-
from typing import Type, TypeVar, Any
3+
from typing import Type, TypeVar
54

65
import capnp
76
from google.protobuf.message import Message
87

98
from xconn.client import connect_anonymous
10-
from xconn import codec
11-
from xconn.types import Event
9+
from xconn.codec import Codec
10+
from xconn.types import Event, OutgoingDataMessage, IncomingDataMessage
1211
from tests.schemas.profile_pb2 import ProfileCreate, ProfileGet
1312

1413

15-
class String(str):
16-
pass
14+
T = TypeVar("T", bound=Message)
1715

1816

19-
class Base64Codec(codec.Codec[String]):
20-
def name(self) -> str:
21-
return "base64"
22-
23-
def encode(self, obj: String) -> str:
24-
return base64.b64encode(obj.encode("utf-8")).decode("utf-8")
25-
26-
def decode(self, data: str, out_type: Type[String]) -> String:
27-
return out_type(base64.b64decode(data.encode("utf-8")).decode())
28-
29-
30-
def test_base64_codec():
31-
encoder = Base64Codec()
32-
encoded = encoder.encode(String("hello"))
33-
assert isinstance(encoded, str)
34-
35-
decoded = encoder.decode(encoded, String)
36-
assert isinstance(decoded, String)
37-
assert decoded == "hello"
38-
39-
40-
class ProtobufCodec(codec.Codec[Message]):
17+
class ProtobufCodec(Codec[T]):
4118
def name(self) -> str:
4219
return "protobuf"
4320

44-
def encode(self, obj: Message) -> bytes:
45-
return obj.SerializeToString()
21+
def encode(self, obj: T) -> OutgoingDataMessage:
22+
payload = obj.SerializeToString()
23+
return OutgoingDataMessage(args=[payload], kwargs={}, details={})
4624

47-
def decode(self, data: bytes, out_type: Type[Message]) -> Message:
48-
msg = out_type()
49-
msg.ParseFromString(data)
25+
def decode(self, msg: IncomingDataMessage, out_type: Type[T]) -> T:
26+
if len(msg.args) == 0 or not isinstance(msg.args[0], bytes):
27+
raise ValueError("ProtobufCodec: cannot decode, expected first arg to be bytes")
5028

51-
return msg
29+
obj = out_type()
30+
obj.ParseFromString(msg.args[0])
31+
return obj
5232

5333

5434
def test_rpc_object_protobuf():
@@ -78,20 +58,6 @@ def inv_handler(profile: ProfileCreate) -> ProfileGet:
7858
session.leave()
7959

8060

81-
def test_pubsub_object():
82-
session = connect_anonymous("ws://localhost:8080/ws", "realm1")
83-
session.set_payload_codec(Base64Codec())
84-
85-
def event_handler(event: Event):
86-
assert event.args[0] == "hello"
87-
88-
session.subscribe_object("io.xconn.object", event_handler, String)
89-
90-
session.publish_object("io.xconn.object", String("hello"))
91-
92-
session.leave()
93-
94-
9561
def test_pubsub_protobuf():
9662
session = connect_anonymous("ws://localhost:8080/ws", "realm1")
9763
session.set_payload_codec(ProtobufCodec())
@@ -192,15 +158,19 @@ def get_profile_handler() -> ProfileGet:
192158
UserGet = user_capnp.UserGet
193159

194160

195-
class CapnpProtoCodec(codec.Codec[T]):
161+
class CapnpProtoCodec(Codec[T]):
196162
def name(self) -> str:
197163
return "capnproto"
198164

199-
def encode(self, obj: Any) -> bytes:
200-
return obj.to_bytes_packed()
165+
def encode(self, obj: T) -> OutgoingDataMessage:
166+
payload = obj.to_bytes_packed()
167+
return OutgoingDataMessage(args=[payload], kwargs={}, details={})
168+
169+
def decode(self, msg: IncomingDataMessage, out_type: Type[T]) -> T:
170+
if len(msg.args) == 0 or not isinstance(msg.args[0], bytes):
171+
raise ValueError("CapnpProtoCodec: cannot decode, expected first arg to be bytes")
201172

202-
def decode(self, data: bytes, out_type: Type[T]) -> T:
203-
return out_type.from_bytes_packed(data)
173+
return out_type.from_bytes_packed(msg.args[0])
204174

205175

206176
def test_rpc_object_capnproto():

xconn/codec.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
from typing import Any, Generic, Type, TypeVar
1+
from typing import Generic, Type, TypeVar
2+
3+
from xconn.types import IncomingDataMessage, OutgoingDataMessage
24

35
T = TypeVar("T")
46

@@ -7,10 +9,10 @@ class Codec(Generic[T]):
79
def name(self) -> str:
810
raise NotImplementedError
911

10-
def encode(self, obj: Any) -> bytes | str:
12+
def encode(self, obj: T) -> OutgoingDataMessage:
1113
"""Serialize a Python object to bytes."""
1214
raise NotImplementedError
1315

14-
def decode(self, data: bytes | str, out_type: Type[T]) -> T:
15-
"""Deserialize bytes into an instance of out_type."""
16+
def decode(self, msg: IncomingDataMessage, out_type: Type[T]) -> T:
17+
"""Deserialize the incoming message into an instance of out_type."""
1618
raise NotImplementedError

xconn/session.py

Lines changed: 15 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import inspect
44
from concurrent.futures import Future
55
from threading import Thread
6-
from typing import Callable, Any, TypeVar, Type
6+
from typing import Callable, Any, TypeVar, Type, overload
77
from dataclasses import dataclass
88

99
from wampproto import messages, session, uris
@@ -180,13 +180,25 @@ def _process_incoming_message(self, msg: messages.Message):
180180
def set_payload_codec(self, codec: Codec) -> None:
181181
self._payload_codec = codec
182182

183-
def call_object(self, procedure: str, request: TReq = None, return_type: Type[TRes] = None) -> TRes | None:
183+
@overload
184+
def call_object(self, procedure: str, request: TReq, return_type: Type[TRes]) -> TRes:
185+
...
186+
187+
@overload
188+
def call_object(self, procedure: str, request: None = None, return_type: None = None) -> None:
189+
...
190+
191+
@overload
192+
def call_object(self, procedure: str, request: None, return_type: Type[TRes]) -> TRes:
193+
...
194+
195+
def call_object(self, procedure: str, request: TReq = None, return_type: Type[TRes] | None = None) -> TRes | None:
184196
if self._payload_codec is None:
185197
raise ValueError("no payload codec set")
186198

187199
if request is not None:
188200
encoded = self._payload_codec.encode(request)
189-
result = self.call(procedure, [encoded])
201+
result = self.call(procedure, args=encoded.args, kwargs=encoded.kwargs, options=encoded.details)
190202
else:
191203
result = self.call(procedure)
192204

xconn/types.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -26,26 +26,34 @@ class UnsubscribeRequest:
2626

2727

2828
@dataclass
29-
class Result:
29+
class OutgoingDataMessage:
3030
args: list | None = None
3131
kwargs: dict | None = None
3232
details: dict | None = None
3333

3434

3535
@dataclass
36-
class Invocation:
37-
args: list | None
38-
kwargs: dict | None
39-
details: dict | None
36+
class Result(OutgoingDataMessage):
37+
pass
4038

4139

4240
@dataclass
43-
class Event:
41+
class IncomingDataMessage:
4442
args: list | None
4543
kwargs: dict | None
4644
details: dict | None
4745

4846

47+
@dataclass
48+
class Invocation(IncomingDataMessage):
49+
pass
50+
51+
52+
@dataclass()
53+
class Event(IncomingDataMessage):
54+
pass
55+
56+
4957
@dataclass
5058
class TransportConfig:
5159
# max wait time for connection to be established

0 commit comments

Comments
 (0)