diff --git a/pyproject.toml b/pyproject.toml index 76dbb35..5e77e85 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,7 @@ build-backend = "setuptools.build_meta" [project] name = "ultralytics-thop" -version = "0.0.2" # Placeholder version, needs to be dynamically set +version = "0.0.3" # Placeholder version, needs to be dynamically set description = "A tool to count the FLOPs of PyTorch model." readme = "README.md" requires-python = ">=3.8" @@ -57,6 +57,7 @@ classifiers = [ "Operating System :: Microsoft :: Windows", ] dependencies = [ + "packaging", "torch", ] diff --git a/thop/profile.py b/thop/profile.py index d8451a9..55f2f05 100644 --- a/thop/profile.py +++ b/thop/profile.py @@ -1,4 +1,4 @@ -from distutils.version import LooseVersion +from packaging.version import Version from thop.rnn_hooks import * from thop.vision.basic_hooks import * @@ -7,7 +7,7 @@ # logger.setLevel(logging.INFO) from .utils import prGreen, prRed, prYellow -if LooseVersion(torch.__version__) < LooseVersion("1.0.0"): +if Version(torch.__version__) < Version("1.0.0"): logging.warning( "You are using an old version PyTorch {version}, which THOP does NOT support.".format(version=torch.__version__) ) @@ -61,7 +61,7 @@ nn.PixelShuffle: zero_ops, } -if LooseVersion(torch.__version__) >= LooseVersion("1.1.0"): +if Version(torch.__version__) >= Version("1.1.0"): register_hooks.update({nn.SyncBatchNorm: count_normalization})