Spaces:
Sleeping
Sleeping
import subprocess | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
import os | |
import gradio as gr | |
from huggingface_hub import snapshot_download | |
import spaces | |
import torch | |
from transformers import T5TokenizerFast | |
from pix2struct.modeling import Pix2StructModel | |
from pix2struct.processing import extract_patches | |
from pix2struct.inference import ask_generator, generate, DocumentQueries, DocumentQuery | |
hub_token = os.environ.get('HUB_TOKEN') | |
model_path = snapshot_download('artyomxyz/pix2struct-docmatix', use_auth_token=hub_token) | |
model = Pix2StructModel.load(model_path) | |
model.eval() | |
model = model.to('cuda') | |
tokenizer = T5TokenizerFast.from_pretrained('google/pix2struct-base') | |
def ask(image, questions): | |
questions = questions.split('\n') | |
documents = [ | |
DocumentQueries( | |
meta=None, | |
patches=extract_patches([image]), | |
queries=[ | |
DocumentQuery( | |
meta=None, | |
generator=ask_generator(tokenizer, question) | |
) | |
for question in questions | |
] | |
) | |
] | |
with torch.inference_mode(): | |
with torch.autocast(device_type="cuda", dtype=torch.bfloat16): | |
result = generate(model, documents, device='cuda') | |
return '\n'.join([q.output for q in result[0].queries]) | |
demo = gr.Interface( | |
fn=ask, | |
inputs=[ | |
gr.Image(type='numpy'), | |
gr.Textbox(label="Questions (one question per line)"), | |
], | |
outputs='text' | |
) | |
demo.launch() | |