cszhzleo's picture
Update app.py
ee3ffed verified
import streamlit as st
from utils import convert_to_base64, convert_to_html
import requests
import boto3
import sagemaker
import os
import json
region = os.getenv("region")
sm_endpoint_name = os.getenv("sm_endpoint_name")
access_key = os.getenv("access_key")
secret_key = os.getenv("secret_key")
hf_token = os.getenv("hf_read_access")
session = boto3.Session(
aws_access_key_id=access_key,
aws_secret_access_key=secret_key,
region_name=region
)
sess = sagemaker.Session(boto_session=session)
smr = session.client("sagemaker-runtime")
headers = {'Content-Type': 'application/json'}
st.set_page_config(page_title="AWS Inferentia2 Demo", layout="wide")
#st.set_page_config(layout="wide")
st.title("Multimodal Model on AWS Inf2")
st.subheader("LLaVA-1.6-Mistral-7B")
st.text(" LLaVA (or Large Language and Vision Assistant), an open-source large multi-modal model. This demo is running on AWS Inferentia2 built with Llava1.6.")
def upload_image():
image_list=["./images/view.jpg",
"./images/cat.jpg",
"./images/olympic.jpg",
"./images/usa.jpg",
"./images/box.jpg"]
name_list=["view(from internet)",
"cat(from internet)",
"paris 2024(from internet)",
"statue of liberty(from internet)",
"box(from my camera)"]
images_all = dict(zip(name_list, image_list))
user_option = st.selectbox("Select a preset image", ["–Select–"] + name_list)
print(user_option)
if user_option!="–Select–":
image_names=[images_all[user_option]]
else:
image_names=[]
st.text("OR")
images = st.file_uploader("Upload an image to chat about", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
#print(images)
# assert max number of images, e.g. 1
assert len(images) <= 1, (st.error("Please upload at most 1 image"), st.stop())
if images or image_names:
if images:
image_names=[]
# convert images to base64
images_b64 = []
for image in images+image_names:
image_b64 = convert_to_base64(image)
images_b64.append(image_b64)
# display images in multiple columns
cols = st.columns(len(images_b64)) ##only process first image
for i, col in enumerate(cols):
col.markdown(f"**Image {i+1}**")
col.markdown(convert_to_html(images_b64[i]), unsafe_allow_html=True)
break #only process first image
st.markdown("---")
return images_b64[0] #only process first image
st.stop()
@st.cache_data(show_spinner=False)
def ask_llm(prompt, byte_image):
payload = {
"prompt":prompt,
"image": byte_image,
"parameters": {
"top_k": 100,
"top_p": 0.1,
"temperature": 0.2,
}
}
#response = requests.post(url, json=payload, headers=headers)
response_model = smr.invoke_endpoint(
EndpointName=sm_endpoint_name,
Body=json.dumps(payload),
ContentType="application/json",
)
#return response.text
return response_model['Body'].read().decode('utf8')
def app():
st.markdown("---")
c1, c2 = st.columns(2)
with c2:
image_b64 = upload_image()
with c1:
question = st.chat_input("Ask a question about this image")
if not question: st.stop()
with c1:
with st.chat_message("question"):
st.markdown(question, unsafe_allow_html=True)
with st.spinner("Thinking..."):
res = ask_llm(question, image_b64)
with st.chat_message("response"):
st.write(res)
if __name__ == "__main__":
app()