artyomxyz's picture
Update app.py
3b6e19a verified
raw
history blame
1.6 kB
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
import os
import gradio as gr
from huggingface_hub import snapshot_download
import spaces
import torch
from transformers import T5TokenizerFast
from pix2struct.modeling import Pix2StructModel
from pix2struct.processing import extract_patches
from pix2struct.inference import ask_generator, generate, DocumentQueries, DocumentQuery
hub_token = os.environ.get('HUB_TOKEN')
model_path = snapshot_download('artyomxyz/pix2struct-docmatix', use_auth_token=hub_token)
model = Pix2StructModel.load(model_path)
model.eval()
model = model.to('cuda')
tokenizer = T5TokenizerFast.from_pretrained('google/pix2struct-base')
@spaces.GPU
def ask(image, questions):
questions = questions.split('\n')
documents = [
DocumentQueries(
meta=None,
patches=extract_patches([image]),
queries=[
DocumentQuery(
meta=None,
generator=ask_generator(tokenizer, question)
)
for question in questions
]
)
]
with torch.inference_mode():
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
result = generate(model, documents, device='cuda')
return '\n'.join([q.output for q in result[0].queries])
demo = gr.Interface(
fn=ask,
inputs=[
gr.Image(type='numpy'),
gr.Textbox(label="Questions (one question per line)"),
],
outputs='text'
)
demo.launch()