File size: 4,089 Bytes
8824f88 0adfa8b ba0ce5d 8824f88 c4bb11b 8824f88 e466982 ba0ce5d 8824f88 c4bb11b 8824f88 e466982 8824f88 e466982 8824f88 e466982 8824f88 e466982 8824f88 ba0ce5d 8824f88 17aeee6 d355157 e466982 c4bb11b 8824f88 c4bb11b 8824f88 b1038ed 8824f88 b86214b 63414cd 8824f88 ba0ce5d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 |
#!/usr/bin/env python
import os
from typing import Iterator
import gradio as gr
import spaces
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
from chat_interface_preference import ChatInterface
MAX_MAX_NEW_TOKENS = 2048
DEFAULT_MAX_NEW_TOKENS = 1024
MAX_INPUT_TOKEN_LENGTH = int(os.getenv("MAX_INPUT_TOKEN_LENGTH", "8192"))
if torch.cuda.is_available():
# model_id = "meta-llama/Meta-Llama-3.1-8B-Instruct"
# model = AutoModelForCausalLM.from_pretrained(model_id, torch_dtype=torch.float16, device_map="auto")
# tokenizer = AutoTokenizer.from_pretrained(model_id)
pass
style = None
AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
@spaces.GPU
def generate(
message: str,
chat_history: list[tuple[str, str]],
max_new_tokens: int = 1024,
temperature: float = 0.06,
top_p: float = 0.95,
top_k: int = 40,
repetition_penalty: float = 1.2,
) -> Iterator[str]:
# conversation = []
# for user, assistant in chat_history:
# conversation.extend([{"role": "user", "content": user}, {"role": "assistant", "content": assistant}])
# conversation.append({"role": "user", "content": message})
# input_ids = tokenizer.apply_chat_template(conversation, add_generation_prompt=True, return_tensors="pt")
# if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH:
# input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:]
# gr.Warning(f"Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
# input_ids = input_ids.to(model.device)
# streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
# generate_kwargs = dict(
# {"input_ids": input_ids},
# streamer=streamer,
# max_new_tokens=max_new_tokens,
# do_sample=True,
# top_p=top_p,
# top_k=top_k,
# temperature=temperature,
# num_beams=1,
# repetition_penalty=repetition_penalty,
# )
# t = Thread(target=model.generate, kwargs=generate_kwargs)
# t.start()
# outputs = []
# for text in streamer:
# outputs.append(text)
# yield "".join(outputs)
for char in "help":
yield "help"
chat_interface = ChatInterface(
fn=generate,
prefence_techniques="dpo",
min_turns=1,
max_turns=10,
repo_id="llm-human-feedback-collector-chat-interface-dpo",
chatbot=gr.Chatbot(height=450, label="Meta-Llama-3.1-8B-Instruct", show_share_button=True),
cache_examples=False,
additional_inputs=[
gr.Slider(
label="Max new tokens",
minimum=1,
maximum=MAX_MAX_NEW_TOKENS,
step=1,
value=DEFAULT_MAX_NEW_TOKENS,
),
gr.Slider(
label="Temperature",
minimum=0.05,
maximum=1.2,
step=0.05,
value=0.2,
),
gr.Slider(
label="Top-p (nucleus sampling)",
minimum=0.05,
maximum=1.0,
step=0.05,
value=0.9,
),
gr.Slider(
label="Top-k",
minimum=1,
maximum=1000,
step=1,
value=50,
),
gr.Slider(
label="Repetition penalty",
minimum=1.0,
maximum=2.0,
step=0.05,
value=1.2,
),
],
examples=[
["""What word doesn't make sense in this row: "car, airplane, lama, bus"?"""],
["Write a news article about the usage of Lama's by the CSI"],
["What are great things cook when getting started with Asian cooking?"],
["Who was Anthony Bourdain?"],
],
title="💪🏽🦾 LLM human-feedback collector ChatInterface (DPO) 🦾💪🏽",
description="""This is an adaptation of the gr.ChatInferface which allows for human feedback collection for SFT, DPO and KTO.""",
)
with gr.Blocks(css="style.css") as demo:
chat_interface.render()
if __name__ == "__main__":
demo.queue(max_size=20).launch()
|