Skip to content

Commit

Permalink
Remove Query's Synchronous usage
Browse files Browse the repository at this point in the history
First step to remove our asyncio metaprogramming...

  #77

Our Query class now provides a run method for synchronous users, and run_async
for asyncio. This also adds a stop method that can cancel our download.
  • Loading branch information
atagar committed Nov 8, 2020
1 parent 757a614 commit 7ce8a5e
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 37 deletions.
114 changes: 90 additions & 24 deletions stem/descriptor/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@

from stem.descriptor import Compression
from stem.util import log, str_tools
from stem.util.asyncio import Synchronous
from typing import Any, AsyncIterator, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Sequence, Tuple, Union

# Tor has a limited number of descriptors we can fetch explicitly by their
# fingerprint or hashes due to a limit on the url length by squid proxies.
Expand Down Expand Up @@ -227,7 +226,7 @@ def get_detached_signatures(**query_args: Any) -> 'stem.descriptor.remote.Query'
return get_instance().get_detached_signatures(**query_args)


class Query(Synchronous):
class Query(object):
"""
Asynchronous request for descriptor content from a directory authority or
mirror. These can either be made through the
Expand Down Expand Up @@ -369,7 +368,6 @@ def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoin
super(Query, self).__init__()

if not resource.startswith('/'):
self.stop()
raise ValueError("Resources should start with a '/': %s" % resource)

if resource.endswith('.z'):
Expand All @@ -380,7 +378,6 @@ def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoin
elif isinstance(compression, stem.descriptor._Compression):
compression = [compression] # caller provided only a single option
else:
self.stop()
raise ValueError('Compression should be a list of stem.descriptor.Compression, was %s (%s)' % (compression, type(compression).__name__))

if Compression.ZSTD in compression and not Compression.ZSTD.available:
Expand All @@ -404,7 +401,6 @@ def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoin
if isinstance(endpoint, (stem.ORPort, stem.DirPort)):
self.endpoints.append(endpoint)
else:
self.stop()
raise ValueError("Endpoints must be an stem.ORPort or stem.DirPort. '%s' is a %s." % (endpoint, type(endpoint).__name__))

self.resource = resource
Expand All @@ -428,6 +424,12 @@ def __init__(self, resource: str, descriptor_type: Optional[str] = None, endpoin
self._downloader_task = None # type: Optional[asyncio.Task]
self._downloader_lock = threading.RLock()

# background thread if outside an asyncio context

self._loop = None # type: Optional[asyncio.AbstractEventLoop]
self._loop_thread = None # type: Optional[threading.Thread]
self._loop_lock = threading.RLock()

if start:
self.start()

Expand All @@ -441,9 +443,38 @@ def start(self) -> None:

with self._downloader_lock:
if self._downloader_task is None:
self._downloader_task = self._loop.create_task(Query._download_descriptors(self, self.retries, self.timeout))
with self._loop_lock:
if self._loop is None:
try:
self._loop = asyncio.get_running_loop()
except RuntimeError:
self._loop = asyncio.new_event_loop()
self._loop_thread = threading.Thread(
name = 'stem.descriptor.remote query',
target = self._loop.run_forever,
daemon = True,
)

self._loop_thread.start()

self._downloader_task = self._loop.create_task(self._download_descriptors(self.retries, self.timeout))

def stop(self) -> None:
"""
Aborts our download if it's in progress, and cleans up underlying
resources.
"""

async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
with self._downloader_lock:
if self._downloader_task and not self._downloader_task.done():
self._downloader_task.cancel()

with self._loop_lock:
if self._loop_thread and self._loop_thread.is_alive():
self._loop.call_soon_threadsafe(self._loop.stop)
self._loop_thread.join()

def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor']:
"""
Blocks until our request is complete then provides the descriptors. If we
haven't yet started our request then this does so.
Expand All @@ -461,12 +492,43 @@ async def run(self, suppress: bool = False) -> List['stem.descriptor.Descriptor'
* :class:`~stem.DownloadFailed` if our request fails
"""

try:
return [desc async for desc in self._run(suppress)]
finally:
self.stop()
if not self.downloaded and not self.error:
with self._loop_lock:
if self._loop is None:
self.start()

async def run_wrapper():
return [desc async for desc in self.run_async(suppress = True)]

asyncio.run_coroutine_threadsafe(run_wrapper(), self._loop).result()

self.stop()

if self.error:
if suppress:
return []

raise self.error
else:
return list(self.downloaded)

async def run_async(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
"""
Asynchronous counterpart of :func:`stem.descriptor.remote.Query.run`
:param suppress: avoids raising exceptions if **True**
:returns: iterator for the requested :class:`~stem.descriptor.__init__.Descriptor` instances
:raises:
Using the iterator can fail with the following if **suppress** is
**False**...
* **ValueError** if the descriptor contents is malformed
* :class:`~stem.DownloadTimeout` if our request timed out
* :class:`~stem.DownloadFailed` if our request fails
"""

async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor]:
with self._downloader_lock:
if not self.downloaded and not self.error:
if not self._downloader_task:
Expand All @@ -477,17 +539,21 @@ async def _run(self, suppress: bool) -> AsyncIterator[stem.descriptor.Descriptor
except Exception as exc:
self.error = exc

if self.error:
if suppress:
return
if self.error:
if suppress:
return

raise self.error
else:
for desc in self.downloaded:
yield desc
raise self.error
else:
for desc in self.downloaded:
yield desc

def __iter__(self) -> Iterator[stem.descriptor.Descriptor]:
for desc in self.run(True):
yield desc

async def __aiter__(self) -> AsyncIterator[stem.descriptor.Descriptor]:
async for desc in self._run(True):
async for desc in self.run_async(True):
yield desc

def _pick_endpoint(self, use_authority: bool = False) -> stem.Endpoint:
Expand Down Expand Up @@ -620,7 +686,7 @@ def use_directory_mirrors(self) -> stem.descriptor.networkstatus.NetworkStatusDo
directories = [auth for auth in stem.directory.Authority.from_cache().values() if auth.nickname not in DIR_PORT_BLACKLIST]
new_endpoints = set([stem.DirPort(directory.address, directory.dir_port) for directory in directories])

consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0] # type: ignore
consensus = list(self.get_consensus(document_handler = stem.descriptor.DocumentHandler.DOCUMENT).run())[0]

for desc in consensus.routers.values():
if stem.Flag.V2DIR in desc.flags and desc.dir_port:
Expand All @@ -630,7 +696,7 @@ def use_directory_mirrors(self) -> stem.descriptor.networkstatus.NetworkStatusDo

self._endpoints = list(new_endpoints)

return consensus
return consensus # type: ignore

def their_server_descriptor(self, **query_args: Any) -> 'stem.descriptor.remote.Query':
"""
Expand Down Expand Up @@ -785,7 +851,7 @@ def get_consensus(self, authority_v3ident: Optional[str] = None, microdescriptor
# authority key certificates

if consensus_query.validate and consensus_query.document_handler == stem.descriptor.DocumentHandler.DOCUMENT:
consensus = list(consensus_query.run())[0] # type: ignore
consensus = list(consensus_query.run())[0]
key_certs = self.get_key_certificates(**query_args).run()

try:
Expand Down
53 changes: 40 additions & 13 deletions test/unit/descriptor/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Unit tests for stem.descriptor.remote.
"""

import time
import unittest

import stem
Expand Down Expand Up @@ -87,12 +88,50 @@ def test_initial_startup(self):

query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)
self.assertTrue(query._downloader_task is None)
query.stop()

query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = True)
self.assertTrue(query._downloader_task is not None)
query.stop()

def test_stop(self):
"""
Stop a complete, in-process, and unstarted query.
"""

# stop a completed query

with mock_download(TEST_DESCRIPTOR):
query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31')
self.assertTrue(query._loop_thread.is_alive())

query.run() # complete the query
self.assertFalse(query._loop_thread.is_alive())
self.assertFalse(query._downloader_task.cancelled())

query.stop() # nothing to do
self.assertFalse(query._loop_thread.is_alive())
self.assertFalse(query._downloader_task.cancelled())

# stop an in-process query

def pause(*args):
time.sleep(5)

with patch('stem.descriptor.remote.Query._download_from', Mock(side_effect = pause)):
query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31')

query.stop() # terminates in-process query
self.assertFalse(query._loop_thread.is_alive())
self.assertTrue(query._downloader_task.cancelled())

# stop an unstarted query

query = stem.descriptor.remote.get_server_descriptors('9695DFC35FFEB861329B9F1AB04C46397020CE31', start = False)

query.stop() # nothing to do
self.assertTrue(query._loop_thread is None)
self.assertTrue(query._downloader_task is None)

@mock_download(TEST_DESCRIPTOR)
def test_download(self):
"""
Expand All @@ -115,8 +154,6 @@ def test_download(self):
self.assertEqual('9695DFC35FFEB861329B9F1AB04C46397020CE31', desc.fingerprint)
self.assertEqual(TEST_DESCRIPTOR.rstrip(), desc.get_bytes())

reply.stop()

def test_response_header_code(self):
"""
When successful Tor provides a '200 OK' status, but we should accept other 2xx
Expand Down Expand Up @@ -165,13 +202,11 @@ def test_reply_header_data(self):
descriptors = list(query)
self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
query.stop()

def test_gzip_url_override(self):
query = stem.descriptor.remote.Query(TEST_RESOURCE + '.z', compression = Compression.PLAINTEXT, start = False)
self.assertEqual([stem.descriptor.Compression.GZIP], query.compression)
self.assertEqual(TEST_RESOURCE, query.resource)
query.stop()

@mock_download(read_resource('compressed_identity'), encoding = 'identity')
def test_compression_plaintext(self):
Expand All @@ -187,7 +222,6 @@ def test_compression_plaintext(self):
)

descriptors = list(query)
query.stop()

self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
Expand All @@ -206,7 +240,6 @@ def test_compression_gzip(self):
)

descriptors = list(query)
query.stop()

self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
Expand All @@ -227,7 +260,6 @@ def test_compression_zstd(self):
)

descriptors = list(query)
query.stop()

self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
Expand All @@ -248,7 +280,6 @@ def test_compression_lzma(self):
)

descriptors = list(query)
query.stop()

self.assertEqual(1, len(descriptors))
self.assertEqual('moria1', descriptors[0].nickname)
Expand Down Expand Up @@ -300,8 +331,6 @@ def test_malformed_content(self):

self.assertRaises(ValueError, query.run)

query.stop()

def test_query_with_invalid_endpoints(self):
invalid_endpoints = {
'hello': "'h' is a str.",
Expand Down Expand Up @@ -330,5 +359,3 @@ def test_can_iterate_multiple_times(self):
self.assertEqual(1, len(list(query)))
self.assertEqual(1, len(list(query)))
self.assertEqual(1, len(list(query)))

query.stop()

0 comments on commit 7ce8a5e

Please sign in to comment.