Spaces:
Sleeping
Sleeping
Andrei Boiarov
commited on
Commit
•
859c3ef
1
Parent(s):
3d04a5c
Update app file
Browse files- .gitignore +2 -1
- app.py +61 -4
- requirements.txt +2 -0
.gitignore
CHANGED
@@ -1 +1,2 @@
|
|
1 |
-
.idea/
|
|
|
|
1 |
+
.idea/
|
2 |
+
flagged/
|
app.py
CHANGED
@@ -1,7 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
|
3 |
-
|
4 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
-
|
7 |
-
|
|
|
1 |
+
from transformers import ViTFeatureExtractor, ViTMAEForPreTraining
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from PIL import Image
|
5 |
+
|
6 |
import gradio as gr
|
7 |
|
8 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained('andrewbo29/vit-mae-base-formula1')
|
9 |
+
model = ViTMAEForPreTraining.from_pretrained('andrewbo29/vit-mae-base-formula1')
|
10 |
+
|
11 |
+
imagenet_mean = np.array(feature_extractor.image_mean)
|
12 |
+
imagenet_std = np.array(feature_extractor.image_std)
|
13 |
+
|
14 |
+
|
15 |
+
def prep_image(image):
|
16 |
+
return torch.clip((image * imagenet_std + imagenet_mean) * 255, 0, 255).int().cpu().numpy()
|
17 |
+
|
18 |
+
|
19 |
+
def reconstruct(img):
|
20 |
+
image = Image.fromarray(img)
|
21 |
+
pixel_values = feature_extractor(image, return_tensors='pt').pixel_values
|
22 |
+
|
23 |
+
outputs = model(pixel_values)
|
24 |
+
y = model.unpatchify(outputs.logits)
|
25 |
+
y = torch.einsum('nchw->nhwc', y).detach().cpu()
|
26 |
+
|
27 |
+
# visualize the mask
|
28 |
+
mask = outputs.mask.detach()
|
29 |
+
mask = mask.unsqueeze(-1).repeat(1, 1, model.config.patch_size ** 2 * 3) # (N, H*W, p*p*3)
|
30 |
+
mask = model.unpatchify(mask) # 1 is removing, 0 is keeping
|
31 |
+
mask = torch.einsum('nchw->nhwc', mask).detach().cpu()
|
32 |
+
|
33 |
+
x = torch.einsum('nchw->nhwc', pixel_values).detach().cpu()
|
34 |
+
|
35 |
+
# masked image
|
36 |
+
im_masked = x * (1 - mask)
|
37 |
+
|
38 |
+
# MAE reconstruction pasted with visible patches
|
39 |
+
im_paste = x * (1 - mask) + y * mask
|
40 |
+
|
41 |
+
out_masked = prep_image(im_masked[0])
|
42 |
+
out_rec = prep_image(y[0])
|
43 |
+
out_rec_vis = prep_image(im_paste[0])
|
44 |
+
|
45 |
+
return [(out_masked, 'masked'), (out_rec, 'reconstruction'), (out_rec_vis, 'reconstruction + visible')]
|
46 |
+
|
47 |
+
|
48 |
+
with gr.Blocks() as demo:
|
49 |
+
with gr.Column(variant="panel"):
|
50 |
+
with gr.Row():
|
51 |
+
img = gr.Image(
|
52 |
+
label="Enter your prompt",
|
53 |
+
container=False,
|
54 |
+
)
|
55 |
+
btn = gr.Button("Generate image", scale=0)
|
56 |
+
|
57 |
+
gallery = gr.Gallery(
|
58 |
+
label="Generated images", show_label=False, elem_id="gallery"
|
59 |
+
, columns=[3], rows=[1], object_fit="contain", height='auto', container=True)
|
60 |
+
|
61 |
+
btn.click(reconstruct, img, gallery)
|
62 |
|
63 |
+
if __name__ == "__main__":
|
64 |
+
demo.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
torch
|
2 |
+
transformers
|