-
Notifications
You must be signed in to change notification settings - Fork 0
/
main.py
93 lines (72 loc) · 2.14 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
#!/usr/bin/env python
if __name__ != '__main__': raise Exception("Do not import me!")
import chainer
import logging
import numpy as np
from contextlib import contextmanager
from tqdm import tqdm
from chainer.cuda import to_cpu
from chainer.dataset.convert import concat_examples
from part_estimation.core import Data
from part_estimation.core import Model
from part_estimation.core import Propagator
from part_estimation.core import ExtractionPipeline
from part_estimation.core import VisualizationPipeline
from part_estimation.utils import arguments
from cluster_parts.core import BoundingBoxPartExtractor
from cluster_parts.core import Corrector
from cluster_parts.utils import ClusterInitType
@contextmanager
def outputs(args):
if args.extract:
assert args.extract is not None, \
"For extraction output files are required!"
outputs = [open(out, "w") for out in args.extract]
yield outputs
[out.close for out in outputs]
else:
logging.warning("Extraction is disabled!")
yield None, None
def main(args):
args.feature_model = args.model_type
clf = Model.load_svm(args.trained_svm, args.visualize_coefs)
scaler, it, *model_args = Data.new(args, clf)
model, prepare = Model.new(args, *model_args)
logging.info("Using following feature composition: {}".format(args.feature_composition))
propagator = Propagator(model, clf,
scaler=scaler,
topk=args.topk,
swap_channels=args.swap_channels,
n_jobs=1,
)
extractor = BoundingBoxPartExtractor(
corrector=Corrector(gamma=args.gamma, sigma=args.sigma),
K=args.K,
fit_object=args.fit_object,
thresh_type=args.thresh_type,
cluster_init=ClusterInitType.MAXIMAS,
feature_composition=args.feature_composition,
)
kwargs = dict(
model=model,
extractor=extractor,
propagator=propagator,
iterator=it,
prepare=prepare,
device=args.gpu[0],
)
if args.extract:
with outputs(args) as files:
pipeline = ExtractionPipeline(
files=files,
**kwargs
)
else:
pipeline = VisualizationPipeline(
**kwargs
)
pipeline.run()
np.seterr(all="raise")
chainer.global_config.cv_resize_backend = "PIL"
with chainer.using_config("train", False):
main(arguments.parse_args())