Skip to content

Commit

Permalink
Merge pull request #62 from jahooker/main
Browse files Browse the repository at this point in the history
`StarWriter.lines`
  • Loading branch information
jojoelfe authored Jul 23, 2024
2 parents 812c6bb + 4e311c5 commit b6ebefa
Show file tree
Hide file tree
Showing 7 changed files with 127 additions and 83 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
/tests/data/from_list.star
/tests/data/test_overwrite_flag.star
/tests/data/test_write_with_float_format.star
/tests/data/test_overwrite_backup.star
/build/
/dist/
/m2relion/
Expand Down
2 changes: 1 addition & 1 deletion src/starfile/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .functions import read, write
from .functions import read, write, to_string
35 changes: 35 additions & 0 deletions src/starfile/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,39 @@ def write(
separator=sep,
quote_character=quote_character,
quote_all_strings=quote_all_strings,
).write()


def to_string(
data: Union[DataBlock, Dict[str, DataBlock], List[DataBlock]],
float_format: str = '%.6f',
sep: str = '\t',
na_rep: str = '<NA>',
quote_character: str = '"',
quote_all_strings: bool = False,
**kwargs
):
"""Represent data in the STAR format.
Parameters
----------
data: DataBlock | Dict[str, DataBlock] | List[DataBlock]
Data to represent. DataBlocks are dictionaries or dataframes.
If a dictionary of datablocks are passed the keys will be the data block names.
float_format: str
Float format string which will be passed to pandas.
sep: str
Separator between values, will be passed to pandas.
na_rep: str
Representation of null values, will be passed to pandas.
"""
writer = StarWriter(
data,
filename=None,
float_format=float_format,
na_rep=na_rep,
separator=sep,
quote_character=quote_character,
quote_all_strings=quote_all_strings,
)
return ''.join(line + '\n' for line in writer.lines())
139 changes: 67 additions & 72 deletions src/starfile/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from datetime import datetime
from pathlib import Path
from typing import TYPE_CHECKING, Union, Dict, List
from typing import TYPE_CHECKING, Union, Dict, List, Generator, Optional
from importlib.metadata import version
import csv

