Skip to content

Commit

Permalink
Update publish.yml (#22)
Browse files Browse the repository at this point in the history
Co-authored-by: UltralyticsAssistant <[email protected]>
  • Loading branch information
glenn-jocher and UltralyticsAssistant committed Jun 9, 2024
1 parent 54128bd commit d39b010
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 9 deletions.
15 changes: 8 additions & 7 deletions .github/workflows/publish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -40,26 +40,27 @@ jobs:
import thop
import ultralytics
from ultralytics.utils.checks import check_latest_pypi_version
v_local = tuple(map(int, thop.__version__.split('.')))
v_pypi = tuple(map(int, check_latest_pypi_version('ultralytics-thop').split('.')))
print(f'Local version is {v_local}')
print(f'PyPI version is {v_pypi}')
d = [a - b for a, b in zip(v_local, v_pypi)] # diff
increment_patch = (d[0] == d[1] == 0) and (0 < d[2] < 3) # publish if patch version increments by 1 or 2
increment_minor = (d[0] == 0) and (d[1] == 1) and v_local[2] == 0 # publish if minor version increments
increment = increment_patch or increment_minor
os.system(f'echo "increment={increment}" >> $GITHUB_OUTPUT')
os.system(f'echo "version={thop.__version__}" >> $GITHUB_OUTPUT')
if increment:
print('Local version is higher than PyPI version. Publishing new version to PyPI ✅.')
id: check_pypi
- name: Publish new tag
if: steps.check_pypi.outputs.increment == 'True'
run: |
COMMIT_MESSAGE=$(git log -1 --pretty=%B)
git config --global user.name "UltralyticsAssistant"
git config --global user.email "[email protected]"
git tag -a "v${{ steps.check_pypi.outputs.version }}" -m "$COMMIT_MESSAGE"
git push origin "v${{ steps.check_pypi.outputs.version }}"
- name: Publish to PyPI
continue-on-error: true
if: (github.event_name == 'push' || github.event.inputs.pypi == 'true') && steps.check_pypi.outputs.increment == 'True'
Expand Down
8 changes: 6 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import torch

model = resnet50()
input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ))
macs, params = profile(model, inputs=(input,))
```

### Define Custom Rules for Third-Party Modules
Expand All @@ -50,16 +50,19 @@ You can define custom rules for unsupported modules:
```python
import torch.nn as nn


class YourModule(nn.Module):
# your definition
pass


def count_your_model(model, x, y):
# your rule here
pass


input = torch.randn(1, 3, 224, 224)
macs, params = profile(model, inputs=(input, ), custom_ops={YourModule: count_your_model})
macs, params = profile(model, inputs=(input,), custom_ops={YourModule: count_your_model})
```

### Improve Output Readability
Expand All @@ -68,6 +71,7 @@ Use `thop.clever_format` for a more readable output:

```python
from thop import clever_format

macs, params = clever_format([macs, params], "%.3f")
```

Expand Down

0 comments on commit d39b010

Please sign in to comment.