Skip to content

Commit

Permalink
[ENH] Allow multithreading in eddy (#743)
Browse files Browse the repository at this point in the history
Allow multithreading in eddy
  • Loading branch information
mattcieslak authored May 2, 2024
1 parent 3fa1af1 commit 230aed7
Showing 1 changed file with 5 additions and 15 deletions.
20 changes: 5 additions & 15 deletions qsiprep/interfaces/eddy.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,10 @@ def _run_interface(self, runtime):
return runtime


class ExtendedEddyInputSpec(fsl.epi.EddyInputSpec):
num_threads = traits.Int(1, usedefault=True, argstr="--nthr=%d")


class ExtendedEddyOutputSpec(fsl.epi.EddyOutputSpec):
shell_PE_translation_parameters = File(
exists=True, desc=("the translation along the PE-direction between the different shells")
Expand Down Expand Up @@ -169,29 +173,15 @@ class ExtendedEddyOutputSpec(fsl.epi.EddyOutputSpec):


class ExtendedEddy(fsl.Eddy):
input_spec = ExtendedEddyInputSpec
output_spec = ExtendedEddyOutputSpec

_num_threads = 1

def __init__(self, **inputs):
super(ExtendedEddy, self).__init__(**inputs)
self.inputs.on_trait_change(self._num_threads_update, "num_threads")
if not isdefined(self.inputs.num_threads):
self.inputs.num_threads = self._num_threads
else:
self._num_threads_update()
self.inputs.on_trait_change(self._use_cuda, "use_cuda")
if isdefined(self.inputs.use_cuda):
self._use_cuda()

def _num_threads_update(self):
self._num_threads = self.inputs.num_threads
if not isdefined(self.inputs.num_threads):
if "OMP_NUM_THREADS" in self.inputs.environ:
del self.inputs.environ["OMP_NUM_THREADS"]
else:
self.inputs.environ["OMP_NUM_THREADS"] = str(self.inputs.num_threads)

def _use_cuda(self):
self._cmd = "eddy_cuda10.2" if self.inputs.use_cuda else "eddy_cpu"

Expand Down

0 comments on commit 230aed7

Please sign in to comment.