Expand All @@ -21,7 +21,7 @@ class StarWriter:
def __init__(
self,
data_blocks: Union[DataBlock, Dict[str, DataBlock], List[DataBlock]],
filename: PathLike,
filename: Optional[PathLike] = None,
float_format: str = '%.6f',
separator: str = '\t',
na_rep: str = '<NA>',
Expand All @@ -31,16 +31,17 @@ def __init__(
# coerce data
self.data_blocks = self.coerce_data_blocks(data_blocks)

# write
self.filename = Path(filename)
if filename is not None:
self.filename = Path(filename)
else:
self.filename = None
self.float_format = float_format
self.sep = separator
self.na_rep = na_rep
self.quote_character = quote_character
self.quote_all_strings = quote_all_strings
self.buffer = TextBuffer()
self.backup_if_file_exists()
self.write()

def coerce_data_blocks(
self,
Expand All @@ -61,35 +62,43 @@ def coerce_data_blocks(
got {type(data_blocks)}'
)

def lines(self) -> Generator[str, None, None]:
yield package_info()
yield ''
yield ''
for line in self.data_block_generator():
yield line

def write(self):
write_package_info(self.filename)
write_blank_lines(self.filename, n=2)
self.write_data_blocks()
if self.filename is None:
raise ValueError('Cannot write nameless file!')
with open(self.filename, mode='w+') as f:
f.writelines(line + '\n' for line in self.lines())

def write_data_blocks(self):
def data_block_generator(self) -> Generator[str, None, None]:
for block_name, block in self.data_blocks.items():
if isinstance(block, dict):
write_simple_block(
file=self.filename,
for line in simple_block(
block_name=block_name,
data=block,
quote_character=self.quote_character,
quote_all_strings=self.quote_all_strings
)
):
yield line
elif isinstance(block, pd.DataFrame):
write_loop_block(
file=self.filename,
for line in loop_block(
block_name=block_name,
df=block,
float_format=self.float_format,
separator=self.sep,
na_rep=self.na_rep,
quote_character=self.quote_character,
quote_all_strings=self.quote_all_strings
)
):
yield line

def backup_if_file_exists(self):
if self.filename.exists():
if self.filename and self.filename.exists():
new_name = self.filename.name + '~'
backup_path = self.filename.resolve().parent / new_name
if backup_path.exists():
Expand Down Expand Up @@ -118,81 +127,67 @@ def coerce_list(data_blocks: List[DataBlock]) -> Dict[str, DataBlock]:
return {f'{idx}': df for idx, df in enumerate(data_blocks)}


def write_blank_lines(file: Path, n: int):
with open(file, mode='a') as f:
f.write('\n' * n)


def write_package_info(file: Path):
def package_info():
date = datetime.now().strftime('%d/%m/%Y')
time = datetime.now().strftime('%H:%M:%S')
line = f'# Created by the starfile Python package (version {__version__}) at {time} on {date}'
with open(file, mode='w+') as f:
f.write(f'{line}\n')
return f'# Created by the starfile Python package (version {__version__}) at {time} on {date}'


def quote(x, *,
quote_character: str = '"',
quote_all_strings: bool = False) -> str:
if isinstance(x, str) and (quote_all_strings or ' ' in x or not x):
return f'{quote_character}{x}{quote_character}'
return x

def write_simple_block(
file: Path,

def simple_block(
block_name: str,
data: Dict[str, Union[str, int, float]],
quote_character: str = '"',
quote_all_strings: bool = False
):
quoted_data = {
k: f"{quote_character}{v}{quote_character}"
if isinstance(v, str) and (quote_all_strings or " " in v or v == "")
else v
for k, v
in data.items()
}
formatted_lines = '\n'.join(
[
f'_{k}\t\t\t{v}'
for k, v
in quoted_data.items()
]
)
with open(file, mode='a') as f:
f.write(f'data_{block_name}\n\n')
f.write(formatted_lines)
f.write('\n\n\n')


def write_loop_block(
file: Path,
) -> Generator[str, None, None]:

yield f'data_{block_name}'
yield ''
for k, v in data.items():
yield f'_{k}\t\t\t{quote(v, quote_character=quote_character, quote_all_strings=quote_all_strings)}'
yield ''
yield ''


def loop_block(
block_name: str,
df: pd.DataFrame,
float_format: str = '%.6f',
separator: str = '\t',
na_rep: str = '<NA>',
quote_character: str = '"',
quote_all_strings: bool = False
):
# write header
header_lines = [
f'_{column_name} #{idx}'
for idx, column_name
in enumerate(df.columns, 1)
]
with open(file, mode='a') as f:
f.write(f'data_{block_name}\n\n')
f.write('loop_\n')
f.write('\n'.join(header_lines))
f.write('\n')

df = df.map(lambda x: f'{quote_character}{x}{quote_character}'
if isinstance(x, str) and (quote_all_strings or " " in x or x == "")
else x)

# write data
df.to_csv(
path_or_buf=file,
) -> Generator[str, None, None]:

# Header
yield f'data_{block_name}'
yield ''
yield 'loop_'
for idx, column_name in enumerate(df.columns, 1):
yield f'_{column_name} #{idx}'

# Data
for line in df.map(lambda x:
quote(x,
quote_character=quote_character,
quote_all_strings=quote_all_strings)
).to_csv(
mode='a',
sep=separator,
header=False,
index=False,
float_format=float_format,
na_rep=na_rep,
quoting=csv.QUOTE_NONE
)
write_blank_lines(file, n=2)
).split('\n'):
yield line

yield ''
yield ''
8 changes: 8 additions & 0 deletions tests/test_functional_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,11 @@ def test_read_non_existent_file():

with pytest.raises(FileNotFoundError):
starfile.read(f)


def test_generate_string():
star_string = starfile.to_string(test_df)
output_file = test_data_directory / "test_write.star"
starfile.write(test_df, output_file, overwrite=True)
with open(output_file, "r") as f:
assert f.read() == star_string
4 changes: 2 additions & 2 deletions tests/test_read_write_round_trip.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def test_round_trip_postprocess(tmp_path):
assert _actual == _expected


def test_write_read_write_read():
filename = 'tmp.star'
def test_write_read_write_read(tmp_path):
filename = tmp_path / 'tmp.star'
df_a = pd.DataFrame({'a': [0, 1], 'b': [2, 3]})
starfile.write(df_a, filename)

Expand Down
21 changes: 13 additions & 8 deletions tests/test_writing.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,28 +14,28 @@
def test_write_simple_block():
s = StarParser(postprocess)
output_file = test_data_directory / 'basic_block.star'
StarWriter(s.data_blocks, output_file)
StarWriter(s.data_blocks, output_file).write()
assert output_file.exists()


def test_write_loop():
s = StarParser(loop_simple)
output_file = test_data_directory / 'loop_block.star'
StarWriter(s.data_blocks, output_file)
StarWriter(s.data_blocks, output_file).write()
assert output_file.exists()


def test_write_multiblock():
s = StarParser(postprocess)
output_file = test_data_directory / 'multiblock.star'
StarWriter(s.data_blocks, output_file)
StarWriter(s.data_blocks, output_file).write()
assert output_file.exists()


def test_from_single_dataframe():
output_file = test_data_directory / 'from_df.star'

StarWriter(test_df, output_file)
StarWriter(test_df, output_file).write()
assert output_file.exists()

s = StarParser(output_file)
Expand All @@ -45,7 +45,7 @@ def test_create_from_dataframes():
dfs = [test_df, test_df]

output_file = test_data_directory / 'from_list.star'
StarWriter(dfs, output_file)
StarWriter(dfs, output_file).write()
assert output_file.exists()

s = StarParser(output_file)
Expand All @@ -59,7 +59,7 @@ def test_can_write_non_zero_indexed_one_row_dataframe():

with TemporaryDirectory() as directory:
filename = join_path(directory, "test.star")
StarWriter(df, filename)
StarWriter(df, filename).write()
with open(filename) as output_file:
output = output_file.read()

Expand All @@ -83,7 +83,7 @@ def test_string_quoting_loop_datablock(quote_character, quote_all_strings, num_q
columns=["a_number","string_without_space", "string_space", "just_space", "empty_string"])

filename = tmp_path / "test.star"
StarWriter(df, filename, quote_character=quote_character, quote_all_strings=quote_all_strings)
StarWriter(df, filename, quote_character=quote_character, quote_all_strings=quote_all_strings).write()

# Test for the appropriate number of quotes
with open(filename) as f:
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_string_quoting_simple_datablock(quote_character, quote_all_strings,num_
}

filename = tmp_path / "test.star"
StarWriter(o, filename, quote_character=quote_character, quote_all_strings=quote_all_strings)
StarWriter(o, filename, quote_character=quote_character, quote_all_strings=quote_all_strings).write()

# Test for the appropriate number of quotes
with open(filename) as f:
Expand All @@ -127,3 +127,8 @@ def test_string_quoting_simple_datablock(quote_character, quote_all_strings,num_

s = StarParser(filename)
assert o == s.data_blocks[""]


def test_no_filename_error():
with pytest.raises(ValueError):
StarWriter(test_df).write()

0 comments on commit b6ebefa

Please sign in to comment.