Spaces:
Running
Running
import streamlit as st | |
from utils import convert_to_base64, convert_to_html | |
import requests | |
import boto3 | |
import sagemaker | |
import os | |
import json | |
region = os.getenv("region") | |
sm_endpoint_name = os.getenv("sm_endpoint_name") | |
access_key = os.getenv("access_key") | |
secret_key = os.getenv("secret_key") | |
hf_token = os.getenv("hf_read_access") | |
session = boto3.Session( | |
aws_access_key_id=access_key, | |
aws_secret_access_key=secret_key, | |
region_name=region | |
) | |
sess = sagemaker.Session(boto_session=session) | |
smr = session.client("sagemaker-runtime") | |
headers = {'Content-Type': 'application/json'} | |
st.set_page_config(page_title="AWS Inferentia2 Demo", layout="wide") | |
#st.set_page_config(layout="wide") | |
st.title("Multimodal Model on AWS Inf2") | |
st.subheader("LLaVA-1.6-Mistral-7B") | |
def upload_image(): | |
image_list=["./images/view.jpg", | |
"./images/cat.jpg", | |
"./images/olympic.jpg", | |
"./images/usa.jpg", | |
"./images/box.jpg"] | |
name_list=["view(from internet)", | |
"cat(from internet)", | |
"paris 2024(from internet)", | |
"statue of liberty(from internet)", | |
"box(from my camera)"] | |
images_all = dict(zip(name_list, image_list)) | |
user_option = st.selectbox("Select a preset image", ["–Select–"] + name_list) | |
print(user_option) | |
if user_option!="–Select–": | |
image_names=[images_all[user_option]] | |
else: | |
image_names=[] | |
st.text("OR") | |
images = st.file_uploader("Upload an image to chat about", type=["png", "jpg", "jpeg"], accept_multiple_files=True) | |
#print(images) | |
# assert max number of images, e.g. 1 | |
assert len(images) <= 1, (st.error("Please upload at most 1 image"), st.stop()) | |
if images or image_names: | |
if images: | |
image_names=[] | |
# convert images to base64 | |
images_b64 = [] | |
for image in images+image_names: | |
image_b64 = convert_to_base64(image) | |
images_b64.append(image_b64) | |
# display images in multiple columns | |
cols = st.columns(len(images_b64)) ##only process first image | |
for i, col in enumerate(cols): | |
col.markdown(f"**Image {i+1}**") | |
col.markdown(convert_to_html(images_b64[i]), unsafe_allow_html=True) | |
break #only process first image | |
st.markdown("---") | |
return images_b64[0] #only process first image | |
st.stop() | |
def ask_llm(prompt, byte_image): | |
payload = { | |
"prompt":prompt, | |
"image": byte_image, | |
"parameters": { | |
"top_k": 100, | |
"top_p": 0.1, | |
"temperature": 0.2, | |
} | |
} | |
#response = requests.post(url, json=payload, headers=headers) | |
response_model = smr.invoke_endpoint( | |
EndpointName=sm_endpoint_name, | |
Body=json.dumps(payload), | |
ContentType="application/json", | |
) | |
#return response.text | |
return response_model['Body'].read().decode('utf8') | |
def app(): | |
st.markdown("---") | |
c1, c2 = st.columns(2) | |
with c2: | |
image_b64 = upload_image() | |
with c1: | |
question = st.chat_input("Ask a question about this image") | |
if not question: st.stop() | |
with c1: | |
with st.chat_message("question"): | |
st.markdown(question, unsafe_allow_html=True) | |
with st.spinner("Thinking..."): | |
res = ask_llm(question, image_b64) | |
with st.chat_message("response"): | |
st.write(res) | |
if __name__ == "__main__": | |
app() |