Update app.py
Browse files
app.py
CHANGED
@@ -1,64 +1,139 @@
|
|
1 |
import gradio as gr
|
2 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
3 |
|
4 |
-
"""
|
5 |
-
For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
|
6 |
-
"""
|
7 |
-
client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
|
8 |
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
-
history
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
):
|
18 |
-
|
|
|
|
|
|
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
messages.append({"role": "user", "content": val[0]})
|
23 |
-
if val[1]:
|
24 |
-
messages.append({"role": "assistant", "content": val[1]})
|
25 |
|
26 |
-
|
|
|
27 |
|
28 |
-
|
29 |
|
30 |
-
for message in client.chat_completion(
|
31 |
-
messages,
|
32 |
-
max_tokens=max_tokens,
|
33 |
-
stream=True,
|
34 |
-
temperature=temperature,
|
35 |
-
top_p=top_p,
|
36 |
-
):
|
37 |
-
token = message.choices[0].delta.content
|
38 |
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
|
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
""
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
minimum=0.1,
|
54 |
-
maximum=1.0,
|
55 |
-
value=0.95,
|
56 |
-
step=0.05,
|
57 |
-
label="Top-p (nucleus sampling)",
|
58 |
-
),
|
59 |
-
],
|
60 |
-
)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
62 |
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import gradio as gr
|
2 |
+
from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
|
3 |
+
from threading import Thread
|
4 |
+
import re
|
5 |
+
import time
|
6 |
+
from PIL import Image
|
7 |
+
import torch
|
8 |
+
import spaces
|
9 |
+
#import subprocess
|
10 |
+
#subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
|
11 |
|
|
|
|
|
|
|
|
|
12 |
|
13 |
+
processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
|
14 |
+
model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct",
|
15 |
+
torch_dtype=torch.bfloat16,
|
16 |
+
#_attn_implementation="flash_attention_2"
|
17 |
+
).to("cuda")
|
18 |
|
19 |
+
@spaces.GPU
|
20 |
+
def model_inference(
|
21 |
+
input_dict, history, decoding_strategy, temperature, max_new_tokens,
|
22 |
+
repetition_penalty, top_p
|
23 |
+
):
|
24 |
+
text = input_dict["text"]
|
25 |
+
print(input_dict["files"])
|
26 |
+
if len(input_dict["files"]) > 1:
|
27 |
+
images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
|
28 |
+
elif len(input_dict["files"]) == 1:
|
29 |
+
images = [Image.open(input_dict["files"][0]).convert("RGB")]
|
30 |
+
|
31 |
|
32 |
+
if text == "" and not images:
|
33 |
+
gr.Error("Please input a query and optionally image(s).")
|
|
|
|
|
|
|
34 |
|
35 |
+
if text == "" and images:
|
36 |
+
gr.Error("Please input a text query along the image(s).")
|
37 |
|
38 |
+
|
39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
+
resulting_messages = [
|
42 |
+
{
|
43 |
+
"role": "user",
|
44 |
+
"content": [{"type": "image"} for _ in range(len(images))] + [
|
45 |
+
{"type": "text", "text": text}
|
46 |
+
]
|
47 |
+
}
|
48 |
+
]
|
49 |
+
prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
|
50 |
+
inputs = processor(text=prompt, images=[images], return_tensors="pt")
|
51 |
+
inputs = {k: v.to("cuda") for k, v in inputs.items()}
|
52 |
+
generation_args = {
|
53 |
+
"max_new_tokens": max_new_tokens,
|
54 |
+
"repetition_penalty": repetition_penalty,
|
55 |
|
56 |
+
}
|
57 |
|
58 |
+
assert decoding_strategy in [
|
59 |
+
"Greedy",
|
60 |
+
"Top P Sampling",
|
61 |
+
]
|
62 |
+
if decoding_strategy == "Greedy":
|
63 |
+
generation_args["do_sample"] = False
|
64 |
+
elif decoding_strategy == "Top P Sampling":
|
65 |
+
generation_args["temperature"] = temperature
|
66 |
+
generation_args["do_sample"] = True
|
67 |
+
generation_args["top_p"] = top_p
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
|
69 |
+
generation_args.update(inputs)
|
70 |
+
# Generate
|
71 |
+
streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens= True)
|
72 |
+
generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
|
73 |
+
generated_text = ""
|
74 |
|
75 |
+
thread = Thread(target=model.generate, kwargs=generation_args)
|
76 |
+
thread.start()
|
77 |
+
thread.join()
|
78 |
+
|
79 |
+
buffer = ""
|
80 |
+
|
81 |
+
|
82 |
+
for new_text in streamer:
|
83 |
+
|
84 |
+
buffer += new_text
|
85 |
+
generated_text_without_prompt = buffer#[len(ext_buffer):]
|
86 |
+
time.sleep(0.01)
|
87 |
+
yield buffer
|
88 |
+
|
89 |
+
|
90 |
+
demo = gr.ChatInterface(fn=model_inference, title="Geoscience AI Interpreter",
|
91 |
+
description="This app take the thin sections, seismic images etc. and interpret them. You just upload an image and text along with it. It works best with single turn conversations, so clear the conversation after a single turn.",
|
92 |
+
textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
|
93 |
+
additional_inputs=[gr.Radio(["Top P Sampling",
|
94 |
+
"Greedy"],
|
95 |
+
value="Greedy",
|
96 |
+
label="Decoding strategy",
|
97 |
+
#interactive=True,
|
98 |
+
info="Higher values is equivalent to sampling more low-probability tokens.",
|
99 |
+
|
100 |
+
), gr.Slider(
|
101 |
+
minimum=0.0,
|
102 |
+
maximum=5.0,
|
103 |
+
value=0.4,
|
104 |
+
step=0.1,
|
105 |
+
interactive=True,
|
106 |
+
label="Sampling temperature",
|
107 |
+
info="Higher values will produce more diverse outputs.",
|
108 |
+
),
|
109 |
+
gr.Slider(
|
110 |
+
minimum=8,
|
111 |
+
maximum=1024,
|
112 |
+
value=512,
|
113 |
+
step=1,
|
114 |
+
interactive=True,
|
115 |
+
label="Maximum number of new tokens to generate",
|
116 |
+
), gr.Slider(
|
117 |
+
minimum=0.01,
|
118 |
+
maximum=5.0,
|
119 |
+
value=1.2,
|
120 |
+
step=0.01,
|
121 |
+
interactive=True,
|
122 |
+
label="Repetition penalty",
|
123 |
+
info="1.0 is equivalent to no penalty",
|
124 |
+
),
|
125 |
+
gr.Slider(
|
126 |
+
minimum=0.01,
|
127 |
+
maximum=0.99,
|
128 |
+
value=0.8,
|
129 |
+
step=0.01,
|
130 |
+
interactive=True,
|
131 |
+
label="Top P",
|
132 |
+
info="Higher values is equivalent to sampling more low-probability tokens.",
|
133 |
+
)],cache_examples=False
|
134 |
+
)
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
|
139 |
+
demo.launch(debug=True)
|