Skip to content
This repository has been archived by the owner on Dec 5, 2023. It is now read-only.

Commit

Permalink
Merge pull request #32 from cirrus-geo/support-workflow-chaining
Browse files Browse the repository at this point in the history
catalog support for workflow chaining
  • Loading branch information
jkeifer authored Dec 29, 2021
2 parents e2a96e3 + 0e1005e commit 970f3a5
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 97 deletions.
115 changes: 83 additions & 32 deletions src/cirrus/lib/catalog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import uuid
from datetime import datetime, timezone
from typing import Dict, Optional, List
from copy import deepcopy

from boto3utils import s3
from cirrus.lib.statedb import StateDB
Expand Down Expand Up @@ -38,21 +39,33 @@ def __init__(self, *args, update=False, state_item=None, **kwargs):
"""
super(Catalog, self).__init__(*args, **kwargs)

# convert old functions field to tasks
if 'functions' in self['process']:
self['process']['tasks'] = self['process'].pop('functions')

self.logger = get_task_logger(__name__, catalog=self)

if update:
self.update()

# validate process block
# TODO: assert isn't safe for this use if debug is off
assert(self['type'] == 'FeatureCollection')
assert('process' in self)
assert('output_options' in self['process'])
assert('workflow' in self['process'])
assert('tasks' in self['process'])

self.process = (
self['process'][0]
if isinstance(self['process'], list)
else self['process']
)

if update:
self.update()

assert('output_options' in self.process)
assert('workflow' in self.process)

# convert old functions field to tasks
if 'functions' in self.process:
self.logger.warning("Deprecated: process 'functions' has been renamed to 'tasks'")
self.process['tasks'] = self.process.pop('functions')

assert('tasks' in self.process)
self.tasks = self.process['tasks']

assert('workflow-' in self['id'])

# TODO - validate with a JSON schema
Expand Down Expand Up @@ -99,10 +112,27 @@ def from_payload(cls, payload: Dict, **kwargs) -> Catalog:
cat = payload
return cls(cat, **kwargs)

def get_task(self, task_name, *args, **kwargs):
return self.tasks.get(task_name, *args, **kwargs)

def next_workflows(self):
if isinstance(self['process'], dict) or len(self['process']) <= 1:
return None
next_processes = (
[self['process'][1]]
if isinstance(self['process'][1], dict)
else self['process'][1]
)
for process in next_processes:
new = deepcopy(self)
new['process'].pop(0)
new['process'][0] = process
yield new

def update(self):
if 'collections' in self['process']:
if 'collections' in self.process:
# allow overriding of collections name
collections_str = self['process']['collections']
collections_str = self.process['collections']
else:
# otherwise, get from items
cols = sorted(list(set([i['collection'] for i in self['features'] if 'collection' in i])))
Expand All @@ -111,14 +141,14 @@ def update(self):

items_str = '/'.join(sorted(list([i['id'] for i in self['features']])))
if 'id' not in self:
self['id'] = f"{collections_str}/workflow-{self['process']['workflow']}/{items_str}"
self['id'] = f"{collections_str}/workflow-{self.process['workflow']}/{items_str}"

# assign collections to Items given a mapping of Col ID: ID regex
def assign_collections(self):
"""Assign new collections to all Items (features) in Catalog
based on self['process']['output_options']['collections']
based on self.process['output_options']['collections']
"""
collections = self['process']['output_options'].get('collections', {})
collections = self.process['output_options'].get('collections', {})
# loop through all Items in Catalog
for item in self['features']:
# loop through all provided output collections regexs
Expand All @@ -144,14 +174,14 @@ def get_payload(self) -> Dict:
return dict(self)

def get_items_by_properties(self, key):
properties = self['process']['item-queries'].get(key, {})
properties = self.process['item-queries'].get(key, {})
features = []
if properties:
for feature in self['features']:
if property_match(feature, properties):
features.append(feature)
else:
msg = f"unable to find item, please check properties parameters"
msg = 'unable to find item, please check properties parameters'
logger.error(msg)
raise Exception(msg)
return features
Expand All @@ -161,7 +191,11 @@ def get_item_by_properties(self, key):
if len(features) == 1:
return features[0]
elif len(features) > 1:
msg = f"multiple items returned, please check properties parameters, or use get_items_by_properties"
msg = (
'multiple items returned, '
'please check properties parameters, '
'or use get_items_by_properties'
)
logger.error(msg)
raise Exception(msg)
else:
Expand All @@ -179,7 +213,7 @@ def publish_to_s3(self, bucket, public=False) -> List:
Returns:
List: List of s3 URLs to published Items
"""
opts = self['process'].get('output_options', {})
opts = self.process.get('output_options', {})
s3urls = []
for item in self['features']:
# determine URL of data bucket to publish to- always do this
Expand Down Expand Up @@ -227,8 +261,8 @@ def publish_to_s3(self, bucket, public=False) -> List:

