Spaces:
Runtime error
Runtime error
LennardZuendorf
commited on
Commit
•
a597c76
1
Parent(s):
f301e04
feat/fixing: correcting bug, updating documentation (final?)
Browse files- README.md +2 -2
- backend/controller.py +3 -2
- components/iframe/README.md +2 -2
- explanation/attention.py +3 -0
- explanation/interpret_captum.py +1 -1
- explanation/markup.py +2 -2
- explanation/plotting.py +6 -6
- main.py +40 -12
- model/godel.py +1 -1
- model/mistral.py +16 -5
- utils/formatting.py +1 -1
- utils/modelling.py +5 -2
README.md
CHANGED
@@ -21,7 +21,7 @@ This is the UI showcase for my thesis about the interpretability of LLM based ch
|
|
21 |
|
22 |
### 🔗 Links:
|
23 |
|
24 |
-
**[
|
25 |
|
26 |
**[Huggingface Spaces Showcase](https://huggingface.co/spaces/lennardzuendorf/thesis-webapp-docker)**
|
27 |
|
@@ -86,7 +86,7 @@ See code for in detailed credits, work is strongly based on:
|
|
86 |
|
87 |
#### SHAP
|
88 |
- [Github](https://github.com/shap/shap)
|
89 |
-
- [
|
90 |
|
91 |
#### Custom Component (/components/iframe/)
|
92 |
|
|
|
21 |
|
22 |
### 🔗 Links:
|
23 |
|
24 |
+
**[GitHub Repository](https://github.com/LennardZuendorf/thesis-webapp)**
|
25 |
|
26 |
**[Huggingface Spaces Showcase](https://huggingface.co/spaces/lennardzuendorf/thesis-webapp-docker)**
|
27 |
|
|
|
86 |
|
87 |
#### SHAP
|
88 |
- [Github](https://github.com/shap/shap)
|
89 |
+
- [Initial Paper](https://arxiv.org/abs/1705.07874)
|
90 |
|
91 |
#### Custom Component (/components/iframe/)
|
92 |
|
backend/controller.py
CHANGED
@@ -43,6 +43,7 @@ def explained_chat(
|
|
43 |
# message, history, system_prompt, knowledge
|
44 |
# )
|
45 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
|
|
46 |
|
47 |
# generating an answer using the methods chat function
|
48 |
answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
|
@@ -73,10 +74,10 @@ def interference(
|
|
73 |
# if a model is selected, grab the model instance
|
74 |
if model_selection.lower() == "mistral":
|
75 |
model = mistral
|
76 |
-
print("
|
77 |
else:
|
78 |
model = godel
|
79 |
-
print("
|
80 |
|
81 |
# if a XAI approach is selected, grab the XAI module instance
|
82 |
# and call the explained chat function
|
|
|
43 |
# message, history, system_prompt, knowledge
|
44 |
# )
|
45 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
46 |
+
print(f"Formatted prompt: {prompt}")
|
47 |
|
48 |
# generating an answer using the methods chat function
|
49 |
answer, xai_graphic, xai_markup, xai_plot = xai.chat_explained(model, prompt)
|
|
|
74 |
# if a model is selected, grab the model instance
|
75 |
if model_selection.lower() == "mistral":
|
76 |
model = mistral
|
77 |
+
print("Identified model as Mistral")
|
78 |
else:
|
79 |
model = godel
|
80 |
+
print("Identified model as GODEL")
|
81 |
|
82 |
# if a XAI approach is selected, grab the XAI module instance
|
83 |
# and call the explained chat function
|
components/iframe/README.md
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
# gradio iFrame
|
2 |
|
3 |
This is a custom gradio component used to display the shap package text plot. Which is interactive HTML and needs a custom wrapper.
|
4 |
-
See custom component examples at
|
5 |
|
6 |
# Credit
|
7 |
CREDIT: based mostly of Gradio template component, HTML
|
@@ -14,4 +14,4 @@ see: https://www.gradio.app/docs/html
|
|
14 |
- backend/iframe.py - updating component to accept custom height/width and added new example
|
15 |
- demo/app.py - slightly changed demo file for better dev experience
|
16 |
- frontend/index.svelte - slightly changed to accept custom height/width
|
17 |
-
- frontend/HTML.svelte - updated to use iFrame and added custom function to
|
|
|
1 |
# gradio iFrame
|
2 |
|
3 |
This is a custom gradio component used to display the shap package text plot. Which is interactive HTML and needs a custom wrapper.
|
4 |
+
See custom component examples at official [docu](https://www.gradio.app/guides/custom-components-in-five-minutes)
|
5 |
|
6 |
# Credit
|
7 |
CREDIT: based mostly of Gradio template component, HTML
|
|
|
14 |
- backend/iframe.py - updating component to accept custom height/width and added new example
|
15 |
- demo/app.py - slightly changed demo file for better dev experience
|
16 |
- frontend/index.svelte - slightly changed to accept custom height/width
|
17 |
+
- frontend/HTML.svelte - updated to use iFrame and added custom function to programmatically set height values
|
explanation/attention.py
CHANGED
@@ -11,6 +11,8 @@ from .markup import markup_text
|
|
11 |
# and marked text based on attention
|
12 |
def chat_explained(model, prompt):
|
13 |
|
|
|
|
|
14 |
# get encoded input
|
15 |
input_ids = model.TOKENIZER(
|
16 |
prompt, return_tensors="pt", add_special_tokens=True
|
@@ -56,6 +58,7 @@ def chat_explained(model, prompt):
|
|
56 |
" Visualization doesn't support an interactive graphic.</h4></div>"
|
57 |
)
|
58 |
# creating marked text using markup_text function and attention
|
|
|
59 |
marked_text = markup_text(input_text, averaged_attention, variant="visualizer")
|
60 |
|
61 |
# returning response, graphic and marked text array
|
|
|
11 |
# and marked text based on attention
|
12 |
def chat_explained(model, prompt):
|
13 |
|
14 |
+
print(f"Running explained chat with prompt {prompt}.")
|
15 |
+
|
16 |
# get encoded input
|
17 |
input_ids = model.TOKENIZER(
|
18 |
prompt, return_tensors="pt", add_special_tokens=True
|
|
|
58 |
" Visualization doesn't support an interactive graphic.</h4></div>"
|
59 |
)
|
60 |
# creating marked text using markup_text function and attention
|
61 |
+
print(f"Creating marked text with {input_text}.")
|
62 |
marked_text = markup_text(input_text, averaged_attention, variant="visualizer")
|
63 |
|
64 |
# returning response, graphic and marked text array
|
explanation/interpret_captum.py
CHANGED
@@ -47,7 +47,7 @@ def chat_explained(model, prompt):
|
|
47 |
# getting response text, graphic placeholder and marked text object
|
48 |
response_text = fmt.format_output_text(attribution_result.output_tokens)
|
49 |
graphic = """<div style='text-align: center; font-family:arial;'><h4>
|
50 |
-
|
51 |
"""
|
52 |
# create the explanation marked text array
|
53 |
marked_text = markup_text(input_tokens, values, variant="captum")
|
|
|
47 |
# getting response text, graphic placeholder and marked text object
|
48 |
response_text = fmt.format_output_text(attribution_result.output_tokens)
|
49 |
graphic = """<div style='text-align: center; font-family:arial;'><h4>
|
50 |
+
Interpretation with Captum doesn't support an interactive graphic.</h4></div>
|
51 |
"""
|
52 |
# create the explanation marked text array
|
53 |
marked_text = markup_text(input_tokens, values, variant="captum")
|
explanation/markup.py
CHANGED
@@ -21,7 +21,7 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
|
|
21 |
elif variant == "visualizer":
|
22 |
text_values = fmt.flatten_attention(text_values)
|
23 |
|
24 |
-
#
|
25 |
min_val, max_val = np.min(text_values), np.max(text_values)
|
26 |
|
27 |
# separate the threshold calculation for negative and positive values
|
@@ -69,7 +69,7 @@ def color_codes():
|
|
69 |
return {
|
70 |
# -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
|
71 |
# 0: white (assuming default light mode)
|
72 |
-
# +1 to +5 light pink to
|
73 |
"-5": "#008bfb",
|
74 |
"-4": "#68a1fd",
|
75 |
"-3": "#96b7fe",
|
|
|
21 |
elif variant == "visualizer":
|
22 |
text_values = fmt.flatten_attention(text_values)
|
23 |
|
24 |
+
# determine the minimum and maximum values
|
25 |
min_val, max_val = np.min(text_values), np.max(text_values)
|
26 |
|
27 |
# separate the threshold calculation for negative and positive values
|
|
|
69 |
return {
|
70 |
# -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
|
71 |
# 0: white (assuming default light mode)
|
72 |
+
# +1 to +5 light pink to strong magenta
|
73 |
"-5": "#008bfb",
|
74 |
"-4": "#68a1fd",
|
75 |
"-3": "#96b7fe",
|
explanation/plotting.py
CHANGED
@@ -7,24 +7,24 @@ import matplotlib.pyplot as plt
|
|
7 |
|
8 |
def plot_seq(seq_values: list, method: str = ""):
|
9 |
|
10 |
-
#
|
11 |
tokens, importance = zip(*seq_values)
|
12 |
|
13 |
-
#
|
14 |
importance = np.array(importance)
|
15 |
|
16 |
-
#
|
17 |
colors = ["#ff0051" if val > 0 else "#008bfb" for val in importance]
|
18 |
|
19 |
-
#
|
20 |
plt.figure(figsize=(len(tokens) * 0.9, np.max(importance)))
|
21 |
x_positions = range(len(tokens)) # Positions for the bars
|
22 |
|
23 |
-
#
|
24 |
bar_width = 0.8
|
25 |
plt.bar(x_positions, importance, color=colors, align="center", width=bar_width)
|
26 |
|
27 |
-
#
|
28 |
padding = 0.1 # Padding for text annotation
|
29 |
for x, (y, color) in enumerate(zip(importance, colors)):
|
30 |
sign = "+" if y > 0 else ""
|
|
|
7 |
|
8 |
def plot_seq(seq_values: list, method: str = ""):
|
9 |
|
10 |
+
# separate the tokens and their corresponding importance values
|
11 |
tokens, importance = zip(*seq_values)
|
12 |
|
13 |
+
# convert importance values to numpy array for conditional coloring
|
14 |
importance = np.array(importance)
|
15 |
|
16 |
+
# determine the colors based on the sign of the importance values
|
17 |
colors = ["#ff0051" if val > 0 else "#008bfb" for val in importance]
|
18 |
|
19 |
+
# create a bar plot
|
20 |
plt.figure(figsize=(len(tokens) * 0.9, np.max(importance)))
|
21 |
x_positions = range(len(tokens)) # Positions for the bars
|
22 |
|
23 |
+
# creating vertical bar plot
|
24 |
bar_width = 0.8
|
25 |
plt.bar(x_positions, importance, color=colors, align="center", width=bar_width)
|
26 |
|
27 |
+
# annotating each bar with its value
|
28 |
padding = 0.1 # Padding for text annotation
|
29 |
for x, (y, color) in enumerate(zip(importance, colors)):
|
30 |
sign = "+" if y > 0 else ""
|
main.py
CHANGED
@@ -26,7 +26,7 @@ css = """
|
|
26 |
.examples {text-align: start;}
|
27 |
.seperatedRow {border-top: 1rem solid;}",
|
28 |
"""
|
29 |
-
# custom js to force
|
30 |
if os.environ["HOSTING"].lower() != "spaces":
|
31 |
js = """
|
32 |
function () {
|
@@ -52,6 +52,12 @@ def load_md(path):
|
|
52 |
|
53 |
# function to display the system prompt info
|
54 |
def system_prompt_info(sys_prompt_txt):
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
# display the system prompt using the Gradio Info component
|
56 |
gr.Info(f"The system prompt was set to:\n {sys_prompt_txt}")
|
57 |
|
@@ -71,7 +77,7 @@ def model_info(model_radio):
|
|
71 |
|
72 |
|
73 |
# ui interface based on Gradio Blocks
|
74 |
-
# see https://www.gradio.app/docs/interface
|
75 |
with gr.Blocks(
|
76 |
css=css,
|
77 |
js=js,
|
@@ -171,11 +177,11 @@ with gr.Blocks(
|
|
171 |
show_copy_button=True,
|
172 |
avatar_images=("./public/human.jpg", "./public/bot.jpg"),
|
173 |
)
|
174 |
-
#
|
175 |
with gr.Accordion(label="Additional Knowledge", open=False):
|
176 |
gr.Markdown("""
|
177 |
*Hint:* Add extra knowledge to see GODEL work the best.
|
178 |
-
Knowledge doesn't work
|
179 |
""")
|
180 |
# textbox to enter the knowledge
|
181 |
knowledge_input = gr.Textbox(
|
@@ -217,8 +223,8 @@ with gr.Blocks(
|
|
217 |
"Does money buy happiness?",
|
218 |
"",
|
219 |
(
|
220 |
-
"Respond from the perspective of
|
221 |
-
" life
|
222 |
),
|
223 |
"Mistral",
|
224 |
"None",
|
@@ -227,8 +233,8 @@ with gr.Blocks(
|
|
227 |
"Does money buy happiness?",
|
228 |
"",
|
229 |
(
|
230 |
-
"Respond from the perspective of
|
231 |
-
" life
|
232 |
),
|
233 |
"Mistral",
|
234 |
"SHAP",
|
@@ -251,14 +257,36 @@ with gr.Blocks(
|
|
251 |
[
|
252 |
"Does money buy happiness?",
|
253 |
(
|
254 |
-
"
|
255 |
-
"
|
256 |
-
"
|
|
|
|
|
257 |
),
|
258 |
"",
|
259 |
"GODEL",
|
260 |
"SHAP",
|
261 |
],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
],
|
263 |
inputs=[
|
264 |
user_prompt,
|
@@ -332,7 +360,7 @@ with gr.Blocks(
|
|
332 |
# load about.md markdown
|
333 |
gr.Markdown(value=load_md("public/about.md"))
|
334 |
with gr.Accordion(label="Credits, Data Protection, License"):
|
335 |
-
# load credits and
|
336 |
gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
|
337 |
|
338 |
# mount function for fastAPI Application
|
|
|
26 |
.examples {text-align: start;}
|
27 |
.seperatedRow {border-top: 1rem solid;}",
|
28 |
"""
|
29 |
+
# custom js to force light mode in custom environments
|
30 |
if os.environ["HOSTING"].lower() != "spaces":
|
31 |
js = """
|
32 |
function () {
|
|
|
52 |
|
53 |
# function to display the system prompt info
|
54 |
def system_prompt_info(sys_prompt_txt):
|
55 |
+
if sys_prompt_txt == "":
|
56 |
+
sys_prompt_txt = """
|
57 |
+
You are a helpful, respectful and honest assistant.
|
58 |
+
Always answer as helpfully as possible, while being safe.
|
59 |
+
"""
|
60 |
+
|
61 |
# display the system prompt using the Gradio Info component
|
62 |
gr.Info(f"The system prompt was set to:\n {sys_prompt_txt}")
|
63 |
|
|
|
77 |
|
78 |
|
79 |
# ui interface based on Gradio Blocks
|
80 |
+
# see https://www.gradio.app/docs/interface
|
81 |
with gr.Blocks(
|
82 |
css=css,
|
83 |
js=js,
|
|
|
177 |
show_copy_button=True,
|
178 |
avatar_images=("./public/human.jpg", "./public/bot.jpg"),
|
179 |
)
|
180 |
+
# extendable components for extra knowledge
|
181 |
with gr.Accordion(label="Additional Knowledge", open=False):
|
182 |
gr.Markdown("""
|
183 |
*Hint:* Add extra knowledge to see GODEL work the best.
|
184 |
+
Knowledge doesn't work with Mistral and will be ignored.
|
185 |
""")
|
186 |
# textbox to enter the knowledge
|
187 |
knowledge_input = gr.Textbox(
|
|
|
223 |
"Does money buy happiness?",
|
224 |
"",
|
225 |
(
|
226 |
+
"Respond from the perspective of billionaire heir"
|
227 |
+
" living his best life with his father's money."
|
228 |
),
|
229 |
"Mistral",
|
230 |
"None",
|
|
|
233 |
"Does money buy happiness?",
|
234 |
"",
|
235 |
(
|
236 |
+
"Respond from the perspective of billionaire heir"
|
237 |
+
" living his best life with his father's money."
|
238 |
),
|
239 |
"Mistral",
|
240 |
"SHAP",
|
|
|
257 |
[
|
258 |
"Does money buy happiness?",
|
259 |
(
|
260 |
+
"Some studies have found a correlation between income"
|
261 |
+
" and happiness, but this relationship often has"
|
262 |
+
" diminishing returns. From a psychological standpoint,"
|
263 |
+
" it's not just having money, but how it is used that"
|
264 |
+
" influences happiness."
|
265 |
),
|
266 |
"",
|
267 |
"GODEL",
|
268 |
"SHAP",
|
269 |
],
|
270 |
+
[
|
271 |
+
"Does money buy happiness?",
|
272 |
+
(
|
273 |
+
"Some studies have found a correlation between income"
|
274 |
+
" and happiness, but this relationship often has"
|
275 |
+
" diminishing returns. From a psychological standpoint,"
|
276 |
+
" it's not just having money, but how it is used that"
|
277 |
+
" influences happiness."
|
278 |
+
),
|
279 |
+
"",
|
280 |
+
"GODEL",
|
281 |
+
"Attention",
|
282 |
+
],
|
283 |
+
[
|
284 |
+
"Does money buy happiness?",
|
285 |
+
"",
|
286 |
+
"",
|
287 |
+
"GODEL",
|
288 |
+
"Attention",
|
289 |
+
],
|
290 |
],
|
291 |
inputs=[
|
292 |
user_prompt,
|
|
|
360 |
# load about.md markdown
|
361 |
gr.Markdown(value=load_md("public/about.md"))
|
362 |
with gr.Accordion(label="Credits, Data Protection, License"):
|
363 |
+
# load credits and data protection markdown
|
364 |
gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
|
365 |
|
366 |
# mount function for fastAPI Application
|
model/godel.py
CHANGED
@@ -6,7 +6,7 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, GenerationConfig
|
|
6 |
# internal imports
|
7 |
from utils import modelling as mdl
|
8 |
|
9 |
-
# global model and tokenizer instance (created on
|
10 |
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
11 |
MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
12 |
|
|
|
6 |
# internal imports
|
7 |
from utils import modelling as mdl
|
8 |
|
9 |
+
# global model and tokenizer instance (created on initial build)
|
10 |
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
11 |
MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
12 |
|
model/mistral.py
CHANGED
@@ -9,7 +9,8 @@ import gradio as gr
|
|
9 |
from utils import modelling as mdl
|
10 |
from utils import formatting as fmt
|
11 |
|
12 |
-
# global model and tokenizer instance (created on
|
|
|
13 |
device = mdl.get_device()
|
14 |
if device == torch.device("cuda"):
|
15 |
n_gpus, max_memory, bnb_config = mdl.gpu_loading_config()
|
@@ -17,13 +18,15 @@ if device == torch.device("cuda"):
|
|
17 |
MODEL = AutoModelForCausalLM.from_pretrained(
|
18 |
"mistralai/Mistral-7B-Instruct-v0.2",
|
19 |
quantization_config=bnb_config,
|
20 |
-
device_map="auto",
|
21 |
max_memory={i: max_memory for i in range(n_gpus)},
|
22 |
)
|
23 |
|
|
|
24 |
else:
|
25 |
MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
26 |
MODEL.to(device)
|
|
|
27 |
TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
28 |
|
29 |
# default model config
|
@@ -48,12 +51,13 @@ def set_config(config_dict: dict):
|
|
48 |
CONFIG.update(**config_dict)
|
49 |
|
50 |
|
51 |
-
# advanced formatting function that takes into
|
52 |
-
# CREDIT:
|
53 |
# see https://github.com/chujiezheng/chat_templates/
|
54 |
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
|
55 |
prompt = ""
|
56 |
|
|
|
57 |
if knowledge != "":
|
58 |
gr.Info("""
|
59 |
Mistral doesn't support additional knowledge, it's gonna be ignored.
|
@@ -94,7 +98,7 @@ def format_answer(answer: str):
|
|
94 |
|
95 |
# checking if proper history got returned
|
96 |
if len(segments) > 1:
|
97 |
-
# return text after the last ['/INST'] -
|
98 |
formatted_answer = segments[-1].strip()
|
99 |
else:
|
100 |
# return warning and full answer if not enough [/INST] tokens found
|
@@ -108,7 +112,11 @@ def format_answer(answer: str):
|
|
108 |
return formatted_answer
|
109 |
|
110 |
|
|
|
|
|
|
|
111 |
def respond(prompt: str):
|
|
|
112 |
set_config({})
|
113 |
|
114 |
# tokenizing inputs and configuring model
|
@@ -117,6 +125,9 @@ def respond(prompt: str):
|
|
117 |
# generating text with tokenized input, returning output
|
118 |
output_ids = MODEL.generate(input_ids, generation_config=CONFIG)
|
119 |
output_text = TOKENIZER.batch_decode(output_ids)
|
|
|
|
|
120 |
output_text = fmt.format_output_text(output_text)
|
121 |
|
|
|
122 |
return format_answer(output_text)
|
|
|
9 |
from utils import modelling as mdl
|
10 |
from utils import formatting as fmt
|
11 |
|
12 |
+
# global model and tokenizer instance (created on initial build)
|
13 |
+
# determine if GPU is available and load model accordingly
|
14 |
device = mdl.get_device()
|
15 |
if device == torch.device("cuda"):
|
16 |
n_gpus, max_memory, bnb_config = mdl.gpu_loading_config()
|
|
|
18 |
MODEL = AutoModelForCausalLM.from_pretrained(
|
19 |
"mistralai/Mistral-7B-Instruct-v0.2",
|
20 |
quantization_config=bnb_config,
|
21 |
+
device_map="auto",
|
22 |
max_memory={i: max_memory for i in range(n_gpus)},
|
23 |
)
|
24 |
|
25 |
+
# otherwise, load model on CPU
|
26 |
else:
|
27 |
MODEL = AutoModelForCausalLM.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
28 |
MODEL.to(device)
|
29 |
+
# load tokenizer
|
30 |
TOKENIZER = AutoTokenizer.from_pretrained("mistralai/Mistral-7B-Instruct-v0.2")
|
31 |
|
32 |
# default model config
|
|
|
51 |
CONFIG.update(**config_dict)
|
52 |
|
53 |
|
54 |
+
# advanced formatting function that takes into account a conversation history
|
55 |
+
# CREDIT: adapted from the Mistral AI Instruct chat template
|
56 |
# see https://github.com/chujiezheng/chat_templates/
|
57 |
def format_prompt(message: str, history: list, system_prompt: str, knowledge: str = ""):
|
58 |
prompt = ""
|
59 |
|
60 |
+
# send information to the ui if knowledge is not empty
|
61 |
if knowledge != "":
|
62 |
gr.Info("""
|
63 |
Mistral doesn't support additional knowledge, it's gonna be ignored.
|
|
|
98 |
|
99 |
# checking if proper history got returned
|
100 |
if len(segments) > 1:
|
101 |
+
# return text after the last ['/INST'] - response to last message
|
102 |
formatted_answer = segments[-1].strip()
|
103 |
else:
|
104 |
# return warning and full answer if not enough [/INST] tokens found
|
|
|
112 |
return formatted_answer
|
113 |
|
114 |
|
115 |
+
# response class calling the model and returning the model output message
|
116 |
+
# CREDIT: Copied from official interference example on Huggingface
|
117 |
+
# see https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.2
|
118 |
def respond(prompt: str):
|
119 |
+
# setting config to default
|
120 |
set_config({})
|
121 |
|
122 |
# tokenizing inputs and configuring model
|
|
|
125 |
# generating text with tokenized input, returning output
|
126 |
output_ids = MODEL.generate(input_ids, generation_config=CONFIG)
|
127 |
output_text = TOKENIZER.batch_decode(output_ids)
|
128 |
+
|
129 |
+
# formatting output text with special function
|
130 |
output_text = fmt.format_output_text(output_text)
|
131 |
|
132 |
+
# returning the model output string
|
133 |
return format_answer(output_text)
|
utils/formatting.py
CHANGED
@@ -100,7 +100,7 @@ def avg_attention(attention_values, model: str):
|
|
100 |
|
101 |
# removing the last dimension and transposing to get the correct shape
|
102 |
attention = attention[:, :, :, 0]
|
103 |
-
attention = attention.transpose
|
104 |
|
105 |
# return the averaged attention values
|
106 |
return np.mean(attention, axis=1)
|
|
|
100 |
|
101 |
# removing the last dimension and transposing to get the correct shape
|
102 |
attention = attention[:, :, :, 0]
|
103 |
+
attention = attention.transpose()
|
104 |
|
105 |
# return the averaged attention values
|
106 |
return np.mean(attention, axis=1)
|
utils/modelling.py
CHANGED
@@ -45,7 +45,7 @@ def prompt_limiter(
|
|
45 |
|
46 |
# if token count small enough, adding history bit by bit
|
47 |
if pre_count < 800:
|
48 |
-
# setting the count to the
|
49 |
count = pre_count
|
50 |
# reversing the history to prioritize recent conversations
|
51 |
history.reverse()
|
@@ -76,6 +76,7 @@ def token_counter(tokenizer, text: str):
|
|
76 |
return len(tokens[0])
|
77 |
|
78 |
|
|
|
79 |
def get_device():
|
80 |
if torch.cuda.is_available():
|
81 |
device = torch.device("cuda")
|
@@ -85,7 +86,9 @@ def get_device():
|
|
85 |
return device
|
86 |
|
87 |
|
88 |
-
#
|
|
|
|
|
89 |
def gpu_loading_config(max_memory: str = "15000MB"):
|
90 |
n_gpus = torch.cuda.device_count()
|
91 |
|
|
|
45 |
|
46 |
# if token count small enough, adding history bit by bit
|
47 |
if pre_count < 800:
|
48 |
+
# setting the count to the pre-count
|
49 |
count = pre_count
|
50 |
# reversing the history to prioritize recent conversations
|
51 |
history.reverse()
|
|
|
76 |
return len(tokens[0])
|
77 |
|
78 |
|
79 |
+
# function to determine the device to use
|
80 |
def get_device():
|
81 |
if torch.cuda.is_available():
|
82 |
device = torch.device("cuda")
|
|
|
86 |
return device
|
87 |
|
88 |
|
89 |
+
# function to set device config
|
90 |
+
# CREDIT: Adapted from captum llama 2 example
|
91 |
+
# see https://captum.ai/tutorials/Llama2_LLM_Attribution
|
92 |
def gpu_loading_config(max_memory: str = "15000MB"):
|
93 |
n_gpus = torch.cuda.device_count()
|
94 |
|