Skip to content

Commit

Permalink
[MNT] Add threads limit for numba and TF (#1604)
Browse files Browse the repository at this point in the history
* Add threads limit for numba and TF

* pre-commit fix

* fix pre-commit

* Fix precommit

* Add env variable setting and move imports

* Update numba import

---------

Co-authored-by: Antoine GUILLAUME <[email protected]>
  • Loading branch information
baraline and baraline authored Jun 16, 2024
1 parent d4c2cf5 commit 6228c09
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 3 deletions.
2 changes: 1 addition & 1 deletion aeon/utils/numba/general.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def generate_new_default_njit_func(
base_func_py = base_func
else:
raise TypeError(
"Expected base_func to be of callable or CPUDispatcher type (numba "
"Expected base_func to be of type callable or CPUDispatcher type (numba "
f"function), but got {type(base_func)}"
)
signature = inspect.signature(base_func_py)
Expand Down
32 changes: 30 additions & 2 deletions conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

__maintainer__ = []

from aeon.testing import test_config


def pytest_addoption(parser):
"""Pytest command line parser options adder."""
Expand All @@ -28,5 +26,35 @@ def pytest_addoption(parser):

def pytest_configure(config):
"""Pytest configuration preamble."""
import os

# Must be called before any numpy imports
os.environ["MKL_NUM_THREADS"] = "1"
os.environ["NUMEXPR_NUM_THREADS"] = "1"
os.environ["OMP_NUM_THREADS"] = "1"
os.environ["OPENBLAS_NUM_THREADS"] = "1"
os.environ["VECLIB_MAXIMUM_THREADS"] = "1"

import numba

from aeon.testing import test_config
from aeon.utils.validation._dependencies import _check_soft_dependencies

numba.set_num_threads(1)

if _check_soft_dependencies("tensorflow", severity="none"):
from tensorflow.config.threading import (
set_inter_op_parallelism_threads,
set_intra_op_parallelism_threads,
)

set_inter_op_parallelism_threads(1)
set_intra_op_parallelism_threads(1)

if _check_soft_dependencies("torch", severity="none"):
import torch

torch.set_num_threads(1)

if config.getoption("--prtesting") in [True, "True", "true"]:
test_config.PR_TESTING = True

0 comments on commit 6228c09

Please sign in to comment.