File size: 5,395 Bytes
8824f88 46218ed a38b23a 8824f88 a38b23a 7e8ce88 8824f88 0adfa8b ba0ce5d 8824f88 c4bb11b 8824f88 05c0b89 ba0ce5d 8824f88 c4bb11b 8824f88 46218ed 05c0b89 8824f88 05c0b89 8824f88 05c0b89 8824f88 05c0b89 8824f88 ba0ce5d 8824f88 7e8ce88 17aeee6 d355157 e466982 c4bb11b 8824f88 c4bb11b 7c45d36 8824f88 b1038ed 8824f88 2c31d1f a38b23a 2c31d1f 8824f88 ba0ce5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 |
#!/usr/bin/env python
import os
import random
from threading import Thread # noqa
from typing import Iterator
import gradio as gr
import spaces
import torch # noqa
from transformers import AutoModelForCausalLM # noqa
from transformers import AutoTokenizer # noqa
from transformers import TextIteratorStreamer # noqa
from chat_interface_preference import ChatInterface
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
if torch.cuda.is_available():
model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
tokenizer = AutoTokenizer.from_pretrained(model_id)
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.06,
top_p: float = 0.95,
top_k: int = 40,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
system_message = random.choice(["concise", "explicit", "simple", "complex", "usefull", "helpfull"])
conversation = [{"role": "system", "content": f"Communicate {system_message}."}]
for user, assistant in chat_history:
conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
conversation.append({"role": "user", "content": message})
input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
input_ids = input_ids.to(model.device)
streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
generate_kwargs = dict(
{"input_ids": input_ids},
streamer=streamer,
max_new_tokens=max_new_tokens,
do_sample=True,
top_p=top_p,
top_k=top_k,
temperature=temperature,
num_beams=1,
repetition_penalty=repetition_penalty,
)
t = Thread(target=model.generate, kwargs=generate_kwargs)
t.start()
outputs = []
for text in streamer:
outputs.append(text)
yield "".join(outputs)
chat_interface = ChatInterface(
fn=generate,
prefence_technique="dpo",
min_turns=1,
max_turns=10,
repo_id="llm-human-feedback-collector-chat-interface-dpo",
chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
cache_examples=False,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.05,
maximum=1.2,
step=0.05,
value=0.7,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
examples=[
["""What word doesn't make sense in this row: "car, airplane, lama, bus"?"""],
["Write a news article about the usage of Lama's by the CSI"],
["What are great things cook when getting started with Asian cooking?"],
["Who was Anthony Bourdain?"],
],
title="💪🏽🦾 Human Feedback Collector | Meta-Llama-3.1-8B-Instruct | (DPO) 🦾💪🏽",
description="".join(
[
"This is an adaptation of the [`gr.ChatInferface`](https://www.gradio.app/docs/gradio/chatinterface) which also uses the [`huggingface_hub.CommitScheduler`](https://huggingface.co/docs/huggingface_hub/main/en/package_reference/hf_api#huggingface_hub.CommitScheduler) to allow for human feedback collection. ",
"Another cool tool for capturing Gradio interactions is the [`gr.HuggingFaceDatasetSaver`](https://www.gradio.app/guides/using-flagging#the-hugging-face-dataset-saver-callback). ",
"This demo shows how you might capture human feedback directly from applications within Gradio. ",
"The captured feedback can directly be used for fine-tuning LLMs within framework like [transformers](https://github.com/huggingface/transformers), [TRL](https://github.com/huggingface/trl) or [AutoTrain](https://huggingface.co/autotrain), ",
"however, it might benefit from additional data curation with something like [Argilla](https://github.com/argilla-io/argilla/) for human feedback and/or [distilabel](https://github.com/argilla-io/distilabel/) for AI feedback. Argilla can even be [deployed for free on Hugging Face Spaces](https://argilla-io.github.io/argilla/latest/getting_started/huggingface-spaces/).",
]
),
)
with gr.Blocks(css="style.css") as demo:
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|