File size: 4,089 Bytes
8824f88
 
 
 
 
 
 
 
 
0adfa8b
ba0ce5d
8824f88
 
c4bb11b
8824f88
 
e466982
 
 
 
 
 
 
 
ba0ce5d
8824f88
 
 
 
 
 
c4bb11b
 
 
8824f88
 
e466982
 
 
 
 
 
 
 
 
 
8824f88
e466982
 
 
 
 
 
 
 
 
 
 
 
 
 
8824f88
e466982
 
 
 
8824f88
e466982
 
8824f88
 
ba0ce5d
8824f88
17aeee6
 
 
d355157
e466982
c4bb11b
8824f88
 
 
 
 
 
 
 
 
 
c4bb11b
 
 
 
8824f88
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1038ed
 
 
 
8824f88
b86214b
63414cd
8824f88
 
 
 
 
 
ba0ce5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python
import os
from typing import Iterator

import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer

from chat_interface_preference import ChatInterface

MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))

if torch.cuda.is_available():
    # model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
    # model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
    # tokenizer = AutoTokenizer.from_pretrained(model_id)
    pass

style = None

AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer


@spaces.GPU
def generate(
    message: str,
    chat_history: list[tuple[str, str]],
    max_new_tokens: int = 1024,
    temperature: float = 0.06,
    top_p: float = 0.95,
    top_k: int = 40,
    repetition_penalty: float = 1.2,
) -> Iterator[str]:
    # conversation = []
    # for user, assistant in chat_history:
    #     conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
    # conversation.append({"role": "user", "content": message})

    # input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
    # if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
    #     input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
    #     gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
    # input_ids = input_ids.to(model.device)

    # streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
    # generate_kwargs = dict(
    #     {"input_ids": input_ids},
    #     streamer=streamer,
    #     max_new_tokens=max_new_tokens,
    #     do_sample=True,
    #     top_p=top_p,
    #     top_k=top_k,
    #     temperature=temperature,
    #     num_beams=1,
    #     repetition_penalty=repetition_penalty,
    # )
    # t = Thread(target=model.generate, kwargs=generate_kwargs)
    # t.start()

    # outputs = []
    # for text in streamer:
    #     outputs.append(text)
    #     yield "".join(outputs)

    for char in "help":
        yield "help"


chat_interface = ChatInterface(
    fn=generate,
    prefence_techniques="dpo",
    min_turns=1,
    max_turns=10,
    repo_id="llm-human-feedback-collector-chat-interface-dpo",
    chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
    cache_examples=False,
    additional_inputs=[
        gr.Slider(
            label="Max new tokens",
            minimum=1,
            maximum=MAX_MAX_NEW_TOKENS,
            step=1,
            value=DEFAULT_MAX_NEW_TOKENS,
        ),
        gr.Slider(
            label="Temperature",
            minimum=0.05,
            maximum=1.2,
            step=0.05,
            value=0.2,
        ),
        gr.Slider(
            label="Top-p (nucleus sampling)",
            minimum=0.05,
            maximum=1.0,
            step=0.05,
            value=0.9,
        ),
        gr.Slider(
            label="Top-k",
            minimum=1,
            maximum=1000,
            step=1,
            value=50,
        ),
        gr.Slider(
            label="Repetition penalty",
            minimum=1.0,
            maximum=2.0,
            step=0.05,
            value=1.2,
        ),
    ],
    examples=[
        ["""What word doesn't make sense in this row: "car, airplane, lama, bus"?"""],
        ["Write a news article about the usage of Lama's by the CSI"],
        ["What are great things cook when getting started with Asian cooking?"],
        ["Who was Anthony Bourdain?"],
    ],
    title="💪🏽🦾 LLM human-feedback collector ChatInterface (DPO) 🦾💪🏽",
    description="""This is an adaptation of the gr.ChatInferface which allows for human feedback collection for SFT, DPO and KTO.""",
)

with gr.Blocks(css="style.css") as demo:
    chat_interface.render()

if __name__ == "__main__":
    demo.queue(max_size=20).launch()