diff --git a/mario/data_extractor.py b/mario/data_extractor.py index 5a596ea..31d5d46 100644 --- a/mario/data_extractor.py +++ b/mario/data_extractor.py @@ -1,5 +1,4 @@ import logging -import shutil import tempfile import pandas as pd @@ -7,6 +6,7 @@ from mario.dataset_specification import DatasetSpecification from mario.metadata import Metadata +from mario.options import CsvOptions, HyperOptions logger = logging.getLogger(__name__) @@ -142,20 +142,37 @@ def get_total(self, measure=None): def validate_data(self, allow_nulls=True): from mario.validation import DataFrameValidator + if self._data is None: + self.get_data_frame() validator = DataFrameValidator( self.dataset_specification, self.metadata, - data=self.get_data_frame() + data=self._data ) return validator.validate_data(allow_nulls) - def get_data_frame(self, minimise=True) -> DataFrame: + def get_data_frame(self, minimise=True, include_row_numbers=False) -> DataFrame: if self._data is None: self.__load__() if minimise: self.__minimise_data__() + if include_row_numbers: + self.__add_row_numbers__() + else: + self.__drop_row_numbers__() return self._data + def __add_row_numbers__(self): + if self._data is None: + self.__load__() + self._data['row_number'] = range(len(self._data)) + + def __drop_row_numbers__(self): + if self._data is None: + self.__load__() + if 'row_number' in self._data.columns: + self._data = self._data .drop(columns=['row_number']) + def save_query(self, file_path: str, formatted: bool = False): """ Output the query used @@ -172,22 +189,39 @@ def save_query(self, file_path: str, formatted: bool = False): with open(file_path, mode='w') as file: file.write(sql) - def save_data_as_csv(self, file_path: str, minimise=True, compress_using_gzip=False): + def save_data_as_csv(self, file_path: str, **kwargs): + options = CsvOptions(**kwargs) if self._data is None: self.__load__() - if minimise: + if options.include_row_numbers: + self.__add_row_numbers__() + else: + self.__drop_row_numbers__() + if options.minimise: self.__minimise_data__() - self._data.to_csv(file_path) + if options.validate: + self.validate_data(allow_nulls=options.allow_nulls) + if options.compress_using_gzip: + compression_options = dict(method='gzip') + file_path = file_path + '.gz' + else: + compression_options = None + self._data.to_csv(file_path, index=False, compression=compression_options) - def save_data_as_hyper(self, file_path: str, table: str = 'Extract', schema: str = 'Extract', minimise=True): - import pantab - from tableauhyperapi import TableName + def save_data_as_hyper(self, file_path: str, **kwargs): + from mario.hyper_utils import save_dataframe_as_hyper + options = HyperOptions(**kwargs) if self._data is None: self.__load__() - table_name = TableName(schema, table) - if minimise: + if options.include_row_numbers: + self.__add_row_numbers__() + else: + self.__drop_row_numbers__() + if options.minimise: self.__minimise_data__() - pantab.frame_to_hyper(df=self._data, database=file_path, table=table_name) + if options.validate: + self.validate_data(allow_nulls=options.allow_nulls) + save_dataframe_as_hyper(df=self._data, file_path=file_path, **kwargs) class HyperFile(DataExtractor): @@ -239,75 +273,34 @@ def get_total(self, measure=None): measure=measure ) - def save_data_as_hyper(self, file_path: str, table: str = 'Extract', schema: str = 'Extract', minimise=False): - if minimise: + def save_data_as_hyper(self, file_path: str, **kwargs): + from mario.hyper_utils import save_hyper_as_hyper, add_row_numbers_to_hyper, get_default_table_and_schema + options = HyperOptions(**kwargs) + if options.minimise: self.__minimise_data__() - shutil.copyfile(self.configuration.file_path, file_path) - - def save_data_as_csv(self, - file_path: str, - minimise=True, - compress_using_gzip=False, - chunk_size=100000 - ): - from mario.hyper_utils import get_default_table_and_schema, add_row_numbers_to_hyper, get_column_list - import pantab - import tempfile - import shutil - import os - - with tempfile.TemporaryDirectory() as temp_dir: - temp_hyper = os.path.join(temp_dir, 'temp.hyper') - shutil.copyfile( - src=self.configuration.file_path, - dst=temp_hyper - ) - - schema, table = get_default_table_and_schema(temp_hyper) - - columns = get_column_list( - hyper_file_path=temp_hyper, + if options.validate: + self.validate_data(allow_nulls=options.allow_nulls) + if options.include_row_numbers: + schema, table = get_default_table_and_schema(self.configuration.file_path) + add_row_numbers_to_hyper( + input_hyper_file_path=self.configuration.file_path, schema=schema, table=table ) + save_hyper_as_hyper(hyper_file=self.configuration.file_path, file_path=file_path, **kwargs) - logger.debug("Adding row numbers to hyper so we can guarantee ordering") - if 'row_number' not in columns: - add_row_numbers_to_hyper( - input_hyper_file_path=temp_hyper, - schema=schema, - table=table - ) - else: - # Remove it so we don't include it in the output - columns.remove('row_number') - - if compress_using_gzip: - compression_options = dict(method='gzip') - file_path = file_path + '.gz' - elif file_path.endswith('.gz'): - compression_options = dict(method='gzip') - else: - compression_options = None - - mode = 'w' - header = True - offset = 0 - column_names = ','.join(f'"{column}"' for column in columns) - - sql = f"SELECT {column_names} FROM \"{schema}\".\"{table}\" ORDER BY row_number" - - while True: - query = f"{sql} LIMIT {chunk_size} OFFSET {offset}" - df_chunk = pantab.frame_from_hyper_query(temp_hyper, query) - if df_chunk.empty: - break - df_chunk.to_csv(file_path, index=False, mode=mode, header=header, - compression=compression_options) - offset += chunk_size - if header: - header = False - mode = "a" + def save_data_as_csv(self,file_path: str, **kwargs): + from mario.hyper_utils import save_hyper_as_csv + options = CsvOptions(**kwargs) + if options.minimise: + self.__minimise_data__() + if options.validate: + self.validate_data(allow_nulls=options.allow_nulls) + save_hyper_as_csv( + hyper_file=self.configuration.file_path, + file_path=file_path, + **kwargs + ) class StreamingDataExtractor(DataExtractor): @@ -324,11 +317,11 @@ def __init__(self, super().__init__(configuration, dataset_specification, metadata) self._data = None - def get_data_frame(self, minimise=True) -> DataFrame: + def get_data_frame(self, minimise=True, include_row_numbers=False) -> DataFrame: if self._data is None: raise NotImplementedError("Dataframe is not available when using a streaming extractor") else: - return super().get_data_frame(minimise=minimise) + return super().get_data_frame(minimise=minimise, include_row_numbers=include_row_numbers) def validate_data(self, allow_nulls=True): if self._data is None: @@ -351,18 +344,19 @@ def get_connection(self): connection = engine.connect().execution_options(stream_results=True) return connection - def save_data_as_csv(self, file_path: str, minimise=False, compress_using_gzip=False): - if minimise: + def save_data_as_csv(self, file_path: str, **kwargs): + options = CsvOptions(**kwargs) + if options.minimise: raise NotImplementedError('Cannot minimise data when using streaming') - self.stream_sql_to_csv(file_path=file_path, compress_using_gzip=compress_using_gzip) + self.stream_sql_to_csv(file_path=file_path, **kwargs) - def save_data_as_hyper(self, file_path: str, table: str = 'Extract', schema: str = 'Extract', minimise=False): - if minimise: + def save_data_as_hyper(self, file_path: str, **kwargs): + options = HyperOptions(**kwargs) + if options.minimise: raise NotImplementedError('Cannot minimise data when using streaming') self.stream_sql_to_hyper( file_path=file_path, - table=table, - schema=schema + **kwargs ) def get_total(self, measure=None): @@ -386,54 +380,39 @@ def get_total(self, measure=None): totals_df = pd.read_sql(totals_query[0], self.get_connection(), params=totals_query[1]) return totals_df.iat[0, 0] - def stream_sql_to_hyper(self, - file_path: str, - table: str = 'Extract', - schema: str = 'Extract', - validate: bool = False, - allow_nulls: bool = True, - minimise: bool = False, - chunk_size: int = 100000): + def stream_sql_to_hyper(self, file_path: str, **kwargs): """ Write From SQL to .hyper using streaming. No data is held in memory apart from chunks of rows as they are read. Optionally, data can be validated as it is read. """ + options = HyperOptions(**kwargs) self.__build_query__() logger.info("Executing query") from tableauhyperapi import TableName from pantab import frame_to_hyper connection = self.get_connection() - table_name = TableName(schema, table) - for df in pd.read_sql(self._query[0], connection, chunksize=chunk_size): - if validate or minimise: + table_name = TableName(options.schema, options.table) + row_counter = 0 + for df in pd.read_sql(self._query[0], connection, chunksize=options.chunk_size): + if options.validate or options.minimise or options.include_row_numbers: self._data = df - if validate: - self.validate_data(allow_nulls=allow_nulls) - if minimise: + if options.validate: + self.validate_data(allow_nulls=options.allow_nulls) + if options.minimise: self.__minimise_data__() df = self._data + self._data = None + if options.include_row_numbers: + df['row_number'] = range(row_counter, row_counter + len(df)) + row_counter += len(df) # Update the counter frame_to_hyper(df, database=file_path, table=table_name, table_mode='a') - def stream_sql_to_csv(self, - file_path, - validate: bool = False, - allow_nulls: bool = True, - chunk_size: int = 100000, - compress_using_gzip: bool = False, - minimise: bool = False - ): - """ - Write From SQL to CSV using streaming. No data is held in memory - apart from chunks of rows as they are read. - Optionally, data can be validated as it is read. - """ - self.__build_query__() - logger.info("Executing query") - connection = self.get_connection() - - if compress_using_gzip: + def stream_sql_query_to_csv(self, file_path, query, connection, row_counter=0, **kwargs) -> int: + from mario.query_builder import get_formatted_query + options = CsvOptions(**kwargs) + if options.compress_using_gzip: compression_options = dict(method='gzip') file_path = file_path + '.gz' else: @@ -441,19 +420,46 @@ def stream_sql_to_csv(self, mode = 'w' header = True - for df in pd.read_sql(self._query[0], connection, chunksize=chunk_size): - if validate or minimise: + + for df in pd.read_sql(get_formatted_query(query[0], query[1]), connection, chunksize=options.chunk_size): + if options.validate or options.minimise: self._data = df - if validate: - self.validate_data(allow_nulls=allow_nulls) - if minimise: + if options.validate: + self.validate_data(allow_nulls=options.allow_nulls) + if options.minimise: self.__minimise_data__() df = self._data + self._data = None + if options.include_row_numbers: + df['row_number'] = range(row_counter, row_counter + len(df)) + row_counter += len(df) # Update the counter df.to_csv(file_path, mode=mode, header=header, index=False, compression=compression_options) if header: header = False mode = "a" + return row_counter + + def stream_sql_to_csv(self, file_path, **kwargs): + """ + Write From SQL to CSV using streaming. No data is held in memory + apart from chunks of rows as they are read. + Optionally, data can be validated as it is read. + """ + self.__build_query__() + options = CsvOptions(**kwargs) + logger.info("Executing query") + connection = self.get_connection() + + self.stream_sql_query_to_csv( + file_path=file_path, + query=self._query, + connection=connection, + row_counter=0, + **kwargs + ) + if options.compress_using_gzip: + file_path = file_path + '.gz' return file_path def stream_sql_to_csv_using_bcp(self, @@ -534,7 +540,7 @@ def __init__(self, self._data = dataframe -class PartitioningExtractor(DataExtractor): +class PartitioningExtractor(StreamingDataExtractor): """ A data extractor that loads from SQL in batches using a specified constraint to partition by @@ -549,16 +555,6 @@ def __init__(self, self._data = None self.partition_column = partition_column - def get_connection(self): - from sqlalchemy import create_engine - - if self.configuration.hook is not None: - return self.configuration.hook.get_conn() - - engine = create_engine(self.configuration.connection_string) - connection = engine.connect().execution_options(stream_results=True) - return connection - def __get_partition_values__(self): for constraint in self.dataset_specification.constraints: if constraint.item == self.partition_column: @@ -606,85 +602,40 @@ def __load_from_sql__(self): for value in self.__get_partition_values__(): self.__load_from_sql_using_partition__(partition_value=value) - def get_data_frame(self, minimise=True) -> DataFrame: + def get_data_frame(self, minimise=True, include_row_numbers=False) -> DataFrame: raise NotImplementedError() def get_total(self, measure=None): raise NotImplementedError() - def save_data_as_hyper(self, - file_path: str, - table: str = 'Extract', - schema: str = 'Extract', - minimise=True): - self.stream_sql_to_hyper(file_path=file_path, table=table, schema=schema, minimise=minimise) + def save_data_as_hyper(self, file_path: str, **kwargs): + self.stream_sql_to_hyper(file_path=file_path, **kwargs) - def save_data_as_csv(self, file_path: str, minimise=True, compress_using_gzip=False): - self.stream_sql_to_csv( - file_path=file_path, - minimise=minimise, - compress_using_gzip=compress_using_gzip - ) + def save_data_as_csv(self, file_path: str, **kwargs): + self.stream_partition_sql_to_csv(file_path=file_path, **kwargs) - def stream_sql_to_csv(self, - file_path, - validate: bool = False, - allow_nulls: bool = True, - chunk_size: int = 100000, - compress_using_gzip: bool = False, - minimise: bool = False, - include_row_numbers: bool = False - ): + def stream_partition_sql_to_csv(self, file_path, **kwargs): """ Write From SQL to CSV using streaming. No data is held in memory apart from chunks of rows as they are read. Optionally, data can be validated as it is read. """ - from mario.query_builder import get_formatted_query logger.info("Executing query") connection = self.get_connection() - - if compress_using_gzip: - compression_options = dict(method='gzip') - file_path = file_path + '.gz' - else: - compression_options = None - - mode = 'w' - header = True - row_counter = 0 # Initialize global row counter - for partition_value in self.__get_partition_values__(): query = self.__build_query_using_partition__(partition_value=partition_value) - for df in pd.read_sql(get_formatted_query(query[0],query[1]), connection, chunksize=chunk_size): - if validate or minimise: - self._data = df - if validate: - self.validate_data(allow_nulls=allow_nulls) - if minimise: - self.__minimise_data__() - df = self._data - if include_row_numbers: - df['row_number'] = range(row_counter, row_counter + len(df)) - row_counter += len(df) # Update the counter - df.to_csv(file_path, mode=mode, header=header, index=False, compression=compression_options) - if header: - header = False - mode = "a" + row_counter = self.stream_sql_query_to_csv( + connection=connection, + query=query, + file_path=file_path, + row_counter=row_counter, + **kwargs + ) return file_path - def stream_sql_to_hyper(self, - file_path: str, - table: str = 'Extract', - schema: str = 'Extract', - validate: bool = False, - allow_nulls: bool = True, - minimise: bool = False, - chunk_size: int = 100000, - include_row_numbers: bool = False - ): + def stream_sql_to_hyper(self, file_path: str, **kwargs): """ Write From SQL to .hyper using streaming. No data is held in memory apart from chunks of rows as they are read. @@ -696,28 +647,30 @@ def stream_sql_to_hyper(self, from pantab import frame_to_hyper from mario.query_builder import get_formatted_query + options = HyperOptions(**kwargs) connection = self.get_connection() - table_name = TableName(schema, table) + table_name = TableName(options.schema, options.table) row_counter = 0 # Initialize global row counter for partition_value in self.__get_partition_values__(): query = self.__build_query_using_partition__(partition_value=partition_value) - for df in pd.read_sql(get_formatted_query(query[0], query[1]), connection, chunksize=chunk_size): - if validate or minimise: + for df in pd.read_sql(get_formatted_query(query[0], query[1]), connection, chunksize=options.chunk_size): + if options.validate or options.minimise: self._data = df - if validate: - self.validate_data(allow_nulls=allow_nulls) - if minimise: + if options.validate: + self.validate_data(allow_nulls=options.allow_nulls) + if options.minimise: self.__minimise_data__() df = self._data - if include_row_numbers: + self._data = None + if options.include_row_numbers: df['row_number'] = range(row_counter, row_counter + len(df)) row_counter += len(df) # Update the counter if len(df) == 0: logger.warning(f"No rows found for partition with value '{partition_value}'") else: - logger.info(f"Saving {chunk_size} rows to file") + logger.info(f"Saving {options.chunk_size} rows to file") frame_to_hyper(df, database=file_path, table=table_name, table_mode='a') diff --git a/mario/dataset_builder.py b/mario/dataset_builder.py index 0f7302d..dca9aa6 100644 --- a/mario/dataset_builder.py +++ b/mario/dataset_builder.py @@ -18,6 +18,7 @@ class Format(Enum): EXCEL_PIVOT = 'xlsx' CSV = 'csv' EXCEL_INFO_SHEET = 'info' + HYPER = 'hyper' class DatasetBuilder: @@ -64,30 +65,35 @@ def remove_redundant_hierarchies(self): if 'hierarchies' in item.properties: item.set_property('hierarchies', [h for h in item.get_property('hierarchies') if h['hierarchy'] != hierarchy]) - def build(self, output_format: Format, file_path: str, template_path: str = None): + def build(self, output_format: Format, file_path: str, template_path: str = None, **kwargs): if output_format == Format.TABLEAU_PACKAGED_DATASOURCE: - self.__build_tdsx__(file_path) + self.__build_tdsx__(file_path, **kwargs) elif output_format == Format.CSV: - self.__build_csv__(file_path) + self.__build_csv__(file_path, **kwargs) elif output_format == Format.EXCEL_PIVOT: - self.__build_excel_pivot__(file_path, template_path) + self.__build_excel_pivot__(file_path, **kwargs) elif output_format == Format.EXCEL_INFO_SHEET: - self.__build_excel_info_sheet(file_path, template_path) + self.__build_excel_info_sheet(file_path, **kwargs) + elif output_format == Format.HYPER: + self.__build_hyper__(file_path, **kwargs) else: raise NotImplementedError - def __build_excel_info_sheet(self, file_path: str, template_path: str): + def __build_hyper__(self, file_path: str, **kwargs): + self.data.save_data_as_hyper(file_path=file_path, **kwargs) + + def __build_excel_info_sheet(self, file_path: str, **kwargs): from mario.excel_builder import ExcelBuilder excel_builder = ExcelBuilder( output_file_path=file_path, dataset_specification=self.dataset_specification, metadata=self.metadata, data_extractor=self.data, - template_path=template_path + **kwargs ) excel_builder.create_notes_only() - def __build_excel_pivot__(self, file_path: str, template_path: str): + def __build_excel_pivot__(self, file_path: str, **kwargs): from mario.excel_builder import ExcelBuilder if self.data.get_total() > 1000000: logger.warning("The dataset is larger than 1m rows; this isn't supported in Excel format") @@ -96,15 +102,15 @@ def __build_excel_pivot__(self, file_path: str, template_path: str): dataset_specification=self.dataset_specification, metadata=self.metadata, data_extractor=self.data, - template_path=template_path + **kwargs ) excel_builder.create_workbook() - def __build_csv__(self, file_path: str, compress_using_gzip=False): + def __build_csv__(self, file_path: str, **kwargs): # TODO export Info sheet as well - see code in Automation2.0 and TDSA. - self.data.save_data_as_csv(file_path=file_path, compress_using_gzip=compress_using_gzip) + self.data.save_data_as_csv(file_path=file_path, **kwargs) - def __build_tdsx__(self, file_path: str): + def __build_tdsx__(self, file_path: str, **kwargs): from mario.hyper_utils import get_default_table_and_schema from tableau_builder.json_metadata import JsonRepository with tempfile.TemporaryDirectory() as temp_folder: @@ -116,7 +122,7 @@ def __build_tdsx__(self, file_path: str): # Move the hyper extract data_file_path = os.path.join(temp_folder, 'data.hyper') - self.data.save_data_as_hyper(file_path=data_file_path) + self.data.save_data_as_hyper(file_path=data_file_path, **kwargs) # Get the table and schema name from the extract schema, table = get_default_table_and_schema(data_file_path) diff --git a/mario/excel_builder.py b/mario/excel_builder.py index 9670ded..dfe6b7c 100644 --- a/mario/excel_builder.py +++ b/mario/excel_builder.py @@ -8,6 +8,7 @@ from mario.data_extractor import DataExtractor from mario.dataset_specification import DatasetSpecification from mario.metadata import Metadata +from mario.options import ExcelOptions logger = logging.getLogger(__name__) style_attrs = ["alignment", "border", "fill", "font", "number_format", "protection"] @@ -33,15 +34,14 @@ def __init__(self, data_extractor: DataExtractor, dataset_specification: DatasetSpecification, metadata: Metadata, - template_path: str + **kwargs ): + self.data_extractor = data_extractor self.dataset_specification = dataset_specification self.metadata = metadata self.filepath = output_file_path - self.template = "excel_template.xlsx" - if template_path is not None: - self.template = template_path + self.options = ExcelOptions(**kwargs) self.workbook = None self.rows = None self.cols = None @@ -53,7 +53,7 @@ def create_workbook(self, create_notes_page=False): Creates a write-only workbook and builds content in streaming mode to conserve memory """ - template_workbook = load_workbook(self.template) + template_workbook = load_workbook(self.options.template_path) template_workbook.save(self.filepath) if create_notes_page: @@ -129,7 +129,7 @@ def create_notes_only(self, filename=None, data_format='CSV'): """ Create a Notes page only, to accompany CSV outputs """ - self.workbook = load_workbook(self.template) + self.workbook = load_workbook(self.options.template_path) self.__update_notes__(data_format=data_format, ws=self.workbook.get_sheet_by_name('Notes')) self.workbook.remove(self.workbook.get_sheet_by_name('Data')) self.workbook.remove(self.workbook.get_sheet_by_name('Pivot')) @@ -139,7 +139,15 @@ def create_notes_only(self, filename=None, data_format='CSV'): self.workbook.save(filename=self.filepath) def __create_data_page__(self): - df: DataFrame = self.data_extractor.get_data_frame() + + if self.options.validate: + if not self.data_extractor.validate_data(allow_nulls=self.options.allow_nulls): + raise ValueError("Validation error") + + df: DataFrame = self.data_extractor.get_data_frame( + minimise=self.options.minimise, + include_row_numbers=self.options.include_row_numbers + ) # Reorder the columns to put the measure in col #1 cols = df.columns.tolist() diff --git a/mario/hyper_utils.py b/mario/hyper_utils.py index 3dc8400..708a81e 100644 --- a/mario/hyper_utils.py +++ b/mario/hyper_utils.py @@ -3,6 +3,7 @@ """ import logging from typing import List +from mario.options import CsvOptions, HyperOptions log = logging.getLogger(__name__) @@ -276,4 +277,83 @@ def concatenate_hypers(hyper_file_path1, hyper_file_path2, output_hyper_file_pat df2 = frame_from_hyper(hyper_file_path2, table=TableName(schema, table)) df_merged = pd.concat([df1, df2], axis=1) - frame_to_hyper(df=df_merged, database=output_hyper_file_path, table=TableName(schema, table)) \ No newline at end of file + frame_to_hyper(df=df_merged, database=output_hyper_file_path, table=TableName(schema, table)) + + +def save_hyper_as_hyper(hyper_file, file_path, **kwargs): + import shutil + shutil.copyfile(hyper_file, file_path) + + +def save_hyper_as_csv(hyper_file: str, file_path: str, **kwargs): + import pantab + import tempfile + import shutil + import os + + options = CsvOptions(**kwargs) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_hyper = os.path.join(temp_dir, 'temp.hyper') + shutil.copyfile( + src=hyper_file, + dst=temp_hyper + ) + + schema, table = get_default_table_and_schema(temp_hyper) + + columns = get_column_list( + hyper_file_path=temp_hyper, + schema=schema, + table=table + ) + + if 'row_number' not in columns: + log.debug("Adding row numbers to hyper so we can guarantee ordering") + add_row_numbers_to_hyper( + input_hyper_file_path=temp_hyper, + schema=schema, + table=table + ) + else: + log.debug("Data source already contains row numbers") + + if options.include_row_numbers is False and 'row_number' in columns: + columns.remove('row_number') + elif options.include_row_numbers is True and 'row_number' not in columns: + columns.append('row_number') + + if options.compress_using_gzip: + compression_options = dict(method='gzip') + file_path = file_path + '.gz' + elif file_path.endswith('.gz'): + compression_options = dict(method='gzip') + else: + compression_options = None + + mode = 'w' + header = True + offset = 0 + column_names = ','.join(f'"{column}"' for column in columns) + + sql = f"SELECT {column_names} FROM \"{schema}\".\"{table}\" ORDER BY row_number" + + while True: + query = f"{sql} LIMIT {options.chunk_size} OFFSET {offset}" + df_chunk = pantab.frame_from_hyper_query(temp_hyper, query) + if df_chunk.empty: + break + df_chunk.to_csv(file_path, index=False, mode=mode, header=header, + compression=compression_options) + offset += options.chunk_size + if header: + header = False + mode = "a" + + +def save_dataframe_as_hyper(df, file_path, **kwargs): + from tableauhyperapi import TableName + import pantab + options = HyperOptions(**kwargs) + table_name = TableName(options.schema, options.table) + pantab.frame_to_hyper(df=df, database=file_path, table=table_name) \ No newline at end of file diff --git a/mario/options.py b/mario/options.py new file mode 100644 index 0000000..7d06c6f --- /dev/null +++ b/mario/options.py @@ -0,0 +1,42 @@ +""" +Common output options for different data formats along with their default values +""" +import logging +logger = logging.getLogger(__name__) + + +class OutputOptions: + def __init__(self, **kwargs): + self.minimise = kwargs.get('minimise', False) + self.validate = kwargs.get('validate', False) + self.allow_nulls = kwargs.get('allow_nulls', True) + self.chunk_size = kwargs.get('chunk_size', 100000) + self.include_row_numbers = kwargs.get('include_row_numbers', False) + + if self.chunk_size <= 0: + raise ValueError("Chunk size must be a positive integer") + + if not self.allow_nulls and not self.validate: + logger.error("Inconsistent options chosen: " + "if you choose not to allow nulls, but don't enable validation, " + "the output may still have nulls") + raise ValueError("Inconsistent output configuration") + + +class CsvOptions(OutputOptions): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.compress_using_gzip = kwargs.get('compress_using_gzip', False) + + +class HyperOptions(OutputOptions): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.table = kwargs.get('table', 'Extract') + self.schema = kwargs.get('schema', 'Extract') + + +class ExcelOptions(OutputOptions): + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.template_path = kwargs.get('template_path', 'excel_template.xlsx') diff --git a/mario/query_builder.py b/mario/query_builder.py index 429a555..4093b23 100644 --- a/mario/query_builder.py +++ b/mario/query_builder.py @@ -13,8 +13,9 @@ def get_formatted_query(query, params): formatted_query = copy(query) - for key, value in params.items(): - formatted_query = formatted_query.replace(f'%({key})s', f"'{value}'") + if len(params) > 0: + for key, value in params.items(): + formatted_query = formatted_query.replace(f'%({key})s', f"'{value}'") return formatted_query diff --git a/setup.py b/setup.py index 9314c46..fc4d6a5 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='mario-pipeline-tools', - version='0.51', + version='0.52', packages=['mario'], url='https://github.com/JiscDACT/mario', license='all rights reserved', diff --git a/test/test_data_extractor.py b/test/test_data_extractor.py index 81e5896..d41870f 100644 --- a/test/test_data_extractor.py +++ b/test/test_data_extractor.py @@ -127,9 +127,10 @@ def test_stream_sql_to_csv_with_compression(): metadata=metadata, configuration=configuration ) - file = tempfile.NamedTemporaryFile(suffix='.csv') - gzip_path = extractor.stream_sql_to_csv(file_path=file.name, chunk_size=1000, compress_using_gzip=True) - df = pd.read_csv(gzip_path) + with tempfile.TemporaryDirectory() as temp_dir: + file = os.path.join(temp_dir, 'test.csv') + gzip_path = extractor.stream_sql_to_csv(file_path=file, chunk_size=1000, compress_using_gzip=True) + df = pd.read_csv(gzip_path) assert len(df) == 10194 @@ -624,16 +625,22 @@ def test_partitioning_extractor_with_row_numbers(): ) path = os.path.join('output', 'test_partitioning_extractor_with_row_numbers') + os.makedirs(path, exist_ok=True) + file = os.path.join(path, 'test.hyper') csv_file = os.path.join(path, 'test.csv') - os.makedirs(path, exist_ok=True) + # drop existing + for path in [file, csv_file]: + if os.path.exists(path): + os.remove(path) extractor.stream_sql_to_hyper( file_path=file, include_row_numbers=True ) - # Check row_number in hyper output + + # Check row_number is in hyper output assert 'row_number' in get_column_list(hyper_file_path=file, schema='Extract', table='Extract') # Load it up and export a CSV diff --git a/test/test_dataset_builder.py b/test/test_dataset_builder.py new file mode 100644 index 0000000..bd5f67f --- /dev/null +++ b/test/test_dataset_builder.py @@ -0,0 +1,188 @@ +import shutil + +import pandas as pd +import pantab +from tableauhyperapi import TableName + +from mario.dataset_specification import dataset_from_json, Constraint +from mario.metadata import metadata_from_json +from mario.dataset_builder import DatasetBuilder, Format +from mario.data_extractor import Configuration, HyperFile, DataExtractor, DataFrameExtractor, \ + StreamingDataExtractor, PartitioningExtractor +import os +import pytest + +from mario.query_builder import ViewBasedQueryBuilder + + +def setup_dataset_builder_test(test): + dataset = dataset_from_json(os.path.join('test', 'dataset.json')) + metadata = metadata_from_json(os.path.join('test', 'metadata.json')) + output_path = os.path.join('output', test) + shutil.rmtree(output_path, ignore_errors=True) + os.makedirs(output_path) + shutil.copyfile(src=os.path.join('test', 'orders_with_nulls.hyper'), dst=os.path.join(output_path, 'orders_with_nulls.hyper')) + return output_path, dataset, metadata + + +def run_consistency_checks(builder, output_path): + """ + Generates outputs in various formats with various output options to check + they are being created consistently + :param builder: + :param output_path: + :return: + """ + + # CSV + path = os.path.join(output_path, 'output.csv') + + # Build as a gzipped csv file using validation + builder.build(file_path=path, output_format=Format.CSV, compress_using_gzip=True, validate=True, allow_nulls=True) + df = pd.read_csv(path + '.gz') + assert len(df) > 0 + + # Build without allowing NULLs + with pytest.raises(ValueError): + builder.build(file_path=path, output_format=Format.CSV, validate=True, allow_nulls=False) + + # Build with row numbers + builder.build(file_path=path, output_format=Format.CSV, allow_nulls=True, include_row_numbers=True) + df = pd.read_csv(path) + assert 'row_number' in df.columns + + # Build without row numbers + builder.build(file_path=path, output_format=Format.CSV, allow_nulls=True, include_row_numbers=False) + df = pd.read_csv(path) + assert 'row_number' not in df.columns + + # Now onto Hypers + table_name = TableName('Extract', 'Extract') + path = os.path.join(output_path, 'output.hyper') + if os.path.exists(path): + os.remove(path) + + # Build hyper without allowing NULLs + with pytest.raises(ValueError): + builder.build(file_path=path, output_format=Format.HYPER, validate=True, allow_nulls=False) + + # Build hyper with row numbers + shutil.copyfile(src=os.path.join('test', 'orders_with_nulls.hyper'), dst=os.path.join(output_path, 'orders_with_nulls.hyper')) + builder.build(file_path=path, output_format=Format.HYPER, allow_nulls=True, include_row_numbers=True) + df = pantab.frame_from_hyper(source=path, table=table_name) + assert 'row_number' in df.columns + os.remove(path) + + # Build hyper without row numbers + shutil.copyfile(src=os.path.join('test', 'orders_with_nulls.hyper'), dst=os.path.join(output_path, 'orders_with_nulls.hyper')) + builder.build(file_path=path, output_format=Format.HYPER, allow_nulls=True, include_row_numbers=False) + df = pantab.frame_from_hyper(source=path, table=table_name) + assert 'row_number' not in df.columns + + # Excel + path = os.path.join(output_path, 'output.xlsx') + + # + # Currently the Excel builder relies on get_data_frame() which isn't supported + # with streaming-based extractors, so we need to see if its supported + # + streaming = False + try: + builder.data.get_data_frame() + except NotImplementedError: + streaming = True + + if not streaming: + # Build without allowing NULLs + with pytest.raises(ValueError): + builder.build(file_path=path, output_format=Format.EXCEL_PIVOT, validate=True, allow_nulls=False) + + # Build with row numbers + builder.build(file_path=path, output_format=Format.EXCEL_PIVOT, allow_nulls=True, include_row_numbers=True) + df = pd.read_excel(path) + assert 'row_number' in df.columns + + # Build without row numbers + builder.build(file_path=path, output_format=Format.EXCEL_PIVOT, allow_nulls=True, include_row_numbers=False) + df = pd.read_excel(path) + assert 'row_number' not in df.columns + + # Build with defaults + builder.build(file_path=path, output_format=Format.EXCEL_PIVOT) + df = pd.read_excel(path) + assert 'row_number' not in df.columns + + +def test_dataset_builder_with_hyperfile_extractor(): + output_path, dataset, metadata = setup_dataset_builder_test('test_dataset_builder_with_hyperfile_extractor') + + configuration = Configuration(file_path=os.path.join(output_path, 'orders_with_nulls.hyper')) + extractor = HyperFile(configuration=configuration, dataset_specification=dataset, metadata=metadata) + builder = DatasetBuilder(dataset_specification=dataset, metadata=metadata, data=extractor) + + run_consistency_checks(builder, output_path) + + +def test_dataset_builder_with_dataframe_extractor(): + from tableauhyperapi import TableName + output_path, dataset, metadata = setup_dataset_builder_test('test_dataset_builder_with_dataframe_extractor') + + configuration = Configuration(file_path=os.path.join(output_path, 'orders_with_nulls.hyper')) + df = pantab.frame_from_hyper(source=configuration.file_path, table=TableName('Extract','Extract')) + extractor = DataFrameExtractor(dataframe=df, dataset_specification=dataset, metadata=metadata) + builder = DatasetBuilder(dataset_specification=dataset, metadata=metadata, data=extractor) + + run_consistency_checks(builder, output_path) + + +def test_dataset_builder_with_default_extractor(): + output_path, dataset, metadata = setup_dataset_builder_test('test_dataset_builder_with_default_extractor') + + configuration = Configuration(file_path=os.path.join(output_path, 'orders_with_nulls.hyper')) + extractor = DataExtractor(configuration=configuration, dataset_specification=dataset, metadata=metadata) + builder = DatasetBuilder(dataset_specification=dataset, metadata=metadata, data=extractor) + + run_consistency_checks(builder, output_path) + + +def test_dataset_builder_with_streaming_extractor(): + # Skip this test if we don't have a connection string + if not os.environ.get('CONNECTION_STRING'): + pytest.skip("Skipping SQL test as no database configured") + + output_path, dataset, metadata = setup_dataset_builder_test('test_dataset_builder_with_streaming_extractor') + + configuration = Configuration( + connection_string=os.environ.get('CONNECTION_STRING'), + schema="dev", + view="superstore_with_nulls", + query_builder=ViewBasedQueryBuilder + ) + extractor = StreamingDataExtractor(configuration=configuration, dataset_specification=dataset, metadata=metadata) + builder = DatasetBuilder(dataset_specification=dataset, metadata=metadata, data=extractor) + + run_consistency_checks(builder, output_path) + + +def test_dataset_builder_with_partitioning_extractor(): + # Skip this test if we don't have a connection string + if not os.environ.get('CONNECTION_STRING'): + pytest.skip("Skipping SQL test as no database configured") + + output_path, dataset, metadata = setup_dataset_builder_test('test_dataset_builder_with_streaming_extractor') + + constraint = Constraint() + constraint.item = 'Category' + constraint.allowed_values = ['Furniture', 'Office Supplies'] + dataset.constraints.append(constraint) + + configuration = Configuration( + connection_string=os.environ.get('CONNECTION_STRING'), + schema="dev", + view="superstore_with_nulls", + query_builder=ViewBasedQueryBuilder + ) + extractor = PartitioningExtractor(configuration=configuration, dataset_specification=dataset, metadata=metadata, partition_column='Category') + builder = DatasetBuilder(dataset_specification=dataset, metadata=metadata, data=extractor) + + run_consistency_checks(builder, output_path)