Skip to content

Commit

Permalink
Merge pull request #350 from coderxio/jrlegrand/vsac
Browse files Browse the repository at this point in the history
Jrlegrand/vsac
  • Loading branch information
jrlegrand authored Jan 18, 2025
2 parents 9feb666 + ec3fe45 commit a2aa1a0
Show file tree
Hide file tree
Showing 4 changed files with 134 additions and 122 deletions.
27 changes: 26 additions & 1 deletion airflow/dags/common_dag_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,32 @@ def run_subprocess_command(command:list, cwd:str, success_code:int = 0) -> None:
print(f"Command succeeded with output: {run_results.output}")
else:
raise AirflowException(f"Command failed with return code {run_results.exit_code}: {run_results.output}")


def get_umls_ticket_granting_ticket():
import requests
from airflow.models import Variable
api_key = Variable.get("umls_api")
params = { "apikey": api_key }
headers = { "Content-type": "application/x-www-form-urlencoded", "Accept": "text/plain", "User-Agent":"python" }
url = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
response = requests.post(url, headers=headers, data=params)
response.raise_for_status()
return response.url.split('/')[-1]

def get_umls_service_ticket(tgt: str):
import requests
service = "https://uts-ws.nlm.nih.gov"
params = { "service": service }
headers = { "Content-type": "application/x-www-form-urlencoded", "Accept": "text/plain", "User-Agent":"python" }
url = f"https://utslogin.nlm.nih.gov/cas/v1/tickets/{tgt}"
response = requests.post(url, headers=headers, data=params)
response.raise_for_status()
return response.text

def get_umls_ticket():
tgt = get_umls_ticket_granting_ticket()
st = get_umls_service_ticket(tgt)
return st

@task
def extract(dag_id,url) -> str:
Expand Down
46 changes: 6 additions & 40 deletions airflow/dags/rxnorm/dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,9 @@
from airflow.operators.python import get_current_context
from airflow.providers.postgres.operators.postgres import PostgresOperator
from airflow.hooks.postgres_hook import PostgresHook
from airflow.models import Variable

from common_dag_tasks import transform
from common_dag_tasks import extract, transform, run_subprocess_command


@dag(
Expand All @@ -19,45 +20,10 @@
)
def rxnorm():
dag_id = "rxnorm"
ds_url = "https://download.nlm.nih.gov/umls/kss/rxnorm/RxNorm_full_current.zip"
api_key = Variable.get("umls_api")
ds_url = f"https://uts-ws.nlm.nih.gov/download?url=https://download.nlm.nih.gov/umls/kss/rxnorm/RxNorm_full_current.zip&apiKey={api_key}"

@task
def get_tgt():
import requests
from airflow.models import Variable

api_key = Variable.get("umls_api")

url = "https://utslogin.nlm.nih.gov/cas/v1/api-key"
param = { "apikey": api_key }
headers = { "Content-type": "application/x-www-form-urlencoded" }

tgt_response = requests.post(url, headers=headers, data=param)

first, second = tgt_response.text.split("api-key/")
tgt, fourth = second.split('" method')

return tgt

@task
def get_st(tgt: str):
import requests

url = f"https://utslogin.nlm.nih.gov/cas/v1/tickets/{tgt}"
param = { "service": ds_url }
headers = { "Content-type": "application/x-www-form-urlencoded" }

st_response = requests.post(url, headers=headers, data=param)
st = st_response.text

return st

# Task to download data from web location
@task
def extract(st: str):
data_folder = Path("/opt/airflow/data") / dag_id
data_path = get_dataset(f"{ds_url}?ticket={st}", data_folder)
return data_path
extract_task = extract(dag_id, ds_url)

# Task to load data into source db schema
load = []
Expand All @@ -75,6 +41,6 @@ def extract(st: str):

transform_task = transform(dag_id, models_subdir=['staging', 'intermediate'])

