2020#
2121#
2222import abc
23+ import json
2324from typing import (
2425 TYPE_CHECKING ,
2526 Any ,
@@ -354,15 +355,17 @@ async def get_e2e_device_keys_and_signatures(
354355 )
355356
356357 for batch in batch_iter (signature_query , 50 ):
357- cross_sigs_result = await self .db_pool .runInteraction (
358- "get_e2e_cross_signing_signatures_for_devices" ,
359- self ._get_e2e_cross_signing_signatures_for_devices_txn ,
360- batch ,
358+ cross_sigs_result = (
359+ await self ._get_e2e_cross_signing_signatures_for_devices (batch )
361360 )
362361
363362 # add each cross-signing signature to the correct device in the result dict.
364- for user_id , key_id , device_id , signature in cross_sigs_result :
363+ for (
364+ user_id ,
365+ device_id ,
366+ ), signature_list in cross_sigs_result .items ():
365367 target_device_result = result [user_id ][device_id ]
368+
366369 # We've only looked up cross-signatures for non-deleted devices with key
367370 # data.
368371 assert target_device_result is not None
@@ -373,7 +376,9 @@ async def get_e2e_device_keys_and_signatures(
373376 signing_user_signatures = target_device_signatures .setdefault (
374377 user_id , {}
375378 )
376- signing_user_signatures [key_id ] = signature
379+
380+ for key_id , signature in signature_list :
381+ signing_user_signatures [key_id ] = signature
377382
378383 log_kv (result )
379384 return result
@@ -479,41 +484,83 @@ def get_e2e_device_keys_txn(
479484
480485 return result
481486
482- def _get_e2e_cross_signing_signatures_for_devices_txn (
483- self , txn : LoggingTransaction , device_query : Iterable [Tuple [str , str ]]
484- ) -> List [Tuple [str , str , str , str ]]:
485- """Get cross-signing signatures for a given list of devices
487+ @cached ()
488+ def _get_e2e_cross_signing_signatures_for_device (
489+ self ,
490+ user_id_and_device_id : Tuple [str , str ],
491+ ) -> Sequence [Tuple [str , str ]]:
492+ """
493+ The single-item version of `_get_e2e_cross_signing_signatures_for_devices`.
494+ See @cachedList for why a separate method is needed.
495+ """
496+ raise NotImplementedError ()
497+
498+ @cachedList (
499+ cached_method_name = "_get_e2e_cross_signing_signatures_for_device" ,
500+ list_name = "device_query" ,
501+ )
502+ async def _get_e2e_cross_signing_signatures_for_devices (
503+ self , device_query : Iterable [Tuple [str , str ]]
504+ ) -> Mapping [Tuple [str , str ], Sequence [Tuple [str , str ]]]:
505+ """Get cross-signing signatures for a given list of user IDs and devices.
506+
507+ Args:
508+ An iterable containing tuples of (user ID, device ID).
509+
510+ Returns:
511+ A mapping of results. The keys are the original (user_id, device_id)
512+ tuple, while the value is the matching list of tuples of
513+ (key_id, signature). The value will be an empty list if no
514+ signatures exist for the device.
486515
487- Returns signatures made by the owners of the devices.
516+ Given this method is annotated with `@cachedList`, the return dict's
517+ keys match the tuples within `device_query`, so that cache entries can
518+ be computed from the corresponding values.
488519
489- Returns: a list of results; each entry in the list is a tuple of
490- (user_id, key_id, target_device_id, signature).
520+ As results are cached, the return type is immutable.
491521 """
492- signature_query_clauses = []
493- signature_query_params = []
494522
495- for user_id , device_id in device_query :
496- signature_query_clauses .append (
497- "target_user_id = ? AND target_device_id = ? AND user_id = ?"
523+ def _get_e2e_cross_signing_signatures_for_devices_txn (
524+ txn : LoggingTransaction , device_query : Iterable [Tuple [str , str ]]
525+ ) -> Mapping [Tuple [str , str ], Sequence [Tuple [str , str ]]]:
526+ where_clause_sql , where_clause_params = make_tuple_in_list_sql_clause (
527+ self .database_engine ,
528+ columns = ("target_user_id" , "target_device_id" , "user_id" ),
529+ iterable = [
530+ (user_id , device_id , user_id ) for user_id , device_id in device_query
531+ ],
498532 )
499- signature_query_params .extend ([user_id , device_id , user_id ])
500-
501- signature_sql = """
502- SELECT user_id, key_id, target_device_id, signature
503- FROM e2e_cross_signing_signatures WHERE %s
504- """ % (" OR " .join ("(" + q + ")" for q in signature_query_clauses ))
505-
506- txn .execute (signature_sql , signature_query_params )
507- return cast (
508- List [
509- Tuple [
510- str ,
511- str ,
512- str ,
513- str ,
514- ]
515- ],
516- txn .fetchall (),
533+
534+ signature_sql = f"""
535+ SELECT user_id, key_id, target_device_id, signature
536+ FROM e2e_cross_signing_signatures WHERE { where_clause_sql }
537+ """
538+
539+ txn .execute (signature_sql , where_clause_params )
540+
541+ devices_and_signatures : Dict [Tuple [str , str ], List [Tuple [str , str ]]] = {}
542+
543+ # `@cachedList` requires we return one key for every item in `device_query`.
544+ # Pre-populate `devices_and_signatures` with each key so that none are missing.
545+ #
546+ # If any are missing, they will be cached as `None`, which is not
547+ # what callers expected.
548+ for user_id , device_id in device_query :
549+ devices_and_signatures .setdefault ((user_id , device_id ), [])
550+
551+ # Populate the return dictionary with each found key_id and signature.
552+ for user_id , key_id , target_device_id , signature in txn .fetchall ():
553+ signature_tuple = (key_id , signature )
554+ devices_and_signatures [(user_id , target_device_id )].append (
555+ signature_tuple
556+ )
557+
558+ return devices_and_signatures
559+
560+ return await self .db_pool .runInteraction (
561+ "_get_e2e_cross_signing_signatures_for_devices_txn" ,
562+ _get_e2e_cross_signing_signatures_for_devices_txn ,
563+ device_query ,
517564 )
518565
519566 async def get_e2e_one_time_keys (
@@ -1772,26 +1819,71 @@ async def store_e2e_cross_signing_signatures(
17721819 user_id: the user who made the signatures
17731820 signatures: signatures to add
17741821 """
1775- await self .db_pool .simple_insert_many (
1776- "e2e_cross_signing_signatures" ,
1777- keys = (
1778- "user_id" ,
1779- "key_id" ,
1780- "target_user_id" ,
1781- "target_device_id" ,
1782- "signature" ,
1783- ),
1784- values = [
1785- (
1786- user_id ,
1787- item .signing_key_id ,
1788- item .target_user_id ,
1789- item .target_device_id ,
1790- item .signature ,
1791- )
1822+
1823+ def _store_e2e_cross_signing_signatures (
1824+ txn : LoggingTransaction ,
1825+ signatures : "Iterable[SignatureListItem]" ,
1826+ ) -> None :
1827+ self .db_pool .simple_insert_many_txn (
1828+ txn ,
1829+ "e2e_cross_signing_signatures" ,
1830+ keys = (
1831+ "user_id" ,
1832+ "key_id" ,
1833+ "target_user_id" ,
1834+ "target_device_id" ,
1835+ "signature" ,
1836+ ),
1837+ values = [
1838+ (
1839+ user_id ,
1840+ item .signing_key_id ,
1841+ item .target_user_id ,
1842+ item .target_device_id ,
1843+ item .signature ,
1844+ )
1845+ for item in signatures
1846+ ],
1847+ )
1848+
1849+ to_invalidate = [
1850+ # Each entry is a tuple of arguments to
1851+ # `_get_e2e_cross_signing_signatures_for_device`, which
1852+ # itself takes a tuple. Hence the double-tuple.
1853+ ((user_id , item .target_device_id ),)
17921854 for item in signatures
1793- ],
1794- desc = "add_e2e_signing_key" ,
1855+ ]
1856+
1857+ if to_invalidate :
1858+ # Invalidate the local cache of this worker.
1859+ for cache_key in to_invalidate :
1860+ txn .call_after (
1861+ self ._get_e2e_cross_signing_signatures_for_device .invalidate ,
1862+ cache_key ,
1863+ )
1864+
1865+ # Stream cache invalidate keys over replication.
1866+ #
1867+ # We can only send a primitive per function argument across
1868+ # replication.
1869+ #
1870+ # Encode the array of strings as a JSON string, and we'll unpack
1871+ # it on the other side.
1872+ to_send = [
1873+ (json .dumps ([user_id , item .target_device_id ]),)
1874+ for item in signatures
1875+ ]
1876+
1877+ self ._send_invalidation_to_replication_bulk (
1878+ txn ,
1879+ cache_name = self ._get_e2e_cross_signing_signatures_for_device .__name__ ,
1880+ key_tuples = to_send ,
1881+ )
1882+
1883+ await self .db_pool .runInteraction (
1884+ "add_e2e_signing_key" ,
1885+ _store_e2e_cross_signing_signatures ,
1886+ signatures ,
17951887 )
17961888
17971889
0 commit comments