Spaces:
Sleeping
Sleeping
VictorSanh
commited on
Commit
•
844c526
1
Parent(s):
5c49818
very big update
Browse files
app.py
CHANGED
@@ -1,24 +1,72 @@
|
|
1 |
import os
|
2 |
import subprocess
|
|
|
3 |
|
|
|
|
|
|
|
4 |
from playwright.sync_api import sync_playwright
|
|
|
|
|
5 |
from typing import List
|
6 |
from PIL import Image
|
7 |
|
8 |
-
import
|
9 |
-
from gradio_client.client import DEFAULT_TEMP_DIR
|
10 |
-
from transformers import AutoProcessor, AutoModelForCausalLM
|
11 |
|
12 |
|
13 |
API_TOKEN = os.getenv("HF_AUTH_TOKEN")
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
#
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
|
24 |
IMAGE_GALLERY_PATHS = [
|
@@ -36,11 +84,13 @@ def install_playwright():
|
|
36 |
|
37 |
install_playwright()
|
38 |
|
|
|
39 |
def add_file_gallery(
|
40 |
selected_state: gr.SelectData,
|
41 |
gallery_list: List[str]
|
42 |
):
|
43 |
-
return
|
|
|
44 |
|
45 |
def render_webpage(
|
46 |
html_css_code,
|
@@ -68,6 +118,22 @@ def render_webpage(
|
|
68 |
def model_inference(
|
69 |
image,
|
70 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
CAR_COMPNAY = """<!DOCTYPE html>
|
72 |
<html lang="en">
|
73 |
<head>
|
@@ -189,8 +255,8 @@ def model_inference(
|
|
189 |
|
190 |
</body>
|
191 |
</html>"""
|
192 |
-
rendered_page = render_webpage(
|
193 |
-
return
|
194 |
|
195 |
|
196 |
generated_html = gr.Code(
|
@@ -216,7 +282,7 @@ with gr.Blocks(title="Img2html", theme=gr.themes.Base(), css=css) as demo:
|
|
216 |
with gr.Row(equal_height=True):
|
217 |
with gr.Column(scale=4, min_width=250) as upload_area:
|
218 |
imagebox = gr.Image(
|
219 |
-
type="
|
220 |
label="Screenshot to extract",
|
221 |
visible=True,
|
222 |
sources=["upload", "clipboard"],
|
@@ -253,7 +319,6 @@ with gr.Blocks(title="Img2html", theme=gr.themes.Base(), css=css) as demo:
|
|
253 |
triggers=[
|
254 |
imagebox.upload,
|
255 |
submit_btn.click,
|
256 |
-
template_gallery.select,
|
257 |
regenerate_btn.click,
|
258 |
],
|
259 |
fn=model_inference,
|
@@ -274,6 +339,10 @@ with gr.Blocks(title="Img2html", theme=gr.themes.Base(), css=css) as demo:
|
|
274 |
inputs=[template_gallery],
|
275 |
outputs=[imagebox],
|
276 |
queue=False,
|
|
|
|
|
|
|
|
|
277 |
)
|
278 |
demo.load(queue=False)
|
279 |
|
|
|
1 |
import os
|
2 |
import subprocess
|
3 |
+
import torch
|
4 |
|
5 |
+
import gradio as gr
|
6 |
+
|
7 |
+
from gradio_client.client import DEFAULT_TEMP_DIR
|
8 |
from playwright.sync_api import sync_playwright
|
9 |
+
from transformers import AutoProcessor, AutoModelForCausalLM
|
10 |
+
from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension
|
11 |
from typing import List
|
12 |
from PIL import Image
|
13 |
|
14 |
+
from transformers.image_transforms import resize, to_channel_dimension_format
|
|
|
|
|
15 |
|
16 |
|
17 |
API_TOKEN = os.getenv("HF_AUTH_TOKEN")
|
18 |
+
DEVICE = torch.device("cuda")
|
19 |
+
PROCESSOR = AutoProcessor.from_pretrained(
|
20 |
+
"HuggingFaceM4/img2html",
|
21 |
+
token=API_TOKEN,
|
22 |
+
)
|
23 |
+
MODEL = AutoModelForCausalLM.from_pretrained(
|
24 |
+
"HuggingFaceM4/img2html", #TODO
|
25 |
+
token=API_TOKEN,
|
26 |
+
trust_remote_code=True,
|
27 |
+
torch_dtype=torch.bfloat16,
|
28 |
+
).to(DEVICE)
|
29 |
+
if MODEL.config.use_resampler:
|
30 |
+
image_seq_len = MODEL.config.perceiver_config.resampler_n_latents
|
31 |
+
else:
|
32 |
+
image_seq_len = (
|
33 |
+
MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size
|
34 |
+
) ** 2
|
35 |
+
BOS_TOKEN = PROCESSOR.tokenizer.bos_token
|
36 |
+
BAD_WORDS_IDS = PROCESSOR.tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids
|
37 |
+
|
38 |
+
|
39 |
+
## Utils
|
40 |
+
|
41 |
+
def convert_to_rgb(image):
|
42 |
+
# `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background
|
43 |
+
# for transparent images. The call to `alpha_composite` handles this case
|
44 |
+
if image.mode == "RGB":
|
45 |
+
return image
|
46 |
+
|
47 |
+
image_rgba = image.convert("RGBA")
|
48 |
+
background = Image.new("RGBA", image_rgba.size, (255, 255, 255))
|
49 |
+
alpha_composite = Image.alpha_composite(background, image_rgba)
|
50 |
+
alpha_composite = alpha_composite.convert("RGB")
|
51 |
+
return alpha_composite
|
52 |
+
|
53 |
+
# The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip,
|
54 |
+
# so this is a hack in order to redefine ONLY the transform method
|
55 |
+
def custom_transform(x):
|
56 |
+
x = convert_to_rgb(x)
|
57 |
+
x = to_numpy_array(x)
|
58 |
+
x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR)
|
59 |
+
x = PROCESSOR.image_processor.rescale(x, scale=1 / 255)
|
60 |
+
x = PROCESSOR.image_processor.normalize(
|
61 |
+
x,
|
62 |
+
mean=PROCESSOR.image_processor.image_mean,
|
63 |
+
std=PROCESSOR.image_processor.image_std
|
64 |
+
)
|
65 |
+
x = to_channel_dimension_format(x, ChannelDimension.FIRST)
|
66 |
+
x = torch.tensor(x)
|
67 |
+
return x
|
68 |
+
|
69 |
+
## End of Utils
|
70 |
|
71 |
|
72 |
IMAGE_GALLERY_PATHS = [
|
|
|
84 |
|
85 |
install_playwright()
|
86 |
|
87 |
+
|
88 |
def add_file_gallery(
|
89 |
selected_state: gr.SelectData,
|
90 |
gallery_list: List[str]
|
91 |
):
|
92 |
+
return Image.open(gallery_list.root[selected_state.index].image.path)
|
93 |
+
|
94 |
|
95 |
def render_webpage(
|
96 |
html_css_code,
|
|
|
118 |
def model_inference(
|
119 |
image,
|
120 |
):
|
121 |
+
if image is None:
|
122 |
+
raise ValueError("`image` is None. It should be a PIL image.")
|
123 |
+
|
124 |
+
inputs = PROCESSOR.tokenizer(
|
125 |
+
f"{BOS_TOKEN}<fake_token_around_image>{'<image>' * image_seq_len}<fake_token_around_image>",
|
126 |
+
return_tensors="pt"
|
127 |
+
)
|
128 |
+
inputs["pixel_values"] = PROCESSOR.image_processor(
|
129 |
+
[image],
|
130 |
+
transform=custom_transform
|
131 |
+
)
|
132 |
+
inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
|
133 |
+
generated_ids = MODEL.generate(**inputs, bad_words_ids=BAD_WORDS_IDS)
|
134 |
+
generated_text = PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0]
|
135 |
+
print(generated_text)
|
136 |
+
|
137 |
CAR_COMPNAY = """<!DOCTYPE html>
|
138 |
<html lang="en">
|
139 |
<head>
|
|
|
255 |
|
256 |
</body>
|
257 |
</html>"""
|
258 |
+
rendered_page = render_webpage(generated_text)
|
259 |
+
return generated_text, rendered_page
|
260 |
|
261 |
|
262 |
generated_html = gr.Code(
|
|
|
282 |
with gr.Row(equal_height=True):
|
283 |
with gr.Column(scale=4, min_width=250) as upload_area:
|
284 |
imagebox = gr.Image(
|
285 |
+
type="pil",
|
286 |
label="Screenshot to extract",
|
287 |
visible=True,
|
288 |
sources=["upload", "clipboard"],
|
|
|
319 |
triggers=[
|
320 |
imagebox.upload,
|
321 |
submit_btn.click,
|
|
|
322 |
regenerate_btn.click,
|
323 |
],
|
324 |
fn=model_inference,
|
|
|
339 |
inputs=[template_gallery],
|
340 |
outputs=[imagebox],
|
341 |
queue=False,
|
342 |
+
).success(
|
343 |
+
fn=model_inference,
|
344 |
+
inputs=[imagebox],
|
345 |
+
outputs=[generated_html, rendered_html],
|
346 |
)
|
347 |
demo.load(queue=False)
|
348 |
|