Commit
•
cdbc268
1
Parent(s):
c32ec54
Update chat_interface_preference.py
Browse files- chat_interface_preference.py +19 -13
chat_interface_preference.py
CHANGED
@@ -21,6 +21,7 @@ from gradio.components import (
|
|
21 |
Component,
|
22 |
Markdown,
|
23 |
MultimodalTextbox,
|
|
|
24 |
State,
|
25 |
Textbox,
|
26 |
get_component_instance,
|
@@ -184,12 +185,16 @@ class ChatInterface(Blocks):
|
|
184 |
if self.commit_scheduler:
|
185 |
self.data_file = self.commit_scheduler.folder_path / f"data_{uuid.uuid4()}.json"
|
186 |
|
|
|
|
|
187 |
if additional_inputs:
|
188 |
if not isinstance(additional_inputs, list):
|
189 |
-
additional_inputs = [additional_inputs]
|
|
|
|
|
190 |
self.additional_inputs = [get_component_instance(i) for i in additional_inputs] # type: ignore
|
191 |
else:
|
192 |
-
self.additional_inputs = []
|
193 |
if additional_inputs_accordion_name is not None:
|
194 |
print(
|
195 |
"The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
|
@@ -382,10 +387,11 @@ class ChatInterface(Blocks):
|
|
382 |
|
383 |
def _setup_events(self) -> None:
|
384 |
submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
|
|
|
385 |
submit_triggers_one = (
|
386 |
[self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
|
387 |
)
|
388 |
-
submit_tuples = [(
|
389 |
if self.submit_btn_two:
|
390 |
submit_fn_two = functools.partial(submit_fn_one, n_generations=2)
|
391 |
submit_triggers_two = [self.submit_btn_two.click]
|
@@ -671,8 +677,8 @@ class ChatInterface(Blocks):
|
|
671 |
message: str | dict[str, list],
|
672 |
history_with_input: list[list[str | tuple | None]],
|
673 |
request: Request,
|
674 |
-
n_generations: int = 1,
|
675 |
*args,
|
|
|
676 |
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
|
677 |
if self.multimodal and isinstance(message, dict):
|
678 |
remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
|
@@ -712,23 +718,20 @@ class ChatInterface(Blocks):
|
|
712 |
message: str | dict[str, list],
|
713 |
history_with_input: list[list[str | tuple | None]],
|
714 |
request: Request,
|
715 |
-
n_generations: int = 1,
|
716 |
*args,
|
|
|
717 |
) -> AsyncGenerator:
|
718 |
if self.multimodal and isinstance(message, dict):
|
719 |
remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
|
720 |
history = history_with_input[:-remove_input]
|
721 |
else:
|
722 |
history = history_with_input[:-1]
|
723 |
-
|
724 |
self._check_message(message)
|
725 |
self._check_num_turns(history)
|
726 |
_, response = history_with_input[-1]
|
727 |
if self._check_if_two_responses(response):
|
728 |
raise Error("Two options detected: undo, log or random pick continuation.")
|
729 |
|
730 |
-
_, response = history_with_input[-1]
|
731 |
-
|
732 |
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
733 |
|
734 |
try:
|
@@ -900,10 +903,13 @@ class ChatInterface(Blocks):
|
|
900 |
str | dict[str, list],
|
901 |
list[list[str | tuple | None]],
|
902 |
]:
|
903 |
-
|
904 |
-
|
905 |
-
|
|
|
|
|
|
|
|
|
|
|
906 |
else:
|
907 |
-
await self._log_fn(message=message, history=history, log="prompt")
|
908 |
-
self._set_conversation_id()
|
909 |
return [], "", []
|
|
|
21 |
Component,
|
22 |
Markdown,
|
23 |
MultimodalTextbox,
|
24 |
+
Slider,
|
25 |
State,
|
26 |
Textbox,
|
27 |
get_component_instance,
|
|
|
185 |
if self.commit_scheduler:
|
186 |
self.data_file = self.commit_scheduler.folder_path / f"data_{uuid.uuid4()}.json"
|
187 |
|
188 |
+
slider_placeholder = Slider(label="generation", minimum=1.0, maximum=2.0, step=0.05, value=1.2, render=False)
|
189 |
+
|
190 |
if additional_inputs:
|
191 |
if not isinstance(additional_inputs, list):
|
192 |
+
additional_inputs = [slider_placeholder] + [additional_inputs]
|
193 |
+
else:
|
194 |
+
additional_inputs = [slider_placeholder] + additional_inputs
|
195 |
self.additional_inputs = [get_component_instance(i) for i in additional_inputs] # type: ignore
|
196 |
else:
|
197 |
+
self.additional_inputs = [slider_placeholder]
|
198 |
if additional_inputs_accordion_name is not None:
|
199 |
print(
|
200 |
"The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
|
|
|
387 |
|
388 |
def _setup_events(self) -> None:
|
389 |
submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
|
390 |
+
submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=1)
|
391 |
submit_triggers_one = (
|
392 |
[self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
|
393 |
)
|
394 |
+
submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
|
395 |
if self.submit_btn_two:
|
396 |
submit_fn_two = functools.partial(submit_fn_one, n_generations=2)
|
397 |
submit_triggers_two = [self.submit_btn_two.click]
|
|
|
677 |
message: str | dict[str, list],
|
678 |
history_with_input: list[list[str | tuple | None]],
|
679 |
request: Request,
|
|
|
680 |
*args,
|
681 |
+
n_generations: int,
|
682 |
) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
|
683 |
if self.multimodal and isinstance(message, dict):
|
684 |
remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
|
|
|
718 |
message: str | dict[str, list],
|
719 |
history_with_input: list[list[str | tuple | None]],
|
720 |
request: Request,
|
|
|
721 |
*args,
|
722 |
+
n_generations: int,
|
723 |
) -> AsyncGenerator:
|
724 |
if self.multimodal and isinstance(message, dict):
|
725 |
remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
|
726 |
history = history_with_input[:-remove_input]
|
727 |
else:
|
728 |
history = history_with_input[:-1]
|
|
|
729 |
self._check_message(message)
|
730 |
self._check_num_turns(history)
|
731 |
_, response = history_with_input[-1]
|
732 |
if self._check_if_two_responses(response):
|
733 |
raise Error("Two options detected: undo, log or random pick continuation.")
|
734 |
|
|
|
|
|
735 |
inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
|
736 |
|
737 |
try:
|
|
|
903 |
str | dict[str, list],
|
904 |
list[list[str | tuple | None]],
|
905 |
]:
|
906 |
+
if history:
|
907 |
+
_, response = history[-1]
|
908 |
+
if self._check_if_two_responses(response):
|
909 |
+
raise Error("First log preference or continue random.")
|
910 |
+
else:
|
911 |
+
await self._log_fn(message=message, history=history, log="prompt")
|
912 |
+
self._set_conversation_id()
|
913 |
+
return [], "", []
|
914 |
else:
|
|
|
|
|
915 |
return [], "", []
|