diff --git a/qsiprep/interfaces/eddy.py b/qsiprep/interfaces/eddy.py index 36130a36..be2567ad 100644 --- a/qsiprep/interfaces/eddy.py +++ b/qsiprep/interfaces/eddy.py @@ -141,6 +141,10 @@ def _run_interface(self, runtime): return runtime +class ExtendedEddyInputSpec(fsl.epi.EddyInputSpec): + num_threads = traits.Int(1, usedefault=True, argstr="--nthr=%d") + + class ExtendedEddyOutputSpec(fsl.epi.EddyOutputSpec): shell_PE_translation_parameters = File( exists=True, desc=("the translation along the PE-direction between the different shells") @@ -169,29 +173,15 @@ class ExtendedEddyOutputSpec(fsl.epi.EddyOutputSpec): class ExtendedEddy(fsl.Eddy): + input_spec = ExtendedEddyInputSpec output_spec = ExtendedEddyOutputSpec - _num_threads = 1 - def __init__(self, **inputs): super(ExtendedEddy, self).__init__(**inputs) - self.inputs.on_trait_change(self._num_threads_update, "num_threads") - if not isdefined(self.inputs.num_threads): - self.inputs.num_threads = self._num_threads - else: - self._num_threads_update() self.inputs.on_trait_change(self._use_cuda, "use_cuda") if isdefined(self.inputs.use_cuda): self._use_cuda() - def _num_threads_update(self): - self._num_threads = self.inputs.num_threads - if not isdefined(self.inputs.num_threads): - if "OMP_NUM_THREADS" in self.inputs.environ: - del self.inputs.environ["OMP_NUM_THREADS"] - else: - self.inputs.environ["OMP_NUM_THREADS"] = str(self.inputs.num_threads) - def _use_cuda(self): self._cmd = "eddy_cuda10.2" if self.inputs.use_cuda else "eddy_cpu"