from dotenv import load_dotenv import os import streamlit as st from PIL import Image from transformers import ViltProcessor, ViltForQuestionAnswering from langchain.prompts import PromptTemplate from langchain.chains import LLMChain from streamlit_extras.add_vertical_space import add_vertical_space from langchain.llms import OpenAI load_dotenv() processor = ViltProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa") model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa") llm = OpenAI(temperature=0.2) prompt = PromptTemplate( input_variables=["question", "elements"], template="""Please generate a structured response using the following information: \n\n #Question: {question} #Response: {elements} \n\n Your structured response:""", ) # Add custom CSS to increase space between images st.markdown( """ """, unsafe_allow_html=True, ) # Add link to the sidebar st.sidebar.markdown("", unsafe_allow_html=True) st.sidebar.markdown("", unsafe_allow_html=True) st.sidebar.markdown("", unsafe_allow_html=True) def process_query(image, query): encoding = processor(image, query, return_tensors="pt") outputs = model(**encoding) logits = outputs.logits idx = logits.argmax(-1).item() chain = LLMChain(llm=llm, prompt=prompt) response = chain.run(question=query, elements=model.config.id2label[idx]) return response st.set_page_config(page_title="Insightly") # Sidebar contents with st.sidebar: st.sidebar.image("https://i.ibb.co/bX6GdqG/insightly-wbg.png", use_column_width=True) load_dotenv() def main(): st.title("Chat With Images 🖼️") uploaded_file = st.file_uploader('Upload your Image', type=['png', 'jpeg', 'jpg']) if uploaded_file is not None: image = Image.open(uploaded_file) st.image(image, caption='Uploaded Image.', width=300) cancel_button = st.button('Remove this image') if cancel_button: st.markdown( """""", unsafe_allow_html=True ) query = st.text_input('Type your question here') if query: with st.spinner('Processing...'): answer = process_query(image, query) st.write(answer) if cancel_button: st.stop() if __name__ == "__main__": main()