Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, Request | |
from pydantic import BaseModel, Field | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import torch | |
from typing import Optional, List | |
import asyncio | |
from fastapi.responses import StreamingResponse, HTMLResponse | |
import uvicorn | |
import psutil | |
app = FastAPI() | |
dispositivo = torch.device("cpu") | |
CPU_LIMIT = 30.0 | |
RAM_LIMIT = 30.0 | |
html_code = """ | |
<!DOCTYPE html> | |
<html> | |
<head> | |
<title>Chatbot</title> | |
<style> | |
body { font-family: Arial, sans-serif; margin: 50px; } | |
#chat { border: 1px solid #ccc; padding: 10px; height: 300px; overflow-y: scroll; } | |
#input { width: 80%; padding: 10px; } | |
#send { padding: 10px; } | |
</style> | |
</head> | |
<body> | |
<h1>Chatbot</h1> | |
<div id="chat"></div> | |
<input type="text" id="input" placeholder="Escribe tu mensaje..."> | |
<button id="send">Enviar</button> | |
<script> | |
const sendButton = document.getElementById('send'); | |
const inputBox = document.getElementById('input'); | |
const chatBox = document.getElementById('chat'); | |
let history = []; | |
sendButton.addEventListener('click', () => { | |
const message = inputBox.value; | |
if (message.trim() === '') return; | |
history.push(`Tú: ${message}`); | |
chatBox.innerHTML += `<div><strong>Tú:</strong> ${message}</div>`; | |
inputBox.value = ''; | |
fetch('/generar', { | |
method: 'POST', | |
headers: { | |
'Content-Type': 'application/json' | |
}, | |
body: JSON.stringify({ | |
texto: message, | |
history: history | |
}) | |
}) | |
.then(response => { | |
if (!response.body) { | |
throw new Error('No soporta streaming'); | |
} | |
const reader = response.body.getReader(); | |
const decoder = new TextDecoder(); | |
let botMessage = ''; | |
function read() { | |
reader.read().then(({ done, value }) => { | |
if (done) { | |
history.push(`Bot: ${botMessage}`); | |
chatBox.innerHTML += `<div><strong>Bot:</strong> ${botMessage}</div>`; | |
chatBox.scrollTop = chatBox.scrollHeight; | |
return; | |
} | |
const chunk = decoder.decode(value, { stream: true }); | |
botMessage += chunk; | |
chatBox.innerHTML += `<div><strong>Bot:</strong> ${botMessage}</div>`; | |
chatBox.scrollTop = chatBox.scrollHeight; | |
read(); | |
}).catch(error => { | |
chatBox.innerHTML += `<div><strong>Bot:</strong> Error: ${error}</div>`; | |
chatBox.scrollTop = chatBox.scrollHeight; | |
}); | |
} | |
read(); | |
}) | |
.catch(error => { | |
chatBox.innerHTML += `<div><strong>Bot:</strong> Error: ${error}</div>`; | |
chatBox.scrollTop = chatBox.scrollHeight; | |
}); | |
}); | |
</script> | |
</body> | |
</html> | |
""" | |
class Entrada(BaseModel): | |
texto: str = Field(..., example="Hola, ¿cómo estás?") | |
history: Optional[List[str]] = Field(default_factory=list) | |
top_p: Optional[float] = Field(0.95, ge=0.0, le=1.0) | |
top_k: Optional[int] = Field(50, ge=0) | |
temperature: Optional[float] = Field(1.0, gt=0.0) | |
max_length: Optional[int] = Field(100, ge=10, le=1000) | |
chunk_size: Optional[int] = Field(10, ge=1) | |
async def limitar_recursos(request: Request, call_next): | |
cpu = psutil.cpu_percent(interval=0.1) | |
ram = psutil.virtual_memory().percent | |
if cpu > CPU_LIMIT or ram > RAM_LIMIT: | |
raise HTTPException(status_code=503, detail="Servidor sobrecargado. Intenta de nuevo más tarde.") | |
response = await call_next(request) | |
return response | |
def cargar_modelo(): | |
global tokenizador, modelo, eos_token, pad_token | |
tokenizador = AutoTokenizer.from_pretrained("Yhhxhfh/dgdggd") | |
modelo = AutoModelForCausalLM.from_pretrained( | |
"Yhhxhfh/dgdggd", | |
torch_dtype=torch.float32, | |
device_map="cpu" | |
) | |
modelo.eval() | |
eos_token = tokenizador.eos_token | |
pad_token = tokenizador.pad_token | |
async def generar_stream(prompt, top_p, top_k, temperature, max_length, chunk_size): | |
input_ids = tokenizador.encode(prompt, return_tensors="pt").to(dispositivo) | |
outputs = modelo.generate( | |
input_ids, | |
max_length=input_ids.shape[1] + max_length, | |
do_sample=True, | |
top_p=top_p, | |
top_k=top_k, | |
temperature=temperature, | |
no_repeat_ngram_size=2, | |
eos_token_id=tokenizador.eos_token_id if tokenizador.eos_token_id is not None else -1 | |
) | |
generated_ids = outputs[0][input_ids.shape[1]:] | |
generated_text = tokenizador.decode(generated_ids, skip_special_tokens=True) | |
for i in range(0, len(generated_text), chunk_size): | |
yield generated_text[i:i+chunk_size] | |
await asyncio.sleep(0) | |
async def generar_texto(entrada: Entrada): | |
try: | |
prompt = "\n".join(entrada.history + [f"Tú: {entrada.texto}", "Bot:"]) | |
async def stream(): | |
async for chunk in generar_stream( | |
prompt, | |
entrada.top_p, | |
entrada.top_k, | |
entrada.temperature, | |
entrada.max_length, | |
entrada.chunk_size | |
): | |
yield chunk | |
return StreamingResponse(stream(), media_type="text/plain") | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
async def get_home(): | |
return html_code | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |