Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
353 changes: 153 additions & 200 deletions mario/data_extractor.py

Large diffs are not rendered by default.

32 changes: 19 additions & 13 deletions mario/dataset_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ class Format(Enum):
EXCEL_PIVOT = 'xlsx'
CSV = 'csv'
EXCEL_INFO_SHEET = 'info'
HYPER = 'hyper'


class DatasetBuilder:
Expand Down Expand Up @@ -64,30 +65,35 @@ def remove_redundant_hierarchies(self):
if 'hierarchies' in item.properties:
item.set_property('hierarchies', [h for h in item.get_property('hierarchies') if h['hierarchy'] != hierarchy])

def build(self, output_format: Format, file_path: str, template_path: str = None):
def build(self, output_format: Format, file_path: str, template_path: str = None, **kwargs):
if output_format == Format.TABLEAU_PACKAGED_DATASOURCE:
self.__build_tdsx__(file_path)
self.__build_tdsx__(file_path, **kwargs)
elif output_format == Format.CSV:
self.__build_csv__(file_path)
self.__build_csv__(file_path, **kwargs)
elif output_format == Format.EXCEL_PIVOT:
self.__build_excel_pivot__(file_path, template_path)
self.__build_excel_pivot__(file_path, **kwargs)
elif output_format == Format.EXCEL_INFO_SHEET:
self.__build_excel_info_sheet(file_path, template_path)
self.__build_excel_info_sheet(file_path, **kwargs)
elif output_format == Format.HYPER:
self.__build_hyper__(file_path, **kwargs)
else:
raise NotImplementedError

def __build_excel_info_sheet(self, file_path: str, template_path: str):
def __build_hyper__(self, file_path: str, **kwargs):
self.data.save_data_as_hyper(file_path=file_path, **kwargs)

def __build_excel_info_sheet(self, file_path: str, **kwargs):
from mario.excel_builder import ExcelBuilder
excel_builder = ExcelBuilder(
output_file_path=file_path,
dataset_specification=self.dataset_specification,
metadata=self.metadata,
data_extractor=self.data,
template_path=template_path
**kwargs
)
excel_builder.create_notes_only()

def __build_excel_pivot__(self, file_path: str, template_path: str):
def __build_excel_pivot__(self, file_path: str, **kwargs):
from mario.excel_builder import ExcelBuilder
if self.data.get_total() > 1000000:
logger.warning("The dataset is larger than 1m rows; this isn't supported in Excel format")
Expand All @@ -96,15 +102,15 @@ def __build_excel_pivot__(self, file_path: str, template_path: str):
dataset_specification=self.dataset_specification,
metadata=self.metadata,
data_extractor=self.data,
template_path=template_path
**kwargs
)
excel_builder.create_workbook()

def __build_csv__(self, file_path: str, compress_using_gzip=False):
def __build_csv__(self, file_path: str, **kwargs):
# TODO export Info sheet as well - see code in Automation2.0 and TDSA.
self.data.save_data_as_csv(file_path=file_path, compress_using_gzip=compress_using_gzip)
self.data.save_data_as_csv(file_path=file_path, **kwargs)

def __build_tdsx__(self, file_path: str):
def __build_tdsx__(self, file_path: str, **kwargs):
from mario.hyper_utils import get_default_table_and_schema
from tableau_builder.json_metadata import JsonRepository
with tempfile.TemporaryDirectory() as temp_folder:
Expand All @@ -116,7 +122,7 @@ def __build_tdsx__(self, file_path: str):

# Move the hyper extract
data_file_path = os.path.join(temp_folder, 'data.hyper')
self.data.save_data_as_hyper(file_path=data_file_path)
self.data.save_data_as_hyper(file_path=data_file_path, **kwargs)

# Get the table and schema name from the extract
schema, table = get_default_table_and_schema(data_file_path)
Expand Down
22 changes: 15 additions & 7 deletions mario/excel_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from mario.data_extractor import DataExtractor
from mario.dataset_specification import DatasetSpecification
from mario.metadata import Metadata
from mario.options import ExcelOptions

logger = logging.getLogger(__name__)
style_attrs = ["alignment", "border", "fill", "font", "number_format", "protection"]
Expand All @@ -33,15 +34,14 @@ def __init__(self,
data_extractor: DataExtractor,
dataset_specification: DatasetSpecification,
metadata: Metadata,
template_path: str
**kwargs
):

self.data_extractor = data_extractor
self.dataset_specification = dataset_specification
self.metadata = metadata
self.filepath = output_file_path
self.template = "excel_template.xlsx"
if template_path is not None:
self.template = template_path
self.options = ExcelOptions(**kwargs)
self.workbook = None
self.rows = None
self.cols = None
Expand All @@ -53,7 +53,7 @@ def create_workbook(self, create_notes_page=False):
Creates a write-only workbook and builds
content in streaming mode to conserve memory
"""
template_workbook = load_workbook(self.template)
template_workbook = load_workbook(self.options.template_path)
template_workbook.save(self.filepath)

if create_notes_page:
Expand Down Expand Up @@ -129,7 +129,7 @@ def create_notes_only(self, filename=None, data_format='CSV'):
"""
Create a Notes page only, to accompany CSV outputs
"""
self.workbook = load_workbook(self.template)
self.workbook = load_workbook(self.options.template_path)
self.__update_notes__(data_format=data_format, ws=self.workbook.get_sheet_by_name('Notes'))
self.workbook.remove(self.workbook.get_sheet_by_name('Data'))
self.workbook.remove(self.workbook.get_sheet_by_name('Pivot'))
Expand All @@ -139,7 +139,15 @@ def create_notes_only(self, filename=None, data_format='CSV'):
self.workbook.save(filename=self.filepath)

def __create_data_page__(self):
df: DataFrame = self.data_extractor.get_data_frame()

if self.options.validate:
if not self.data_extractor.validate_data(allow_nulls=self.options.allow_nulls):
raise ValueError("Validation error")

df: DataFrame = self.data_extractor.get_data_frame(
minimise=self.options.minimise,
include_row_numbers=self.options.include_row_numbers
)

# Reorder the columns to put the measure in col #1
cols = df.columns.tolist()
Expand Down
82 changes: 81 additions & 1 deletion mario/hyper_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
"""
import logging
from typing import List
from mario.options import CsvOptions, HyperOptions

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -276,4 +277,83 @@ def concatenate_hypers(hyper_file_path1, hyper_file_path2, output_hyper_file_pat
df2 = frame_from_hyper(hyper_file_path2, table=TableName(schema, table))

df_merged = pd.concat([df1, df2], axis=1)
frame_to_hyper(df=df_merged, database=output_hyper_file_path, table=TableName(schema, table))
frame_to_hyper(df=df_merged, database=output_hyper_file_path, table=TableName(schema, table))


def save_hyper_as_hyper(hyper_file, file_path, **kwargs):
import shutil
shutil.copyfile(hyper_file, file_path)


def save_hyper_as_csv(hyper_file: str, file_path: str, **kwargs):
import pantab
import tempfile
import shutil
import os

options = CsvOptions(**kwargs)

with tempfile.TemporaryDirectory() as temp_dir:
temp_hyper = os.path.join(temp_dir, 'temp.hyper')
shutil.copyfile(
src=hyper_file,
dst=temp_hyper
)

schema, table = get_default_table_and_schema(temp_hyper)

columns = get_column_list(
hyper_file_path=temp_hyper,
schema=schema,
table=table
)

if 'row_number' not in columns:
log.debug("Adding row numbers to hyper so we can guarantee ordering")
add_row_numbers_to_hyper(
input_hyper_file_path=temp_hyper,
schema=schema,
table=table
)
else:
log.debug("Data source already contains row numbers")

if options.include_row_numbers is False and 'row_number' in columns:
columns.remove('row_number')
elif options.include_row_numbers is True and 'row_number' not in columns:
columns.append('row_number')

if options.compress_using_gzip:
compression_options = dict(method='gzip')
file_path = file_path + '.gz'
elif file_path.endswith('.gz'):
compression_options = dict(method='gzip')
else:
compression_options = None

mode = 'w'
header = True
offset = 0
column_names = ','.join(f'"{column}"' for column in columns)

sql = f"SELECT {column_names} FROM \"{schema}\".\"{table}\" ORDER BY row_number"

while True:
query = f"{sql} LIMIT {options.chunk_size} OFFSET {offset}"
df_chunk = pantab.frame_from_hyper_query(temp_hyper, query)
if df_chunk.empty:
break
df_chunk.to_csv(file_path, index=False, mode=mode, header=header,
compression=compression_options)
offset += options.chunk_size
if header:
header = False
mode = "a"


def save_dataframe_as_hyper(df, file_path, **kwargs):
from tableauhyperapi import TableName
import pantab
options = HyperOptions(**kwargs)
table_name = TableName(options.schema, options.table)
pantab.frame_to_hyper(df=df, database=file_path, table=table_name)
42 changes: 42 additions & 0 deletions mario/options.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
"""
Common output options for different data formats along with their default values
"""
import logging
logger = logging.getLogger(__name__)


class OutputOptions:
def __init__(self, **kwargs):
self.minimise = kwargs.get('minimise', False)
self.validate = kwargs.get('validate', False)
self.allow_nulls = kwargs.get('allow_nulls', True)
self.chunk_size = kwargs.get('chunk_size', 100000)
self.include_row_numbers = kwargs.get('include_row_numbers', False)

if self.chunk_size <= 0:
raise ValueError("Chunk size must be a positive integer")

if not self.allow_nulls and not self.validate:
logger.error("Inconsistent options chosen: "
"if you choose not to allow nulls, but don't enable validation, "
"the output may still have nulls")
raise ValueError("Inconsistent output configuration")


class CsvOptions(OutputOptions):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.compress_using_gzip = kwargs.get('compress_using_gzip', False)


class HyperOptions(OutputOptions):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.table = kwargs.get('table', 'Extract')
self.schema = kwargs.get('schema', 'Extract')


class ExcelOptions(OutputOptions):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.template_path = kwargs.get('template_path', 'excel_template.xlsx')
5 changes: 3 additions & 2 deletions mario/query_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@

def get_formatted_query(query, params):
formatted_query = copy(query)
for key, value in params.items():
formatted_query = formatted_query.replace(f'%({key})s', f"'{value}'")
if len(params) > 0:
for key, value in params.items():
formatted_query = formatted_query.replace(f'%({key})s', f"'{value}'")
return formatted_query


Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='mario-pipeline-tools',
version='0.51',
version='0.52',
packages=['mario'],
url='https://github.com/JiscDACT/mario',
license='all rights reserved',
Expand Down
17 changes: 12 additions & 5 deletions test/test_data_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,10 @@ def test_stream_sql_to_csv_with_compression():
metadata=metadata,
configuration=configuration
)
file = tempfile.NamedTemporaryFile(suffix='.csv')
gzip_path = extractor.stream_sql_to_csv(file_path=file.name, chunk_size=1000, compress_using_gzip=True)
df = pd.read_csv(gzip_path)
with tempfile.TemporaryDirectory() as temp_dir:
file = os.path.join(temp_dir, 'test.csv')
gzip_path = extractor.stream_sql_to_csv(file_path=file, chunk_size=1000, compress_using_gzip=True)
df = pd.read_csv(gzip_path)
assert len(df) == 10194


Expand Down Expand Up @@ -624,16 +625,22 @@ def test_partitioning_extractor_with_row_numbers():
)

path = os.path.join('output', 'test_partitioning_extractor_with_row_numbers')
os.makedirs(path, exist_ok=True)

file = os.path.join(path, 'test.hyper')
csv_file = os.path.join(path, 'test.csv')

os.makedirs(path, exist_ok=True)
# drop existing
for path in [file, csv_file]:
if os.path.exists(path):
os.remove(path)

extractor.stream_sql_to_hyper(
file_path=file,
include_row_numbers=True
)
# Check row_number in hyper output

# Check row_number is in hyper output
assert 'row_number' in get_column_list(hyper_file_path=file, schema='Extract', table='Extract')

# Load it up and export a CSV
Expand Down
Loading