Skip to content

Commit

Permalink
Create two where_strs, remove unused import
Browse files Browse the repository at this point in the history
  • Loading branch information
vivbak committed Apr 19, 2024
1 parent 3d9c0b6 commit 7ba3371
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 10 deletions.
2 changes: 0 additions & 2 deletions db/python/layers/cohort.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from db.python.connect import Connection
from db.python.layers.base import BaseLayer
from db.python.layers.sequencing_group import SequencingGroupLayer
from db.python.tables.analysis import AnalysisTable
from db.python.tables.cohort import CohortFilter, CohortTable, CohortTemplateFilter
from db.python.tables.project import ProjectId, ProjectPermissionsTable
from db.python.tables.sample import SampleFilter, SampleTable
Expand Down Expand Up @@ -82,7 +81,6 @@ def __init__(self, connection: Connection):
super().__init__(connection)

self.sampt = SampleTable(connection)
self.at = AnalysisTable(connection)
self.ct = CohortTable(connection)
self.pt = ProjectPermissionsTable(connection)
self.sgt = SequencingGroupTable(connection)
Expand Down
29 changes: 21 additions & 8 deletions db/python/tables/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,7 @@ async def query(self, filter_: AnalysisFilter) -> List[AnalysisInternal]:
'or project to filter on'
)

where_str, values = filter_.to_sql(
sg_where_str, sg_values = filter_.to_sql(
{
'id': 'a.id',
'sample_id': 'a_sg.sample_id',
Expand All @@ -241,22 +241,35 @@ async def query(self, filter_: AnalysisFilter) -> List[AnalysisInternal]:
},
)

cohort_where_str, cohort_values = filter_.to_sql(
{
'id': 'a.id',
'project': 'a.project',
'type': 'a.type',
'status': 'a.status',
'meta': 'a.meta',
'output': 'a.output',
'active': 'a.active',
'cohort_id': 'a_c.cohort_id',
},
)

retvals: Dict[int, AnalysisInternal] = {}

if filter_.cohort_id and filter_.sequencing_group_id:
raise ValueError('Cannot filter on both cohort_id and sequencing_group_id')

if filter_.cohort_id:
_query = f"""
_cohort_query = f"""
SELECT a.id as id, a.type as type, a.status as status,
a.output as output, a_c.cohort_id as cohort_id,
a.project as project, a.timestamp_completed as timestamp_completed,
a.active as active, a.meta as meta, a.author as author
FROM analysis a
LEFT JOIN analysis_cohort a_c ON a.id = a_c.analysis_id
WHERE {where_str}
WHERE {cohort_where_str}
"""
rows = await self.connection.fetch_all(_query, values)
rows = await self.connection.fetch_all(_cohort_query, cohort_values)
for row in rows:
key = row['id']
if key in retvals:
Expand All @@ -265,7 +278,7 @@ async def query(self, filter_: AnalysisFilter) -> List[AnalysisInternal]:
retvals[key] = AnalysisInternal.from_db(**dict(row))

if retvals.keys():
_query_sg_ids = f"""
_query_sg_ids = """
SELECT sequencing_group_id, analysis_id
FROM analysis_sequencing_group
WHERE analysis_id IN :analysis_ids
Expand All @@ -287,17 +300,17 @@ async def query(self, filter_: AnalysisFilter) -> List[AnalysisInternal]:
a.active as active, a.meta as meta, a.author as author
FROM analysis a
LEFT JOIN analysis_sequencing_group a_sg ON a.id = a_sg.analysis_id
WHERE {where_str}
WHERE {sg_where_str}
"""
rows = await self.connection.fetch_all(_query, values)
rows = await self.connection.fetch_all(_query, sg_values)
for row in rows:
key = row['id']
if key in retvals:
retvals[key].sequencing_group_ids.append(row['sequencing_group_id'])
else:
retvals[key] = AnalysisInternal.from_db(**dict(row))

_query_cohort_ids = f"""
_query_cohort_ids = """
SELECT analysis_id, cohort_id
FROM analysis_cohort
WHERE analysis_id IN :analysis_ids;
Expand Down

0 comments on commit 7ba3371

Please sign in to comment.