Nishgop commited on
Commit
76d42da
·
verified ·
1 Parent(s): 16898c5

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +75 -0
app.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import json
4
+ from io import BytesIO
5
+ from PIL import Image, ImageOps
6
+ from IPython.display import display, Markdown
7
+ from transformers import AutoModelForCausalLM, LlamaTokenizer
8
+ from accelerate import init_empty_weights, infer_auto_device_map, load_checkpoint_and_dispatch
9
+
10
+ # Initialize tokenizer and model
11
+ tokenizer = LlamaTokenizer.from_pretrained('lmsys/vicuna-7b-v1.5')
12
+ # tokenizer = LlamaTokenizer.from_pretrained('vicuna-7b-v1.5')
13
+ model = AutoModelForCausalLM.from_pretrained(
14
+ 'THUDM/cogvlm-chat-hf',
15
+ load_in_4bit=True,
16
+ trust_remote_code=True,
17
+ device_map="auto"
18
+ ).eval()
19
+
20
+ def generate_description(image, query, top_p, top_k, output_length, temperature):
21
+ # Use the uploaded image (PIL format)
22
+ display_size = (224, 224)
23
+ image = image.resize(display_size, Image.LANCZOS)
24
+
25
+ # Build the conversation input
26
+ inputs = model.build_conversation_input_ids(tokenizer, query=query, history=[], images=[image])
27
+
28
+ # Prepare the inputs dictionary for model.generate()
29
+ inputs = {
30
+ 'input_ids': inputs['input_ids'].unsqueeze(0).to('cuda'),
31
+ 'token_type_ids': inputs['token_type_ids'].unsqueeze(0).to('cuda'),
32
+ 'attention_mask': inputs['attention_mask'].unsqueeze(0).to('cuda'),
33
+ 'images': [[inputs['images'][0].to('cuda').to(torch.float16)]],
34
+ }
35
+
36
+ # Set the generation kwargs with user-defined values
37
+ gen_kwargs = {
38
+ "max_length": output_length,
39
+ "do_sample": True, # Enable sampling to use top_p, top_k, and temperature
40
+ "top_p": top_p,
41
+ "top_k": top_k,
42
+ "temperature": temperature
43
+ }
44
+
45
+ # Generate the description
46
+ with torch.no_grad():
47
+ outputs = model.generate(**inputs, **gen_kwargs)
48
+ description = tokenizer.decode(outputs[0], skip_special_tokens=True)
49
+
50
+ return description
51
+
52
+ with gr.Blocks() as app:
53
+ gr.Markdown("# Visual Product DNA - Image to Attribute Extractor")
54
+
55
+ with gr.Row():
56
+ with gr.Column():
57
+ image_input = gr.Image(label="Upload Image", type="pil", height=500)
58
+ gr.skip
59
+ query_input = gr.Textbox(label="Enter your prompt", value="Capture all attributes as JSON", lines=4)
60
+
61
+ with gr.Column():
62
+ top_p_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, value=0.1, label="Creativity (top_p)")
63
+ top_k_slider = gr.Slider(minimum=0, maximum=100, step=1, value=100, label="Coherence (top_k)")
64
+ output_length_slider = gr.Slider(minimum=1, maximum=4096, step=1, value=2048, label="Output Length")
65
+ temperature_slider = gr.Slider(minimum=0.1, maximum=2.0, step=0.01, value=0.1, label="Temperature")
66
+ submit_button = gr.Button("Extract Attributes")
67
+ description_output = gr.Textbox(label="Generated JSON", lines=12)
68
+
69
+ submit_button.click(
70
+ fn=generate_description,
71
+ inputs=[image_input, query_input, top_p_slider, top_k_slider, output_length_slider, temperature_slider],
72
+ outputs=description_output
73
+ )
74
+
75
+ app.launch(share=True, input = False)