Commit
โข
7e8ce88
1
Parent(s):
447570b
Update preference technique
Browse files- app.py +4 -6
- chat_interface_preference.py +12 -17
app.py
CHANGED
@@ -7,11 +7,9 @@ from typing import Iterator
|
|
7 |
import gradio as gr
|
8 |
import spaces
|
9 |
import torch # noqa
|
10 |
-
from transformers import
|
11 |
-
|
12 |
-
|
13 |
-
TextIteratorStreamer, # noqa
|
14 |
-
)
|
15 |
|
16 |
from chat_interface_preference import ChatInterface
|
17 |
|
@@ -70,7 +68,7 @@ def generate(
|
|
70 |
|
71 |
chat_interface = ChatInterface(
|
72 |
fn=generate,
|
73 |
-
|
74 |
min_turns=1,
|
75 |
max_turns=10,
|
76 |
repo_id="llm-human-feedback-collector-chat-interface-dpo",
|
|
|
7 |
import gradio as gr
|
8 |
import spaces
|
9 |
import torch # noqa
|
10 |
+
from transformers import AutoModelForCausalLM # noqa
|
11 |
+
from transformers import AutoTokenizer # noqa
|
12 |
+
from transformers import TextIteratorStreamer # noqa
|
|
|
|
|
13 |
|
14 |
from chat_interface_preference import ChatInterface
|
15 |
|
|
|
68 |
|
69 |
chat_interface = ChatInterface(
|
70 |
fn=generate,
|
71 |
+
prefence_technique="dpo",
|
72 |
min_turns=1,
|
73 |
max_turns=10,
|
74 |
repo_id="llm-human-feedback-collector-chat-interface-dpo",
|
chat_interface_preference.py
CHANGED
@@ -11,7 +11,7 @@ import json
|
|
11 |
import random
|
12 |
import re
|
13 |
import uuid
|
14 |
-
from typing import AsyncGenerator, Callable,
|
15 |
|
16 |
import anyio
|
17 |
from gradio.blocks import Blocks
|
@@ -27,9 +27,8 @@ from gradio.components import (
|
|
27 |
get_component_instance,
|
28 |
)
|
29 |
from gradio.events import Dependency, on
|
30 |
-
from gradio.helpers import Error, Info
|
31 |
from gradio.helpers import create_examples as Examples # noqa: N812
|
32 |
-
from gradio.helpers import special_args
|
33 |
from gradio.layouts import Accordion, Group, Row
|
34 |
from gradio.routes import Request
|
35 |
from gradio.themes import ThemeClass as Theme
|
@@ -66,7 +65,7 @@ class ChatInterface(Blocks):
|
|
66 |
self,
|
67 |
fn: Callable,
|
68 |
*,
|
69 |
-
|
70 |
min_turns: int = 1,
|
71 |
max_turns: int = 1,
|
72 |
repo_id: None | str,
|
@@ -127,14 +126,9 @@ class ChatInterface(Blocks):
|
|
127 |
raise ValueError("`max_turns` should be larger than `min_turns`")
|
128 |
self.max_turns = max_turns
|
129 |
self.min_turns = min_turns
|
130 |
-
if isinstance(prefence_techniques, str):
|
131 |
-
prefence_techniques = [prefence_techniques]
|
132 |
-
elif prefence_techniques is None:
|
133 |
-
prefence_techniques = ["sft"]
|
134 |
-
self.prefence_techniques = [technique.lower() for technique in prefence_techniques]
|
135 |
|
136 |
optional_techniques = ["kto", "sft", "spin", "dpo", "simpo", "rlhf", "orpo"]
|
137 |
-
if
|
138 |
raise ValueError(f"Supported techniques are {optional_techniques}")
|
139 |
submit_btn_one = "Generate"
|
140 |
submit_btn_two = None
|
@@ -146,11 +140,12 @@ class ChatInterface(Blocks):
|
|
146 |
stop_btn = "Stop"
|
147 |
undo_btn = "โฉ๏ธ Undo"
|
148 |
clear_btn = "๐๏ธ Clear"
|
149 |
-
|
|
|
150 |
submit_btn_good = "The response ๐"
|
151 |
submit_btn_bad = "The response ๐"
|
152 |
-
if
|
153 |
-
|
154 |
submit_btn_a = "A is better than B"
|
155 |
submit_btn_b = "B is better than A"
|
156 |
submit_btn_ab = "A and B are similar"
|
@@ -368,7 +363,7 @@ class ChatInterface(Blocks):
|
|
368 |
self.saved_input = State()
|
369 |
self.chatbot_state = State(self.chatbot.value) if self.chatbot.value else State([])
|
370 |
|
371 |
-
self._setup_events()
|
372 |
self._setup_api()
|
373 |
|
374 |
def _set_conversation_id(self):
|
@@ -385,15 +380,15 @@ class ChatInterface(Blocks):
|
|
385 |
with self.data_file.open("a") as f:
|
386 |
f.write(json.dumps(feedback))
|
387 |
|
388 |
-
def _setup_events(self) -> None:
|
389 |
submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
|
390 |
-
submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=
|
391 |
submit_triggers_one = (
|
392 |
[self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
|
393 |
)
|
394 |
submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
|
395 |
if self.submit_btn_two:
|
396 |
-
submit_fn_two = functools.partial(submit_fn_one, n_generations=
|
397 |
submit_triggers_two = [self.submit_btn_two.click]
|
398 |
submit_tuples.append((submit_fn_two, submit_triggers_two))
|
399 |
for _fn, _triggers in submit_tuples:
|
|
|
11 |
import random
|
12 |
import re
|
13 |
import uuid
|
14 |
+
from typing import AsyncGenerator, Callable, Literal, Union, cast
|
15 |
|
16 |
import anyio
|
17 |
from gradio.blocks import Blocks
|
|
|
27 |
get_component_instance,
|
28 |
)
|
29 |
from gradio.events import Dependency, on
|
30 |
+
from gradio.helpers import Error, Info, special_args
|
31 |
from gradio.helpers import create_examples as Examples # noqa: N812
|
|
|
32 |
from gradio.layouts import Accordion, Group, Row
|
33 |
from gradio.routes import Request
|
34 |
from gradio.themes import ThemeClass as Theme
|
|
|
65 |
self,
|
66 |
fn: Callable,
|
67 |
*,
|
68 |
+
prefence_technique: str = None,
|
69 |
min_turns: int = 1,
|
70 |
max_turns: int = 1,
|
71 |
repo_id: None | str,
|
|
|
126 |
raise ValueError("`max_turns` should be larger than `min_turns`")
|
127 |
self.max_turns = max_turns
|
128 |
self.min_turns = min_turns
|
|
|
|
|
|
|
|
|
|
|
129 |
|
130 |
optional_techniques = ["kto", "sft", "spin", "dpo", "simpo", "rlhf", "orpo"]
|
131 |
+
if prefence_technique not in optional_techniques:
|
132 |
raise ValueError(f"Supported techniques are {optional_techniques}")
|
133 |
submit_btn_one = "Generate"
|
134 |
submit_btn_two = None
|
|
|
140 |
stop_btn = "Stop"
|
141 |
undo_btn = "โฉ๏ธ Undo"
|
142 |
clear_btn = "๐๏ธ Clear"
|
143 |
+
n_generations = 1
|
144 |
+
if "kto" == prefence_technique:
|
145 |
submit_btn_good = "The response ๐"
|
146 |
submit_btn_bad = "The response ๐"
|
147 |
+
if prefence_technique in ["dpo", "simpo", "rlhf", "orpo"]:
|
148 |
+
n_generations = 2
|
149 |
submit_btn_a = "A is better than B"
|
150 |
submit_btn_b = "B is better than A"
|
151 |
submit_btn_ab = "A and B are similar"
|
|
|
363 |
self.saved_input = State()
|
364 |
self.chatbot_state = State(self.chatbot.value) if self.chatbot.value else State([])
|
365 |
|
366 |
+
self._setup_events(n_generations)
|
367 |
self._setup_api()
|
368 |
|
369 |
def _set_conversation_id(self):
|
|
|
380 |
with self.data_file.open("a") as f:
|
381 |
f.write(json.dumps(feedback))
|
382 |
|
383 |
+
def _setup_events(self, n_generations) -> None:
|
384 |
submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
|
385 |
+
submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=n_generations)
|
386 |
submit_triggers_one = (
|
387 |
[self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
|
388 |
)
|
389 |
submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
|
390 |
if self.submit_btn_two:
|
391 |
+
submit_fn_two = functools.partial(submit_fn_one, n_generations=n_generations)
|
392 |
submit_triggers_two = [self.submit_btn_two.click]
|
393 |
submit_tuples.append((submit_fn_two, submit_triggers_two))
|
394 |
for _fn, _triggers in submit_tuples:
|