Spaces:
Runtime error
Runtime error
LennardZuendorf
commited on
Commit
•
2492536
1
Parent(s):
3f2ed3d
chore: updating documentation
Browse files- .gitignore +0 -1
- README.md +5 -0
- __init__.py +1 -2
- backend/__init__.py +1 -2
- backend/controller.py +13 -10
- components/iframe/README.md +13 -47
- components/iframe/backend/gradio_iframe/iframe.py +2 -0
- components/iframe/frontend/Index.svelte +3 -0
- components/iframe/frontend/shared/HTML.svelte +8 -0
- explanation/__init__.py +1 -2
- explanation/interpret_shap.py +14 -4
- explanation/markup.py +22 -5
- explanation/visualize.py +17 -7
- main.py +44 -16
- model/__init__.py +1 -2
- model/godel.py +20 -9
- utils/__init__.py +1 -0
- utils/formatting.py +10 -4
- utils/modelling.py +16 -11
.gitignore
CHANGED
@@ -2,4 +2,3 @@
|
|
2 |
__pycache__/
|
3 |
/start-venv.sh
|
4 |
/components/iframe/dist/
|
5 |
-
/components/
|
|
|
2 |
__pycache__/
|
3 |
/start-venv.sh
|
4 |
/components/iframe/dist/
|
|
README.md
CHANGED
@@ -80,6 +80,7 @@ This project is licensed under the MIT License, see [LICENSE](LICENSE.md) for mo
|
|
80 |
- University: HTW Berlin
|
81 |
|
82 |
See code for in detailed credits, work is strongly based on:
|
|
|
83 |
#### GODEL
|
84 |
- [HGF Model Page](https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq?text=Hey+my+name+is+Mariama%21+How+are+you%3F)
|
85 |
- [Paper on HGF](https://huggingface.co/papers/2206.11309)
|
@@ -88,3 +89,7 @@ See code for in detailed credits, work is strongly based on:
|
|
88 |
#### SHAP
|
89 |
- [Github](https://github.com/shap/shap)
|
90 |
- [Inital Paper](https://arxiv.org/abs/1705.07874)
|
|
|
|
|
|
|
|
|
|
80 |
- University: HTW Berlin
|
81 |
|
82 |
See code for in detailed credits, work is strongly based on:
|
83 |
+
|
84 |
#### GODEL
|
85 |
- [HGF Model Page](https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq?text=Hey+my+name+is+Mariama%21+How+are+you%3F)
|
86 |
- [Paper on HGF](https://huggingface.co/papers/2206.11309)
|
|
|
89 |
#### SHAP
|
90 |
- [Github](https://github.com/shap/shap)
|
91 |
- [Inital Paper](https://arxiv.org/abs/1705.07874)
|
92 |
+
|
93 |
+
#### Custom Component (/components/iframe/)
|
94 |
+
|
95 |
+
Is based on Gradio component, see indivdual README for full changelog.
|
__init__.py
CHANGED
@@ -1,2 +1 @@
|
|
1 |
-
# empty init file for the
|
2 |
-
# for fastapi to recognize the module
|
|
|
1 |
+
# empty init file for the module
|
|
backend/__init__.py
CHANGED
@@ -1,2 +1 @@
|
|
1 |
-
# empty init file for the
|
2 |
-
# for fastapi to recognize the module
|
|
|
1 |
+
# empty init file for the modules
|
|
backend/controller.py
CHANGED
@@ -1,15 +1,16 @@
|
|
1 |
# controller for the application that calls the model and explanation functions
|
2 |
-
#
|
3 |
|
4 |
# external imports
|
5 |
import gradio as gr
|
6 |
|
7 |
# internal imports
|
8 |
from model import godel
|
9 |
-
from explanation import interpret_shap as
|
10 |
|
11 |
|
12 |
# main interference function that that calls chat functions depending on selections
|
|
|
13 |
def interference(
|
14 |
prompt: str,
|
15 |
history: list,
|
@@ -17,18 +18,19 @@ def interference(
|
|
17 |
system_prompt: str,
|
18 |
xai_selection: str,
|
19 |
):
|
20 |
-
# if no system prompt is given, use a default one
|
21 |
-
if system_prompt
|
22 |
system_prompt = """
|
23 |
You are a helpful, respectful and honest assistant.
|
24 |
Always answer as helpfully as possible, while being safe.
|
25 |
"""
|
26 |
|
27 |
-
# if a XAI approach is selected, grab the XAI instance
|
28 |
if xai_selection in ("SHAP", "Attention"):
|
|
|
29 |
match xai_selection.lower():
|
30 |
case "shap":
|
31 |
-
xai =
|
32 |
case "attention":
|
33 |
xai = viz
|
34 |
case _:
|
@@ -37,9 +39,10 @@ def interference(
|
|
37 |
There was an error in the selected XAI Approach.
|
38 |
It is "{xai_selection}"
|
39 |
""")
|
|
|
40 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
41 |
|
42 |
-
# call the explained chat function
|
43 |
prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
|
44 |
model=godel,
|
45 |
xai=xai,
|
@@ -48,7 +51,7 @@ def interference(
|
|
48 |
system_prompt=system_prompt,
|
49 |
knowledge=knowledge,
|
50 |
)
|
51 |
-
# if no
|
52 |
else:
|
53 |
# call the vanilla chat function
|
54 |
prompt_output, history_output = vanilla_chat(
|
@@ -78,12 +81,12 @@ def vanilla_chat(
|
|
78 |
):
|
79 |
# formatting the prompt using the model's format_prompt function
|
80 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
|
|
81 |
# generating an answer using the model's respond function
|
82 |
answer = model.respond(prompt)
|
83 |
|
84 |
# updating the chat history with the new answer
|
85 |
history.append((message, answer))
|
86 |
-
|
87 |
# returning the updated history
|
88 |
return "", history
|
89 |
|
@@ -94,7 +97,7 @@ def explained_chat(
|
|
94 |
# formatting the prompt using the model's format_prompt function
|
95 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
96 |
|
97 |
-
# generating an answer using the
|
98 |
answer, xai_graphic, xai_markup = xai.chat_explained(model, prompt)
|
99 |
|
100 |
# updating the chat history with the new answer
|
|
|
1 |
# controller for the application that calls the model and explanation functions
|
2 |
+
# returns the updated conversation history and extra elements
|
3 |
|
4 |
# external imports
|
5 |
import gradio as gr
|
6 |
|
7 |
# internal imports
|
8 |
from model import godel
|
9 |
+
from explanation import interpret_shap as shap_int, visualize as viz
|
10 |
|
11 |
|
12 |
# main interference function that that calls chat functions depending on selections
|
13 |
+
# is getting called on every chat submit
|
14 |
def interference(
|
15 |
prompt: str,
|
16 |
history: list,
|
|
|
18 |
system_prompt: str,
|
19 |
xai_selection: str,
|
20 |
):
|
21 |
+
# if no proper system prompt is given, use a default one
|
22 |
+
if system_prompt in ('', ' '):
|
23 |
system_prompt = """
|
24 |
You are a helpful, respectful and honest assistant.
|
25 |
Always answer as helpfully as possible, while being safe.
|
26 |
"""
|
27 |
|
28 |
+
# if a XAI approach is selected, grab the XAI module instance
|
29 |
if xai_selection in ("SHAP", "Attention"):
|
30 |
+
# matching selection
|
31 |
match xai_selection.lower():
|
32 |
case "shap":
|
33 |
+
xai = shap_int
|
34 |
case "attention":
|
35 |
xai = viz
|
36 |
case _:
|
|
|
39 |
There was an error in the selected XAI Approach.
|
40 |
It is "{xai_selection}"
|
41 |
""")
|
42 |
+
# raise runtime exception
|
43 |
raise RuntimeError("There was an error in the selected XAI approach.")
|
44 |
|
45 |
+
# call the explained chat function with the model instance
|
46 |
prompt_output, history_output, xai_graphic, xai_markup = explained_chat(
|
47 |
model=godel,
|
48 |
xai=xai,
|
|
|
51 |
system_prompt=system_prompt,
|
52 |
knowledge=knowledge,
|
53 |
)
|
54 |
+
# if no XAI approach is selected call the vanilla chat function
|
55 |
else:
|
56 |
# call the vanilla chat function
|
57 |
prompt_output, history_output = vanilla_chat(
|
|
|
81 |
):
|
82 |
# formatting the prompt using the model's format_prompt function
|
83 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
84 |
+
|
85 |
# generating an answer using the model's respond function
|
86 |
answer = model.respond(prompt)
|
87 |
|
88 |
# updating the chat history with the new answer
|
89 |
history.append((message, answer))
|
|
|
90 |
# returning the updated history
|
91 |
return "", history
|
92 |
|
|
|
97 |
# formatting the prompt using the model's format_prompt function
|
98 |
prompt = model.format_prompt(message, history, system_prompt, knowledge)
|
99 |
|
100 |
+
# generating an answer using the methods chat function
|
101 |
answer, xai_graphic, xai_markup = xai.chat_explained(model, prompt)
|
102 |
|
103 |
# updating the chat history with the new answer
|
components/iframe/README.md
CHANGED
@@ -1,51 +1,17 @@
|
|
1 |
-
#
|
2 |
-
A custom gradio component to embed an iframe in a gradio interface. This component is based on the [HTML]() component.
|
3 |
-
It's currently still a work in progress.
|
4 |
|
5 |
-
|
|
|
6 |
|
7 |
-
|
8 |
-
|
|
|
9 |
|
10 |
-
|
11 |
-
|
12 |
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
gr.Interface(
|
20 |
-
iFrame(
|
21 |
-
label="iFrame Example",
|
22 |
-
value=("""
|
23 |
-
<iframe width="560"
|
24 |
-
height="315"
|
25 |
-
src="https://www.youtube.com/embed/dQw4w9WgXcQ?si=QfHLpHZsI98oZT1G"
|
26 |
-
title="YouTube video player"
|
27 |
-
frameborder="0"
|
28 |
-
allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share"
|
29 |
-
allowfullscreen>
|
30 |
-
</iframe>"""),
|
31 |
-
show_label=True)
|
32 |
-
)
|
33 |
-
```
|
34 |
-
|
35 |
-
## Roadmap
|
36 |
-
|
37 |
-
- [ ] Add manual hand over of other iFrame options.
|
38 |
-
- [ ] Explore switch between src and srcdoc through variable.
|
39 |
-
|
40 |
-
## Known Issues
|
41 |
-
|
42 |
-
**There are many reason why it's not a good idea to embed websites in an iframe.**
|
43 |
-
See [this](https://blog.bitsrc.io/4-security-concerns-with-iframes-every-web-developer-should-know-24c73e6a33e4), or just google "iframe security concerns" for more information. Also, iFrames will use additional computing power and memory, which can slow down the interface.
|
44 |
-
|
45 |
-
Also, this component is still a work in progress and not fully tested. Use at your own risk.
|
46 |
-
|
47 |
-
### Other Issues
|
48 |
-
|
49 |
-
- Height sometimes does not grow according to the inner component.
|
50 |
-
- The component is not completely responsive yet and struggles with variable heigth.
|
51 |
-
- ...
|
|
|
1 |
+
# gradio iFrame
|
|
|
|
|
2 |
|
3 |
+
This is a custom gradio component used to display the shap package text plot. Which is interactive HTML and needs a custom wrapper.
|
4 |
+
See custom component examples at offical [docu](https://www.gradio.app/guides/custom-components-in-five-minutes)
|
5 |
|
6 |
+
# Credit
|
7 |
+
CREDIT: based mostly of Gradio template component, HTML
|
8 |
+
see: https://www.gradio.app/docs/html
|
9 |
|
10 |
+
## Changes
|
11 |
+
**Addition/changes are marked. Everything else can be considered the work of other (the Gradio Team)**
|
12 |
|
13 |
+
#### Changes Files/Contributions
|
14 |
+
- backend/iframe.py - updating component to accept custom height/width and added new example
|
15 |
+
- demo/app.py - slightly changed demo file for better dev experience
|
16 |
+
- frontend/index.svelte - slightly changed to accept custom height/width
|
17 |
+
- frontend/HTML.svelte - updated to use iFrame and added custom function to programmtically set heigth values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
components/iframe/backend/gradio_iframe/iframe.py
CHANGED
@@ -62,10 +62,12 @@ class iFrame(Component):
|
|
62 |
value=value,
|
63 |
)
|
64 |
|
|
|
65 |
self.height = height
|
66 |
self.width = width
|
67 |
|
68 |
def example_inputs(self) -> Any:
|
|
|
69 |
return """<iframe width="560" height="315" src="https://www.youtube.com/embed/dQw4w9WgXcQ?si=QfHLpHZsI98oZT1G" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>"""
|
70 |
|
71 |
def preprocess(self, payload: str | None) -> str | None:
|
|
|
62 |
value=value,
|
63 |
)
|
64 |
|
65 |
+
# updating component to take custom height and width values
|
66 |
self.height = height
|
67 |
self.width = width
|
68 |
|
69 |
def example_inputs(self) -> Any:
|
70 |
+
# setting a custom example
|
71 |
return """<iframe width="560" height="315" src="https://www.youtube.com/embed/dQw4w9WgXcQ?si=QfHLpHZsI98oZT1G" title="YouTube video player" frameborder="0" allow="accelerometer; autoplay; clipboard-write; encrypted-media; gyroscope; picture-in-picture; web-share" allowfullscreen></iframe>"""
|
72 |
|
73 |
def preprocess(self, payload: str | None) -> str | None:
|
components/iframe/frontend/Index.svelte
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
<script lang="ts">
|
2 |
import type { Gradio } from "@gradio/utils";
|
3 |
import HTML from "./shared/HTML.svelte";
|
@@ -10,6 +12,7 @@
|
|
10 |
export let elem_classes: string[] = [];
|
11 |
export let visible = true;
|
12 |
export let value = "";
|
|
|
13 |
export let height: string;
|
14 |
export let width: string = "100%";
|
15 |
export let loading_status: LoadingStatus;
|
|
|
1 |
+
# index component that wraps the custom iFrame ("HTML")
|
2 |
+
|
3 |
<script lang="ts">
|
4 |
import type { Gradio } from "@gradio/utils";
|
5 |
import HTML from "./shared/HTML.svelte";
|
|
|
12 |
export let elem_classes: string[] = [];
|
13 |
export let visible = true;
|
14 |
export let value = "";
|
15 |
+
# updated to take custom heigth
|
16 |
export let height: string;
|
17 |
export let width: string = "100%";
|
18 |
export let loading_status: LoadingStatus;
|
components/iframe/frontend/shared/HTML.svelte
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
<script lang="ts">
|
2 |
import { createEventDispatcher } from "svelte";
|
3 |
export let elem_classes: string[] = [];
|
@@ -5,6 +7,7 @@
|
|
5 |
export let visible = true;
|
6 |
export let min_height = false;
|
7 |
|
|
|
8 |
export let height = "100%";
|
9 |
export let width = "100%";
|
10 |
|
@@ -12,10 +15,14 @@
|
|
12 |
|
13 |
let iframeElement;
|
14 |
|
|
|
15 |
const onLoad = () => {
|
16 |
try {
|
|
|
17 |
const iframeDocument = iframeElement.contentDocument || iframeElement.contentWindow.document;
|
|
|
18 |
if (height === "100%") {
|
|
|
19 |
const height = iframeDocument.documentElement.scrollHeight;
|
20 |
iframeElement.style.height = `${height}px`;
|
21 |
}
|
@@ -33,6 +40,7 @@
|
|
33 |
class:hide={!visible}
|
34 |
class:height={height}
|
35 |
>
|
|
|
36 |
<iframe
|
37 |
bind:this={iframeElement}
|
38 |
title="iframe component"
|
|
|
1 |
+
# HTML component that implements custom iFrame
|
2 |
+
|
3 |
<script lang="ts">
|
4 |
import { createEventDispatcher } from "svelte";
|
5 |
export let elem_classes: string[] = [];
|
|
|
7 |
export let visible = true;
|
8 |
export let min_height = false;
|
9 |
|
10 |
+
# default setting height and width
|
11 |
export let height = "100%";
|
12 |
export let width = "100%";
|
13 |
|
|
|
15 |
|
16 |
let iframeElement;
|
17 |
|
18 |
+
# custom function to update iFrame height on load of HTML
|
19 |
const onLoad = () => {
|
20 |
try {
|
21 |
+
# calling iFrame document
|
22 |
const iframeDocument = iframeElement.contentDocument || iframeElement.contentWindow.document;
|
23 |
+
# if heigth not custom, setting height individually
|
24 |
if (height === "100%") {
|
25 |
+
# grabbing height from iFrame document
|
26 |
const height = iframeDocument.documentElement.scrollHeight;
|
27 |
iframeElement.style.height = `${height}px`;
|
28 |
}
|
|
|
40 |
class:hide={!visible}
|
41 |
class:height={height}
|
42 |
>
|
43 |
+
# updated to use Iframe instead of HTML, using string values with srcdoc
|
44 |
<iframe
|
45 |
bind:this={iframeElement}
|
46 |
title="iframe component"
|
explanation/__init__.py
CHANGED
@@ -1,2 +1 @@
|
|
1 |
-
# empty init file for the
|
2 |
-
# for fastapi to recognize the module
|
|
|
1 |
+
# empty init file for the modules
|
|
explanation/interpret_shap.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
# interpret module that implements the interpretability method
|
|
|
2 |
# external imports
|
3 |
from shap import models, maskers, plots, PartitionExplainer
|
4 |
import torch
|
@@ -14,14 +15,15 @@ TEXT_MASKER = None
|
|
14 |
|
15 |
# main explain function that returns a chat with explanations
|
16 |
def chat_explained(model, prompt):
|
17 |
-
model.set_config()
|
18 |
|
19 |
# create the shap explainer
|
20 |
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
|
|
|
21 |
# get the shap values for the prompt
|
22 |
shap_values = shap_explainer([prompt])
|
23 |
|
24 |
-
# create the explanation graphic and
|
25 |
graphic = create_graphic(shap_values)
|
26 |
marked_text = markup_text(
|
27 |
shap_values.data[0], shap_values.values[0], variant="shap"
|
@@ -29,20 +31,26 @@ def chat_explained(model, prompt):
|
|
29 |
|
30 |
# create the response text
|
31 |
response_text = fmt.format_output_text(shap_values.output_names)
|
|
|
|
|
32 |
return response_text, graphic, marked_text
|
33 |
|
34 |
|
|
|
35 |
def wrap_shap(model):
|
|
|
36 |
global TEXT_MASKER, TEACHER_FORCING
|
37 |
|
38 |
# set the device to cuda if gpu is available
|
39 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
40 |
|
41 |
-
# updating the model settings
|
42 |
model.set_config()
|
43 |
|
44 |
# (re)initialize the shap models and masker
|
|
|
45 |
text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
|
|
|
46 |
TEACHER_FORCING = models.TeacherForcing(
|
47 |
text_generation,
|
48 |
model.TOKENIZER,
|
@@ -50,13 +58,15 @@ def wrap_shap(model):
|
|
50 |
similarity_model=model.MODEL,
|
51 |
similarity_tokenizer=model.TOKENIZER,
|
52 |
)
|
|
|
53 |
TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
|
54 |
|
55 |
|
56 |
# graphic plotting function that creates a html graphic (as string) for the explanation
|
57 |
def create_graphic(shap_values):
|
|
|
58 |
# create the html graphic using shap text plot function
|
59 |
graphic_html = plots.text(shap_values, display=False)
|
60 |
|
61 |
-
# return the html graphic as string
|
62 |
return str(graphic_html)
|
|
|
1 |
# interpret module that implements the interpretability method
|
2 |
+
|
3 |
# external imports
|
4 |
from shap import models, maskers, plots, PartitionExplainer
|
5 |
import torch
|
|
|
15 |
|
16 |
# main explain function that returns a chat with explanations
|
17 |
def chat_explained(model, prompt):
|
18 |
+
model.set_config({})
|
19 |
|
20 |
# create the shap explainer
|
21 |
shap_explainer = PartitionExplainer(model.MODEL, model.TOKENIZER)
|
22 |
+
|
23 |
# get the shap values for the prompt
|
24 |
shap_values = shap_explainer([prompt])
|
25 |
|
26 |
+
# create the explanation graphic and marked text array
|
27 |
graphic = create_graphic(shap_values)
|
28 |
marked_text = markup_text(
|
29 |
shap_values.data[0], shap_values.values[0], variant="shap"
|
|
|
31 |
|
32 |
# create the response text
|
33 |
response_text = fmt.format_output_text(shap_values.output_names)
|
34 |
+
|
35 |
+
# return response, graphic and marked_text array
|
36 |
return response_text, graphic, marked_text
|
37 |
|
38 |
|
39 |
+
# function used to wrap the model with a shap model
|
40 |
def wrap_shap(model):
|
41 |
+
# calling global variants
|
42 |
global TEXT_MASKER, TEACHER_FORCING
|
43 |
|
44 |
# set the device to cuda if gpu is available
|
45 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
46 |
|
47 |
+
# updating the model settings
|
48 |
model.set_config()
|
49 |
|
50 |
# (re)initialize the shap models and masker
|
51 |
+
# creating a shap text_generation model
|
52 |
text_generation = models.TextGeneration(model.MODEL, model.TOKENIZER)
|
53 |
+
# wrapping the text generation model in a teacher forcing model
|
54 |
TEACHER_FORCING = models.TeacherForcing(
|
55 |
text_generation,
|
56 |
model.TOKENIZER,
|
|
|
58 |
similarity_model=model.MODEL,
|
59 |
similarity_tokenizer=model.TOKENIZER,
|
60 |
)
|
61 |
+
# setting the text masker as an empty string
|
62 |
TEXT_MASKER = maskers.Text(model.TOKENIZER, " ", collapse_mask_token=True)
|
63 |
|
64 |
|
65 |
# graphic plotting function that creates a html graphic (as string) for the explanation
|
66 |
def create_graphic(shap_values):
|
67 |
+
|
68 |
# create the html graphic using shap text plot function
|
69 |
graphic_html = plots.text(shap_values, display=False)
|
70 |
|
71 |
+
# return the html graphic as string to display in iFrame
|
72 |
return str(graphic_html)
|
explanation/markup.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
# markup module that provides marked up text
|
2 |
|
3 |
# external imports
|
4 |
import numpy as np
|
@@ -8,10 +8,12 @@ from numpy import ndarray
|
|
8 |
from utils import formatting as fmt
|
9 |
|
10 |
|
|
|
11 |
def markup_text(input_text: list, text_values: ndarray, variant: str):
|
|
|
12 |
bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
|
13 |
|
14 |
-
#
|
15 |
# attention is averaged, SHAP summed up
|
16 |
if variant == "shap":
|
17 |
text_values = np.transpose(text_values)
|
@@ -22,34 +24,49 @@ def markup_text(input_text: list, text_values: ndarray, variant: str):
|
|
22 |
# Determine the minimum and maximum values
|
23 |
min_val, max_val = np.min(text_values), np.max(text_values)
|
24 |
|
25 |
-
#
|
|
|
26 |
if variant == "visualizer":
|
27 |
neg_thresholds = np.linspace(
|
28 |
0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
|
29 |
)[1:]
|
|
|
30 |
else:
|
31 |
neg_thresholds = np.linspace(
|
32 |
min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
|
33 |
)[1:]
|
|
|
34 |
pos_thresholds = np.linspace(0, max_val, num=(len(bucket_tags) - 1) // 2 + 1)[1:]
|
|
|
35 |
thresholds = np.concatenate([neg_thresholds, [0], pos_thresholds])
|
36 |
|
|
|
37 |
marked_text = []
|
38 |
|
39 |
-
#
|
40 |
for text, value in zip(input_text, text_values):
|
|
|
41 |
bucket = "-5"
|
|
|
|
|
42 |
for i, threshold in zip(bucket_tags, thresholds):
|
|
|
43 |
if value >= threshold:
|
44 |
bucket = i
|
|
|
45 |
marked_text.append((text, str(bucket)))
|
46 |
|
|
|
47 |
return marked_text
|
48 |
|
49 |
|
|
|
|
|
50 |
def color_codes():
|
51 |
return {
|
52 |
-
#
|
|
|
|
|
53 |
"-5": "#3251a8", # Strong Light Sky Blue
|
54 |
"-4": "#5A7FB2", # Slightly Lighter Sky Blue
|
55 |
"-3": "#8198BC", # Intermediate Sky Blue
|
|
|
1 |
+
# markup module that provides marked up text as an array
|
2 |
|
3 |
# external imports
|
4 |
import numpy as np
|
|
|
8 |
from utils import formatting as fmt
|
9 |
|
10 |
|
11 |
+
# main function that assigns each text snipped a marked bucket
|
12 |
def markup_text(input_text: list, text_values: ndarray, variant: str):
|
13 |
+
# naming of the 11 buckets
|
14 |
bucket_tags = ["-5", "-4", "-3", "-2", "-1", "0", "+1", "+2", "+3", "+4", "+5"]
|
15 |
|
16 |
+
# flatten the values depending on the source
|
17 |
# attention is averaged, SHAP summed up
|
18 |
if variant == "shap":
|
19 |
text_values = np.transpose(text_values)
|
|
|
24 |
# Determine the minimum and maximum values
|
25 |
min_val, max_val = np.min(text_values), np.max(text_values)
|
26 |
|
27 |
+
# separate the threshold calculation for negative and positive values
|
28 |
+
# visualization negative thresholds are all 0 since attetion always positive
|
29 |
if variant == "visualizer":
|
30 |
neg_thresholds = np.linspace(
|
31 |
0, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
|
32 |
)[1:]
|
33 |
+
# standart config for 5 negative buckets
|
34 |
else:
|
35 |
neg_thresholds = np.linspace(
|
36 |
min_val, 0, num=(len(bucket_tags) - 1) // 2 + 1, endpoint=False
|
37 |
)[1:]
|
38 |
+
# creating positive thresholds between 0 and max values
|
39 |
pos_thresholds = np.linspace(0, max_val, num=(len(bucket_tags) - 1) // 2 + 1)[1:]
|
40 |
+
# combining thresholds
|
41 |
thresholds = np.concatenate([neg_thresholds, [0], pos_thresholds])
|
42 |
|
43 |
+
# init empty marked text list
|
44 |
marked_text = []
|
45 |
|
46 |
+
# looping over each text snippet and attribution value
|
47 |
for text, value in zip(input_text, text_values):
|
48 |
+
# setting inital bucket at lowest
|
49 |
bucket = "-5"
|
50 |
+
|
51 |
+
# looping over all bucket and their threshold
|
52 |
for i, threshold in zip(bucket_tags, thresholds):
|
53 |
+
# updating assigned bucket if value is above threshold
|
54 |
if value >= threshold:
|
55 |
bucket = i
|
56 |
+
# finally adding text and bucket assignment to list of tuples
|
57 |
marked_text.append((text, str(bucket)))
|
58 |
|
59 |
+
# returning list of marked text snippets as list of tuples
|
60 |
return marked_text
|
61 |
|
62 |
|
63 |
+
# function that defines color codes
|
64 |
+
# coloring along SHAP style coloring for consistency
|
65 |
def color_codes():
|
66 |
return {
|
67 |
+
# -5 to -1: Strong Light Sky Blue to Lighter Sky Blue
|
68 |
+
# 0: white (assuming default light mode)
|
69 |
+
# +1 to +5 light pink to string magenta
|
70 |
"-5": "#3251a8", # Strong Light Sky Blue
|
71 |
"-4": "#5A7FB2", # Slightly Lighter Sky Blue
|
72 |
"-3": "#8198BC", # Intermediate Sky Blue
|
explanation/visualize.py
CHANGED
@@ -1,21 +1,26 @@
|
|
1 |
-
# visualization module that creates an attention visualization
|
2 |
|
3 |
|
4 |
# internal imports
|
5 |
from utils import formatting as fmt
|
|
|
6 |
from .markup import markup_text
|
7 |
|
8 |
|
9 |
-
#
|
|
|
10 |
def chat_explained(model, prompt):
|
11 |
|
12 |
-
|
13 |
-
|
14 |
-
# get encoded input and output vectors
|
15 |
encoder_input_ids = model.TOKENIZER(
|
16 |
prompt, return_tensors="pt", add_special_tokens=True
|
17 |
).input_ids
|
18 |
-
|
|
|
|
|
|
|
|
|
|
|
19 |
encoder_text = fmt.format_tokens(
|
20 |
model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
|
21 |
)
|
@@ -24,20 +29,25 @@ def chat_explained(model, prompt):
|
|
24 |
)
|
25 |
|
26 |
# get attention values for the input and output vectors
|
|
|
27 |
attention_output = model.MODEL(
|
28 |
input_ids=encoder_input_ids,
|
29 |
decoder_input_ids=decoder_input_ids,
|
30 |
output_attentions=True,
|
31 |
)
|
32 |
|
|
|
33 |
averaged_attention = fmt.avg_attention(attention_output)
|
34 |
|
35 |
-
#
|
36 |
response_text = fmt.format_output_text(decoder_text)
|
|
|
37 |
graphic = (
|
38 |
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
39 |
" Visualization doesn't support an interactive graphic.</h4></div>"
|
40 |
)
|
|
|
41 |
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
|
42 |
|
|
|
43 |
return response_text, graphic, marked_text
|
|
|
1 |
+
# visualization module that creates an attention visualization
|
2 |
|
3 |
|
4 |
# internal imports
|
5 |
from utils import formatting as fmt
|
6 |
+
from model.godel import CONFIG
|
7 |
from .markup import markup_text
|
8 |
|
9 |
|
10 |
+
# chat function that returns an answer
|
11 |
+
# and marked text based on attention
|
12 |
def chat_explained(model, prompt):
|
13 |
|
14 |
+
# get encoded input
|
|
|
|
|
15 |
encoder_input_ids = model.TOKENIZER(
|
16 |
prompt, return_tensors="pt", add_special_tokens=True
|
17 |
).input_ids
|
18 |
+
# generate output together with attentions of the model
|
19 |
+
decoder_input_ids = model.MODEL.generate(
|
20 |
+
encoder_input_ids, output_attentions=True, **CONFIG
|
21 |
+
)
|
22 |
+
|
23 |
+
# get input and output text as list of strings
|
24 |
encoder_text = fmt.format_tokens(
|
25 |
model.TOKENIZER.convert_ids_to_tokens(encoder_input_ids[0])
|
26 |
)
|
|
|
29 |
)
|
30 |
|
31 |
# get attention values for the input and output vectors
|
32 |
+
# using already generated input and output
|
33 |
attention_output = model.MODEL(
|
34 |
input_ids=encoder_input_ids,
|
35 |
decoder_input_ids=decoder_input_ids,
|
36 |
output_attentions=True,
|
37 |
)
|
38 |
|
39 |
+
# averaging attention across layers
|
40 |
averaged_attention = fmt.avg_attention(attention_output)
|
41 |
|
42 |
+
# format response text for clean output
|
43 |
response_text = fmt.format_output_text(decoder_text)
|
44 |
+
# setting placeholder for iFrame graphic
|
45 |
graphic = (
|
46 |
"<div style='text-align: center; font-family:arial;'><h4>Attention"
|
47 |
" Visualization doesn't support an interactive graphic.</h4></div>"
|
48 |
)
|
49 |
+
# creating marked text using markup_text function and attention
|
50 |
marked_text = markup_text(encoder_text, averaged_attention, variant="visualizer")
|
51 |
|
52 |
+
# returning response, graphic and marked text array
|
53 |
return response_text, graphic, marked_text
|
main.py
CHANGED
@@ -14,13 +14,21 @@ from gradio_iframe import iFrame
|
|
14 |
from backend.controller import interference
|
15 |
from explanation.markup import color_codes
|
16 |
|
17 |
-
|
|
|
|
|
18 |
app = FastAPI()
|
|
|
|
|
|
|
|
|
19 |
css = """
|
20 |
.examples {text-align: start;}
|
21 |
.seperatedRow {border-top: 1rem solid;}",
|
22 |
"""
|
23 |
-
js
|
|
|
|
|
24 |
function () {
|
25 |
gradioURL = window.location.href
|
26 |
if (!gradioURL.endsWith('?__theme=light')) {
|
@@ -28,7 +36,8 @@ js = """
|
|
28 |
}
|
29 |
}
|
30 |
"""
|
31 |
-
|
|
|
32 |
|
33 |
|
34 |
# different functions to provide frontend abilities
|
@@ -56,8 +65,8 @@ def xai_info(xai_radio):
|
|
56 |
gr.Info("No XAI method was selected.")
|
57 |
|
58 |
|
59 |
-
# ui interface based on Gradio Blocks
|
60 |
-
# https://www.gradio.app/docs/interface)
|
61 |
with gr.Blocks(
|
62 |
css=css,
|
63 |
js=js,
|
@@ -88,6 +97,7 @@ with gr.Blocks(
|
|
88 |
""")
|
89 |
# row with columns for the different settings
|
90 |
with gr.Row(equal_height=True):
|
|
|
91 |
with gr.Accordion(label="Application Settings", open=False):
|
92 |
# column that takes up 3/4 of the row
|
93 |
with gr.Column(scale=3):
|
@@ -95,6 +105,7 @@ with gr.Blocks(
|
|
95 |
system_prompt = gr.Textbox(
|
96 |
label="System Prompt",
|
97 |
info="Set the models system prompt, dictating how it answers.",
|
|
|
98 |
placeholder=(
|
99 |
"You are a helpful, respectful and honest assistant. Always"
|
100 |
" answer as helpfully as possible, while being safe."
|
@@ -105,26 +116,29 @@ with gr.Blocks(
|
|
105 |
# checkbox group to select the xai method
|
106 |
xai_selection = gr.Radio(
|
107 |
["None", "SHAP", "Attention"],
|
108 |
-
label="
|
109 |
-
info="Select a
|
110 |
value="None",
|
111 |
interactive=True,
|
112 |
show_label=True,
|
113 |
)
|
114 |
|
115 |
-
# calling info functions on inputs for different settings
|
116 |
system_prompt.submit(system_prompt_info, [system_prompt])
|
117 |
xai_selection.input(xai_info, [xai_selection])
|
118 |
|
119 |
# row with chatbot ui displaying "conversation" with the model
|
120 |
with gr.Row(equal_height=True):
|
|
|
121 |
with gr.Group(elem_classes="border: 1px solid black;"):
|
122 |
# accordion to display the normalized input explanation
|
123 |
with gr.Accordion(label="Input Explanation", open=False):
|
124 |
gr.Markdown("""
|
125 |
The explanations are based on 10 buckets that range between the
|
126 |
lowest negative value (1 to 5) and the highest positive attribution value (6 to 10).
|
127 |
-
**The legend
|
|
|
|
|
128 |
""")
|
129 |
xai_text = gr.HighlightedText(
|
130 |
color_map=coloring,
|
@@ -132,15 +146,19 @@ with gr.Blocks(
|
|
132 |
show_legend=True,
|
133 |
show_label=False,
|
134 |
)
|
135 |
-
# out of the box chatbot component
|
136 |
# see documentation: https://www.gradio.app/docs/chatbot
|
137 |
chatbot = gr.Chatbot(
|
138 |
layout="panel",
|
139 |
show_copy_button=True,
|
140 |
avatar_images=("./public/human.jpg", "./public/bot.jpg"),
|
141 |
)
|
142 |
-
#
|
143 |
with gr.Accordion(label="Additional Knowledge", open=False):
|
|
|
|
|
|
|
|
|
144 |
knowledge_input = gr.Textbox(
|
145 |
value="",
|
146 |
label="Knowledge",
|
@@ -149,24 +167,31 @@ with gr.Blocks(
|
|
149 |
show_label=True,
|
150 |
)
|
151 |
# textbox to enter the user prompt
|
|
|
|
|
|
|
|
|
152 |
user_prompt = gr.Textbox(
|
153 |
label="Input Message",
|
154 |
max_lines=5,
|
155 |
info="""
|
156 |
Ask the ChatBot a question.
|
157 |
-
Hint: More complicated question give better explanation insights!
|
158 |
""",
|
159 |
show_label=True,
|
160 |
)
|
161 |
# row with columns for buttons to submit and clear content
|
162 |
with gr.Row(elem_classes=""):
|
163 |
-
with gr.Column(
|
164 |
# out of the box clear button which clearn the given components (see
|
165 |
-
#
|
166 |
clear_btn = gr.ClearButton([user_prompt, chatbot])
|
167 |
-
with gr.Column(
|
|
|
168 |
submit_btn = gr.Button("Submit", variant="primary")
|
|
|
169 |
with gr.Row(elem_classes="examples"):
|
|
|
|
|
170 |
gr.Examples(
|
171 |
label="Example Questions",
|
172 |
examples=[
|
@@ -235,18 +260,21 @@ with gr.Blocks(
|
|
235 |
# final row to show legal information
|
236 |
## - credits, data protection and link to the License
|
237 |
with gr.Tab(label="About"):
|
|
|
238 |
gr.Markdown(value=load_md("public/about.md"))
|
239 |
with gr.Accordion(label="Credits, Data Protection, License"):
|
|
|
240 |
gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
|
241 |
|
242 |
# mount function for fastAPI Application
|
243 |
app = gr.mount_gradio_app(app, ui, path="/")
|
244 |
|
245 |
-
# launch function
|
246 |
if __name__ == "__main__":
|
247 |
|
248 |
# use standard gradio launch option for hgf spaces
|
249 |
if os.environ["HOSTING"].lower() == "spaces":
|
|
|
250 |
ui.launch(auth=("htw", "berlin@123"))
|
251 |
|
252 |
# otherwise run the application on port 8080 in reload mode
|
|
|
14 |
from backend.controller import interference
|
15 |
from explanation.markup import color_codes
|
16 |
|
17 |
+
|
18 |
+
# global Variables and js/css
|
19 |
+
# creating FastAPI app and getting color codes
|
20 |
app = FastAPI()
|
21 |
+
coloring = color_codes()
|
22 |
+
|
23 |
+
|
24 |
+
# defining custom css and js for certain environments
|
25 |
css = """
|
26 |
.examples {text-align: start;}
|
27 |
.seperatedRow {border-top: 1rem solid;}",
|
28 |
"""
|
29 |
+
# custom js to force lightmode in custom environments
|
30 |
+
if os.environ["HOSTING"].lower() != "spaces":
|
31 |
+
js = """
|
32 |
function () {
|
33 |
gradioURL = window.location.href
|
34 |
if (!gradioURL.endsWith('?__theme=light')) {
|
|
|
36 |
}
|
37 |
}
|
38 |
"""
|
39 |
+
else:
|
40 |
+
js = ""
|
41 |
|
42 |
|
43 |
# different functions to provide frontend abilities
|
|
|
65 |
gr.Info("No XAI method was selected.")
|
66 |
|
67 |
|
68 |
+
# ui interface based on Gradio Blocks
|
69 |
+
# see https://www.gradio.app/docs/interface)
|
70 |
with gr.Blocks(
|
71 |
css=css,
|
72 |
js=js,
|
|
|
97 |
""")
|
98 |
# row with columns for the different settings
|
99 |
with gr.Row(equal_height=True):
|
100 |
+
# accordion that extends if clicked
|
101 |
with gr.Accordion(label="Application Settings", open=False):
|
102 |
# column that takes up 3/4 of the row
|
103 |
with gr.Column(scale=3):
|
|
|
105 |
system_prompt = gr.Textbox(
|
106 |
label="System Prompt",
|
107 |
info="Set the models system prompt, dictating how it answers.",
|
108 |
+
# default system prompt is set to this in the backend
|
109 |
placeholder=(
|
110 |
"You are a helpful, respectful and honest assistant. Always"
|
111 |
" answer as helpfully as possible, while being safe."
|
|
|
116 |
# checkbox group to select the xai method
|
117 |
xai_selection = gr.Radio(
|
118 |
["None", "SHAP", "Attention"],
|
119 |
+
label="Interpretability Settings",
|
120 |
+
info="Select a Interpretability Implementation to use.",
|
121 |
value="None",
|
122 |
interactive=True,
|
123 |
show_label=True,
|
124 |
)
|
125 |
|
126 |
+
# calling info functions on inputs/submits for different settings
|
127 |
system_prompt.submit(system_prompt_info, [system_prompt])
|
128 |
xai_selection.input(xai_info, [xai_selection])
|
129 |
|
130 |
# row with chatbot ui displaying "conversation" with the model
|
131 |
with gr.Row(equal_height=True):
|
132 |
+
# group to display components closely together
|
133 |
with gr.Group(elem_classes="border: 1px solid black;"):
|
134 |
# accordion to display the normalized input explanation
|
135 |
with gr.Accordion(label="Input Explanation", open=False):
|
136 |
gr.Markdown("""
|
137 |
The explanations are based on 10 buckets that range between the
|
138 |
lowest negative value (1 to 5) and the highest positive attribution value (6 to 10).
|
139 |
+
**The legend shows the color for each bucket.**
|
140 |
+
|
141 |
+
*HINT*: This works best in light mode.
|
142 |
""")
|
143 |
xai_text = gr.HighlightedText(
|
144 |
color_map=coloring,
|
|
|
146 |
show_legend=True,
|
147 |
show_label=False,
|
148 |
)
|
149 |
+
# out of the box chatbot component with avatar images
|
150 |
# see documentation: https://www.gradio.app/docs/chatbot
|
151 |
chatbot = gr.Chatbot(
|
152 |
layout="panel",
|
153 |
show_copy_button=True,
|
154 |
avatar_images=("./public/human.jpg", "./public/bot.jpg"),
|
155 |
)
|
156 |
+
# extenable components for extra knowledge
|
157 |
with gr.Accordion(label="Additional Knowledge", open=False):
|
158 |
+
gr.Markdown(
|
159 |
+
"*Hint:* Add extra knowledge to see GODEL work the best."
|
160 |
+
)
|
161 |
+
# textbox to enter the knowledge
|
162 |
knowledge_input = gr.Textbox(
|
163 |
value="",
|
164 |
label="Knowledge",
|
|
|
167 |
show_label=True,
|
168 |
)
|
169 |
# textbox to enter the user prompt
|
170 |
+
gr.Markdown(
|
171 |
+
"*Hint:* More complicated question give better explanation"
|
172 |
+
" insights!"
|
173 |
+
)
|
174 |
user_prompt = gr.Textbox(
|
175 |
label="Input Message",
|
176 |
max_lines=5,
|
177 |
info="""
|
178 |
Ask the ChatBot a question.
|
|
|
179 |
""",
|
180 |
show_label=True,
|
181 |
)
|
182 |
# row with columns for buttons to submit and clear content
|
183 |
with gr.Row(elem_classes=""):
|
184 |
+
with gr.Column():
|
185 |
# out of the box clear button which clearn the given components (see
|
186 |
+
# see: https://www.gradio.app/docs/clearbutton)
|
187 |
clear_btn = gr.ClearButton([user_prompt, chatbot])
|
188 |
+
with gr.Column():
|
189 |
+
# submit button that calls the backend functions on click
|
190 |
submit_btn = gr.Button("Submit", variant="primary")
|
191 |
+
# row with content examples that get autofilled on click
|
192 |
with gr.Row(elem_classes="examples"):
|
193 |
+
# examples util component
|
194 |
+
# see: https://www.gradio.app/docs/examples
|
195 |
gr.Examples(
|
196 |
label="Example Questions",
|
197 |
examples=[
|
|
|
260 |
# final row to show legal information
|
261 |
## - credits, data protection and link to the License
|
262 |
with gr.Tab(label="About"):
|
263 |
+
# load about.md markdown
|
264 |
gr.Markdown(value=load_md("public/about.md"))
|
265 |
with gr.Accordion(label="Credits, Data Protection, License"):
|
266 |
+
# load credits and dataprotection markdown
|
267 |
gr.Markdown(value=load_md("public/credits_dataprotection_license.md"))
|
268 |
|
269 |
# mount function for fastAPI Application
|
270 |
app = gr.mount_gradio_app(app, ui, path="/")
|
271 |
|
272 |
+
# launch function to launch the application
|
273 |
if __name__ == "__main__":
|
274 |
|
275 |
# use standard gradio launch option for hgf spaces
|
276 |
if os.environ["HOSTING"].lower() == "spaces":
|
277 |
+
# set password to deny public access
|
278 |
ui.launch(auth=("htw", "berlin@123"))
|
279 |
|
280 |
# otherwise run the application on port 8080 in reload mode
|
model/__init__.py
CHANGED
@@ -1,2 +1 @@
|
|
1 |
-
# empty init file for the
|
2 |
-
# for fastapi to recognize the module
|
|
|
1 |
+
# empty init file for the module
|
|
model/godel.py
CHANGED
@@ -6,21 +6,28 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
|
|
6 |
# internal imports
|
7 |
from utils import modelling as mdl
|
8 |
|
9 |
-
# model and tokenizer instance
|
10 |
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
11 |
MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
|
|
|
|
12 |
CONFIG = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
|
13 |
|
14 |
|
15 |
-
#
|
16 |
-
def set_config(config: dict
|
17 |
-
|
18 |
-
config = {}
|
19 |
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
|
26 |
# formatting class to formatting input for the model
|
@@ -56,8 +63,12 @@ def format_prompt(message: str, history: list, system_prompt: str, knowledge: st
|
|
56 |
# CREDIT: Copied from official interference example on Huggingface
|
57 |
## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
|
58 |
def respond(prompt):
|
|
|
59 |
input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
|
|
|
|
|
60 |
outputs = MODEL.generate(input_ids, **CONFIG)
|
61 |
output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
|
62 |
|
|
|
63 |
return output
|
|
|
6 |
# internal imports
|
7 |
from utils import modelling as mdl
|
8 |
|
9 |
+
# global model and tokenizer instance (created on inital build)
|
10 |
TOKENIZER = AutoTokenizer.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
11 |
MODEL = AutoModelForSeq2SeqLM.from_pretrained("microsoft/GODEL-v1_1-large-seq2seq")
|
12 |
+
|
13 |
+
# default model config
|
14 |
CONFIG = {"max_new_tokens": 50, "min_length": 8, "top_p": 0.9, "do_sample": True}
|
15 |
|
16 |
|
17 |
+
# function to (re) set config
|
18 |
+
def set_config(config: dict):
|
19 |
+
global CONFIG
|
|
|
20 |
|
21 |
+
# if config dict is given, update it
|
22 |
+
if config != {}:
|
23 |
+
CONFIG = config
|
24 |
+
else:
|
25 |
+
# hard setting model config to default
|
26 |
+
# needed for shap
|
27 |
+
MODEL.config.max_new_tokens = 50
|
28 |
+
MODEL.config.min_length = 8
|
29 |
+
MODEL.config.top_p = 0.9
|
30 |
+
MODEL.config.do_sample = True
|
31 |
|
32 |
|
33 |
# formatting class to formatting input for the model
|
|
|
63 |
# CREDIT: Copied from official interference example on Huggingface
|
64 |
## see https://huggingface.co/microsoft/GODEL-v1_1-large-seq2seq
|
65 |
def respond(prompt):
|
66 |
+
# tokenizing input string
|
67 |
input_ids = TOKENIZER(f"{prompt}", return_tensors="pt").input_ids
|
68 |
+
|
69 |
+
# generating using config and decoding output
|
70 |
outputs = MODEL.generate(input_ids, **CONFIG)
|
71 |
output = TOKENIZER.decode(outputs[0], skip_special_tokens=True)
|
72 |
|
73 |
+
# returns the model output string
|
74 |
return output
|
utils/__init__.py
CHANGED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# empty init file for the module
|
utils/formatting.py
CHANGED
@@ -7,8 +7,10 @@ from numpy import ndarray
|
|
7 |
|
8 |
|
9 |
# function to format the model reponse nicely
|
|
|
10 |
def format_output_text(output: list):
|
11 |
-
|
|
|
12 |
formatted_output = format_tokens(output)
|
13 |
|
14 |
# start string with first list item if it is not empty
|
@@ -34,8 +36,10 @@ def format_output_text(output: list):
|
|
34 |
|
35 |
# format the tokens by removing special tokens and special characters
|
36 |
def format_tokens(tokens: list):
|
37 |
-
# define special tokens to remove
|
38 |
special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", "▁", "Ġ", "</w>"]
|
|
|
|
|
39 |
updated_tokens = []
|
40 |
|
41 |
# loop through tokens
|
@@ -44,7 +48,7 @@ def format_tokens(tokens: list):
|
|
44 |
if t.startswith("▁"):
|
45 |
t = t.lstrip("▁")
|
46 |
|
47 |
-
# loop through special tokens and remove
|
48 |
for s in special_tokens:
|
49 |
t = t.replace(s, "")
|
50 |
|
@@ -55,15 +59,17 @@ def format_tokens(tokens: list):
|
|
55 |
return updated_tokens
|
56 |
|
57 |
|
58 |
-
# function to flatten values into a 2d list by
|
59 |
def flatten_attribution(values: ndarray, axis: int = 0):
|
60 |
return np.sum(values, axis=axis)
|
61 |
|
62 |
|
|
|
63 |
def flatten_attention(values: ndarray, axis: int = 0):
|
64 |
return np.mean(values, axis=axis)
|
65 |
|
66 |
|
|
|
67 |
def avg_attention(attention_values):
|
68 |
attention = attention_values.decoder_attentions[0][0].detach().numpy()
|
69 |
return np.mean(attention, axis=0)
|
|
|
7 |
|
8 |
|
9 |
# function to format the model reponse nicely
|
10 |
+
# takes a list of strings and returnings a combined string
|
11 |
def format_output_text(output: list):
|
12 |
+
|
13 |
+
# remove special tokens from list using other function
|
14 |
formatted_output = format_tokens(output)
|
15 |
|
16 |
# start string with first list item if it is not empty
|
|
|
36 |
|
37 |
# format the tokens by removing special tokens and special characters
|
38 |
def format_tokens(tokens: list):
|
39 |
+
# define special tokens to remove
|
40 |
special_tokens = ["[CLS]", "[SEP]", "[PAD]", "[UNK]", "[MASK]", "▁", "Ġ", "</w>"]
|
41 |
+
|
42 |
+
# initialize empty list
|
43 |
updated_tokens = []
|
44 |
|
45 |
# loop through tokens
|
|
|
48 |
if t.startswith("▁"):
|
49 |
t = t.lstrip("▁")
|
50 |
|
51 |
+
# loop through special tokens list and remove from current token if matched
|
52 |
for s in special_tokens:
|
53 |
t = t.replace(s, "")
|
54 |
|
|
|
59 |
return updated_tokens
|
60 |
|
61 |
|
62 |
+
# function to flatten shap values into a 2d list by summing them up
|
63 |
def flatten_attribution(values: ndarray, axis: int = 0):
|
64 |
return np.sum(values, axis=axis)
|
65 |
|
66 |
|
67 |
+
# function to flatten values into a 2d list by averaging the attention values
|
68 |
def flatten_attention(values: ndarray, axis: int = 0):
|
69 |
return np.mean(values, axis=axis)
|
70 |
|
71 |
|
72 |
+
# function to get averaged decoder attention from attention values
|
73 |
def avg_attention(attention_values):
|
74 |
attention = attention_values.decoder_attentions[0][0].detach().numpy()
|
75 |
return np.mean(attention, axis=0)
|
utils/modelling.py
CHANGED
@@ -1,26 +1,28 @@
|
|
1 |
-
# module for
|
2 |
|
3 |
# external imports
|
4 |
import gradio as gr
|
5 |
|
6 |
|
|
|
|
|
7 |
def prompt_limiter(
|
8 |
tokenizer, message: str, history: list, system_prompt: str, knowledge: str = ""
|
9 |
):
|
10 |
-
# initializing the prompt history empty
|
11 |
prompt_history = []
|
12 |
-
# getting the token count for the message, system prompt, and knowledge
|
13 |
pre_count = (
|
14 |
token_counter(tokenizer, message)
|
15 |
+ token_counter(tokenizer, system_prompt)
|
16 |
+ token_counter(tokenizer, knowledge)
|
17 |
)
|
18 |
|
19 |
-
# validating the token count
|
20 |
-
# check if token count already too high
|
21 |
if pre_count > 1024:
|
22 |
|
23 |
-
# check if token count too high even without knowledge
|
24 |
if (
|
25 |
token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt)
|
26 |
> 1024
|
@@ -32,11 +34,14 @@ def prompt_limiter(
|
|
32 |
"Message and system prompt are too long. Please shorten them."
|
33 |
)
|
34 |
|
35 |
-
# show warning and
|
36 |
-
gr.Warning("
|
|
|
|
|
|
|
37 |
return message, prompt_history, system_prompt, ""
|
38 |
|
39 |
-
# if token count small enough,
|
40 |
if pre_count < 800:
|
41 |
# setting the count to the precount
|
42 |
count = pre_count
|
@@ -46,7 +51,7 @@ def prompt_limiter(
|
|
46 |
# iterating through the history
|
47 |
for conversation in history:
|
48 |
|
49 |
-
# checking the token count with the current conversation
|
50 |
count += token_counter(tokenizer, conversation[0]) + token_counter(
|
51 |
tokenizer, conversation[1]
|
52 |
)
|
@@ -57,7 +62,7 @@ def prompt_limiter(
|
|
57 |
else:
|
58 |
break
|
59 |
|
60 |
-
# return the message,
|
61 |
return message, prompt_history, system_prompt, knowledge
|
62 |
|
63 |
|
|
|
1 |
+
# modelling util module providing formatting functions for model functionalities
|
2 |
|
3 |
# external imports
|
4 |
import gradio as gr
|
5 |
|
6 |
|
7 |
+
# function that limits the prompt to contain model runtime
|
8 |
+
# tries to keep as much as possible, always keeping at least message and system prompt
|
9 |
def prompt_limiter(
|
10 |
tokenizer, message: str, history: list, system_prompt: str, knowledge: str = ""
|
11 |
):
|
12 |
+
# initializing the new prompt history empty
|
13 |
prompt_history = []
|
14 |
+
# getting the current token count for the message, system prompt, and knowledge
|
15 |
pre_count = (
|
16 |
token_counter(tokenizer, message)
|
17 |
+ token_counter(tokenizer, system_prompt)
|
18 |
+ token_counter(tokenizer, knowledge)
|
19 |
)
|
20 |
|
21 |
+
# validating the token count against threshold of 1024
|
22 |
+
# check if token count already too high without history
|
23 |
if pre_count > 1024:
|
24 |
|
25 |
+
# check if token count too high even without knowledge and history
|
26 |
if (
|
27 |
token_counter(tokenizer, message) + token_counter(tokenizer, system_prompt)
|
28 |
> 1024
|
|
|
34 |
"Message and system prompt are too long. Please shorten them."
|
35 |
)
|
36 |
|
37 |
+
# show warning and return with empty history and empty knowledge
|
38 |
+
gr.Warning("""
|
39 |
+
Input too long.
|
40 |
+
Knowledge and conversation history have been removed to keep model running.
|
41 |
+
""")
|
42 |
return message, prompt_history, system_prompt, ""
|
43 |
|
44 |
+
# if token count small enough, adding history bit by bit
|
45 |
if pre_count < 800:
|
46 |
# setting the count to the precount
|
47 |
count = pre_count
|
|
|
51 |
# iterating through the history
|
52 |
for conversation in history:
|
53 |
|
54 |
+
# checking the token count i´with the current conversation
|
55 |
count += token_counter(tokenizer, conversation[0]) + token_counter(
|
56 |
tokenizer, conversation[1]
|
57 |
)
|
|
|
62 |
else:
|
63 |
break
|
64 |
|
65 |
+
# return the message, adapted, system prompt, and knowledge
|
66 |
return message, prompt_history, system_prompt, knowledge
|
67 |
|
68 |
|