Spaces:
Running
on
A10G
Running
on
A10G
Shanshan Wang
commited on
Commit
β’
bcfef20
1
Parent(s):
d6bfd67
Track binary files with Git LFS
Browse files- .gitattributes +1 -0
- app.py +69 -4
- assets/rental_application.png +3 -0
.gitattributes
CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
assets/handwritten-note-example.jpg filter=lfs diff=lfs merge=lfs -text
|
|
|
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
assets/handwritten-note-example.jpg filter=lfs diff=lfs merge=lfs -text
|
37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -30,6 +30,29 @@ example_prompts = [
|
|
30 |
]
|
31 |
|
32 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
def load_model_and_set_image_function(model_name):
|
34 |
# Get the model path from the model_paths dictionary
|
35 |
model_path = model_paths[model_name]
|
@@ -245,10 +268,34 @@ def regenerate_response(chatbot,
|
|
245 |
def clear_all():
|
246 |
return [], None, None, "" # Clear chatbot, state, reset image_input
|
247 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
248 |
# Build the Gradio interface
|
249 |
with gr.Blocks() as demo:
|
250 |
-
gr.
|
251 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
252 |
state= gr.State()
|
253 |
model_state = gr.State()
|
254 |
|
@@ -258,7 +305,12 @@ with gr.Blocks() as demo:
|
|
258 |
label="Select Model",
|
259 |
value="H2OVL-Mississippi-2B"
|
260 |
)
|
261 |
-
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
with gr.Row(equal_height=True):
|
264 |
# First column with image input
|
@@ -293,7 +345,7 @@ with gr.Blocks() as demo:
|
|
293 |
inputs=None,
|
294 |
outputs=[chatbot, state]
|
295 |
)
|
296 |
-
|
297 |
|
298 |
# Reset chatbot and state when image input changes
|
299 |
image_input.change(
|
@@ -343,6 +395,18 @@ with gr.Blocks() as demo:
|
|
343 |
label="Tile Number (default: 6)"
|
344 |
)
|
345 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
346 |
with gr.Row():
|
347 |
submit_button = gr.Button("Submit")
|
348 |
regenerate_button = gr.Button("Regenerate")
|
@@ -394,6 +458,7 @@ with gr.Blocks() as demo:
|
|
394 |
gr.Examples(
|
395 |
examples=[
|
396 |
["assets/handwritten-note-example.jpg", "Read the text on the image"],
|
|
|
397 |
["assets/receipt.jpg", "Extract the text from the image."],
|
398 |
["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"],
|
399 |
["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"],
|
|
|
30 |
]
|
31 |
|
32 |
|
33 |
+
# Function to handle task type logic
|
34 |
+
def handle_task_type(task_type, model_name):
|
35 |
+
max_new_tokens = 1024 # Default value
|
36 |
+
if task_type == "OCR":
|
37 |
+
max_new_tokens = 3072 # Adjust for OCR
|
38 |
+
return max_new_tokens
|
39 |
+
|
40 |
+
# Function to handle task type logic and default question
|
41 |
+
def handle_task_type_and_prompt(task_type, model_name):
|
42 |
+
max_new_tokens = handle_task_type(task_type, model_name)
|
43 |
+
default_question = example_prompts[0] if task_type == "OCR" else None
|
44 |
+
return max_new_tokens, default_question
|
45 |
+
|
46 |
+
def update_task_type_on_model_change(model_name):
|
47 |
+
# Set default task type and max_new_tokens based on the model
|
48 |
+
if '2b' in model_name.lower():
|
49 |
+
return "Document extractor", handle_task_type("Document extractor", model_name)
|
50 |
+
elif '0.8b' in model_name.lower():
|
51 |
+
return "OCR", handle_task_type("OCR", model_name)
|
52 |
+
else:
|
53 |
+
return "Chat", handle_task_type("Chat", model_name)
|
54 |
+
|
55 |
+
|
56 |
def load_model_and_set_image_function(model_name):
|
57 |
# Get the model path from the model_paths dictionary
|
58 |
model_path = model_paths[model_name]
|
|
|
268 |
def clear_all():
|
269 |
return [], None, None, "" # Clear chatbot, state, reset image_input
|
270 |
|
271 |
+
|
272 |
+
title_html = """
|
273 |
+
<h1> <span class="gradient-text" id="text">H2OVL-Mississippi</span><span class="plain-text">: Lightweight Vision Language Models for OCR and Doc AI tasks</span></h1>
|
274 |
+
<a href="https://huggingface.co/collections/h2oai/h2ovl-mississippi-66e492da45da0a1b7ea7cf39">[π Hugging Face]</a>
|
275 |
+
<a href="https://arxiv.org/abs/2410.13611">[π Paper]</a>
|
276 |
+
<a href="https://huggingface.co/spaces/h2oai/h2ovl-mississippi-benchmarks">[π Benchmarks]</a>
|
277 |
+
"""
|
278 |
+
|
279 |
+
|
280 |
+
|
281 |
# Build the Gradio interface
|
282 |
with gr.Blocks() as demo:
|
283 |
+
gr.HTML(title_html)
|
284 |
+
gr.HTML("""
|
285 |
+
<style>
|
286 |
+
.gradient-text {
|
287 |
+
font-size: 36px !important;
|
288 |
+
font-weight: bold !important;
|
289 |
+
}
|
290 |
+
.plain-text {
|
291 |
+
font-size: 32px !important;
|
292 |
+
}
|
293 |
+
h1 {
|
294 |
+
margin-bottom: 20px !important;
|
295 |
+
}
|
296 |
+
</style>
|
297 |
+
""")
|
298 |
+
|
299 |
state= gr.State()
|
300 |
model_state = gr.State()
|
301 |
|
|
|
305 |
label="Select Model",
|
306 |
value="H2OVL-Mississippi-2B"
|
307 |
)
|
308 |
+
|
309 |
+
task_type_dropdown = gr.Dropdown(
|
310 |
+
choices=["OCR", "Document extractor", "Chat"],
|
311 |
+
label="Select Task Type",
|
312 |
+
value="Document extractor"
|
313 |
+
)
|
314 |
|
315 |
with gr.Row(equal_height=True):
|
316 |
# First column with image input
|
|
|
345 |
inputs=None,
|
346 |
outputs=[chatbot, state]
|
347 |
)
|
348 |
+
|
349 |
|
350 |
# Reset chatbot and state when image input changes
|
351 |
image_input.change(
|
|
|
395 |
label="Tile Number (default: 6)"
|
396 |
)
|
397 |
|
398 |
+
model_dropdown.change(
|
399 |
+
fn=update_task_type_on_model_change,
|
400 |
+
inputs=[model_dropdown],
|
401 |
+
outputs=[task_type_dropdown, max_new_tokens_input]
|
402 |
+
)
|
403 |
+
|
404 |
+
task_type_dropdown.change(
|
405 |
+
fn=handle_task_type_and_prompt,
|
406 |
+
inputs=[task_type_dropdown, model_dropdown],
|
407 |
+
outputs=[max_new_tokens_input, user_input]
|
408 |
+
)
|
409 |
+
|
410 |
with gr.Row():
|
411 |
submit_button = gr.Button("Submit")
|
412 |
regenerate_button = gr.Button("Regenerate")
|
|
|
458 |
gr.Examples(
|
459 |
examples=[
|
460 |
["assets/handwritten-note-example.jpg", "Read the text on the image"],
|
461 |
+
["assets/rental_application.png", "Read the text and provide word by word ocr for the document. <doc>"],
|
462 |
["assets/receipt.jpg", "Extract the text from the image."],
|
463 |
["assets/driver_license.png", "Extract the text from the image and fill the following json {'license_number':'',\n'full_name':'',\n'date_of_birth':'',\n'address':'',\n'issue_date':'',\n'expiration_date':'',\n}"],
|
464 |
["assets/invoice.png", "Please extract the following fields, and return the result in JSON format: supplier_name, supplier_address, customer_name, customer_address, invoice_number, invoice_total_amount, invoice_tax_amount"],
|
assets/rental_application.png
ADDED
Git LFS Details
|