File size: 3,882 Bytes
422e63e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17fd63b
 
422e63e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c7a7d18
66d2a7a
 
566a3b7
422e63e
 
c7a7d18
422e63e
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import os, transformers, peft, torch, gradio as gr

base_model = "h2oai/h2ogpt-4096-llama2-13b"
model = transformers.AutoModelForCausalLM.from_pretrained(
    base_model,    
    load_in_8bit=True,
    torch_dtype=torch.float16
)
tokenizer = transformers.AutoTokenizer.from_pretrained(base_model)

lora_model = "pcalhoun/Llama-2-13b-Conversations-With-Tyler-Swift"
model = peft.PeftModel.from_pretrained(
    model,
    lora_model,
    torch_dtype=torch.float16
)

def generate(prompt, extra_eos=[]):
    inputs = tokenizer(prompt, return_tensors="pt", return_attention_mask=False)
    input_token_length = inputs.input_ids.shape[1]
    outputs = model.generate(**inputs, max_length=4096)
    text = tokenizer.batch_decode(outputs)[0]
    return text

def create_next_prompt(title_string,description_string="",conversation_messages=[]):
    if not len(description_string):
        conversation_messages = []
    prompt = """<s>### CONVERSATIONS WITH TYLER SWIFT ###

TITLE: """ + title_string.strip() + """

DESCRIPTION:"""
    if not len(description_string):
        return prompt
    else:
        prompt += " "+description_string.replace("\n\n","\n").strip() + "\n\n"
    if not len(conversation_messages):
        prompt += "### TYLER SWIFT:"
        return prompt
    else:
        for message_data in conversation_messages:
            prompt += "### " + message_data['speaker'].upper() + ": " + message_data['message'].strip()
            if message_data['speaker'].upper() == "TYLER SWIFT":
                prompt += "</s><s>"
            prompt += "\n"
        if conversation_messages[-1]["speaker"].upper() != "TYLER SWIFT":
            prompt += "### TYLER SWIFT:"
        return prompt

def deconstruct_returned_text(text):
    #skip first line
    text = "\n".join(text.split("\n")[1:]).strip()
    title = text.split("\n")[0].replace("TITLE:","").strip()
    text = "\n".join(text.split("\n")[1:]).strip()
    description = text.split("\n\n")[0].replace("DESCRIPTION:","").strip()
    text = "\n\n".join(text.split("\n\n")[1:]).strip()
    conversation_text = text.replace("</s>", "").replace("<s>", "").split("<<")[0].strip()
    return title,description,conversation_text

def generate_next(title,description,conversation_text):
    if not len(title):
        title = "Set a Title First"
        return title,description,conversation_text
    if title == "Set a Title First":
        title = "On Olympic Curling Injuries and Glaswegian Alcoholism"
    conversation = []
    for line in conversation_text.split("\n"):
        if "CONVERSATIONS WITH TYLER SWIFT" in line:
            continue
        if line.startswith("###"):
            speaker = line.split(":")[0].replace("###","").strip()
            message = ":".join(line.split(":")[1:]).strip()
            conversation.append({"speaker":speaker,"message":message.replace("</s>", "").replace("<s>", "").strip()})
    prompt = create_next_prompt(title,description,conversation)
    generated_text = generate(prompt)
    print("GENERATED TEXT:",generated_text)
    title,description,conversation_text = deconstruct_returned_text(generated_text)
    return title,description,conversation_text


with gr.Blocks() as demo:
    gr.Markdown("""Come up with a podcast title and then click 'Generate Additional CWTS Podcast Transcript Lines'.

Conversation lines are separated by a linebreak followed by three hashes. This should be generated automatically. New lines can be added by clicking 'Generate...' again.""")
    title = gr.Textbox(label="Podcast Title")
    description = gr.Textbox(label="Description")
    conversation_text = gr.Textbox(label="Conversation")
    generate_button = gr.Button(value="Generate Additional CWTS Podcast Transcript Lines")
    generate_button.click(fn=generate_next, inputs=[title,description,conversation_text], outputs=[title,description,conversation_text])

demo.launch()