import os import subprocess import spaces import torch import gradio as gr from gradio_client.client import DEFAULT_TEMP_DIR from playwright.sync_api import sync_playwright from threading import Thread from transformers import AutoProcessor, AutoModelForCausalLM, TextIteratorStreamer from transformers.image_utils import to_numpy_array, PILImageResampling, ChannelDimension from typing import List from PIL import Image from transformers.image_transforms import resize, to_channel_dimension_format subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) DEVICE = torch.device("cuda") PROCESSOR = AutoProcessor.from_pretrained( "HuggingFaceM4/VLM_WebSight_finetuned", ) MODEL = AutoModelForCausalLM.from_pretrained( "HuggingFaceM4/VLM_WebSight_finetuned", trust_remote_code=True, torch_dtype=torch.bfloat16, ).to(DEVICE) if MODEL.config.use_resampler: image_seq_len = MODEL.config.perceiver_config.resampler_n_latents else: image_seq_len = ( MODEL.config.vision_config.image_size // MODEL.config.vision_config.patch_size ) ** 2 BOS_TOKEN = PROCESSOR.tokenizer.bos_token BAD_WORDS_IDS = PROCESSOR.tokenizer(["", ""], add_special_tokens=False).input_ids ## Utils def convert_to_rgb(image): # `image.convert("RGB")` would only work for .jpg images, as it creates a wrong background # for transparent images. The call to `alpha_composite` handles this case if image.mode == "RGB": return image image_rgba = image.convert("RGBA") background = Image.new("RGBA", image_rgba.size, (255, 255, 255)) alpha_composite = Image.alpha_composite(background, image_rgba) alpha_composite = alpha_composite.convert("RGB") return alpha_composite # The processor is the same as the Idefics processor except for the BICUBIC interpolation inside siglip, # so this is a hack in order to redefine ONLY the transform method def custom_transform(x): x = convert_to_rgb(x) x = to_numpy_array(x) x = resize(x, (960, 960), resample=PILImageResampling.BILINEAR) x = PROCESSOR.image_processor.rescale(x, scale=1 / 255) x = PROCESSOR.image_processor.normalize( x, mean=PROCESSOR.image_processor.image_mean, std=PROCESSOR.image_processor.image_std ) x = to_channel_dimension_format(x, ChannelDimension.FIRST) x = torch.tensor(x) return x ## End of Utils IMAGE_GALLERY_PATHS = [ f"example_images/{ex_image}" for ex_image in os.listdir(f"example_images") ] def install_playwright(): try: subprocess.run(["playwright", "install"], check=True) print("Playwright installation successful.") except subprocess.CalledProcessError as e: print(f"Error during Playwright installation: {e}") install_playwright() def add_file_gallery( selected_state: gr.SelectData, gallery_list: List[str] ): return Image.open(gallery_list.root[selected_state.index].image.path) def render_webpage( html_css_code, ): with sync_playwright() as p: browser = p.chromium.launch(headless=True) context = browser.new_context( user_agent=( "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/107.0.0.0" " Safari/537.36" ) ) page = context.new_page() page.set_content(html_css_code) page.wait_for_load_state("networkidle") output_path_screenshot = f"{DEFAULT_TEMP_DIR}/{hash(html_css_code)}.png" _ = page.screenshot(path=output_path_screenshot, full_page=True) context.close() browser.close() return Image.open(output_path_screenshot) @spaces.GPU(duration=300) def model_inference( image, ): if image is None: raise ValueError("`image` is None. It should be a PIL image.") inputs = PROCESSOR.tokenizer( f"{BOS_TOKEN}{'' * image_seq_len}", return_tensors="pt", add_special_tokens=False, ) inputs["pixel_values"] = PROCESSOR.image_processor( [image], transform=custom_transform ) inputs = { k: v.to(DEVICE) for k, v in inputs.items() } streamer = TextIteratorStreamer( PROCESSOR.tokenizer, skip_prompt=True, ) generation_kwargs = dict( inputs, bad_words_ids=BAD_WORDS_IDS, max_length=4096, streamer=streamer, ) # Regular generation version # generation_kwargs.pop("streamer") # generated_ids = MODEL.generate(**generation_kwargs) # generated_text = PROCESSOR.batch_decode( # generated_ids, # skip_special_tokens=True # )[0] # rendered_page = render_webpage(generated_text) # return generated_text, rendered_page # Token streaming version thread = Thread( target=MODEL.generate, kwargs=generation_kwargs, ) thread.start() generated_text = "" for new_text in streamer: if "" in new_text: new_text = new_text.replace("", "") rendered_image = render_webpage(generated_text) else: rendered_image = None generated_text += new_text yield generated_text, rendered_image generated_html = gr.Code( label="Extracted HTML", elem_id="generated_html", ) rendered_html = gr.Image( label="Rendered HTML", show_download_button=False, show_share_button=False, ) # rendered_html = gr.HTML( # label="Rendered HTML" # ) css = """ .gradio-container{max-width: 1000px!important} h1{display: flex;align-items: center;justify-content: center;gap: .25em} *{transition: width 0.5s ease, flex-grow 0.5s ease} """ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as demo: gr.Markdown( "Since the model used for this demo *does not generate images*, it is more effective to input standalone website elements or sites with minimal image content." ) with gr.Row(equal_height=True): with gr.Column(scale=4, min_width=250) as upload_area: imagebox = gr.Image( type="pil", label="Screenshot to extract", visible=True, sources=["upload", "clipboard"], ) with gr.Group(): with gr.Row(): submit_btn = gr.Button( value="▶️ Submit", visible=True, min_width=120 ) clear_btn = gr.ClearButton( [imagebox, generated_html, rendered_html], value="🧹 Clear", min_width=120 ) regenerate_btn = gr.Button( value="🔄 Regenerate", visible=True, min_width=120 ) with gr.Column(scale=4): rendered_html.render() with gr.Row(): generated_html.render() with gr.Row(): template_gallery = gr.Gallery( value=IMAGE_GALLERY_PATHS, label="Templates Gallery", allow_preview=False, columns=5, elem_id="gallery", show_share_button=False, height=400, ) gr.on( triggers=[ imagebox.upload, submit_btn.click, regenerate_btn.click, ], fn=model_inference, inputs=[imagebox], outputs=[generated_html, rendered_html], ) regenerate_btn.click( fn=model_inference, inputs=[imagebox], outputs=[generated_html, rendered_html], ) template_gallery.select( fn=add_file_gallery, inputs=[template_gallery], outputs=[imagebox], ).success( fn=model_inference, inputs=[imagebox], outputs=[generated_html, rendered_html], ) demo.load() demo.queue(max_size=40, api_open=False) demo.launch(max_threads=400)