Mrahsanahmad commited on
Commit
369d822
1 Parent(s): f822c09

try subplotting

Browse files
Files changed (2) hide show
  1. app.py +169 -4
  2. app.py.orig +122 -0
app.py CHANGED
@@ -1,7 +1,172 @@
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- iface = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- iface.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import torch
4
+ import jax
5
+ import jax.numpy as jnp
6
+ from flax.jax_utils import replicate
7
+ from flax.training.common_utils import shard
8
+ from PIL import Image
9
+ from segment_anything import SamPredictor, sam_model_registry, SamAutomaticMaskGenerator
10
+ from diffusers import (
11
+ FlaxStableDiffusionControlNetPipeline,
12
+ FlaxControlNetModel,
13
+ )
14
+ from transformers import pipeline
15
 
16
+ import colorsys
 
17
 
18
+ sam_checkpoint = "sam_vit_h_4b8939.pth"
19
+ model_type = "vit_h"
20
+ device = "cuda" if torch.cuda.is_available() else "cpu"
21
+
22
+
23
+ #sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
24
+ #sam.to(device=device)
25
+ #predictor = SamPredictor(sam)
26
+ #mask_generator = SamAutomaticMaskGenerator(sam)
27
+
28
+ generator = pipeline(model="facebook/sam-vit-base", task="mask-generation", points_per_batch=256)
29
+ #image_url = "https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
30
+
31
+ controlnet, controlnet_params = FlaxControlNetModel.from_pretrained(
32
+ "SAMControlNet/sd-controlnet-sam-seg", dtype=jnp.float32
33
+ )
34
+
35
+ pipe, params = FlaxStableDiffusionControlNetPipeline.from_pretrained(
36
+ "runwayml/stable-diffusion-v1-5",
37
+ controlnet=controlnet,
38
+ revision="flax",
39
+ dtype=jnp.bfloat16,
40
+ )
41
+
42
+ params["controlnet"] = controlnet_params
43
+ p_params = replicate(params)
44
+
45
+
46
+ with gr.Blocks() as demo:
47
+ gr.Markdown("# Ahsans version WildSynth: Synthetic Wildlife Data Generation")
48
+ gr.Markdown(
49
+ """
50
+ ## Work in Progress
51
+ ### About
52
+ We have trained a JAX ControlNet model for semantic segmentation on Wildlife Animal Images.
53
+
54
+ For the training data creation we used the [Wildlife Animals Images](https://www.kaggle.com/datasets/anshulmehtakaggl/wildlife-animals-images) dataset.
55
+ We created segmentation masks with the help of [Grounded SAM](https://github.com/IDEA-Research/Grounded-Segment-Anything) where we used the animals names
56
+ as input prompts for detection and more accurate segmentation.
57
+
58
+ ### How To Use
59
+
60
+ """
61
+ )
62
+ with gr.Row():
63
+ input_img = gr.Image(label="Input", type="pil")
64
+ mask_img = gr.Image(label="Mask", interactive=False)
65
+ output_img = gr.Image(label="Output", interactive=False)
66
+
67
+ with gr.Row():
68
+ prompt_text = gr.Textbox(lines=1, label="Prompt")
69
+ negative_prompt_text = gr.Textbox(lines=1, label="Negative Prompt")
70
+
71
+ with gr.Row():
72
+ submit = gr.Button("Submit")
73
+ clear = gr.Button("Clear")
74
+
75
+ def generate_mask(image):
76
+ outputs = generator(image, points_per_batch=256)
77
+ mask_images = []
78
+ for mask in outputs["masks"]:
79
+ color = np.concatenate([np.random.random(3), np.array([1.0])], axis=0)
80
+ h, w = mask.shape[-2:]
81
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
82
+ mask_images.append(mask_image)
83
+
84
+ return np.stack(mask_images)
85
+
86
+ def infer(
87
+ image, prompts, negative_prompts, num_inference_steps=50, seed=4, num_samples=4
88
+ ):
89
+ try:
90
+ rng = jax.random.PRNGKey(int(seed))
91
+ num_inference_steps = int(num_inference_steps)
92
+ image = Image.fromarray(image, mode="RGB")
93
+ num_samples = max(jax.device_count(), int(num_samples))
94
+ p_rng = jax.random.split(rng, jax.device_count())
95
+
96
+ prompt_ids = pipe.prepare_text_inputs([prompts] * num_samples)
97
+ negative_prompt_ids = pipe.prepare_text_inputs(
98
+ [negative_prompts] * num_samples
99
+ )
100
+ processed_image = pipe.prepare_image_inputs([image] * num_samples)
101
+
102
+ prompt_ids = shard(prompt_ids)
103
+ negative_prompt_ids = shard(negative_prompt_ids)
104
+ processed_image = shard(processed_image)
105
+
106
+ output = pipe(
107
+ prompt_ids=prompt_ids,
108
+ image=processed_image,
109
+ params=p_params,
110
+ prng_seed=p_rng,
111
+ num_inference_steps=num_inference_steps,
112
+ neg_prompt_ids=negative_prompt_ids,
113
+ jit=True,
114
+ ).images
115
+
116
+ del negative_prompt_ids
117
+ del processed_image
118
+ del prompt_ids
119
+
120
+ output = output.reshape((num_samples,) + output.shape[-3:])
121
+ final_image = [np.array(x * 255, dtype=np.uint8) for x in output]
122
+ print(output.shape)
123
+ del output
124
+
125
+ except Exception as e:
126
+ print("Error: " + str(e))
127
+ final_image = [np.zeros((512, 512, 3), dtype=np.uint8)] * num_samples
128
+ finally:
129
+ gc.collect()
130
+ return final_image
131
+
132
+ def _clear(sel_pix, img, mask, seg, out, prompt, neg_prompt, bg):
133
+ img = None
134
+ mask = None
135
+ seg = None
136
+ out = None
137
+ prompt = ""
138
+ neg_prompt = ""
139
+ bg = False
140
+ return img, mask, seg, out, prompt, neg_prompt, bg
141
+
142
+ input_img.change(
143
+ generate_mask,
144
+ inputs=[input_img],
145
+ outputs=[mask_img],
146
+ )
147
+ submit.click(
148
+ infer,
149
+ inputs=[mask_img, prompt_text, negative_prompt_text],
150
+ outputs=[output_img],
151
+ )
152
+ clear.click(
153
+ _clear,
154
+ inputs=[
155
+ input_img,
156
+ mask_img,
157
+ output_img,
158
+ prompt_text,
159
+ negative_prompt_text,
160
+ ],
161
+ outputs=[
162
+ input_img,
163
+ mask_img,
164
+ output_img,
165
+ prompt_text,
166
+ negative_prompt_text,
167
+ ],
168
+ )
169
+
170
+ if __name__ == "__main__":
171
+ demo.queue()
172
+ demo.launch()
app.py.orig ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ from PIL import Image
5
+ import requests
6
+ from transformers import SamModel, SamProcessor
7
+ import numpy as np
8
+
9
+
10
+ def show_mask(mask, ax, random_color=False):
11
+ if random_color:
12
+ color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
13
+ else:
14
+ color = np.array([30/255, 144/255, 255/255, 0.6])
15
+ h, w = mask.shape[-2:]
16
+ mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
17
+ ax.imshow(mask_image)
18
+
19
+
20
+ def show_box(box, ax):
21
+ x0, y0 = box[0], box[1]
22
+ w, h = box[2] - box[0], box[3] - box[1]
23
+ ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
24
+
25
+ def show_boxes_on_image(raw_image, boxes):
26
+ plt.figure(figsize=(10,10))
27
+ plt.imshow(raw_image)
28
+ for box in boxes:
29
+ show_box(box, plt.gca())
30
+ plt.axis('on')
31
+ plt.show()
32
+
33
+
34
+ def show_points_on_image(raw_image, input_points, input_labels=None):
35
+ plt.figure(figsize=(10,10))
36
+ plt.imshow(raw_image)
37
+ input_points = np.array(input_points)
38
+ if input_labels is None:
39
+ labels = np.ones_like(input_points[:, 0])
40
+ else:
41
+ labels = np.array(input_labels)
42
+ show_points(input_points, labels, plt.gca())
43
+ plt.axis('on')
44
+ plt.show()
45
+
46
+ def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
47
+ plt.figure(figsize=(10,10))
48
+ plt.imshow(raw_image)
49
+ input_points = np.array(input_points)
50
+ if input_labels is None:
51
+ labels = np.ones_like(input_points[:, 0])
52
+ else:
53
+ labels = np.array(input_labels)
54
+ show_points(input_points, labels, plt.gca())
55
+ for box in boxes:
56
+ show_box(box, plt.gca())
57
+ plt.axis('on')
58
+ plt.show()
59
+
60
+
61
+ def show_points_and_boxes_on_image(raw_image, boxes, input_points, input_labels=None):
62
+ plt.figure(figsize=(10,10))
63
+ plt.imshow(raw_image)
64
+ input_points = np.array(input_points)
65
+ if input_labels is None:
66
+ labels = np.ones_like(input_points[:, 0])
67
+ else:
68
+ labels = np.array(input_labels)
69
+ show_points(input_points, labels, plt.gca())
70
+ for box in boxes:
71
+ show_box(box, plt.gca())
72
+ plt.axis('on')
73
+ plt.show()
74
+
75
+
76
+ def show_points(coords, labels, ax, marker_size=375):
77
+ pos_points = coords[labels==1]
78
+ neg_points = coords[labels==0]
79
+ ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
80
+ ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
81
+
82
+
83
+ def apply_masks_on_image(raw_image, masks, scores):
84
+ if len(masks.shape) == 4:
85
+ masks = masks.squeeze()
86
+ if scores.shape[0] == 1:
87
+ scores = scores.squeeze()
88
+
89
+ nb_predictions = scores.shape[-1]
90
+ fig, axes = plt.subplots(1, nb_predictions, figsize=(15, 15))
91
+
92
+ for i, (mask, score) in enumerate(zip(masks, scores)):
93
+ mask = mask.cpu().detach()
94
+ axes[i].imshow(np.array(raw_image))
95
+ show_mask(mask, axes[i])
96
+ axes[i].title.set_text(f"Mask {i+1}, Score: {score.item():.3f}")
97
+ axes[i].axis("off")
98
+ plt.show()
99
+
100
+ def segment(imageUrl):
101
+ device = "cuda" if torch.cuda.is_available() else "cpu"
102
+ model = SamModel.from_pretrained("facebook/sam-vit-huge").to(device)
103
+ processor = SamProcessor.from_pretrained("facebook/sam-vit-huge")
104
+
105
+ img_url = imageUrl#"https://huggingface.co/ybelkada/segment-anything/resolve/main/assets/car.png"
106
+ raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB")
107
+ input_points = [[[450, 600]]] # 2D location of a window in the image
108
+
109
+ inputs = processor(raw_image, input_points=input_points, return_tensors="pt").to(device)
110
+ outputs = model(**inputs)
111
+
112
+ masks = processor.image_processor.post_process_masks(
113
+ outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
114
+ )
115
+ scores = outputs.iou_scores
116
+ return {"Masks": masks, "Scores": scores}
117
+
118
+ gr.Interface(fn=predict,
119
+ inputs=gr.Image(type="pil"),
120
+ outputs=[{"type":"dataframe","name":"Categories Scores"},
121
+ {"type":"dataframe","name":"Categories Labels"}],
122
+ ).launch()