File size: 5,395 Bytes
8824f88
 
46218ed
a38b23a
8824f88
 
 
 
a38b23a
7e8ce88
 
 
8824f88
0adfa8b
ba0ce5d
8824f88
 
c4bb11b
8824f88
 
05c0b89
 
 
ba0ce5d
8824f88
 
 
 
 
 
c4bb11b
 
 
8824f88
 
46218ed
 
05c0b89
 
 
8824f88
05c0b89
 
 
 
 
8824f88
05c0b89
 
 
 
 
 
 
 
 
 
 
 
 
 
8824f88
05c0b89
 
 
 
8824f88
 
ba0ce5d
8824f88
7e8ce88
17aeee6
 
d355157
e466982
c4bb11b
8824f88
 
 
 
 
 
 
 
 
 
c4bb11b
 
 
7c45d36
8824f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1038ed
 
 
 
8824f88
2c31d1f
 
 
a38b23a
2c31d1f
 
 
 
 
 
8824f88
 
 
 
 
 
ba0ce5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
#!/usr/bin/env python
import os
import random
from threading import Thread  # noqa
from typing import Iterator

import gradio as gr
import spaces
import torch  # noqa
from transformers import AutoModelForCausalLM  # noqa
from transformers import AutoTokenizer  # noqa
from transformers import TextIteratorStreamer  # noqa

from chat_interface_preference import ChatInterface

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))

if torch.cuda.is_available():
    model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
    tokenizer = AutoTokenizer.from_pretrained(model_id)


@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.06,
    top_p: float = 0.95,
    top_k: int = 40,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    system_message = random.choice(["concise", "explicit", "simple", "complex", "usefull", "helpfull"])
    conversation = [{"role": "system", "content": f"Communicate {system_message}."}]
    for user, assistant in chat_history:
        conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    conversation.append({"role": "user", "content": message})

    input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
        input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
        gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    input_ids = input_ids.to(model.device)

    streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    generate_kwargs = dict(
        {"input_ids": input_ids},
        streamer=streamer,
        max_new_tokens=max_new_tokens,
        do_sample=True,
        top_p=top_p,
        top_k=top_k,
        temperature=temperature,
        num_beams=1,
        repetition_penalty=repetition_penalty,
    )
    t = Thread(target=model.generate, kwargs=generate_kwargs)
    t.start()

    outputs = []
    for text in streamer:
        outputs.append(text)
        yield "".join(outputs)


chat_interface = ChatInterface(
    fn=generate,
    prefence_technique="dpo",
    min_turns=1,
    max_turns=10,
    repo_id="llm-human-feedback-collector-chat-interface-dpo",
    chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
    cache_examples=False,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.05,
            maximum=1.2,
            step=0.05,
            value=0.7,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    examples=[
        ["""What word doesn't make sense in this row: "car, airplane, lama, bus"?"""],
        ["Write a news article about the usage of Lama's by the CSI"],
        ["What are great things cook when getting started with Asian cooking?"],
        ["Who was Anthony Bourdain?"],
    ],
    title="💪🏽🦾 Human Feedback Collector | Meta-Llama-3.1-8B-Instruct | (DPO) 🦾💪🏽",
    description="".join(
        [
            "This is an adaptation of the [`gr.ChatInferface`](https://www.gradio.app/docs/gradio/chatinterface) which also uses the [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler) to allow for human feedback collection. ",
            "Another cool tool for capturing Gradio interactions is the [`gr.HuggingFaceDatasetSaver`](https://www.gradio.app/guides/using-flagging#the-hugging-face-dataset-saver-callback). ",
            "This demo shows how you might capture human feedback directly from applications within Gradio. ",
            "The captured feedback can directly be used for fine-tuning LLMs within framework like [transformers](https://github.com/huggingface/transformers), [TRL](https://github.com/huggingface/trl) or [AutoTrain](https://huggingface.co/autotrain), ",
            "however, it might benefit from additional data curation with something like [Argilla](https://github.com/argilla-io/argilla/) for human feedback and/or [distilabel](https://github.com/argilla-io/distilabel/) for AI feedback. Argilla can even be [deployed for free on Hugging Face Spaces](https://argilla-io.github.io/argilla/latest/getting_started/huggingface-spaces/).",
        ]
    ),
)

with gr.Blocks(css="style.css") as demo:
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()