diff --git a/opal/core/search/extract.py b/opal/core/search/extract.py index f981c5b63..6a7719bb2 100644 --- a/opal/core/search/extract.py +++ b/opal/core/search/extract.py @@ -1,301 +1,66 @@ """ Utilities for extracting data from Opal applications """ -from collections import OrderedDict -import csv import datetime import functools -import json -import logging import os import tempfile import zipfile from django.template import loader -from django.core.serializers.json import DjangoJSONEncoder -from six import text_type - -from collections import defaultdict -from opal.models import Episode from opal.core.subrecords import ( episode_subrecords, subrecords ) +from opal.core.search.extract_rule import ModelRule, ExtractRule -class CsvColumn(object): - """ A custom column class that will render a custom value - - * name is similar to api_name on a field - if it matches and existig field api name in the extract - fields this will override it - * value is that takes in whatever arguments - are passed to get_row - * display name is what is used in the header - """ - def __init__(self, name, value=None, display_name=None): - self.name = name - self.value = value - - if value: - self.value = value - else: - self.value = lambda renderer, obj: getattr(obj, self.name) - - if display_name: - self.display_name = display_name - else: - self.display_name = self.name.title() - - -class CsvRenderer(object): - """ - An Abstract base class of the other csv renderers - """ - - # overrides of model fields for the csv columns - non_field_csv_columns = [] - - def __init__(self, model, queryset, user, fields=None): - self.queryset = queryset - self.model = model - self.user = user - if fields: - self.fields = fields - else: - self.fields = self.get_field_names_to_render() - - def get_non_field_csv_column_names(self): - return [csv_column.name for csv_column in self.non_field_csv_columns] - - def get_non_field_csv_columns(self, field_name): - return next( - i for i in self.non_field_csv_columns if i.name == field_name - ) - - def get_field_names_to_render(self): - field_names = self.model._get_fieldnames_to_extract() - field_names.remove("consistency_token") - result = self.get_non_field_csv_column_names() - non_field_csv_columns_set = set(result) - for field_name in field_names: - if field_name not in non_field_csv_columns_set: - result.append(field_name) - - return result - - def get_field_title(self, field_name): - return self.model._get_field_title(field_name) - - def get_headers(self): - result = [] - for field in self.fields: - if field in self.get_non_field_csv_column_names(): - result.append( - self.get_non_field_csv_columns(field).display_name - ) - else: - result.append(self.get_field_title(field)) - return result - - def serialize_dict(self, some_dict): - return json.dumps(some_dict, cls=DjangoJSONEncoder) - - def serialize_list(self, some_list): - """ - Complex datatypes (ie anything that involves a dict) - should be json serialized, otherwise, return a - semicolon seperated list - """ - if len(some_list) and isinstance(some_list[0], dict): - return json.dumps(some_list, cls=DjangoJSONEncoder) - else: - return "; ".join(text_type(i) for i in some_list) - - def serialize_value(self, some_value): - if isinstance(some_value, list): - return self.serialize_list(some_value) - if isinstance(some_value, dict): - return self.serialize_dict(some_value) - return text_type(some_value) - - def get_field_value(self, field_name, data): - return self.serialize_value(data[field_name]) - - def get_row(self, instance, *args, **kwargs): - as_dict = instance.to_dict(user=self.user) - - result = [] - - for field in self.fields: - if field in self.get_non_field_csv_column_names(): - some_fn = self.get_non_field_csv_columns(field).value - result.append( - some_fn(self, instance, *args, **kwargs) - ) - else: - result.append(self.get_field_value(field, as_dict)) - return result - - def get_rows(self): - for instance in self.queryset: - yield self.get_row(instance) - - def count(self): - return self.queryset.count() - - def write_to_file(self, file_name): - logging.info("writing for {}".format(self.model)) - - with open(file_name, "w") as csv_file: - writer = csv.writer(csv_file) - writer.writerow(self.get_headers()) - for row in self.get_rows(): - writer.writerow(row) - - logging.info("finished writing for {}".format(self.model)) - - -class EpisodeCsvRenderer(CsvRenderer): - non_field_csv_columns = ( - CsvColumn( - "tagging", - value=lambda renderer, instance: text_type(";".join( - instance.get_tag_names(renderer.user, historic=True) - )) - ), - CsvColumn("start"), - CsvColumn("end"), - CsvColumn("created"), - CsvColumn("updated"), - CsvColumn("created_by_id", display_name="Created By"), - CsvColumn("updated_by_id", display_name="Updated By"), - CsvColumn("patient_id", display_name="Patient"), - ) - - -class PatientSubrecordCsvRenderer(CsvRenderer): - non_field_csv_columns = ( - CsvColumn( - "episode_id", - display_name="Episode", - value=lambda renderer, instance, episode_id: text_type(episode_id) - ), - ) - - def __init__(self, model, episode_queryset, user, fields=None): - self.patient_to_episode = defaultdict(list) - - for episode in episode_queryset: - self.patient_to_episode[episode.patient_id].append(episode.id) - - queryset = model.objects.filter( - patient__in=list(self.patient_to_episode.keys())) - - super(PatientSubrecordCsvRenderer, self).__init__( - model, queryset, user, fields - ) - - def get_field_names_to_render(self): - field_names = super( - PatientSubrecordCsvRenderer, self - ).get_field_names_to_render() - field_names.remove("id") - return field_names - - def get_rows(self): - for sub in self.queryset: - for episode_id in self.patient_to_episode[sub.patient_id]: - yield self.get_row(sub, episode_id) - - -def field_to_dict(subrecord, field_name): - return dict( - display_name=subrecord._get_field_title(field_name), - description=subrecord.get_field_description(field_name), - type_display_name=subrecord.get_human_readable_type(field_name), - ) - - -def get_data_dictionary(): - schema = {} - for subrecord in subrecords(): - if getattr(subrecord, '_exclude_from_extract', False): - continue - - field_names = subrecord._get_fieldnames_to_extract() - record_schema = [field_to_dict(subrecord, i) for i in field_names] - schema[subrecord.get_display_name()] = record_schema - field_names = Episode._get_fieldnames_to_extract() - schema["Episode"] = [field_to_dict(Episode, i) for i in field_names] - return OrderedDict(sorted(schema.items(), key=lambda t: t[0])) - - -def write_data_dictionary(file_name): - dictionary = get_data_dictionary() +def write_data_dictionary(root_dir, schema): + file_name = "data_dictionary.html" + full_file_name = os.path.join(root_dir, file_name) t = loader.get_template("extract_data_schema.html") - ctx = dict(schema=dictionary) + # sort the dictionary by the name of the rule/subrecord + schema = { + k: v for k, v in sorted(schema.items(), key=lambda item: item[0]) + } + ctx = dict(schema=schema) rendered = t.render(ctx) - with open(file_name, "w") as f: + with open(full_file_name, "w") as f: f.write(rendered) - - -class EpisodeSubrecordCsvRenderer(CsvRenderer): - non_field_csv_columns = ( - CsvColumn( - "patient_id", - display_name="Patient", - value=lambda self, instance: text_type(instance.episode.patient_id) - ), - ) - - def get_field_names_to_render(self): - field_names = super( - EpisodeSubrecordCsvRenderer, self - ).get_field_names_to_render() - field_names.remove("id") - return field_names - - def __init__(self, model, episode_queryset, user, fields=None): - queryset = model.objects.filter(episode__in=episode_queryset) - - super(EpisodeSubrecordCsvRenderer, self).__init__( - model, queryset, user, fields - ) + return file_name def generate_csv_files(root_dir, episodes, user): - """ Generate the files and return a tuple of absolute_file_name, file_name - """ file_names = [] - - file_name = "data_dictionary.html" - full_file_name = os.path.join(root_dir, file_name) - write_data_dictionary(full_file_name) - file_names.append((full_file_name, file_name,)) - - file_name = "episodes.csv" - full_file_name = os.path.join(root_dir, file_name) - renderer = EpisodeCsvRenderer(Episode, episodes, user) - renderer.write_to_file(full_file_name) - file_names.append((full_file_name, file_name,)) - + data_dictionary_dict = {} for subrecord in subrecords(): if getattr(subrecord, '_exclude_from_extract', False): continue - file_name = '{0}.csv'.format(subrecord.get_api_name()) - full_file_name = os.path.join(root_dir, file_name) if subrecord in episode_subrecords(): - renderer = EpisodeSubrecordCsvRenderer( - subrecord, episodes, user + rule = ModelRule( + episodes, user, subrecord, "episode__id" ) else: - renderer = PatientSubrecordCsvRenderer( - subrecord, episodes, user + rule = ModelRule( + episodes, user, subrecord, "patient__episode__id" ) - if renderer.count(): - renderer.write_to_file(full_file_name) - file_names.append((full_file_name, file_name,)) - + file_name = rule.write_to_file(root_dir) + if file_name: + file_names.append(file_name) + data_dictionary_dict.update(rule.get_data_dictionary()) + + for rule_cls in ExtractRule.list(): + rule = rule_cls(episodes, user) + file_name = rule.write_to_file(root_dir) + data_dictionary_dict.update(rule.get_data_dictionary()) + if not file_name: + continue + if isinstance(file_name, list): + file_names.extend(file_name) + else: + file_names.append(file_name) + file_names.append( + write_data_dictionary(root_dir, data_dictionary_dict) + ) return file_names @@ -314,10 +79,12 @@ def zip_archive(episodes, description, user): os.mkdir(root_dir) zip_relative_file_path = functools.partial(os.path.join, zipfolder) file_names = generate_csv_files(root_dir, episodes, user) - for full_file_name, file_name in file_names: + for file_name in file_names: + full_file_name = os.path.join(root_dir, file_name) + zip_relative_file_path = os.path.join(zipfolder, file_name) z.write( full_file_name, - zip_relative_file_path(file_name) + zip_relative_file_path ) return target diff --git a/opal/core/search/extract_rule.py b/opal/core/search/extract_rule.py new file mode 100644 index 000000000..2c1c296f9 --- /dev/null +++ b/opal/core/search/extract_rule.py @@ -0,0 +1,315 @@ +import csv +import os +from collections import defaultdict +from opal.core.discoverable import DiscoverableFeature +from opal.models import Episode +import json +from django.utils.module_loading import import_string +from django.conf import settings +from django.core.serializers.json import DjangoJSONEncoder +from opal.utils import AbstractBase + + +def default_base_fields(user): + return { + "patient_id": "Patient", + "id": "Episode", + } + + +EXTRACT_BASE_FIELDS = getattr(settings, "EXTRACT_BASE_FIELDS", None) + +if BASE_FIELDS: + base_fields_function = import_string(BASE_FIELDS) +else: + base_fields_function = default_base_fields + + +class ExtractRule(DiscoverableFeature): + module_name = "extract_rule" + + def __init__(self, episode_list, user): + self.episode_list = episode_list + self.user = user + + def write_to_file(self, directory): + """ + Write to a file and return the name of the file written to. + If multiple files have been written to, return a list. + """ + pass + + def get_data_dictionary(self): + """ + return { + {{ display_name }}: [{ + display_name: {{ field display name }}, + description: {{ field description }}, + type_display_name: {{ field type }}, + }] + } + """ + return {} + + +class ModelRule(ExtractRule, AbstractBase): + """ + Extract an queryset into a file. + + Takes in an episode queryset, user, model and a path that links the + model to an episode id. + + e.g. for a patient subrecord Allergy + + it would take (episode_qs, user, Allergy, 'patient__episode__id'). + + The output is a csv with the field names of a combination of + the base fields that should appear in every csv, by default episode id + and the fields from model._get_fieldnames_to_extract. + + The method is a combination of three things... + 1. model_id_to_episode_ids_dict a dict of the instance id to episode ids + 2. base_field_dict a dict of episode_ids to the values that should appear on every row + (by default episode id) + 3. the output of get_instance_dict(instance) that returns serializes an instance to a dict + """ + additional_fields = [] + + def __init__(self, episode_list, user, model, path_to_episode_id): + self.episode_list = episode_list + self.user = user + self.model = model + self.path_to_episode_id = path_to_episode_id + self.model_field_name_to_display_name = self.get_fields_name_to_display_name() + + # a dictionary of model id to epis + self.model_id_to_episode_ids_dict = self.get_model_id_to_episode_ids() + + # the model queryset + self.queryset = self.model.objects.filter( + id__in=self.model_id_to_episode_ids_dict.keys() + ) + + # the additional fields that will be attatched to every row + # e.g. patient__demographics__surname + self.base_fields = base_fields_function(user) + + # a dict of episode id to base fields + self.base_field_dict = self.get_base_field_dict() + self.validate_additional_fields() + + def validate_additional_fields(self): + for field in self.additional_fields: + extract_method = "extract_{}".format(field) + get_field_description_method = "get_{}_description".format(field) + if not hasattr(self.model, field) and not hasattr(self, extract_method): + err_message = " ".join(( + "{} is not an attribute on the model", + "please implement a {} on the extract rule {}", + )).format(field, extract_method, self) + raise NotImplementedError(err_message) + + if not hasattr(self, get_field_description_method): + err_message = " ".join( + "Please implement a {} message on the model", + "that returns a dictionary with 'display_name',", + "'description' and 'type_display_name'" + ) + raise NotImplementedError(err_message) + + @property + def file_name(self): + """ + The file that will be extracted, not including the .csv + """ + + return "{}.csv".format(self.model.get_api_name()) + + def get_base_field_dict(self): + """ + A dictionary of episode id to fields that should exist + in all csvs for that episode id. + """ + fields = self.base_fields.keys() + if "id" not in fields: + fields = list(fields) + fields.append("id") + episode_qs = Episode.objects.filter( + id__in=[i.id for i in self.episode_list] + ) + values_list = episode_qs.values(*fields) + result = {} + for value_row in values_list: + result[value_row["id"]] = { + self.base_fields[k]: v for k, v in value_row.items() + } + return result + + def get_field_value(self, data): + """ + Serialize a value from a model field to what appears in the csv + """ + if isinstance(data, (list, dict,)): + return json.dumps(data, cls=DjangoJSONEncoder) + return str(data) + + def get_fields_name_to_display_name(self): + """ + Return a dictionary of the instance field name to the display name + """ + result = {} + field_names = self.model._get_fieldnames_to_extract() + fields_to_ignore = { + "consistency_token", "episode_id", "patient_id", "id" + } + for field_name in field_names: + if field_name in fields_to_ignore: + continue + result[field_name] = self.model._get_field_title(field_name) + + for field_name in self.additional_fields: + description_getter = "get_{}_description".format(field_name) + result[field_name] = getattr(self, description_getter)()["display_name"] + + return result + + def get_model_id_to_episode_ids(self): + """ + Returns a list of dictionaries + {id: [[ model_id]], 'episode_id': [[ episode_id]]}. + + Not that for a patient with multiple patient subrecords for the same type + and multiple episodes we end up with a cartesian join ie multiple + rows per episode and multiple rows per instance. + + e.g. Jane has two episodes, ICU(id 1) and Infection service (id 2) + and two allergies, Aciclovir(id 3) and Amphotericin(id 4) + + the output for allergies will be + + [ + {'episode_id': 1, 'id': 3}, + {'episode_id': 2, 'id': 4}, + {'episode_id': 1, 'id': 3}, + {'episode_id': 2, 'id': 4}, + ] + + I am not sure this is necessarily correct + """ + filter_arg = "{}__in".format(self.path_to_episode_id) + list_of_dicts = self.model.objects.filter( + **{ + filter_arg: [i.id for i in self.episode_list] + } + ).values('id', self.path_to_episode_id) + result = defaultdict(list) + for some_dict in list_of_dicts: + result[some_dict["id"]].append( + some_dict[self.path_to_episode_id] + ) + return result + + def get_rows_for_instance(self, instance): + """ + Get the rows that should appear in the csv for a given instance. + This is the combination of the base fields that appear in + every csv and the fields for the instance + """ + episode_ids = self.model_id_to_episode_ids_dict[instance.id] + model_row = self.get_instance_dict(instance) + rows = [] + for episode_id in episode_ids: + row = self.base_field_dict[episode_id].copy() + row.update(model_row) + rows.append(row) + return rows + + def get_instance_dict(self, instance): + """ + Return a serialized form of the instance without the base fields. + """ + result = {} + for field, display_name in self.model_field_name_to_display_name.items(): + getter = getattr(self, 'extract_{}'.format(field), None) + if getter: + value = getter(instance) + else: + value = self.get_field_value(getattr(instance, field)) + result[display_name] = value + return result + + def get_rows(self): + rows = [] + for instance in self.queryset: + rows.extend(self.get_rows_for_instance(instance)) + return rows + + def sort_headers(self, headers): + base_field_display_names = self.base_fields.values() + return list(self.base_fields.values()) + [ + i for i in headers if i not in base_field_display_names + ] + + def write_to_file(self, directory): + """ + Writes what is generated out to a file + """ + file_name = self.file_name + full_file_name = os.path.join(directory, file_name) + rows = self.get_rows() + if not rows: + return + with open(full_file_name, "w") as f: + fieldnames = self.sort_headers(rows[0].keys()) + writer = csv.DictWriter(f, fieldnames=fieldnames) + writer.writeheader() + for row in rows: + writer.writerow(row) + return file_name + + def get_field_description(self, field_name): + get_method = "get_{}_description".format( + field_name + ) + if hasattr(self, get_method): + return getattr(self, get_method)() + return { + "display_name": self.model._get_field_title(field_name), + "description": self.model.get_field_description(field_name), + "type_display_name": self.model.get_human_readable_type(field_name), + } + + def get_data_dictionary(self): + fields = self.model_field_name_to_display_name.keys() + fields_descriptions = [ + self.get_field_description(field) for field in fields + ] + fields_descriptions = sorted( + fields_descriptions, key=lambda x: x["display_name"] + ) + display_name = getattr(self, "display_name", None) + + if not display_name: + display_name = self.model.get_display_name() + return {display_name: fields_descriptions} + + +class EpisodeRule(ModelRule): + file_name = "episodes.csv" + additional_fields = ["tagging"] + display_name = "Episode" + + def __init__(self, episode_list, user): + super().__init__(episode_list, user, Episode, "id") + + def get_tagging_description(self): + return { + "display_name": "Tagging", + "description": "A list of tags that the episode has been tagged with", + "type_display_name": "semicolon seperated list" + } + + def extract_tagging(self, episode): + return ";".join( + episode.get_tag_names(self.user, historic=True) + )