Commit
•
0adfa8b
1
Parent(s):
97c603a
Add package argilla manually override within code
Browse files- app.py +12 -9
- requirements.txt +0 -1
app.py
CHANGED
@@ -1,4 +1,13 @@
|
|
1 |
#!/usr/bin/env python
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
import os
|
4 |
from threading import Thread
|
@@ -10,14 +19,14 @@ import spaces
|
|
10 |
import torch
|
11 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
12 |
|
13 |
-
from
|
14 |
|
15 |
MAX_MAX_NEW_TOKENS = 2048
|
16 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
17 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
|
18 |
|
19 |
if torch.cuda.is_available():
|
20 |
-
model_id = "
|
21 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
|
22 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
23 |
style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
|
@@ -87,9 +96,8 @@ chat_interface = ChatInterface(
|
|
87 |
chatbot=gr.Chatbot(
|
88 |
height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
|
89 |
),
|
90 |
-
|
91 |
rg_dataset=dataset,
|
92 |
-
# textbox=gr.Textbox(value="Typ een bericht…"),
|
93 |
cache_examples=False,
|
94 |
additional_inputs=[
|
95 |
gr.Slider(
|
@@ -159,11 +167,6 @@ Vat bovenstaand artikel samen"""
|
|
159 |
description="""\
|
160 |
<a href="https://huggingface.co/davidberenstein1957/ultra-feedback-dutch-cleaned-hq-spin-geitje-7b-ultra-sft_iter2">GEITje 7B SPIN iter 2</a> is een geavanceerde versie van GEITje, verder getraind op uitgebreide chat datasets en ook op een preferentiedataset om beter te aligneren met het gedrag van een gewenste chatbot op basis van het [SPIN algoritme en code](https://github.com/argilla-io/distilabel-spin-dibt/), in dit geval gpt-4-turbo. De data is net iets anders dan de originele dataset want we hebben de schoonmaak voor UltraFeedback gebruikt van Argilla, de uiteindelijke dataset is [hier](https://huggingface.co/datasets/BramVanroy/ultra_feedback_dutch_cleaned) te vinden.
|
161 |
""",
|
162 |
-
submit_btn="Genereer",
|
163 |
-
stop_btn="Stop",
|
164 |
-
retry_btn="🔄 Opnieuw",
|
165 |
-
undo_btn="↩️ Ongedaan maken",
|
166 |
-
clear_btn="🗑️ Wissen",
|
167 |
)
|
168 |
|
169 |
with gr.Blocks(css="style.css") as demo:
|
|
|
1 |
#!/usr/bin/env python
|
2 |
+
if True:
|
3 |
+
import subprocess
|
4 |
+
import sys
|
5 |
+
|
6 |
+
def install_package(package_name):
|
7 |
+
subprocess.check_call([sys.executable, "-m", "pip", "install", package_name])
|
8 |
+
|
9 |
+
# Example usage:
|
10 |
+
install_package("argilla==2.0.0rc1")
|
11 |
|
12 |
import os
|
13 |
from threading import Thread
|
|
|
19 |
import torch
|
20 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
21 |
|
22 |
+
from chat_interface_preference import ChatInterface
|
23 |
|
24 |
MAX_MAX_NEW_TOKENS = 2048
|
25 |
DEFAULT_MAX_NEW_TOKENS = 1024
|
26 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
|
27 |
|
28 |
if torch.cuda.is_available():
|
29 |
+
model_id = "Qwen/Qwen2-0.5B-Instruct"
|
30 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
|
31 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
32 |
style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
|
|
|
96 |
chatbot=gr.Chatbot(
|
97 |
height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
|
98 |
),
|
99 |
+
css=style,
|
100 |
rg_dataset=dataset,
|
|
|
101 |
cache_examples=False,
|
102 |
additional_inputs=[
|
103 |
gr.Slider(
|
|
|
167 |
description="""\
|
168 |
<a href="https://huggingface.co/davidberenstein1957/ultra-feedback-dutch-cleaned-hq-spin-geitje-7b-ultra-sft_iter2">GEITje 7B SPIN iter 2</a> is een geavanceerde versie van GEITje, verder getraind op uitgebreide chat datasets en ook op een preferentiedataset om beter te aligneren met het gedrag van een gewenste chatbot op basis van het [SPIN algoritme en code](https://github.com/argilla-io/distilabel-spin-dibt/), in dit geval gpt-4-turbo. De data is net iets anders dan de originele dataset want we hebben de schoonmaak voor UltraFeedback gebruikt van Argilla, de uiteindelijke dataset is [hier](https://huggingface.co/datasets/BramVanroy/ultra_feedback_dutch_cleaned) te vinden.
|
169 |
""",
|
|
|
|
|
|
|
|
|
|
|
170 |
)
|
171 |
|
172 |
with gr.Blocks(css="style.css") as demo:
|
requirements.txt
CHANGED
@@ -6,4 +6,3 @@ sentencepiece==0.2.0
|
|
6 |
spaces==0.28.3
|
7 |
torch==2.0.1
|
8 |
transformers==4.41.2
|
9 |
-
argilla==2.0.0rc1 --install-option="--no-deps"
|
|
|
6 |
spaces==0.28.3
|
7 |
torch==2.0.1
|
8 |
transformers==4.41.2
|
|