Ketengan-Diffusion-Lab commited on
Commit
789acc7
·
verified ·
1 Parent(s): 0a0d7ab

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +63 -0
app.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import transformers
4
+ from transformers import AutoModelForCausalLM, AutoTokenizer
5
+ from PIL import Image
6
+ import warnings
7
+
8
+ # disable some warnings
9
+ transformers.logging.set_verbosity_error()
10
+ transformers.logging.disable_progress_bar()
11
+ warnings.filterwarnings('ignore')
12
+
13
+ # set device
14
+ torch.set_default_device('cuda') # or 'cpu'
15
+
16
+ model_name = 'cognitivecomputations/dolphin-vision-7b'
17
+
18
+ # create model
19
+ model = AutoModelForCausalLM.from_pretrained(
20
+ model_name,
21
+ torch_dtype=torch.float16,
22
+ device_map='auto',
23
+ trust_remote_code=True)
24
+ tokenizer = AutoTokenizer.from_pretrained(
25
+ model_name,
26
+ trust_remote_code=True)
27
+
28
+ def inference(prompt, image):
29
+ messages = [
30
+ {"role": "user", "content": f'<image>\n{prompt}'}
31
+ ]
32
+ text = tokenizer.apply_chat_template(
33
+ messages,
34
+ tokenize=False,
35
+ add_generation_prompt=True
36
+ )
37
+
38
+ text_chunks = [tokenizer(chunk).input_ids for chunk in text.split('<image>')]
39
+ input_ids = torch.tensor(text_chunks[0] + [-200] + text_chunks[1], dtype=torch.long).unsqueeze(0)
40
+
41
+ image_tensor = model.process_images([image], model.config).to(dtype=model.dtype)
42
+
43
+ # generate
44
+ output_ids = model.generate(
45
+ input_ids,
46
+ images=image_tensor,
47
+ max_new_tokens=2048,
48
+ use_cache=True)[0]
49
+
50
+ return tokenizer.decode(output_ids[input_ids.shape[1]:], skip_special_tokens=True).strip()
51
+
52
+ with gr.Blocks() as demo:
53
+ with gr.Row():
54
+ with gr.Column():
55
+ prompt_input = gr.Textbox(label="Prompt", placeholder="Describe this image in detail")
56
+ image_input = gr.Image(label="Image", type="pil")
57
+ submit_button = gr.Button("Submit")
58
+ with gr.Column():
59
+ output_text = gr.Textbox(label="Output")
60
+
61
+ submit_button.click(fn=inference, inputs=[prompt_input, image_input], outputs=output_text)
62
+
63
+ demo.launch()