taesiri commited on
Commit
942f16c
1 Parent(s): 936d897
Files changed (2) hide show
  1. app.py +199 -44
  2. requirements.txt +2 -1
app.py CHANGED
@@ -7,11 +7,24 @@ from peft import PeftModel
7
  from huggingface_hub import login
8
  import spaces
9
  import json
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
 
12
  # Login to Hugging Face
13
- if "HF_TOKEN" not in os.environ:
14
- raise ValueError("Please set the HF_TOKEN environment variable with your Hugging Face token")
15
  login(token=os.environ["HF_TOKEN"])
16
 
17
  # Load model and processor (do this outside the inference function to avoid reloading)
@@ -28,71 +41,213 @@ model = PeftModel.from_pretrained(model, lora_weights_path)
28
  model.tie_weights()
29
 
30
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
31
  @spaces.GPU
32
  def inference(image):
 
 
 
 
 
 
 
 
 
 
33
  # Prepare input
34
  messages = [
35
- {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": "Describe the image in JSON"}]}
 
 
 
 
 
 
36
  ]
37
  input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
38
- inputs = processor(image, input_text, add_special_tokens=False, return_tensors="pt").to(model.device)
39
-
40
- # Run inference
41
- with torch.no_grad():
42
- output = model.generate(**inputs, max_new_tokens=2048)
43
-
 
 
 
 
 
 
 
 
 
 
44
  # Decode output
45
  result = processor.decode(output[0], skip_special_tokens=True)
46
  json_str = result.strip().split("assistant\n")[1].strip()
47
-
48
- try:
49
- # First JSON parse to handle escaped JSON string
50
- first_parse = json.loads(json_str)
51
-
52
- try:
53
- # Second JSON parse to get the actual JSON object
54
- json_object = json.loads(first_parse)
55
- # Return indented JSON string with 2 spaces
56
- return json.dumps(json_object, indent=2)
57
- except json.JSONDecodeError:
58
- # If second parse fails, return the result of first parse indented
59
- if isinstance(first_parse, (dict, list)):
60
- return json.dumps(first_parse, indent=2)
61
- return first_parse
62
-
63
- except json.JSONDecodeError:
64
- # If both JSON parses fail, return original string
65
- return json_str
66
-
67
- return None # In case of unexpected errors
68
-
69
- # Create Gradio interface using Blocks
70
  with gr.Blocks() as demo:
71
- gr.Markdown("# BugsBunny-LLama-3.2-11B-Base-Medium Demo")
72
 
73
  with gr.Row():
74
- # Container for the image takes full width
75
  with gr.Column(scale=1):
76
  image_input = gr.Image(
77
  type="pil",
78
  label="Upload Image",
79
  elem_id="large-image",
80
- height=500, # Increased height for larger display
81
  )
82
 
83
- with gr.Row():
84
- # Container for the text output takes full width
85
- with gr.Column(scale=1):
86
- text_output = gr.Textbox(
87
- label="Response",
88
- elem_id="response-text",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  lines=25,
90
- max_lines=10,
91
  )
92
 
93
- # Button to trigger the analysis
94
  submit_btn = gr.Button("Analyze Image", variant="primary")
95
- submit_btn.click(fn=inference, inputs=[image_input], outputs=[text_output])
 
 
 
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
 
98
  demo.launch()
 
7
  from huggingface_hub import login
8
  import spaces
9
  import json
10
+ import matplotlib.pyplot as plt
11
+ import io
12
+ import base64
13
+
14
+
15
+ def check_environment():
16
+ required_vars = ["HF_TOKEN"]
17
+ missing_vars = [var for var in required_vars if var not in os.environ]
18
+
19
+ if missing_vars:
20
+ raise ValueError(
21
+ f"Missing required environment variables: {', '.join(missing_vars)}\n"
22
+ "Please set the HF_TOKEN environment variable with your Hugging Face token"
23
+ )
24
 
25
 
26
  # Login to Hugging Face
27
+ check_environment()
 
28
  login(token=os.environ["HF_TOKEN"])
29
 
30
  # Load model and processor (do this outside the inference function to avoid reloading)
 
41
  model.tie_weights()
42
 
43
 
44
+ def parse_json_response(json_str):
45
+ if not json_str:
46
+ return None
47
+
48
+ try:
49
+ # Handle potential JSON string escaping
50
+ json_str = json_str.strip()
51
+ if json_str.startswith('"') and json_str.endswith('"'):
52
+ json_str = json_str[1:-1]
53
+
54
+ first_parse = json.loads(json_str)
55
+ json_object = (
56
+ json.loads(first_parse) if isinstance(first_parse, str) else first_parse
57
+ )
58
+
59
+ # Validate expected keys
60
+ required_keys = [
61
+ "description",
62
+ "scene_description",
63
+ "character_list",
64
+ "object_list",
65
+ ]
66
+ if not all(key in json_object for key in required_keys):
67
+ print("Missing required keys in JSON response")
68
+ return None
69
+
70
+ return json_object
71
+ except json.JSONDecodeError as e:
72
+ print(f"JSON parsing error: {e}")
73
+ return None
74
+ except Exception as e:
75
+ print(f"Unexpected error during JSON parsing: {e}")
76
+ return None
77
+
78
+
79
+ def create_color_palette_image(colors):
80
+ if not colors or not isinstance(colors, list):
81
+ return None
82
+
83
+ try:
84
+ # Validate color format
85
+ for color in colors:
86
+ if not isinstance(color, str) or not color.startswith("#"):
87
+ return None
88
+
89
+ # Create figure and axis
90
+ fig, ax = plt.subplots(figsize=(10, 2))
91
+
92
+ # Create rectangles for each color
93
+ for i, color in enumerate(colors):
94
+ ax.add_patch(plt.Rectangle((i, 0), 1, 1, facecolor=color))
95
+
96
+ # Set the view limits and aspect ratio
97
+ ax.set_xlim(0, len(colors))
98
+ ax.set_ylim(0, 1)
99
+ ax.set_xticks([])
100
+ ax.set_yticks([])
101
+
102
+ # Save to bytes buffer
103
+ buf = io.BytesIO()
104
+ plt.savefig(buf, format="png", bbox_inches="tight", dpi=100)
105
+ plt.close("all") # Close all figures to prevent memory leaks
106
+ plt.close(fig) # Explicitly close the current figure
107
+
108
+ # Convert to base64 string
109
+ buf.seek(0)
110
+ return buf
111
+ except Exception as e:
112
+ print(f"Error creating color palette: {e}")
113
+ return None
114
+
115
+
116
  @spaces.GPU
117
  def inference(image):
118
+ if image is None:
119
+ return ["Please provide an image"] * 8
120
+
121
+ if not isinstance(image, Image.Image):
122
+ try:
123
+ image = Image.fromarray(image)
124
+ except Exception as e:
125
+ print(f"Image conversion error: {e}")
126
+ return ["Invalid image format"] * 8
127
+
128
  # Prepare input
129
  messages = [
130
+ {
131
+ "role": "user",
132
+ "content": [
133
+ {"type": "image"},
134
+ {"type": "text", "text": "Describe the image in JSON"},
135
+ ],
136
+ }
137
  ]
138
  input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
139
+ try:
140
+ # Move inputs to the correct device
141
+ inputs = processor(
142
+ image, input_text, add_special_tokens=False, return_tensors="pt"
143
+ ).to(model.device)
144
+
145
+ # Clear CUDA cache after inference
146
+ with torch.no_grad():
147
+ output = model.generate(**inputs, max_new_tokens=2048)
148
+ if torch.cuda.is_available():
149
+ torch.cuda.empty_cache()
150
+
151
+ except Exception as e:
152
+ print(f"Inference error: {e}")
153
+ return ["Error during inference"] * 8
154
+
155
  # Decode output
156
  result = processor.decode(output[0], skip_special_tokens=True)
157
  json_str = result.strip().split("assistant\n")[1].strip()
158
+
159
+ parsed_json = parse_json_response(json_str)
160
+ if parsed_json:
161
+ # Create color palette visualization
162
+ colors = parsed_json.get("color_palette", [])
163
+ color_image = create_color_palette_image(colors)
164
+
165
+ return (
166
+ parsed_json.get("description", "Not available"),
167
+ parsed_json.get("scene_description", "Not available"),
168
+ json.dumps(parsed_json.get("character_list", []), indent=2),
169
+ json.dumps(parsed_json.get("object_list", []), indent=2),
170
+ json.dumps(parsed_json.get("texture_details", []), indent=2),
171
+ parsed_json.get("lighting_details", "Not available"),
172
+ color_image,
173
+ json_str,
174
+ "", # Error box
175
+ "Analysis complete", # Status
176
+ )
177
+ return ["Error parsing response"] * 8 + ["Failed to parse JSON", "Error"]
178
+
179
+
180
+ # Update Gradio interface
181
  with gr.Blocks() as demo:
182
+ gr.Markdown("# BungsBunny-LLama-3.2-11B-Base-Medium Demo")
183
 
184
  with gr.Row():
 
185
  with gr.Column(scale=1):
186
  image_input = gr.Image(
187
  type="pil",
188
  label="Upload Image",
189
  elem_id="large-image",
190
+ height=500,
191
  )
192
 
193
+ with gr.Tabs():
194
+ with gr.Tab("Structured Results"):
195
+ with gr.Column(scale=1):
196
+ description_output = gr.Textbox(
197
+ label="Description",
198
+ lines=4,
199
+ )
200
+ scene_output = gr.Textbox(
201
+ label="Scene Description",
202
+ lines=2,
203
+ )
204
+ characters_output = gr.JSON(
205
+ label="Characters",
206
+ )
207
+ objects_output = gr.JSON(
208
+ label="Objects",
209
+ )
210
+ textures_output = gr.JSON(
211
+ label="Texture Details",
212
+ )
213
+ lighting_output = gr.Textbox(
214
+ label="Lighting Details",
215
+ lines=2,
216
+ )
217
+ color_palette_output = gr.Image(
218
+ label="Color Palette",
219
+ height=100,
220
+ )
221
+
222
+ with gr.Tab("Raw Output"):
223
+ raw_output = gr.Textbox(
224
+ label="Raw JSON Response",
225
  lines=25,
226
+ max_lines=30,
227
  )
228
 
 
229
  submit_btn = gr.Button("Analyze Image", variant="primary")
230
+ error_box = gr.Textbox(label="Error Messages", visible=False)
231
+
232
+ with gr.Row():
233
+ status_text = gr.Textbox(label="Status", value="Ready", interactive=False)
234
 
235
+ submit_btn.click(
236
+ fn=inference,
237
+ inputs=[image_input],
238
+ outputs=[
239
+ description_output,
240
+ scene_output,
241
+ characters_output,
242
+ objects_output,
243
+ textures_output,
244
+ lighting_output,
245
+ color_palette_output,
246
+ raw_output,
247
+ error_box,
248
+ status_text,
249
+ ],
250
+ api_name="analyze",
251
+ )
252
 
253
  demo.launch()
requirements.txt CHANGED
@@ -9,4 +9,5 @@ accelerate
9
  huggingface_hub[cli]
10
  hf-transfer
11
  pillow
12
- gradio
 
 
9
  huggingface_hub[cli]
10
  hf-transfer
11
  pillow
12
+ gradio
13
+ matplotlib