pcalhoun's picture
Update app.py
66d2a7a
raw
history blame
3.88 kB
import os, transformers, peft, torch, gradio as gr
base_model = "h2oai/h2ogpt-4096-llama2-13b"
model = transformers.AutoModelForCausalLM.from_pretrained(
base_model,
load_in_8bit=True,
torch_dtype=torch.float16
)
tokenizer = transformers.AutoTokenizer.from_pretrained(base_model)
lora_model = "pcalhoun/Llama-2-13b-Conversations-With-Tyler-Swift"
model = peft.PeftModel.from_pretrained(
model,
lora_model,
torch_dtype=torch.float16
)
def generate(prompt, extra_eos=[]):
inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False)
input_token_length = inputs.input_ids.shape[1]
outputs = model.generate(**inputs, max_length=4096)
text = tokenizer.batch_decode(outputs)[0]
return text
def create_next_prompt(title_string,description_string="",conversation_messages=[]):
if not len(description_string):
conversation_messages = []
prompt = """<s>### CONVERSATIONS WITH TYLER SWIFT ###
TITLE: """ + title_string.strip() + """
DESCRIPTION:"""
if not len(description_string):
return prompt
else:
prompt += " "+description_string.replace("\n\n","\n").strip() + "\n\n"
if not len(conversation_messages):
prompt += "### TYLER SWIFT:"
return prompt
else:
for message_data in conversation_messages:
prompt += "### " + message_data['speaker'].upper() + ": " + message_data['message'].strip()
if message_data['speaker'].upper() == "TYLER SWIFT":
prompt += "</s><s>"
prompt += "\n"
if conversation_messages[-1]["speaker"].upper() != "TYLER SWIFT":
prompt += "### TYLER SWIFT:"
return prompt
def deconstruct_returned_text(text):
#skip first line
text = "\n".join(text.split("\n")[1:]).strip()
title = text.split("\n")[0].replace("TITLE:","").strip()
text = "\n".join(text.split("\n")[1:]).strip()
description = text.split("\n\n")[0].replace("DESCRIPTION:","").strip()
text = "\n\n".join(text.split("\n\n")[1:]).strip()
conversation_text = text.replace("</s>", "").replace("<s>", "").split("<<")[0].strip()
return title,description,conversation_text
def generate_next(title,description,conversation_text):
if not len(title):
title = "Set a Title First"
return title,description,conversation_text
if title == "Set a Title First":
title = "On Olympic Curling Injuries and Glaswegian Alcoholism"
conversation = []
for line in conversation_text.split("\n"):
if "CONVERSATIONS WITH TYLER SWIFT" in line:
continue
if line.startswith("###"):
speaker = line.split(":")[0].replace("###","").strip()
message = ":".join(line.split(":")[1:]).strip()
conversation.append({"speaker":speaker,"message":message.replace("</s>", "").replace("<s>", "").strip()})
prompt = create_next_prompt(title,description,conversation)
generated_text = generate(prompt)
print("GENERATED TEXT:",generated_text)
title,description,conversation_text = deconstruct_returned_text(generated_text)
return title,description,conversation_text
with gr.Blocks() as demo:
gr.Markdown("""Come up with a podcast title and then click 'Generate Additional CWTS Podcast Transcript Lines'.
Conversation lines are separated by a linebreak followed by three hashes. This should be generated automatically. New lines can be added by clicking 'Generate...' again.""")
title = gr.Textbox(label="Podcast Title")
description = gr.Textbox(label="Description")
conversation_text = gr.Textbox(label="Conversation")
generate_button = gr.Button(value="Generate Additional CWTS Podcast Transcript Lines")
generate_button.click(fn=generate_next, inputs=[title,description,conversation_text], outputs=[title,description,conversation_text])
demo.launch()