-
Notifications
You must be signed in to change notification settings - Fork 37
/
app.py
49 lines (45 loc) · 1.57 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import os
import subprocess
from PIL import Image
import streamlit as st
from mirnet.inference import Inferer
def main():
st.markdown(
'<h1 align="center">Low-light Image Enhancement using MIRNet</h1><hr>',
unsafe_allow_html=True
)
inferer = Inferer()
if not os.path.exists('low_light_weights_best.h5'):
st.sidebar.text('Downloading Model weights...')
inferer.download_weights('1sUlRD5MTRKKGxtqyYDpTv7T3jOW6aVAL')
st.sidebar.text('Done')
st.sidebar.text('Building MIRNet Model...')
inferer.build_model(
num_rrg=3, num_mrb=2, channels=64,
weights_path='low_light_weights_best.h5'
)
st.sidebar.text('Done')
uploaded_files = st.sidebar.file_uploader(
'Please Upload your Low-light Images',
accept_multiple_files=True
)
col_1, col_2 = st.beta_columns(2)
if len(uploaded_files) > 0:
for uploaded_file in uploaded_files:
pil_image = Image.open(uploaded_file)
original_image, output_image = inferer.infer_streamlit(pil_image)
with col_1:
st.image(
original_image, use_column_width=True,
caption='Original Image'
)
with col_2:
st.image(
output_image, use_column_width=True,
caption='Predicted Image'
)
st.markdown('---')
if not os.path.exists('low_light_weights_best.h5'):
subprocess.run(['rm', 'low_light_weights_best.h5'])
if __name__ == '__main__':
main()