diff --git a/README.md b/README.md index acd61de..5f14140 100644 --- a/README.md +++ b/README.md @@ -38,8 +38,7 @@ Explore MemBrain-seg, use it for your needs, and let us know how it works for yo Preliminary [documentation](https://github.com/teamtomo/membrain-seg/blob/main/docs/index.md) is available, but far from perfect. Please let us know if you encounter any issues, and we are more than happy to help (and get feedback what does not work yet). ``` -[1] Isensee, F., Jaeger, P. F., Kohl, S. A., Petersen, J., & Maier-Hein, K. H. (2020). nnU-Net: a self-configuring method -for deep learning-based biomedical image segmentation. Nature Methods, 1-9. +[1] Isensee, F., Jaeger, P.F., Kohl, S.A.A., Petersen, J., Maier-Hein, K.H., 2021. nnU-Net: a self-configuring method for deep learning-based biomedical image segmentation. Nature Methods 18, 203-211. https://doi.org/10.1038/s41592-020-01008-z ``` # Installation @@ -52,8 +51,9 @@ Please find more detailed instructions [here](./docs/Usage/Segmentation.md). ## Preprocessing Currently, we provide the following two [preprocessing](https://github.com/teamtomo/membrain-seg/tree/main/src/tomo_preprocessing) options: -- pixel size matching: Rescale your tomogram to match the training pixel sizes +- Pixel size matching: Rescale your tomogram to match the training pixel sizes - Fourier amplitude matching: Scale Fourier components to match the "style" of different tomograms +- Deconvolution: denoises the tomogram by applying the deconvolution filter from Warp For more information, see the [Preprocessing](./docs/Usage/Preprocessing.md) subsection. diff --git a/src/membrain_seg/tomo_preprocessing/readme.md b/src/membrain_seg/tomo_preprocessing/README.md similarity index 71% rename from src/membrain_seg/tomo_preprocessing/readme.md rename to src/membrain_seg/tomo_preprocessing/README.md index f68456e..8809e6a 100644 --- a/src/membrain_seg/tomo_preprocessing/readme.md +++ b/src/membrain_seg/tomo_preprocessing/README.md @@ -29,6 +29,7 @@ already some rules of thumb: 2. The Fourier amplitude matching only works in some cases, depending on the CTFs of input and target tomograms. Our current recommendation is: If you're not satisfied with MemBrain's segmentation performance, why not give the amplitude matching a shot? +3. Deconvolution can help with the segmentation performance if your tomogram has not already been denoised somehow (e.g. using cryo-CARE, IsoNet or Warp). Deconvolving an already denoised tomogram is not recommended, it will most likely make things worse. More detailed guidelines are in progress! @@ -53,7 +54,8 @@ For help on a specific command, use: `tomo_preprocessing extract_spectrum --input_path --output_path ` - **match_spectrum**: Match amplitude of Fourier spectrum from input tomogram to target spectrum. Example: `tomo_preprocessing match_spectrum --input --target --output ` - +- **deconvolve**: Denoises the tomogram by deconvolving the contrast transfer function. Example: +`tomo_preprocessing deconvolve --input --output --df ` ### **Pixel Size Matching** Pixel size matching is recommended when your tomogram pixel sizes differs strongly from the training pixel size range (roughly 10-14Å). You can perform it using the command @@ -78,4 +80,15 @@ Fourier amplitude matching is performed in two steps: This extracts the radially averaged Fourier spectrum and stores it into a .tsv file. 2. Matching of the input tomogram to the extracted spectrum: `tomo_preprocessing match_spectrum --input --target --output ` -Now, the input tomograms Fourier components are re-scaled based on the equalization kernel computed from the input tomogram's radially averaged Fourier intensities, and the previously extracted .tsv file. \ No newline at end of file +Now, the input tomograms Fourier components are re-scaled based on the equalization kernel computed from the input tomogram's radially averaged Fourier intensities, and the previously extracted .tsv file. + +### **Deconvolution** + +Deconvolution is a denoising method that works by "removing" the effects of the contrast transfer function (CTF) from the tomogram. This is based on an ad-hoc model of the spectral signal-to-noise-ratio (SSNR) in the data, following the implementation in the Warp package [1]. Effectively what the filter does is to boost the very low frequencies, thus enhancing the tomogram contrast, while low-pass filtering beyond the first zero-crossing of the CTF. +For the filter to work, you need to provide the CTF parameters, namely a defocus value for the tomogram, as well as the acceleration voltage, spherical aberration and amplitude contrast, if those differ from the defaults. This is typically the defocus value of the zero tilt. It does not need to be super accurate, a roughly correct value already produces decent results. While the defaults usually work well, you can play with the filter parameters, namely the deconvolution strength and the falloff, to fine-tune the results. +Example detailed command: +`tomo_preprocessing deconvolve --input --output --df 45000 --ampcon 0.07 --cs 2.7 --kv 300 --strength 1.0 --falloff 1.0` + +``` +[1] Tegunov, D., Cramer, P., 2019. Real-time cryo-electron microscopy data preprocessing with Warp. Nature Methods 16, 1146–1152. https://doi.org/10.1038/s41592-019-0580-y +``` \ No newline at end of file diff --git a/src/membrain_seg/tomo_preprocessing/__init__.py b/src/membrain_seg/tomo_preprocessing/__init__.py index be78817..48c4bd4 100644 --- a/src/membrain_seg/tomo_preprocessing/__init__.py +++ b/src/membrain_seg/tomo_preprocessing/__init__.py @@ -7,4 +7,5 @@ # These imports are necessary to register CLI commands. Do not remove! from .amplitude_spectrum_matching._cli import extract, match_spectrum # noqa: F401 from .cli import cli # noqa: F401 +from .deconvolution._cli import deconvolve # noqa: F401 from .pixel_size_matching._cli import match_pixel_size, match_seg_to_tomo # noqa: F401 diff --git a/src/membrain_seg/tomo_preprocessing/deconvolution/__init__.py b/src/membrain_seg/tomo_preprocessing/deconvolution/__init__.py new file mode 100644 index 0000000..8d291f9 --- /dev/null +++ b/src/membrain_seg/tomo_preprocessing/deconvolution/__init__.py @@ -0,0 +1 @@ +"""Empty init.""" diff --git a/src/membrain_seg/tomo_preprocessing/deconvolution/_cli.py b/src/membrain_seg/tomo_preprocessing/deconvolution/_cli.py new file mode 100644 index 0000000..060ed80 --- /dev/null +++ b/src/membrain_seg/tomo_preprocessing/deconvolution/_cli.py @@ -0,0 +1,87 @@ +from typer import Option + +from ..cli import OPTION_PROMPT_KWARGS as PKWARGS +from ..cli import cli +from .deconvolve import deconvolve as run_deconvolve + + +@cli.command(name="deconvolve", no_args_is_help=True) +def deconvolve( + input: str = Option( # noqa: B008 + None, help="Tomogram to deconvolve (.mrc/.rec format)", **PKWARGS + ), + output: str = Option( # noqa: B008 + None, + help="Output location for deconvolved tomogram (.mrc/.rec format)", + **PKWARGS, + ), + pixel_size: float = Option( # noqa: B008 + None, + help="Input pixel size (optional). If not specified, it will be read from the \ +tomogram's header. ATTENTION: This can lead to severe errors if the header pixel size \ +is not correct.", + ), + df: float = Option( # noqa: B008 + 50000, + help="The defocus value to be used for deconvolution, in Angstroms. This is \ +typically the defocus of the zero tilt. Underfocus is positive.", + **PKWARGS, + ), + # df2: float = Option( + # None, + # help="Defocus 2 (or Defocus V in some notations) in Angstroms. Defocus axis \ + # orthogonal to the U axis. Only mandatory for astigmatic data.", + # ), + # ast: float = Option( + # 0.0, + # help="Angle for astigmatic data (in degrees). Astigmatism is currently not \ + # used in deconvolution (only the axis of largest defocus is considered), but maybe\ + # some better model in the future will use it?", + # ), + ampcon: float = Option( # noqa: B008 + 0.07, + help="Amplitude contrast fraction (between 0.0 and 1.0).", + ), + cs: float = Option( # noqa: B008 + 2.7, + help="Spherical aberration (in mm).", + ), + kv: float = Option( # noqa: B008 + 300.0, + help="Acceleration voltage of the TEM (in kV).", + ), + strength: float = Option( # noqa: B008 + 1.0, + help="Strength parameter for the denoising filter.", + ), + falloff: float = Option( # noqa: B008 + 1.0, + help="Falloff parameter for the denoising filter.", + ), + hp_fraction: float = Option( # noqa: B008 + 0.02, + help="Fraction of Nyquist frequency to be cut off on the lower end (since it \ +will be boosted the most)", + ), + skip_lowpass: bool = Option( # noqa: B008 + False, + help="The denoising filter by default will have a smooth low-pass effect that \ +enforces filtering out any information beyond the first zero of the CTF. Use this \ +option to skip this filter i.e. potentially include information beyond the first CTF \ +zero (not recommended).", + ), +): + """Deconvolve the input tomogram using the Warp deconvolution filter.""" + run_deconvolve( + input, + output, + df, + ampcon, + cs, + kv, + pixel_size, + strength, + falloff, + hp_fraction, + skip_lowpass, + ) diff --git a/src/membrain_seg/tomo_preprocessing/deconvolution/deconv_utils.py b/src/membrain_seg/tomo_preprocessing/deconvolution/deconv_utils.py new file mode 100644 index 0000000..9dbf452 --- /dev/null +++ b/src/membrain_seg/tomo_preprocessing/deconvolution/deconv_utils.py @@ -0,0 +1,518 @@ +# Derived from: Python utilities for Focus +# Author: Ricardo Righetto +# E-mail: ricardo.righetto@unibas.ch +# https://github.com/C-CINA/focustools/ + +import warnings + +import numpy as np + +warnings.filterwarnings("ignore", category=RuntimeWarning) + +pi = np.pi # global PI + +# TO-DO: +# Port heavy calculations to Torch or something more efficient than pure NumPy? +# Original implementation used numexpr (see commented code) but that would add one more +# dependency to MemBrain and does not give significant speedups, at least not with the +# defaults. +# For now, we stick to pure Numpy. + + +def RadialIndices( + imsize: tuple = (128, 128), + rounding: bool = True, + normalize: bool = False, + rfft: bool = False, + xyz: tuple = (0, 0, 0), + nozero: bool = True, + nozeroval: float = 1e-3, +): + """ + Generates a 1D/2D/3D array whose values are the distance counted from the origin. + + Parameters + ---------- + imsize : tuple + The shape of the input ndarray. + rounding : bool + Whether the radius values should be rounded to the nearest integer ensuring \ +"perfect" radial symmetry. + normalize : bool + Whether the radius values should be normalized to the range [0.0,11.0]. + rfft : bool + Whether to return an array consistent with np.fft.rfftn i.e. exploiting the \ +Hermitian symmetry of the Fourier transform of real data. + xyz : tuple + Shifts to be applied to the origin specified as (x_shift, y_shift, z_shift). \ +Useful when applying phase shifts. + nozero : bool + Whether the value of the origin (corresponding to the zero frequency or DC \ +component in the Fourier transform) should be set to a small value instead of zero. + nozeroval : float + The value to put at the origin if nozero is True. + + Returns + ------- + rmesh : ndarray + Array whose values are the distance from the origin. + amesh : ndarray + Array whose values are the angle from the x- axis (2D) or from the x,y plane \ +(3D) + + Raises + ------ + ValueError + If imsize with more than 3 dimensions is given. + + Notes + ----- + This function is compliant with NumPy fft.fftfreq() and fft.rfftfreq(). + """ + imsize = np.array(imsize) + + # if np.isscalar(imsize): + # imsize = [imsize, imsize] + + if len(imsize) > 3: + raise ValueError( + "Object should have 2 or 3 dimensions: len(imsize) = %d " % len(imsize) + ) + + xyz = np.flipud(xyz) + + m = np.mod(imsize, 2) # Check if dimensions are odd or even + + if len(imsize) == 1: + # The definition below is consistent with numpy np.fft.fftfreq and + # np.fft.rfftfreq: + + if not rfft: + xmesh = np.mgrid[ + -imsize[0] // 2 + m[0] - xyz[0] : (imsize[0] - 1) // 2 + 1 - xyz[0] + ] + + else: + xmesh = np.mgrid[0 - xyz[0] : imsize[0] // 2 + 1 - xyz[0]] + + rmesh = np.sqrt(xmesh * xmesh) + + amesh = np.zeros(xmesh.shape) + + n = 1 # Normalization factor + + if len(imsize) == 2: + # The definition below is consistent with numpy np.fft.fftfreq and + # np.fft.rfftfreq: + + if not rfft: + [xmesh, ymesh] = np.mgrid[ + -imsize[0] // 2 + m[0] - xyz[0] : (imsize[0] - 1) // 2 + 1 - xyz[0], + -imsize[1] // 2 + m[1] - xyz[1] : (imsize[1] - 1) // 2 + 1 - xyz[1], + ] + + else: + [xmesh, ymesh] = np.mgrid[ + -imsize[0] // 2 + m[0] - xyz[0] : (imsize[0] - 1) // 2 + 1 - xyz[0], + 0 - xyz[1] : imsize[1] // 2 + 1 - xyz[1], + ] + xmesh = np.fft.ifftshift(xmesh) + + rmesh = np.sqrt(xmesh * xmesh + ymesh * ymesh) + + amesh = np.arctan2(ymesh, xmesh) + + n = 2 # Normalization factor + + if len(imsize) == 3: + # The definition below is consistent with numpy np.fft.fftfreq and + # np.fft.rfftfreq: + + if not rfft: + [xmesh, ymesh, zmesh] = np.mgrid[ + -imsize[0] // 2 + m[0] - xyz[0] : (imsize[0] - 1) // 2 + 1 - xyz[0], + -imsize[1] // 2 + m[1] - xyz[1] : (imsize[1] - 1) // 2 + 1 - xyz[1], + -imsize[2] // 2 + m[2] - xyz[2] : (imsize[2] - 1) // 2 + 1 - xyz[2], + ] + + else: + [xmesh, ymesh, zmesh] = np.mgrid[ + -imsize[0] // 2 + m[0] - xyz[0] : (imsize[0] - 1) // 2 + 1 - xyz[0], + -imsize[1] // 2 + m[1] - xyz[1] : (imsize[1] - 1) // 2 + 1 - xyz[1], + 0 - xyz[2] : imsize[2] // 2 + 1 - xyz[2], + ] + xmesh = np.fft.ifftshift(xmesh) + ymesh = np.fft.ifftshift(ymesh) + + rmesh = np.sqrt(xmesh * xmesh + ymesh * ymesh + zmesh * zmesh) + + amesh = np.arccos(zmesh / rmesh) + + n = 3 # Normalization factor + + if rounding: + rmesh = np.round(rmesh) + + if normalize: + a = np.sum(imsize * imsize) + + rmesh = rmesh / (np.sqrt(a) / np.sqrt(n)) + + if nozero: + # Replaces the "zero radius" by a small value to prevent division by zero in + # other programs + idx = rmesh == 0 + rmesh[idx] = nozeroval + + return rmesh, np.nan_to_num(amesh) + + +def CTF( + imsize: tuple = (128, 128), + df1: float = 50000.0, + df2: float = None, + ast: float = 0.0, + ampcon: float = 0.07, + Cs: float = 2.7, + kV: float = 300.0, + apix: float = 1.0, + B: float = 0.0, + rfft: bool = True, +): + """ + Generates 1D, 2D or 3D contrast transfer function (CTF) of a TEM. + + Parameters + ---------- + imsize : tuple + The shape of the input ndarray. + df1 : float + Defocus 1 (or Defocus U in some notations) in Angstroms. Principal defocus \ +axis. Underfocus is positive. + df2 : float + Defocus 2 (or Defocus V in some notations) in Angstroms. Defocus axis \ +orthogonal to the U axis. Only mandatory for astigmatic data. + ast : float + Angle for astigmatic data (in degrees). + ampcon : float + Amplitude contrast fraction (between 0.0 and 1.0). + Cs : float + Spherical aberration (in mm). + kV : float + Acceleration voltage of the TEM (in kV). + apix : float + Input pixel size in Angstroms. + B : float + B-factor in Angstroms**2. + rfft : bool + Whether to return an array consistent with np.fft.rfftn i.e. exploiting the \ +Hermitian symmetry of the Fourier transform of real data. + + Returns + ------- + CTFim : ndarray + Array containing the CTF. + + Notes + ----- + Follows the CTF definition from Mindell & Grigorieff, JSB (2003) \ +(https://doi.org/10.1016/S1047-8477(03)00069-8), which is adopted in FREALIGN/\ +cisTEM, RELION and many other packages. + """ + if not np.isscalar(imsize) and len(imsize) == 1: + imsize = imsize[0] + + Cs *= 1e7 # Convert Cs to Angstroms + + if df2 is None or np.isscalar(imsize): + df2 = df1 + + # NOTATION FOR DEFOCUS1, DEFOCUS2, ASTIGMATISM BELOW IS INVERTED DUE TO NUMPY + # CONVENTION: + # df1, df2 = df2, df1 + + ast *= -pi / 180.0 + + WL = ElectronWavelength(kV) + + w1 = np.sqrt(1 - ampcon * ampcon) + w2 = ampcon + + import warnings + + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=RuntimeWarning) + + if np.isscalar(imsize): + if rfft: + rmesh = np.fft.rfftfreq(imsize) + else: + rmesh = np.fft.fftfreq(imsize) + amesh = 0.0 + + else: + rmesh, amesh = RadialIndices(imsize, normalize=True, rfft=rfft) + + rmesh2 = rmesh**2 / apix**2 + + # From Mindell & Grigorieff, JSB 2003: + df = 0.5 * (df1 + df2 + (df1 - df2) * np.cos(2.0 * (amesh - ast))) + + Xr = np.nan_to_num(pi * WL * rmesh2 * (df - 0.5 * WL * WL * rmesh2 * Cs)) + + CTFim = -w1 * np.sin(Xr) - w2 * np.cos(Xr) + + if B != 0.0: # Apply B-factor only if necessary: + CTFim = CTFim * np.exp(-B * (rmesh2) / 4) + + return CTFim + + +def CorrectCTF( + img=None, + df1: float = 50000.0, + df2: float = None, + ast: float = 0.0, + ampcon: float = 0.07, + invert_contrast: bool = False, + Cs: float = 2.7, + kV: float = 300.0, + apix: float = 1.0, + phase_flip: bool = False, + ctf_multiply: bool = False, + wiener_filter: bool = False, + C: float = 1.0, + return_ctf: bool = False, +): + """ + Applies different types of CTF correction to a 2D or 3D image. + + Parameters + ---------- + img : + The input image to be corrected. + df1 : float + Defocus 1 (or Defocus U in some notations) in Angstroms. Principal defocus \ +axis. Underfocus is positive. + df2 : float + Defocus 2 (or Defocus V in some notations) in Angstroms. Defocus axis \ +orthogonal to the U axis. Only mandatory for astigmatic data. + ast : float + Angle for astigmatic data (in degrees). + ampcon : float + Amplitude contrast fraction (between 0.0 and 1.0). + invert_contrast: bool + Whether to invert the contrast of the input image. + Cs : float + Spherical aberration (in mm). + kV : float + Acceleration voltage of the TEM (in kV). + apix : float + Input pixel size in Angstroms. + phase_flip : bool + Correct CTF by phase-flipping only. This corrects phases but leaves the \ +amplitudes unchanged. + ctf_multiply : bool + Correct CTF by multiplying the input image FT with the CTF. This corrects \ +phases and dampens amplitudes even further near the CTF zeros. + wiener_filter : bool + Correct CTF by applying a Wiener filter. This corrects phases and attempts to \ +restore amplitudes to their original values based on an ad-hoc spectral signal-to-noise\ +ratio (SSNR) value. That means, at frequencies where the SNR is low amplitudes will be\ +restored conservatively, while where the SNR is high these frequencies will be boosted. + C : float + Wiener filter constant (per frequency). A scalar can be given to use the same \ +constant for all frequencies, whereas an 1D array (radial average) can be given to \ +restore each frequency using a different SNR value. + return_ctf : bool + Whether to return an array containing the CTF itself alongside the corrected \ +image. + + Returns + ------- + CTFcor : list + A list containing the corrected image(s), the CTF (if return_ctf is True) and \ +a string indicating the type of correction(s) applied. + + Raises + ------ + ValueError + If any entry of the Wiener filter constant is lower than or equal to zero. + + """ + if df2 is None: + df2 = df1 + + # Direct CTF correction would invert the image contrast. By default we don't do + # that, hence the negative sign: + CTFim = -CTF(img.shape, df1, df2, ast, ampcon, Cs, kV, apix, 0.0, rfft=True) + + if invert_contrast: + CTFim = -CTFim + + pass + + FT = np.fft.rfftn(img) + + if phase_flip: # Phase-flipping + s = np.sign(CTFim) + CTFcor = np.fft.irfftn(FT * s) + + elif ctf_multiply: # CTF multiplication + CTFcor = np.fft.irfftn(FT * CTFim) + + elif wiener_filter: # Wiener filtering + if np.any(C <= 0.0): + raise ValueError( + "Error: Wiener filter contain value(s) less than or equal to zero!" + ) + + CTFcor = np.fft.irfftn(FT * CTFim / (CTFim * CTFim + C)) + + if return_ctf: + return CTFcor, CTFim + + else: + return CTFcor + + +def AdhocSSNR( + imsize: tuple = (128, 128), + apix: float = 1.0, + df: float = 50000.0, + ampcon: float = 0.07, + Cs: float = 2.7, + kV: float = 300.0, + S: float = 1.0, + F: float = 1.0, + hp_frac: float = 0.02, + lp: bool = True, +): + """ + An ad hoc SSNR model for cryo-EM data as proposed by Dimitry Tegunov [1,2]. + + Parameters + ---------- + imsize : tuple + The shape of the input array. + apix : float + Input pixel size in Angstroms. + df : float + Average defocus in Angstroms. + ampcon : float + Amplitude contrast fraction (between 0.0 and 1.0). + Cs : float + Spherical aberration (in mm). + kV : float + Acceleration voltage of the TEM (in kV). + S : float + Strength of the deconvolution to be applied. + F : float + Strength of the SSNR falloff. + hp_frac : float + fraction of Nyquist frequency to be cut off on the lower end (since it will \ +be boosted the most). + lp : bool + Whether to low-pass all information beyond the first zero of the CTF. + + Returns + ------- + ssnr : ndarray + Array containing the radial ad hoc SSNR. + + Notes + ----- + This SSNR model ignores astigmatism. + + References + ---------- + [1] Tegunov & Cramer, Nat. Meth. (2019). https://doi.org/10.1038/s41592-019-0580-y + [2] https://github.com/dtegunov/tom_deconv/blob/master/tom_deconv.m + """ + rmesh = RadialIndices(imsize, rounding=False, normalize=True, rfft=True)[0] / apix + + # The ad hoc SSNR exponential falloff + falloff = np.exp(-100 * rmesh * F) * 10 ** (3 * S) + + # The cosine-shaped high-pass filter. It starts at zero frequency and reaches 1.0 + # at hp_freq (fraction of the Nyquist frequency) + a = np.minimum(1.0, rmesh * apix / hp_frac) + highpass = 1.0 - np.cos(a * pi / 2) + + if lp: + # Ensure the filter will reach zero at the first zero of the CTF + first_zero_res = FirstZeroCTF(df=df, ampcon=ampcon, Cs=Cs, kV=kV) + a = np.minimum(1.0, rmesh / first_zero_res) + + lowpass = np.cos(a * pi / 2) + + # Composite filter + ssnr = highpass * falloff * lowpass + + else: + ssnr = highpass * falloff # Composite filter + + return np.abs(ssnr) + + +def ElectronWavelength(kV: float = 300.0): + """ + Calculates electron wavelength given acceleration voltage. + + Parameters + ---------- + kV : float + Acceleration voltage of the TEM (in kV). + + Returns + ------- + WL : float + A scalar value containing the electron wavelength in Angstroms. + """ + WL = 12.2639 / np.sqrt(kV * 1e3 + 0.97845 * kV * kV) + + return WL + + +def FirstZeroCTF( + df: float = 50000.0, ampcon: float = 0.07, Cs: float = 2.7, kV: float = 300.0 +): + """ + The frequency at which the CTF first crosses zero. + + Parameters + ---------- + df : float + Average defocus in Angstroms. + ampcon : float + Amplitude contrast fraction (between 0.0 and 1.0). + Cs : float + Spherical aberration (in mm). + kV : float + Acceleration voltage of the TEM (in kV). + + Returns + ------- + g : float + A scalar containing the resolution in Angstroms corresponding to the first \ +zero crossing of the CTF. + + Notes + ----- + Finds the resolution at the first zero of the CTF + Wolfram Alpha, solving for -w1 * sinXr - w2 * cosXr = 0 + https://www.wolframalpha.com/input/?i=solve+%CF%80*L*(g%5E2)*(d-1%2F(\ + 2*(L%5E2)*(g%5E2)*C))%3Dn+%CF%80+-+tan%5E(-1)(c%2Fa)+for+g + """ + Cs *= 1e7 # Convert Cs to Angstroms + + w1 = np.sqrt(1 - ampcon * ampcon) + w2 = ampcon + + WL = ElectronWavelength(kV) + + g = np.sqrt(-2 * Cs * WL * np.arctan2(w2, w1) + 2 * pi * Cs * WL + pi) / ( + np.sqrt(2 * pi * Cs * df) * WL + ) + + return g diff --git a/src/membrain_seg/tomo_preprocessing/deconvolution/deconvolve.py b/src/membrain_seg/tomo_preprocessing/deconvolution/deconvolve.py new file mode 100644 index 0000000..f7b215f --- /dev/null +++ b/src/membrain_seg/tomo_preprocessing/deconvolution/deconvolve.py @@ -0,0 +1,135 @@ +from membrain_seg.segmentation.dataloading.data_utils import ( + load_tomogram, + store_tomogram, +) +from membrain_seg.tomo_preprocessing.deconvolution.deconv_utils import ( + AdhocSSNR, + CorrectCTF, +) + + +def deconvolve( + mrcin: str, + mrcout: str, + df: float = 50000.0, + ampcon: float = 0.07, + Cs: float = 2.7, + kV: float = 300.0, + apix: float = None, + strength: float = 1.0, + falloff: float = 1.0, + hp_frac: float = 0.02, + skip_lowpass: bool = True, +) -> None: + """ + Deconvolve the input tomogram using the Warp deconvolution filter. + + Parameters + ---------- + mrcin : str + The file path to the input tomogram to be processed. + mrcout : str + The file path where the processed tomogram will be stored. + df: float + "The defocus value to be used for deconvolution, in Angstroms. This is \ +typically the defocus of the zero tilt. Underfocus is positive." + ampcon: float + Amplitude contrast fraction (between 0.0 and 1.0). + Cs: float + Spherical aberration (in mm). + kV: float + Acceleration voltage of the TEM (in kV). + apix: float + Input pixel size (optional). If not specified, it will be read from the \ +tomogram's header. ATTENTION: This can lead to severe errors if the header pixel \ +size is not correct. + strength: float + Strength parameter for the denoising filter. + falloff: float + Falloff parameter for the denoising filter. + hp_frac : float + fraction of Nyquist frequency to be cut off on the lower end (since it will \ +be boosted the most). + skip_lowpass: bool + The denoising filter by default will have a smooth low-pass effect that \ +enforces filtering out any information beyond the first zero of the CTF. Use this \ +option to skip this filter (i.e. potentially include information beyond the first CTF \ +zero). + + Returns + ------- + None + + Raises + ------ + FileNotFoundError + If the file specified in `mrcin` does not exist. + + Notes + ----- + This function reads the input tomogram and applies the deconvolution filter on it \ +following the Warp implementation (https://doi.org/10.1038/s41592-019-0580-y), then \ +stores the processed tomogram to the specified output path. The deconvolution process \ +is controlled by several parameters including the tomogram defocus, acceleration \ +voltage, spherical aberration, strength and falloff. The implementation here is based \ +on that of the focustools package: https://github.com/C-CINA/focustools/ + """ + tomo = load_tomogram(mrcin) + + if apix is None: + apix = tomo.voxel_size.x + + # if df2 is None: + # df2 = df1 + + print( + "\nDeconvolving input tomogram:\n", + mrcin, + "\noutput will be written as:\n", + mrcout, + "\nusing:", + f"\npixel_size: {apix:.3f}", + f"\ndf: {df:.1f}", + f"\nkV: {kV:.1f}", + f"\nCs: {Cs:.1f}", + f"\nstrength: {strength:.3f}", + f"\nfalloff: {falloff:.3f}", + f"\nhp_fraction: {hp_frac:.3f}", + f"\nskip_lowpass: {skip_lowpass}\n", + ) + print("Deconvolution can take a few minutes, please wait...") + + ssnr = AdhocSSNR( + imsize=tomo.data.shape, + apix=apix, + df=df, + ampcon=ampcon, + Cs=Cs, + kV=kV, + S=strength, + F=falloff, + hp_frac=hp_frac, + lp=not skip_lowpass, + ) + + wiener_constant = 1 / ssnr + + deconvtomo = CorrectCTF( + tomo.data, + df1=df, + ast=0.0, + ampcon=ampcon, + invert_contrast=False, + Cs=Cs, + kV=kV, + apix=apix, + phase_flip=False, + ctf_multiply=False, + wiener_filter=True, + C=wiener_constant, + return_ctf=False, + ) + + store_tomogram(mrcout, deconvtomo, voxel_size=apix) + + print("\nDone!")