Skip to content

Commit

Permalink
Eliminate packaging dependency
Browse files Browse the repository at this point in the history
  • Loading branch information
glenn-jocher committed May 31, 2024
1 parent 5acf27a commit 790ad97
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 26 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
print(f'Local version is {v_local}')
print(f'PyPI version is {v_pypi}')
d = [a - b for a, b in zip(v_local, v_pypi)] # diff
increment = (d[0] == d[1] == 0) and (0 < d[2] < 3) # only publish if patch version increments by 1 or 2
increment = True # (d[0] == d[1] == 0) and (0 < d[2] < 3) # only publish if patch version increments by 1 or 2
os.system(f'echo "increment={increment}" >> $GITHUB_OUTPUT')
os.system(f'echo "version={pyproject_version}" >> $GITHUB_OUTPUT')
if increment:
Expand Down
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ build-backend = "setuptools.build_meta"

[project]
name = "ultralytics-thop"
version = "0.0.3" # Placeholder version, needs to be dynamically set
description = "A tool to count the FLOPs of PyTorch model."
dynamic = ["version"]
description = "Ultralytics THOP package for fast computation of PyTorch model FLOPs and parameters."
readme = "README.md"
requires-python = ">=3.8"
license = { file = "LICENSE" }
license = { "text" = "AGPL-3.0" }
keywords = ["FLOPs", "PyTorch", "Model Analysis"] # Optional
authors = [
{ name = "Ligeng Zhu", email = "[email protected]" }
Expand Down Expand Up @@ -57,7 +57,6 @@ classifiers = [
"Operating System :: Microsoft :: Windows",
]
dependencies = [
"packaging",
"torch",
]

Expand Down
2 changes: 1 addition & 1 deletion thop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .utils import clever_format

default_dtype = torch.float64
from .__version__ import __version__
__version__ = "0.2.0"
1 change: 0 additions & 1 deletion thop/__version__.py

This file was deleted.

31 changes: 12 additions & 19 deletions thop/profile.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from packaging.version import Version

from thop.rnn_hooks import *
from thop.vision.basic_hooks import *
Expand All @@ -7,11 +6,6 @@
# logger.setLevel(logging.INFO)
from .utils import prGreen, prRed, prYellow

if Version(torch.__version__) < Version("1.0.0"):
logging.warning(
"You are using an old version PyTorch {version}, which THOP does NOT support.".format(version=torch.__version__)
)

default_dtype = torch.float64

register_hooks = {
Expand Down Expand Up @@ -59,16 +53,11 @@
nn.LSTM: count_lstm,
nn.Sequential: zero_ops,
nn.PixelShuffle: zero_ops,
nn.SyncBatchNorm: count_normalization,
}

if Version(torch.__version__) >= Version("1.1.0"):
register_hooks.update({nn.SyncBatchNorm: count_normalization})


def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
"""Profiles a PyTorch model's operations and parameters by applying custom or default hooks and returns total
operations and parameters.
"""
"""Profiles a PyTorch model's operations and parameters by applying custom or default hooks and returns total operations and parameters."""
handler_collection = []
types_collection = set()
if custom_ops is None:
Expand Down Expand Up @@ -98,14 +87,16 @@ def add_hooks(m):
if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
fn = custom_ops[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
elif m_type in register_hooks:
fn = register_hooks[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
else:
if m_type not in types_collection and report_missing:
prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)
prRed(
f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params."
)

if fn is not None:
handler = m.register_forward_hook(fn)
Expand Down Expand Up @@ -179,14 +170,16 @@ def add_hooks(m: nn.Module):
# if defined both op maps, use custom_ops to overwrite.
fn = custom_ops[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
elif m_type in register_hooks:
fn = register_hooks[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
else:
if m_type not in types_collection and report_missing:
prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)
prRed(
f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params."
)

if fn is not None:
handler_collection[m] = (
Expand Down

0 comments on commit 790ad97

Please sign in to comment.