akhaliq HF staff commited on
Commit
eb6ddbd
1 Parent(s): 3f9cf44

update gemini default

Browse files
Files changed (2) hide show
  1. app_experimental.py +237 -0
  2. app_gemini.py +1 -1
app_experimental.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import gradio as gr
3
+ from typing import List, Dict
4
+ import random
5
+ import time
6
+ from utils import get_app
7
+
8
+ # Import all the model registries (keeping existing imports)
9
+ import anthropic_gradio
10
+ import cerebras_gradio
11
+ import dashscope_gradio
12
+ import fireworks_gradio
13
+ import gemini_gradio
14
+ import groq_gradio
15
+ import hyperbolic_gradio
16
+ import mistral_gradio
17
+ import nvidia_gradio
18
+ import openai_gradio
19
+ import perplexity_gradio
20
+ import sambanova_gradio
21
+ import together_gradio
22
+ import xai_gradio
23
+
24
+ # Define MODEL_REGISTRIES dictionary
25
+ MODEL_REGISTRIES = {
26
+ "OpenAI": (openai_gradio.registry, os.getenv("OPENAI_API_KEY")),
27
+ "Anthropic": (anthropic_gradio.registry, os.getenv("ANTHROPIC_API_KEY")),
28
+ "Cerebras": (cerebras_gradio, os.getenv("CEREBRAS_API_KEY")),
29
+ "DashScope": (dashscope_gradio, os.getenv("DASHSCOPE_API_KEY")),
30
+ "Fireworks": (fireworks_gradio, os.getenv("FIREWORKS_API_KEY")),
31
+ "Gemini": (gemini_gradio, os.getenv("GEMINI_API_KEY")),
32
+ "Groq": (groq_gradio, os.getenv("GROQ_API_KEY")),
33
+ "Hyperbolic": (hyperbolic_gradio, os.getenv("HYPERBOLIC_API_KEY")),
34
+ "Mistral": (mistral_gradio, os.getenv("MISTRAL_API_KEY")),
35
+ "NVIDIA": (nvidia_gradio, os.getenv("NVIDIA_API_KEY")),
36
+ "SambaNova": (sambanova_gradio, os.getenv("SAMBANOVA_API_KEY")),
37
+ "Together": (together_gradio, os.getenv("TOGETHER_API_KEY")),
38
+ "XAI": (xai_gradio, os.getenv("XAI_API_KEY")),
39
+ }
40
+
41
+ def get_all_models():
42
+ """Get all available models from the registries."""
43
+ return [
44
+ "OpenAI: gpt-4o", # From app_openai.py
45
+ "Anthropic: claude-3-5-sonnet-20241022", # From app_claude.py
46
+ ]
47
+
48
+ def generate_discussion_prompt(original_question: str, previous_responses: List[str]) -> str:
49
+ """Generate a prompt for models to discuss and build upon previous responses."""
50
+ prompt = f"""You are participating in a multi-AI discussion about this question: "{original_question}"
51
+
52
+ Previous responses from other AI models:
53
+ {chr(10).join(f"- {response}" for response in previous_responses)}
54
+
55
+ Please provide your perspective while:
56
+ 1. Acknowledging key insights from previous responses
57
+ 2. Adding any missing important points
58
+ 3. Respectfully noting if you disagree with anything and explaining why
59
+ 4. Building towards a complete answer
60
+
61
+ Keep your response focused and concise (max 3-4 paragraphs)."""
62
+ return prompt
63
+
64
+ def generate_consensus_prompt(original_question: str, discussion_history: List[str]) -> str:
65
+ """Generate a prompt for final consensus building."""
66
+ return f"""Review this multi-AI discussion about: "{original_question}"
67
+
68
+ Discussion history:
69
+ {chr(10).join(discussion_history)}
70
+
71
+ As a final synthesizer, please:
72
+ 1. Identify the key points where all models agreed
73
+ 2. Explain how any disagreements were resolved
74
+ 3. Present a clear, unified answer that represents our collective best understanding
75
+ 4. Note any remaining uncertainties or caveats
76
+
77
+ Keep the final consensus concise but complete."""
78
+
79
+ def chat_with_openai(model: str, messages: List[Dict], api_key: str) -> str:
80
+ import openai
81
+ client = openai.OpenAI(api_key=api_key)
82
+ response = client.chat.completions.create(
83
+ model=model,
84
+ messages=messages
85
+ )
86
+ return response.choices[0].message.content
87
+
88
+ def chat_with_anthropic(model: str, messages: List[Dict], api_key: str) -> str:
89
+ from anthropic import Anthropic
90
+ client = Anthropic(api_key=api_key)
91
+ # Convert messages to Anthropic format
92
+ prompt = "\n\n".join([f"{m['role']}: {m['content']}" for m in messages])
93
+ response = client.messages.create(
94
+ model=model,
95
+ messages=[{"role": "user", "content": prompt}]
96
+ )
97
+ return response.content[0].text
98
+
99
+ def multi_model_consensus(
100
+ question: str,
101
+ selected_models: List[str],
102
+ rounds: int = 3,
103
+ progress: gr.Progress = gr.Progress()
104
+ ) -> tuple[str, List[Dict]]:
105
+ if not selected_models:
106
+ return "Please select at least one model to chat with.", []
107
+
108
+ chat_history = []
109
+ discussion_history = []
110
+
111
+ # Initial responses
112
+ progress(0, desc="Getting initial responses...")
113
+ initial_responses = []
114
+ for i, model in enumerate(selected_models):
115
+ provider, model_name = model.split(": ", 1)
116
+ registry_fn, api_key = MODEL_REGISTRIES[provider]
117
+
118
+ if not api_key:
119
+ continue
120
+
121
+ try:
122
+ # Load the model using the registry function
123
+ predictor = gr.load(
124
+ name=model_name,
125
+ src=registry_fn,
126
+ token=api_key
127
+ )
128
+
129
+ # Format the request based on the provider
130
+ if provider == "Anthropic":
131
+ response = predictor.predict(
132
+ messages=[{"role": "user", "content": question}],
133
+ max_tokens=1024,
134
+ model=model_name,
135
+ api_name="chat"
136
+ )
137
+ else:
138
+ response = predictor.predict(
139
+ question,
140
+ api_name="chat"
141
+ )
142
+
143
+ initial_responses.append(f"{model}: {response}")
144
+ discussion_history.append(f"Initial response from {model}:\n{response}")
145
+ chat_history.append((f"Initial response from {model}", response))
146
+ except Exception as e:
147
+ chat_history.append((f"Error from {model}", str(e)))
148
+
149
+ # Discussion rounds
150
+ for round_num in range(rounds):
151
+ progress((round_num + 1) / (rounds + 2), desc=f"Discussion round {round_num + 1}...")
152
+ round_responses = []
153
+
154
+ random.shuffle(selected_models) # Randomize order each round
155
+ for model in selected_models:
156
+ provider, model_name = model.split(": ", 1)
157
+ registry, api_key = MODEL_REGISTRIES[provider]
158
+
159
+ if not api_key:
160
+ continue
161
+
162
+ try:
163
+ discussion_prompt = generate_discussion_prompt(question, discussion_history)
164
+ response = registry.chat(
165
+ model=model_name,
166
+ messages=[{"role": "user", "content": discussion_prompt}],
167
+ api_key=api_key
168
+ )
169
+ round_responses.append(f"{model}: {response}")
170
+ discussion_history.append(f"Round {round_num + 1} - {model}:\n{response}")
171
+ chat_history.append((f"Round {round_num + 1} - {model}", response))
172
+ except Exception as e:
173
+ chat_history.append((f"Error from {model} in round {round_num + 1}", str(e)))
174
+
175
+ # Final consensus - use the model that's shown most consistency
176
+ progress(0.9, desc="Building final consensus...")
177
+ # Use the first model for final consensus instead of two models
178
+ model = selected_models[0]
179
+ provider, model_name = model.split(": ", 1)
180
+ registry, api_key = MODEL_REGISTRIES[provider]
181
+
182
+ try:
183
+ consensus_prompt = generate_consensus_prompt(question, discussion_history)
184
+ final_consensus = registry.chat(
185
+ model=model_name,
186
+ messages=[{"role": "user", "content": consensus_prompt}],
187
+ api_key=api_key
188
+ )
189
+ except Exception as e:
190
+ final_consensus = f"Error getting consensus from {model}: {str(e)}"
191
+
192
+ chat_history.append(("Final Consensus", final_consensus))
193
+
194
+ progress(1.0, desc="Done!")
195
+ return chat_history
196
+
197
+ with gr.Blocks() as demo:
198
+ gr.Markdown("# Experimental Multi-Model Consensus Chat")
199
+ gr.Markdown("""Select multiple models to collaborate on answering your question.
200
+ The models will discuss with each other and attempt to reach a consensus.
201
+ Maximum 5 models can be selected at once.""")
202
+
203
+ with gr.Row():
204
+ with gr.Column():
205
+ model_selector = gr.Dropdown(
206
+ choices=get_all_models(),
207
+ multiselect=True,
208
+ label="Select Models (max 5)",
209
+ info="Choose up to 5 models to participate in the discussion",
210
+ value=["OpenAI: gpt-4o", "Anthropic: claude-3-5-sonnet-20241022"], # Updated model names
211
+ max_choices=5
212
+ )
213
+ rounds_slider = gr.Slider(
214
+ minimum=1,
215
+ maximum=5,
216
+ value=3,
217
+ step=1,
218
+ label="Discussion Rounds",
219
+ info="Number of rounds of discussion between models"
220
+ )
221
+
222
+ chatbot = gr.Chatbot(height=600, label="Multi-Model Discussion")
223
+ msg = gr.Textbox(label="Your Question", placeholder="Ask a question for the models to discuss...")
224
+
225
+ def respond(message, selected_models, rounds):
226
+ chat_history = multi_model_consensus(message, selected_models, rounds)
227
+ return chat_history
228
+
229
+ msg.submit(
230
+ respond,
231
+ [msg, model_selector, rounds_slider],
232
+ [chatbot],
233
+ api_name="consensus_chat"
234
+ )
235
+
236
+ if __name__ == "__main__":
237
+ demo.launch()
app_gemini.py CHANGED
@@ -12,7 +12,7 @@ demo = get_app(
12
  "gemini-exp-1114",
13
  "gemini-exp-1121"
14
  ],
15
- default_model="gemini-exp-1121",
16
  src=gemini_gradio.registry,
17
  accept_token=not os.getenv("GEMINI_API_KEY"),
18
  )
 
12
  "gemini-exp-1114",
13
  "gemini-exp-1121"
14
  ],
15
+ default_model="gemini-1.5-pro",
16
  src=gemini_gradio.registry,
17
  accept_token=not os.getenv("GEMINI_API_KEY"),
18
  )