File size: 1,356 Bytes
38df1da
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
import numpy as np
np.random.seed(0)
import pickle
from sklearn.compose import ColumnTransformer
from sklearn.datasets import fetch_openml
from sklearn.pipeline import Pipeline
from sklearn.impute import SimpleImputer
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

from sklearn import tree

from fastapi import FastAPI, HTTPException
from fastapi.responses import HTMLResponse
from pydantic import BaseModel
from typing import List

class InputData(BaseModel):
    data: List[float]
    
# Inicializar la aplicaci贸n FastAPI
app = FastAPI()

def build_model():
  with open('miarbol.pkl', 'rb') as fid:
    miarbol = pickle.load(fid)  
  return miarbol
   
miarbol = build_model()

# Ruta de predicci贸n
@app.post("/predict/")
async def predict(data: InputData):
    print(f"Data: {data}")
    global miarbol
    try:
        # Convertir la lista de entrada a un array de NumPy para la predicci贸n
        input_data = np.array(data.data).reshape(
            1, -1
        )  # Asumiendo que la entrada debe ser de forma (1, num_features)
        prediction = miarbol.predict(input_data).round()
        return {"prediction": prediction.tolist()}
    except Exception as e:
        raise HTTPException(status_code=500, detail=str(e))