Skip to content

Commit 849f1a8

Browse files
committed
Remove pkg_resources usage and replace with modern alternatives
- Replace pkg_resources.parse_version() with packaging.version.parse() - Replace pkg_resources.get_distribution() with importlib.metadata.distribution() - Update imports in cebra/integrations/sklearn/cebra.py, cebra/helper.py, tests/test_sklearn.py, tests/test_plot.py - Fixes deprecation warning: pkg_resources is deprecated as an API and slated for removal as early as 2025-11-30 Resolves #271
1 parent 52ead63 commit 849f1a8

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

cebra/helper.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@
3232

3333
import numpy as np
3434
import numpy.typing as npt
35-
import pkg_resources
35+
import packaging.version
3636
import requests
3737
import torch
3838

@@ -75,8 +75,8 @@ def download_file_from_zip_url(url, *, file):
7575

7676
def _is_mps_availabe(torch):
7777
available = False
78-
if pkg_resources.parse_version(
79-
torch.__version__) >= pkg_resources.parse_version("1.12"):
78+
if packaging.version.parse(
79+
torch.__version__) >= packaging.version.parse("1.12"):
8080
if torch.backends.mps.is_available():
8181
if torch.backends.mps.is_built():
8282
available = True
@@ -159,17 +159,17 @@ def requires_package_version(module, version: str):
159159
the required ``version``.
160160
"""
161161

162-
required_version = pkg_resources.parse_version(version)
162+
required_version = packaging.version.parse(version)
163163

164164
def _requires_package_version(function):
165165

166166
@wraps(function)
167167
def wrapper(*args, patched_version=None, **kwargs):
168168
if patched_version is not None:
169-
installed_version = pkg_resources.parse_version(
169+
installed_version = packaging.version.parse(
170170
patched_version) # Use the patched version if provided
171171
else:
172-
installed_version = pkg_resources.parse_version(
172+
installed_version = packaging.version.parse(
173173
module.__version__)
174174

175175
if installed_version < required_version:

cebra/integrations/sklearn/cebra.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import numpy as np
2929
import numpy.typing as npt
3030
import packaging.version
31-
import pkg_resources
31+
import importlib.metadata
3232
import sklearn
3333
import sklearn.utils.validation as sklearn_utils_validation
3434
import torch
@@ -1397,7 +1397,7 @@ def save(self,
13971397
'numpy_version':
13981398
np.__version__,
13991399
'sklearn_version':
1400-
pkg_resources.get_distribution("scikit-learn"
1400+
importlib.metadata.distribution("scikit-learn"
14011401
).version
14021402
}
14031403
}, filename)

tests/test_plot.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
import matplotlib
2626
import matplotlib.pyplot as plt
2727
import numpy as np
28-
import pkg_resources
28+
import packaging.version
2929
import pytest
3030
import torch
3131
from sklearn.exceptions import NotFittedError
@@ -190,8 +190,8 @@ def test_compare_models_with_different_versions(matplotlib_version):
190190
# minimum version of matplotlib
191191
minimum_version = "3.6"
192192

193-
if pkg_resources.parse_version(
194-
matplotlib_version) < pkg_resources.parse_version(minimum_version):
193+
if packaging.version.parse(
194+
matplotlib_version) < packaging.version.parse(minimum_version):
195195
with pytest.raises(ImportError):
196196
cebra_plot.compare_models(models=fitted_models,
197197
patched_version=matplotlib_version)

tests/test_sklearn.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import _util
2727
import _utils_deprecated
2828
import numpy as np
29-
import pkg_resources
29+
import packaging.version
3030
import pytest
3131
import sklearn.utils.estimator_checks
3232
import torch
@@ -1320,8 +1320,8 @@ def test_check_device():
13201320
with pytest.raises(ValueError):
13211321
cebra_sklearn_utils.check_device(device)
13221322

1323-
if pkg_resources.parse_version(
1324-
torch.__version__) >= pkg_resources.parse_version("1.12"):
1323+
if packaging.version.parse(
1324+
torch.__version__) >= packaging.version.parse("1.12"):
13251325

13261326
device = "mps"
13271327
torch.backends.mps.is_available = lambda: True

0 commit comments

Comments
 (0)