Skip to content
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 66 additions & 48 deletions services/data/postgres_async_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ async def _init(self, db_conf: DBConfiguration, create_triggers=DB_TRIGGER_CREAT

break # Break the retry loop
except Exception as e:
self.logger.exception("Exception occured")
self.logger.exception("Exception occurred")
if retries - i <= 1:
raise e
time.sleep(connection_retry_wait_time_seconds)
Expand Down Expand Up @@ -466,6 +466,10 @@ class AsyncFlowTablePostgres(AsyncPostgresTable):
)
_row_type = FlowRow

@staticmethod
def get_filter_dict(flow_id: str):
return {"flow_id": flow_id}

async def add_flow(self, flow: FlowRow):
dict = {
"flow_id": flow.flow_id,
Expand All @@ -476,7 +480,7 @@ async def add_flow(self, flow: FlowRow):
return await self.create_record(dict)

async def get_flow(self, flow_id: str):
filter_dict = {"flow_id": flow_id}
filter_dict = self.get_filter_dict(flow_id)
return await self.get_records(filter_dict=filter_dict, fetch_single=True)

async def get_all_flows(self):
Expand Down Expand Up @@ -523,9 +527,13 @@ async def add_run(self, run: RunRow):
}
return await self.create_record(dict)

async def get_run(self, flow_id: str, run_id: str, expanded: bool = False):
@staticmethod
def get_filter_dict(flow_id: str, run_id: str):
key, value = translate_run_key(run_id)
filter_dict = {"flow_id": flow_id, key: str(value)}
return {"flow_id": flow_id, key: str(value)}

async def get_run(self, flow_id: str, run_id: str, expanded: bool = False):
filter_dict = self.get_filter_dict(flow_id, run_id)
return await self.get_records(filter_dict=filter_dict,
fetch_single=True, expanded=expanded)

Expand All @@ -534,9 +542,7 @@ async def get_all_runs(self, flow_id: str):
return await self.get_records(filter_dict=filter_dict)

async def update_heartbeat(self, flow_id: str, run_id: str):
run_key, run_value = translate_run_key(run_id)
filter_dict = {"flow_id": flow_id,
run_key: str(run_value)}
filter_dict = self.get_filter_dict(flow_id, run_id)
set_dict = {
"last_heartbeat_ts": int(datetime.datetime.utcnow().timestamp())
}
Expand Down Expand Up @@ -589,19 +595,23 @@ async def add_step(self, step_object: StepRow):
}
return await self.create_record(dict)

@staticmethod
def get_filter_dict(flow_id: str, run_id: str, step_name: str):
run_id_key, run_id_value = translate_run_key(run_id)
return {
"flow_id": flow_id,
run_id_key: run_id_value,
"step_name": step_name,
}

async def get_steps(self, flow_id: str, run_id: str):
run_id_key, run_id_value = translate_run_key(run_id)
filter_dict = {"flow_id": flow_id,
run_id_key: run_id_value}
return await self.get_records(filter_dict=filter_dict)

async def get_step(self, flow_id: str, run_id: str, step_name: str):
run_id_key, run_id_value = translate_run_key(run_id)
filter_dict = {
"flow_id": flow_id,
run_id_key: run_id_value,
"step_name": step_name,
}
filter_dict = self.get_filter_dict(flow_id, run_id, step_name)
return await self.get_records(filter_dict=filter_dict, fetch_single=True)


Expand Down Expand Up @@ -651,36 +661,35 @@ async def add_task(self, task: TaskRow):
}
return await self.create_record(dict)

