Staticaliza commited on
Commit
5268082
1 Parent(s): 0c034e2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -1
app.py CHANGED
@@ -12,8 +12,58 @@ if DEVICE == "auto":
12
  print(f"[SYSTEM] | Using {DEVICE} type compute device.")
13
 
14
  # Variables
 
 
15
  repo = AutoModel.from_pretrained("openbmb/MiniCPM-V-2_6", torch_dtype=torch.bfloat16, trust_remote_code=True)
16
  tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-V-2_6", trust_remote_code=True)
17
 
18
  # Functions
19
- # Initialize
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  print(f"[SYSTEM] | Using {DEVICE} type compute device.")
13
 
14
  # Variables
15
+ DEFAULT_INPUT = "Describe in one paragraph."
16
+
17
  repo = AutoModel.from_pretrained("openbmb/MiniCPM-V-2_6", torch_dtype=torch.bfloat16, trust_remote_code=True)
18
  tokenizer = AutoTokenizer.from_pretrained("openbmb/MiniCPM-V-2_6", trust_remote_code=True)
19
 
20
  # Functions
21
+ @spaces.GPU(duration=60)
22
+ def generate(image, instruction=DEFAULT_INPUT, sampling=False, temperature=0.7, top_p=0.8, top_k=100, repetition_penalty=1.05, max_tokens=512):
23
+ global model, tokenizer
24
+
25
+ image_rgb = Image.open(image).convert("RGB")
26
+ print(image_rgb, instruction)
27
+
28
+ inputs = [{"role": "user", "content": [image_rgb, instruction]}]
29
+
30
+ parameters = {
31
+ "sampling": sampling,
32
+ "temperature": temperature,
33
+ "top_p": top_p,
34
+ "top_k": top_k,
35
+ "repetition_penalty": repetition_penalty,
36
+ "max_new_tokens": max_tokens
37
+ }
38
+
39
+ output = model.chat(image=None, msgs=inputs, tokenizer=tokenizer, **parameters)
40
+
41
+ return output
42
+
43
+ def cloud():
44
+ print("[CLOUD] | Space maintained.")
45
+
46
+ # Initialize
47
+ with gr.Blocks(css=css) as main:
48
+ with gr.Column():
49
+ gr.Markdown("🪄 Analyze images and caption them.")
50
+
51
+ with gr.Column():
52
+ input = gr.Image(label="Image")
53
+ instruction = gr.Textbox(lines=1, value=DEFAULT_INPUT, label="Instruction")
54
+ sampling = gr.Checkbox(value=False, label="Sampling")
55
+ temperature = gr.Slider(minimum=0, maximum=2, step=0.01, value=0.7, label="Temperature")
56
+ top_p = gr.Slider(minimum=0, maximum=1, step=0.01, value=0.8, label="Top P")
57
+ top_k = gr.Slider(minimum=0, maximum=1000, step=1, value=100, label="Top K")
58
+ repetition_penalty = gr.Slider(minimum=0, maximum=2, step=0.01, value=1.05, label="Repetition Penalty")
59
+ max_tokens = gr.Slider(minimum=1, maximum=4096, step=1, value=512, label="Max Tokens")
60
+ submit = gr.Button("▶")
61
+ maintain = gr.Button("☁️")
62
+
63
+ with gr.Column():
64
+ output = gr.Textbox(lines=1, value="", label="Output")
65
+
66
+ submit.click(fn=generate, inputs=[input, instruction, sampling, temperature, top_p, top_k, repetition_penalty, max_tokens], outputs=[output], queue=False)
67
+ maintain.click(cloud, inputs=[], outputs=[], queue=False)
68
+
69
+ main.launch(show_api=True)