dgdgdgdgd / app.py
Yhhxhfh's picture
Update app.py
f665e00 verified
import os
import platform
from dotenv import load_dotenv
import torch
from transformers import GPT2LMHeadModel, GPT2Tokenizer, Trainer, TrainingArguments, DataCollatorForLanguageModeling
from datasets import load_dataset, concatenate_datasets
from huggingface_hub import login
import time
import uvicorn
from fastapi import FastAPI
import threading
import logging
import warnings
# Ignorar advertencias espec铆ficas si lo deseas (opcional)
warnings.filterwarnings("ignore", category=FutureWarning)
# Configurar logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
handlers=[
logging.FileHandler("training.log"),
logging.StreamHandler()
]
)
# Cargar las variables de entorno
load_dotenv()
huggingface_token = os.getenv('HUGGINGFACE_TOKEN')
if huggingface_token is None:
raise ValueError("HUGGINGFACE_TOKEN no encontrado en las variables de entorno.")
# Iniciar sesi贸n en Hugging Face
login(token=huggingface_token)
# Definir la aplicaci贸n FastAPI
app = FastAPI()
@app.get("/")
async def root():
return {"message": "Modelo entrenado y en ejecuci贸n."}
def load_and_train():
model_name = 'gpt2'
tokenizer = GPT2Tokenizer.from_pretrained(model_name)
model = GPT2LMHeadModel.from_pretrained(model_name, return_dict=True)
# Asignar el pad_token al eos_token
tokenizer.pad_token = tokenizer.eos_token
# Redimensionar las embeddings del modelo para incluir el pad_token
model.resize_token_embeddings(len(tokenizer))
# Verificar dispositivo
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
logging.info(f"Entrenando en: {device}")
# Determinar cache_dir
if platform.system() == "Linux":
cache_dir = '/dev/shm'
else:
cache_dir = './cache'
# Crear el directorio de cach茅 si no existe
os.makedirs(cache_dir, exist_ok=True)
# Intentar cargar los datasets con manejo de errores
try:
dataset_humanizado = load_dataset('daily_dialog', split='train', cache_dir=cache_dir, trust_remote_code=True)
dataset_codigo = load_dataset('code_search_net', split='train', cache_dir=cache_dir, trust_remote_code=True)
except Exception as e:
logging.error(f"Error al cargar los datasets: {e}")
# Intentar cargar un dataset alternativo
time.sleep(60) # Esperar 60 segundos antes de reintentar
try:
dataset_humanizado = load_dataset('alternative_dataset', split='train', cache_dir=cache_dir, trust_remote_code=True)
dataset_codigo = load_dataset('alternative_code_dataset', split='train', cache_dir=cache_dir, trust_remote_code=True)
except Exception as e:
logging.error(f"Error al cargar el dataset alternativo: {e}")
return
logging.info("Daily Dialog columnas: %s", dataset_humanizado.column_names)
logging.info("Code Search Net columnas: %s", dataset_codigo.column_names)
# Combinar los datasets en memoria
combined_dataset = concatenate_datasets([dataset_humanizado, dataset_codigo])
logging.info("Dataset combinado columnas: %s", combined_dataset.column_names)
# Funci贸n para crear un campo 'text' estandarizado
def concatenate_text_fields(examples):
"""
Crea un nuevo campo 'text' concatenando los campos de texto disponibles en cada ejemplo.
Prioriza 'dialog', luego 'whole_func_string', y luego 'func_documentation_string'.
Si ninguno est谩 presente, asigna una cadena vac铆a.
Args:
examples (dict): Diccionario con listas de valores para cada columna.
Returns:
dict: Diccionario con el nuevo campo 'text'.
"""
texts = []
# Determinar el tama帽o del lote
num_examples = len(next(iter(examples.values()))) # Obtener el tama帽o del lote
for i in range(num_examples):
text = ''
# Procesar 'dialog'
if 'dialog' in examples and i < len(examples['dialog']) and isinstance(examples['dialog'][i], str) and examples['dialog'][i]:
text = examples['dialog'][i]
# Procesar 'whole_func_string'
elif 'whole_func_string' in examples and i < len(examples['whole_func_string']) and isinstance(examples['whole_func_string'][i], str) and examples['whole_func_string'][i]:
text = examples['whole_func_string'][i]
# Procesar 'func_documentation_string'
elif 'func_documentation_string' in examples and i < len(examples['func_documentation_string']) and isinstance(examples['func_documentation_string'][i], str) and examples['func_documentation_string'][i]:
text = examples['func_documentation_string'][i]
# Puedes a帽adir m谩s campos si es necesario
texts.append(text)
examples['text'] = texts
return examples
# Crear el campo 'text'
combined_dataset = combined_dataset.map(concatenate_text_fields, batched=True)
# Funci贸n de tokenizaci贸n basada en el campo 'text'
def tokenize_function(examples):
tokenized = tokenizer(
examples['text'],
truncation=True,
padding='max_length',
max_length=512
)
tokenized['labels'] = tokenized['input_ids'].copy()
return tokenized
# Tokenizar el dataset
tokenized_dataset = combined_dataset.map(
tokenize_function,
batched=True
)
# Configurar el Data Collator
data_collator = DataCollatorForLanguageModeling(
tokenizer=tokenizer,
mlm=False # Para modelado de lenguaje causal
)
# Configurar argumentos de entrenamiento
training_args = TrainingArguments(
output_dir=os.path.join(cache_dir, 'results'), # Almacenar temporalmente en RAM
per_device_train_batch_size=4,
per_device_eval_batch_size=4,
num_train_epochs=1,
learning_rate=1e-5,
logging_steps=100,
save_total_limit=1,
seed=42,
weight_decay=0.01,
warmup_ratio=0.1,
evaluation_strategy="epoch",
lr_scheduler_type="linear",
save_strategy="epoch", # Guardar solo al final de cada epoch
logging_dir=os.path.join(cache_dir, 'logs'), # Directorio de logs
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_dataset,
data_collator=data_collator,
)
while True:
try:
trainer.train()
# Subir el modelo a Hugging Face desde la RAM
model.push_to_hub(
'Yhhxhfh/nombre_de_tu_modelo',
commit_message="Actualizaci贸n del modelo",
add_to_git_credential=False # Desactivar la configuraci贸n autom谩tica de credenciales de Git
)
tokenizer.push_to_hub(
'Yhhxhfh/nombre_de_tu_modelo',
commit_message="Actualizaci贸n del tokenizador",
add_to_git_credential=False # Desactivar la configuraci贸n autom谩tica de credenciales de Git
)
logging.info("Modelo y tokenizador subidos exitosamente.")
time.sleep(0) # Esperar 0 segundos antes de la siguiente iteraci贸n
except Exception as e:
logging.error(f"Error durante el entrenamiento: {e}. Reiniciando el proceso de entrenamiento...")
time.sleep(0) # Esperar 0 segundos antes de reintentar
if __name__ == "__main__":
# Correr FastAPI en un hilo separado
threading.Thread(target=lambda: uvicorn.run(app, host="0.0.0.0", port=7860), daemon=True).start()
load_and_train()