async def get_tasks(self, flow_id: str, run_id: str, step_name: str):
@staticmethod
def get_filter_dict(flow_id: str, run_id: str, step_name: str, task_id: str):
run_id_key, run_id_value = translate_run_key(run_id)
filter_dict = {
task_id_key, task_id_value = translate_task_key(task_id)
return {
"flow_id": flow_id,
run_id_key: run_id_value,
"step_name": step_name,
task_id_key: task_id_value,
}
return await self.get_records(filter_dict=filter_dict)

async def get_task(self, flow_id: str, run_id: str, step_name: str,
task_id: str, expanded: bool = False):
async def get_tasks(self, flow_id: str, run_id: str, step_name: str):
run_id_key, run_id_value = translate_run_key(run_id)
task_id_key, task_id_value = translate_task_key(task_id)
filter_dict = {
"flow_id": flow_id,
run_id_key: run_id_value,
"step_name": step_name,
task_id_key: task_id_value,
}
return await self.get_records(filter_dict=filter_dict)

async def get_task(self, flow_id: str, run_id: str, step_name: str,
task_id: str, expanded: bool = False):
filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id)
return await self.get_records(filter_dict=filter_dict,
fetch_single=True, expanded=expanded)

async def update_heartbeat(self, flow_id: str, run_id: str, step_name: str,
task_id: str):
run_key, run_value = translate_run_key(run_id)
task_key, task_value = translate_task_key(task_id)
filter_dict = {"flow_id": flow_id,
run_key: str(run_value),
"step_name": step_name,
task_key: str(task_value)}
filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id)
set_dict = {
"last_heartbeat_ts": int(datetime.datetime.utcnow().timestamp())
}
Expand Down Expand Up @@ -757,23 +766,27 @@ async def add_metadata(
}
return await self.create_record(dict)

@staticmethod
def get_filter_dict(flow_id: str, run_id: str, step_name: str, task_id: str):
run_id_key, run_id_value = translate_run_key(run_id)
task_id_key, task_id_value = translate_task_key(task_id)
return {
"flow_id": flow_id,
run_id_key: run_id_value,
"step_name": step_name,
task_id_key: task_id_value,
}

async def get_metadata_in_runs(self, flow_id: str, run_id: str):
run_id_key, run_id_value = translate_run_key(run_id)
filter_dict = {"flow_id": flow_id,
run_id_key: run_id_value}
return await self.get_records(filter_dict=filter_dict)

async def get_metadata(
self, flow_id: str, run_id: int, step_name: str, task_id: str
self, flow_id: str, run_id: str, step_name: str, task_id: str
):
run_id_key, run_id_value = translate_run_key(run_id)
task_id_key, task_id_value = translate_task_key(task_id)
filter_dict = {
"flow_id": flow_id,
run_id_key: run_id_value,
"step_name": step_name,
task_id_key: task_id_value,
}
filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id)
return await self.get_records(filter_dict=filter_dict)


Expand Down Expand Up @@ -856,7 +869,20 @@ async def add_artifact(
}
return await self.create_record(dict)

async def get_artifacts_in_runs(self, flow_id: str, run_id: int):
@staticmethod
def get_filter_dict(
flow_id: str, run_id: str, step_name: str, task_id: str, name: str):
run_id_key, run_id_value = translate_run_key(run_id)
task_id_key, task_id_value = translate_task_key(task_id)
return {
"flow_id": flow_id,
run_id_key: run_id_value,
"step_name": step_name,
task_id_key: task_id_value,
'"name"': name,
}

