File size: 6,948 Bytes
0231e6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
460e2a9
 
 
 
 
 
 
 
 
0231e6a
 
 
 
 
 
 
 
 
 
 
9e82682
 
 
 
 
 
 
 
 
 
 
 
0231e6a
 
 
460e2a9
0231e6a
 
 
 
460e2a9
0231e6a
 
 
 
 
 
 
460e2a9
0231e6a
 
 
 
 
9e82682
0231e6a
 
 
 
9e82682
 
 
 
 
 
0231e6a
9e82682
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0231e6a
 
9e82682
 
 
 
 
 
0231e6a
9e82682
0231e6a
 
 
 
9e82682
 
0231e6a
 
9e82682
0231e6a
 
 
 
 
 
 
f358cdd
460e2a9
9e82682
 
460e2a9
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
import gradio as gr 
from openai import OpenAI
import os
from typing import List
import logging

# add logging info to console 
logging.basicConfig(level=logging.INFO)


BASE_URL = "https://api.together.xyz/v1"
DEFAULT_API_KEY = os.getenv("TOGETHER_API_KEY") 
import urllib.request
URIAL_VERSION = "inst_1k_v4.help"

urial_url = f"https://raw.githubusercontent.com/Re-Align/URIAL/main/urial_prompts/{URIAL_VERSION}.txt"
urial_prompt = urllib.request.urlopen(urial_url).read().decode('utf-8')
urial_prompt = urial_prompt.replace("```", '"""')
stop_str = ['"""', '# Query:', '# Answer:']

def urial_template(urial_prompt, history, message):
    current_prompt = urial_prompt + "\n"
    for user_msg, ai_msg in history:
        current_prompt += f'# Query:\n"""\n{user_msg}\n"""\n\n# Answer:\n"""\n{ai_msg}\n"""\n\n'
    current_prompt += f'# Query:\n"""\n{message}\n"""\n\n# Answer:\n"""\n'
    return current_prompt
    



def openai_base_request(
    model: str=None, 
    temperature: float=0,
    max_tokens: int=512,
    top_p: float=1.0, 
    prompt: str=None,
    n: int=1, 
    repetition_penalty: float=1.0,
    stop: List[str]=None, 
    api_key: str=None,
    ):  
    if api_key is None:
        api_key = DEFAULT_API_KEY
    client = OpenAI(api_key=api_key, base_url=BASE_URL)
    # print(f"Requesting chat completion from OpenAI API with model {model}")
    logging.info(f"Requesting chat completion from OpenAI API with model {model}")
    logging.info(f"Prompt: {prompt}")
    logging.info(f"Temperature: {temperature}")
    logging.info(f"Max tokens: {max_tokens}")
    logging.info(f"Top-p: {top_p}")
    logging.info(f"Repetition penalty: {repetition_penalty}")
    logging.info(f"Stop: {stop}")

    request = client.completions.create(
        model=model, 
        prompt=prompt,
        temperature=float(temperature),
        max_tokens=int(max_tokens),
        top_p=float(top_p),
        n=n,
        extra_body={'repetition_penalty': float(repetition_penalty)},
        stop=stop, 
        stream=True
    ) 
    
    return request 




