Skip to content

Commit f3a6a7c

Browse files
View and Set N Jobs (#1029)
* [Style] isort preference applied * [Style] isort preference applied * reset n_jobs to default after testing * simplify imports
1 parent ac413b6 commit f3a6a7c

File tree

2 files changed

+39
-0
lines changed

2 files changed

+39
-0
lines changed

tests/units/utilities/test_distribution.py

+14
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,23 @@
1616
LocalDaskDistributor,
1717
MultiprocessingDistributor,
1818
)
19+
from tsfresh.utilities.profiling import get_n_jobs, set_n_jobs
1920

2021

2122
class MultiprocessingDistributorTestCase(TestCase):
23+
def test_n_jobs(self):
24+
curr = get_n_jobs()
25+
self.assertEqual(curr, get_n_jobs())
26+
27+
set_n_jobs(2)
28+
self.assertEqual(2, get_n_jobs())
29+
30+
set_n_jobs(4)
31+
self.assertEqual(4, get_n_jobs())
32+
33+
set_n_jobs(curr)
34+
self.assertEqual(curr, get_n_jobs())
35+
2236
def test_partition(self):
2337

2438
distributor = MultiprocessingDistributor(n_workers=1)

tsfresh/utilities/profiling.py

+25
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,8 @@
1010
import logging
1111
import pstats
1212

13+
from tsfresh import defaults
14+
1315
_logger = logging.getLogger(__name__)
1416

1517

@@ -66,3 +68,26 @@ def end_profiling(profiler, filename, sorting=None):
6668
"[calculate_ts_features] Finished profiling of time series feature extraction"
6769
)
6870
f.write(s.getvalue())
71+
72+
73+
def get_n_jobs():
74+
"""
75+
Get the number of jobs to use for parallel processing.
76+
77+
:return: The number of jobs to use for parallel processing.
78+
:rtype: int
79+
"""
80+
return defaults.N_PROCESSES
81+
82+
83+
def set_n_jobs(n_jobs):
84+
"""
85+
Set the number of jobs to use for parallel processing.
86+
87+
:param n_jobs: The number of jobs to use for parallel processing.
88+
:type n_jobs: int
89+
90+
:return: None
91+
:rtype: None
92+
"""
93+
defaults.N_PROCESSES = n_jobs

0 commit comments

Comments
 (0)