extract(get_st(get_tgt())) >> load >> transform_task
extract_task >> load >> transform()

rxnorm()
31 changes: 31 additions & 0 deletions airflow/dags/vsac/dag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
from pathlib import Path
import pendulum

from sagerx import get_dataset, read_sql_file, get_sql_list, alert_slack_channel

from airflow.decorators import dag, task

from airflow.operators.python import get_current_context
from airflow.providers.postgres.operators.postgres import PostgresOperator
from airflow.hooks.postgres_hook import PostgresHook

from common_dag_tasks import run_subprocess_command, extract
from vsac.dag_tasks import main_execution



@dag(
schedule="0 3 * * *",
start_date=pendulum.yesterday(),
catchup=False,
)
def vsac():
dag_id = "vsac"
base_url = "https://cts.nlm.nih.gov/fhir"
ds_url = ""

extract_load_task = main_execution()

extract_load_task

vsac()
152 changes: 71 additions & 81 deletions airflow/dags/vsac/vsac.py → airflow/dags/vsac/dag_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,18 @@
import base64
import concurrent.futures
from xml.etree import ElementTree as ET
from tqdm import tqdm

# Constants
BASE_URL_SVS = "https://cts.nlm.nih.gov/fhir"
API_KEY = ''
from common_dag_tasks import get_umls_ticket
from airflow.models import Variable
from airflow.decorators import task
from sagerx import load_df_to_pg

# Function to retrieve tag values for a given tag name
# constants
api_key = Variable.get("umls_api")

