File size: 4,661 Bytes
0231e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460e2a9
 
 
 
 
 
 
 
 
0231e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
460e2a9
0231e6a
 
 
 
460e2a9
0231e6a
 
 
 
 
 
 
460e2a9
0231e6a
 
 
460e2a9
0231e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460e2a9
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
import gradio as gr 
from openai import OpenAI
import os
from typing import List
import logging

# add logging info to console 
logging.basicConfig(level=logging.INFO)


BASE_URL = "https://api.together.xyz/v1"
DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY") 
import urllib.request
URIAL_VERSION = "inst_1k_v4.help"

urial_url = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
urial_prompt = urllib.request.urlopen(urial_url).read().decode('utf-8')
urial_prompt = urial_prompt.replace("```", '"""')
stop_str = ['"""', '# Query:', '# Answer:']

def urial_template(urial_prompt, history, message):
    current_prompt = urial_prompt + "\n"
    for user_msg, ai_msg in history:
        current_prompt += f'# Query:\n"""\n{user_msg}\n"""\n\n# Answer:\n"""\n{ai_msg}\n"""\n\n'
    current_prompt += f'# Query:\n"""\n{message}\n"""\n\n# Answer:\n"""\n'
    return current_prompt
    



def openai_base_request(
    model: str=None, 
    temperature: float=0,
    max_tokens: int=512,
    top_p: float=1.0, 
    prompt: str=None,
    n: int=1, 
    repetition_penalty: float=1.0,
    stop: List[str]=None, 
    api_key: str=None,
    ):  
    if api_key is None:
        api_key = DEFAULT_API_KEY
    client = OpenAI(api_key=api_key, base_url=BASE_URL)
    # print(f"Requesting chat completion from OpenAI API with model {model}")
    logging.info(f"Requesting chat completion from OpenAI API with model {model}")
    logging.info(f"Prompt: {prompt}")
    logging.info(f"Temperature: {temperature}")
    logging.info(f"Max tokens: {max_tokens}")
    logging.info(f"Top-p: {top_p}")
    logging.info(f"Repetition penalty: {repetition_penalty}")
    logging.info(f"Stop: {stop}")

    request = client.completions.create(
        model=model, 
        prompt=prompt,
        temperature=float(temperature),
        max_tokens=int(max_tokens),
        top_p=float(top_p),
        n=n,
        extra_body={'repetition_penalty': float(repetition_penalty)},
        stop=stop, 
        stream=True
    ) 
    
    return request 




def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    top_p,
    rp,
    model_name,
    together_api_key
):  
    global stop_str, urial_prompt
    rp = 1.0
    prompt = urial_template(urial_prompt, history, message)
    if model_name == "Llama-3-8B":
        _model_name = "meta-llama/Llama-3-8b-hf"
    elif model_name == "Llama-3-70B":
        _model_name = "meta-llama/Llama-3-70b-hf"
    else:
        raise ValueError("Invalid model name")
    # _model_name = "meta-llama/Llama-3-8b-hf"

    if together_api_key and len(together_api_key) == 64:
        api_key = together_api_key
    else:
        api_key = DEFAULT_API_KEY

    request = openai_base_request(prompt=prompt, model=_model_name, 
                                   temperature=temperature, 
                                   max_tokens=max_tokens, 
                                   top_p=top_p, 
                                   repetition_penalty=rp,
                                   stop=stop_str, api_key=api_key)  
    
    response = ""
    for msg in request:
        # print(msg.choices[0].delta.keys())
        token = msg.choices[0].delta["content"]
        response += token
        should_stop = False
        for _stop in stop_str:
            if _stop in response:
                should_stop = True
                break
        if should_stop:
            break
        yield response 
 
with gr.Blocks() as demo:
    with gr.Row():
        with gr.Column():
            gr.Label("Welcome to the URIAL Chatbot!")
            model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B"], value="Llama-3-8B", label="Base model name")
            together_api_key = gr.Textbox(label="Together API Key", placeholder="Enter your Together API Key. Leave it blank if you want to use the default API key.", type="password")
        with gr.Column():
            with gr.Column():
                with gr.Row():
                    max_tokens = gr.Textbox(value=1024, label="Max tokens")
                    temperature = gr.Textbox(value=0.5, label="Temperature")
            with gr.Column():
                with gr.Row():
                    top_p = gr.Textbox(value=0.9, label="Top-p")
                    rp = gr.Textbox(value=1.1, label="Repetition penalty")

    chat = gr.ChatInterface(
        respond,
        additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key],
        # additional_inputs_accordion="⚙️ Parameters",
        # fill_height=True, 
    )
    chat.chatbot.height = 600


if __name__ == "__main__":
    demo.launch()