Commit
·
17aeee6
1
Parent(s):
5b0592a
Update LLM preference collector
Browse files- .gitignore +1 -0
- app copy.py +0 -148
- app.py +4 -0
- chat_interface_preference.py +224 -111
- test.py +10 -18
.gitignore
CHANGED
@@ -160,3 +160,4 @@ cython_debug/
|
|
160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
#.idea/
|
|
|
|
160 |
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
161 |
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
162 |
#.idea/
|
163 |
+
feedback
|
app copy.py
DELETED
@@ -1,148 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
|
3 |
-
import os
|
4 |
-
from threading import Thread
|
5 |
-
from typing import Iterator
|
6 |
-
|
7 |
-
import gradio as gr
|
8 |
-
import spaces
|
9 |
-
import torch
|
10 |
-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
|
11 |
-
|
12 |
-
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
|
13 |
-
|
14 |
-
if torch.cuda.is_available():
|
15 |
-
model_id = "davidberenstein1957/ultra-feedback-dutch-cleaned-hq-spin-geitje-7b-ultra-sft_iter2"
|
16 |
-
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
|
17 |
-
tokenizer = AutoTokenizer.from_pretrained(model_id)
|
18 |
-
|
19 |
-
|
20 |
-
@spaces.GPU
|
21 |
-
def generate(
|
22 |
-
message: str,
|
23 |
-
chat_history: list[tuple[str, str]],
|
24 |
-
max_new_tokens: int = 1024,
|
25 |
-
temperature: float = 0.06,
|
26 |
-
top_p: float = 0.95,
|
27 |
-
top_k: int = 40,
|
28 |
-
repetition_penalty: float = 1.2,
|
29 |
-
) -> Iterator[str]:
|
30 |
-
conversation = []
|
31 |
-
for user, assistant in chat_history:
|
32 |
-
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
|
33 |
-
conversation.append({"role": "user", "content": message})
|
34 |
-
|
35 |
-
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
|
36 |
-
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
|
37 |
-
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
|
38 |
-
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
|
39 |
-
input_ids = input_ids.to(model.device)
|
40 |
-
|
41 |
-
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
|
42 |
-
generate_kwargs = dict(
|
43 |
-
{"input_ids": input_ids},
|
44 |
-
streamer=streamer,
|
45 |
-
max_new_tokens=max_new_tokens,
|
46 |
-
do_sample=True,
|
47 |
-
top_p=top_p,
|
48 |
-
top_k=top_k,
|
49 |
-
temperature=temperature,
|
50 |
-
num_beams=1,
|
51 |
-
repetition_penalty=repetition_penalty,
|
52 |
-
)
|
53 |
-
t = Thread(target=model.generate, kwargs=generate_kwargs)
|
54 |
-
t.start()
|
55 |
-
|
56 |
-
outputs = []
|
57 |
-
for text in streamer:
|
58 |
-
outputs.append(text)
|
59 |
-
yield "".join(outputs)
|
60 |
-
|
61 |
-
|
62 |
-
chat_interface = gr.ChatInterface(
|
63 |
-
fn=generate,
|
64 |
-
chatbot=gr.Chatbot(
|
65 |
-
height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
|
66 |
-
),
|
67 |
-
# textbox=gr.Textbox(value="Typ een bericht…"),
|
68 |
-
cache_examples=False,
|
69 |
-
additional_inputs=[
|
70 |
-
gr.Slider(
|
71 |
-
label="Max new tokens",
|
72 |
-
minimum=1,
|
73 |
-
maximum=MAX_MAX_NEW_TOKENS,
|
74 |
-
step=1,
|
75 |
-
value=DEFAULT_MAX_NEW_TOKENS,
|
76 |
-
),
|
77 |
-
gr.Slider(
|
78 |
-
label="Temperature",
|
79 |
-
minimum=0.05,
|
80 |
-
maximum=1.2,
|
81 |
-
step=0.05,
|
82 |
-
value=0.2,
|
83 |
-
),
|
84 |
-
gr.Slider(
|
85 |
-
label="Top-p (nucleus sampling)",
|
86 |
-
minimum=0.05,
|
87 |
-
maximum=1.0,
|
88 |
-
step=0.05,
|
89 |
-
value=0.9,
|
90 |
-
),
|
91 |
-
gr.Slider(
|
92 |
-
label="Top-k",
|
93 |
-
minimum=1,
|
94 |
-
maximum=1000,
|
95 |
-
step=1,
|
96 |
-
value=50,
|
97 |
-
),
|
98 |
-
gr.Slider(
|
99 |
-
label="Repetition penalty",
|
100 |
-
minimum=1.0,
|
101 |
-
maximum=2.0,
|
102 |
-
step=0.05,
|
103 |
-
value=1.2,
|
104 |
-
),
|
105 |
-
],
|
106 |
-
examples=[
|
107 |
-
["""Vraagje: welk woord hoort er niet in dit rijtje thuis: "auto, vliegtuig, geit, bus"?"""],
|
108 |
-
[
|
109 |
-
"Schrijf een nieuwsbericht voor De Speld over de inzet van een kudde geiten door het Nederlands Forensisch Instituut"
|
110 |
-
],
|
111 |
-
["Wat zijn 3 leuke dingen om te doen als ik een weekendje naar Friesland ga?"],
|
112 |
-
["Met wie trad clown Bassie op?"],
|
113 |
-
["Kan je naar de maan fietsen?"],
|
114 |
-
["Wat is het belang van open source taalmodellen?"],
|
115 |
-
[
|
116 |
-
"""```
|
117 |
-
Wortelverkopers krijgen miljoenenboete voor ongeoorloofd samenspannen
|
118 |
-
Door onze economieredactie
|
119 |
-
14 dec 2023 om 12:58
|
120 |
-
Update: 20 uur geleden
|
121 |
-
162 reacties
|
122 |
-
Delen
|
123 |
-
Toezichthouder ACM heeft een Nederlands wortelkartel aangepakt. Vier telers en verkopers van wortelen krijgen samen ruim 2,5 miljoen euro boete vanwege ongeoorloofde afspraken over het verdelen van de markt.
|
124 |
-
Het gaat om telers en verkopers Laarakker, VanRijsingen, Veco en Verduyn. De vier bedrijven verkopen waspeen en Parijse wortelen aan conserven- en diepvriesfabrikanten in Nederland, België en Duitsland. Waspeen wordt vaak verkocht in potten of blikken in een mix met erwtjes.
|
125 |
-
De vier bedrijven hadden in 2018 afgesproken dat ze tien jaar lang niet overal de concurrentie met elkaar zouden aangaan. Zo zou Veco tien jaar lang geen waspeen telen of verkopen. Daarnaast zouden Laarakker, VanRijsingen en Verduyn juist de Parijse wortelen links laten liggen.
|
126 |
-
Ook betaalden de andere wortelverkopers Veco ter compensatie van de afspraken. Laarakker en Veco maakten ook nog afzonderlijke afspraken over de levering van Parijse wortelen aan Duitse klanten.
|
127 |
-
Zulke afspraken zijn verboden. Als concurrentie door die samenwerking achterwege blijft en er dus sprake is van een kartel, betalen kopers mogelijk een hogere prijs, stelt de ACM.
|
128 |
-
Twee van de wortelbedrijven werkten mee door meer informatie over de ongeoorloofde afspraken te delen met de toezichthouder. Daardoor kregen zij een lagere boete.
|
129 |
-
```
|
130 |
-
Vat bovenstaand artikel samen"""
|
131 |
-
],
|
132 |
-
],
|
133 |
-
title="🐐🕷️ GEITje 7B Spin Iter 2 🕷️🐐",
|
134 |
-
description="""\
|
135 |
-
<a href="https://huggingface.co/davidberenstein1957/ultra-feedback-dutch-cleaned-hq-spin-geitje-7b-ultra-sft_iter2">GEITje 7B SPIN iter 2</a> is een geavanceerde versie van GEITje, verder getraind op uitgebreide chat datasets en ook op een preferentiedataset om beter te aligneren met het gedrag van een gewenste chatbot op basis van het [SPIN algoritme en code](https://github.com/argilla-io/distilabel-spin-dibt/), in dit geval gpt-4-turbo. De data is net iets anders dan de originele dataset want we hebben de schoonmaak voor UltraFeedback gebruikt van Argilla, de uiteindelijke dataset is [hier](https://huggingface.co/datasets/BramVanroy/ultra_feedback_dutch_cleaned) te vinden.
|
136 |
-
""",
|
137 |
-
submit_btn="Genereer",
|
138 |
-
stop_btn="Stop",
|
139 |
-
retry_btn="🔄 Opnieuw",
|
140 |
-
undo_btn="↩️ Ongedaan maken",
|
141 |
-
clear_btn="🗑️ Wissen",
|
142 |
-
)
|
143 |
-
|
144 |
-
with gr.Blocks(css="style.css") as demo:
|
145 |
-
chat_interface.render()
|
146 |
-
|
147 |
-
if __name__ == "__main__":
|
148 |
-
demo.queue(max_size=20).launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
CHANGED
@@ -93,6 +93,10 @@ def generate(
|
|
93 |
|
94 |
chat_interface = ChatInterface(
|
95 |
fn=generate,
|
|
|
|
|
|
|
|
|
96 |
chatbot=gr.Chatbot(
|
97 |
height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
|
98 |
),
|
|
|
93 |
|
94 |
chat_interface = ChatInterface(
|
95 |
fn=generate,
|
96 |
+
prefence_techniques="dpo",
|
97 |
+
min_turns=1,
|
98 |
+
max_turns=10,
|
99 |
+
repo_id="geitje-spin",
|
100 |
chatbot=gr.Chatbot(
|
101 |
height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
|
102 |
),
|
chat_interface_preference.py
CHANGED
@@ -4,11 +4,14 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat
|
|
4 |
|
5 |
from __future__ import annotations
|
6 |
|
|
|
7 |
import functools
|
8 |
import inspect
|
|
|
9 |
import random
|
10 |
import re
|
11 |
-
|
|
|
12 |
|
13 |
import anyio
|
14 |
from gradio.blocks import Blocks
|
@@ -23,17 +26,19 @@ from gradio.components import (
|
|
23 |
get_component_instance,
|
24 |
)
|
25 |
from gradio.events import Dependency, on
|
26 |
-
from gradio.helpers import Error, Info,
|
27 |
from gradio.helpers import create_examples as Examples # noqa: N812
|
28 |
-
from gradio.helpers import special_args
|
29 |
from gradio.layouts import Accordion, Group, Row
|
30 |
from gradio.routes import Request
|
31 |
from gradio.themes import ThemeClass as Theme
|
32 |
from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda
|
33 |
from gradio_client.documentation import document
|
|
|
34 |
|
35 |
pattern = re.compile(r'<div class="message-identifier">(.*?)</div>', re.DOTALL)
|
36 |
|
|
|
|
|
37 |
|
38 |
@document()
|
39 |
class ChatInterface(Blocks):
|
@@ -59,6 +64,11 @@ class ChatInterface(Blocks):
|
|
59 |
self,
|
60 |
fn: Callable,
|
61 |
*,
|
|
|
|
|
|
|
|
|
|
|
62 |
multimodal: bool = False,
|
63 |
chatbot: Chatbot | None = None,
|
64 |
textbox: Textbox | MultimodalTextbox | None = None,
|
@@ -75,19 +85,10 @@ class ChatInterface(Blocks):
|
|
75 |
js: str | None = None,
|
76 |
head: str | None = None,
|
77 |
analytics_enabled: bool | None = None,
|
78 |
-
submit_btn_one: str | None | Button = "Generate 1",
|
79 |
-
submit_btn_two: str | None | Button = "Generate 2",
|
80 |
-
submit_btn_a: str | None | Button = "Log preference 🅰️",
|
81 |
-
submit_btn_b: str | None | Button = "Log preference 🅱️",
|
82 |
-
submit_btn_ab: str | None | Button = "Random pick 🅰️=🅱️",
|
83 |
-
stop_btn: str | None | Button = "Stop",
|
84 |
-
undo_btn: str | None | Button = "↩️ Undo",
|
85 |
-
clear_btn: str | None | Button = "🗑️ Clear",
|
86 |
autofocus: bool = True,
|
87 |
concurrency_limit: int | None | Literal["default"] = "default",
|
88 |
fill_height: bool = True,
|
89 |
delete_cache: tuple[int, int] | None = None,
|
90 |
-
rg_dataset: None,
|
91 |
):
|
92 |
"""
|
93 |
Parameters:
|
@@ -118,6 +119,39 @@ class ChatInterface(Blocks):
|
|
118 |
fill_height: If True, the chat interface will expand to the height of window.
|
119 |
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
|
120 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
super().__init__(
|
122 |
analytics_enabled=analytics_enabled,
|
123 |
mode="chat_interface",
|
@@ -129,6 +163,7 @@ class ChatInterface(Blocks):
|
|
129 |
fill_height=fill_height,
|
130 |
delete_cache=delete_cache,
|
131 |
)
|
|
|
132 |
self.css = css
|
133 |
self.multimodal = multimodal
|
134 |
self.concurrency_limit = concurrency_limit
|
@@ -139,7 +174,15 @@ class ChatInterface(Blocks):
|
|
139 |
|
140 |
self.examples = examples
|
141 |
self.cache_examples: bool | None | Literal["lazy"] = cache_examples
|
142 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
143 |
|
144 |
if additional_inputs:
|
145 |
if not isinstance(additional_inputs, list):
|
@@ -173,14 +216,27 @@ class ChatInterface(Blocks):
|
|
173 |
Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>")
|
174 |
if description:
|
175 |
Markdown(description)
|
176 |
-
|
|
|
|
|
|
|
|
|
|
|
177 |
if chatbot:
|
178 |
self.chatbot = chatbot.render()
|
179 |
else:
|
180 |
self.chatbot = Chatbot(label="Chatbot", scale=1, height=200 if fill_height else None)
|
181 |
|
182 |
with Row():
|
183 |
-
for btn in [
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
184 |
if btn is not None:
|
185 |
if isinstance(btn, Button):
|
186 |
btn.render()
|
@@ -268,6 +324,8 @@ class ChatInterface(Blocks):
|
|
268 |
self.submit_btn_a,
|
269 |
self.submit_btn_b,
|
270 |
self.submit_btn_ab,
|
|
|
|
|
271 |
self.undo_btn,
|
272 |
self.clear_btn,
|
273 |
self.submit_btn_one,
|
@@ -308,14 +366,31 @@ class ChatInterface(Blocks):
|
|
308 |
self._setup_events()
|
309 |
self._setup_api()
|
310 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
311 |
def _setup_events(self) -> None:
|
312 |
submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
|
313 |
submit_triggers_one = (
|
314 |
[self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
|
315 |
)
|
316 |
-
|
317 |
-
|
318 |
-
|
|
|
|
|
|
|
319 |
submit_event = (
|
320 |
on(
|
321 |
_triggers,
|
@@ -342,18 +417,38 @@ class ChatInterface(Blocks):
|
|
342 |
)
|
343 |
self._setup_stop_events(_triggers, submit_event)
|
344 |
|
345 |
-
partial_fn_a, partial_fn_b, partial_fn_ab = (
|
346 |
functools.partial(self._log_fn, log="a"),
|
347 |
functools.partial(self._log_fn, log="b"),
|
348 |
functools.partial(self._log_fn, log="ab"),
|
|
|
|
|
349 |
)
|
350 |
for _fn, _btn in [
|
351 |
(partial_fn_a, self.submit_btn_a),
|
352 |
(partial_fn_b, self.submit_btn_b),
|
353 |
(partial_fn_ab, self.submit_btn_ab),
|
|
|
|
|
354 |
]:
|
355 |
-
_btn
|
356 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
357 |
[self.saved_input, self.chatbot_state],
|
358 |
[self.chatbot, self.saved_input, self.chatbot_state],
|
359 |
show_api=False,
|
@@ -366,36 +461,9 @@ class ChatInterface(Blocks):
|
|
366 |
queue=False,
|
367 |
)
|
368 |
|
369 |
-
|
370 |
-
|
371 |
-
|
372 |
-
# self.retry_btn.click(
|
373 |
-
# self._delete_prev_fn,
|
374 |
-
# [self.saved_input, self.chatbot_state],
|
375 |
-
# [self.chatbot, self.saved_input, self.chatbot_state],
|
376 |
-
# show_api=False,
|
377 |
-
# queue=False,
|
378 |
-
# )
|
379 |
-
# .then(
|
380 |
-
# self._display_input,
|
381 |
-
# [self.saved_input, self.chatbot_state],
|
382 |
-
# [self.chatbot, self.chatbot_state],
|
383 |
-
# show_api=False,
|
384 |
-
# queue=False,
|
385 |
-
# )
|
386 |
-
# .then(
|
387 |
-
# submit_fn_partial,
|
388 |
-
# [self.saved_input, self.chatbot_state] + self.additional_inputs,
|
389 |
-
# [self.chatbot, self.chatbot_state],
|
390 |
-
# show_api=False,
|
391 |
-
# concurrency_limit=cast(Union[int, Literal["default"], None], self.concurrency_limit),
|
392 |
-
# )
|
393 |
-
# )
|
394 |
-
# self._setup_stop_events([self.retry_btn.click], retry_event)
|
395 |
-
|
396 |
-
if self.undo_btn:
|
397 |
-
self.undo_btn.click(
|
398 |
-
self._delete_prev_fn,
|
399 |
[self.saved_input, self.chatbot_state],
|
400 |
[self.chatbot, self.saved_input, self.chatbot_state],
|
401 |
show_api=False,
|
@@ -408,15 +476,6 @@ class ChatInterface(Blocks):
|
|
408 |
queue=False,
|
409 |
)
|
410 |
|
411 |
-
if self.clear_btn:
|
412 |
-
self.clear_btn.click(
|
413 |
-
async_lambda(lambda: ([], [], None)),
|
414 |
-
None,
|
415 |
-
[self.chatbot, self.chatbot_state, self.saved_input],
|
416 |
-
queue=False,
|
417 |
-
show_api=False,
|
418 |
-
)
|
419 |
-
|
420 |
def _setup_stop_events(self, event_triggers: list[Callable], event_to_cancel: Dependency) -> None:
|
421 |
if self.stop_btn and self.is_generator:
|
422 |
if self.submit_btn_one:
|
@@ -545,6 +604,16 @@ class ChatInterface(Blocks):
|
|
545 |
|
546 |
return self.css + "<body>" + conversation + "</body>"
|
547 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
548 |
@staticmethod
|
549 |
def _get_chat_message(message, role, turn):
|
550 |
if role == "user":
|
@@ -575,8 +644,27 @@ class ChatInterface(Blocks):
|
|
575 |
@staticmethod
|
576 |
def _check_if_two_responses(response):
|
577 |
if response:
|
578 |
-
|
579 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
580 |
|
581 |
async def _submit_fn(
|
582 |
self,
|
@@ -586,21 +674,18 @@ class ChatInterface(Blocks):
|
|
586 |
n_generations: int = 1,
|
587 |
*args,
|
588 |
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
|
589 |
-
if not message:
|
590 |
-
Warning("Make sure to provide a message next time.")
|
591 |
-
history = history_with_input[:-1]
|
592 |
-
return history, history
|
593 |
-
|
594 |
-
_, response = history_with_input[-1]
|
595 |
-
if self._check_if_two_responses(response):
|
596 |
-
raise Error("Two options detected: undo, log or random pick continuation.")
|
597 |
-
|
598 |
if self.multimodal and isinstance(message, dict):
|
599 |
remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
|
600 |
history = history_with_input[:-remove_input]
|
601 |
else:
|
602 |
history = history_with_input[:-1]
|
603 |
|
|
|
|
|
|
|
|
|
|
|
|
|
604 |
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
605 |
|
606 |
async def _get_response():
|
@@ -636,6 +721,14 @@ class ChatInterface(Blocks):
|
|
636 |
else:
|
637 |
history = history_with_input[:-1]
|
638 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
639 |
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
640 |
|
641 |
try:
|
@@ -681,7 +774,7 @@ class ChatInterface(Blocks):
|
|
681 |
generator_two = self.fn(*inputs)
|
682 |
else:
|
683 |
generator_two = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
684 |
-
generator_two = SyncToAsyncIterator(
|
685 |
try:
|
686 |
first_response_two = await async_iteration(generator_two)
|
687 |
first_response_two_formatted = self._get_chat_message_comparison(response, first_response_two)
|
@@ -701,7 +794,7 @@ class ChatInterface(Blocks):
|
|
701 |
else:
|
702 |
update = history + [[message, None]]
|
703 |
yield update, update
|
704 |
-
async for response_two in
|
705 |
response_two = self._get_chat_message_comparison(response, response_two)
|
706 |
if self.multimodal and isinstance(message, dict):
|
707 |
update = history + [[message["text"], response_two]]
|
@@ -717,44 +810,48 @@ class ChatInterface(Blocks):
|
|
717 |
str | dict[str, list],
|
718 |
list[list[str | tuple | None]],
|
719 |
]:
|
720 |
-
|
721 |
-
|
722 |
-
|
723 |
-
|
724 |
-
|
725 |
-
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
|
730 |
-
|
731 |
-
|
732 |
-
|
733 |
-
|
734 |
-
|
735 |
-
|
736 |
-
|
737 |
-
|
738 |
-
|
739 |
-
|
740 |
-
|
741 |
-
|
742 |
-
|
743 |
-
|
744 |
-
|
745 |
-
|
746 |
-
|
747 |
-
|
748 |
-
|
749 |
-
|
750 |
-
|
751 |
-
|
752 |
-
|
753 |
-
|
754 |
-
|
755 |
-
|
|
|
|
|
|
|
|
|
756 |
else:
|
757 |
-
raise Error("
|
758 |
|
759 |
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
|
760 |
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
|
@@ -770,7 +867,6 @@ class ChatInterface(Blocks):
|
|
770 |
*args,
|
771 |
) -> AsyncGenerator:
|
772 |
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
|
773 |
-
|
774 |
if self.is_async:
|
775 |
generator = self.fn(*inputs)
|
776 |
else:
|
@@ -794,3 +890,20 @@ class ChatInterface(Blocks):
|
|
794 |
else:
|
795 |
history = history[:-1]
|
796 |
return history, message or "", history
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
4 |
|
5 |
from __future__ import annotations
|
6 |
|
7 |
+
import datetime
|
8 |
import functools
|
9 |
import inspect
|
10 |
+
import json
|
11 |
import random
|
12 |
import re
|
13 |
+
import uuid
|
14 |
+
from typing import AsyncGenerator, Callable, List, Literal, Union, cast
|
15 |
|
16 |
import anyio
|
17 |
from gradio.blocks import Blocks
|
|
|
26 |
get_component_instance,
|
27 |
)
|
28 |
from gradio.events import Dependency, on
|
29 |
+
from gradio.helpers import Error, Info, special_args
|
30 |
from gradio.helpers import create_examples as Examples # noqa: N812
|
|
|
31 |
from gradio.layouts import Accordion, Group, Row
|
32 |
from gradio.routes import Request
|
33 |
from gradio.themes import ThemeClass as Theme
|
34 |
from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda
|
35 |
from gradio_client.documentation import document
|
36 |
+
from huggingface_hub import CommitScheduler
|
37 |
|
38 |
pattern = re.compile(r'<div class="message-identifier">(.*?)</div>', re.DOTALL)
|
39 |
|
40 |
+
PREFERENCE_TECHNIQUE_MAPPING = {"sft": "prompt", "dpo": "preference", "kto": "vibes"}
|
41 |
+
|
42 |
|
43 |
@document()
|
44 |
class ChatInterface(Blocks):
|
|
|
64 |
self,
|
65 |
fn: Callable,
|
66 |
*,
|
67 |
+
prefence_techniques: str | List[str] | None = None,
|
68 |
+
min_turns: int = 1,
|
69 |
+
max_turns: int = 1,
|
70 |
+
repo_id: None | str,
|
71 |
+
repo_private: bool = False,
|
72 |
multimodal: bool = False,
|
73 |
chatbot: Chatbot | None = None,
|
74 |
textbox: Textbox | MultimodalTextbox | None = None,
|
|
|
85 |
js: str | None = None,
|
86 |
head: str | None = None,
|
87 |
analytics_enabled: bool | None = None,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
autofocus: bool = True,
|
89 |
concurrency_limit: int | None | Literal["default"] = "default",
|
90 |
fill_height: bool = True,
|
91 |
delete_cache: tuple[int, int] | None = None,
|
|
|
92 |
):
|
93 |
"""
|
94 |
Parameters:
|
|
|
119 |
fill_height: If True, the chat interface will expand to the height of window.
|
120 |
delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
|
121 |
"""
|
122 |
+
if max_turns < min_turns:
|
123 |
+
raise ValueError("`max_turns` should be larger than `min_turns`")
|
124 |
+
if any([turn for turn in [max_turns, min_turns] if turn < 1]):
|
125 |
+
raise ValueError("`max_turns` should be larger than `min_turns`")
|
126 |
+
self.max_turns = max_turns
|
127 |
+
self.min_turns = min_turns
|
128 |
+
if isinstance(prefence_techniques, str):
|
129 |
+
prefence_techniques = [prefence_techniques]
|
130 |
+
elif prefence_techniques is None:
|
131 |
+
prefence_techniques = ["sft"]
|
132 |
+
self.prefence_techniques = [technique.lower() for technique in prefence_techniques]
|
133 |
+
|
134 |
+
optional_techniques = ["kto", "sft", "spin", "dpo", "simpo", "rlhf", "orpo"]
|
135 |
+
if any([technique for technique in self.prefence_techniques if technique not in optional_techniques]):
|
136 |
+
raise ValueError(f"Supported techniques are {optional_techniques}")
|
137 |
+
submit_btn_one = "Generate"
|
138 |
+
submit_btn_two = None
|
139 |
+
submit_btn_a = None
|
140 |
+
submit_btn_b = None
|
141 |
+
submit_btn_ab = None
|
142 |
+
submit_btn_good = None
|
143 |
+
submit_btn_bad = None
|
144 |
+
stop_btn = "Stop"
|
145 |
+
undo_btn = "↩️ Undo"
|
146 |
+
clear_btn = "🗑️ Log and clear"
|
147 |
+
if "kto" in prefence_techniques:
|
148 |
+
submit_btn_good = "Log response 👍"
|
149 |
+
submit_btn_bad = "Log response 👎"
|
150 |
+
if any([technique for technique in ["dpo", "simpo", "rlhf", "orpo"] if technique in self.prefence_techniques]):
|
151 |
+
submit_btn_two = "Generate 2"
|
152 |
+
submit_btn_a = "Log preference 🅰️"
|
153 |
+
submit_btn_b = "Log preference 🅱️"
|
154 |
+
submit_btn_ab = "Continue random 🅰️=🅱️"
|
155 |
super().__init__(
|
156 |
analytics_enabled=analytics_enabled,
|
157 |
mode="chat_interface",
|
|
|
163 |
fill_height=fill_height,
|
164 |
delete_cache=delete_cache,
|
165 |
)
|
166 |
+
|
167 |
self.css = css
|
168 |
self.multimodal = multimodal
|
169 |
self.concurrency_limit = concurrency_limit
|
|
|
174 |
|
175 |
self.examples = examples
|
176 |
self.cache_examples: bool | None | Literal["lazy"] = cache_examples
|
177 |
+
self._set_conversation_id()
|
178 |
+
if repo_id:
|
179 |
+
self.commit_scheduler = CommitScheduler(
|
180 |
+
repo_id=repo_id, folder_path="feedback", repo_type="dataset", private=repo_private, every=1
|
181 |
+
)
|
182 |
+
else:
|
183 |
+
self.commit_scheduler = None
|
184 |
+
if self.commit_scheduler:
|
185 |
+
self.data_file = self.commit_scheduler.folder_path / f"data_{uuid.uuid4()}.json"
|
186 |
|
187 |
if additional_inputs:
|
188 |
if not isinstance(additional_inputs, list):
|
|
|
216 |
Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>")
|
217 |
if description:
|
218 |
Markdown(description)
|
219 |
+
if self.commit_scheduler:
|
220 |
+
Markdown(
|
221 |
+
f"## Data is being logged to a datset on the hub: [{self.commit_scheduler.repo_id}](https://huggingface.co/datasets/{self.commit_scheduler.repo_id})"
|
222 |
+
)
|
223 |
+
Markdown(f"### Techniques: {self.prefence_techniques}")
|
224 |
+
Markdown(f"### MIN TURNS: {self.min_turns} - MAX TURN: {self.max_turns}")
|
225 |
if chatbot:
|
226 |
self.chatbot = chatbot.render()
|
227 |
else:
|
228 |
self.chatbot = Chatbot(label="Chatbot", scale=1, height=200 if fill_height else None)
|
229 |
|
230 |
with Row():
|
231 |
+
for btn in [
|
232 |
+
submit_btn_a,
|
233 |
+
submit_btn_b,
|
234 |
+
submit_btn_ab,
|
235 |
+
submit_btn_good,
|
236 |
+
submit_btn_bad,
|
237 |
+
undo_btn,
|
238 |
+
clear_btn,
|
239 |
+
]:
|
240 |
if btn is not None:
|
241 |
if isinstance(btn, Button):
|
242 |
btn.render()
|
|
|
324 |
self.submit_btn_a,
|
325 |
self.submit_btn_b,
|
326 |
self.submit_btn_ab,
|
327 |
+
self.submit_btn_good,
|
328 |
+
self.submit_btn_bad,
|
329 |
self.undo_btn,
|
330 |
self.clear_btn,
|
331 |
self.submit_btn_one,
|
|
|
366 |
self._setup_events()
|
367 |
self._setup_api()
|
368 |
|
369 |
+
def _set_conversation_id(self):
|
370 |
+
self._conversation_id = str(uuid.uuid4())
|
371 |
+
|
372 |
+
def _save_feedback(self, item):
|
373 |
+
feedback = {
|
374 |
+
"timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
|
375 |
+
"conversation_id": self._conversation_id,
|
376 |
+
}
|
377 |
+
feedback.update(item)
|
378 |
+
if self.commit_scheduler:
|
379 |
+
with self.commit_scheduler.lock:
|
380 |
+
with self.data_file.open("a") as f:
|
381 |
+
f.write(json.dumps(feedback))
|
382 |
+
|
383 |
def _setup_events(self) -> None:
|
384 |
submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
|
385 |
submit_triggers_one = (
|
386 |
[self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
|
387 |
)
|
388 |
+
submit_tuples = [(submit_fn_one, submit_triggers_one)]
|
389 |
+
if self.submit_btn_two:
|
390 |
+
submit_fn_two = functools.partial(submit_fn_one, n_generations=2)
|
391 |
+
submit_triggers_two = [self.submit_btn_two.click]
|
392 |
+
submit_tuples.append((submit_fn_two, submit_triggers_two))
|
393 |
+
for _fn, _triggers in submit_tuples:
|
394 |
submit_event = (
|
395 |
on(
|
396 |
_triggers,
|
|
|
417 |
)
|
418 |
self._setup_stop_events(_triggers, submit_event)
|
419 |
|
420 |
+
partial_fn_a, partial_fn_b, partial_fn_ab, partial_fn_good, partial_fn_bad = (
|
421 |
functools.partial(self._log_fn, log="a"),
|
422 |
functools.partial(self._log_fn, log="b"),
|
423 |
functools.partial(self._log_fn, log="ab"),
|
424 |
+
functools.partial(self._log_fn, log="good"),
|
425 |
+
functools.partial(self._log_fn, log="bad"),
|
426 |
)
|
427 |
for _fn, _btn in [
|
428 |
(partial_fn_a, self.submit_btn_a),
|
429 |
(partial_fn_b, self.submit_btn_b),
|
430 |
(partial_fn_ab, self.submit_btn_ab),
|
431 |
+
(partial_fn_good, self.submit_btn_good),
|
432 |
+
(partial_fn_bad, self.submit_btn_bad),
|
433 |
]:
|
434 |
+
if _btn:
|
435 |
+
_btn.click(
|
436 |
+
_fn,
|
437 |
+
[self.saved_input, self.chatbot_state],
|
438 |
+
[self.chatbot, self.saved_input, self.chatbot_state],
|
439 |
+
show_api=False,
|
440 |
+
queue=False,
|
441 |
+
).then(
|
442 |
+
async_lambda(lambda x: x),
|
443 |
+
[self.saved_input],
|
444 |
+
[self.textbox],
|
445 |
+
show_api=False,
|
446 |
+
queue=False,
|
447 |
+
)
|
448 |
+
|
449 |
+
if self.undo_btn:
|
450 |
+
self.undo_btn.click(
|
451 |
+
self._delete_prev_fn,
|
452 |
[self.saved_input, self.chatbot_state],
|
453 |
[self.chatbot, self.saved_input, self.chatbot_state],
|
454 |
show_api=False,
|
|
|
461 |
queue=False,
|
462 |
)
|
463 |
|
464 |
+
if self.clear_btn:
|
465 |
+
self.clear_btn.click(
|
466 |
+
self._clear_fn,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
467 |
[self.saved_input, self.chatbot_state],
|
468 |
[self.chatbot, self.saved_input, self.chatbot_state],
|
469 |
show_api=False,
|
|
|
476 |
queue=False,
|
477 |
)
|
478 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
479 |
def _setup_stop_events(self, event_triggers: list[Callable], event_to_cancel: Dependency) -> None:
|
480 |
if self.stop_btn and self.is_generator:
|
481 |
if self.submit_btn_one:
|
|
|
604 |
|
605 |
return self.css + "<body>" + conversation + "</body>"
|
606 |
|
607 |
+
def _get_conversation_in_openai_format(self, history):
|
608 |
+
conversation = []
|
609 |
+
for idx, turn in enumerate(history):
|
610 |
+
roles = ["user", "assistant"]
|
611 |
+
if idx == len(turn) - 1:
|
612 |
+
roles = ["user"]
|
613 |
+
for role, content in zip(roles, turn):
|
614 |
+
conversation.append({"role": role, "content": content})
|
615 |
+
return conversation
|
616 |
+
|
617 |
@staticmethod
|
618 |
def _get_chat_message(message, role, turn):
|
619 |
if role == "user":
|
|
|
644 |
@staticmethod
|
645 |
def _check_if_two_responses(response):
|
646 |
if response:
|
647 |
+
matches = pattern.findall(response)
|
648 |
+
return matches
|
649 |
+
|
650 |
+
def _check_num_turns(self, history, generate=True):
|
651 |
+
if generate:
|
652 |
+
if len(history) >= self.max_turns:
|
653 |
+
raise Error(
|
654 |
+
f"We intend to collect conversations with a maximum of {self.max_turns}, please clear or log info first."
|
655 |
+
)
|
656 |
+
return history, history
|
657 |
+
else:
|
658 |
+
if len(history) < self.min_turns:
|
659 |
+
raise Error(
|
660 |
+
f"We intend to collect conversations with at least of {self.min_turns}, please continue the conversation first."
|
661 |
+
)
|
662 |
+
return history, history
|
663 |
+
|
664 |
+
@staticmethod
|
665 |
+
def _check_message(message):
|
666 |
+
if not message:
|
667 |
+
raise Error("Make sure to provide a message next time.")
|
668 |
|
669 |
async def _submit_fn(
|
670 |
self,
|
|
|
674 |
n_generations: int = 1,
|
675 |
*args,
|
676 |
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
677 |
if self.multimodal and isinstance(message, dict):
|
678 |
remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
|
679 |
history = history_with_input[:-remove_input]
|
680 |
else:
|
681 |
history = history_with_input[:-1]
|
682 |
|
683 |
+
self._check_message(message)
|
684 |
+
self._check_num_turns(history)
|
685 |
+
_, response = history_with_input[-1]
|
686 |
+
if self._check_if_two_responses(response):
|
687 |
+
raise Error("Two options detected: undo, log or random pick continuation.")
|
688 |
+
|
689 |
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
690 |
|
691 |
async def _get_response():
|
|
|
721 |
else:
|
722 |
history = history_with_input[:-1]
|
723 |
|
724 |
+
self._check_message(message)
|
725 |
+
self._check_num_turns(history)
|
726 |
+
_, response = history_with_input[-1]
|
727 |
+
if self._check_if_two_responses(response):
|
728 |
+
raise Error("Two options detected: undo, log or random pick continuation.")
|
729 |
+
|
730 |
+
_, response = history_with_input[-1]
|
731 |
+
|
732 |
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
733 |
|
734 |
try:
|
|
|
774 |
generator_two = self.fn(*inputs)
|
775 |
else:
|
776 |
generator_two = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
|
777 |
+
generator_two = SyncToAsyncIterator(generator_two, self.limiter)
|
778 |
try:
|
779 |
first_response_two = await async_iteration(generator_two)
|
780 |
first_response_two_formatted = self._get_chat_message_comparison(response, first_response_two)
|
|
|
794 |
else:
|
795 |
update = history + [[message, None]]
|
796 |
yield update, update
|
797 |
+
async for response_two in generator_two:
|
798 |
response_two = self._get_chat_message_comparison(response, response_two)
|
799 |
if self.multimodal and isinstance(message, dict):
|
800 |
update = history + [[message["text"], response_two]]
|
|
|
810 |
str | dict[str, list],
|
811 |
list[list[str | tuple | None]],
|
812 |
]:
|
813 |
+
self._check_num_turns(history, generate=False)
|
814 |
+
history_as_openai_format = self._get_conversation_in_openai_format(history)
|
815 |
+
feedback = {"prompt": history_as_openai_format}
|
816 |
+
|
817 |
+
prompt, response = history[-1]
|
818 |
+
matches = self._check_if_two_responses(response)
|
819 |
+
if matches and log != "prompt":
|
820 |
+
option_a, option_b = matches[0], matches[1]
|
821 |
+
if log == "a":
|
822 |
+
chosen, rejected = option_a, option_b
|
823 |
+
Info("Logged preference: a")
|
824 |
+
elif log == "b":
|
825 |
+
chosen, rejected = option_b, option_a
|
826 |
+
Info("Logged preference: b")
|
827 |
+
elif log == "ab":
|
828 |
+
options = [option_a, option_b]
|
829 |
+
chosen, rejected = random.choice([options])
|
830 |
+
Info("Picked random response to continue")
|
831 |
+
if log in ["a", "b"] and self.commit_scheduler:
|
832 |
+
feedback.update(
|
833 |
+
{
|
834 |
+
"chosen": [{"content": chosen, "role": "assistant"}],
|
835 |
+
"rejected": [{"content": rejected, "role": "assistant"}],
|
836 |
+
}
|
837 |
+
)
|
838 |
+
self._save_feedback(feedback)
|
839 |
+
elif log == "ab":
|
840 |
+
self._save_feedback(feedback)
|
841 |
+
history[-1] = [prompt, chosen]
|
842 |
+
return history, message or "", history
|
843 |
+
elif log in ["conversation", "good", "bad"]:
|
844 |
+
feedback.update({"response": response})
|
845 |
+
if log == "good":
|
846 |
+
feedback.update({"label": True})
|
847 |
+
elif log == "bad":
|
848 |
+
feedback.update({"label": False})
|
849 |
+
Info("Logged conversation")
|
850 |
+
self._save_feedback(feedback)
|
851 |
+
|
852 |
+
return history, "", history
|
853 |
else:
|
854 |
+
raise Error("Error in code w.r.t logging.")
|
855 |
|
856 |
async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
|
857 |
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
|
|
|
867 |
*args,
|
868 |
) -> AsyncGenerator:
|
869 |
inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
|
|
|
870 |
if self.is_async:
|
871 |
generator = self.fn(*inputs)
|
872 |
else:
|
|
|
890 |
else:
|
891 |
history = history[:-1]
|
892 |
return history, message or "", history
|
893 |
+
|
894 |
+
async def _clear_fn(
|
895 |
+
self,
|
896 |
+
message: str | dict[str, list],
|
897 |
+
history: list[list[str | tuple | None]],
|
898 |
+
) -> tuple[
|
899 |
+
list[list[str | tuple | None]],
|
900 |
+
str | dict[str, list],
|
901 |
+
list[list[str | tuple | None]],
|
902 |
+
]:
|
903 |
+
_, response = history[-1]
|
904 |
+
if self._check_if_two_responses(response):
|
905 |
+
raise Error("First log preference or continue random.")
|
906 |
+
else:
|
907 |
+
await self._log_fn(message=message, history=history, log="prompt")
|
908 |
+
self._set_conversation_id()
|
909 |
+
return [], "", []
|
test.py
CHANGED
@@ -1,34 +1,26 @@
|
|
1 |
import random
|
2 |
|
3 |
-
|
4 |
from chat_interface_preference import ChatInterface
|
5 |
|
6 |
|
7 |
def random_response(message, history, request):
|
8 |
response = random.choice(["Yes", "No"])
|
|
|
9 |
for char in response:
|
10 |
-
|
|
|
11 |
|
12 |
|
13 |
style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
|
14 |
-
client = rg.Argilla(api_url="https://davidberenstein1957-argilla-gradio.hf.space", api_key="owner.apikey")
|
15 |
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
],
|
|
|
22 |
)
|
23 |
-
name = "test"
|
24 |
-
if client.datasets(name=name).exists():
|
25 |
-
dataset: rg.Dataset = client.datasets(name=name)
|
26 |
-
|
27 |
-
else:
|
28 |
-
dataset = rg.Dataset(name=name, settings=required_settings)
|
29 |
-
dataset.create()
|
30 |
-
|
31 |
-
demo = ChatInterface(random_response, cache_examples=False, css=style, rg_dataset=dataset)
|
32 |
|
33 |
if __name__ == "__main__":
|
34 |
demo.launch()
|
|
|
1 |
import random
|
2 |
|
|
|
3 |
from chat_interface_preference import ChatInterface
|
4 |
|
5 |
|
6 |
def random_response(message, history, request):
|
7 |
response = random.choice(["Yes", "No"])
|
8 |
+
response_total = ""
|
9 |
for char in response:
|
10 |
+
response_total += char
|
11 |
+
yield response_total
|
12 |
|
13 |
|
14 |
style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
|
|
|
15 |
|
16 |
+
demo = ChatInterface(
|
17 |
+
random_response,
|
18 |
+
cache_examples=False,
|
19 |
+
css=style,
|
20 |
+
repo_id="geitje-spin-preference",
|
21 |
+
prefence_techniques=["kto"],
|
22 |
+
max_turns=10,
|
23 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
|
25 |
if __name__ == "__main__":
|
26 |
demo.launch()
|