from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer import gradio as gr import spaces import torch import re from threading import Thread from typing import Iterator from datetime import datetime from huggingface_hub import HfApi, hf_hub_download import json import os from gradio_client import Client model_name = "Woziii/llama-3-8b-chat-me" model = AutoModelForCausalLM.from_pretrained(model_name, device_map="auto", torch_dtype=torch.float16) tokenizer = AutoTokenizer.from_pretrained(model_name) MAX_MAX_NEW_TOKENS = 250 DEFAULT_MAX_NEW_TOKENS = 70 MAX_INPUT_TOKEN_LENGTH = 2048 is_first_interaction = True def determine_response_type(message): short_response_keywords = ["salut", "bonjour", "ça va", "comment tu vas", "quoi de neuf", "coucou", "hello", "hi", "bye", "au revoir", "merci", "d'accord", "ok", "super", "cool", "génial", "wow"] long_response_keywords = ["présente", "parle moi de", "explique", "raconte", "décris", "dis moi", "détaille", "précise", "vision", "t'es qui", "pourquoi", "comment", "quel est", "peux-tu développer", "en quoi consiste", "qu'est-ce que", "que penses-tu de", "analyse", "compare", "élabore sur", "expérience", "parcours", "formation", "études", "compétences", "projets", "réalisations"] message_lower = message.lower() if any(keyword.lower() in message_lower for keyword in short_response_keywords): return "short" elif any(keyword.lower() in message_lower for keyword in long_response_keywords): return "long" else: return "medium" def truncate_to_questions(text, max_questions): sentences = re.split(r'(?<=[.!?])\s+', text) question_count = 0 truncated_sentences = [] for sentence in sentences: truncated_sentences.append(sentence) if re.search(r'\?!?$', sentence.strip()): question_count += 1 if question_count >= max_questions: break return ' '.join(truncated_sentences) def post_process_response(response, is_short_response, max_questions=2): truncated_response = truncate_to_questions(response, max_questions) if is_short_response: sentences = re.split(r'(?<=[.!?])\s+', truncated_response) if len(sentences) > 2: return ' '.join(sentences[:2]).strip() return truncated_response.strip() def check_coherence(response): sentences = re.split(r'(?<=[.!?])\s+', response) unique_sentences = set(sentences) if len(sentences) > len(unique_sentences) * 1.1: return False return True @spaces.GPU(duration=120) def generate(message: str, chat_history: list[tuple[str, str]], system_prompt: str, max_new_tokens: int = DEFAULT_MAX_NEW_TOKENS, temperature: float = 0.7, top_p: float = 0.95) -> Iterator[str]: global is_first_interaction if is_first_interaction: warning_message = """⚠️ Attention : Je suis un modèle en version alpha (V.0.0.3.5) et je peux générer des réponses incohérentes ou inexactes. Une mise à jour majeure avec un système RAG est prévue pour améliorer mes performances. Merci de votre compréhension ! 😊 """ yield warning_message is_first_interaction = False response_type = determine_response_type(message) if response_type == "short": max_new_tokens = max(70, max_new_tokens) elif response_type == "long": max_new_tokens = min(200, max_new_tokens) else: max_new_tokens = max(100, max_new_tokens) conversation = [] enhanced_system_prompt = f"{system_prompt}\n\n{LUCAS_KNOWLEDGE_BASE}" conversation.append({"role": "system", "content": enhanced_system_prompt}) for user, _ in chat_history[-5:]: conversation.append({"role": "user", "content": user}) conversation.append({"role": "user", "content": message}) input_ids = tokenizer.apply_chat_template(conversation, return_tensors="pt") if input_ids.shape[1] > MAX_INPUT_TOKEN_LENGTH: input_ids = input_ids[:, -MAX_INPUT_TOKEN_LENGTH:] gr.Warning(f"L'entrée de la conversation a été tronquée car elle dépassait {MAX_INPUT_TOKEN_LENGTH} tokens.") input_ids = input_ids.to(model.device) streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True) generate_kwargs = dict(input_ids=input_ids, streamer=streamer, max_new_tokens=max_new_tokens, do_sample=True, top_p=top_p, temperature=temperature, num_beams=1) t = Thread(target=model.generate, kwargs=generate_kwargs) t.start() outputs = [] for text in streamer: outputs.append(text) partial_output = post_process_response("".join(outputs), response_type == "short") if response_type == "long" and not check_coherence(partial_output): yield "Je m'excuse, ma réponse manquait de cohérence. Pouvez-vous reformuler votre question ?" return yield partial_output yield post_process_response("".join(outputs), response_type == "short") def vote(data: gr.LikeData, history): user_input = history[-1][0] if history else "" feedback = { "timestamp": datetime.now().isoformat(), "user_input": user_input, "bot_response": data.value, "liked": data.liked } api = HfApi() token = os.environ.get("HF_TOKEN") repo_id = "Woziii/llama-3-8b-chat-me" file_name = "feedback.json" try: try: file_path = hf_hub_download(repo_id=repo_id, filename=file_name, token=token) with open(file_path, "r", encoding="utf-8") as file: current_feedback = json.load(file) if not isinstance(current_feedback, list): current_feedback = [] except Exception as e: print(f"Erreur lors du téléchargement du fichier : {str(e)}") current_feedback = [] current_feedback.append(feedback) updated_content = json.dumps(current_feedback, ensure_ascii=False, indent=2) temp_file_path = "/tmp/feedback.json" with open(temp_file_path, "w", encoding="utf-8") as temp_file: temp_file.write(updated_content) api.upload_file(path_or_fileobj=temp_file_path, path_in_repo=file_name, repo_id=repo_id, token=token) print(f"Feedback enregistré dans {repo_id}/{file_name}") except Exception as e: print(f"Erreur lors de l'enregistrement du feedback : {str(e)}") theme = gr.themes.Default().set( body_background_fill="#f0f0f0", button_primary_background_fill="#4a90e2", button_primary_background_fill_hover="#3a7bc8", button_primary_text_color="white", ) css = """ .gradio-container { font-family: 'Arial', sans-serif; } .chatbot-message { padding: 10px; border-radius: 15px; margin-bottom: 10px; } .user-message { background-color: #e6f3ff; } .bot-message { background-color: #f0f0f0; } .thought-bubble { background-color: #ffd700; border-radius: 10px; padding: 5px; margin-top: 5px; font-style: italic; } """ def format_message(message): if isinstance(message, str): return message elif isinstance(message, dict): if "thought" in message: return f"{message['content']}\n