import gradio as gr import random import openai from openai import APIError, APIConnectionError, RateLimitError import os from PIL import Image # This is the corrected import import io import base64 import asyncio from queue import Queue from threading import Thread import time # Get the current script's directory current_dir = os.path.dirname(os.path.abspath(__file__)) avatars_dir = os.path.join(current_dir, "avatars") # Dictionary mapping characters to their avatar image filenames character_avatars = { "Harry Potter": "harry.png", "Hermione Granger": "hermione.png", "poor Ph.D. student": "phd.png", "a super cute red panda": "red_panda.png" } BACKUP_API_KEY_0 = os.environ.get('BACKUP_API_KEY_0') BACKUP_API_KEY_1 = os.environ.get('BACKUP_API_KEY_1') BACKUP_API_KEYS = [BACKUP_API_KEY_0, BACKUP_API_KEY_1] predefined_characters = ["Harry Potter", "Hermione Granger", "poor Ph.D. student", "a super cute red panda"] def get_character(dropdown_value, custom_value): return custom_value if dropdown_value == "Custom" else dropdown_value def resize_image(image_path, size=(100, 100)): if not os.path.exists(image_path): return None with Image.open(image_path) as img: img.thumbnail(size) buffered = io.BytesIO() img.save(buffered, format="PNG") return base64.b64encode(buffered.getvalue()).decode() resized_avatars = {} for character, filename in character_avatars.items(): full_path = os.path.join(avatars_dir, filename) if os.path.exists(full_path): resized_avatars[character] = resize_image(full_path) else: pass async def generate_response_stream(messages, user_api_key): # Combine the user's API key with your backup keys api_keys = [user_api_key] + BACKUP_API_KEYS # backup_api_keys is a list of your internal keys for idx, api_key in enumerate(api_keys): client = openai.AsyncOpenAI( api_key=api_key, base_url="https://api.sambanova.ai/v1", ) try: response = await client.chat.completions.create( model='Meta-Llama-3.1-405B-Instruct', messages=messages, temperature=0.7, top_p=0.9, stream=True ) full_response = "" async for chunk in response: if chunk.choices[0].delta.content: full_response += chunk.choices[0].delta.content yield full_response # If successful, exit the loop return except RateLimitError: if idx == len(api_keys) - 1: # No more API keys to try raise Exception("Rate limit exceeded") else: # Try the next API key continue except Exception as e: # For other exceptions, raise the error raise e async def simulate_conversation_stream(character1, character2, initial_message, num_turns, api_key): messages_character_1 = [ {"role": "system", "content": f"Avoid overly verbose answer in your response. Act as {character1}."}, {"role": "assistant", "content": initial_message} ] messages_character_2 = [ {"role": "system", "content": f"Avoid overly verbose answer in your response. Act as {character2}."}, {"role": "user", "content": initial_message} ] conversation = [ {"character": character1, "content": initial_message}, # We will add new messages as we loop ] yield format_conversation_as_html(conversation) num_turns *= 2 for turn_num in range(num_turns - 1): current_character = character2 if turn_num % 2 == 0 else character1 messages = messages_character_2 if turn_num % 2 == 0 else messages_character_1 # Add a new empty message for the current character conversation.append({"character": current_character, "content": ""}) full_response = "" try: async for response in generate_response_stream(messages, api_key): full_response = response conversation[-1]["content"] = full_response yield format_conversation_as_html(conversation) # After a successful response, update the messages if turn_num % 2 == 0: messages_character_1.append({"role": "user", "content": full_response}) messages_character_2.append({"role": "assistant", "content": full_response}) else: messages_character_2.append({"role": "user", "content": full_response}) messages_character_1.append({"role": "assistant", "content": full_response}) except Exception as e: # Replace the current message with the error message error_message = f"Error: {str(e)}" conversation[-1]["character"] = "System" conversation[-1]["content"] = error_message yield format_conversation_as_html(conversation) # Stop the conversation break def stream_conversation(character1, character2, initial_message, num_turns, api_key, queue): async def run_simulation(): try: async for html in simulate_conversation_stream(character1, character2, initial_message, num_turns, api_key): queue.put(html) queue.put(None) # Signal that the conversation is complete except Exception as e: # Handle exceptions and put the error message in the queue error_message = f"Error: {str(e)}" queue.put(error_message) queue.put(None) # Signal that the conversation is complete # Create a new event loop for the thread loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) loop.run_until_complete(run_simulation()) loop.close() def validate_api_key(api_key): if not api_key.strip(): return False, "API key is required. Please enter a valid API key." return True, "" def update_api_key_status(api_key): is_valid, message = validate_api_key(api_key) if not is_valid: return f"

