Skip to content

Commit d949888

Browse files
orionleeceb8
authored andcommitted
MAST query result cache: Observations.query_criteria()
1 parent 62a34f0 commit d949888

File tree

4 files changed

+100
-12
lines changed

4 files changed

+100
-12
lines changed

astroquery/mast/discovery_portal.py

+46-9
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
import uuid
1111
import json
1212
import time
13+
import re
1314

1415
import numpy as np
1516

@@ -18,7 +19,7 @@
1819
from astropy.table import Table, vstack, MaskedColumn
1920
from astropy.utils import deprecated
2021

21-
from ..query import BaseQuery, QueryWithLogin
22+
from ..query import BaseQuery, QueryWithLogin, AstroQuery, to_cache
2223
from ..utils import async_to_sync
2324
from ..utils.class_or_instance import class_or_instance
2425
from ..exceptions import InputWarning, NoResultsWarning, RemoteServiceError
@@ -211,7 +212,39 @@ def _request(self, method, url, params=None, data=None, headers=None,
211212

212213
return all_responses
213214

214-
def _get_col_config(self, service, fetch_name=None):
215+
def _request_w_cache(self, method, url, data=None, headers=None, retrieve_all=True,
216+
cache=False, cache_opts=None):
217+
# Note: the method only exposes 4 parameters of the underlying _request() function
218+
# to play nice with existing mocks
219+
# Caching: follow BaseQuery._request()'s pattern, which uses an AstroQuery object
220+
if not cache:
221+
response = self._request(method, url, data=data, headers=headers, retrieve_all=retrieve_all)
222+
else:
223+
cacher = self._get_cacher(method, url, data, headers, retrieve_all)
224+
response = cacher.from_cache(self.cache_location)
225+
if not response:
226+
response = self._request(method, url, data=data, headers=headers, retrieve_all=retrieve_all)
227+
to_cache(response, cacher.request_file(self.cache_location))
228+
return response
229+
230+
def _get_cacher(self, method, url, data, headers, retrieve_all):
231+
"""
232+
Return an object that can cache the HTTP request based on the supplied arguments
233+
"""
234+
235+
# cacheBreaker parameter (to underlying MAST service) is not relevant (and breaks) local caching
236+
# remove it from part of the cache key
237+
data_no_cache_breaker = re.sub(r'^(.+)cacheBreaker%22%3A%20%22.+%22', r'\1', data)
238+
# include retrieve_all as part of the cache key by appending it to data
239+
# it cannot be added as part of req_kwargs dict, as it will be rejected by AstroQuery
240+
data_w_retrieve_all = data_no_cache_breaker + " retrieve_all={}".format(retrieve_all)
241+
req_kwargs = dict(
242+
data=data_no_cache_breaker,
243+
headers=headers
244+
)
245+
return AstroQuery(method, url, **req_kwargs)
246+
247+
def _get_col_config(self, service, fetch_name=None, cache=False):
215248
"""
216249
Gets the columnsConfig entry for given service and stores it in `self._column_configs`.
217250
@@ -247,7 +280,7 @@ def _get_col_config(self, service, fetch_name=None):
247280
if more:
248281
mashup_request = {'service': all_name, 'params': {}, 'format': 'extjs'}
249282
req_string = _prepare_service_request_string(mashup_request)
250-
response = self._request("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers)
283+
response = self._request_w_cache("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers, cache=cache)
251284
json_response = response[0].json()
252285

253286
self._column_configs[service].update(json_response['data']['Tables'][0]
@@ -301,7 +334,7 @@ def _parse_result(self, responses, verbose=False):
301334
return all_results
302335

303336
@class_or_instance
304-
def service_request_async(self, service, params, pagesize=None, page=None, **kwargs):
337+
def service_request_async(self, service, params, pagesize=None, page=None, cache=False, cache_opts=None, **kwargs):
305338
"""
306339
Given a Mashup service and parameters, builds and excecutes a Mashup query.
307340
See documentation `here <https://mast.stsci.edu/api/v0/class_mashup_1_1_mashup_request.html>`__
@@ -321,6 +354,10 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa
321354
Default None.
322355
Can be used to override the default behavior of all results being returned to obtain
323356
a specific page of results.
357+
cache : Boolean, optional
358+
try to use cached the query result if set to True
359+
cache_opts : dict, optional
360+
cache options, details TBD, e.g., cache expiration policy, etc.
324361
**kwargs :
325362
See MashupRequest properties
326363
`here <https://mast.stsci.edu/api/v0/class_mashup_1_1_mashup_request.html>`__
@@ -334,7 +371,7 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa
334371
# setting self._current_service
335372
if service not in self._column_configs.keys():
336373
fetch_name = kwargs.pop('fetch_name', None)
337-
self._get_col_config(service, fetch_name)
374+
self._get_col_config(service, fetch_name, cache)
338375
self._current_service = service
339376

340377
# setting up pagination
@@ -360,12 +397,12 @@ def service_request_async(self, service, params, pagesize=None, page=None, **kwa
360397
mashup_request[prop] = value
361398

362399
req_string = _prepare_service_request_string(mashup_request)
363-
response = self._request("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers,
364-
retrieve_all=retrieve_all)
400+
response = self._request_w_cache("POST", self.MAST_REQUEST_URL, data=req_string, headers=headers,
401+
retrieve_all=retrieve_all, cache=cache, cache_opts=cache_opts)
365402

366403
return response
367404

368-
def build_filter_set(self, column_config_name, service_name=None, **filters):
405+
def build_filter_set(self, column_config_name, service_name=None, cache=False, **filters):
369406
"""
370407
Takes user input dictionary of filters and returns a filterlist that the Mashup can understand.
371408
@@ -393,7 +430,7 @@ def build_filter_set(self, column_config_name, service_name=None, **filters):
393430
service_name = column_config_name
394431

395432
if not self._column_configs.get(service_name):
396-
self._get_col_config(service_name, fetch_name=column_config_name)
433+
self._get_col_config(service_name, fetch_name=column_config_name, cache=cache)
397434

398435
caom_col_config = self._column_configs[service_name]
399436

astroquery/mast/observations.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def query_object_async(self, objectname, *, radius=0.2*u.deg, pagesize=None, pag
255255
return self.query_region_async(coordinates, radius=radius, pagesize=pagesize, page=page)
256256

257257
@class_or_instance
258-
def query_criteria_async(self, *, pagesize=None, page=None, **criteria):
258+
def query_criteria_async(self, *, pagesize=None, page=None, cache=False, cache_opts=None, **criteria):
259259
"""
260260
Given an set of criteria, returns a list of MAST observations.
261261
Valid criteria are returned by ``get_metadata("observations")``
@@ -300,7 +300,7 @@ def query_criteria_async(self, *, pagesize=None, page=None, **criteria):
300300
params = {"columns": "*",
301301
"filters": mashup_filters}
302302

303-
return self._portal_api_connection.service_request_async(service, params)
303+
return self._portal_api_connection.service_request_async(service, params, cache=cache, cache_opts=cache_opts)
304304

305305
def query_region_count(self, coordinates, *, radius=0.2*u.deg, pagesize=None, page=None):
306306
"""

astroquery/mast/tests/test_mast.py

+50
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,25 @@ def patch_post(request):
7676
return mp
7777

7878

79+
_num_mockreturn = 0
80+
81+
82+
def _get_num_mockreturn():
83+
global _num_mockreturn
84+
return _num_mockreturn
85+
86+
87+
def _reset_mockreturn_counter():
88+
global _num_mockreturn
89+
_num_mockreturn = 0
90+
91+
92+
def _inc_num_mockreturn():
93+
global _num_mockreturn
94+
_num_mockreturn += 1
95+
return _num_mockreturn
96+
97+
7998
def post_mockreturn(self, method="POST", url=None, data=None, timeout=10, **kwargs):
8099
if "columnsconfig" in url:
81100
if "Mast.Catalogs.Tess.Cone" in data:
@@ -102,6 +121,9 @@ def post_mockreturn(self, method="POST", url=None, data=None, timeout=10, **kwar
102121
with open(filename, 'rb') as infile:
103122
content = infile.read()
104123

124+
# For cache tests
125+
_inc_num_mockreturn()
126+
105127
# returning as list because this is what the mast _request function does
106128
return [MockResponse(content)]
107129

@@ -367,6 +389,34 @@ def test_query_observations_criteria_async(patch_post):
367389
assert isinstance(responses, list)
368390

369391

392+
def test_query_observations_criteria_async_cache(patch_post):
393+
_reset_mockreturn_counter()
394+
assert 0 == _get_num_mockreturn(), "Mock HTTP call counter reset to 0"
395+
396+
responses_cache_miss = mast.Observations.query_criteria_async(dataproduct_type=["image"],
397+
proposal_pi="Ost*",
398+
s_dec=[43.5, 45.5], cache=True)
399+
assert isinstance(responses_cache_miss, list)
400+
num_mockreturn_after_first_call = _get_num_mockreturn()
401+
assert num_mockreturn_after_first_call > 0, "Cache miss, some underlying HTTP call"
402+
403+
responses_cache_hit = mast.Observations.query_criteria_async(dataproduct_type=["image"],
404+
proposal_pi="Ost*",
405+
s_dec=[43.5, 45.5], cache=True)
406+
# assert the cached response is the same
407+
assert len(responses_cache_hit) == len(responses_cache_miss)
408+
assert responses_cache_hit[0].text == responses_cache_miss[0].text
409+
# ensure the response really comes from the cache
410+
assert num_mockreturn_after_first_call == _get_num_mockreturn(), \
411+
'Cache hit: should reach cache only, i.e., no HTTP call'
412+
413+
responses_no_cache = mast.Observations.query_criteria_async(dataproduct_type=["image"],
414+
proposal_pi="Ost*",
415+
s_dec=[43.5, 45.5], cache=False)
416+
assert isinstance(responses_no_cache, list)
417+
assert _get_num_mockreturn() > num_mockreturn_after_first_call, "Cache off , some underlying HTTP call"
418+
419+
370420
def test_observations_query_criteria(patch_post):
371421
# without position
372422
result = mast.Observations.query_criteria(dataproduct_type=["image"],

astroquery/query.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,8 @@ def from_cache(self, cache_location):
110110
try:
111111
with open(request_file, "rb") as f:
112112
response = pickle.load(f)
113-
if not isinstance(response, requests.Response):
113+
if not isinstance(response, requests.Response) and not isinstance(response, list):
114+
# MAST query response is a list of Response
114115
response = None
115116
except FileNotFoundError:
116117
response = None

0 commit comments

Comments
 (0)