Skip to content

Commit

Permalink
Eliminate packaging dependency (#12)
Browse files Browse the repository at this point in the history
Co-authored-by: UltralyticsAssistant <[email protected]>
  • Loading branch information
glenn-jocher and UltralyticsAssistant committed May 31, 2024
1 parent 5acf27a commit 3a47aa3
Show file tree
Hide file tree
Showing 5 changed files with 12 additions and 23 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ jobs:
print(f'Local version is {v_local}')
print(f'PyPI version is {v_pypi}')
d = [a - b for a, b in zip(v_local, v_pypi)] # diff
increment = (d[0] == d[1] == 0) and (0 < d[2] < 3) # only publish if patch version increments by 1 or 2
increment = True # (d[0] == d[1] == 0) and (0 < d[2] < 3) # only publish if patch version increments by 1 or 2
os.system(f'echo "increment={increment}" >> $GITHUB_OUTPUT')
os.system(f'echo "version={pyproject_version}" >> $GITHUB_OUTPUT')
if increment:
Expand Down
7 changes: 3 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ build-backend = "setuptools.build_meta"

[project]
name = "ultralytics-thop"
version = "0.0.3" # Placeholder version, needs to be dynamically set
description = "A tool to count the FLOPs of PyTorch model."
dynamic = ["version"]
description = "Ultralytics THOP package for fast computation of PyTorch model FLOPs and parameters."
readme = "README.md"
requires-python = ">=3.8"
license = { file = "LICENSE" }
license = { "text" = "AGPL-3.0" }
keywords = ["FLOPs", "PyTorch", "Model Analysis"] # Optional
authors = [
{ name = "Ligeng Zhu", email = "[email protected]" }
Expand Down Expand Up @@ -57,7 +57,6 @@ classifiers = [
"Operating System :: Microsoft :: Windows",
]
dependencies = [
"packaging",
"torch",
]

Expand Down
2 changes: 1 addition & 1 deletion thop/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .utils import clever_format

default_dtype = torch.float64
from .__version__ import __version__
__version__ = "0.2.0"
1 change: 0 additions & 1 deletion thop/__version__.py

This file was deleted.

23 changes: 7 additions & 16 deletions thop/profile.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,10 @@
from packaging.version import Version

from thop.rnn_hooks import *
from thop.vision.basic_hooks import *

# logger = logging.getLogger(__name__)
# logger.setLevel(logging.INFO)
from .utils import prGreen, prRed, prYellow

if Version(torch.__version__) < Version("1.0.0"):
logging.warning(
"You are using an old version PyTorch {version}, which THOP does NOT support.".format(version=torch.__version__)
)

default_dtype = torch.float64

register_hooks = {
Expand Down Expand Up @@ -59,11 +52,9 @@
nn.LSTM: count_lstm,
nn.Sequential: zero_ops,
nn.PixelShuffle: zero_ops,
nn.SyncBatchNorm: count_normalization,
}

if Version(torch.__version__) >= Version("1.1.0"):
register_hooks.update({nn.SyncBatchNorm: count_normalization})


def profile_origin(model, inputs, custom_ops=None, verbose=True, report_missing=False):
"""Profiles a PyTorch model's operations and parameters by applying custom or default hooks and returns total
Expand Down Expand Up @@ -98,14 +89,14 @@ def add_hooks(m):
if m_type in custom_ops: # if defined both op maps, use custom_ops to overwrite.
fn = custom_ops[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
elif m_type in register_hooks:
fn = register_hooks[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
else:
if m_type not in types_collection and report_missing:
prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)
prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")

if fn is not None:
handler = m.register_forward_hook(fn)
Expand Down Expand Up @@ -179,14 +170,14 @@ def add_hooks(m: nn.Module):
# if defined both op maps, use custom_ops to overwrite.
fn = custom_ops[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Customize rule %s() %s." % (fn.__qualname__, m_type))
print(f"[INFO] Customize rule {fn.__qualname__}() {m_type}.")
elif m_type in register_hooks:
fn = register_hooks[m_type]
if m_type not in types_collection and verbose:
print("[INFO] Register %s() for %s." % (fn.__qualname__, m_type))
print(f"[INFO] Register {fn.__qualname__}() for {m_type}.")
else:
if m_type not in types_collection and report_missing:
prRed("[WARN] Cannot find rule for %s. Treat it as zero Macs and zero Params." % m_type)
prRed(f"[WARN] Cannot find rule for {m_type}. Treat it as zero Macs and zero Params.")

if fn is not None:
handler_collection[m] = (
Expand Down

0 comments on commit 3a47aa3

Please sign in to comment.