{message}

" return "" def chat_interface(character1_dropdown, character1_custom, character2_dropdown, character2_custom, initial_message, num_turns, api_key): character1 = get_character(character1_dropdown, character1_custom) character2 = get_character(character2_dropdown, character2_custom) queue = Queue() thread = Thread(target=stream_conversation, args=(character1, character2, initial_message, num_turns, api_key, queue)) thread.start() while True: result = queue.get() if result is None: break yield result thread.join() def format_conversation_as_html(conversation): html_output = """
""" for i, message in enumerate(conversation): align = "left" if i % 2 == 0 else "right" avatar_data = resized_avatars.get(message["character"]) html_output += f'
' if avatar_data: html_output += f'''
{message[
''' html_output += f'''
{message["character"]}
{message["content"]}
''' html_output += "
" return html_output def format_chat_for_download(html_chat): # Extract text content from HTML import re chat_text = re.findall(r'
(.*?)
.*?
(.*?)
', html_chat, re.DOTALL) return "\n".join([f"{speaker.strip()}: {message.strip()}" for speaker, message in chat_text]) def save_chat_to_file(chat_content): # Create a downloads directory if it doesn't exist downloads_dir = os.path.join(os.getcwd(), "downloads") os.makedirs(downloads_dir, exist_ok=True) # Generate a unique filename import datetime timestamp = datetime.datetime.now().strftime("%Y%m%d_%H%M%S") filename = f"chat_{timestamp}.txt" file_path = os.path.join(downloads_dir, filename) # Save the chat content to the file with open(file_path, "w", encoding="utf-8") as f: f.write(chat_content) return file_path with gr.Blocks() as app: gr.Markdown("# Character Chat Generator") gr.Markdown("Powerd by [LLama3.1-405B-Instruct](https://huggingface.co/meta-llama/Meta-Llama-3.1-405B-Instruct) on [SambaNova Cloud](https://cloud.sambanova.ai/apis)") api_key = gr.Textbox(label="Enter your Sambanova Cloud API Key\n(To get one, go to https://cloud.sambanova.ai/apis)", type="password") api_key_status = gr.Markdown() with gr.Column(): character1_dropdown = gr.Dropdown(choices=predefined_characters + ["Custom"], label="Select Character 1") character1_custom = gr.Textbox(label="Custom Character 1 (if selected above)", visible=False) with gr.Column(): character2_dropdown = gr.Dropdown(choices=predefined_characters + ["Custom"], label="Select Character 2") character2_custom = gr.Textbox(label="Custom Character 2 (if selected above)", visible=False) initial_message = gr.Textbox(label="Initial message (for Character 1)") num_turns = gr.Slider(minimum=1, maximum=10, step=1, value=5, label="Number of conversation turns") generate_btn = gr.Button("Generate Conversation") output = gr.HTML(label="Generated Conversation") def show_custom_input(choice): return gr.update(visible=choice == "Custom") character1_dropdown.change(show_custom_input, inputs=character1_dropdown, outputs=character1_custom) character2_dropdown.change(show_custom_input, inputs=character2_dropdown, outputs=character2_custom) api_key.change(update_api_key_status, inputs=[api_key], outputs=[api_key_status]) generate_btn.click( chat_interface, inputs=[character1_dropdown, character1_custom, character2_dropdown, character2_custom, initial_message, num_turns, api_key], outputs=output, ) gr.Markdown("## Download Chat History") download_btn = gr.Button("Download Conversation") download_output = gr.File(label="Download") def download_conversation(html_chat): chat_content = format_chat_for_download(html_chat) file_path = save_chat_to_file(chat_content) return file_path download_btn.click( download_conversation, inputs=output, outputs=download_output ) app.load(lambda: update_api_key_status(""), outputs=[api_key_status]) if __name__ == "__main__": app.launch()