davidberenstein1957 HF staff commited on
Commit
17aeee6
·
1 Parent(s): 5b0592a

Update LLM preference collector

Browse files
Files changed (5) hide show
  1. .gitignore +1 -0
  2. app copy.py +0 -148
  3. app.py +4 -0
  4. chat_interface_preference.py +224 -111
  5. test.py +10 -18
.gitignore CHANGED
@@ -160,3 +160,4 @@ cython_debug/
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
  #.idea/
 
 
160
  # and can be added to the global gitignore or merged into this file. For a more nuclear
161
  # option (not recommended) you can uncomment the following to ignore the entire idea folder.
162
  #.idea/
163
+ feedback
app copy.py DELETED
@@ -1,148 +0,0 @@
1
- #!/usr/bin/env python
2
-
3
- import os
4
- from threading import Thread
5
- from typing import Iterator
6
-
7
- import gradio as gr
8
- import spaces
9
- import torch
10
- from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
11
-
12
- MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
13
-
14
- if torch.cuda.is_available():
15
- model_id = "davidberenstein1957/ultra-feedback-dutch-cleaned-hq-spin-geitje-7b-ultra-sft_iter2"
16
- model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
17
- tokenizer = AutoTokenizer.from_pretrained(model_id)
18
-
19
-
20
- @spaces.GPU
21
- def generate(
22
- message: str,
23
- chat_history: list[tuple[str, str]],
24
- max_new_tokens: int = 1024,
25
- temperature: float = 0.06,
26
- top_p: float = 0.95,
27
- top_k: int = 40,
28
- repetition_penalty: float = 1.2,
29
- ) -> Iterator[str]:
30
- conversation = []
31
- for user, assistant in chat_history:
32
- conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
33
- conversation.append({"role": "user", "content": message})
34
-
35
- input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
36
- if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
37
- input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
38
- gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
39
- input_ids = input_ids.to(model.device)
40
-
41
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
42
- generate_kwargs = dict(
43
- {"input_ids": input_ids},
44
- streamer=streamer,
45
- max_new_tokens=max_new_tokens,
46
- do_sample=True,
47
- top_p=top_p,
48
- top_k=top_k,
49
- temperature=temperature,
50
- num_beams=1,
51
- repetition_penalty=repetition_penalty,
52
- )
53
- t = Thread(target=model.generate, kwargs=generate_kwargs)
54
- t.start()
55
-
56
- outputs = []
57
- for text in streamer:
58
- outputs.append(text)
59
- yield "".join(outputs)
60
-
61
-
62
- chat_interface = gr.ChatInterface(
63
- fn=generate,
64
- chatbot=gr.Chatbot(
65
- height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
66
- ),
67
- # textbox=gr.Textbox(value="Typ een bericht…"),
68
- cache_examples=False,
69
- additional_inputs=[
70
- gr.Slider(
71
- label="Max new tokens",
72
- minimum=1,
73
- maximum=MAX_MAX_NEW_TOKENS,
74
- step=1,
75
- value=DEFAULT_MAX_NEW_TOKENS,
76
- ),
77
- gr.Slider(
78
- label="Temperature",
79
- minimum=0.05,
80
- maximum=1.2,
81
- step=0.05,
82
- value=0.2,
83
- ),
84
- gr.Slider(
85
- label="Top-p (nucleus sampling)",
86
- minimum=0.05,
87
- maximum=1.0,
88
- step=0.05,
89
- value=0.9,
90
- ),
91
- gr.Slider(
92
- label="Top-k",
93
- minimum=1,
94
- maximum=1000,
95
- step=1,
96
- value=50,
97
- ),
98
- gr.Slider(
99
- label="Repetition penalty",
100
- minimum=1.0,
101
- maximum=2.0,
102
- step=0.05,
103
- value=1.2,
104
- ),
105
- ],
106
- examples=[
107
- ["""Vraagje: welk woord hoort er niet in dit rijtje thuis: "auto, vliegtuig, geit, bus"?"""],
108
- [
109
- "Schrijf een nieuwsbericht voor De Speld over de inzet van een kudde geiten door het Nederlands Forensisch Instituut"
110
- ],
111
- ["Wat zijn 3 leuke dingen om te doen als ik een weekendje naar Friesland ga?"],
112
- ["Met wie trad clown Bassie op?"],
113
- ["Kan je naar de maan fietsen?"],
114
- ["Wat is het belang van open source taalmodellen?"],
115
- [
116
- """```
117
- Wortelverkopers krijgen miljoenenboete voor ongeoorloofd samenspannen
118
- Door onze economieredactie
119
- 14 dec 2023 om 12:58
120
- Update: 20 uur geleden
121
- 162 reacties
122
- Delen
123
- Toezichthouder ACM heeft een Nederlands wortelkartel aangepakt. Vier telers en verkopers van wortelen krijgen samen ruim 2,5 miljoen euro boete vanwege ongeoorloofde afspraken over het verdelen van de markt.
124
- Het gaat om telers en verkopers Laarakker, VanRijsingen, Veco en Verduyn. De vier bedrijven verkopen waspeen en Parijse wortelen aan conserven- en diepvriesfabrikanten in Nederland, België en Duitsland. Waspeen wordt vaak verkocht in potten of blikken in een mix met erwtjes.
125
- De vier bedrijven hadden in 2018 afgesproken dat ze tien jaar lang niet overal de concurrentie met elkaar zouden aangaan. Zo zou Veco tien jaar lang geen waspeen telen of verkopen. Daarnaast zouden Laarakker, VanRijsingen en Verduyn juist de Parijse wortelen links laten liggen.
126
- Ook betaalden de andere wortelverkopers Veco ter compensatie van de afspraken. Laarakker en Veco maakten ook nog afzonderlijke afspraken over de levering van Parijse wortelen aan Duitse klanten.
127
- Zulke afspraken zijn verboden. Als concurrentie door die samenwerking achterwege blijft en er dus sprake is van een kartel, betalen kopers mogelijk een hogere prijs, stelt de ACM.
128
- Twee van de wortelbedrijven werkten mee door meer informatie over de ongeoorloofde afspraken te delen met de toezichthouder. Daardoor kregen zij een lagere boete.
129
- ```
130
- Vat bovenstaand artikel samen"""
131
- ],
132
- ],
133
- title="🐐🕷️ GEITje 7B Spin Iter 2 🕷️🐐",
134
- description="""\
135
- <a href="https://huggingface.co/davidberenstein1957/ultra-feedback-dutch-cleaned-hq-spin-geitje-7b-ultra-sft_iter2">GEITje 7B SPIN iter 2</a> is een geavanceerde versie van GEITje, verder getraind op uitgebreide chat datasets en ook op een preferentiedataset om beter te aligneren met het gedrag van een gewenste chatbot op basis van het [SPIN algoritme en code](https://github.com/argilla-io/distilabel-spin-dibt/), in dit geval gpt-4-turbo. De data is net iets anders dan de originele dataset want we hebben de schoonmaak voor UltraFeedback gebruikt van Argilla, de uiteindelijke dataset is [hier](https://huggingface.co/datasets/BramVanroy/ultra_feedback_dutch_cleaned) te vinden.
136
- """,
137
- submit_btn="Genereer",
138
- stop_btn="Stop",
139
- retry_btn="🔄 Opnieuw",
140
- undo_btn="↩️ Ongedaan maken",
141
- clear_btn="🗑️ Wissen",
142
- )
143
-
144
- with gr.Blocks(css="style.css") as demo:
145
- chat_interface.render()
146
-
147
- if __name__ == "__main__":
148
- demo.queue(max_size=20).launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -93,6 +93,10 @@ def generate(
93
 
94
  chat_interface = ChatInterface(
95
  fn=generate,
 
 
 
 
96
  chatbot=gr.Chatbot(
97
  height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
98
  ),
 
93
 
94
  chat_interface = ChatInterface(
95
  fn=generate,
96
+ prefence_techniques="dpo",
97
+ min_turns=1,
98
+ max_turns=10,
99
+ repo_id="geitje-spin",
100
  chatbot=gr.Chatbot(
101
  height=450, label="GEITje-SPIN", show_share_button=True, avatar_images=(None, "geitje-logo.jpg")
102
  ),
chat_interface_preference.py CHANGED
@@ -4,11 +4,14 @@ This file defines a useful high-level abstraction to build Gradio chatbots: Chat
4
 
5
  from __future__ import annotations
6
 
 
7
  import functools
8
  import inspect
 
9
  import random
10
  import re
11
- from typing import AsyncGenerator, Callable, Literal, Union, cast
 
12
 
13
  import anyio
14
  from gradio.blocks import Blocks
@@ -23,17 +26,19 @@ from gradio.components import (
23
  get_component_instance,
24
  )
25
  from gradio.events import Dependency, on
26
- from gradio.helpers import Error, Info, Warning
27
  from gradio.helpers import create_examples as Examples # noqa: N812
28
- from gradio.helpers import special_args
29
  from gradio.layouts import Accordion, Group, Row
30
  from gradio.routes import Request
31
  from gradio.themes import ThemeClass as Theme
32
  from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda
33
  from gradio_client.documentation import document
 
34
 
35
  pattern = re.compile(r'<div class="message-identifier">(.*?)</div>', re.DOTALL)
36
 
 
 
37
 
38
  @document()
39
  class ChatInterface(Blocks):
@@ -59,6 +64,11 @@ class ChatInterface(Blocks):
59
  self,
60
  fn: Callable,
61
  *,
 
 
 
 
 
62
  multimodal: bool = False,
63
  chatbot: Chatbot | None = None,
64
  textbox: Textbox | MultimodalTextbox | None = None,
@@ -75,19 +85,10 @@ class ChatInterface(Blocks):
75
  js: str | None = None,
76
  head: str | None = None,
77
  analytics_enabled: bool | None = None,
78
- submit_btn_one: str | None | Button = "Generate 1",
79
- submit_btn_two: str | None | Button = "Generate 2",
80
- submit_btn_a: str | None | Button = "Log preference 🅰️",
81
- submit_btn_b: str | None | Button = "Log preference 🅱️",
82
- submit_btn_ab: str | None | Button = "Random pick 🅰️=🅱️",
83
- stop_btn: str | None | Button = "Stop",
84
- undo_btn: str | None | Button = "↩️ Undo",
85
- clear_btn: str | None | Button = "🗑️ Clear",
86
  autofocus: bool = True,
87
  concurrency_limit: int | None | Literal["default"] = "default",
88
  fill_height: bool = True,
89
  delete_cache: tuple[int, int] | None = None,
90
- rg_dataset: None,
91
  ):
92
  """
93
  Parameters:
@@ -118,6 +119,39 @@ class ChatInterface(Blocks):
118
  fill_height: If True, the chat interface will expand to the height of window.
119
  delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
120
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  super().__init__(
122
  analytics_enabled=analytics_enabled,
123
  mode="chat_interface",
@@ -129,6 +163,7 @@ class ChatInterface(Blocks):
129
  fill_height=fill_height,
130
  delete_cache=delete_cache,
131
  )
 
132
  self.css = css
133
  self.multimodal = multimodal
134
  self.concurrency_limit = concurrency_limit
@@ -139,7 +174,15 @@ class ChatInterface(Blocks):
139
 
140
  self.examples = examples
141
  self.cache_examples: bool | None | Literal["lazy"] = cache_examples
142
- self.rg_dataset = rg_dataset
 
 
 
 
 
 
 
 
143
 
144
  if additional_inputs:
145
  if not isinstance(additional_inputs, list):
@@ -173,14 +216,27 @@ class ChatInterface(Blocks):
173
  Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>")
174
  if description:
175
  Markdown(description)
176
-
 
 
 
 
 
177
  if chatbot:
178
  self.chatbot = chatbot.render()
179
  else:
180
  self.chatbot = Chatbot(label="Chatbot", scale=1, height=200 if fill_height else None)
181
 
182
  with Row():
183
- for btn in [submit_btn_a, submit_btn_b, submit_btn_ab, undo_btn, clear_btn]:
 
 
 
 
 
 
 
 
184
  if btn is not None:
185
  if isinstance(btn, Button):
186
  btn.render()
@@ -268,6 +324,8 @@ class ChatInterface(Blocks):
268
  self.submit_btn_a,
269
  self.submit_btn_b,
270
  self.submit_btn_ab,
 
 
271
  self.undo_btn,
272
  self.clear_btn,
273
  self.submit_btn_one,
@@ -308,14 +366,31 @@ class ChatInterface(Blocks):
308
  self._setup_events()
309
  self._setup_api()
310
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
311
  def _setup_events(self) -> None:
312
  submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
313
  submit_triggers_one = (
314
  [self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
315
  )
316
- submit_fn_two = functools.partial(submit_fn_one, n_generations=2)
317
- submit_triggers_two = [self.submit_btn_two.click]
318
- for _fn, _triggers in [(submit_fn_one, submit_triggers_one), (submit_fn_two, submit_triggers_two)]:
 
 
 
319
  submit_event = (
320
  on(
321
  _triggers,
@@ -342,18 +417,38 @@ class ChatInterface(Blocks):
342
  )
343
  self._setup_stop_events(_triggers, submit_event)
344
 
345
- partial_fn_a, partial_fn_b, partial_fn_ab = (
346
  functools.partial(self._log_fn, log="a"),
347
  functools.partial(self._log_fn, log="b"),
348
  functools.partial(self._log_fn, log="ab"),
 
 
349
  )
350
  for _fn, _btn in [
351
  (partial_fn_a, self.submit_btn_a),
352
  (partial_fn_b, self.submit_btn_b),
353
  (partial_fn_ab, self.submit_btn_ab),
 
 
354
  ]:
355
- _btn.click(
356
- _fn,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
357
  [self.saved_input, self.chatbot_state],
358
  [self.chatbot, self.saved_input, self.chatbot_state],
359
  show_api=False,
@@ -366,36 +461,9 @@ class ChatInterface(Blocks):
366
  queue=False,
367
  )
368
 
369
- # if self.retry_btn:
370
- # submit_fn_partial = functools.partial(submit_fn, append=True)
371
- # retry_event = (
372
- # self.retry_btn.click(
373
- # self._delete_prev_fn,
374
- # [self.saved_input, self.chatbot_state],
375
- # [self.chatbot, self.saved_input, self.chatbot_state],
376
- # show_api=False,
377
- # queue=False,
378
- # )
379
- # .then(
380
- # self._display_input,
381
- # [self.saved_input, self.chatbot_state],
382
- # [self.chatbot, self.chatbot_state],
383
- # show_api=False,
384
- # queue=False,
385
- # )
386
- # .then(
387
- # submit_fn_partial,
388
- # [self.saved_input, self.chatbot_state] + self.additional_inputs,
389
- # [self.chatbot, self.chatbot_state],
390
- # show_api=False,
391
- # concurrency_limit=cast(Union[int, Literal["default"], None], self.concurrency_limit),
392
- # )
393
- # )
394
- # self._setup_stop_events([self.retry_btn.click], retry_event)
395
-
396
- if self.undo_btn:
397
- self.undo_btn.click(
398
- self._delete_prev_fn,
399
  [self.saved_input, self.chatbot_state],
400
  [self.chatbot, self.saved_input, self.chatbot_state],
401
  show_api=False,
@@ -408,15 +476,6 @@ class ChatInterface(Blocks):
408
  queue=False,
409
  )
410
 
411
- if self.clear_btn:
412
- self.clear_btn.click(
413
- async_lambda(lambda: ([], [], None)),
414
- None,
415
- [self.chatbot, self.chatbot_state, self.saved_input],
416
- queue=False,
417
- show_api=False,
418
- )
419
-
420
  def _setup_stop_events(self, event_triggers: list[Callable], event_to_cancel: Dependency) -> None:
421
  if self.stop_btn and self.is_generator:
422
  if self.submit_btn_one:
@@ -545,6 +604,16 @@ class ChatInterface(Blocks):
545
 
546
  return self.css + "<body>" + conversation + "</body>"
547
 
 
 
 
 
 
 
 
 
 
 
548
  @staticmethod
549
  def _get_chat_message(message, role, turn):
550
  if role == "user":
@@ -575,8 +644,27 @@ class ChatInterface(Blocks):
575
  @staticmethod
576
  def _check_if_two_responses(response):
577
  if response:
578
- if '<div class="message-content">' in response:
579
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
580
 
581
  async def _submit_fn(
582
  self,
@@ -586,21 +674,18 @@ class ChatInterface(Blocks):
586
  n_generations: int = 1,
587
  *args,
588
  ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
589
- if not message:
590
- Warning("Make sure to provide a message next time.")
591
- history = history_with_input[:-1]
592
- return history, history
593
-
594
- _, response = history_with_input[-1]
595
- if self._check_if_two_responses(response):
596
- raise Error("Two options detected: undo, log or random pick continuation.")
597
-
598
  if self.multimodal and isinstance(message, dict):
599
  remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
600
  history = history_with_input[:-remove_input]
601
  else:
602
  history = history_with_input[:-1]
603
 
 
 
 
 
 
 
604
  inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
605
 
606
  async def _get_response():
@@ -636,6 +721,14 @@ class ChatInterface(Blocks):
636
  else:
637
  history = history_with_input[:-1]
638
 
 
 
 
 
 
 
 
 
639
  inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
640
 
641
  try:
@@ -681,7 +774,7 @@ class ChatInterface(Blocks):
681
  generator_two = self.fn(*inputs)
682
  else:
683
  generator_two = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
684
- generator_two = SyncToAsyncIterator(generator, self.limiter)
685
  try:
686
  first_response_two = await async_iteration(generator_two)
687
  first_response_two_formatted = self._get_chat_message_comparison(response, first_response_two)
@@ -701,7 +794,7 @@ class ChatInterface(Blocks):
701
  else:
702
  update = history + [[message, None]]
703
  yield update, update
704
- async for response_two in generator:
705
  response_two = self._get_chat_message_comparison(response, response_two)
706
  if self.multimodal and isinstance(message, dict):
707
  update = history + [[message["text"], response_two]]
@@ -717,44 +810,48 @@ class ChatInterface(Blocks):
717
  str | dict[str, list],
718
  list[list[str | tuple | None]],
719
  ]:
720
- if history:
721
- prompt, response = history[-1]
722
- if self._check_if_two_responses(response):
723
- matches = pattern.findall(response)
724
- option_a, option_b = matches[0], matches[1]
725
- if log == "a":
726
- chosen, rejected = option_a, option_b
727
- print(Info("Logged preference: a"))
728
- elif log == "b":
729
- chosen, rejected = option_b, option_a
730
- print(Info("Logged preference: b"))
731
- elif log == "ab":
732
- options = [option_a, option_b]
733
- chosen, rejected = random.choice([options])
734
- print(Info("Picked random response to continue"))
735
-
736
- if log in ["a", "b"] and self.rg_dataset:
737
- import argilla as rg
738
-
739
- self.rg_dataset.records.log(
740
- [
741
- rg.Record(
742
- fields={
743
- "conversation": self._get_conversation_from_history(history),
744
- },
745
- suggestions=[
746
- rg.Suggestion(question_name="chosen", value=chosen),
747
- rg.Suggestion(question_name="rejected", value=rejected),
748
- ],
749
- )
750
- ]
751
- )
752
- history[-1] = [prompt, chosen]
753
- return history, message or "", history
754
- else:
755
- raise Error("Only one option found.")
 
 
 
 
756
  else:
757
- raise Error("No history found")
758
 
759
  async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
760
  inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
@@ -770,7 +867,6 @@ class ChatInterface(Blocks):
770
  *args,
771
  ) -> AsyncGenerator:
772
  inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
773
-
774
  if self.is_async:
775
  generator = self.fn(*inputs)
776
  else:
@@ -794,3 +890,20 @@ class ChatInterface(Blocks):
794
  else:
795
  history = history[:-1]
796
  return history, message or "", history
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
  from __future__ import annotations
6
 
7
+ import datetime
8
  import functools
9
  import inspect
10
+ import json
11
  import random
12
  import re
13
+ import uuid
14
+ from typing import AsyncGenerator, Callable, List, Literal, Union, cast
15
 
16
  import anyio
17
  from gradio.blocks import Blocks
 
26
  get_component_instance,
27
  )
28
  from gradio.events import Dependency, on
29
+ from gradio.helpers import Error, Info, special_args
30
  from gradio.helpers import create_examples as Examples # noqa: N812
 
31
  from gradio.layouts import Accordion, Group, Row
32
  from gradio.routes import Request
33
  from gradio.themes import ThemeClass as Theme
34
  from gradio.utils import SyncToAsyncIterator, async_iteration, async_lambda
35
  from gradio_client.documentation import document
36
+ from huggingface_hub import CommitScheduler
37
 
38
  pattern = re.compile(r'<div class="message-identifier">(.*?)</div>', re.DOTALL)
39
 
40
+ PREFERENCE_TECHNIQUE_MAPPING = {"sft": "prompt", "dpo": "preference", "kto": "vibes"}
41
+
42
 
43
  @document()
44
  class ChatInterface(Blocks):
 
64
  self,
65
  fn: Callable,
66
  *,
67
+ prefence_techniques: str | List[str] | None = None,
68
+ min_turns: int = 1,
69
+ max_turns: int = 1,
70
+ repo_id: None | str,
71
+ repo_private: bool = False,
72
  multimodal: bool = False,
73
  chatbot: Chatbot | None = None,
74
  textbox: Textbox | MultimodalTextbox | None = None,
 
85
  js: str | None = None,
86
  head: str | None = None,
87
  analytics_enabled: bool | None = None,
 
 
 
 
 
 
 
 
88
  autofocus: bool = True,
89
  concurrency_limit: int | None | Literal["default"] = "default",
90
  fill_height: bool = True,
91
  delete_cache: tuple[int, int] | None = None,
 
92
  ):
93
  """
94
  Parameters:
 
119
  fill_height: If True, the chat interface will expand to the height of window.
120
  delete_cache: A tuple corresponding [frequency, age] both expressed in number of seconds. Every `frequency` seconds, the temporary files created by this Blocks instance will be deleted if more than `age` seconds have passed since the file was created. For example, setting this to (86400, 86400) will delete temporary files every day. The cache will be deleted entirely when the server restarts. If None, no cache deletion will occur.
121
  """
122
+ if max_turns < min_turns:
123
+ raise ValueError("`max_turns` should be larger than `min_turns`")
124
+ if any([turn for turn in [max_turns, min_turns] if turn < 1]):
125
+ raise ValueError("`max_turns` should be larger than `min_turns`")
126
+ self.max_turns = max_turns
127
+ self.min_turns = min_turns
128
+ if isinstance(prefence_techniques, str):
129
+ prefence_techniques = [prefence_techniques]
130
+ elif prefence_techniques is None:
131
+ prefence_techniques = ["sft"]
132
+ self.prefence_techniques = [technique.lower() for technique in prefence_techniques]
133
+
134
+ optional_techniques = ["kto", "sft", "spin", "dpo", "simpo", "rlhf", "orpo"]
135
+ if any([technique for technique in self.prefence_techniques if technique not in optional_techniques]):
136
+ raise ValueError(f"Supported techniques are {optional_techniques}")
137
+ submit_btn_one = "Generate"
138
+ submit_btn_two = None
139
+ submit_btn_a = None
140
+ submit_btn_b = None
141
+ submit_btn_ab = None
142
+ submit_btn_good = None
143
+ submit_btn_bad = None
144
+ stop_btn = "Stop"
145
+ undo_btn = "↩️ Undo"
146
+ clear_btn = "🗑️ Log and clear"
147
+ if "kto" in prefence_techniques:
148
+ submit_btn_good = "Log response 👍"
149
+ submit_btn_bad = "Log response 👎"
150
+ if any([technique for technique in ["dpo", "simpo", "rlhf", "orpo"] if technique in self.prefence_techniques]):
151
+ submit_btn_two = "Generate 2"
152
+ submit_btn_a = "Log preference 🅰️"
153
+ submit_btn_b = "Log preference 🅱️"
154
+ submit_btn_ab = "Continue random 🅰️=🅱️"
155
  super().__init__(
156
  analytics_enabled=analytics_enabled,
157
  mode="chat_interface",
 
163
  fill_height=fill_height,
164
  delete_cache=delete_cache,
165
  )
166
+
167
  self.css = css
168
  self.multimodal = multimodal
169
  self.concurrency_limit = concurrency_limit
 
174
 
175
  self.examples = examples
176
  self.cache_examples: bool | None | Literal["lazy"] = cache_examples
177
+ self._set_conversation_id()
178
+ if repo_id:
179
+ self.commit_scheduler = CommitScheduler(
180
+ repo_id=repo_id, folder_path="feedback", repo_type="dataset", private=repo_private, every=1
181
+ )
182
+ else:
183
+ self.commit_scheduler = None
184
+ if self.commit_scheduler:
185
+ self.data_file = self.commit_scheduler.folder_path / f"data_{uuid.uuid4()}.json"
186
 
187
  if additional_inputs:
188
  if not isinstance(additional_inputs, list):
 
216
  Markdown(f"<h1 style='text-align: center; margin-bottom: 1rem'>{self.title}</h1>")
217
  if description:
218
  Markdown(description)
219
+ if self.commit_scheduler:
220
+ Markdown(
221
+ f"## Data is being logged to a datset on the hub: [{self.commit_scheduler.repo_id}](https://huggingface.co/datasets/{self.commit_scheduler.repo_id})"
222
+ )
223
+ Markdown(f"### Techniques: {self.prefence_techniques}")
224
+ Markdown(f"### MIN TURNS: {self.min_turns} - MAX TURN: {self.max_turns}")
225
  if chatbot:
226
  self.chatbot = chatbot.render()
227
  else:
228
  self.chatbot = Chatbot(label="Chatbot", scale=1, height=200 if fill_height else None)
229
 
230
  with Row():
231
+ for btn in [
232
+ submit_btn_a,
233
+ submit_btn_b,
234
+ submit_btn_ab,
235
+ submit_btn_good,
236
+ submit_btn_bad,
237
+ undo_btn,
238
+ clear_btn,
239
+ ]:
240
  if btn is not None:
241
  if isinstance(btn, Button):
242
  btn.render()
 
324
  self.submit_btn_a,
325
  self.submit_btn_b,
326
  self.submit_btn_ab,
327
+ self.submit_btn_good,
328
+ self.submit_btn_bad,
329
  self.undo_btn,
330
  self.clear_btn,
331
  self.submit_btn_one,
 
366
  self._setup_events()
367
  self._setup_api()
368
 
369
+ def _set_conversation_id(self):
370
+ self._conversation_id = str(uuid.uuid4())
371
+
372
+ def _save_feedback(self, item):
373
+ feedback = {
374
+ "timestamp": datetime.datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
375
+ "conversation_id": self._conversation_id,
376
+ }
377
+ feedback.update(item)
378
+ if self.commit_scheduler:
379
+ with self.commit_scheduler.lock:
380
+ with self.data_file.open("a") as f:
381
+ f.write(json.dumps(feedback))
382
+
383
  def _setup_events(self) -> None:
384
  submit_fn_one = self._stream_fn if self.is_generator else self._submit_fn
385
  submit_triggers_one = (
386
  [self.textbox.submit, self.submit_btn_one.click] if self.submit_btn_one else [self.textbox.submit]
387
  )
388
+ submit_tuples = [(submit_fn_one, submit_triggers_one)]
389
+ if self.submit_btn_two:
390
+ submit_fn_two = functools.partial(submit_fn_one, n_generations=2)
391
+ submit_triggers_two = [self.submit_btn_two.click]
392
+ submit_tuples.append((submit_fn_two, submit_triggers_two))
393
+ for _fn, _triggers in submit_tuples:
394
  submit_event = (
395
  on(
396
  _triggers,
 
417
  )
418
  self._setup_stop_events(_triggers, submit_event)
419
 
420
+ partial_fn_a, partial_fn_b, partial_fn_ab, partial_fn_good, partial_fn_bad = (
421
  functools.partial(self._log_fn, log="a"),
422
  functools.partial(self._log_fn, log="b"),
423
  functools.partial(self._log_fn, log="ab"),
424
+ functools.partial(self._log_fn, log="good"),
425
+ functools.partial(self._log_fn, log="bad"),
426
  )
427
  for _fn, _btn in [
428
  (partial_fn_a, self.submit_btn_a),
429
  (partial_fn_b, self.submit_btn_b),
430
  (partial_fn_ab, self.submit_btn_ab),
431
+ (partial_fn_good, self.submit_btn_good),
432
+ (partial_fn_bad, self.submit_btn_bad),
433
  ]:
434
+ if _btn:
435
+ _btn.click(
436
+ _fn,
437
+ [self.saved_input, self.chatbot_state],
438
+ [self.chatbot, self.saved_input, self.chatbot_state],
439
+ show_api=False,
440
+ queue=False,
441
+ ).then(
442
+ async_lambda(lambda x: x),
443
+ [self.saved_input],
444
+ [self.textbox],
445
+ show_api=False,
446
+ queue=False,
447
+ )
448
+
449
+ if self.undo_btn:
450
+ self.undo_btn.click(
451
+ self._delete_prev_fn,
452
  [self.saved_input, self.chatbot_state],
453
  [self.chatbot, self.saved_input, self.chatbot_state],
454
  show_api=False,
 
461
  queue=False,
462
  )
463
 
464
+ if self.clear_btn:
465
+ self.clear_btn.click(
466
+ self._clear_fn,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
467
  [self.saved_input, self.chatbot_state],
468
  [self.chatbot, self.saved_input, self.chatbot_state],
469
  show_api=False,
 
476
  queue=False,
477
  )
478
 
 
 
 
 
 
 
 
 
 
479
  def _setup_stop_events(self, event_triggers: list[Callable], event_to_cancel: Dependency) -> None:
480
  if self.stop_btn and self.is_generator:
481
  if self.submit_btn_one:
 
604
 
605
  return self.css + "<body>" + conversation + "</body>"
606
 
607
+ def _get_conversation_in_openai_format(self, history):
608
+ conversation = []
609
+ for idx, turn in enumerate(history):
610
+ roles = ["user", "assistant"]
611
+ if idx == len(turn) - 1:
612
+ roles = ["user"]
613
+ for role, content in zip(roles, turn):
614
+ conversation.append({"role": role, "content": content})
615
+ return conversation
616
+
617
  @staticmethod
618
  def _get_chat_message(message, role, turn):
619
  if role == "user":
 
644
  @staticmethod
645
  def _check_if_two_responses(response):
646
  if response:
647
+ matches = pattern.findall(response)
648
+ return matches
649
+
650
+ def _check_num_turns(self, history, generate=True):
651
+ if generate:
652
+ if len(history) >= self.max_turns:
653
+ raise Error(
654
+ f"We intend to collect conversations with a maximum of {self.max_turns}, please clear or log info first."
655
+ )
656
+ return history, history
657
+ else:
658
+ if len(history) < self.min_turns:
659
+ raise Error(
660
+ f"We intend to collect conversations with at least of {self.min_turns}, please continue the conversation first."
661
+ )
662
+ return history, history
663
+
664
+ @staticmethod
665
+ def _check_message(message):
666
+ if not message:
667
+ raise Error("Make sure to provide a message next time.")
668
 
669
  async def _submit_fn(
670
  self,
 
674
  n_generations: int = 1,
675
  *args,
676
  ) -> tuple[list[list[str | tuple | None]], list[list[str | tuple | None]]]:
 
 
 
 
 
 
 
 
 
677
  if self.multimodal and isinstance(message, dict):
678
  remove_input = len(message["files"]) + 1 if message["text"] is not None else len(message["files"])
679
  history = history_with_input[:-remove_input]
680
  else:
681
  history = history_with_input[:-1]
682
 
683
+ self._check_message(message)
684
+ self._check_num_turns(history)
685
+ _, response = history_with_input[-1]
686
+ if self._check_if_two_responses(response):
687
+ raise Error("Two options detected: undo, log or random pick continuation.")
688
+
689
  inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
690
 
691
  async def _get_response():
 
721
  else:
722
  history = history_with_input[:-1]
723
 
724
+ self._check_message(message)
725
+ self._check_num_turns(history)
726
+ _, response = history_with_input[-1]
727
+ if self._check_if_two_responses(response):
728
+ raise Error("Two options detected: undo, log or random pick continuation.")
729
+
730
+ _, response = history_with_input[-1]
731
+
732
  inputs, _, _ = special_args(self.fn, inputs=[message, history, *args], request=request)
733
 
734
  try:
 
774
  generator_two = self.fn(*inputs)
775
  else:
776
  generator_two = await anyio.to_thread.run_sync(self.fn, *inputs, limiter=self.limiter)
777
+ generator_two = SyncToAsyncIterator(generator_two, self.limiter)
778
  try:
779
  first_response_two = await async_iteration(generator_two)
780
  first_response_two_formatted = self._get_chat_message_comparison(response, first_response_two)
 
794
  else:
795
  update = history + [[message, None]]
796
  yield update, update
797
+ async for response_two in generator_two:
798
  response_two = self._get_chat_message_comparison(response, response_two)
799
  if self.multimodal and isinstance(message, dict):
800
  update = history + [[message["text"], response_two]]
 
810
  str | dict[str, list],
811
  list[list[str | tuple | None]],
812
  ]:
813
+ self._check_num_turns(history, generate=False)
814
+ history_as_openai_format = self._get_conversation_in_openai_format(history)
815
+ feedback = {"prompt": history_as_openai_format}
816
+
817
+ prompt, response = history[-1]
818
+ matches = self._check_if_two_responses(response)
819
+ if matches and log != "prompt":
820
+ option_a, option_b = matches[0], matches[1]
821
+ if log == "a":
822
+ chosen, rejected = option_a, option_b
823
+ Info("Logged preference: a")
824
+ elif log == "b":
825
+ chosen, rejected = option_b, option_a
826
+ Info("Logged preference: b")
827
+ elif log == "ab":
828
+ options = [option_a, option_b]
829
+ chosen, rejected = random.choice([options])
830
+ Info("Picked random response to continue")
831
+ if log in ["a", "b"] and self.commit_scheduler:
832
+ feedback.update(
833
+ {
834
+ "chosen": [{"content": chosen, "role": "assistant"}],
835
+ "rejected": [{"content": rejected, "role": "assistant"}],
836
+ }
837
+ )
838
+ self._save_feedback(feedback)
839
+ elif log == "ab":
840
+ self._save_feedback(feedback)
841
+ history[-1] = [prompt, chosen]
842
+ return history, message or "", history
843
+ elif log in ["conversation", "good", "bad"]:
844
+ feedback.update({"response": response})
845
+ if log == "good":
846
+ feedback.update({"label": True})
847
+ elif log == "bad":
848
+ feedback.update({"label": False})
849
+ Info("Logged conversation")
850
+ self._save_feedback(feedback)
851
+
852
+ return history, "", history
853
  else:
854
+ raise Error("Error in code w.r.t logging.")
855
 
856
  async def _examples_fn(self, message: str, *args) -> list[list[str | None]]:
857
  inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
 
867
  *args,
868
  ) -> AsyncGenerator:
869
  inputs, _, _ = special_args(self.fn, inputs=[message, [], *args], request=None)
 
870
  if self.is_async:
871
  generator = self.fn(*inputs)
872
  else:
 
890
  else:
891
  history = history[:-1]
892
  return history, message or "", history
893
+
894
+ async def _clear_fn(
895
+ self,
896
+ message: str | dict[str, list],
897
+ history: list[list[str | tuple | None]],
898
+ ) -> tuple[
899
+ list[list[str | tuple | None]],
900
+ str | dict[str, list],
901
+ list[list[str | tuple | None]],
902
+ ]:
903
+ _, response = history[-1]
904
+ if self._check_if_two_responses(response):
905
+ raise Error("First log preference or continue random.")
906
+ else:
907
+ await self._log_fn(message=message, history=history, log="prompt")
908
+ self._set_conversation_id()
909
+ return [], "", []
test.py CHANGED
@@ -1,34 +1,26 @@
1
  import random
2
 
3
-
4
  from chat_interface_preference import ChatInterface
5
 
6
 
7
  def random_response(message, history, request):
8
  response = random.choice(["Yes", "No"])
 
9
  for char in response:
10
- yield char
 
11
 
12
 
13
  style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
14
- client = rg.Argilla(api_url="https://davidberenstein1957-argilla-gradio.hf.space", api_key="owner.apikey")
15
 
16
- required_settings = rg.Settings(
17
- fields=[rg.TextField(name="conversation")],
18
- questions=[
19
- rg.TextQuestion(name="chosen"),
20
- rg.TextQuestion(name="rejected"),
21
- ],
 
22
  )
23
- name = "test"
24
- if client.datasets(name=name).exists():
25
- dataset: rg.Dataset = client.datasets(name=name)
26
-
27
- else:
28
- dataset = rg.Dataset(name=name, settings=required_settings)
29
- dataset.create()
30
-
31
- demo = ChatInterface(random_response, cache_examples=False, css=style, rg_dataset=dataset)
32
 
33
  if __name__ == "__main__":
34
  demo.launch()
 
1
  import random
2
 
 
3
  from chat_interface_preference import ChatInterface
4
 
5
 
6
  def random_response(message, history, request):
7
  response = random.choice(["Yes", "No"])
8
+ response_total = ""
9
  for char in response:
10
+ response_total += char
11
+ yield response_total
12
 
13
 
14
  style = "<style>.user-message,.system-message{display:flex;margin:10px}.user-message .message-content{background-color:#c2e3f7;color:#000}.system-message .message-content{background-color:#f5f5f5;color:#000}.message-content{padding:10px;border-radius:10px;max-width:70%;word-wrap:break-word}.container{display:flex;justify-content:space-between}.column{width:48%}</style>"
 
15
 
16
+ demo = ChatInterface(
17
+ random_response,
18
+ cache_examples=False,
19
+ css=style,
20
+ repo_id="geitje-spin-preference",
21
+ prefence_techniques=["kto"],
22
+ max_turns=10,
23
  )
 
 
 
 
 
 
 
 
 
24
 
25
  if __name__ == "__main__":
26
  demo.launch()