From d39b01005a4f5ab6e1991df9449b6030952717fb Mon Sep 17 00:00:00 2001 From: Glenn Jocher Date: Mon, 10 Jun 2024 01:21:02 +0200 Subject: [PATCH] Update publish.yml (#22) Co-authored-by: UltralyticsAssistant --- .github/workflows/publish.yml | 15 ++++++++------- README.md | 8 ++++++-- 2 files changed, 14 insertions(+), 9 deletions(-) diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 07ebb5a..4ab74bc 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -40,26 +40,27 @@ jobs: import thop import ultralytics from ultralytics.utils.checks import check_latest_pypi_version - v_local = tuple(map(int, thop.__version__.split('.'))) v_pypi = tuple(map(int, check_latest_pypi_version('ultralytics-thop').split('.'))) - print(f'Local version is {v_local}') print(f'PyPI version is {v_pypi}') - d = [a - b for a, b in zip(v_local, v_pypi)] # diff - increment_patch = (d[0] == d[1] == 0) and (0 < d[2] < 3) # publish if patch version increments by 1 or 2 increment_minor = (d[0] == 0) and (d[1] == 1) and v_local[2] == 0 # publish if minor version increments - increment = increment_patch or increment_minor - os.system(f'echo "increment={increment}" >> $GITHUB_OUTPUT') os.system(f'echo "version={thop.__version__}" >> $GITHUB_OUTPUT') - if increment: print('Local version is higher than PyPI version. Publishing new version to PyPI ✅.') id: check_pypi + - name: Publish new tag + if: steps.check_pypi.outputs.increment == 'True' + run: | + COMMIT_MESSAGE=$(git log -1 --pretty=%B) + git config --global user.name "UltralyticsAssistant" + git config --global user.email "web@ultralytics.com" + git tag -a "v${{ steps.check_pypi.outputs.version }}" -m "$COMMIT_MESSAGE" + git push origin "v${{ steps.check_pypi.outputs.version }}" - name: Publish to PyPI continue-on-error: true if: (github.event_name == 'push' || github.event.inputs.pypi == 'true') && steps.check_pypi.outputs.increment == 'True' diff --git a/README.md b/README.md index fce0e45..1c2fefc 100644 --- a/README.md +++ b/README.md @@ -40,7 +40,7 @@ import torch model = resnet50() input = torch.randn(1, 3, 224, 224) -macs, params = profile(model, inputs=(input, )) +macs, params = profile(model, inputs=(input,)) ``` ### Define Custom Rules for Third-Party Modules @@ -50,16 +50,19 @@ You can define custom rules for unsupported modules: ```python import torch.nn as nn + class YourModule(nn.Module): # your definition pass + def count_your_model(model, x, y): # your rule here pass + input = torch.randn(1, 3, 224, 224) -macs, params = profile(model, inputs=(input, ), custom_ops={YourModule: count_your_model}) +macs, params = profile(model, inputs=(input,), custom_ops={YourModule: count_your_model}) ``` ### Improve Output Readability @@ -68,6 +71,7 @@ Use `thop.clever_format` for a more readable output: ```python from thop import clever_format + macs, params = clever_format([macs, params], "%.3f") ```