from fastapi import FastAPI, HTTPException, Request from fastapi.responses import StreamingResponse import lightgbm as lgb from pydantic import BaseModel, validator from datetime import datetime import pandas as pd import json import io import zipfile import matplotlib.pyplot as plt from data import get_all_features import warnings warnings.filterwarnings("ignore") class DataPoint(BaseModel): timestamp: datetime storage_charge: float heat_pump: float circulation_pump: float air_conditioning: float ventilation: float dishwasher: float washing_machine: float refrigerator: float freezer: float cooling_aggregate: float facility: float total: float @validator('timestamp', pre=True) def parse_timestamp(cls, value): if isinstance(value, str): return datetime.fromisoformat(value) return value app = FastAPI() devices = [ 'storage_charge', 'heat_pump', 'circulation_pump', 'air_conditioning', 'ventilation', 'dishwasher', 'washing_machine', 'refrigerator', 'freezer', 'cooling_aggregate', 'facility'] models = dict() for device in devices: models[device] = lgb.Booster(model_file=f"models/model_{device}.txt") def lowercase_keys_and_copy_values(list_of_dicts): return [{key.lower(): value for key, value in d.items()} for d in list_of_dicts] @app.get("/") def greet_json(): return {"Hello": str(models)} async def get_data(request: Request): data = await request.json() #data = json.loads(data) data = lowercase_keys_and_copy_values(data) data_points = [DataPoint(**item) for item in data] data_dicts = [item.dict() for item in data_points] df = pd.DataFrame(data_dicts) predictions = dict() for i in devices: predictions[i] = [] return df, predictions async def get_plots(request: Request, mode): res = await request.json() df = pd.DataFrame(res) print(df) if mode == 1: plt.style.use('dark_background') else: plt.style.use('default') plots = [] d = devices + ['total'] for i in d: buf = io.BytesIO() plt.figure() plt.plot(list(range(1, len(df)+1)), df[i]) plt.xticks(rotation=60) plt.xlabel('Hour') plt.ylabel('kWh') plt.title(f'Energy consumption of {i}') plt.savefig(buf, format='png', bbox_inches='tight') buf.seek(0) plots.append(buf) zip_buf = io.BytesIO() with zipfile.ZipFile(zip_buf, 'w', zipfile.ZIP_DEFLATED) as z: for i, plot_buf in enumerate(plots): z.writestr(f"{d[i]}.png", plot_buf.getvalue()) zip_buf.seek(0) return StreamingResponse(zip_buf, media_type="application/zip", headers={"Content-Disposition": "attachment; filename=plots.zip"}) async def get_prediction(request, H): df, predictions = await get_data(request) predictions['total'] = [] for _ in range(H): res = get_all_features(df, devices) p = dict() predictions['total'].append(0) for i in devices: pred = (models[i].predict(res[i].iloc[-1]) * 0.8) predictions[i].append(pred[0]) predictions['total'][-1] += pred[0] p[i] = pred p['timestamp'] = df.iloc[-1]['timestamp'] + pd.to_timedelta(1, unit='h') df = pd.concat([df, pd.DataFrame(p)], ignore_index=True) return {"dataframe": pd.DataFrame(predictions).to_json()} async def get_anomalies(request): df, _ = await get_data(request) res = get_all_features(df, devices) for i in devices: pred = (df[i] - models[i].predict(res[i].iloc[-1]) * 0.8).abs() df[f"is_anomaly_{i}"] = pred > 3 return {"dataframe": df.to_json(orient='records')} @app.post("/anomalies") async def predict(request: Request): try: res = await get_anomalies(request) return res except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/statistcks") async def statistcks(request: Request): try: df, _ = await get_data(request) res = get_all_features(df, devices) json_dict = {key: df.to_json() for key, df in res.items()} json_object = json.dumps(json_dict, indent=4) return {"dataframe": json_object} except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict/day") async def predict(request: Request): try: H = 24 res = await get_prediction(request, H) return res except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict/three_day") async def predict(request: Request): try: H = 24 * 3 res = await get_prediction(request, H) return res except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/predict/week") async def predict(request: Request): try: H = 24 * 7 res = await get_prediction(request, H) return res except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/plots/dark") async def predict(request: Request): try: zip_buf = await get_plots(request, 1) return zip_buf except Exception as e: raise HTTPException(status_code=500, detail=str(e)) @app.post("/plots/light") async def predict(request: Request): try: zip_buf = await get_plots(request, 0) return zip_buf except Exception as e: raise HTTPException(status_code=500, detail=str(e))