davidberenstein1957 HF staff commited on
Commit
cdbc268
1 Parent(s): c32ec54

Update chat_interface_preference.py

Browse files
Files changed (1) hide show
  1. chat_interface_preference.py +19 -13
chat_interface_preference.py CHANGED
@@ -21,6 +21,7 @@ from gradio.components import (
21
  Component,
22
  Markdown,
23
  MultimodalTextbox,
 
24
  State,
25
  Textbox,
26
  get_component_instance,
@@ -184,12 +185,16 @@ class ChatInterface(Blocks):
184
  if self.commit_scheduler:
185
  self.data_file = self.commit_scheduler.folder_path / f"data_{uuid.uuid4()}.json"
186
 
 
 
187
  if additional_inputs:
188
  if not isinstance(additional_inputs, list):
189
- additional_inputs = [additional_inputs]
 
 
190
  self.additional_inputs = [get_component_instance(i) for i in additional_inputs] # type: ignore
191
  else:
192
- self.additional_inputs = []
193
  if additional_inputs_accordion_name is not None:
194
  print(
195
  "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
@@ -382,10 +387,11 @@ class ChatInterface(Blocks):
382
 
383
  def _setup_events(self) -> None:
384
  submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
 
385
  submit_triggers_one = (
386
  [self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
387
  )
388
- submit_tuples = [(submit_fn_one, submit_triggers_one)]
389
  if self.submit_btn_two:
390
  submit_fn_two = functools.partial(submit_fn_one, n_generations=2)
391
  submit_triggers_two = [self.submit_btn_two.click]
@@ -671,8 +677,8 @@ class ChatInterface(Blocks):
671
  message: str | dict[str, list],
672
  history_with_input: list[list[str | tuple | None]],
673
  request: Request,
674
- n_generations: int = 1,
675
  *args,
 
676
  ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
677
  if self.multimodal and isinstance(message, dict):
678
  remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
@@ -712,23 +718,20 @@ class ChatInterface(Blocks):
712
  message: str | dict[str, list],
713
  history_with_input: list[list[str | tuple | None]],
714
  request: Request,
715
- n_generations: int = 1,
716
  *args,
 
717
  ) -> AsyncGenerator:
718
  if self.multimodal and isinstance(message, dict):
719
  remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
720
  history = history_with_input[:-remove_input]
721
  else:
722
  history = history_with_input[:-1]
723
-
724
  self._check_message(message)
725
  self._check_num_turns(history)
726
  _, response = history_with_input[-1]
727
  if self._check_if_two_responses(response):
728
  raise Error("Two options detected: undo, log or random pick continuation.")
729
 
730
- _, response = history_with_input[-1]
731
-
732
  inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
733
 
734
  try:
@@ -900,10 +903,13 @@ class ChatInterface(Blocks):
900
  str | dict[str, list],
901
  list[list[str | tuple | None]],
902
  ]:
903
- _, response = history[-1]
904
- if self._check_if_two_responses(response):
905
- raise Error("First log preference or continue random.")
 
 
 
 
 
906
  else:
907
- await self._log_fn(message=message, history=history, log="prompt")
908
- self._set_conversation_id()
909
  return [], "", []
 
21
  Component,
22
  Markdown,
23
  MultimodalTextbox,
24
+ Slider,
25
  State,
26
  Textbox,
27
  get_component_instance,
 
185
  if self.commit_scheduler:
186
  self.data_file = self.commit_scheduler.folder_path / f"data_{uuid.uuid4()}.json"
187
 
188
+ slider_placeholder = Slider(label="generation", minimum=1.0, maximum=2.0, step=0.05, value=1.2, render=False)
189
+
190
  if additional_inputs:
191
  if not isinstance(additional_inputs, list):
192
+ additional_inputs = [slider_placeholder] + [additional_inputs]
193
+ else:
194
+ additional_inputs = [slider_placeholder] + additional_inputs
195
  self.additional_inputs = [get_component_instance(i) for i in additional_inputs] # type: ignore
196
  else:
197
+ self.additional_inputs = [slider_placeholder]
198
  if additional_inputs_accordion_name is not None:
199
  print(
200
  "The `additional_inputs_accordion_name` parameter is deprecated and will be removed in a future version of Gradio. Use the `additional_inputs_accordion` parameter instead."
 
387
 
388
  def _setup_events(self) -> None:
389
  submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
390
+ submit_fn_one_partial = functools.partial(submit_fn_one, n_generations=1)
391
  submit_triggers_one = (
392
  [self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
393
  )
394
+ submit_tuples = [(submit_fn_one_partial, submit_triggers_one)]
395
  if self.submit_btn_two:
396
  submit_fn_two = functools.partial(submit_fn_one, n_generations=2)
397
  submit_triggers_two = [self.submit_btn_two.click]
 
677
  message: str | dict[str, list],
678
  history_with_input: list[list[str | tuple | None]],
679
  request: Request,
 
680
  *args,
681
+ n_generations: int,
682
  ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
683
  if self.multimodal and isinstance(message, dict):
684
  remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
 
718
  message: str | dict[str, list],
719
  history_with_input: list[list[str | tuple | None]],
720
  request: Request,
 
721
  *args,
722
+ n_generations: int,
723
  ) -> AsyncGenerator:
724
  if self.multimodal and isinstance(message, dict):
725
  remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
726
  history = history_with_input[:-remove_input]
727
  else:
728
  history = history_with_input[:-1]
 
729
  self._check_message(message)
730
  self._check_num_turns(history)
731
  _, response = history_with_input[-1]
732
  if self._check_if_two_responses(response):
733
  raise Error("Two options detected: undo, log or random pick continuation.")
734
 
 
 
735
  inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
736
 
737
  try:
 
903
  str | dict[str, list],
904
  list[list[str | tuple | None]],
905
  ]:
906
+ if history:
907
+ _, response = history[-1]
908
+ if self._check_if_two_responses(response):
909
+ raise Error("First log preference or continue random.")
910
+ else:
911
+ await self._log_fn(message=message, history=history, log="prompt")
912
+ self._set_conversation_id()
913
+ return [], "", []
914
  else:
 
 
915
  return [], "", []