return s3urls

@classmethod
def sns_attributes(self, item) -> Dict:
@staticmethod
def sns_attributes(item) -> Dict:
"""Create attributes from Item for publishing to SNS
Args:
Expand Down Expand Up @@ -280,26 +314,44 @@ def sns_attributes(self, item) -> Dict:
}
return attr

def publish_to_sns(self, topic_arn=PUBLISH_TOPIC_ARN):
def publish_to_sns(self, topic_arn):
"""Publish this catalog to SNS
Args:
topic_arn (str): ARN of SNS Topic.
"""
response = snsclient.publish(
TopicArn=topic_arn,
Message=json.dumps(self),
)
self.logger.debug(f"Published catalog to {topic_arn}")
return response

def publish_items_to_sns(self, topic_arn=PUBLISH_TOPIC_ARN):
"""Publish this catalog's items to SNS
Args:
topic_arn (str, optional): ARN of SNS Topic. Defaults to PUBLISH_TOPIC_ARN.
"""
responses = []
for item in self['features']:
response = snsclient.publish(TopicArn=topic_arn, Message=json.dumps(item),
MessageAttributes=self.sns_attributes(item))
responses.append(snsclient.publish(
TopicArn=topic_arn,
Message=json.dumps(item),
MessageAttributes=self.sns_attributes(item),
))
self.logger.debug(f"Published item to {topic_arn}")
return responses

def process(self) -> str:
def __call__(self) -> str:
"""Add this Catalog to Cirrus and start workflow
Returns:
str: Catalog ID
"""
assert(CATALOG_BUCKET)

arn = os.getenv('BASE_WORKFLOW_ARN') + self['process']['workflow']
arn = os.getenv('BASE_WORKFLOW_ARN') + self.process['workflow']

# start workflow
try:
Expand All @@ -319,14 +371,13 @@ def process(self) -> str:

return self['id']
except statedb.db.meta.client.exceptions.ConditionalCheckFailedException:
msg = f"Already in PROCESSING state"
self.logger.warning(msg)
self.logger.warning('Already in PROCESSING state')
return None
except Exception as err:
msg = f"failed starting workflow ({err})"
self.logger.error(msg, exc_info=True)
self.logger.exception(msg)
statedb.set_failed(self['id'], msg)
raise err
raise


class Catalogs(object):
Expand Down Expand Up @@ -428,15 +479,15 @@ def process(self, replace=False):
# check existing states
states = self.get_states()
for cat in self.catalogs:
_replace = replace or cat['process'].get('replace', False)
_replace = replace or cat.process.get('replace', False)
# check existing state for Item, if any
state = states.get(cat['id'], '')
# don't try and process these - if they are stuck they should be removed from db
#if state in ['QUEUED', 'PROCESSING']:
# logger.info(f"Skipping {cat['id']}, in {state} state")
# continue
if state in ['FAILED', ''] or _replace:
catid = cat.process()
catid = cat()
if catid is None:
catids.append(catid)
else:
Expand Down
Loading

0 comments on commit 970f3a5

Please sign in to comment.