pvyas96 commited on
Commit
a5fed7d
1 Parent(s): 2cec925

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +126 -51
app.py CHANGED
@@ -1,64 +1,139 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
 
 
 
 
 
 
 
 
3
 
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
 
 
 
 
 
 
9
 
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
 
 
 
19
 
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
 
26
- messages.append({"role": "user", "content": message})
 
27
 
28
- response = ""
29
 
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
 
39
- response += token
40
- yield response
 
 
 
 
 
 
 
 
 
 
 
 
41
 
 
42
 
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- demo = gr.ChatInterface(
47
- respond,
48
- additional_inputs=[
49
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
50
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
51
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
52
- gr.Slider(
53
- minimum=0.1,
54
- maximum=1.0,
55
- value=0.95,
56
- step=0.05,
57
- label="Top-p (nucleus sampling)",
58
- ),
59
- ],
60
- )
61
 
 
 
 
 
 
62
 
63
- if __name__ == "__main__":
64
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ from transformers import AutoProcessor, AutoModelForVision2Seq, TextIteratorStreamer
3
+ from threading import Thread
4
+ import re
5
+ import time
6
+ from PIL import Image
7
+ import torch
8
+ import spaces
9
+ #import subprocess
10
+ #subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True)
11
 
 
 
 
 
12
 
13
+ processor = AutoProcessor.from_pretrained("HuggingFaceTB/SmolVLM-Instruct")
14
+ model = AutoModelForVision2Seq.from_pretrained("HuggingFaceTB/SmolVLM-Instruct",
15
+ torch_dtype=torch.bfloat16,
16
+ #_attn_implementation="flash_attention_2"
17
+ ).to("cuda")
18
 
19
+ @spaces.GPU
20
+ def model_inference(
21
+ input_dict, history, decoding_strategy, temperature, max_new_tokens,
22
+ repetition_penalty, top_p
23
+ ):
24
+ text = input_dict["text"]
25
+ print(input_dict["files"])
26
+ if len(input_dict["files"]) > 1:
27
+ images = [Image.open(image).convert("RGB") for image in input_dict["files"]]
28
+ elif len(input_dict["files"]) == 1:
29
+ images = [Image.open(input_dict["files"][0]).convert("RGB")]
30
+
31
 
32
+ if text == "" and not images:
33
+ gr.Error("Please input a query and optionally image(s).")
 
 
 
34
 
35
+ if text == "" and images:
36
+ gr.Error("Please input a text query along the image(s).")
37
 
38
+
39
 
 
 
 
 
 
 
 
 
40
 
41
+ resulting_messages = [
42
+ {
43
+ "role": "user",
44
+ "content": [{"type": "image"} for _ in range(len(images))] + [
45
+ {"type": "text", "text": text}
46
+ ]
47
+ }
48
+ ]
49
+ prompt = processor.apply_chat_template(resulting_messages, add_generation_prompt=True)
50
+ inputs = processor(text=prompt, images=[images], return_tensors="pt")
51
+ inputs = {k: v.to("cuda") for k, v in inputs.items()}
52
+ generation_args = {
53
+ "max_new_tokens": max_new_tokens,
54
+ "repetition_penalty": repetition_penalty,
55
 
56
+ }
57
 
58
+ assert decoding_strategy in [
59
+ "Greedy",
60
+ "Top P Sampling",
61
+ ]
62
+ if decoding_strategy == "Greedy":
63
+ generation_args["do_sample"] = False
64
+ elif decoding_strategy == "Top P Sampling":
65
+ generation_args["temperature"] = temperature
66
+ generation_args["do_sample"] = True
67
+ generation_args["top_p"] = top_p
 
 
 
 
 
 
 
 
68
 
69
+ generation_args.update(inputs)
70
+ # Generate
71
+ streamer = TextIteratorStreamer(processor, skip_prompt=True, skip_special_tokens= True)
72
+ generation_args = dict(inputs, streamer=streamer, max_new_tokens=max_new_tokens)
73
+ generated_text = ""
74
 
75
+ thread = Thread(target=model.generate, kwargs=generation_args)
76
+ thread.start()
77
+ thread.join()
78
+
79
+ buffer = ""
80
+
81
+
82
+ for new_text in streamer:
83
+
84
+ buffer += new_text
85
+ generated_text_without_prompt = buffer#[len(ext_buffer):]
86
+ time.sleep(0.01)
87
+ yield buffer
88
+
89
+
90
+ demo = gr.ChatInterface(fn=model_inference, title="Geoscience AI Interpreter",
91
+ description="This app take the thin sections, seismic images etc. and interpret them. You just upload an image and text along with it. It works best with single turn conversations, so clear the conversation after a single turn.",
92
+ textbox=gr.MultimodalTextbox(label="Query Input", file_types=["image"], file_count="multiple"), stop_btn="Stop Generation", multimodal=True,
93
+ additional_inputs=[gr.Radio(["Top P Sampling",
94
+ "Greedy"],
95
+ value="Greedy",
96
+ label="Decoding strategy",
97
+ #interactive=True,
98
+ info="Higher values is equivalent to sampling more low-probability tokens.",
99
+
100
+ ), gr.Slider(
101
+ minimum=0.0,
102
+ maximum=5.0,
103
+ value=0.4,
104
+ step=0.1,
105
+ interactive=True,
106
+ label="Sampling temperature",
107
+ info="Higher values will produce more diverse outputs.",
108
+ ),
109
+ gr.Slider(
110
+ minimum=8,
111
+ maximum=1024,
112
+ value=512,
113
+ step=1,
114
+ interactive=True,
115
+ label="Maximum number of new tokens to generate",
116
+ ), gr.Slider(
117
+ minimum=0.01,
118
+ maximum=5.0,
119
+ value=1.2,
120
+ step=0.01,
121
+ interactive=True,
122
+ label="Repetition penalty",
123
+ info="1.0 is equivalent to no penalty",
124
+ ),
125
+ gr.Slider(
126
+ minimum=0.01,
127
+ maximum=0.99,
128
+ value=0.8,
129
+ step=0.01,
130
+ interactive=True,
131
+ label="Top P",
132
+ info="Higher values is equivalent to sampling more low-probability tokens.",
133
+ )],cache_examples=False
134
+ )
135
+
136
+
137
+
138
+
139
+ demo.launch(debug=True)