Kiddo commited on
Commit
b6fc9a7
·
1 Parent(s): e90ad78

Vision Processor

Browse files
Files changed (1) hide show
  1. app.py +78 -2
app.py CHANGED
@@ -1,4 +1,80 @@
1
  import streamlit as st
 
 
 
2
 
3
- x = st.slider('Select a value')
4
- st.write(x, 'squared is', x * x)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ from PIL import Image
4
+ from transformers import AutoProcessor, AutoModelForImageTextToText
5
 
6
+ # Set page configuration
7
+ st.set_page_config(page_title="Llama 3.2 Vision Model", page_icon="???")
8
+
9
+ # Title and description
10
+ st.title("Llama 3.2 Vision Model Inference")
11
+ st.write("Upload an image and provide a prompt to get model insights!")
12
+
13
+ # Load model and processor (consider caching to improve performance)
14
+ @st.cache_resource
15
+ def load_model():
16
+ try:
17
+ processor = AutoProcessor.from_pretrained("meta-llama/Llama-3.2-90B-Vision-Instruct")
18
+ model = AutoModelForImageTextToText.from_pretrained("meta-llama/Llama-3.2-90B-Vision-Instruct")
19
+ return processor, model
20
+ except Exception as e:
21
+ st.error(f"Error loading model: {e}")
22
+ return None, None
23
+
24
+ # Inference function
25
+ def generate_response(image, prompt):
26
+ processor, model = load_model()
27
+
28
+ if not processor or not model:
29
+ return "Model could not be loaded."
30
+
31
+ try:
32
+ # Prepare inputs
33
+ inputs = processor(images=image, text=prompt, return_tensors="pt")
34
+
35
+ # Generate response
36
+ with torch.no_grad():
37
+ outputs = model.generate(**inputs)
38
+
39
+ # Decode the response
40
+ response = processor.decode(outputs[0], skip_special_tokens=True)
41
+ return response
42
+
43
+ except Exception as e:
44
+ st.error(f"Error during inference: {e}")
45
+ return "An error occurred during image processing."
46
+
47
+ # Sidebar for user inputs
48
+ st.sidebar.header("Image and Prompt")
49
+
50
+ # Image uploader
51
+ uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"])
52
+
53
+ # Prompt input
54
+ prompt = st.sidebar.text_input("Enter your prompt:",
55
+ placeholder="Describe what you want to know about the image")
56
+
57
+ # Main content area
58
+ if uploaded_file is not None:
59
+ # Display uploaded image
60
+ image = Image.open(uploaded_file)
61
+ st.image(image, caption="Uploaded Image", use_column_width=True)
62
+
63
+ # Generate button
64
+ if st.sidebar.button("Generate Response"):
65
+ if prompt:
66
+ # Show loading spinner
67
+ with st.spinner("Generating response..."):
68
+ response = generate_response(image, prompt)
69
+
70
+ # Display response
71
+ st.subheader("Model Response")
72
+ st.write(response)
73
+ else:
74
+ st.warning("Please enter a prompt!")
75
+ else:
76
+ st.info("Upload an image and enter a prompt to get started!")
77
+
78
+ # Additional error handling and information
79
+ st.sidebar.markdown("---")
80
+ st.sidebar.info("Note: Model performance depends on image quality and prompt specificity.")