Bashir Gulistani
commited on
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import gradio as gr
|
3 |
+
import os
|
4 |
+
from PIL import Image, ImageEnhance
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch.autograd import Variable
|
8 |
+
from torchvision import transforms
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import warnings
|
12 |
+
warnings.filterwarnings("ignore")
|
13 |
+
|
14 |
+
os.system("git clone https://github.com/xuebinqin/DIS")
|
15 |
+
os.system("mv DIS/IS-Net/* .")
|
16 |
+
|
17 |
+
from data_loader_cache import normalize, im_reader, im_preprocess
|
18 |
+
from models import *
|
19 |
+
|
20 |
+
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
21 |
+
|
22 |
+
if not os.path.exists("saved_models"):
|
23 |
+
os.mkdir("saved_models")
|
24 |
+
os.system("mv isnet.pth saved_models/")
|
25 |
+
|
26 |
+
class GOSNormalize(object):
|
27 |
+
def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
|
28 |
+
self.mean = mean
|
29 |
+
self.std = std
|
30 |
+
|
31 |
+
def __call__(self, image):
|
32 |
+
image = normalize(image, self.mean, self.std)
|
33 |
+
return image
|
34 |
+
|
35 |
+
transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])])
|
36 |
+
|
37 |
+
def load_image(im_path, hypar):
|
38 |
+
im = im_reader(im_path)
|
39 |
+
im, im_shp = im_preprocess(im, hypar["cache_size"])
|
40 |
+
im = torch.divide(im, 255.0)
|
41 |
+
shape = torch.from_numpy(np.array(im_shp))
|
42 |
+
return transform(im).unsqueeze(0), shape.unsqueeze(0)
|
43 |
+
|
44 |
+
def build_model(hypar, device):
|
45 |
+
net = hypar["model"]
|
46 |
+
if hypar["model_digit"] == "half":
|
47 |
+
net.half()
|
48 |
+
for layer in net.modules():
|
49 |
+
if isinstance(layer, nn.BatchNorm2d):
|
50 |
+
layer.float()
|
51 |
+
|
52 |
+
net.to(device)
|
53 |
+
if hypar["restore_model"] != "":
|
54 |
+
net.load_state_dict(torch.load(hypar["model_path"] + "/" + hypar["restore_model"], map_location=device))
|
55 |
+
net.eval()
|
56 |
+
return net
|
57 |
+
|
58 |
+
def predict(net, inputs_val, shapes_val, hypar, device):
|
59 |
+
net.eval()
|
60 |
+
if hypar["model_digit"] == "full":
|
61 |
+
inputs_val = inputs_val.type(torch.FloatTensor)
|
62 |
+
else:
|
63 |
+
inputs_val = inputs_val.type(torch.HalfTensor)
|
64 |
+
|
65 |
+
inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
|
66 |
+
ds_val = net(inputs_val_v)[0]
|
67 |
+
pred_val = ds_val[0][0, :, :, :]
|
68 |
+
pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear'))
|
69 |
+
|
70 |
+
ma = torch.max(pred_val)
|
71 |
+
mi = torch.min(pred_val)
|
72 |
+
pred_val = (pred_val - mi) / (ma - mi)
|
73 |
+
|
74 |
+
if device == 'cuda': torch.cuda.empty_cache()
|
75 |
+
return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
|
76 |
+
|
77 |
+
hypar = {}
|
78 |
+
hypar["model_path"] = "./saved_models"
|
79 |
+
hypar["restore_model"] = "isnet.pth"
|
80 |
+
hypar["interm_sup"] = False
|
81 |
+
hypar["model_digit"] = "full"
|
82 |
+
hypar["seed"] = 0
|
83 |
+
hypar["cache_size"] = [1024, 1024]
|
84 |
+
hypar["input_size"] = [1024, 1024]
|
85 |
+
hypar["crop_size"] = [1024, 1024]
|
86 |
+
hypar["model"] = ISNetDIS()
|
87 |
+
|
88 |
+
net = build_model(hypar, device)
|
89 |
+
|
90 |
+
def inference(image):
|
91 |
+
image_path = image
|
92 |
+
image_tensor, orig_size = load_image(image_path, hypar)
|
93 |
+
mask = predict(net, image_tensor, orig_size, hypar, device)
|
94 |
+
pil_mask = Image.fromarray(mask).convert('L')
|
95 |
+
im_rgb = Image.open(image).convert("RGB")
|
96 |
+
im_rgba = im_rgb.copy()
|
97 |
+
im_rgba.putalpha(pil_mask)
|
98 |
+
return [im_rgba, pil_mask]
|
99 |
+
|
100 |
+
# Functions Added From Team
|
101 |
+
def rotate_image(image, degrees):
|
102 |
+
img = Image.open(image).rotate(degrees)
|
103 |
+
return img
|
104 |
+
|
105 |
+
def resize_image(image, width, height):
|
106 |
+
img = Image.open(image).resize((width, height))
|
107 |
+
return img
|
108 |
+
|
109 |
+
def convert_to_grayscale(image):
|
110 |
+
img = Image.open(image).convert('L')
|
111 |
+
return img
|
112 |
+
|
113 |
+
def adjust_brightness(image, brightness_factor):
|
114 |
+
img = Image.open(image)
|
115 |
+
enhancer = ImageEnhance.Brightness(img)
|
116 |
+
img_enhanced = enhancer.enhance(brightness_factor)
|
117 |
+
return img_enhanced
|
118 |
+
|
119 |
+
# Custom CSS Added From Team
|
120 |
+
custom_css = """
|
121 |
+
body {
|
122 |
+
background-color: #f0f0f0;
|
123 |
+
}
|
124 |
+
.gradio-container {
|
125 |
+
max-width: 900px;
|
126 |
+
margin: auto;
|
127 |
+
background-color: #ffffff;
|
128 |
+
padding: 20px;
|
129 |
+
border-radius: 12px;
|
130 |
+
box-shadow: 0px 4px 16px rgba(0, 0, 0, 0.2);
|
131 |
+
}
|
132 |
+
button.lg {
|
133 |
+
background-color: #4CAF50;
|
134 |
+
color: white;
|
135 |
+
border: none;
|
136 |
+
padding: 10px 20px;
|
137 |
+
text-align: center;
|
138 |
+
text-decoration: none;
|
139 |
+
display: inline-block;
|
140 |
+
font-size: 16px;
|
141 |
+
margin: 4px 2px;
|
142 |
+
transition-duration: 0.4s;
|
143 |
+
cursor: pointer;
|
144 |
+
border-radius: 8px;
|
145 |
+
}
|
146 |
+
button.lg:hover {
|
147 |
+
background-color: #45a049;
|
148 |
+
color: white;
|
149 |
+
}
|
150 |
+
"""
|
151 |
+
|
152 |
+
# Used Some Codes From Yang's Chatbot
|
153 |
+
with gr.Blocks(css=custom_css) as interface:
|
154 |
+
gr.Markdown(f"# {title}")
|
155 |
+
gr.Markdown("<h1 style='text-align: center;'>🚩 Image Processor with Brightness Adjustment 🚩</h1>")
|
156 |
+
with gr.Row():
|
157 |
+
with gr.Column():
|
158 |
+
input_image = gr.Image(label="Input Image", type='filepath')
|
159 |
+
rotate_button = gr.Button("Rotate Image")
|
160 |
+
resize_button = gr.Button("Resize Image")
|
161 |
+
grayscale_button = gr.Button("Convert to Grayscale")
|
162 |
+
brightness_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Adjust Brightness")
|
163 |
+
submit_button = gr.Button("Submit", variant="primary")
|
164 |
+
clear_button = gr.Button("Clear", variant="secondary")
|
165 |
+
with gr.Column():
|
166 |
+
output_image = gr.Image(label="Output Image")
|
167 |
+
mask_image = gr.Image(label="Mask")
|
168 |
+
|
169 |
+
# AI Generated: Use Gradio Blocks to organize the interface with buttons
|
170 |
+
rotate_button.click(rotate_image, inputs=[input_image, gr.Slider(minimum=0, maximum=360, step=1, value=90, label="Rotation Degrees")], outputs=output_image)
|
171 |
+
resize_button.click(resize_image, inputs=[input_image, gr.Number(value=512, label="Width"), gr.Number(value=512, label="Height")], outputs=output_image)
|
172 |
+
grayscale_button.click(convert_to_grayscale, inputs=input_image, outputs=output_image)
|
173 |
+
|
174 |
+
brightness_slider.change(adjust_brightness, inputs=[input_image, brightness_slider], outputs=output_image)
|
175 |
+
|
176 |
+
submit_button.click(inference, inputs=input_image, outputs=[output_image, mask_image])
|
177 |
+
|
178 |
+
clear_button.click(lambda: (None, None, None), inputs=None, outputs=[input_image, output_image, mask_image])
|
179 |
+
|
180 |
+
interface.launch(share=True)
|