unknown
Lutece Vision Space
d5b5b3a
raw
history blame
3.31 kB
import streamlit as st
import torch
from PIL import Image
from transformers import AutoProcessor, AutoModelForCausalLM, AutoConfig
import json
import subprocess
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
# Function to load the model and processor
@st.cache_resource
def load_model_and_processor():
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
config = AutoConfig.from_pretrained("microsoft/Florence-2-base-ft", trust_remote_code=True)
config.vision_config.model_type = "davit"
model = AutoModelForCausalLM.from_pretrained("sujet-ai/Lutece-Vision-Base", config=config, trust_remote_code=True).to(device).eval()
processor = AutoProcessor.from_pretrained("sujet-ai/Lutece-Vision-Base", config=config, trust_remote_code=True)
return model, processor, device
# Function to generate answer
def generate_answer(model, processor, device, image, prompt):
task = "<FinanceQA>"
inputs = processor(text=prompt, images=image, return_tensors="pt").to(device)
generated_ids = model.generate(
input_ids=inputs["input_ids"],
pixel_values=inputs["pixel_values"],
max_new_tokens=1024,
do_sample=False,
num_beams=3,
)
generated_text = processor.batch_decode(generated_ids, skip_special_tokens=False)[0]
parsed_answer = processor.post_process_generation(generated_text, task=task, image_size=(image.width, image.height))
return parsed_answer[task]
# Function to display config without nested expanders
def display_config(config, depth=0):
for key, value in config.items():
if isinstance(value, dict):
st.markdown(f"{' ' * depth}**{key}**:")
display_config(value, depth + 1)
else:
st.markdown(f"{' ' * depth}{key}: {value}")
# Streamlit app
def main():
st.set_page_config(page_title="Lutece-Vision-Base Demo", page_icon="πŸ—Ό", layout="wide", initial_sidebar_state="expanded")
# Title and description
st.title("πŸ—Ό Lutece-Vision-Base Demo")
st.markdown("Upload a financial document and ask questions about it!")
# Sidebar with SujetAI watermark
st.sidebar.image("sujetAI.svg", use_column_width=True)
st.sidebar.markdown("---")
st.sidebar.markdown("Our website : [sujet.ai](https://sujet.ai)")
# Load model and processor
model, processor, device = load_model_and_processor()
# File uploader for document
uploaded_file = st.file_uploader("πŸ“„ Upload a financial document", type=["png", "jpg", "jpeg"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert('RGB')
st.image(image, caption="Uploaded Document", use_column_width=True)
# Question input
question = st.text_input("❓ Ask a question about the document", "")
if st.button("πŸ” Generate Answer"):
with st.spinner("Generating answer..."):
answer = generate_answer(model, processor, device, image, question)
st.success(f"## πŸ’‘ {answer}")
# # Model configuration viewer
# with st.expander("πŸ”§ Model Configuration"):
# config_dict = model.config.to_dict()
# display_config(config_dict)
if __name__ == "__main__":
main()