Skip to content

Commit 3611e9e

Browse files
committed
distribute as package
1 parent 658acc3 commit 3611e9e

File tree

12 files changed

+88
-32
lines changed

12 files changed

+88
-32
lines changed

Diff for: LICENSE

+21
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
MIT License
2+
3+
Copyright (c) 2019 marsggbo
4+
5+
Permission is hereby granted, free of charge, to any person obtaining a copy
6+
of this software and associated documentation files (the "Software"), to deal
7+
in the Software without restriction, including without limitation the rights
8+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9+
copies of the Software, and to permit persons to whom the Software is
10+
furnished to do so, subject to the following conditions:
11+
12+
The above copyright notice and this permission notice shall be included in all
13+
copies or substantial portions of the Software.
14+
15+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21+
SOFTWARE.

Diff for: README.md

+13
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,20 @@
11
# torchline
22
Easy to use Pytorch
33

4+
# Dependences
45

6+
- Python>=3.6
7+
- Pytorch==1.3.1
8+
- torchvision==0.4.2
9+
- yacs==0.1.6
10+
- pytorch-lightning==0.5.3.2
11+
12+
13+
# Install
14+
15+
```bash
16+
pip install torchline
17+
```
518

619
# Structures
720

Diff for: release.md

+9-2
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,22 @@
33

44
## v0.1
55

6-
### 更新信息
6+
### 2019.12.12 更新信息
77
- 基本框架搭建完成
88
- 可正常运行cifar10_demo
99

10+
11+
### 2019.12.13 更新信息
12+
- 实现setup安装
13+
- 完善包之间的引用关系
14+
1015
### todo list
1116

1217
- [ ] 弄清楚logging机制
1318
- [ ] save和load模型,优化器参数
1419
- [ ] skin数据集读取测试
1520
- [ ] 构建skin project
1621
- [ ] 能否预测单张图片?
17-
- [ ] 构建一个简单地API接口
22+
- [ ] 构建一个简单地API接口
23+
- [ ] 进一步完善包导入
24+
- [ ] 完善使用文档

Diff for: requirements.txt

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
torch==1.3.1
2+
torchvision==0.4.2
3+
yacs==0.1.6
4+
pytorch-lightning==0.5.3.2

Diff for: setup.py

+23
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
from setuptools import setup, find_packages
2+
3+
with open('requirements.txt', 'r') as f:
4+
requirements = f.read().splitlines()
5+
6+
setup(
7+
name="torchline", # Replace with your own username
8+
version="0.1",
9+
author="marsggbo",
10+
author_email="[email protected]",
11+
description="A framework for easy to use Pytorch",
12+
long_description='...',
13+
long_description_content_type="text/markdown",
14+
url="https://github.com/marsggbo/torchline",
15+
packages=find_packages(exclude=("tests", "projects")),
16+
install_requires=requirements,
17+
classifiers=[
18+
"Programming Language :: Python :: 3",
19+
"License :: OSI Approved :: MIT License",
20+
"Operating System :: OS Independent",
21+
],
22+
python_requires='>=3.6',
23+
)

Diff for: torchline/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .config import *
2+
from .data import *
3+
from .engine import *
4+
from .losses import *
5+
from .models import *
6+
from .utils import *

Diff for: torchline/config/skin10.yaml

-16
This file was deleted.

Diff for: torchline/data/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
from .build import build_data
1+
from .build import build_data, DATASET_REGISTRY
22
from .skin100 import Skin100Dataset
33
from .skin10 import Skin10Dataset
44
from .common_datasets import MNIST, CIFAR10
5-
from .transforms import DefaultTransforms, build_transforms
5+
from .transforms import DefaultTransforms, build_transforms
6+
from .autoaugment import ImageNetPolicy, CIFAR10Policy, SVHNPolicy, SubPolicy

Diff for: torchline/engine/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
from .lightning_module_template import LightningTemplateModel

Diff for: torchline/losses/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .build import build_loss_fn
1+
from .build import build_loss_fn, LOSS_FN_REGISTRY
22
from .loss import CrossEntropy

Diff for: torchline/models/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,2 @@
1-
from .build import build_model
1+
from .build import build_model, META_ARCH_REGISTRY
22
from .resnet_models import *

Diff for: torchline/models/resnet_models.py

+6-10
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,8 @@ def __init__(self, cfg):
3636
self.layer3 = nn.Sequential(*model[6])
3737
self.layer4 = nn.Sequential(*model[7])
3838

39-
self.clf = nn.Sequential(
40-
nn.Conv2d(2048, 512, kernel_size=1),
41-
nn.ReLU(inplace=True),
42-
nn.Conv2d(512, self.num_classes, kernel_size=1),
43-
nn.ReLU(inplace=True),
44-
nn.AdaptiveAvgPool2d((1,1))
45-
)
39+
self.avg_pool = nn.AdaptiveAvgPool2d((1,1))
40+
self.clf = nn.Linear(in_features=2048, out_features=self.num_classes)
4641

4742
def extract_features(self, x):
4843
assert len(self.cfg.MODEL.FEATURES) >= 1
@@ -67,13 +62,14 @@ def forward(self, x):
6762
x (tensor): N*c*h*w
6863
6964
return:
70-
img_cls_preds (tensor): N*classes
65+
predictions (tensor): N*classes
7166
'''
7267
bs= x.shape[0]
7368
features = self.extract_features(x)
74-
img_cls_preds = self.clf(features).view(bs, -1)
69+
predictions = self.avg_pool(features).view(bs, -1)
70+
predictions = self.clf(predictions).view(bs, -1)
7571

76-
return img_cls_preds
72+
return predictions
7773

7874

7975

0 commit comments

Comments
 (0)