Skip to content

Commit b596faa

Browse files
Cache _get_e2e_cross_signing_signatures_for_devices (#18899)
1 parent 6f9fab1 commit b596faa

File tree

5 files changed

+196
-59
lines changed

5 files changed

+196
-59
lines changed

changelog.d/18899.feature

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Add an in-memory cache to `_get_e2e_cross_signing_signatures_for_devices` to reduce DB load.

synapse/storage/database.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2653,15 +2653,22 @@ def make_in_list_sql_clause(
26532653

26542654

26552655
# These overloads ensure that `columns` and `iterable` values have the same length.
2656-
# Suppress "Single overload definition, multiple required" complaint.
2657-
@overload # type: ignore[misc]
2656+
@overload
26582657
def make_tuple_in_list_sql_clause(
26592658
database_engine: BaseDatabaseEngine,
26602659
columns: Tuple[str, str],
26612660
iterable: Collection[Tuple[Any, Any]],
26622661
) -> Tuple[str, list]: ...
26632662

26642663

2664+
@overload
2665+
def make_tuple_in_list_sql_clause(
2666+
database_engine: BaseDatabaseEngine,
2667+
columns: Tuple[str, str, str],
2668+
iterable: Collection[Tuple[Any, Any, Any]],
2669+
) -> Tuple[str, list]: ...
2670+
2671+
26652672
def make_tuple_in_list_sql_clause(
26662673
database_engine: BaseDatabaseEngine,
26672674
columns: Tuple[str, ...],

synapse/storage/databases/main/cache.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222

2323
import itertools
24+
import json
2425
import logging
2526
from typing import TYPE_CHECKING, Any, Collection, Iterable, List, Optional, Tuple
2627

@@ -62,6 +63,12 @@
6263
# As above, but for invalidating room caches on room deletion
6364
DELETE_ROOM_CACHE_NAME = "dr_cache_fake"
6465

66+
# This cache takes a list of tuples as its first argument, which requires
67+
# special handling.
68+
GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME = (
69+
"_get_e2e_cross_signing_signatures_for_device"
70+
)
71+
6572
# How long between cache invalidation table cleanups, once we have caught up
6673
# with the backlog.
6774
REGULAR_CLEANUP_INTERVAL_MS = Config.parse_duration("1h")
@@ -270,6 +277,33 @@ def process_replication_rows(
270277
# room membership.
271278
#
272279
# self._membership_stream_cache.all_entities_changed(token) # type: ignore[attr-defined]
280+
elif (
281+
row.cache_func
282+
== GET_E2E_CROSS_SIGNING_SIGNATURES_FOR_DEVICE_CACHE_NAME
283+
):
284+
# "keys" is a list of strings, where each string is a
285+
# JSON-encoded representation of the tuple keys, i.e.
286+
# keys: ['["@userid:domain", "DEVICEID"]','["@userid2:domain", "DEVICEID2"]']
287+
#
288+
# This is a side-effect of not being able to send nested
289+
# information over replication.
290+
for json_str in row.keys:
291+
try:
292+
user_id, device_id = json.loads(json_str)
293+
except (json.JSONDecodeError, TypeError):
294+
logger.error(
295+
"Failed to deserialise cache key as valid JSON: %s",
296+
json_str,
297+
)
298+
continue
299+
300+
# Invalidate each key.
301+
#
302+
# Note: .invalidate takes a tuple of arguments, hence the need
303+
# to nest our tuple in another tuple.
304+
self._get_e2e_cross_signing_signatures_for_device.invalidate( # type: ignore[attr-defined]
305+
((user_id, device_id),)
306+
)
273307
else:
274308
self._attempt_to_invalidate_cache(row.cache_func, row.keys)
275309

synapse/storage/databases/main/end_to_end_keys.py

Lines changed: 147 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#
2121
#
2222
import abc
23+
import json
2324
from typing import (
2425
TYPE_CHECKING,
2526
Any,
@@ -354,15 +355,17 @@ async def get_e2e_device_keys_and_signatures(
354355
)
355356

356357
for batch in batch_iter(signature_query, 50):
357-
cross_sigs_result = await self.db_pool.runInteraction(
358-
"get_e2e_cross_signing_signatures_for_devices",
359-
self._get_e2e_cross_signing_signatures_for_devices_txn,
360-
batch,
358+
cross_sigs_result = (
359+
await self._get_e2e_cross_signing_signatures_for_devices(batch)
361360
)
362361

363362
# add each cross-signing signature to the correct device in the result dict.
364-
for user_id, key_id, device_id, signature in cross_sigs_result:
363+
for (
364+
user_id,
365+
device_id,
366+
), signature_list in cross_sigs_result.items():
365367
target_device_result = result[user_id][device_id]
368+
366369
# We've only looked up cross-signatures for non-deleted devices with key
367370
# data.
368371
assert target_device_result is not None
@@ -373,7 +376,9 @@ async def get_e2e_device_keys_and_signatures(
373376
signing_user_signatures = target_device_signatures.setdefault(
374377
user_id, {}
375378
)
376-
signing_user_signatures[key_id] = signature
379+
380+
for key_id, signature in signature_list:
381+
signing_user_signatures[key_id] = signature
377382

378383
log_kv(result)
379384
return result
@@ -479,41 +484,83 @@ def get_e2e_device_keys_txn(
479484

480485
return result
481486

482-
def _get_e2e_cross_signing_signatures_for_devices_txn(
483-
self, txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
484-
) -> List[Tuple[str, str, str, str]]:
485-
"""Get cross-signing signatures for a given list of devices
487+
@cached()
488+
def _get_e2e_cross_signing_signatures_for_device(
489+
self,
490+
user_id_and_device_id: Tuple[str, str],
491+
) -> Sequence[Tuple[str, str]]:
492+
"""
493+
The single-item version of `_get_e2e_cross_signing_signatures_for_devices`.
494+
See @cachedList for why a separate method is needed.
495+
"""
496+
raise NotImplementedError()
497+
498+
@cachedList(
499+
cached_method_name="_get_e2e_cross_signing_signatures_for_device",
500+
list_name="device_query",
501+
)
502+
async def _get_e2e_cross_signing_signatures_for_devices(
503+
self, device_query: Iterable[Tuple[str, str]]
504+
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
505+
"""Get cross-signing signatures for a given list of user IDs and devices.
506+
507+
Args:
508+
An iterable containing tuples of (user ID, device ID).
509+
510+
Returns:
511+
A mapping of results. The keys are the original (user_id, device_id)
512+
tuple, while the value is the matching list of tuples of
513+
(key_id, signature). The value will be an empty list if no
514+
signatures exist for the device.
486515
487-
Returns signatures made by the owners of the devices.
516+
Given this method is annotated with `@cachedList`, the return dict's
517+
keys match the tuples within `device_query`, so that cache entries can
518+
be computed from the corresponding values.
488519
489-
Returns: a list of results; each entry in the list is a tuple of
490-
(user_id, key_id, target_device_id, signature).
520+
As results are cached, the return type is immutable.
491521
"""
492-
signature_query_clauses = []
493-
signature_query_params = []
494522

495-
for user_id, device_id in device_query:
496-
signature_query_clauses.append(
497-
"target_user_id = ? AND target_device_id = ? AND user_id = ?"
523+
def _get_e2e_cross_signing_signatures_for_devices_txn(
524+
txn: LoggingTransaction, device_query: Iterable[Tuple[str, str]]
525+
) -> Mapping[Tuple[str, str], Sequence[Tuple[str, str]]]:
526+
where_clause_sql, where_clause_params = make_tuple_in_list_sql_clause(
527+
self.database_engine,
528+
columns=("target_user_id", "target_device_id", "user_id"),
529+
iterable=[
530+
(user_id, device_id, user_id) for user_id, device_id in device_query
531+
],
498532
)
499-
signature_query_params.extend([user_id, device_id, user_id])
500-
501-
signature_sql = """
502-
SELECT user_id, key_id, target_device_id, signature
503-
FROM e2e_cross_signing_signatures WHERE %s
504-
""" % (" OR ".join("(" + q + ")" for q in signature_query_clauses))
505-
506-
txn.execute(signature_sql, signature_query_params)
507-
return cast(
508-
List[
509-
Tuple[
510-
str,
511-
str,
512-
str,
513-
str,
514-
]
515-
],
516-
txn.fetchall(),
533+
534+
signature_sql = f"""
535+
SELECT user_id, key_id, target_device_id, signature
536+
FROM e2e_cross_signing_signatures WHERE {where_clause_sql}
537+
"""
538+
539+
txn.execute(signature_sql, where_clause_params)
540+
541+
devices_and_signatures: Dict[Tuple[str, str], List[Tuple[str, str]]] = {}
542+
543+
# `@cachedList` requires we return one key for every item in `device_query`.
544+
# Pre-populate `devices_and_signatures` with each key so that none are missing.
545+
#
546+
# If any are missing, they will be cached as `None`, which is not
547+
# what callers expected.
548+
for user_id, device_id in device_query:
549+
devices_and_signatures.setdefault((user_id, device_id), [])
550+
551+
# Populate the return dictionary with each found key_id and signature.
552+
for user_id, key_id, target_device_id, signature in txn.fetchall():
553+
signature_tuple = (key_id, signature)
554+
devices_and_signatures[(user_id, target_device_id)].append(
555+
signature_tuple
556+
)
557+
558+
return devices_and_signatures
559+
560+
return await self.db_pool.runInteraction(
561+
"_get_e2e_cross_signing_signatures_for_devices_txn",
562+
_get_e2e_cross_signing_signatures_for_devices_txn,
563+
device_query,
517564
)
518565

519566
async def get_e2e_one_time_keys(
@@ -1772,26 +1819,71 @@ async def store_e2e_cross_signing_signatures(
17721819
user_id: the user who made the signatures
17731820
signatures: signatures to add
17741821
"""
1775-
await self.db_pool.simple_insert_many(
1776-
"e2e_cross_signing_signatures",
1777-
keys=(
1778-
"user_id",
1779-
"key_id",
1780-
"target_user_id",
1781-
"target_device_id",
1782-
"signature",
1783-
),
1784-
values=[
1785-
(
1786-
user_id,
1787-
item.signing_key_id,
1788-
item.target_user_id,
1789-
item.target_device_id,
1790-
item.signature,
1791-
)
1822+
1823+
def _store_e2e_cross_signing_signatures(
1824+
txn: LoggingTransaction,
1825+
signatures: "Iterable[SignatureListItem]",
1826+
) -> None:
1827+
self.db_pool.simple_insert_many_txn(
1828+
txn,
1829+
"e2e_cross_signing_signatures",
1830+
keys=(
1831+
"user_id",
1832+
"key_id",
1833+
"target_user_id",
1834+
"target_device_id",
1835+
"signature",
1836+
),
1837+
values=[
1838+
(
1839+
user_id,
1840+
item.signing_key_id,
1841+
item.target_user_id,
1842+
item.target_device_id,
1843+
item.signature,
1844+
)
1845+
for item in signatures
1846+
],
1847+
)
1848+
1849+
to_invalidate = [
1850+
# Each entry is a tuple of arguments to
1851+
# `_get_e2e_cross_signing_signatures_for_device`, which
1852+
# itself takes a tuple. Hence the double-tuple.
1853+
((user_id, item.target_device_id),)
17921854
for item in signatures
1793-
],
1794-
desc="add_e2e_signing_key",
1855+
]
1856+
1857+
if to_invalidate:
1858+
# Invalidate the local cache of this worker.
1859+
for cache_key in to_invalidate:
1860+
txn.call_after(
1861+
self._get_e2e_cross_signing_signatures_for_device.invalidate,
1862+
cache_key,
1863+
)
1864+
1865+
# Stream cache invalidate keys over replication.
1866+
#
1867+
# We can only send a primitive per function argument across
1868+
# replication.
1869+
#
1870+
# Encode the array of strings as a JSON string, and we'll unpack
1871+
# it on the other side.
1872+
to_send = [
1873+
(json.dumps([user_id, item.target_device_id]),)
1874+
for item in signatures
1875+
]
1876+
1877+
self._send_invalidation_to_replication_bulk(
1878+
txn,
1879+
cache_name=self._get_e2e_cross_signing_signatures_for_device.__name__,
1880+
key_tuples=to_send,
1881+
)
1882+
1883+
await self.db_pool.runInteraction(
1884+
"add_e2e_signing_key",
1885+
_store_e2e_cross_signing_signatures,
1886+
signatures,
17951887
)
17961888

17971889

synapse/util/caches/descriptors.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -579,9 +579,12 @@ def cachedList(
579579
Used to do batch lookups for an already created cache. One of the arguments
580580
is specified as a list that is iterated through to lookup keys in the
581581
original cache. A new tuple consisting of the (deduplicated) keys that weren't in
582-
the cache gets passed to the original function, which is expected to results
582+
the cache gets passed to the original function, which is expected to result
583583
in a map of key to value for each passed value. The new results are stored in the
584-
original cache. Note that any missing values are cached as None.
584+
original cache.
585+
586+
Note that any values in the input that end up being missing from both the
587+
cache and the returned dictionary will be cached as `None`.
585588
586589
Args:
587590
cached_method_name: The name of the single-item lookup method.

0 commit comments

Comments
 (0)