Spaces:
Sleeping
Sleeping
# app.py | |
import gradio as gr | |
import os | |
from models import load_embedding_model, load_yi_coder_model | |
from pinecone_utils import connect_to_pinecone, vector_search # Ahora debería funcionar correctamente | |
from ui import build_interface | |
from config import SIMILARITY_THRESHOLD_DEFAULT, SYSTEM_PROMPT, MAX_LENGTH_DEFAULT | |
from decorators import gpu_decorator | |
import torch | |
######################## | |
from utils import process_tags_chat | |
######################## | |
# Cargar modelos | |
embedding_model = load_embedding_model() | |
tokenizer, yi_coder_model, yi_coder_device = load_yi_coder_model() | |
# Conectar a Pinecone | |
index = connect_to_pinecone() | |
# Función para generar código utilizando Yi-Coder | |
def generate_code(system_prompt, user_prompt, max_length): | |
device = yi_coder_device | |
model = yi_coder_model | |
tokenizer_ = tokenizer # Ya lo tenemos cargado | |
messages = [ | |
{"role": "system", "content": system_prompt}, | |
{"role": "user", "content": user_prompt} | |
] | |
# Aplicar la plantilla de chat y preparar el texto | |
text = tokenizer_.apply_chat_template( | |
messages, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
model_inputs = tokenizer_([text], return_tensors="pt").to(device) | |
with torch.no_grad(): | |
generated_ids = model.generate( | |
model_inputs.input_ids, | |
max_new_tokens=max_length, | |
eos_token_id=tokenizer_.eos_token_id | |
) | |
# Extraer solo la parte generada | |
generated_ids = [ | |
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) | |
] | |
response = tokenizer_.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return response | |
# Función para combinar búsqueda vectorial y Yi-Coder | |
def combined_function(user_prompt, similarity_threshold, selected_option, system_prompt, max_length): | |
def get_partial_message(response): | |
"""Obtiene el contenido después de 'Respuesta:' si está presente en la respuesta.""" | |
if "Respuesta:" in response: | |
return response.split("Respuesta:")[1].strip() # Tomar solo el texto después de 'Respuesta:' | |
else: | |
return response # Devolver la respuesta completa si no contiene 'Respuesta:' | |
if selected_option == "Solo Búsqueda Vectorial": | |
# Realizar búsqueda vectorial | |
search_results = vector_search(user_prompt, embedding_model, index) | |
if search_results: | |
content = search_results[0]['content'] | |
partial_message = get_partial_message(content) | |
return partial_message, None | |
else: | |
return "No se encontraron resultados en Pinecone.", None | |
elif selected_option == "Solo Yi-Coder": | |
# Generar respuesta usando Yi-Coder | |
yi_coder_response = generate_code(system_prompt, user_prompt, max_length) | |
partial_message = get_partial_message(yi_coder_response) | |
return partial_message, None | |
elif selected_option == "Ambos (basado en umbral de similitud)": | |
# Realizar búsqueda vectorial | |
search_results = vector_search(user_prompt, embedding_model, index) | |
if search_results: | |
top_result = search_results[0] | |
if top_result['score'] >= similarity_threshold: | |
content = top_result['content'] | |
partial_message = get_partial_message(content) | |
return partial_message, None | |
else: | |
yi_coder_response = generate_code(system_prompt, user_prompt, max_length) | |
partial_message = get_partial_message(yi_coder_response) | |
return partial_message, None | |
else: | |
yi_coder_response = generate_code(system_prompt, user_prompt, max_length) | |
partial_message = get_partial_message(yi_coder_response) | |
return partial_message, None | |
else: | |
return "Opción no válida.", None | |
# Funciones para el procesamiento de entradas y actualización de imágenes | |
def process_input(message, history, selected_option, similarity_threshold, system_prompt, max_length): | |
response, image = combined_function(message, similarity_threshold, selected_option, system_prompt, max_length) | |
history.append((message, response)) | |
return history, history, image | |
def update_image(image_url): | |
""" | |
Retorna los datos binarios de la imagen para ser mostrados en Gradio. | |
Args: | |
image_url (str): Ruta de la imagen. | |
Returns: | |
bytes o None: Datos binarios de la imagen si existe, de lo contrario None. | |
""" | |
if image_url and os.path.exists(image_url): | |
try: | |
with open(image_url, "rb") as img_file: | |
image_data = img_file.read() | |
return image_data | |
except Exception as e: | |
print(f"Error al leer la imagen: {e}") | |
return None | |
else: | |
print("No se encontró una imagen válida.") | |
return None | |
def send_preset_question(question, history, selected_option, similarity_threshold, system_prompt, max_length): | |
return process_input(question, history, selected_option, similarity_threshold, system_prompt, max_length) | |
# Construir y lanzar la interfaz | |
demo = build_interface(process_input, send_preset_question, update_image) | |
if __name__ == "__main__": | |
demo.launch() | |