|
import os, transformers, peft, torch, gradio as gr |
|
|
|
base_model = "h2oai/h2ogpt-4096-llama2-13b" |
|
model = transformers.AutoModelForCausalLM.from_pretrained( |
|
base_model, |
|
load_in_8bit=True, |
|
torch_dtype=torch.float16 |
|
) |
|
tokenizer = transformers.AutoTokenizer.from_pretrained(base_model) |
|
|
|
lora_model = "pcalhoun/Llama-2-13b-Conversations-With-Tyler-Swift" |
|
model = peft.PeftModel.from_pretrained( |
|
model, |
|
lora_model, |
|
torch_dtype=torch.float16 |
|
) |
|
|
|
def generate(prompt, extra_eos=[]): |
|
inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False) |
|
input_token_length = inputs.input_ids.shape[1] |
|
outputs = model.generate(**inputs, max_length=4096) |
|
text = tokenizer.batch_decode(outputs)[0] |
|
return text |
|
|
|
def create_next_prompt(title_string,description_string="",conversation_messages=[]): |
|
if not len(description_string): |
|
conversation_messages = [] |
|
prompt = """<s>### CONVERSATIONS WITH TYLER SWIFT ### |
|
|
|
TITLE: """ + title_string.strip() + """ |
|
|
|
DESCRIPTION:""" |
|
if not len(description_string): |
|
return prompt |
|
else: |
|
prompt += " "+description_string.replace("\n\n","\n").strip() + "\n\n" |
|
if not len(conversation_messages): |
|
prompt += "### TYLER SWIFT:" |
|
return prompt |
|
else: |
|
for message_data in conversation_messages: |
|
prompt += "### " + message_data['speaker'].upper() + ": " + message_data['message'].strip() |
|
if message_data['speaker'].upper() == "TYLER SWIFT": |
|
prompt += "</s><s>" |
|
prompt += "\n" |
|
if conversation_messages[-1]["speaker"].upper() != "TYLER SWIFT": |
|
prompt += "### TYLER SWIFT:" |
|
return prompt |
|
|
|
def deconstruct_returned_text(text): |
|
|
|
text = "\n".join(text.split("\n")[1:]).strip() |
|
title = text.split("\n")[0].replace("TITLE:","").strip() |
|
text = "\n".join(text.split("\n")[1:]).strip() |
|
description = text.split("\n\n")[0].replace("DESCRIPTION:","").strip() |
|
text = "\n\n".join(text.split("\n\n")[1:]).strip() |
|
conversation_text = text.replace("</s>", "").replace("<s>", "").split("<<")[0].strip() |
|
return title,description,conversation_text |
|
|
|
def generate_next(title,description,conversation_text): |
|
if not len(title): |
|
title = "Set a Title First" |
|
return title,description,conversation_text |
|
if title == "Set a Title First": |
|
title = "On Olympic Curling Injuries and Glaswegian Alcoholism" |
|
conversation = [] |
|
for line in conversation_text.split("\n"): |
|
if "CONVERSATIONS WITH TYLER SWIFT" in line: |
|
continue |
|
if line.startswith("###"): |
|
speaker = line.split(":")[0].replace("###","").strip() |
|
message = ":".join(line.split(":")[1:]).strip() |
|
conversation.append({"speaker":speaker,"message":message.replace("</s>", "").replace("<s>", "").strip()}) |
|
prompt = create_next_prompt(title,description,conversation) |
|
generated_text = generate(prompt) |
|
print("GENERATED TEXT:",generated_text) |
|
title,description,conversation_text = deconstruct_returned_text(generated_text) |
|
return title,description,conversation_text |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("""Come up with a podcast title and then click 'Generate Additional CWTS Podcast Transcript Lines'. |
|
|
|
Conversation lines are separated by a linebreak followed by three hashes. This should be generated automatically. New lines can be added by clicking 'Generate...' again.""") |
|
title = gr.Textbox(label="Podcast Title") |
|
description = gr.Textbox(label="Description") |
|
conversation_text = gr.Textbox(label="Conversation") |
|
generate_button = gr.Button(value="Generate Additional CWTS Podcast Transcript Lines") |
|
generate_button.click(fn=generate_next, inputs=[title,description,conversation_text], outputs=[title,description,conversation_text]) |
|
|
|
demo.launch() |
|
|
|
|
|
|