import gradio as gr from transformers import AutoProcessor, Pix2StructForConditionalGeneration import torch from PIL import Image import json import vl_convert as vlc from io import BytesIO device = "cuda" if torch.cuda.is_available() else "cpu" # Load the processor and model processor = AutoProcessor.from_pretrained("google/matcha-base") processor.image_processor.is_vqa = False model = Pix2StructForConditionalGeneration.from_pretrained("martinsinnona/visdecode_B").to(device) model.eval() def generate(image): inputs = processor(images=image, return_tensors="pt", max_patches=1024).to(device) generated_ids = model.generate(flattened_patches=inputs.flattened_patches, attention_mask=inputs.attention_mask, max_length=600) generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0] # Generate the Vega image vega = string_to_vega(generated_caption) vega_image = draw_vega(vega) return generated_caption, vega_image def draw_vega(vega, scale=3): spec = json.dumps(vega, indent=4) png_data = vlc.vegalite_to_png(vl_spec=spec, scale=scale) return Image.open(BytesIO(png_data)) def string_to_vega(string): string = string.replace("'", "\"") vega = json.loads(string) for axis in ["x", "y"]: field = vega["encoding"][axis]["field"] if field == "": vega["encoding"][axis]["field"] = axis vega["encoding"][axis]["title"] = "" else: for entry in vega["data"]["values"]: entry[field] = entry.pop(axis) return vega # Create the Gradio interface iface = gr.Interface( fn=generate, inputs=gr.Image(type="pil"), outputs=[gr.Textbox(), gr.Image(type="pil", label="Generated Vega Image")], title="Image to Vega-Lite", description="Upload an image to generate vega-lite" ) # Launch the interface if __name__ == "__main__": iface.launch(share=True)