async def get_artifacts_in_runs(self, flow_id: str, run_id: str):
run_id_key, run_id_value = translate_run_key(run_id)
filter_dict = {
"flow_id": flow_id,
Expand All @@ -865,7 +891,7 @@ async def get_artifacts_in_runs(self, flow_id: str, run_id: int):
return await self.get_records(filter_dict=filter_dict,
ordering=self.ordering)

async def get_artifact_in_steps(self, flow_id: str, run_id: int, step_name: str):
async def get_artifact_in_steps(self, flow_id: str, run_id: str, step_name: str):
run_id_key, run_id_value = translate_run_key(run_id)
filter_dict = {
"flow_id": flow_id,
Expand All @@ -876,7 +902,7 @@ async def get_artifact_in_steps(self, flow_id: str, run_id: int, step_name: str)
ordering=self.ordering)

async def get_artifact_in_task(
self, flow_id: str, run_id: int, step_name: str, task_id: int
self, flow_id: str, run_id: str, step_name: str, task_id: str
):
run_id_key, run_id_value = translate_run_key(run_id)
task_id_key, task_id_value = translate_task_key(task_id)
Expand All @@ -890,16 +916,8 @@ async def get_artifact_in_task(
ordering=self.ordering)

async def get_artifact(
self, flow_id: str, run_id: int, step_name: str, task_id: int, name: str
self, flow_id: str, run_id: str, step_name: str, task_id: str, name: str
):
run_id_key, run_id_value = translate_run_key(run_id)
task_id_key, task_id_value = translate_task_key(task_id)
filter_dict = {
"flow_id": flow_id,
run_id_key: run_id_value,
"step_name": step_name,
task_id_key: task_id_value,
'"name"': name,
}
filter_dict = self.get_filter_dict(flow_id, run_id, step_name, task_id, name)
return await self.get_records(filter_dict=filter_dict,
fetch_single=True, ordering=self.ordering)
139 changes: 139 additions & 0 deletions services/metadata_service/api/tag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
from services.data import TaskRow
from services.data.db_utils import DBResponse
from services.data.postgres_async_db import AsyncPostgresDB
from services.metadata_service.api.utils import format_response, \
handle_exceptions
import json

import asyncio


class TagApi(object):
lock = asyncio.Lock()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seems to be unused.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I copied from other APIs but let me check if there was a purpose behind it.


def __init__(self, app):
app.router.add_route(
"POST",
"/tags",
self.update_tags,
)
self._db = AsyncPostgresDB.get_instance()

def _get_table(self, type):
if type == 'flow':
return self._db.flow_table_postgres
elif type == 'run':
return self._db.run_table_postgres
elif type == 'step':
return self._db.step_table_postgres
elif type == 'task':
return self._db.task_table_postgres
elif type == 'artifact':
return self._db.artifact_table_postgres
else:
raise ValueError("cannot find table for type %s" % type)

@handle_exceptions
@format_response
async def update_tags(self, request):
"""
---
description: Update user-tags for objects
tags:
- Tags
parameters:
- name: "body"
in: "body"
description: "body"
required: true
schema:
type: array
items:
type: object
required:
- object_type
- id
- tag
- operation
properties:
object_type:
type: string
enum: [flow, run, step, task, artifact]
id:
type: string
operation:
type: string
enum: [add, remove]
tag:
type: string
user:
type: string
produces:
- application/json
responses:
"202":
description: successful operation. Return newly registered task
"404":
description: not found
"500":
description: internal server error
"""
body = await request.json()
results = []
for o in body:
try:
Comment on lines +83 to +84
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The multiple updates are not done inside a transaction, so there is a slight peculiarity with the way things might work in case of errors. For example

  1. flow/run/step tag updates succeed
  2. for whatever reason task tag update fails
  3. API responds with an error, even though some records were updated, and some updates were never reached.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, this was a replica of how things were happening inside (in our internal version) and was a conscious choice but maybe something worth reconsidering. @ferras may have some more ideas on that.

table = self._get_table(o['object_type'])
pathspec = o['id'].split('/')
# Do some basic verification
if o['object_type'] == 'flow' and len(pathspec) != 1:
raise ValueError("invalid flow specification: %s" % o['id'])
elif o['object_type'] == 'run' and len(pathspec) != 2:
raise ValueError("invalid run specification: %s" % o['id'])
elif o['object_type'] == 'step' and len(pathspec) != 3:
raise ValueError("invalid step specification: %s" % o['id'])
elif o['object_type'] == 'task' and len(pathspec) != 4:
raise ValueError("invalid task specification: %s" % o['id'])
elif o['object_type'] == 'artifact' and len(pathspec) != 5:
raise ValueError("invalid artifact specification: %s" % o['id'])
obj_filter = table.get_filter_dict(*pathspec)
except ValueError as e:
return DBResponse(response_code=400, body=json.dumps(
{"message": "invalid input: %s" % str(e)}))

# Now we can get the object
obj = await table.get_records(
filter_dict=obj_filter, fetch_single=True, expanded=True)
if obj.response_code != 200:
return DBResponse(response_code=obj.response_code, body=json.dumps(
{"message": "could not get object %s: %s" % (o['id'], obj.body)}))

# At this point do some checks and update the tags
obj = obj.body
modified = False
if o['operation'] == 'add':
# This is the only error we fail hard on; adding a tag that is
# in system tag
if o['tag'] in obj['system_tags']:
return DBResponse(response_code=405, body=json.dumps(
{"message": "tag %s is already a system tag and can't be added to %s"
% (o['tag'], o['id'])}))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this a necessary guard? the metaflow cli does not guard against adding system tags as user tags with --tag it seems. Should similar checks be done during creation as well then?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's an excellent point and a bug most likely. This patch was done to allow modification of tags after the fact and the idea is to make sure that system tags stay unique and not affected. They are kept separate but we were trying to make sure that system tags were not overwritten in a way although I will admit that this is maybe not very useful. Let me circle back on that.

if o['tag'] not in obj['tags']:
modified = True
obj['tags'].append(o['tag'])
elif o['operation'] == 'remove':
if o['tag'] in obj['tags']:
modified = True
obj['tags'].remove(o['tag'])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Noticed that the Metaflow client and service do not guard against duplicate tags, so --tag test --tag test would add the test tag twice, due to tags being a simple array. remove will only get rid of one occurrence of a tag. Should this be the way the update works?

alternatively get rid of all occurences of a tag:

obj['tags'] = [tag for tag in obj['tags] if tag != o['tag']]

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You are right again on the creation side of things; the update will guard against duplicates but not the add (I will double check on that side; I think we should just convert it to a set and set that and also check against system tags).

else:
return DBResponse(response_code=400, body=json.dumps(
{"message": "invalid tag operation %s" % o['operation']}))
if modified:
# We save the value back
result = await table.update_row(filter_dict=obj_filter, update_dict={
'tags': "'%s'" % json.dumps(obj['tags'])})
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Worth checking: Previously update_row was only used for heartbeat updating, but using it this way might not be reliable/secure, as tags contain user inputs. Preferred way for psycopg is to use the cur.execute(sql_template, values) which should take care of correctly escaping the values.

Might be worth refactoring the update_row function to use the safe execute syntax.
see https://www.psycopg.org/docs/usage.html#passing-parameters-to-sql-queries for usage.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point. I will update the update_row function to use a template. That seems the better way.

if result.response_code != 200:
return DBResponse(response_code=result.response_code, body=json.dumps(
{"message": "error updating tags for %s: %s" % (o['id'], result.body)}))
results.append(obj)

return DBResponse(response_code=200, body=json.dumps(results))
2 changes: 2 additions & 0 deletions services/metadata_service/api/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ def format_response(func):
@wraps(func)
async def wrapper(*args, **kwargs):
db_response = await func(*args, **kwargs)
if isinstance(db_response, web.Response):
return db_response
return web.Response(status=db_response.response_code,
body=json.dumps(db_response.body),
headers=MultiDict(
Expand Down
2 changes: 2 additions & 0 deletions services/metadata_service/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from .api.task import TaskApi
from .api.artifact import ArtificatsApi
from .api.admin import AuthApi
from .api.tag import TagApi

from .api.metadata import MetadataApi
from services.data.postgres_async_db import AsyncPostgresDB
Expand All @@ -30,6 +31,7 @@ def app(loop=None, db_conf: DBConfiguration = None):
MetadataApi(app)
ArtificatsApi(app)
AuthApi(app)
TagApi(app)
setup_swagger(app)
return app

Expand Down