Skip to content

Commit f6bef74

Browse files
authored
Update APP by ID photo DIY (#5707)
1 parent 72375b2 commit f6bef74

File tree

7 files changed

+359
-168
lines changed

7 files changed

+359
-168
lines changed

modelcenter/PP-Matting/APP1/app.py

+63-167
Original file line numberDiff line numberDiff line change
@@ -1,182 +1,78 @@
1-
import codecs
2-
import os
3-
import sys
4-
import time
5-
import zipfile
6-
71
import gradio as gr
82
import numpy as np
9-
import cv2
10-
import requests
11-
import yaml
12-
from paddle.inference import Config as PredictConfig
13-
from paddle.inference import create_predictor
14-
15-
lasttime = time.time()
16-
FLUSH_INTERVAL = 0.1
17-
18-
19-
def progress(str, end=False):
20-
global lasttime
21-
if end:
22-
str += "\n"
23-
lasttime = 0
24-
if time.time() - lasttime >= FLUSH_INTERVAL:
25-
sys.stdout.write("\r%s" % str)
26-
lasttime = time.time()
27-
sys.stdout.flush()
28-
29-
30-
def _download_file(url, savepath, print_progress=True):
31-
if print_progress:
32-
print("Connecting to {}".format(url))
33-
r = requests.get(url, stream=True, timeout=15)
34-
total_length = r.headers.get('content-length')
35-
36-
if total_length is None:
37-
with open(savepath, 'wb') as f:
38-
shutil.copyfileobj(r.raw, f)
39-
else:
40-
with open(savepath, 'wb') as f:
41-
dl = 0
42-
total_length = int(total_length)
43-
starttime = time.time()
44-
if print_progress:
45-
print("Downloading %s" % os.path.basename(savepath))
46-
for data in r.iter_content(chunk_size=4096):
47-
dl += len(data)
48-
f.write(data)
49-
if print_progress:
50-
done = int(50 * dl / total_length)
51-
progress("[%-50s] %.2f%%" %
52-
('=' * done, float(100 * dl) / total_length))
53-
if print_progress:
54-
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
55-
56-
57-
def uncompress(path):
58-
files = zipfile.ZipFile(path, 'r')
59-
filelist = files.namelist()
60-
rootpath = filelist[0]
61-
for file in filelist:
62-
files.extract(file, './')
63-
64-
65-
class DeployConfig:
66-
def __init__(self, path):
67-
with codecs.open(path, 'r', 'utf-8') as file:
68-
self.dic = yaml.load(file, Loader=yaml.FullLoader)
69-
self._dir = os.path.dirname(path)
70-
71-
@property
72-
def model(self):
73-
return os.path.join(self._dir, self.dic['Deploy']['model'])
74-
75-
@property
76-
def params(self):
77-
return os.path.join(self._dir, self.dic['Deploy']['params'])
78-
79-
80-
class Predictor:
81-
def __init__(self, cfg):
82-
"""
83-
Prepare for prediction.
84-
The usage and docs of paddle inference, please refer to
85-
https://paddleinference.paddlepaddle.org.cn/product_introduction/summary.html
86-
"""
87-
self.cfg = DeployConfig(cfg)
88-
89-
self._init_base_config()
90-
91-
self._init_cpu_config()
92-
93-
self.predictor = create_predictor(self.pred_cfg)
94-
95-
def _init_base_config(self):
96-
self.pred_cfg = PredictConfig(self.cfg.model, self.cfg.params)
97-
self.pred_cfg.enable_memory_optim()
98-
self.pred_cfg.switch_ir_optim(True)
99-
100-
def _init_cpu_config(self):
101-
"""
102-
Init the config for x86 cpu.
103-
"""
104-
self.pred_cfg.disable_gpu()
105-
self.pred_cfg.set_cpu_math_library_num_threads(10)
106-
107-
def _preprocess(self, img):
108-
# resize short edge to 512.
109-
h, w = img.shape[:2]
110-
short_edge = min(h, w)
111-
scale = 512 / short_edge
112-
h_resize = int(round(h * scale)) // 32 * 32
113-
w_resize = int(round(w * scale)) // 32 * 32
114-
img = cv2.resize(img, (w_resize, h_resize))
115-
img = (img / 255 - 0.5) / 0.5
116-
img = np.transpose(img, [2, 0, 1])[np.newaxis, :]
117-
return img
118-
119-
def run(self, img):
120-
input_names = self.predictor.get_input_names()
121-
input_handle = {}
122-
123-
for i in range(len(input_names)):
124-
input_handle[input_names[i]] = self.predictor.get_input_handle(
125-
input_names[i])
126-
output_names = self.predictor.get_output_names()
127-
output_handle = self.predictor.get_output_handle(output_names[0])
128-
129-
img_inputs = img.astype('float32')
130-
ori_h, ori_w = img_inputs.shape[:2]
131-
img_inputs = self._preprocess(img=img_inputs)
132-
input_handle['img'].copy_from_cpu(img_inputs)
133-
134-
self.predictor.run()
135-
136-
results = output_handle.copy_to_cpu()
137-
alpha = results.squeeze()
138-
alpha = cv2.resize(alpha, (ori_w, ori_h))
139-
alpha = (alpha * 255).astype('uint8')
140-
141-
return alpha
142-
143-
144-
def model_inference(image):
145-
# Download inference model
146-
url = 'https://paddleseg.bj.bcebos.com/matting/models/deploy/ppmatting-hrnet_w18-human_512.zip'
147-
savepath = './ppmatting-hrnet_w18-human_512.zip'
148-
if not os.path.exists('./ppmatting-hrnet_w18-human_512'):
149-
_download_file(url=url, savepath=savepath)
150-
uncompress(savepath)
1513

152-
# Inference
153-
predictor = Predictor(cfg='./ppmatting-hrnet_w18-human_512/deploy.yaml')
154-
alpha = predictor.run(image)
4+
import utils
5+
from predict import build_predictor
1556

156-
return alpha
7+
IMAGE_DEMO = "./images/idphoto.jpg"
8+
predictor = build_predictor()
9+
sizes_play = utils.size_play()
15710

15811

159-
def clear_all():
160-
return None, None
12+
def get_output(img, size, bg, download_size):
13+
"""
14+
Get the special size and background photo.
16115
16+
Args:
17+
img(numpy:ndarray): The image array.
18+
size(str): The size user specified.
19+
bg(str): The background color user specified.
20+
download_size(str): The size for image saving.
16221
163-
with gr.Blocks() as demo:
164-
gr.Markdown("Objective Detection")
22+
"""
23+
alpha = predictor.run(img)
24+
res = utils.bg_replace(img, alpha, bg_name=bg)
16525

166-
with gr.Column(scale=1, min_width=100):
26+
size_index = sizes_play.index(size)
27+
res = utils.adjust_size(res, size_index)
28+
res_download = utils.download(res, download_size)
29+
return res, res_download
16730

168-
img_in = gr.Image(
169-
value="https://paddleseg.bj.bcebos.com/matting/demo/human.jpg",
170-
label="Input")
17131

172-
with gr.Row():
173-
btn1 = gr.Button("Clear")
174-
btn2 = gr.Button("Submit")
32+
def download(img, size):
33+
utils.download(img, size)
34+
return None
17535

176-
img_out = gr.Image(label="Output").style(height=200)
17736

178-
btn2.click(fn=model_inference, inputs=img_in, outputs=[img_out])
179-
btn1.click(fn=clear_all, inputs=None, outputs=[img_in, img_out])
37+
with gr.Blocks() as demo:
38+
gr.Markdown("""# ID Photo DIY""")
39+
40+
img_in = gr.Image(value=IMAGE_DEMO, label="Input image")
41+
gr.Markdown(
42+
"""<font color=Gray>Tips: Please upload photos with good posture, center portrait, crown free, no jewelry, ears and eyebrows exposed.</font>"""
43+
)
44+
with gr.Row():
45+
size = gr.Dropdown(sizes_play, label="Sizes", value=sizes_play[0])
46+
bg = gr.Radio(
47+
["White", "Red", "Blue"], label="Background color", value='White')
48+
download_size = gr.Radio(
49+
["Small", "Middle", "Large"],
50+
label="File size (affects image quality)",
51+
value='Large',
52+
interactive=True)
53+
54+
with gr.Row():
55+
btn1 = gr.Button("Clear")
56+
btn2 = gr.Button("Submit")
57+
58+
img_out = gr.Image(
59+
label="Output image", interactive=False).style(height=300)
60+
f1 = gr.File(label='Image download').style(height=50)
61+
with gr.Row():
62+
gr.Markdown(
63+
"""<font color=Gray>This application is supported by [PaddleSeg](https://github.com/PaddlePaddle/PaddleSeg).
64+
If you have any questions or feature requists, welcome to raise issues on [GitHub](https://github.com/PaddlePaddle/PaddleSeg/issues). BTW, a star is a great encouragement for us, thanks! ^_^</font>"""
65+
)
66+
67+
btn2.click(
68+
fn=get_output,
69+
inputs=[img_in, size, bg, download_size],
70+
outputs=[img_out, f1])
71+
btn1.click(
72+
fn=utils.clear_all,
73+
inputs=None,
74+
outputs=[img_in, img_out, size, bg, download_size, f1])
75+
18076
gr.Button.style(1)
18177

182-
demo.launch(share=True)
78+
demo.launch()
+55
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import os
2+
import sys
3+
import time
4+
5+
import requests
6+
import zipfile
7+
8+
FLUSH_INTERVAL = 0.1
9+
lasttime = time.time()
10+
11+
12+
def progress(str, end=False):
13+
global lasttime
14+
if end:
15+
str += "\n"
16+
lasttime = 0
17+
if time.time() - lasttime >= FLUSH_INTERVAL:
18+
sys.stdout.write("\r%s" % str)
19+
lasttime = time.time()
20+
sys.stdout.flush()
21+
22+
23+
def download_file(url, savepath, print_progress=True):
24+
if print_progress:
25+
print("Connecting to {}".format(url))
26+
r = requests.get(url, stream=True, timeout=15)
27+
total_length = r.headers.get('content-length')
28+
29+
if total_length is None:
30+
with open(savepath, 'wb') as f:
31+
shutil.copyfileobj(r.raw, f)
32+
else:
33+
with open(savepath, 'wb') as f:
34+
dl = 0
35+
total_length = int(total_length)
36+
starttime = time.time()
37+
if print_progress:
38+
print("Downloading %s" % os.path.basename(savepath))
39+
for data in r.iter_content(chunk_size=4096):
40+
dl += len(data)
41+
f.write(data)
42+
if print_progress:
43+
done = int(50 * dl / total_length)
44+
progress("[%-50s] %.2f%%" %
45+
('=' * done, float(100 * dl) / total_length))
46+
if print_progress:
47+
progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True)
48+
49+
50+
def uncompress(path):
51+
files = zipfile.ZipFile(path, 'r')
52+
filelist = files.namelist()
53+
rootpath = filelist[0]
54+
for file in filelist:
55+
files.extract(file, './')
71.1 KB
Loading
Loading

0 commit comments

Comments
 (0)