xor / app.py
RaulHuarote's picture
Update app.py
0e923cd verified
from keras.api.models import Sequential
from keras.api.layers import InputLayer, Dense
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import numpy as np
from typing import List
class InputData(BaseModel):
data: List[float] # Lista de caracter铆sticas num茅ricas (flotantes)
app = FastAPI()
# Funci贸n para construir el modelo manualmente
def build_model():
model = Sequential(
[
InputLayer(
input_shape=(2,), name="dense_2_input"
), # Ajusta el tama帽o de entrada seg煤n tu modelo
Dense(16, activation="relu", name="dense_2"),
Dense(1, activation="sigmoid", name="dense_3"),
]
)
model.load_weights(
"model.h5"
) # Aseg煤rate de que los nombres de las capas coincidan para que los pesos se carguen correctamente
model.compile(
loss="mean_squared_error", optimizer="adam", metrics=["binary_accuracy"]
)
return model
model = build_model() # Construir el modelo al iniciar la aplicaci贸n
# Ruta de predicci贸n
@app.post("/predict/")
async def predict(data: InputData):
print(f"Data: {data}")
global model
try:
# Convertir la lista de entrada a un array de NumPy para la predicci贸n
input_data = np.array(data.data).reshape(
1, -1
) # Asumiendo que la entrada debe ser de forma (1, num_features)
prediction = model.predict(input_data).round()
return {"prediction": prediction.tolist()}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))