Vitrous commited on
Commit
7ab6c0e
1 Parent(s): 90313b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -62
app.py CHANGED
@@ -1,63 +1,111 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
- """
5
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
6
- """
7
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
8
-
9
-
10
- def respond(
11
- message,
12
- history: list[tuple[str, str]],
13
- system_message,
14
- max_tokens,
15
- temperature,
16
- top_p,
17
- ):
18
- messages = [{"role": "system", "content": system_message}]
19
-
20
- for val in history:
21
- if val[0]:
22
- messages.append({"role": "user", "content": val[0]})
23
- if val[1]:
24
- messages.append({"role": "assistant", "content": val[1]})
25
-
26
- messages.append({"role": "user", "content": message})
27
-
28
- response = ""
29
-
30
- for message in client.chat_completion(
31
- messages,
32
- max_tokens=max_tokens,
33
- stream=True,
34
- temperature=temperature,
35
- top_p=top_p,
36
- ):
37
- token = message.choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
- """
43
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
44
- """
45
- demo = gr.ChatInterface(
46
- respond,
47
- additional_inputs=[
48
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
49
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
50
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
51
- gr.Slider(
52
- minimum=0.1,
53
- maximum=1.0,
54
- value=0.95,
55
- step=0.05,
56
- label="Top-p (nucleus sampling)",
57
- ),
58
- ],
59
- )
60
-
61
-
62
- if __name__ == "__main__":
63
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
+ import plotly.express as px
3
+ import os
4
+ import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
+
7
+ # Set environment variables for GPU usage and memory allocation
8
+ os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'expandable_segments:True'
9
+ torch.cuda.empty_cache()
10
+ torch.cuda.set_per_process_memory_fraction(0.8) # Adjust the fraction as needed
11
+
12
+ # Define device
13
+ device = "cuda" # The device to load the model onto
14
+
15
+ # System message (placeholder, adjust as needed)
16
+ system_message = ""
17
+
18
+ # Load the model and tokenizer
19
+ def hermes_model():
20
+ tokenizer = AutoTokenizer.from_pretrained("TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ")
21
+ model = AutoModelForCausalLM.from_pretrained(
22
+ "TheBloke/CapybaraHermes-2.5-Mistral-7B-AWQ", low_cpu_mem_usage=True, device_map="auto"
23
+ )
24
+ return model, tokenizer
25
+
26
+ model, tokenizer = hermes_model()
27
+
28
+ # Function to generate a response from the model
29
+ def chat_response(msg_prompt: str) -> str:
30
+ """
31
+ Generates a response from the model given a prompt.
32
+
33
+ Args:
34
+ msg_prompt (str): The user's message prompt.
35
+
36
+ Returns:
37
+ str: The model's response.
38
+ """
39
+ generation_params = {
40
+ "do_sample": True,
41
+ "temperature": 0.7,
42
+ "top_p": 0.95,
43
+ "top_k": 40,
44
+ "max_new_tokens": 512,
45
+ "repetition_penalty": 1.1,
46
+ }
47
+ pipe = pipeline("text-generation", model=model, tokenizer=tokenizer, **generation_params)
48
+ try:
49
+ prompt_template = f'''system
50
+ {system_message}
51
+ user
52
+ {msg_prompt}
53
+ assistant
54
+ '''
55
+ pipe_output = pipe(prompt_template)[0]['generated_text']
56
+
57
+ # Separate assistant's response from the output
58
+ response_lines = pipe_output.split('assistant')
59
+ assistant_response = response_lines[-1].strip() if len(response_lines) > 1 else pipe_output.strip()
60
+
61
+ return assistant_response
62
+ except Exception as e:
63
+ return str(e)
64
+
65
+ # Function to generate a random plot
66
+ def random_plot():
67
+ df = px.data.iris()
68
+ fig = px.scatter(df, x="sepal_width", y="sepal_length", color="species",
69
+ size='petal_length', hover_data=['petal_width'])
70
+ return fig
71
+
72
+ # Function to handle likes/dislikes (for demonstration purposes)
73
+ def print_like_dislike(x: gr.LikeData):
74
+ print(x.index, x.value, x.liked)
75
+
76
+ # Function to add messages to the chat history
77
+ def add_message(history, message):
78
+ for x in message["files"]:
79
+ history.append(((x,), None))
80
+ if message["text"] is not None:
81
+ history.append((message["text"], None))
82
+ return history, gr.update(value=None, interactive=True)
83
+
84
+ # Function to simulate the bot response
85
+ def bot(history):
86
+ user_message = history[-1][0]
87
+ bot_response = chat_response(user_message)
88
+ history[-1][1] = bot_response
89
+ return history
90
+
91
+ fig = random_plot()
92
+
93
+ # Gradio interface setup
94
+ with gr.Blocks(fill_height=True) as demo:
95
+ chatbot = gr.Chatbot(elem_id="chatbot", bubble_full_width=False, scale=1)
96
+
97
+ chat_input = gr.MultimodalTextbox(
98
+ interactive=True,
99
+ file_count="multiple",
100
+ placeholder="Enter message or upload file...",
101
+ show_label=False
102
+ )
103
+
104
+ chat_msg = chat_input.submit(add_message, [chatbot, chat_input], [chatbot, chat_input])
105
+ bot_msg = chat_msg.then(bot, chatbot, chatbot)
106
+ bot_msg.then(lambda: gr.update(interactive=True), None, [chat_input])
107
+
108
+ chatbot.like(print_like_dislike, None, None)
109
+
110
+ demo.queue()
111
+ demo.launch()