diff --git a/src/cirrus/lib/catalog.py b/src/cirrus/lib/catalog.py index 4076323..b472dc2 100644 --- a/src/cirrus/lib/catalog.py +++ b/src/cirrus/lib/catalog.py @@ -8,6 +8,7 @@ import uuid from datetime import datetime, timezone from typing import Dict, Optional, List +from copy import deepcopy from boto3utils import s3 from cirrus.lib.statedb import StateDB @@ -38,21 +39,33 @@ def __init__(self, *args, update=False, state_item=None, **kwargs): """ super(Catalog, self).__init__(*args, **kwargs) - # convert old functions field to tasks - if 'functions' in self['process']: - self['process']['tasks'] = self['process'].pop('functions') - self.logger = get_task_logger(__name__, catalog=self) - if update: - self.update() - # validate process block + # TODO: assert isn't safe for this use if debug is off assert(self['type'] == 'FeatureCollection') assert('process' in self) - assert('output_options' in self['process']) - assert('workflow' in self['process']) - assert('tasks' in self['process']) + + self.process = ( + self['process'][0] + if isinstance(self['process'], list) + else self['process'] + ) + + if update: + self.update() + + assert('output_options' in self.process) + assert('workflow' in self.process) + + # convert old functions field to tasks + if 'functions' in self.process: + self.logger.warning("Deprecated: process 'functions' has been renamed to 'tasks'") + self.process['tasks'] = self.process.pop('functions') + + assert('tasks' in self.process) + self.tasks = self.process['tasks'] + assert('workflow-' in self['id']) # TODO - validate with a JSON schema @@ -99,10 +112,27 @@ def from_payload(cls, payload: Dict, **kwargs) -> Catalog: cat = payload return cls(cat, **kwargs) + def get_task(self, task_name, *args, **kwargs): + return self.tasks.get(task_name, *args, **kwargs) + + def next_workflows(self): + if isinstance(self['process'], dict) or len(self['process']) <= 1: + return None + next_processes = ( + [self['process'][1]] + if isinstance(self['process'][1], dict) + else self['process'][1] + ) + for process in next_processes: + new = deepcopy(self) + new['process'].pop(0) + new['process'][0] = process + yield new + def update(self): - if 'collections' in self['process']: + if 'collections' in self.process: # allow overriding of collections name - collections_str = self['process']['collections'] + collections_str = self.process['collections'] else: # otherwise, get from items cols = sorted(list(set([i['collection'] for i in self['features'] if 'collection' in i]))) @@ -111,14 +141,14 @@ def update(self): items_str = '/'.join(sorted(list([i['id'] for i in self['features']]))) if 'id' not in self: - self['id'] = f"{collections_str}/workflow-{self['process']['workflow']}/{items_str}" + self['id'] = f"{collections_str}/workflow-{self.process['workflow']}/{items_str}" # assign collections to Items given a mapping of Col ID: ID regex def assign_collections(self): """Assign new collections to all Items (features) in Catalog - based on self['process']['output_options']['collections'] + based on self.process['output_options']['collections'] """ - collections = self['process']['output_options'].get('collections', {}) + collections = self.process['output_options'].get('collections', {}) # loop through all Items in Catalog for item in self['features']: # loop through all provided output collections regexs @@ -144,14 +174,14 @@ def get_payload(self) -> Dict: return dict(self) def get_items_by_properties(self, key): - properties = self['process']['item-queries'].get(key, {}) + properties = self.process['item-queries'].get(key, {}) features = [] if properties: for feature in self['features']: if property_match(feature, properties): features.append(feature) else: - msg = f"unable to find item, please check properties parameters" + msg = 'unable to find item, please check properties parameters' logger.error(msg) raise Exception(msg) return features @@ -161,7 +191,11 @@ def get_item_by_properties(self, key): if len(features) == 1: return features[0] elif len(features) > 1: - msg = f"multiple items returned, please check properties parameters, or use get_items_by_properties" + msg = ( + 'multiple items returned, ' + 'please check properties parameters, ' + 'or use get_items_by_properties' + ) logger.error(msg) raise Exception(msg) else: @@ -179,7 +213,7 @@ def publish_to_s3(self, bucket, public=False) -> List: Returns: List: List of s3 URLs to published Items """ - opts = self['process'].get('output_options', {}) + opts = self.process.get('output_options', {}) s3urls = [] for item in self['features']: # determine URL of data bucket to publish to- always do this @@ -227,8 +261,8 @@ def publish_to_s3(self, bucket, public=False) -> List: return s3urls - @classmethod - def sns_attributes(self, item) -> Dict: + @staticmethod + def sns_attributes(item) -> Dict: """Create attributes from Item for publishing to SNS Args: @@ -280,18 +314,36 @@ def sns_attributes(self, item) -> Dict: } return attr - def publish_to_sns(self, topic_arn=PUBLISH_TOPIC_ARN): + def publish_to_sns(self, topic_arn): """Publish this catalog to SNS + Args: + topic_arn (str): ARN of SNS Topic. + """ + response = snsclient.publish( + TopicArn=topic_arn, + Message=json.dumps(self), + ) + self.logger.debug(f"Published catalog to {topic_arn}") + return response + + def publish_items_to_sns(self, topic_arn=PUBLISH_TOPIC_ARN): + """Publish this catalog's items to SNS + Args: topic_arn (str, optional): ARN of SNS Topic. Defaults to PUBLISH_TOPIC_ARN. """ + responses = [] for item in self['features']: - response = snsclient.publish(TopicArn=topic_arn, Message=json.dumps(item), - MessageAttributes=self.sns_attributes(item)) + responses.append(snsclient.publish( + TopicArn=topic_arn, + Message=json.dumps(item), + MessageAttributes=self.sns_attributes(item), + )) self.logger.debug(f"Published item to {topic_arn}") + return responses - def process(self) -> str: + def __call__(self) -> str: """Add this Catalog to Cirrus and start workflow Returns: @@ -299,7 +351,7 @@ def process(self) -> str: """ assert(CATALOG_BUCKET) - arn = os.getenv('BASE_WORKFLOW_ARN') + self['process']['workflow'] + arn = os.getenv('BASE_WORKFLOW_ARN') + self.process['workflow'] # start workflow try: @@ -319,14 +371,13 @@ def process(self) -> str: return self['id'] except statedb.db.meta.client.exceptions.ConditionalCheckFailedException: - msg = f"Already in PROCESSING state" - self.logger.warning(msg) + self.logger.warning('Already in PROCESSING state') return None except Exception as err: msg = f"failed starting workflow ({err})" - self.logger.error(msg, exc_info=True) + self.logger.exception(msg) statedb.set_failed(self['id'], msg) - raise err + raise class Catalogs(object): @@ -428,7 +479,7 @@ def process(self, replace=False): # check existing states states = self.get_states() for cat in self.catalogs: - _replace = replace or cat['process'].get('replace', False) + _replace = replace or cat.process.get('replace', False) # check existing state for Item, if any state = states.get(cat['id'], '') # don't try and process these - if they are stuck they should be removed from db @@ -436,7 +487,7 @@ def process(self, replace=False): # logger.info(f"Skipping {cat['id']}, in {state} state") # continue if state in ['FAILED', ''] or _replace: - catid = cat.process() + catid = cat() if catid is None: catids.append(catid) else: diff --git a/tests/test_catalog.py b/tests/test_catalog.py index 2b2d204..a41c21b 100644 --- a/tests/test_catalog.py +++ b/tests/test_catalog.py @@ -1,70 +1,141 @@ import copy -import os import json -import unittest +import pytest + +from pathlib import Path from cirrus.lib.catalog import Catalog -testpath = os.path.dirname(__file__) - - -class TestClassMethods(unittest.TestCase): - - def open_fixture(self, filename='test-catalog.json'): - with open(os.path.join(testpath, 'fixtures', filename)) as f: - data = json.loads(f.read()) - return data - - def test_open_catalog(self): - data = self.open_fixture() - cat = Catalog(**data) - assert(cat['id'] == "sentinel-s2-l2a/workflow-cog-archive/S2B_17HQD_20201103_0_L2A") - - def test_update_catalog(self): - data = self.open_fixture() - del data['id'] - del data['features'][0]['links'] - cat = Catalog(**data, update=True) - assert(cat['id'] == "sentinel-s2-l2a/workflow-cog-archive/S2B_17HQD_20201103_0_L2A") - - def test_from_payload(self): - data = self.open_fixture('sqs-payload.json') - cat = Catalog.from_payload(data, update=True) - assert(len(cat['features']) == 1) - assert(cat['id'] == 'sentinel-s2-l2a-aws/workflow-publish-sentinel/tiles-17-H-QD-2020-11-3-0') - - def test_assign_collections(self): - cat = Catalog(self.open_fixture()) - cat['process']['output_options']['collections'] = {'test': '.*'} - cat.assign_collections() - assert(cat['features'][0]['collection'] == 'test') - - def test_sns_attributes(self): - cat = Catalog(self.open_fixture()) - attr = Catalog.sns_attributes(cat['features'][0]) - assert(attr['cloud_cover']['StringValue'] == '51.56') - assert(attr['datetime']['StringValue'] == '2020-11-03T15:22:26Z') - - def test_get_items_by_properties(self): - data = self.open_fixture() - data['process']['item-queries'] = { - 'test': {'platform':'sentinel-2b'}, - 'empty-test': {'platform': 'test-platform'} - } - cat = Catalog.from_payload(data) - assert(cat.get_items_by_properties("test") == data['features']) - assert(cat.get_items_by_properties("empty-test") == []) - - def test_get_item_by_properties(self): - data = self.open_fixture() - data['process']['item-queries'] = { - 'feature1': {'platform':'sentinel-2b'}, - 'feature2': {'platform': 'test-platform'} - } - feature1 = copy.deepcopy(data['features'][0]) - feature2 = copy.deepcopy(data['features'][0]) - feature2['properties']['platform'] = 'test-platform' - data['features'] = [feature1, feature2] - cat = Catalog.from_payload(data) - assert(cat.get_item_by_properties("feature1") == feature1) - assert(cat.get_item_by_properties("feature2") == feature2) +fixtures = Path(__file__).parent.joinpath('fixtures') + + +def read_json_fixture(filename): + with fixtures.joinpath(filename).open() as f: + return json.load(f) + + +@pytest.fixture() +def base_payload(): + return read_json_fixture('test-catalog.json') + + +@pytest.fixture() +def sqs_payload(): + return read_json_fixture('sqs-payload.json') + + +def test_open_catalog(base_payload): + cat = Catalog(**base_payload) + assert cat['id'] == \ + "sentinel-s2-l2a/workflow-cog-archive/S2B_17HQD_20201103_0_L2A" + + +def test_update_catalog(base_payload): + del base_payload['id'] + del base_payload['features'][0]['links'] + cat = Catalog(**base_payload, update=True) + assert cat['id'] == \ + "sentinel-s2-l2a/workflow-cog-archive/S2B_17HQD_20201103_0_L2A" + + +def test_from_payload(sqs_payload): + cat = Catalog.from_payload(sqs_payload, update=True) + assert len(cat['features']) == 1 + assert cat['id'] == \ + 'sentinel-s2-l2a-aws/workflow-publish-sentinel/tiles-17-H-QD-2020-11-3-0' + + +def test_assign_collections(base_payload): + cat = Catalog(base_payload) + cat['process']['output_options']['collections'] = {'test': '.*'} + cat.assign_collections() + assert cat['features'][0]['collection'] == 'test' + + +def test_sns_attributes(base_payload): + cat = Catalog(base_payload) + attr = Catalog.sns_attributes(cat['features'][0]) + assert attr['cloud_cover']['StringValue'] == '51.56' + assert attr['datetime']['StringValue'] == '2020-11-03T15:22:26Z' + + +def test_get_items_by_properties(base_payload): + base_payload['process']['item-queries'] = { + 'test': {'platform':'sentinel-2b'}, + 'empty-test': {'platform': 'test-platform'} + } + cat = Catalog.from_payload(base_payload) + assert(cat.get_items_by_properties("test") == base_payload['features']) + assert(cat.get_items_by_properties("empty-test") == []) + + +def test_get_item_by_properties(base_payload): + base_payload['process']['item-queries'] = { + 'feature1': {'platform':'sentinel-2b'}, + 'feature2': {'platform': 'test-platform'} + } + feature1 = copy.deepcopy(base_payload['features'][0]) + feature2 = copy.deepcopy(base_payload['features'][0]) + feature2['properties']['platform'] = 'test-platform' + base_payload['features'] = [feature1, feature2] + cat = Catalog.from_payload(base_payload) + assert(cat.get_item_by_properties("feature1") == feature1) + assert(cat.get_item_by_properties("feature2") == feature2) + + +def test_next_workflows_no_list(base_payload): + workflows = list(Catalog.from_payload(base_payload).next_workflows()) + assert len(workflows) == 0 + + +def test_next_workflows_list_of_one(base_payload): + base_payload['process'] = [base_payload['process']] + workflows = list(Catalog.from_payload(base_payload).next_workflows()) + assert len(workflows) == 0 + + +def test_next_workflows_list_of_four(base_payload): + length = 4 + list_payload = copy.deepcopy(base_payload) + list_payload['process'] = [base_payload['process']] * length + + # We should now have something like this: + # + # payload + # process: + # - wf1 + # - wf2 + # - wf3 + # - wf4 + workflows = list(Catalog.from_payload(list_payload).next_workflows()) + + # When we call next_workflows, we find one next workflow (wf2) + # with two to follow. So the length of the list returned should be + # one, a workflow payload with a process array of length 3. + assert len(workflows) == 1 + assert workflows[0]['process'] == [base_payload['process']] * (length-1) + + +def test_next_workflows_list_of_four_fork(base_payload): + length = 3 + list_payload = copy.deepcopy(base_payload) + list_payload['process'] = [base_payload['process']] * length + list_payload['process'][1] = [base_payload['process']] * 2 + + # We should now have something like this: + # + # payload + # process: + # - wf1 + # - [ wf2a, wf2b] + # - wf3 + # - wf4 + workflows = list(Catalog.from_payload(list_payload).next_workflows()) + + # When we call next_workflows, we find two next workflows + # (wf2a and wf2b), each with two to follow. So the length of + # the list returned should be two, each a workflow payload + # with a process array of length 3. + assert len(workflows) == 2 + assert workflows[0]['process'] == [base_payload['process']] * (length-1) + assert workflows[1]['process'] == [base_payload['process']] * (length-1)