File size: 1,597 Bytes
98159fd
3b6e19a
 
759290e
 
 
 
45a5093
4bab72e
 
759290e
 
45a5093
4bab72e
98159fd
44d7e8f
759290e
 
 
b618d36
 
 
45a5093
 
 
4bab72e
 
45a5093
 
 
 
 
 
 
4bab72e
45a5093
4bab72e
45a5093
 
 
4bab72e
b618d36
 
4bab72e
44d7e8f
759290e
45a5093
759290e
 
4bab72e
759290e
 
 
44d7e8f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)

import os

import gradio as gr
from huggingface_hub import snapshot_download
import spaces
import torch
from transformers import T5TokenizerFast

from pix2struct.modeling import Pix2StructModel
from pix2struct.processing import extract_patches
from pix2struct.inference import ask_generator, generate, DocumentQueries, DocumentQuery


hub_token = os.environ.get('HUB_TOKEN')
model_path = snapshot_download('artyomxyz/pix2struct-docmatix', use_auth_token=hub_token)

model = Pix2StructModel.load(model_path)
model.eval()
model = model.to('cuda')
tokenizer = T5TokenizerFast.from_pretrained('google/pix2struct-base')

@spaces.GPU
def ask(image, questions):
    questions = questions.split('\n')
    documents = [
        DocumentQueries(
            meta=None,
            patches=extract_patches([image]),
            queries=[
                DocumentQuery(
                    meta=None,
                    generator=ask_generator(tokenizer, question)
                )
                for question in questions
            ]
        )
    ]
    with torch.inference_mode():
        with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
            result = generate(model, documents, device='cuda')
    return '\n'.join([q.output for q in result[0].queries])

demo = gr.Interface(
    fn=ask,
    inputs=[
        gr.Image(type='numpy'),
        gr.Textbox(label="Questions (one question per line)"),
    ],
    outputs='text'
)
demo.launch()