def respond(
    message,
    history: list[tuple[str, str]],
    max_tokens,
    temperature,
    top_p,
    rp,
    model_name,
    together_api_key
):  
    global stop_str, urial_prompt
    rp = 1.0
    prompt = urial_template(urial_prompt, history, message)
    if model_name == "Llama-3-8B":
        _model_name = "meta-llama/Llama-3-8b-hf"
    elif model_name == "Llama-3-70B":
        _model_name = "meta-llama/Llama-3-70b-hf"
    elif model_name == "Llama-2-7B":
        _model_name = "meta-llama/Llama-2-7b-hf"
    elif model_name == "Llama-2-70B":
        _model_name = "meta-llama/Llama-2-70b-hf"
    elif model_name == "Mistral-7B-v0.1":
        _model_name = "mistralai/Mistral-7B-v0.1"
    elif model_name == "mistralai/Mixtral-8x22B":
        _model_name = "mistralai/Mixtral-8x22B"
    elif model_name == "Qwen1.5-72B":
        _model_name = "Qwen/Qwen1.5-72B"
    elif model_name == "Yi-34B":
        _model_name = "zero-one-ai/Yi-34B"
    else:
        raise ValueError("Invalid model name")
    # _model_name = "meta-llama/Llama-3-8b-hf"

    if together_api_key and len(together_api_key) == 64:
        api_key = together_api_key
    else:
        api_key = DEFAULT_API_KEY

    request = openai_base_request(prompt=prompt, model=_model_name, 
                                   temperature=temperature, 
                                   max_tokens=max_tokens, 
                                   top_p=top_p, 
                                   repetition_penalty=rp,
                                   stop=stop_str, api_key=api_key)  
    
    response = ""
    for msg in request:
        # print(msg.choices[0].delta.keys())
        token = msg.choices[0].delta["content"]
        should_stop = False
        for _stop in stop_str:
            if _stop in response + token:
                should_stop = True
                break
        if should_stop:
            break
        response += token
        if response.endswith('\n"'):
            response = response[:-1]
        elif response.endswith('\n""'):
            response = response[:-2]
        yield response
 
js_code_label = """
function addApiKeyLink() {
    // Select the div with id 'api_key'
    const apiKeyDiv = document.getElementById('api_key');

    // Find the span within that div with data-testid 'block-info'
    const blockInfoSpan = apiKeyDiv.querySelector('span[data-testid="block-info"]');

    // Create the new link element
    const newLink = document.createElement('a');
    newLink.href = 'https://api.together.ai/settings/api-keys';
    newLink.textContent = ' View your keys here.';
    newLink.target = '_blank'; // Open link in new tab
    newLink.style = 'color: #007bff; text-decoration: underline;';

    // Create the additional text
    const additionalText = document.createTextNode(' (new account will have free credits to use.)');

    // Append the link and additional text to the span
    if (blockInfoSpan) {
        // add a br 
        apiKeyDiv.appendChild(document.createElement('br'));
        apiKeyDiv.appendChild(newLink);
        apiKeyDiv.appendChild(additionalText);
    } else {
        console.error('Span with data-testid "block-info" not found');
    }
}
"""
with gr.Blocks(gr.themes.Soft(), js=js_code_label) as demo:
    with gr.Row():
        with gr.Column():
            gr.Markdown("""# 💬 BaseChat: Chat with Base LLMs with URIAL
                        [Paper](https://arxiv.org/abs/2312.01552) | [Website](https://allenai.github.io/re-align/) | [GitHub](https://github.com/Re-Align/urial) | Contact: [Yuchen Lin](https://yuchenlin.xyz/)

                        **Talk with __BASE__ LLMs which are not fine-tuned at all.**
                        """)
            model_name = gr.Radio(["Llama-3-8B", "Llama-3-70B", "Mistral-7B-v0.1", "mistralai/Mixtral-8x22B", "Yi-34B", "Llama-2-7B", "Llama-2-70B"], value="Llama-3-8B", label="Base LLM name")
        with gr.Column():
            together_api_key = gr.Textbox(label="🔑 Together APIKey", placeholder="Enter your Together API Key. Leave it blank if you want to use the default API key.", type="password", elem_id="api_key")
            with gr.Column():
                with gr.Row():
                    max_tokens = gr.Textbox(value=1024, label="Max tokens")
                    temperature = gr.Textbox(value=0.5, label="Temperature")
            # with gr.Column():
            #     with gr.Row():
                    top_p = gr.Textbox(value=0.9, label="Top-p")
                    rp = gr.Textbox(value=1.1, label="Repetition penalty")
            

    chat = gr.ChatInterface(
        respond,
        additional_inputs=[max_tokens, temperature, top_p, rp, model_name, together_api_key],
        # additional_inputs_accordion="⚙️ Parameters",
        # fill_height=True, 
    )
    chat.chatbot.height = 550

    


if __name__ == "__main__":
    demo.launch()