Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
"""export_gradio_spaces.ipynb | |
Automatically generated by Colaboratory. | |
Original file is located at | |
https://colab.research.google.com/drive/1mI9gAr_Vtpl2mZsoZoRc56muGboXdsMi | |
# Chargement du modèle GPT-2 entrainé | |
""" | |
#import pip | |
import subprocess | |
import sys | |
def install(package): | |
subprocess.check_call([sys.executable, "-m", "pip", "install", package]) | |
install("typing-extensions") | |
install("gradio") | |
install("keras_nlp") | |
import os | |
from tensorflow import keras | |
import keras_nlp | |
#from google.colab import drive | |
import time | |
import re | |
os.environ["KERAS_BACKEND"] = "tensorflow" # or "tensorflow" or "torch" | |
keras.mixed_precision.set_global_policy("mixed_float16") | |
preprocessor = keras_nlp.models.GPT2CausalLMPreprocessor.from_preset( | |
"gpt2_medium_en", | |
sequence_length=128, | |
) | |
gpt2_lm = keras_nlp.models.GPT2CausalLM.from_preset( | |
"gpt2_medium_en", preprocessor=preprocessor | |
) | |
#drive.mount('/content/drive', force_remount=True) | |
# GPT2 Base 1 epochs | |
#checkpoint_path = "./aloqas_model_checkpoints/cp.ckpt" | |
# GPT2 Base 3 epochs | |
#checkpoint_path = "./aloqas_model_checkpoints_gpt2base_3_epochs/cp.ckpt" | |
# GPT2 Base 3 epochs | |
checkpoint_path = "./aloqas_model_checkpoints_gpt2medium_3_epochs/cp.ckpt" | |
gpt2_lm.load_weights(checkpoint_path) | |
"""# Chargement et configuration de gradio""" | |
""" | |
def generate_text(prompt): | |
return gpt2_lm.generate(prompt, max_length=100) | |
""" | |
# Expérimental | |
def format_text(text, to_remove): | |
# Function to format generated text | |
# - removes the prompt from the answer | |
# - removes unnecessary space chars before punctuation | |
# - capitalize the first letter of each sentence | |
text = text.replace(to_remove, '') | |
text = re.sub(r' +,', ',', text) # comma | |
text = re.sub(r' +\.', '.', text) # full stop | |
text = '. '.join(sentence.strip().capitalize() for sentence in text.split('.')) | |
return text | |
# Expérimental | |
def generate_text(prompt): | |
output = gpt2_lm.generate(prompt, max_length=150) | |
formatted_output = format_text(output, prompt) | |
print("=====================================") | |
print(f"DEBUG - GPT-2 Output : {output}\n") | |
print(f"DEBUG - Formatted GPT-2 Output : {formatted_output}") | |
return formatted_output | |
# CSS styles | |
css = """ | |
@import url('https://fonts.googleapis.com/css2?family=Poppins&display=swap') | |
body, html { | |
height: 100%; | |
margin: 0; | |
font-family : 'Poppins', sans-serif; | |
# font-family: 'Arial', sans-serif; | |
} | |
:root{ | |
--body-background-fill: none !important; | |
} | |
body,html { | |
background-color: #131722; /* Dark background color */ | |
color: #ffffff; | |
} | |
/* Container for the entire chat interface */ | |
#chat-interface { | |
display: flex; | |
flex-direction: column; | |
max-width: 80%; /* Ensure maximum width */ | |
height: 100vh; | |
justify-content: space-between; | |
margin: 0 auto; /* Center the chat interface */ | |
} | |
/* Container for the chat messages */ | |
#chat-messages { | |
flex-grow: 1; | |
overflow-y: auto; | |
background: none; | |
border: 1px solid #627385 !important; | |
max-height: 25%; | |
} | |
/* Styling for the chatbot bubble messages */ | |
.gr-chatbot .chatbubble { | |
max-width: 85%; | |
margin-bottom: 12px; | |
border-radius: 16px; | |
padding: 12px 16px; | |
position: relative; | |
font-size: 1rem; | |
} | |
.gr-chatbot .chatbubble:before { | |
content: ''; | |
position: absolute; | |
width: 0; | |
height: 0; | |
border-style: solid; | |
} | |
/* Chatbot message bubble */ | |
.gr-chatbot .bot .chatbubble { | |
background-color: #2d3e55; /* Darker bubble background */ | |
} | |
/* User message bubble */ | |
.gr-chatbot .user .chatbubble { | |
background-color: #4CAF50; /* Green bubble background */ | |
} | |
/* Input area styling */ | |
#input-area { | |
display: flex; | |
align-items: center; | |
padding: 20px; | |
} | |
/* Text input field styling */ | |
#input-area .gr-textbox { | |
flex: 1; | |
margin-right: 12px; | |
padding: 12px 16px; | |
border: 5px solid #627385; | |
border-radius: 16px; | |
font-size: 1rem; | |
} | |
/* Send button styling */ | |
#input-area button { | |
padding: 12px 20px; | |
background-color: #4CAF50; /* Green button color */ | |
border: none; | |
border-radius: 16px; | |
cursor: pointer; | |
font-size: 1rem; | |
color: #fff; | |
} | |
/* Suggestion buttons styling */ | |
.suggestion-btn { | |
background-color: #2d3e55; /* Dark button color */ | |
color: #ffffff; | |
padding: 10px 50px; | |
margin: 5px; | |
border: 2px solid #627385; | |
border-radius: 20px; | |
cursor: pointer; | |
font-size: 14px; | |
display: inline-block; | |
} | |
/* Suggestions container */ | |
#suggestions { | |
padding: 20px; | |
} | |
/* Style the avatar images if needed */ | |
.gr-chatbot .gr-chatbot-avatar-image { | |
border-radius: 50%; | |
} | |
/* Style for the chatbot avatar */ | |
.gr-chatbot .bot .gr-chatbot-avatar-image { | |
background-image: url('/content/drive/MyDrive/img/ALOQAS logo.png'); | |
} | |
/* Style for the user avatar */ | |
.gr-chatbot .user .gr-chatbot-avatar-image { | |
background-image: url('/content/drive/MyDrive/img/pp discord copie.png'); | |
} | |
/* Additional CSS for layout adjustments */ | |
#header { | |
display: flex; | |
flex-direction: column; | |
max-width: 100%; /* Ensure maximum width */ | |
gap: 20px; | |
justify-content: center; | |
align-items: center; | |
margin: 0 auto; /* Center the chat interface */ | |
} | |
#main-title { | |
font-size: 2.5em; | |
margin-bottom: 0.5em; | |
color: #ffffff; | |
} | |
#sub-title { | |
font-size: 1.5em; | |
margin-bottom: 1em; | |
color: #ffffff; | |
} | |
/* Adjust the chat interface to not grow beyond its container */ | |
#chat-interface { | |
flex: 1; | |
overflow: auto; /* Add scrolling to the chat interface if needed */ | |
} | |
.logo { | |
width: 150px; /* Width of the logo */ | |
height: 150px; /* Height of the logo, should be equal to width for a perfect circle */ | |
background-image: url('https://github.com/LucasAguetai/ALOQAS/blob/main/Ressources/Gif%20Aloqas%20Logo.gif?raw=true'); | |
background-size: cover; /* Cover the entire area of the div without stretching */ | |
background-position: center; /* Center the background image within the div */ | |
border-radius: 50%; /* This will make it circular */ | |
display: inline-block; /* Allows the div to be inline with text and other inline elements */ | |
margin-bottom: 1em; /* Space below the logo */ | |
} | |
.message{ | |
max-width: 50% !important; | |
} | |
.pending{ | |
max-width: 100% !important; | |
} | |
#input-area > *{ | |
padding: 0px; | |
border: 3px solid #627385; | |
} | |
#input-area > * > *{ | |
padding: 0px; | |
background-color: #091E37 !important; | |
} | |
#input-area > * * { | |
border-radius: 0px !important; | |
} | |
.dark{ | |
--background-fill-primary: #091E37 !important; | |
} | |
.send{ | |
max-width: 10px; | |
background-color: #627385 !important; | |
} | |
""" | |
import gradio as gr | |
theme = gr.themes.Base(primary_hue="slate") | |
suggestion = [ | |
"What are the latest advancements in cancer research ?", | |
"What is the impact of diet on heart disease according to recent studies ?", | |
"What are the usual causes of lung diseases ?" | |
] | |
def respond(message, chat_history): | |
response = generate_text(message) | |
chat_history.append((message,response+".")) | |
return "",chat_history | |
def suggestion1(chat_history): | |
response = generate_text(suggestion[0]) | |
chat_history.append((suggestion[0], response+".")) | |
return chat_history | |
def suggestion2(chat_history): | |
response = generate_text(suggestion[1]) | |
chat_history.append((suggestion[1], response+".")) | |
return chat_history | |
def suggestion3(chat_history): | |
response = generate_text(suggestion[2]) | |
chat_history.append((suggestion[2], response+".")) | |
return chat_history | |
""" | |
chat_history.append((None, suggestion[2])) | |
chat_history.append((response, None)) | |
""" | |
user_profile_image = "https://huggingface.co/spaces/ALOQAS/aloqas-gradio/resolve/main/img/user.png" | |
bot_profile_image = "https://huggingface.co/spaces/ALOQAS/aloqas-gradio/resolve/main/img/bot.png" | |
#theme=theme, css=css | |
with gr.Blocks(theme=theme, css=css) as demo: # Suppression de 'css=css' si 'css' n'est pas défini | |
gr.Markdown(""" | |
<div id='header'> | |
<h1 id='main-title'>ALOQAS</h1> | |
<div class='logo'></div> | |
</div> | |
""") | |
with gr.Column(elem_id="chat-interface"): | |
chat = gr.Chatbot( | |
elem_id="chat-messages", | |
show_label=False, | |
avatar_images=[user_profile_image,bot_profile_image], | |
value=[[None,"Hi there ! I'm ALOQAS, a chatbot trained on over 119.000 PubMed scientific papers. Ask me anything about medicine or scientific research !"]] | |
) | |
with gr.Row(elem_id="suggestions"): | |
sugg1 = gr.Button(suggestion[0], elem_classes="suggestion-btn").click( | |
suggestion1, inputs=[chat], outputs=[chat] | |
) | |
sugg2 = gr.Button(suggestion[1], elem_classes="suggestion-btn").click( | |
suggestion2, inputs=[chat], outputs=[chat] | |
) | |
sugg3 = gr.Button(suggestion[2], elem_classes="suggestion-btn").click( | |
suggestion3, inputs=[chat], outputs=[chat] | |
) | |
with gr.Row(elem_id="input-area"): | |
text_input = gr.Textbox(placeholder="Type a question, a sentence or keywords to ALOQAS...", show_label=False) | |
text_input.submit(respond, inputs=[text_input, chat], outputs=[text_input,chat]) | |
demo.launch(share=False) |