Bashir Gulistani commited on
Commit
6c944f8
·
unverified ·
1 Parent(s): 715929c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +180 -0
app.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import gradio as gr
3
+ import os
4
+ from PIL import Image, ImageEnhance
5
+ import numpy as np
6
+ import torch
7
+ from torch.autograd import Variable
8
+ from torchvision import transforms
9
+ import torch.nn.functional as F
10
+ import matplotlib.pyplot as plt
11
+ import warnings
12
+ warnings.filterwarnings("ignore")
13
+
14
+ os.system("git clone https://github.com/xuebinqin/DIS")
15
+ os.system("mv DIS/IS-Net/* .")
16
+
17
+ from data_loader_cache import normalize, im_reader, im_preprocess
18
+ from models import *
19
+
20
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
21
+
22
+ if not os.path.exists("saved_models"):
23
+ os.mkdir("saved_models")
24
+ os.system("mv isnet.pth saved_models/")
25
+
26
+ class GOSNormalize(object):
27
+ def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]):
28
+ self.mean = mean
29
+ self.std = std
30
+
31
+ def __call__(self, image):
32
+ image = normalize(image, self.mean, self.std)
33
+ return image
34
+
35
+ transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])])
36
+
37
+ def load_image(im_path, hypar):
38
+ im = im_reader(im_path)
39
+ im, im_shp = im_preprocess(im, hypar["cache_size"])
40
+ im = torch.divide(im, 255.0)
41
+ shape = torch.from_numpy(np.array(im_shp))
42
+ return transform(im).unsqueeze(0), shape.unsqueeze(0)
43
+
44
+ def build_model(hypar, device):
45
+ net = hypar["model"]
46
+ if hypar["model_digit"] == "half":
47
+ net.half()
48
+ for layer in net.modules():
49
+ if isinstance(layer, nn.BatchNorm2d):
50
+ layer.float()
51
+
52
+ net.to(device)
53
+ if hypar["restore_model"] != "":
54
+ net.load_state_dict(torch.load(hypar["model_path"] + "/" + hypar["restore_model"], map_location=device))
55
+ net.eval()
56
+ return net
57
+
58
+ def predict(net, inputs_val, shapes_val, hypar, device):
59
+ net.eval()
60
+ if hypar["model_digit"] == "full":
61
+ inputs_val = inputs_val.type(torch.FloatTensor)
62
+ else:
63
+ inputs_val = inputs_val.type(torch.HalfTensor)
64
+
65
+ inputs_val_v = Variable(inputs_val, requires_grad=False).to(device)
66
+ ds_val = net(inputs_val_v)[0]
67
+ pred_val = ds_val[0][0, :, :, :]
68
+ pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear'))
69
+
70
+ ma = torch.max(pred_val)
71
+ mi = torch.min(pred_val)
72
+ pred_val = (pred_val - mi) / (ma - mi)
73
+
74
+ if device == 'cuda': torch.cuda.empty_cache()
75
+ return (pred_val.detach().cpu().numpy() * 255).astype(np.uint8)
76
+
77
+ hypar = {}
78
+ hypar["model_path"] = "./saved_models"
79
+ hypar["restore_model"] = "isnet.pth"
80
+ hypar["interm_sup"] = False
81
+ hypar["model_digit"] = "full"
82
+ hypar["seed"] = 0
83
+ hypar["cache_size"] = [1024, 1024]
84
+ hypar["input_size"] = [1024, 1024]
85
+ hypar["crop_size"] = [1024, 1024]
86
+ hypar["model"] = ISNetDIS()
87
+
88
+ net = build_model(hypar, device)
89
+
90
+ def inference(image):
91
+ image_path = image
92
+ image_tensor, orig_size = load_image(image_path, hypar)
93
+ mask = predict(net, image_tensor, orig_size, hypar, device)
94
+ pil_mask = Image.fromarray(mask).convert('L')
95
+ im_rgb = Image.open(image).convert("RGB")
96
+ im_rgba = im_rgb.copy()
97
+ im_rgba.putalpha(pil_mask)
98
+ return [im_rgba, pil_mask]
99
+
100
+ # Functions Added From Team
101
+ def rotate_image(image, degrees):
102
+ img = Image.open(image).rotate(degrees)
103
+ return img
104
+
105
+ def resize_image(image, width, height):
106
+ img = Image.open(image).resize((width, height))
107
+ return img
108
+
109
+ def convert_to_grayscale(image):
110
+ img = Image.open(image).convert('L')
111
+ return img
112
+
113
+ def adjust_brightness(image, brightness_factor):
114
+ img = Image.open(image)
115
+ enhancer = ImageEnhance.Brightness(img)
116
+ img_enhanced = enhancer.enhance(brightness_factor)
117
+ return img_enhanced
118
+
119
+ # Custom CSS Added From Team
120
+ custom_css = """
121
+ body {
122
+ background-color: #f0f0f0;
123
+ }
124
+ .gradio-container {
125
+ max-width: 900px;
126
+ margin: auto;
127
+ background-color: #ffffff;
128
+ padding: 20px;
129
+ border-radius: 12px;
130
+ box-shadow: 0px 4px 16px rgba(0, 0, 0, 0.2);
131
+ }
132
+ button.lg {
133
+ background-color: #4CAF50;
134
+ color: white;
135
+ border: none;
136
+ padding: 10px 20px;
137
+ text-align: center;
138
+ text-decoration: none;
139
+ display: inline-block;
140
+ font-size: 16px;
141
+ margin: 4px 2px;
142
+ transition-duration: 0.4s;
143
+ cursor: pointer;
144
+ border-radius: 8px;
145
+ }
146
+ button.lg:hover {
147
+ background-color: #45a049;
148
+ color: white;
149
+ }
150
+ """
151
+
152
+ # Used Some Codes From Yang's Chatbot
153
+ with gr.Blocks(css=custom_css) as interface:
154
+ gr.Markdown(f"# {title}")
155
+ gr.Markdown("<h1 style='text-align: center;'>🚩 Image Processor with Brightness Adjustment 🚩</h1>")
156
+ with gr.Row():
157
+ with gr.Column():
158
+ input_image = gr.Image(label="Input Image", type='filepath')
159
+ rotate_button = gr.Button("Rotate Image")
160
+ resize_button = gr.Button("Resize Image")
161
+ grayscale_button = gr.Button("Convert to Grayscale")
162
+ brightness_slider = gr.Slider(minimum=0.5, maximum=2.0, step=0.1, value=1.0, label="Adjust Brightness")
163
+ submit_button = gr.Button("Submit", variant="primary")
164
+ clear_button = gr.Button("Clear", variant="secondary")
165
+ with gr.Column():
166
+ output_image = gr.Image(label="Output Image")
167
+ mask_image = gr.Image(label="Mask")
168
+
169
+ # AI Generated: Use Gradio Blocks to organize the interface with buttons
170
+ rotate_button.click(rotate_image, inputs=[input_image, gr.Slider(minimum=0, maximum=360, step=1, value=90, label="Rotation Degrees")], outputs=output_image)
171
+ resize_button.click(resize_image, inputs=[input_image, gr.Number(value=512, label="Width"), gr.Number(value=512, label="Height")], outputs=output_image)
172
+ grayscale_button.click(convert_to_grayscale, inputs=input_image, outputs=output_image)
173
+
174
+ brightness_slider.change(adjust_brightness, inputs=[input_image, brightness_slider], outputs=output_image)
175
+
176
+ submit_button.click(inference, inputs=input_image, outputs=[output_image, mask_image])
177
+
178
+ clear_button.click(lambda: (None, None, None), inputs=None, outputs=[input_image, output_image, mask_image])
179
+
180
+ interface.launch(share=True)