|
1 |
| -import codecs |
2 |
| -import os |
3 |
| -import sys |
4 |
| -import time |
5 |
| -import zipfile |
6 |
| - |
7 | 1 | import gradio as gr
|
8 | 2 | import numpy as np
|
9 |
| -import cv2 |
10 |
| -import requests |
11 |
| -import yaml |
12 |
| -from paddle.inference import Config as PredictConfig |
13 |
| -from paddle.inference import create_predictor |
14 |
| - |
15 |
| -lasttime = time.time() |
16 |
| -FLUSH_INTERVAL = 0.1 |
17 |
| - |
18 |
| - |
19 |
| -def progress(str, end=False): |
20 |
| - global lasttime |
21 |
| - if end: |
22 |
| - str += "\n" |
23 |
| - lasttime = 0 |
24 |
| - if time.time() - lasttime >= FLUSH_INTERVAL: |
25 |
| - sys.stdout.write("\r%s" % str) |
26 |
| - lasttime = time.time() |
27 |
| - sys.stdout.flush() |
28 |
| - |
29 |
| - |
30 |
| -def _download_file(url, savepath, print_progress=True): |
31 |
| - if print_progress: |
32 |
| - print("Connecting to {}".format(url)) |
33 |
| - r = requests.get(url, stream=True, timeout=15) |
34 |
| - total_length = r.headers.get('content-length') |
35 |
| - |
36 |
| - if total_length is None: |
37 |
| - with open(savepath, 'wb') as f: |
38 |
| - shutil.copyfileobj(r.raw, f) |
39 |
| - else: |
40 |
| - with open(savepath, 'wb') as f: |
41 |
| - dl = 0 |
42 |
| - total_length = int(total_length) |
43 |
| - starttime = time.time() |
44 |
| - if print_progress: |
45 |
| - print("Downloading %s" % os.path.basename(savepath)) |
46 |
| - for data in r.iter_content(chunk_size=4096): |
47 |
| - dl += len(data) |
48 |
| - f.write(data) |
49 |
| - if print_progress: |
50 |
| - done = int(50 * dl / total_length) |
51 |
| - progress("[%-50s] %.2f%%" % |
52 |
| - ('=' * done, float(100 * dl) / total_length)) |
53 |
| - if print_progress: |
54 |
| - progress("[%-50s] %.2f%%" % ('=' * 50, 100), end=True) |
55 |
| - |
56 |
| - |
57 |
| -def uncompress(path): |
58 |
| - files = zipfile.ZipFile(path, 'r') |
59 |
| - filelist = files.namelist() |
60 |
| - rootpath = filelist[0] |
61 |
| - for file in filelist: |
62 |
| - files.extract(file, './') |
63 |
| - |
64 |
| - |
65 |
| -class DeployConfig: |
66 |
| - def __init__(self, path): |
67 |
| - with codecs.open(path, 'r', 'utf-8') as file: |
68 |
| - self.dic = yaml.load(file, Loader=yaml.FullLoader) |
69 |
| - self._dir = os.path.dirname(path) |
70 |
| - |
71 |
| - @property |
72 |
| - def model(self): |
73 |
| - return os.path.join(self._dir, self.dic['Deploy']['model']) |
74 |
| - |
75 |
| - @property |
76 |
| - def params(self): |
77 |
| - return os.path.join(self._dir, self.dic['Deploy']['params']) |
78 |
| - |
79 |
| - |
80 |
| -class Predictor: |
81 |
| - def __init__(self, cfg): |
82 |
| - """ |
83 |
| - Prepare for prediction. |
84 |
| - The usage and docs of paddle inference, please refer to |
85 |
| - https://paddleinference.paddlepaddle.org.cn/product_introduction/summary.html |
86 |
| - """ |
87 |
| - self.cfg = DeployConfig(cfg) |
88 |
| - |
89 |
| - self._init_base_config() |
90 |
| - |
91 |
| - self._init_cpu_config() |
92 |
| - |
93 |
| - self.predictor = create_predictor(self.pred_cfg) |
94 |
| - |
95 |
| - def _init_base_config(self): |
96 |
| - self.pred_cfg = PredictConfig(self.cfg.model, self.cfg.params) |
97 |
| - self.pred_cfg.enable_memory_optim() |
98 |
| - self.pred_cfg.switch_ir_optim(True) |
99 |
| - |
100 |
| - def _init_cpu_config(self): |
101 |
| - """ |
102 |
| - Init the config for x86 cpu. |
103 |
| - """ |
104 |
| - self.pred_cfg.disable_gpu() |
105 |
| - self.pred_cfg.set_cpu_math_library_num_threads(10) |
106 |
| - |
107 |
| - def _preprocess(self, img): |
108 |
| - # resize short edge to 512. |
109 |
| - h, w = img.shape[:2] |
110 |
| - short_edge = min(h, w) |
111 |
| - scale = 512 / short_edge |
112 |
| - h_resize = int(round(h * scale)) // 32 * 32 |
113 |
| - w_resize = int(round(w * scale)) // 32 * 32 |
114 |
| - img = cv2.resize(img, (w_resize, h_resize)) |
115 |
| - img = (img / 255 - 0.5) / 0.5 |
116 |
| - img = np.transpose(img, [2, 0, 1])[np.newaxis, :] |
117 |
| - return img |
118 |
| - |
119 |
| - def run(self, img): |
120 |
| - input_names = self.predictor.get_input_names() |
121 |
| - input_handle = {} |
122 |
| - |
123 |
| - for i in range(len(input_names)): |
124 |
| - input_handle[input_names[i]] = self.predictor.get_input_handle( |
125 |
| - input_names[i]) |
126 |
| - output_names = self.predictor.get_output_names() |
127 |
| - output_handle = self.predictor.get_output_handle(output_names[0]) |
128 |
| - |
129 |
| - img_inputs = img.astype('float32') |
130 |
| - ori_h, ori_w = img_inputs.shape[:2] |
131 |
| - img_inputs = self._preprocess(img=img_inputs) |
132 |
| - input_handle['img'].copy_from_cpu(img_inputs) |
133 |
| - |
134 |
| - self.predictor.run() |
135 |
| - |
136 |
| - results = output_handle.copy_to_cpu() |
137 |
| - alpha = results.squeeze() |
138 |
| - alpha = cv2.resize(alpha, (ori_w, ori_h)) |
139 |
| - alpha = (alpha * 255).astype('uint8') |
140 |
| - |
141 |
| - return alpha |
142 |
| - |
143 |
| - |
144 |
| -def model_inference(image): |
145 |
| - # Download inference model |
146 |
| - url = 'https://paddleseg.bj.bcebos.com/matting/models/deploy/ppmatting-hrnet_w18-human_512.zip' |
147 |
| - savepath = './ppmatting-hrnet_w18-human_512.zip' |
148 |
| - if not os.path.exists('./ppmatting-hrnet_w18-human_512'): |
149 |
| - _download_file(url=url, savepath=savepath) |
150 |
| - uncompress(savepath) |
151 | 3 |
|
152 |
| - # Inference |
153 |
| - predictor = Predictor(cfg='./ppmatting-hrnet_w18-human_512/deploy.yaml') |
154 |
| - alpha = predictor.run(image) |
| 4 | +import utils |
| 5 | +from predict import build_predictor |
155 | 6 |
|
156 |
| - return alpha |
| 7 | +IMAGE_DEMO = "./images/idphoto.jpg" |
| 8 | +predictor = build_predictor() |
| 9 | +sizes_play = utils.size_play() |
157 | 10 |
|
158 | 11 |
|
159 |
| -def clear_all(): |
160 |
| - return None, None |
| 12 | +def get_output(img, size, bg, download_size): |
| 13 | + """ |
| 14 | + Get the special size and background photo. |
161 | 15 |
|
| 16 | + Args: |
| 17 | + img(numpy:ndarray): The image array. |
| 18 | + size(str): The size user specified. |
| 19 | + bg(str): The background color user specified. |
| 20 | + download_size(str): The size for image saving. |
162 | 21 |
|
163 |
| -with gr.Blocks() as demo: |
164 |
| - gr.Markdown("Objective Detection") |
| 22 | + """ |
| 23 | + alpha = predictor.run(img) |
| 24 | + res = utils.bg_replace(img, alpha, bg_name=bg) |
165 | 25 |
|
166 |
| - with gr.Column(scale=1, min_width=100): |
| 26 | + size_index = sizes_play.index(size) |
| 27 | + res = utils.adjust_size(res, size_index) |
| 28 | + res_download = utils.download(res, download_size) |
| 29 | + return res, res_download |
167 | 30 |
|
168 |
| - img_in = gr.Image( |
169 |
| - value="https://paddleseg.bj.bcebos.com/matting/demo/human.jpg", |
170 |
| - label="Input") |
171 | 31 |
|
172 |
| - with gr.Row(): |
173 |
| - btn1 = gr.Button("Clear") |
174 |
| - btn2 = gr.Button("Submit") |
| 32 | +def download(img, size): |
| 33 | + utils.download(img, size) |
| 34 | + return None |
175 | 35 |
|
176 |
| - img_out = gr.Image(label="Output").style(height=200) |
177 | 36 |
|
178 |
| - btn2.click(fn=model_inference, inputs=img_in, outputs=[img_out]) |
179 |
| - btn1.click(fn=clear_all, inputs=None, outputs=[img_in, img_out]) |
| 37 | +with gr.Blocks() as demo: |
| 38 | + gr.Markdown("""# ID Photo DIY""") |
| 39 | + |
| 40 | + img_in = gr.Image(value=IMAGE_DEMO, label="Input image") |
| 41 | + gr.Markdown( |
| 42 | + """<font color=Gray>Tips: Please upload photos with good posture, center portrait, crown free, no jewelry, ears and eyebrows exposed.</font>""" |
| 43 | + ) |
| 44 | + with gr.Row(): |
| 45 | + size = gr.Dropdown(sizes_play, label="Sizes", value=sizes_play[0]) |
| 46 | + bg = gr.Radio( |
| 47 | + ["White", "Red", "Blue"], label="Background color", value='White') |
| 48 | + download_size = gr.Radio( |
| 49 | + ["Small", "Middle", "Large"], |
| 50 | + label="File size (affects image quality)", |
| 51 | + value='Large', |
| 52 | + interactive=True) |
| 53 | + |
| 54 | + with gr.Row(): |
| 55 | + btn1 = gr.Button("Clear") |
| 56 | + btn2 = gr.Button("Submit") |
| 57 | + |
| 58 | + img_out = gr.Image( |
| 59 | + label="Output image", interactive=False).style(height=300) |
| 60 | + f1 = gr.File(label='Image download').style(height=50) |
| 61 | + with gr.Row(): |
| 62 | + gr.Markdown( |
| 63 | + """<font color=Gray>This application is supported by [PaddleSeg](https://github.com/PaddlePaddle/PaddleSeg). |
| 64 | + If you have any questions or feature requists, welcome to raise issues on [GitHub](https://github.com/PaddlePaddle/PaddleSeg/issues). BTW, a star is a great encouragement for us, thanks! ^_^</font>""" |
| 65 | + ) |
| 66 | + |
| 67 | + btn2.click( |
| 68 | + fn=get_output, |
| 69 | + inputs=[img_in, size, bg, download_size], |
| 70 | + outputs=[img_out, f1]) |
| 71 | + btn1.click( |
| 72 | + fn=utils.clear_all, |
| 73 | + inputs=None, |
| 74 | + outputs=[img_in, img_out, size, bg, download_size, f1]) |
| 75 | + |
180 | 76 | gr.Button.style(1)
|
181 | 77 |
|
182 |
| -demo.launch(share=True) |
| 78 | +demo.launch() |
0 commit comments