diff --git a/README.md b/README.md index 7330e29..89f9d3e 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,8 @@ scripts and larger software packages to provide basic STAR file I/O functions. Data is exposed as simple python dictionaries or [pandas dataframes](https://pandas.pydata.org/docs/user_guide/dsintro.html#dataframe). +(The data may be exposed as [polars dataframes](https://docs.pola.rs/py-polars/html/reference/dataframe/index.html) +if `polars=True` is passed to the `read` function.) This package was designed principally for compatibility with files generated by [RELION](https://www3.mrc-lmb.cam.ac.uk/relion/index.php/Main_Page). diff --git a/pyproject.toml b/pyproject.toml index a5ed6c5..afa9c43 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,6 +38,7 @@ classifiers = [ dependencies = [ "numpy", "pandas>=2.1.1", + "polars>=0.20", "pyarrow", "typing-extensions", ] diff --git a/src/starfile/functions.py b/src/starfile/functions.py index b563f8a..c2510cc 100644 --- a/src/starfile/functions.py +++ b/src/starfile/functions.py @@ -1,17 +1,16 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Union, Optional +from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: - import pandas as pd from os import PathLike + from .parser import StarParser -from .writer import StarWriter from .typing import DataBlock +from .writer import StarWriter if TYPE_CHECKING: - import pandas as pd from os import PathLike @@ -19,8 +18,9 @@ def read( filename: PathLike, read_n_blocks: Optional[int] = None, always_dict: bool = False, - parse_as_string: List[str] = [] -) -> Union[DataBlock, Dict[DataBlock]]: + parse_as_string: list[str] = [], + polars: bool = False, +) -> Union[DataBlock, dict[DataBlock]]: """Read data from a STAR file. Basic data blocks are read as dictionaries. Loop blocks are read as pandas @@ -40,7 +40,12 @@ def read( parse_as_string: list[str] A list of keys or column names which will not be coerced to numeric values. """ - parser = StarParser(filename, n_blocks_to_read=read_n_blocks, parse_as_string=parse_as_string) + parser = StarParser( + filename, + n_blocks_to_read=read_n_blocks, + parse_as_string=parse_as_string, + polars=polars, + ) if len(parser.data_blocks) == 1 and always_dict is False: return list(parser.data_blocks.values())[0] else: @@ -48,14 +53,14 @@ def read( def write( - data: Union[DataBlock, Dict[str, DataBlock], List[DataBlock]], + data: Union[DataBlock, dict[str, DataBlock], list[DataBlock]], filename: PathLike, - float_format: str = '%.6f', - sep: str = '\t', - na_rep: str = '', + float_format: int = 6, + sep: str = "\t", + na_rep: str = "", quote_character: str = '"', quote_all_strings: bool = False, - **kwargs + **kwargs, ): """Write data to disk in the STAR format. @@ -66,8 +71,8 @@ def write( If a dictionary of datablocks are passed the keys will be the data block names. filename: PathLike Path where the file will be saved. - float_format: str - Float format string which will be passed to pandas. + float_format: int + Number of decimal places to write floats to. sep: str Separator between values, will be passed to pandas. na_rep: str diff --git a/src/starfile/parser.py b/src/starfile/parser.py index c21a8c2..c7683fa 100644 --- a/src/starfile/parser.py +++ b/src/starfile/parser.py @@ -1,14 +1,16 @@ from __future__ import annotations +import shlex from collections import deque +from functools import lru_cache from io import StringIO from linecache import getline -import shlex +from pathlib import Path +from typing import TYPE_CHECKING, Optional, Union import numpy as np import pandas as pd -from pathlib import Path -from typing import TYPE_CHECKING, Union, Optional, Dict, Tuple, List +import polars as pl from starfile.typing import DataBlock @@ -21,14 +23,15 @@ class StarParser: n_lines_in_file: int n_blocks_to_read: int current_line_number: int - data_blocks: Dict[DataBlock] - parse_as_string: List[str] + data_blocks: dict[DataBlock] + parse_as_string: list[str] def __init__( self, filename: PathLike, n_blocks_to_read: Optional[int] = None, - parse_as_string: List[str] = [], + parse_as_string: list[str] = [], + polars: bool = False, ): # set filename, with path checking filename = Path(filename) @@ -42,48 +45,51 @@ def __init__( self.n_blocks_to_read = n_blocks_to_read self.parse_as_string = parse_as_string + self.polars = polars + # parse file self.current_line_number = 0 self.parse_file() - @property - def current_line(self) -> str: - return getline(str(self.filename), self.current_line_number).strip() + @lru_cache(maxsize=25) + def _get_line(self, line_number: int) -> str: + return " ".join(getline(str(self.filename), line_number).split()) def parse_file(self): while self.current_line_number <= self.n_lines_in_file: if len(self.data_blocks) == self.n_blocks_to_read: break - elif self.current_line.startswith('data_'): + elif self._get_line(self.current_line_number).startswith("data_"): block_name, block = self._parse_data_block() self.data_blocks[block_name] = block else: self.current_line_number += 1 - def _parse_data_block(self) -> Tuple[str, DataBlock]: + def _parse_data_block(self) -> tuple[str, DataBlock]: # current line starts with 'data_foo' - block_name = self.current_line[5:] # 'data_foo' -> 'foo' + block_name = self._get_line(self.current_line_number)[5:] # 'data_foo' -> 'foo' self.current_line_number += 1 # iterate over file, while self.current_line_number <= self.n_lines_in_file: self.current_line_number += 1 - if self.current_line.startswith('loop_'): + current_line = self._get_line(self.current_line_number) + if current_line.startswith("loop_"): return block_name, self._parse_loop_block() - elif self.current_line.startswith('_'): # line is simple block + elif current_line.startswith("_"): # line is simple block return block_name, self._parse_simple_block() - def _parse_simple_block(self) -> Dict[str, Union[str, int, float]]: + def _parse_simple_block(self) -> dict[str, Union[str, int, float]]: block = {} while self.current_line_number <= self.n_lines_in_file: - if self.current_line.startswith('data'): + c = self._get_line(self.current_line_number) + if c.startswith("data"): break - elif self.current_line.startswith('_'): # '_foo bar' - k, v = shlex.split(self.current_line) + elif c.startswith("_"): # '_foo bar' + k, v = shlex.split(c) column_name = k[1:] - parse_column_as_string = ( - self.parse_as_string is not None - and any(column_name == col for col in self.parse_as_string) + parse_column_as_string = self.parse_as_string is not None and any( + column_name == col for col in self.parse_as_string ) if parse_column_as_string is True: block[column_name] = v @@ -92,58 +98,66 @@ def _parse_simple_block(self) -> Dict[str, Union[str, int, float]]: self.current_line_number += 1 return block - def _parse_loop_block(self) -> pd.DataFrame: + def _parse_loop_block(self) -> pd.DataFrame | pl.DataFrame: # parse loop header loop_column_names = deque() self.current_line_number += 1 - while self.current_line.startswith('_'): - column_name = self.current_line.split()[0][1:] + while self._get_line(self.current_line_number).startswith("_"): + column_name = self._get_line(self.current_line_number).split()[0][1:] loop_column_names.append(column_name) self.current_line_number += 1 # now parse the loop block data loop_data = deque() while self.current_line_number <= self.n_lines_in_file: - if self.current_line.startswith('data_'): + current_line = self._get_line(self.current_line_number) + if current_line.startswith("data_"): break - loop_data.append(self.current_line) + previous_line = self._get_line(self.current_line_number - 1) + if not (current_line.isspace() and previous_line.isspace()) and ( + current_line and previous_line + ): + loop_data.append(current_line) self.current_line_number += 1 - loop_data = '\n'.join(loop_data) - if loop_data[-2:] != '\n': - loop_data += '\n' + loop_data = "\n".join(loop_data) + if loop_data[-2:] != "\n": + loop_data += "\n" # put string data into a dataframe - if loop_data == '\n': + if loop_data == "\n": n_cols = len(loop_column_names) - df = pd.DataFrame(np.zeros(shape=(0, n_cols))) + df = pl.DataFrame(np.zeros(shape=(0, n_cols))) else: - column_name_to_index = {col: idx for idx, col in enumerate(loop_column_names)} - df = pd.read_csv( + df = pl.read_csv( StringIO(loop_data.replace("'", '"')), - delimiter=r'\s+', - header=None, - comment='#', - dtype={column_name_to_index[k]: str for k in self.parse_as_string if k in loop_column_names}, - keep_default_na=False, - engine='c', + separator=" ", + has_header=False, + comment_prefix="#", + dtypes={ + k: pl.String for k in self.parse_as_string if k in loop_column_names + }, + truncate_ragged_lines=True, + null_values=["", ""], ) df.columns = loop_column_names - # Numericise all columns in temporary copy - df_numeric = df.apply(_apply_numeric) - - # Replace columns that are all NaN with the original columns - df_numeric[df_numeric.columns[df_numeric.isna().all()]] = df[df_numeric.columns[df_numeric.isna().all()]] + # If the column type is string then use empty strings rather than null + df = df.with_columns(pl.col(pl.String).fill_null("")) # Replace columns that should be strings for col in df.columns: - df[col] = df_numeric[col] if col not in self.parse_as_string else df[col] - return df + if col in self.parse_as_string: + df = df.with_columns( + pl.col(col).cast(pl.String).fill_null("").alias(col) + ) + if self.polars: + return df + return df.to_pandas() def count_lines(file: Path) -> int: - with open(file, 'rb') as f: + with open(file, "rb") as f: return sum(1 for _ in f) @@ -169,10 +183,3 @@ def numericise(value: str) -> Union[str, int, float]: # If it's not a float either, leave it as a string value = value return value - - -def _apply_numeric(col: pd.Series) -> pd.Series: - try: - return pd.to_numeric(col) - except ValueError: - return col diff --git a/src/starfile/typing.py b/src/starfile/typing.py index c039b57..58544e7 100644 --- a/src/starfile/typing.py +++ b/src/starfile/typing.py @@ -1,11 +1,11 @@ from __future__ import annotations -from typing import Union, Dict -from typing_extensions import TypeAlias +from typing import Dict, Union import pandas as pd +import polars as pl +from typing_extensions import TypeAlias DataBlock: TypeAlias = Union[ - pd.DataFrame, - Dict[str, Union[str, int, float]] + Union[pd.DataFrame, pl.DataFrame], Dict[str, Union[str, int, float]] ] diff --git a/src/starfile/writer.py b/src/starfile/writer.py index 833c689..f24fb72 100644 --- a/src/starfile/writer.py +++ b/src/starfile/writer.py @@ -1,15 +1,15 @@ from __future__ import annotations from datetime import datetime -from pathlib import Path -from typing import TYPE_CHECKING, Union, Dict, List from importlib.metadata import version -import csv +from pathlib import Path +from typing import TYPE_CHECKING, Dict, List, Union import pandas as pd +import polars as pl -from .utils import TextBuffer from .typing import DataBlock +from .utils import TextBuffer if TYPE_CHECKING: from os import PathLike @@ -20,11 +20,11 @@ class StarWriter: def __init__( self, - data_blocks: Union[DataBlock, Dict[str, DataBlock], List[DataBlock]], + data_blocks: Union[DataBlock, dict[str, DataBlock], list[DataBlock]], filename: PathLike, - float_format: str = '%.6f', - separator: str = '\t', - na_rep: str = '', + float_format: int = 6, + separator: str = "\t", + na_rep: str = "", quote_character: str = '"', quote_all_strings: bool = False, ): @@ -43,10 +43,11 @@ def __init__( self.write() def coerce_data_blocks( - self, - data_blocks: Union[DataBlock, List[DataBlock], Dict[str, DataBlock]] - ) -> Dict[str, DataBlock]: - if isinstance(data_blocks, pd.DataFrame): + self, data_blocks: Union[DataBlock, list[DataBlock], dict[str, DataBlock]] + ) -> dict[str, DataBlock]: + if isinstance(data_blocks, pd.DataFrame) or isinstance( + data_blocks, pl.DataFrame + ): return coerce_dataframe(data_blocks) elif isinstance(data_blocks, dict): return coerce_dict(data_blocks) @@ -54,11 +55,14 @@ def coerce_data_blocks( return coerce_list(data_blocks) else: raise ValueError( - f'Expected \ + f"Expected \ {pd.DataFrame}, \ {Dict[str, pd.DataFrame]} \ or {List[pd.DataFrame]}, \ - got {type(data_blocks)}' + or {pl.DataFrame}, \ + {Dict[str, pl.DataFrame]} \ + or {List[pl.DataFrame]}, \ + got {type(data_blocks)}" ) def write(self): @@ -74,9 +78,9 @@ def write_data_blocks(self): block_name=block_name, data=block, quote_character=self.quote_character, - quote_all_strings=self.quote_all_strings + quote_all_strings=self.quote_all_strings, ) - elif isinstance(block, pd.DataFrame): + elif isinstance(block, pd.DataFrame) or isinstance(block, pl.DataFrame): write_loop_block( file=self.filename, block_name=block_name, @@ -85,112 +89,109 @@ def write_data_blocks(self): separator=self.sep, na_rep=self.na_rep, quote_character=self.quote_character, - quote_all_strings=self.quote_all_strings + quote_all_strings=self.quote_all_strings, ) def backup_if_file_exists(self): if self.filename.exists(): - new_name = self.filename.name + '~' + new_name = self.filename.name + "~" backup_path = self.filename.resolve().parent / new_name self.filename.rename(backup_path) -def coerce_dataframe(df: pd.DataFrame) -> Dict[str, DataBlock]: - return {'': df} +def coerce_dataframe(df: pd.DataFrame | pl.DataFrame) -> dict[str, DataBlock]: + return {"": df} def coerce_dict( - data_blocks: Union[DataBlock, Dict[str, DataBlock]] -) -> Dict[str, DataBlock]: + data_blocks: Union[DataBlock, dict[str, DataBlock]] +) -> dict[str, DataBlock]: """Coerce dict into dict of data blocks.""" # check if data is already Dict[str, DataBlock] for k, v in data_blocks.items(): - if type(v) in (dict, pd.DataFrame): # + if type(v) in (dict, pd.DataFrame, pl.DataFrame): # return data_blocks # coerce if not - return {'': data_blocks} + return {"": data_blocks} -def coerce_list(data_blocks: List[DataBlock]) -> Dict[str, DataBlock]: +def coerce_list(data_blocks: list[DataBlock]) -> dict[str, DataBlock]: """Coerces a list of DataFrames into a dict""" - return {f'{idx}': df for idx, df in enumerate(data_blocks)} + return {f"{idx}": df for idx, df in enumerate(data_blocks)} def write_blank_lines(file: Path, n: int): - with open(file, mode='a') as f: - f.write('\n' * n) + with open(file, mode="a") as f: + f.write("\n" * n) def write_package_info(file: Path): - date = datetime.now().strftime('%d/%m/%Y') - time = datetime.now().strftime('%H:%M:%S') - line = f'# Created by the starfile Python package (version {__version__}) at {time} on {date}' - with open(file, mode='w+') as f: - f.write(f'{line}\n') + date = datetime.now().strftime("%d/%m/%Y") + time = datetime.now().strftime("%H:%M:%S") + line = f"# Created by the starfile Python package (version {__version__}) at {time} on {date}" + with open(file, mode="w+") as f: + f.write(f"{line}\n") def write_simple_block( file: Path, block_name: str, - data: Dict[str, Union[str, int, float]], + data: dict[str, Union[str, int, float]], quote_character: str = '"', - quote_all_strings: bool = False -): + quote_all_strings: bool = False, +): quoted_data = { - k: f"{quote_character}{v}{quote_character}" - if isinstance(v, str) and (quote_all_strings or " " in v or v == "") + k: f"{quote_character}{v}{quote_character}" + if isinstance(v, str) and (quote_all_strings or " " in v or v == "") else v - for k, v - in data.items() + for k, v in data.items() } - formatted_lines = '\n'.join( - [ - f'_{k}\t\t\t{v}' - for k, v - in quoted_data.items() - ] - ) - with open(file, mode='a') as f: - f.write(f'data_{block_name}\n\n') + formatted_lines = "\n".join([f"_{k}\t\t\t{v}" for k, v in quoted_data.items()]) + with open(file, mode="a") as f: + f.write(f"data_{block_name}\n\n") f.write(formatted_lines) - f.write('\n\n\n') + f.write("\n\n\n") def write_loop_block( file: Path, block_name: str, - df: pd.DataFrame, - float_format: str = '%.6f', - separator: str = '\t', - na_rep: str = '', + df: pd.DataFrame | pl.DataFrame, + float_format: int = 6, + separator: str = "\t", + na_rep: str = "", quote_character: str = '"', - quote_all_strings: bool = False + quote_all_strings: bool = False, ): # write header header_lines = [ - f'_{column_name} #{idx}' - for idx, column_name - in enumerate(df.columns, 1) + f"_{column_name} #{idx}" for idx, column_name in enumerate(df.columns, 1) ] - with open(file, mode='a') as f: - f.write(f'data_{block_name}\n\n') - f.write('loop_\n') - f.write('\n'.join(header_lines)) - f.write('\n') - - df = df.map(lambda x: f'{quote_character}{x}{quote_character}' - if isinstance(x, str) and (quote_all_strings or " " in x or x == "") - else x) + with open(file, mode="a") as f: + f.write(f"data_{block_name}\n\n") + f.write("loop_\n") + f.write("\n".join(header_lines)) + f.write("\n") # write data - df.to_csv( - path_or_buf=file, - mode='a', - sep=separator, - header=False, - index=False, - float_format=float_format, - na_rep=na_rep, - quoting=csv.QUOTE_NONE + if isinstance(df, pd.DataFrame): + df = pl.from_pandas(df) + + df = df.with_columns( + pl.col(pl.String).map_elements( + lambda x: f"{quote_character}{x}{quote_character}" + if (quote_all_strings or " " in x or x == "") + else x + ) ) + + with open(file, "a") as fobj: + df.write_csv( + fobj, + separator=separator, + include_header=False, + float_precision=float_format, + null_value=na_rep, + quote_style="never", + ) write_blank_lines(file, n=2) diff --git a/tests/constants.py b/tests/constants.py index 9a9ad68..3e7ae47 100644 --- a/tests/constants.py +++ b/tests/constants.py @@ -1,33 +1,40 @@ from pathlib import Path import pandas as pd +import polars as pl # Directories for test data -test_data_directory = Path(__file__).parent / 'data' -relion_tutorial = test_data_directory / 'relion_tutorial' +test_data_directory = Path(__file__).parent / "data" +relion_tutorial = test_data_directory / "relion_tutorial" # Test files -loop_simple = test_data_directory / 'one_loop.star' -postprocess = test_data_directory / 'postprocess.star' -pipeline = test_data_directory / 'default_pipeline.star' -rln31_style = test_data_directory / 'rln3.1_data_style.star' -single_line_end_of_multiblock = test_data_directory / 'single_line_end_of_multiblock.star' -single_line_middle_of_multiblock = test_data_directory / 'single_line_middle_of_multiblock.star' -optimiser_2d = relion_tutorial / 'run_it025_optimiser_2D.star' -optimiser_3d = relion_tutorial / 'run_it025_optimiser_3D.star' -sampling_2d = relion_tutorial / 'run_it025_sampling_2D.star' -sampling_3d = relion_tutorial / 'run_it025_sampling_3D.star' -non_existant_file = test_data_directory / 'non_existant_file.star' -two_single_line_loop_blocks = test_data_directory / 'two_single_line_loop_blocks.star' -two_basic_blocks = test_data_directory / 'two_basic_blocks.star' -empty_loop = test_data_directory / 'empty_loop.star' -basic_single_quote = test_data_directory / 'basic_single_quote.star' -basic_double_quote = test_data_directory / 'basic_double_quote.star' -loop_single_quote = test_data_directory / 'loop_single_quote.star' -loop_double_quote = test_data_directory / 'loop_double_quote.star' +loop_simple = test_data_directory / "one_loop.star" +postprocess = test_data_directory / "postprocess.star" +pipeline = test_data_directory / "default_pipeline.star" +rln31_style = test_data_directory / "rln3.1_data_style.star" +single_line_end_of_multiblock = ( + test_data_directory / "single_line_end_of_multiblock.star" +) +single_line_middle_of_multiblock = ( + test_data_directory / "single_line_middle_of_multiblock.star" +) +optimiser_2d = relion_tutorial / "run_it025_optimiser_2D.star" +optimiser_3d = relion_tutorial / "run_it025_optimiser_3D.star" +sampling_2d = relion_tutorial / "run_it025_sampling_2D.star" +sampling_3d = relion_tutorial / "run_it025_sampling_3D.star" +non_existant_file = test_data_directory / "non_existant_file.star" +two_single_line_loop_blocks = test_data_directory / "two_single_line_loop_blocks.star" +two_basic_blocks = test_data_directory / "two_basic_blocks.star" +empty_loop = test_data_directory / "empty_loop.star" +basic_single_quote = test_data_directory / "basic_single_quote.star" +basic_double_quote = test_data_directory / "basic_double_quote.star" +loop_single_quote = test_data_directory / "loop_single_quote.star" +loop_double_quote = test_data_directory / "loop_double_quote.star" # Example DataFrame for testing -cars = {'Brand': ['Honda_Civic', 'Toyota_Corolla', 'Ford_Focus', 'Audi_A4'], - 'Price': [22000, 25000, 27000, 35000] - } +cars = { + "Brand": ["Honda_Civic", "Toyota_Corolla", "Ford_Focus", "Audi_A4"], + "Price": [22000, 25000, 27000, 35000], +} test_df = pd.DataFrame.from_dict(cars) +test_df_pl = pl.DataFrame(cars) diff --git a/tests/test_parsing.py b/tests/test_parsing.py index a94bf9e..1e728d8 100644 --- a/tests/test_parsing.py +++ b/tests/test_parsing.py @@ -1,201 +1,229 @@ import time -import pandas as pd import numpy as np +import pandas as pd +import polars as pl import pytest from starfile.parser import StarParser + from .constants import ( + basic_double_quote, + basic_single_quote, + empty_loop, + loop_double_quote, loop_simple, - postprocess, - pipeline, - rln31_style, + loop_single_quote, + non_existant_file, optimiser_2d, optimiser_3d, + pipeline, + postprocess, + rln31_style, sampling_2d, sampling_3d, - single_line_middle_of_multiblock, single_line_end_of_multiblock, - non_existant_file, - two_single_line_loop_blocks, + single_line_middle_of_multiblock, two_basic_blocks, - empty_loop, - basic_single_quote, - basic_double_quote, - loop_single_quote, - loop_double_quote, + two_single_line_loop_blocks, ) -from .utils import generate_large_star_file, remove_large_star_file, million_row_file +from .utils import generate_large_star_file, million_row_file, remove_large_star_file -def test_instantiation(): +@pytest.mark.parametrize("polars", [True, False]) +def test_instantiation(polars): """ Tests instantiation of the StarFile class """ # instantiation with file which exists - s = StarParser(loop_simple) + s = StarParser(loop_simple, polars=polars) # instantiation with non-existant file should fail assert non_existant_file.exists() is False with pytest.raises(FileNotFoundError): - s = StarParser(non_existant_file) + s = StarParser(non_existant_file, polars=polars) -def test_read_loop_block(): +@pytest.mark.parametrize("polars", [True, False]) +def test_read_loop_block(polars): """ Check that loop block is parsed correctly, data has the correct shape """ - parser = StarParser(loop_simple) + parser = StarParser(loop_simple, polars=polars) # Check that only one object is present assert len(parser.data_blocks) == 1 # get dataframe df = list(parser.data_blocks.values())[0] - assert isinstance(df, pd.DataFrame) + assert isinstance(df, pl.DataFrame if polars else pd.DataFrame) # Check shape of dataframe assert df.shape == (16, 12) # check columns expected_columns = [ - 'rlnCoordinateX', - 'rlnCoordinateY', - 'rlnCoordinateZ', - 'rlnMicrographName', - 'rlnMagnification', - 'rlnDetectorPixelSize', - 'rlnCtfMaxResolution', - 'rlnImageName', - 'rlnCtfImage', - 'rlnAngleRot', - 'rlnAngleTilt', - 'rlnAnglePsi', + "rlnCoordinateX", + "rlnCoordinateY", + "rlnCoordinateZ", + "rlnMicrographName", + "rlnMagnification", + "rlnDetectorPixelSize", + "rlnCtfMaxResolution", + "rlnImageName", + "rlnCtfImage", + "rlnAngleRot", + "rlnAngleTilt", + "rlnAnglePsi", ] - assert all(df.columns == expected_columns) + if polars: + assert df.columns == expected_columns + else: + assert all(df.columns == expected_columns) -def test_read_multiblock_file(): +@pytest.mark.parametrize("polars", [True, False]) +def test_read_multiblock_file(polars): """ Check that multiblock STAR files such as postprocess RELION files parse properly """ - parser = StarParser(postprocess) + parser = StarParser(postprocess, polars=polars) assert len(parser.data_blocks) == 3 - assert 'general' in parser.data_blocks - assert isinstance(parser.data_blocks['general'], dict) - assert len(parser.data_blocks['general']) == 6 - columns = list(parser.data_blocks['general'].keys()) + assert "general" in parser.data_blocks + assert isinstance(parser.data_blocks["general"], dict) + assert len(parser.data_blocks["general"]) == 6 + columns = list(parser.data_blocks["general"].keys()) expected_columns = [ - 'rlnFinalResolution', - 'rlnBfactorUsedForSharpening', - 'rlnUnfilteredMapHalf1', - 'rlnUnfilteredMapHalf2', - 'rlnMaskName', - 'rlnRandomiseFrom', + "rlnFinalResolution", + "rlnBfactorUsedForSharpening", + "rlnUnfilteredMapHalf1", + "rlnUnfilteredMapHalf2", + "rlnMaskName", + "rlnRandomiseFrom", ] assert columns == expected_columns - assert 'fsc' in parser.data_blocks - assert isinstance(parser.data_blocks['fsc'], pd.DataFrame) - assert parser.data_blocks['fsc'].shape == (49, 7) + assert "fsc" in parser.data_blocks + assert isinstance( + parser.data_blocks["fsc"], pl.DataFrame if polars else pd.DataFrame + ) + assert parser.data_blocks["fsc"].shape == (49, 7) - assert 'guinier' in parser.data_blocks - assert isinstance(parser.data_blocks['guinier'], pd.DataFrame) - assert parser.data_blocks['guinier'].shape == (49, 3) + assert "guinier" in parser.data_blocks + assert isinstance( + parser.data_blocks["guinier"], pl.DataFrame if polars else pd.DataFrame + ) + assert parser.data_blocks["guinier"].shape == (49, 3) -def test_read_pipeline(): +@pytest.mark.parametrize("polars", [True, False]) +def test_read_pipeline(polars): """ Check that a pipeline.star file is parsed correctly """ - parser = StarParser(pipeline) + parser = StarParser(pipeline, polars=polars) # Check that data match file contents - assert isinstance(parser.data_blocks['pipeline_general'], dict) - assert parser.data_blocks['pipeline_processes'].shape == (31, 4) - assert parser.data_blocks['pipeline_nodes'].shape == (74, 2) - assert parser.data_blocks['pipeline_input_edges'].shape == (48, 2) - assert parser.data_blocks['pipeline_output_edges'].shape == (72, 2) + assert isinstance(parser.data_blocks["pipeline_general"], dict) + assert parser.data_blocks["pipeline_processes"].shape == (31, 4) + assert parser.data_blocks["pipeline_nodes"].shape == (74, 2) + assert parser.data_blocks["pipeline_input_edges"].shape == (48, 2) + assert parser.data_blocks["pipeline_output_edges"].shape == (72, 2) -def test_read_rln31(): +@pytest.mark.parametrize("polars", [True, False]) +def test_read_rln31(polars): """ Check that reading of RELION 3.1 style star files works properly """ - s = StarParser(rln31_style) + s = StarParser(rln31_style, polars=polars) for key, df in s.data_blocks.items(): - assert isinstance(df, pd.DataFrame) + assert isinstance(df, pl.DataFrame if polars else pd.DataFrame) - assert isinstance(s.data_blocks['block_1'], pd.DataFrame) - assert isinstance(s.data_blocks['block_2'], pd.DataFrame) - assert isinstance(s.data_blocks['block_3'], pd.DataFrame) + assert isinstance( + s.data_blocks["block_1"], pl.DataFrame if polars else pd.DataFrame + ) + assert isinstance( + s.data_blocks["block_2"], pl.DataFrame if polars else pd.DataFrame + ) + assert isinstance( + s.data_blocks["block_3"], pl.DataFrame if polars else pd.DataFrame + ) -def test_read_n_blocks(): +@pytest.mark.parametrize("polars", [True, False]) +def test_read_n_blocks(polars): """ Check that passing read_n_blocks allows reading of only a specified number of data blocks from a star file """ # test 1 block - s = StarParser(postprocess, n_blocks_to_read=1) + s = StarParser(postprocess, n_blocks_to_read=1, polars=polars) assert len(s.data_blocks) == 1 # test 2 blocks - s = StarParser(postprocess, n_blocks_to_read=2) + s = StarParser(postprocess, n_blocks_to_read=2, polars=polars) assert len(s.data_blocks) == 2 -def test_single_line_middle_of_multiblock(): - s = StarParser(single_line_middle_of_multiblock) +@pytest.mark.parametrize("polars", [True, False]) +def test_single_line_middle_of_multiblock(polars): + s = StarParser(single_line_middle_of_multiblock, polars=polars) assert len(s.data_blocks) == 2 -def test_single_line_end_of_multiblock(): - s = StarParser(single_line_end_of_multiblock) +@pytest.mark.parametrize("polars", [True, False]) +def test_single_line_end_of_multiblock(polars): + s = StarParser(single_line_end_of_multiblock, polars=polars) assert len(s.data_blocks) == 2 # iterate over dataframes, checking keys, names and shapes for idx, (key, df) in enumerate(s.data_blocks.items()): if idx == 0: - assert key == 'block_1' + assert key == "block_1" assert df.shape == (2, 5) if idx == 1: - assert key == 'block_2' + assert key == "block_2" assert df.shape == (1, 5) -def test_read_optimiser_2d(): - parser = StarParser(optimiser_2d) +@pytest.mark.parametrize("polars", [True, False]) +def test_read_optimiser_2d(polars): + parser = StarParser(optimiser_2d, polars=polars) assert len(parser.data_blocks) == 1 - assert len(parser.data_blocks['optimiser_general']) == 84 + assert len(parser.data_blocks["optimiser_general"]) == 84 -def test_read_optimiser_3d(): - parser = StarParser(optimiser_3d) +@pytest.mark.parametrize("polars", [True, False]) +def test_read_optimiser_3d(polars): + parser = StarParser(optimiser_3d, polars=polars) assert len(parser.data_blocks) == 1 - assert len(parser.data_blocks['optimiser_general']) == 84 + assert len(parser.data_blocks["optimiser_general"]) == 84 -def test_read_sampling_2d(): - parser = StarParser(sampling_2d) +@pytest.mark.parametrize("polars", [True, False]) +def test_read_sampling_2d(polars): + parser = StarParser(sampling_2d, polars=polars) assert len(parser.data_blocks) == 1 - assert len(parser.data_blocks['sampling_general']) == 12 + assert len(parser.data_blocks["sampling_general"]) == 12 -def test_read_sampling_3d(): - parser = StarParser(sampling_3d) +@pytest.mark.parametrize("polars", [True, False]) +def test_read_sampling_3d(polars): + parser = StarParser(sampling_3d, polars=polars) assert len(parser.data_blocks) == 2 - assert len(parser.data_blocks['sampling_general']) == 15 - assert parser.data_blocks['sampling_directions'].shape == (192, 2) + assert len(parser.data_blocks["sampling_general"]) == 15 + assert parser.data_blocks["sampling_directions"].shape == (192, 2) -def test_parsing_speed(): +@pytest.mark.parametrize("polars", [True, False]) +def test_parsing_speed(polars): generate_large_star_file() start = time.time() - s = StarParser(million_row_file) + s = StarParser(million_row_file, polars=polars) end = time.time() remove_large_star_file() @@ -203,87 +231,103 @@ def test_parsing_speed(): assert end - start < 1 -def test_two_single_line_loop_blocks(): - parser = StarParser(two_single_line_loop_blocks) +@pytest.mark.parametrize("polars", [True, False]) +def test_two_single_line_loop_blocks(polars): + parser = StarParser(two_single_line_loop_blocks, polars=polars) assert len(parser.data_blocks) == 2 np.testing.assert_array_equal( - parser.data_blocks['block_0'].columns, [f'val{i}' for i in (1, 2, 3)] + parser.data_blocks["block_0"].columns, [f"val{i}" for i in (1, 2, 3)] ) - assert parser.data_blocks['block_0'].shape == (1, 3) + assert parser.data_blocks["block_0"].shape == (1, 3) np.testing.assert_array_equal( - parser.data_blocks['block_1'].columns, [f'col{i}' for i in (1, 2, 3)] + parser.data_blocks["block_1"].columns, [f"col{i}" for i in (1, 2, 3)] ) - assert parser.data_blocks['block_1'].shape == (1, 3) + assert parser.data_blocks["block_1"].shape == (1, 3) -def test_two_basic_blocks(): - parser = StarParser(two_basic_blocks) +@pytest.mark.parametrize("polars", [True, False]) +def test_two_basic_blocks(polars): + parser = StarParser(two_basic_blocks, polars=polars) assert len(parser.data_blocks) == 2 - assert 'block_0' in parser.data_blocks - b0 = parser.data_blocks['block_0'] + assert "block_0" in parser.data_blocks + b0 = parser.data_blocks["block_0"] assert b0 == { - 'val1': 1.0, - 'val2': 2.0, - 'val3': 3.0, + "val1": 1.0, + "val2": 2.0, + "val3": 3.0, } - assert 'block_1' in parser.data_blocks - b1 = parser.data_blocks['block_1'] + assert "block_1" in parser.data_blocks + b1 = parser.data_blocks["block_1"] assert b1 == { - 'col1': 'A', - 'col2': 'B', - 'col3': 'C', + "col1": "A", + "col2": "B", + "col3": "C", } -def test_empty_loop_block(): +@pytest.mark.parametrize("polars", [True, False]) +def test_empty_loop_block(polars): """Parsing an empty loop block should return an empty dataframe.""" - parser = StarParser(empty_loop) + parser = StarParser(empty_loop, polars=polars) assert len(parser.data_blocks) == 1 -@pytest.mark.parametrize("quote_character, filename", [("'", basic_single_quote), - ('"', basic_double_quote), - ]) +@pytest.mark.parametrize( + "quote_character, filename", + [ + ("'", basic_single_quote), + ('"', basic_double_quote), + ], +) def test_quote_basic(quote_character, filename): parser = StarParser(filename) assert len(parser.data_blocks) == 1 - assert parser.data_blocks['']['no_quote_string'] == "noquote" - assert parser.data_blocks['']['quote_string'] == "quote string" - assert parser.data_blocks['']['whitespace_string'] == " " - assert parser.data_blocks['']['empty_string'] == "" - - -@pytest.mark.parametrize("quote_character, filename", [("'", loop_single_quote), - ('"', loop_double_quote), - ]) + assert parser.data_blocks[""]["no_quote_string"] == "noquote" + assert parser.data_blocks[""]["quote_string"] == "quote string" + assert parser.data_blocks[""]["whitespace_string"] == " " + assert parser.data_blocks[""]["empty_string"] == "" + + +@pytest.mark.parametrize( + "quote_character, filename", + [ + ("'", loop_single_quote), + ('"', loop_double_quote), + ], +) def test_quote_loop(quote_character, filename): import math + parser = StarParser(filename) assert len(parser.data_blocks) == 1 - assert parser.data_blocks[''].loc[0, 'no_quote_string'] == "noquote" - assert parser.data_blocks[''].loc[0, 'quote_string'] == "quote string" - assert parser.data_blocks[''].loc[0, 'whitespace_string'] == " " - assert parser.data_blocks[''].loc[0, 'empty_string'] == "" - - assert parser.data_blocks[''].dtypes['number_and_string'] == object - assert parser.data_blocks[''].dtypes['number_and_empty'] == 'float64' - assert parser.data_blocks[''].dtypes['number'] == 'float64' - assert parser.data_blocks[''].dtypes['empty_string_and_normal_string'] == object - - assert math.isnan(parser.data_blocks[''].loc[1, 'number_and_empty']) - assert parser.data_blocks[''].loc[0, 'empty_string_and_normal_string'] == '' - - -def test_parse_as_string(): - parser = StarParser(postprocess, parse_as_string=['rlnFinalResolution', 'rlnResolution']) + assert parser.data_blocks[""].loc[0, "no_quote_string"] == "noquote" + assert parser.data_blocks[""].loc[0, "quote_string"] == "quote string" + assert parser.data_blocks[""].loc[0, "whitespace_string"] == " " + assert parser.data_blocks[""].loc[0, "empty_string"] == "" + + assert parser.data_blocks[""].dtypes["number_and_string"] == object + assert parser.data_blocks[""].dtypes["number_and_empty"] == "float64" + assert parser.data_blocks[""].dtypes["number"] == "float64" + assert parser.data_blocks[""].dtypes["empty_string_and_normal_string"] == object + + assert math.isnan(parser.data_blocks[""].loc[1, "number_and_empty"]) + assert parser.data_blocks[""].loc[0, "empty_string_and_normal_string"] == "" + + +@pytest.mark.parametrize("polars", [True, False]) +def test_parse_as_string(polars): + parser = StarParser( + postprocess, + parse_as_string=["rlnFinalResolution", "rlnResolution"], + polars=polars, + ) # check 'rlnFinalResolution' is parsed as string in general (basic) block - block = parser.data_blocks['general'] - assert type(block['rlnFinalResolution']) == str + block = parser.data_blocks["general"] + assert type(block["rlnFinalResolution"]) == str # check 'rlnResolution' is parsed as string in fsc (loop) block - df = parser.data_blocks['fsc'] - assert df['rlnResolution'].dtype == 'object' - + df = parser.data_blocks["fsc"] + assert df["rlnResolution"].dtype == pl.String if polars else "object" diff --git a/tests/test_writing.py b/tests/test_writing.py index 1396520..14c236e 100644 --- a/tests/test_writing.py +++ b/tests/test_writing.py @@ -1,6 +1,6 @@ +import time from os.path import join as join_path from tempfile import TemporaryDirectory -import time import pandas as pd import pytest @@ -8,47 +8,62 @@ from starfile.parser import StarParser from starfile.writer import StarWriter -from .constants import loop_simple, postprocess, test_data_directory, test_df +from .constants import ( + loop_simple, + postprocess, + test_data_directory, + test_df, + test_df_pl, +) from .utils import generate_large_star_file, remove_large_star_file -def test_write_simple_block(): - s = StarParser(postprocess) - output_file = test_data_directory / 'basic_block.star' + +@pytest.mark.parametrize("polars", [True, False]) +def test_write_simple_block(polars): + s = StarParser(postprocess, polars=polars) + output_file = test_data_directory / "basic_block.star" StarWriter(s.data_blocks, output_file) assert output_file.exists() -def test_write_loop(): - s = StarParser(loop_simple) - output_file = test_data_directory / 'loop_block.star' +@pytest.mark.parametrize("polars", [True, False]) +def test_write_loop(polars): + s = StarParser(loop_simple, polars=polars) + output_file = test_data_directory / "loop_block.star" StarWriter(s.data_blocks, output_file) assert output_file.exists() -def test_write_multiblock(): - s = StarParser(postprocess) - output_file = test_data_directory / 'multiblock.star' +@pytest.mark.parametrize("polars", [True, False]) +def test_write_multiblock(polars): + s = StarParser(postprocess, polars=polars) + output_file = test_data_directory / "multiblock.star" StarWriter(s.data_blocks, output_file) assert output_file.exists() -def test_from_single_dataframe(): - output_file = test_data_directory / 'from_df.star' +@pytest.mark.parametrize("polars", [True, False]) +def test_from_single_dataframe(polars): + output_file = test_data_directory / "from_df.star" - StarWriter(test_df, output_file) + if polars: + StarWriter(test_df_pl, output_file) + else: + StarWriter(test_df, output_file) assert output_file.exists() - s = StarParser(output_file) + s = StarParser(output_file, polars=polars) -def test_create_from_dataframes(): +@pytest.mark.parametrize("polars", [True, False]) +def test_create_from_dataframes(polars): dfs = [test_df, test_df] - output_file = test_data_directory / 'from_list.star' + output_file = test_data_directory / "from_list.star" StarWriter(dfs, output_file) assert output_file.exists() - s = StarParser(output_file) + s = StarParser(output_file, polars=polars) assert len(s.data_blocks) == 2 @@ -63,28 +78,36 @@ def test_can_write_non_zero_indexed_one_row_dataframe(): with open(filename) as output_file: output = output_file.read() - expected = ( - "_A #1\n" - "_B #2\n" - "_C #3\n" - "1\t2\t3" + expected = "_A #1\n" "_B #2\n" "_C #3\n" "1\t2\t3" + assert expected in output + + +@pytest.mark.parametrize( + "quote_character, quote_all_strings, num_quotes", + [('"', False, 6), ('"', True, 8), ("'", False, 6), ("'", True, 8)], +) +def test_string_quoting_loop_datablock( + quote_character, quote_all_strings, num_quotes, tmp_path +): + df = pd.DataFrame( + [[1, "nospace", "String with space", " ", ""]], + columns=[ + "a_number", + "string_without_space", + "string_space", + "just_space", + "empty_string", + ], ) - assert (expected in output) - - -@pytest.mark.parametrize("quote_character, quote_all_strings, num_quotes", - [('"', False, 6), - ('"', True, 8), - ("'", False, 6), - ("'", True, 8) - ]) -def test_string_quoting_loop_datablock(quote_character, quote_all_strings, num_quotes, tmp_path): - df = pd.DataFrame([[1,"nospace", "String with space", " ", ""]], - columns=["a_number","string_without_space", "string_space", "just_space", "empty_string"]) filename = tmp_path / "test.star" - StarWriter(df, filename, quote_character=quote_character, quote_all_strings=quote_all_strings) - + StarWriter( + df, + filename, + quote_character=quote_character, + quote_all_strings=quote_all_strings, + ) + # Test for the appropriate number of quotes with open(filename) as f: star_content = f.read() @@ -93,6 +116,7 @@ def test_string_quoting_loop_datablock(quote_character, quote_all_strings, num_q s = StarParser(filename) assert df.equals(s.data_blocks[""]) + def test_writing_speed(): start = time.time() generate_large_star_file() @@ -102,24 +126,30 @@ def test_writing_speed(): # Check that execution takes less than a second assert end - start < 1 -@pytest.mark.parametrize("quote_character, quote_all_strings, num_quotes", - [('"', False, 6), - ('"', True, 8), - ("'", False, 6), - ("'", True, 8) - ]) -def test_string_quoting_simple_datablock(quote_character, quote_all_strings,num_quotes, tmp_path): + +@pytest.mark.parametrize( + "quote_character, quote_all_strings, num_quotes", + [('"', False, 6), ('"', True, 8), ("'", False, 6), ("'", True, 8)], +) +def test_string_quoting_simple_datablock( + quote_character, quote_all_strings, num_quotes, tmp_path +): o = { "a_number": 1, "string_without_space": "nospace", "string_space": "String with space", "just_space": " ", - "empty_string": "" + "empty_string": "", } filename = tmp_path / "test.star" - StarWriter(o, filename, quote_character=quote_character, quote_all_strings=quote_all_strings) - + StarWriter( + o, + filename, + quote_character=quote_character, + quote_all_strings=quote_all_strings, + ) + # Test for the appropriate number of quotes with open(filename) as f: star_content = f.read()