File size: 4,215 Bytes
7e430ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4009180
7e430ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4009180
7e430ac
 
 
 
 
 
 
 
 
 
 
4009180
7e430ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# App code based on: https://github.com/petergro-hub/ComicInpainting
# Model based on: https://github.com/saic-mdal/lama

import numpy as np
import pandas as pd
import streamlit as st
import os
from datetime import datetime
from PIL import Image
from streamlit_drawable_canvas import st_canvas
from io import BytesIO
from copy import deepcopy

from src.core import process_inpaint


def image_download_button(pil_image, filename: str, fmt: str, label="Download"):
    if fmt not in ["jpg", "png"]:
        raise Exception(f"Unknown image format (Available: {fmt} - case sensitive)")
    
    pil_format = "JPEG" if fmt == "jpg" else "PNG"
    file_format = "jpg" if fmt == "jpg" else "png"
    mime = "image/jpeg" if fmt == "jpg" else "image/png"
    
    buf = BytesIO()
    pil_image.save(buf, format=pil_format)
    
    return st.download_button(
        label=label,
        data=buf.getvalue(),
        file_name=f'{filename}.{file_format}',
        mime=mime,
    )



if "button_id" not in st.session_state:
    st.session_state["button_id"] = ""
if "color_to_label" not in st.session_state:
    st.session_state["color_to_label"] = {}

if 'reuse_image' not in st.session_state:
    st.session_state.reuse_image = None
def set_image(img):
    st.session_state.reuse_image = img

st.title("AI Photo Object Removal")

st.image(open("assets/demo.png", "rb").read())

st.markdown(
    """
    So you want to remove an object in your photo? You don't need to learn photo editing skills.
    **Just draw over the parts of the image you want to remove, then our AI will remove them.**
    """
)
uploaded_file = st.file_uploader("Choose image", accept_multiple_files=False, type=["png", "jpg", "jpeg"])

if uploaded_file is not None:
    
    if st.session_state.reuse_image is not None:
        img_input = Image.fromarray(st.session_state.reuse_image)
    else:
        bytes_data = uploaded_file.getvalue()
        img_input = Image.open(BytesIO(bytes_data)).convert("RGBA")

    stroke_width = st.slider("Brush size", 1, 100, 50)

    st.write("**Now draw (brush) the part of image that you want to remove.**")
    
    # Canvas size logic
    canvas_bg = deepcopy(img_input)
    aspect_ratio = canvas_bg.width / canvas_bg.height
    streamlit_width = 720
    
    # Max width is 720. Resize the height to maintain its aspectratio.
    if canvas_bg.width > streamlit_width:
        canvas_bg = canvas_bg.resize((streamlit_width, int(streamlit_width / aspect_ratio)))
    
    canvas_result = st_canvas(
        stroke_color="rgba(255, 0, 255, 1)",
        stroke_width=stroke_width,
        background_image=canvas_bg,
        width=canvas_bg.width,
        height=canvas_bg.height,
        drawing_mode="freedraw",
        key="compute_arc_length", 
    )
    
    if canvas_result.image_data is not None:
        im = np.array(Image.fromarray(canvas_result.image_data.astype(np.uint8)).resize(img_input.size))
        background = np.where(
            (im[:, :, 0] == 0) & 
            (im[:, :, 1] == 0) & 
            (im[:, :, 2] == 0)
        )
        drawing = np.where(
            (im[:, :, 0] == 255) & 
            (im[:, :, 1] == 0) & 
            (im[:, :, 2] == 255)
        )
        im[background]=[0,0,0,255]
        im[drawing]=[0,0,0,0] # RGBA
        
        reuse = False
        
        if st.button('Submit'):
            
            with st.spinner("AI is doing the magic!"):
                output = process_inpaint(np.array(img_input), np.array(im)) #TODO Put button here
                img_output = Image.fromarray(output).convert("RGB")
            
            st.write("AI has finished the job!")
            st.image(img_output)
            # reuse = st.button('Edit again (Re-use this image)', on_click=set_image, args=(inpainted_img, ))
            
            uploaded_name = os.path.splitext(uploaded_file.name)[0]
            image_download_button(
                pil_image=img_output,
                filename=uploaded_name,
                fmt="jpg",
                label="Download Image"
            )
                
            st.info("**TIP**: If the result is not perfect, you can download it then "
                    "upload then remove the artifacts.")