Commit
·
5579546
1
Parent(s):
088561c
small code updates
Browse files- app copy.py +1 -5
- app.py +1 -0
- chat_interface_preference.py +5 -9
- test.py +0 -2
app copy.py
CHANGED
@@ -9,10 +9,6 @@ import spaces
|
|
9 |
import torch
|
10 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
11 |
|
12 |
-
from chat_interface_preference import ChatInterface
|
13 |
-
|
14 |
-
MAX_MAX_NEW_TOKENS = 2048
|
15 |
-
DEFAULT_MAX_NEW_TOKENS = 1024
|
16 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
|
17 |
|
18 |
if torch.cuda.is_available():
|
@@ -63,7 +59,7 @@ def generate(
|
|
63 |
yield "".join(outputs)
|
64 |
|
65 |
|
66 |
-
chat_interface = ChatInterface(
|
67 |
fn=generate,
|
68 |
chatbot=gr.Chatbot(
|
69 |
height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
|
|
|
9 |
import torch
|
10 |
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
11 |
|
|
|
|
|
|
|
|
|
12 |
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
|
13 |
|
14 |
if torch.cuda.is_available():
|
|
|
59 |
yield "".join(outputs)
|
60 |
|
61 |
|
62 |
+
chat_interface = gr.ChatInterface(
|
63 |
fn=generate,
|
64 |
chatbot=gr.Chatbot(
|
65 |
height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
|
app.py
CHANGED
@@ -29,6 +29,7 @@ if torch.cuda.is_available():
|
|
29 |
model_id = "Qwen/Qwen2-0.5B-Instruct"
|
30 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
|
31 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
|
|
32 |
style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
|
33 |
|
34 |
client = rg.Argilla(api_url="https://davidberenstein1957-argilla-gradio.hf.space", api_key="owner.apikey")
|
|
|
29 |
model_id = "Qwen/Qwen2-0.5B-Instruct"
|
30 |
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
|
31 |
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
32 |
+
|
33 |
style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
|
34 |
|
35 |
client = rg.Argilla(api_url="https://davidberenstein1957-argilla-gradio.hf.space", api_key="owner.apikey")
|
chat_interface_preference.py
CHANGED
@@ -23,8 +23,9 @@ from gradio.components import (
|
|
23 |
get_component_instance,
|
24 |
)
|
25 |
from gradio.events import Dependency, on
|
26 |
-
from gradio.helpers import Error, Info, Warning
|
27 |
from gradio.helpers import create_examples as Examples # noqa: N812
|
|
|
28 |
from gradio.layouts import Accordion, Group, Row
|
29 |
from gradio.routes import Request
|
30 |
from gradio.themes import ThemeClass as Theme
|
@@ -629,10 +630,6 @@ class ChatInterface(Blocks):
|
|
629 |
n_generations: int = 1,
|
630 |
*args,
|
631 |
) -> AsyncGenerator:
|
632 |
-
_, response = history_with_input[-1]
|
633 |
-
if self._check_if_two_responses(response):
|
634 |
-
raise Error("Two options detected: undo, log or random pick continuation.")
|
635 |
-
|
636 |
if self.multimodal and isinstance(message, dict):
|
637 |
remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
|
638 |
history = history_with_input[:-remove_input]
|
@@ -646,14 +643,13 @@ class ChatInterface(Blocks):
|
|
646 |
generator = self.fn(*inputs)
|
647 |
else:
|
648 |
generator = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
649 |
-
|
650 |
first_response = await async_iteration(generator)
|
651 |
if n_generations == 2:
|
652 |
first_response_formatted = self._get_chat_message_comparison(first_response, "")
|
653 |
if self.multimodal and isinstance(message, dict):
|
654 |
for x in message["files"]:
|
655 |
history.append([(x,), None])
|
656 |
-
|
657 |
update = history + [[message["text"], first_response_formatted]]
|
658 |
yield update, update
|
659 |
else:
|
@@ -670,10 +666,10 @@ class ChatInterface(Blocks):
|
|
670 |
if n_generations == 2:
|
671 |
response_formatted = self._get_chat_message_comparison(response, "")
|
672 |
if self.multimodal and isinstance(message, dict):
|
673 |
-
update = history + [[message["text"],
|
674 |
yield update, update
|
675 |
else:
|
676 |
-
update = history + [[message,
|
677 |
yield update, update
|
678 |
|
679 |
if n_generations == 2:
|
|
|
23 |
get_component_instance,
|
24 |
)
|
25 |
from gradio.events import Dependency, on
|
26 |
+
from gradio.helpers import Error, Info, Warning
|
27 |
from gradio.helpers import create_examples as Examples # noqa: N812
|
28 |
+
from gradio.helpers import special_args
|
29 |
from gradio.layouts import Accordion, Group, Row
|
30 |
from gradio.routes import Request
|
31 |
from gradio.themes import ThemeClass as Theme
|
|
|
630 |
n_generations: int = 1,
|
631 |
*args,
|
632 |
) -> AsyncGenerator:
|
|
|
|
|
|
|
|
|
633 |
if self.multimodal and isinstance(message, dict):
|
634 |
remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
|
635 |
history = history_with_input[:-remove_input]
|
|
|
643 |
generator = self.fn(*inputs)
|
644 |
else:
|
645 |
generator = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
646 |
+
generator = SyncToAsyncIterator(generator, self.limiter)
|
647 |
first_response = await async_iteration(generator)
|
648 |
if n_generations == 2:
|
649 |
first_response_formatted = self._get_chat_message_comparison(first_response, "")
|
650 |
if self.multimodal and isinstance(message, dict):
|
651 |
for x in message["files"]:
|
652 |
history.append([(x,), None])
|
|
|
653 |
update = history + [[message["text"], first_response_formatted]]
|
654 |
yield update, update
|
655 |
else:
|
|
|
666 |
if n_generations == 2:
|
667 |
response_formatted = self._get_chat_message_comparison(response, "")
|
668 |
if self.multimodal and isinstance(message, dict):
|
669 |
+
update = history + [[message["text"], response]]
|
670 |
yield update, update
|
671 |
else:
|
672 |
+
update = history + [[message, response]]
|
673 |
yield update, update
|
674 |
|
675 |
if n_generations == 2:
|
test.py
CHANGED
@@ -1,8 +1,6 @@
|
|
1 |
import random
|
2 |
|
3 |
|
4 |
-
import argilla as rg
|
5 |
-
|
6 |
from chat_interface_preference import ChatInterface
|
7 |
|
8 |
|
|
|
1 |
import random
|
2 |
|
3 |
|
|
|
|
|
4 |
from chat_interface_preference import ChatInterface
|
5 |
|
6 |
|