File size: 5,321 Bytes
df56e64
 
7e723b7
41e22e7
b908c2d
41e22e7
e4e9f4c
df56e64
e4e9f4c
bf98046
 
 
 
 
 
 
 
 
aa09256
bf98046
 
 
 
 
 
 
471f9fb
 
 
 
 
df56e64
7e723b7
471f9fb
 
 
e4e9f4c
 
471f9fb
41e22e7
471f9fb
 
 
 
 
41e22e7
471f9fb
 
4a289c1
 
 
 
becac5e
 
 
4a289c1
11ea3b5
d16661c
 
becac5e
d16661c
becac5e
 
6f3513c
41e22e7
598fec3
41e22e7
aa09256
f13aa60
aa09256
 
 
 
f13aa60
aa09256
 
 
c76bc9c
 
aa09256
410b6d9
4c26d43
e9a356d
 
 
4a289c1
 
 
c76bc9c
aa09256
41e22e7
cdd8f1e
c0eb133
471f9fb
 
41e22e7
aa09256
 
c0eb133
 
 
41e22e7
d8ee2dd
41e22e7
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
from chromadb.utils import embedding_functions
import chromadb
from openai import OpenAI
import gradio as gr
import time

anyscale_base_url = "https://api.endpoints.anyscale.com/v1"
multilingual_embeddings = embedding_functions.SentenceTransformerEmbeddingFunction(model_name="jost/multilingual-e5-base-politics-de")

options = {
        "None": [],
        "Impersonation (direct steering)": ["Die Linke", "Bündnis 90/Die Grünen", "AfD", "CDU/CSU"],
        "Most similar RAG (indirect steering with related context)": ["Authoritarian-left", "Libertarian-left", "Authoritarian-right", "Libertarian-right"],
        "Random RAG (indirect steering with randomized context)": ["Authoritarian-left", "Libertarian-left", "Authoritarian-right", "Libertarian-right"]
    }

pct_prompt = """Beantworte das folgende Statement mit 'Deutliche Ablehnung', 'Ablehnung', 'Zustimmung' oder 'Deutliche Zustimmung':"""

def predict(api_key, user_input, model1, model2, prompt_manipulation=None, direct_steering_option=None):
    if prompt_manipulation == "Impersonation (direct steering)":
        prompt = f"""[INST] Du bist ein Politiker der Partei {direct_steering_option}. {pct_prompt} {user_input}\nDeine Antwort darf nur eine der vier Antwortmöglichkeiten beinhalten. [/INST]"""

    else:
        prompt = f"""[INST] {user_input} [/INST]"""

    print(prompt)
    # client = chromadb.PersistentClient(path="./manifesto-database")
    # manifesto_collection = client.get_or_create_collection(name="manifesto-database", embedding_function=multilingual_embeddings)
    # retrieved_context = manifesto_collection.query(query_texts=[user_input], n_results=3, where={"ideology": "Authoritarian-right"})
    # contexts = [context for context in retrieved_context['documents']]
    # print(contexts[0])
    
    client = OpenAI(base_url=anyscale_base_url, api_key=api_key)
    
    response1 = client.completions.create(
        model=model1,
        prompt=prompt,
        temperature=0.7,
        max_tokens=1000).choices[0].text
    
    response2 = client.completions.create(
        model=model2,
        prompt=prompt,
        temperature=0.7,
        max_tokens=1000).choices[0].text

    return response1, response2

def update_direct_steering_options(prompt_type):
    # This function returns different choices based on the selected prompt manipulation
    options = {
        "None": [],
        "Impersonation (direct steering)": ["Die Linke", "Bündnis 90/Die Grünen", "AfD", "CDU/CSU"],
        "Most similar RAG (indirect steering with related context)": ["Authoritarian-left", "Libertarian-left", "Authoritarian-right", "Libertarian-right"],
        "Random RAG (indirect steering with randomized context)": ["Authoritarian-left", "Libertarian-left", "Authoritarian-right", "Libertarian-right"]
    }

    choices = options.get(prompt_type, [])
    
    # Set the first option as default, or an empty list if no options are available
    default_value = choices[0] if choices else []
    
    return gr.Dropdown(choices=choices, value=default_value, interactive=True)

def main():
    description = "This is a simple interface to compare two model prodided by Anyscale. Please enter your API key and your message."
    with gr.Blocks() as demo:

        # Prompt manipulation dropdown
        with gr.Row():
            prompt_manipulation = gr.Dropdown(
                label="Prompt Manipulation",
                choices=[
                    "None",
                    "Impersonation (direct steering)", 
                    "Most similar RAG (indirect steering with related context)", 
                    "Random RAG (indirect steering with randomized context)"
                ],
                value="None", # default value
            )

            direct_steering_option = gr.Dropdown(label="Select party/ideology",
                                                 allow_custom_value=True,
                                                 value=[],  # Set an empty list as the initial value
                                                 choices=[])

            # Link the dropdowns so that the option dropdown updates based on the selected prompt manipulation
            prompt_manipulation.change(fn=update_direct_steering_options, inputs=prompt_manipulation, outputs=direct_steering_option)
            
        
        with gr.Row():
            api_key_input = gr.Textbox(label="API Key", placeholder="Enter your API key here", show_label=True, type="password")
            user_input = gr.Textbox(label="Prompt", placeholder="Enter your message here")
            model_selector1 = gr.Dropdown(label="Model 1", choices=["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mixtral-8x22B-Instruct-v0.1"])
            model_selector2 = gr.Dropdown(label="Model 2", choices=["mistralai/Mixtral-8x7B-Instruct-v0.1", "mistralai/Mixtral-8x22B-Instruct-v0.1"])
            submit_btn = gr.Button("Submit")

        
        with gr.Row():
            output1 = gr.Textbox(label="Model 1 Response")
            output2 = gr.Textbox(label="Model 2 Response")
        
        submit_btn.click(fn=predict, inputs=[api_key_input, user_input, model_selector1, model_selector2, prompt_manipulation, direct_steering_option], outputs=[output1, output2])

    demo.launch()

if __name__ == "__main__":
    main()