PROYECTO_2024 / app.py
C2MV's picture
Update app.py
6220229 verified
# app.py
import gradio as gr
import os
from models import load_embedding_model, load_yi_coder_model
from pinecone_utils import connect_to_pinecone, vector_search # Ahora debería funcionar correctamente
from ui import build_interface
from config import SIMILARITY_THRESHOLD_DEFAULT, SYSTEM_PROMPT, MAX_LENGTH_DEFAULT
from decorators import gpu_decorator
import torch
########################
from utils import process_tags_chat
########################
# Cargar modelos
embedding_model = load_embedding_model()
tokenizer, yi_coder_model, yi_coder_device = load_yi_coder_model()
# Conectar a Pinecone
index = connect_to_pinecone()
# Función para generar código utilizando Yi-Coder
@gpu_decorator(duration=100)
def generate_code(system_prompt, user_prompt, max_length):
device = yi_coder_device
model = yi_coder_model
tokenizer_ = tokenizer # Ya lo tenemos cargado
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
]
# Aplicar la plantilla de chat y preparar el texto
text = tokenizer_.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
model_inputs = tokenizer_([text], return_tensors="pt").to(device)
with torch.no_grad():
generated_ids = model.generate(
model_inputs.input_ids,
max_new_tokens=max_length,
eos_token_id=tokenizer_.eos_token_id
)
# Extraer solo la parte generada
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
response = tokenizer_.batch_decode(generated_ids, skip_special_tokens=True)[0]
return response
# Función para combinar búsqueda vectorial y Yi-Coder
@gpu_decorator(duration=100)
def combined_function(user_prompt, similarity_threshold, selected_option, system_prompt, max_length):
def get_partial_message(response):
"""Obtiene el contenido después de 'Respuesta:' si está presente en la respuesta."""
if "Respuesta:" in response:
return response.split("Respuesta:")[1].strip() # Tomar solo el texto después de 'Respuesta:'
else:
return response # Devolver la respuesta completa si no contiene 'Respuesta:'
if selected_option == "Solo Búsqueda Vectorial":
# Realizar búsqueda vectorial
search_results = vector_search(user_prompt, embedding_model, index)
if search_results:
content = search_results[0]['content']
partial_message = get_partial_message(content)
return partial_message, None
else:
return "No se encontraron resultados en Pinecone.", None
elif selected_option == "Solo Yi-Coder":
# Generar respuesta usando Yi-Coder
yi_coder_response = generate_code(system_prompt, user_prompt, max_length)
partial_message = get_partial_message(yi_coder_response)
return partial_message, None
elif selected_option == "Ambos (basado en umbral de similitud)":
# Realizar búsqueda vectorial
search_results = vector_search(user_prompt, embedding_model, index)
if search_results:
top_result = search_results[0]
if top_result['score'] >= similarity_threshold:
content = top_result['content']
partial_message = get_partial_message(content)
return partial_message, None
else:
yi_coder_response = generate_code(system_prompt, user_prompt, max_length)
partial_message = get_partial_message(yi_coder_response)
return partial_message, None
else:
yi_coder_response = generate_code(system_prompt, user_prompt, max_length)
partial_message = get_partial_message(yi_coder_response)
return partial_message, None
else:
return "Opción no válida.", None
# Funciones para el procesamiento de entradas y actualización de imágenes
def process_input(message, history, selected_option, similarity_threshold, system_prompt, max_length):
response, image = combined_function(message, similarity_threshold, selected_option, system_prompt, max_length)
history.append((message, response))
return history, history, image
def update_image(image_url):
"""
Retorna los datos binarios de la imagen para ser mostrados en Gradio.
Args:
image_url (str): Ruta de la imagen.
Returns:
bytes o None: Datos binarios de la imagen si existe, de lo contrario None.
"""
if image_url and os.path.exists(image_url):
try:
with open(image_url, "rb") as img_file:
image_data = img_file.read()
return image_data
except Exception as e:
print(f"Error al leer la imagen: {e}")
return None
else:
print("No se encontró una imagen válida.")
return None
def send_preset_question(question, history, selected_option, similarity_threshold, system_prompt, max_length):
return process_input(question, history, selected_option, similarity_threshold, system_prompt, max_length)
# Construir y lanzar la interfaz
demo = build_interface(process_input, send_preset_question, update_image)
if __name__ == "__main__":
demo.launch()