Skip to content

Commit 6be6c0c

Browse files
grdsdevclaude
andauthored
feat(realtime): add presence enabled flag on join payload (#1229)
Co-authored-by: Claude <[email protected]>
1 parent 6e19a47 commit 6be6c0c

File tree

4 files changed

+140
-2
lines changed

4 files changed

+140
-2
lines changed

src/realtime/src/realtime/_async/channel.py

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ def __init__(
8585
else {
8686
"config": {
8787
"broadcast": {"ack": False, "self": False},
88-
"presence": {"key": ""},
88+
"presence": {"key": "", "enabled": False},
8989
"private": False,
9090
}
9191
}
@@ -191,9 +191,16 @@ async def subscribe(
191191
else:
192192
config: RealtimeChannelConfig = self.params["config"]
193193
broadcast = config.get("broadcast")
194-
presence = config.get("presence")
194+
presence = config.get("presence") or RealtimeChannelPresenceConfig(
195+
key="", enabled=False
196+
)
195197
private = config.get("private", False)
196198

199+
presence_enabled = self.presence._has_callback_attached or presence.get(
200+
"enabled", False
201+
)
202+
presence["enabled"] = presence_enabled
203+
197204
config_payload: Dict[str, Any] = {
198205
"config": {
199206
"broadcast": broadcast,
@@ -429,6 +436,13 @@ def on_presence_sync(self, callback: Callable[[], None]) -> AsyncRealtimeChannel
429436
:return: The Channel instance for method chaining.
430437
"""
431438
self.presence.on_sync(callback)
439+
440+
if self.is_joined:
441+
logger.info(
442+
f"channel {self.topic} resubscribe due to change in presence callbacks on joined channel"
443+
)
444+
asyncio.create_task(self._resubscribe())
445+
432446
return self
433447

434448
def on_presence_join(
@@ -441,6 +455,12 @@ def on_presence_join(
441455
:return: The Channel instance for method chaining.
442456
"""
443457
self.presence.on_join(callback)
458+
if self.is_joined:
459+
logger.info(
460+
f"channel {self.topic} resubscribe due to change in presence callbacks on joined channel"
461+
)
462+
asyncio.create_task(self._resubscribe())
463+
444464
return self
445465

446466
def on_presence_leave(
@@ -453,6 +473,11 @@ def on_presence_leave(
453473
:return: The Channel instance for method chaining.
454474
"""
455475
self.presence.on_leave(callback)
476+
if self.is_joined:
477+
logger.info(
478+
f"channel {self.topic} resubscribe due to change in presence callbacks on joined channel"
479+
)
480+
asyncio.create_task(self._resubscribe())
456481
return self
457482

458483
# Broadcast methods
@@ -469,6 +494,11 @@ async def send_broadcast(self, event: str, data: Any) -> None:
469494
)
470495

471496
# Internal methods
497+
498+
async def _resubscribe(self) -> None:
499+
await self.unsubscribe()
500+
await self.subscribe()
501+
472502
def _broadcast_endpoint_url(self):
473503
return f"{http_endpoint_url(self.socket.http_endpoint)}/api/broadcast"
474504

src/realtime/src/realtime/_async/presence.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,14 @@
2121

2222

2323
class AsyncRealtimePresence:
24+
@property
25+
def _has_callback_attached(self) -> bool:
26+
return (
27+
self.on_join_callback is not None
28+
or self.on_leave_callback is not None
29+
or self.on_sync_callback is not None
30+
)
31+
2432
def __init__(self):
2533
self.state: RealtimePresenceState = {}
2634
self.on_join_callback: Optional[PresenceOnJoinCallback] = None

src/realtime/src/realtime/types.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,7 @@ class RealtimeChannelBroadcastConfig(TypedDict):
179179

180180
class RealtimeChannelPresenceConfig(TypedDict):
181181
key: str
182+
enabled: bool
182183

183184

184185
class RealtimeChannelConfig(TypedDict):

src/realtime/tests/test_presence.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -173,3 +173,102 @@ def test_transform_state_additional_fields():
173173

174174
result = AsyncRealtimePresence._transform_state(state_with_additional_fields)
175175
assert result == expected_output
176+
177+
178+
def test_presence_has_callback_attached():
179+
"""Test that _has_callback_attached property correctly detects presence callbacks."""
180+
presence = AsyncRealtimePresence()
181+
182+
# Initially no callbacks should be attached
183+
assert not presence._has_callback_attached
184+
185+
# After setting sync callback
186+
presence.on_sync(lambda: None)
187+
assert presence._has_callback_attached
188+
189+
# Reset and test with join callback
190+
presence = AsyncRealtimePresence()
191+
presence.on_join(lambda key, current, new: None)
192+
assert presence._has_callback_attached
193+
194+
# Reset and test with leave callback
195+
presence = AsyncRealtimePresence()
196+
presence.on_leave(lambda key, current, left: None)
197+
assert presence._has_callback_attached
198+
199+
200+
def test_presence_config_includes_enabled_field():
201+
"""Test that presence config correctly includes enabled flag."""
202+
from realtime.types import RealtimeChannelPresenceConfig
203+
204+
# Test creating presence config with enabled field
205+
config: RealtimeChannelPresenceConfig = {"key": "user123", "enabled": True}
206+
assert config["key"] == "user123"
207+
assert config["enabled"] == True
208+
209+
# Test with enabled False
210+
config_disabled: RealtimeChannelPresenceConfig = {"key": "", "enabled": False}
211+
assert config_disabled["key"] == ""
212+
assert config_disabled["enabled"] == False
213+
214+
215+
@pytest.mark.asyncio
216+
async def test_presence_enabled_when_callbacks_attached():
217+
"""Test that presence.enabled is set correctly based on callback attachment."""
218+
from unittest.mock import AsyncMock, Mock
219+
220+
socket = AsyncRealtimeClient(f"{URL}/realtime/v1", ANON_KEY)
221+
channel = socket.channel("test")
222+
223+
# Mock the join_push to capture the payload
224+
mock_join_push = Mock()
225+
mock_join_push.receive = Mock(return_value=mock_join_push)
226+
mock_join_push.update_payload = Mock()
227+
mock_join_push.resend = AsyncMock()
228+
channel.join_push = mock_join_push
229+
230+
# Mock socket connection by setting _ws_connection
231+
mock_ws = Mock()
232+
socket._ws_connection = mock_ws
233+
socket._leave_open_topic = AsyncMock()
234+
235+
# Add presence callback before subscription
236+
channel.on_presence_sync(lambda: None)
237+
238+
await channel.subscribe()
239+
240+
# Verify that update_payload was called
241+
assert mock_join_push.update_payload.called
242+
243+
# Get the payload that was passed to update_payload
244+
call_args = mock_join_push.update_payload.call_args
245+
payload = call_args[0][0]
246+
247+
# Verify presence.enabled is True because callback is attached
248+
assert payload["config"]["presence"]["enabled"] == True
249+
250+
251+
@pytest.mark.asyncio
252+
async def test_resubscribe_on_presence_callback_addition():
253+
"""Test that channel resubscribes when presence callbacks are added after joining."""
254+
import asyncio
255+
from unittest.mock import AsyncMock
256+
257+
socket = AsyncRealtimeClient(f"{URL}/realtime/v1", ANON_KEY)
258+
channel = socket.channel("test")
259+
260+
# Mock the channel as joined
261+
channel.state = "joined"
262+
channel._joined_once = True
263+
264+
# Mock resubscribe method
265+
channel._resubscribe = AsyncMock()
266+
267+
# Add presence callbacks after joining
268+
channel.on_presence_sync(lambda: None)
269+
270+
# Wait a bit for async tasks to complete
271+
await asyncio.sleep(0.1)
272+
273+
# Verify resubscribe was called
274+
assert channel._resubscribe.call_count == 1

0 commit comments

Comments
 (0)