diff --git a/pyproject.toml b/pyproject.toml index 6947df8..d0d776a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,7 +42,9 @@ dependencies = [ "matplotlib", "mrcfile", "plotille", - "rich" + "rich", + "starfile", + "pydantic>=2.0" ] # https://peps.python.org/pep-0621/#dependencies-optional-dependencies diff --git a/src/ttfsc/_cli.py b/src/ttfsc/_cli.py index 2845987..ac5da1e 100644 --- a/src/ttfsc/_cli.py +++ b/src/ttfsc/_cli.py @@ -97,6 +97,7 @@ def ttfsc_cli( estimated_resolution_frequency_pixel, correction_from_resolution_angstrom, fsc_values_corrected, + fsc_values_randomized, ) = calculate_noise_injected_fsc( map1_tensor, map2_tensor, @@ -113,6 +114,45 @@ def ttfsc_cli( f"criterion with correction after {correction_from_resolution_angstrom:.2f} Å: " f"{estimated_resolution_angstrom:.2f} Å" ) + if save_starfile: + import pandas as pd + import starfile + from numpy import nan + + from ._starfile_schema import RelionDataGeneral, RelionFSCData + + data_general = RelionDataGeneral( + rlnFinalResolution=estimated_resolution_angstrom, + rlnUnfilteredMapHalf1=map1.name, + rlnUnfilteredMapHalf2=map2.name, + ) + if mask != Masking.none: + data_general.rlnParticleBoxFractionSolventMask = mask_tensor.sum().item() / mask_tensor.numel() + if correct_for_masking: + data_general.rlnRandomiseFrom = correction_from_resolution_angstrom + + fsc_data = [] + for i, (f, r) in enumerate(zip(fsc_values_unmasked, resolution_angstroms)): + fsc_data.append( + RelionFSCData( + rlnSpectralIndex=i, + rlnResolution=r, + rlnAngstromResolution=r, + rlnFourierShellCorrelationCorrected=fsc_values_corrected[i] if correct_for_masking else nan, + rlnFourierShellCorrelationUnmaskedMaps=f, + rlnFourierShellCorrelationMaskedMaps=fsc_values_masked[i] if fsc_values_masked is not None else nan, + rlnCorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps=fsc_values_randomized[i] + if correct_for_masking + else nan, + rlnFourierShellCorrelationParticleMaskFraction=1 + - ((1 - f) * data_general.rlnParticleBoxFractionSolventMask) + if mask != Masking.none + else nan, + ) + ) + starfile.write( + {"general": data_general.__dict__, "fsc": pd.DataFrame([f.__dict__ for f in fsc_data])}, save_starfile + ) if plot: from ._plotting import plot_matplotlib, plot_plottile diff --git a/src/ttfsc/_masking.py b/src/ttfsc/_masking.py index afe9a2d..ce41950 100644 --- a/src/ttfsc/_masking.py +++ b/src/ttfsc/_masking.py @@ -20,7 +20,7 @@ def calculate_noise_injected_fsc( estimated_resolution_frequency_pixel: float, correct_from_resolution: Optional[float] = None, correct_from_fraction_of_estimated_resolution: float = 0.5, -): +) -> tuple[float, float, float, torch.tensor, torch.tensor]: from torch_grid_utils import fftfreq_grid map1_tensor_randomized = torch.fft.rfftn(map1_tensor) @@ -75,6 +75,7 @@ def calculate_noise_injected_fsc( estimated_resolution_frequency_pixel, correct_from_resolution, fsc_values_corrected, + fsc_values_masked_randomized, ) @@ -106,7 +107,7 @@ def calculate_masked_fsc( mask_tensor[inside_sphere] = 1 # if requested, a soft edge is added to the mask - mask_tensor = add_soft_edge(mask_tensor, mask_soft_edge_width_pixels) + mask_tensor = torch.tensor(add_soft_edge(mask_tensor, mask_soft_edge_width_pixels)) map1_tensor_masked = map1_tensor * mask_tensor map2_tensor_masked = map2_tensor * mask_tensor diff --git a/src/ttfsc/_starfile_schema.py b/src/ttfsc/_starfile_schema.py new file mode 100644 index 0000000..e3b9215 --- /dev/null +++ b/src/ttfsc/_starfile_schema.py @@ -0,0 +1,29 @@ +from typing import List + +from numpy import nan +from pydantic import BaseModel + + +class RelionDataGeneral(BaseModel): + rlnFinalResolution: float + rlnUnfilteredMapHalf1: str + rlnUnfilteredMapHalf2: str + rlnParticleBoxFractionSolventMask: float = nan + rlnRandomiseFrom: float = nan + rlnMaskName: str = "" + + +class RelionFSCData(BaseModel): + rlnSpectralIndex: int + rlnResolution: float + rlnAngstromResolution: float + rlnFourierShellCorrelationCorrected: float + rlnFourierShellCorrelationParticleMaskFraction: float + rlnFourierShellCorrelationUnmaskedMaps: float + rlnFourierShellCorrelationMaskedMaps: float + rlnCorrectedFourierShellCorrelationPhaseRandomizedMaskedMaps: float + + +class RelionStarfile(BaseModel): + data_general: RelionDataGeneral + fsc_data: List[RelionFSCData]