# function to retrieve tag values for a given tag name
def get_tag_values(tag_name):
credentials = f"apikey:{API_KEY}".encode('utf-8')
credentials = f"apikey:{api_key}".encode('utf-8')
base64_encoded_credentials = base64.b64encode(credentials).decode('utf-8')
headers = {
"Authorization": f"Basic {base64_encoded_credentials}",
Expand All @@ -22,7 +25,7 @@ def get_tag_values(tag_name):
response.raise_for_status() # This will raise an exception for HTTP errors
return response.text
except requests.exceptions.RequestException as e:
tqdm.write(f"Error fetching data for tag {tag_name}: {e}")
print(f"Error fetching data for tag {tag_name}: {e}")
return None

def parse_xml_for_tag_values(xml_content):
Expand All @@ -40,7 +43,7 @@ def get_latest_version_cms_eMeasureID(values):
return [measure_id + 'v' + version for measure_id, version in latest_versions.items()]

def get_described_value_set_ids(tag_name, tag_value):
credentials = f"apikey:{API_KEY}".encode('utf-8')
credentials = f"apikey:{api_key}".encode('utf-8')
base64_encoded_credentials = base64.b64encode(credentials).decode('utf-8')
headers = {
"Authorization": f"Basic {base64_encoded_credentials}",
Expand All @@ -51,83 +54,44 @@ def get_described_value_set_ids(tag_name, tag_value):
value_set_ids = [value_set.get('ID') for value_set in root.findall('.//ns0:DescribedValueSet', namespaces={'ns0': 'urn:ihe:iti:svs:2008'})]
return value_set_ids

# Function to process each tag value using multithreading
# function to process each tag value using multithreading
def process_tag_values(tag_name, tag_values_list):
described_value_set_ids = []
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
future_to_ids = {executor.submit(get_described_value_set_ids, tag_name, value): value for value in tag_values_list}
for future in tqdm(concurrent.futures.as_completed(future_to_ids), total=len(tag_values_list), desc=f"Processing tag values for {tag_name}"):
for future in concurrent.futures.as_completed(future_to_ids):
described_value_set_ids.extend(future.result())
return described_value_set_ids

# Main execution
tag_names = ['CMS eMeasure ID', 'eMeasure Identifier', 'NQF Number']
tag_values = {}
for tag_name in tqdm(tag_names, desc="Retrieving tag values"):
xml_response = get_tag_values(tag_name)
if xml_response:
tag_values[tag_name] = xml_response

described_value_set_ids = []
for tag_name, xml_content in tag_values.items():
tag_values_list = parse_xml_for_tag_values(xml_content)

if tag_name == 'CMS eMeasure ID':
tag_values_list = get_latest_version_cms_eMeasureID(tag_values_list)

ids = process_tag_values(tag_name, tag_values_list)
described_value_set_ids.extend(ids)

processed_oids = set() # To keep track of processed OIDs and avoid infinite loops

class UMLSFetcher:
def __init__(self, api_key):
self.API_KEY = api_key
self.SERVICE = "https://uts-ws.nlm.nih.gov"
self.TICKET_GRANTING_TICKET = self.get_ticket_granting_ticket()

def get_ticket_granting_ticket(self):
params = {'apikey': self.API_KEY}
headers = {"Content-type": "application/x-www-form-urlencoded", "Accept": "text/plain", "User-Agent":"python" }
response = requests.post("https://utslogin.nlm.nih.gov/cas/v1/api-key", headers=headers, data=params)
response.raise_for_status()
return response.url.split('/')[-1]

def get_service_ticket(self):
params = {'service': self.SERVICE}
headers = {"Content-type": "application/x-www-form-urlencoded", "Accept": "text/plain", "User-Agent":"python" }
response = requests.post(f"https://utslogin.nlm.nih.gov/cas/v1/tickets/{self.TICKET_GRANTING_TICKET}", headers=headers, data=params)
response.raise_for_status()
return response.text
processed_oids = set() # to keep track of processed OIDs and avoid infinite loops

# Use to get descendant codes for value sets that only provide references instead of codes
def get_descendants(self, source, code):
ticket = self.get_service_ticket()
url = f"{self.SERVICE}/rest/content/current/source/{source}/{code}/descendants?ticket={ticket}"
response = requests.get(url, headers={"User-Agent": "python"})
if response.status_code != 200:
tqdm.write(f"Error fetching descendants: {response.status_code} - {response.text}")
return []
return [result['ui'] for result in tqdm(response.json()['result'], desc=f"Retrieving descendants for {code}")]
# use to get descendant codes for value sets that only provide references instead of codes
def get_descendants(self, source, code):
ticket = get_umls_ticket()
url = f"{self.SERVICE}/rest/content/current/source/{source}/{code}/descendants?ticket={ticket}"
response = requests.get(url, headers={"User-Agent": "python"})
if response.status_code != 200:
print(f"Error fetching descendants: {response.status_code} - {response.text}")
return []
return [result['ui'] for result in response.json()['result']]

umls_fetcher = UMLSFetcher(API_KEY)

# Use to retrieve a value set from VSAC
# use to retrieve a value set from VSAC
def retrieve_value_set(oid):
credentials = f"apikey:{API_KEY}".encode('utf-8')
credentials = f"apikey:{api_key}".encode('utf-8')
base64_encoded_credentials = base64.b64encode(credentials).decode('utf-8')
headers = {
"Authorization": f"Basic {base64_encoded_credentials}",
"Accept": "application/fhir+json"
}
response = requests.get(f"{BASE_URL_SVS}/ValueSet/{oid}", headers=headers)
response = requests.get(f"https://cts.nlm.nih.gov/fhir/ValueSet/{oid}", headers=headers)
return response.json()

# Convert the JSON response to a DataFrame
# convert the JSON response to a dataframe
def json_to_dataframe(response_json, current_oid=None, parent_oid=None):
data = []

# Mapping of system URIs to recognizable names
# mapping of system URIs to recognizable names
system_map = {
"http://snomed.info/sct": "SNOMED",
"http://hl7.org/fhir/sid/icd-10-cm": "ICD10CM",
Expand Down Expand Up @@ -165,16 +129,16 @@ def json_to_dataframe(response_json, current_oid=None, parent_oid=None):
"parent_oid": parent_oid if parent_oid else None
})

# Handle the 'descendantOf' filter for value sets that only provide references instead of codes
# handle the 'descendantOf' filter for value sets that only provide references instead of codes
filters = response_json.get('compose', {}).get('include', [{}])[0].get('filter', [])
for filter_ in filters:
if filter_["op"] == "descendantOf":
descendants = umls_fetcher.get_descendants(filter_["system"], filter_["value"])
descendants = get_descendants(filter_["system"], filter_["value"])
for descendant_code in descendants:
data.append({
"valueSetName": response_json["name"],
"code": descendant_code,
"display": "", # Display is empty because the UMLS API doesn't return it
"display": "", # display is empty because the UMLS API doesn't return it
"system": filter_["system"],
"status": response_json["status"],
"oid": current_oid if current_oid else response_json["id"],
Expand All @@ -183,32 +147,58 @@ def json_to_dataframe(response_json, current_oid=None, parent_oid=None):
"lastUpdated": response_json["meta"]["lastUpdated"]
})

# If the value set references other value sets, retrieve those as well
# if the value set references other value sets, retrieve those as well
referenced_value_sets = concept.get('valueSet', [])
for ref_vs in referenced_value_sets:
# Extract OID from the reference URL
# extract OID from the reference URL
oid = ref_vs.split('/')[-1]
if oid not in processed_oids: # Avoid re-processing already processed OIDs
if oid not in processed_oids: # avoid re-processing already processed OIDs
data.extend(json_to_dataframe(retrieve_value_set(oid), current_oid=oid, parent_oid=current_oid))

return data

# Concurrent function that retrieves and processes OIDs
# concurrent function that retrieves and processes OIDs
def retrieve_and_process(oid):
processed_oids.add(oid)
response_json = retrieve_value_set(oid)
return json_to_dataframe(response_json, current_oid=oid)

# Use ThreadPoolExecutor for concurrent requests with tqdm progress bar
all_data = []
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(retrieve_and_process, oid) for oid in described_value_set_ids]
for future in tqdm(concurrent.futures.as_completed(futures), total=len(described_value_set_ids), desc="Processing OIDs"):
all_data.extend(future.result())

result_df = pd.DataFrame(all_data)
# main execution
@task
def main_execution():
tag_names = ['CMS eMeasure ID', 'eMeasure Identifier', 'NQF Number']
tag_values = {}
for tag_name in tag_names:
xml_response = get_tag_values(tag_name)
if xml_response:
tag_values[tag_name] = xml_response
print('got tag_names')

described_value_set_ids = []
for tag_name, xml_content in tag_values.items():
tag_values_list = parse_xml_for_tag_values(xml_content)

if tag_name == 'CMS eMeasure ID':
tag_values_list = get_latest_version_cms_eMeasureID(tag_values_list)

ids = process_tag_values(tag_name, tag_values_list)
described_value_set_ids.extend(ids)
print('got described_value_set_ids')

# use ThreadPoolExecutor for concurrent requests
all_data = []
with concurrent.futures.ThreadPoolExecutor(max_workers=10) as executor:
futures = [executor.submit(retrieve_and_process, oid) for oid in described_value_set_ids]
for future in concurrent.futures.as_completed(futures):
all_data.extend(future.result())
if len(all_data) % 100 == 0:
print(f'MILESTONE {len(all_data)}')

result_df = pd.DataFrame(all_data)

# Reset the index of the final dataframe and remove duplicates
result_df = result_df.drop_duplicates(subset=['valueSetName', 'code', 'display'])
# reset the index of the final dataframe and remove duplicates
result_df = result_df.drop_duplicates(subset=['valueSetName', 'code', 'display'])

result_df
load_df_to_pg(result_df, "sagerx_lake", "vsac", "replace")

0 comments on commit a2aa1a0

Please sign in to comment.