|
import os |
|
import base64 |
|
import io |
|
import csv |
|
import json |
|
import datetime |
|
import pandas as pd |
|
from datetime import timedelta |
|
from fastapi import FastAPI,WebSocket, Depends, HTTPException, status, UploadFile, File |
|
from fastapi.responses import FileResponse, JSONResponse |
|
from fastapi.requests import Request |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from fastapi.security import OAuth2PasswordRequestForm |
|
from .auth import ( |
|
get_current_user, |
|
create_access_token, |
|
verify_password, |
|
get_password_hash, |
|
) |
|
from .db.models import User, Token, FileUpload |
|
from .db.database import get_user_by_username, create_user, save_file, get_user_files |
|
from .websocket import handle_websocket |
|
from .llm_models import invoke_general_model, invoke_customer_search |
|
|
|
app = FastAPI() |
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
|
|
current_dir = os.path.dirname(os.path.realpath(__file__)) |
|
|
|
|
|
frontend_dir = os.path.join(current_dir, "..", "..", "frontend", "dist") |
|
|
|
@app.get("/") |
|
async def serve_react_root() -> FileResponse: |
|
return FileResponse(os.path.join(frontend_dir, "index.html")) |
|
|
|
@app.post("/register", response_model=Token) |
|
async def register(user: User) -> Token: |
|
if await get_user_by_username(user.username): |
|
raise HTTPException( |
|
status_code=status.HTTP_400_BAD_REQUEST, |
|
detail="Username already registered" |
|
) |
|
|
|
hashed_password = get_password_hash(user.password) |
|
user.password = hashed_password |
|
|
|
if not await create_user(user): |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail="Could not create user" |
|
) |
|
|
|
access_token = create_access_token( |
|
data={"sub": user.username}, |
|
expires_delta=timedelta(minutes=30) |
|
) |
|
return Token(access_token=access_token, token_type="bearer") |
|
|
|
@app.post("/token", response_model=Token) |
|
async def login(form_data: OAuth2PasswordRequestForm = Depends()) -> Token: |
|
user = await get_user_by_username(form_data.username) |
|
if not user or not verify_password(form_data.password, user.password): |
|
raise HTTPException( |
|
status_code=status.HTTP_401_UNAUTHORIZED, |
|
detail="Incorrect username or password", |
|
headers={"WWW-Authenticate": "Bearer"}, |
|
) |
|
|
|
access_token = create_access_token( |
|
data={"sub": user.username}, |
|
expires_delta=timedelta(minutes=30) |
|
) |
|
return Token(access_token=access_token, token_type="bearer") |
|
|
|
@app.post("/upload") |
|
async def upload_file( |
|
file: UploadFile = File(...), |
|
current_user: User = Depends(get_current_user) |
|
) -> dict: |
|
contents = await file.read() |
|
|
|
df = pd.read_csv(io.StringIO(contents.decode('utf-8'))) |
|
|
|
|
|
records = json.loads(df.to_json(orient='records')) |
|
|
|
|
|
|
|
|
|
|
|
if not await save_file(current_user.username, records, file.filename): |
|
raise HTTPException( |
|
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, |
|
detail="Could not save file" |
|
) |
|
|
|
return {"message": "File uploaded successfully"} |
|
|
|
@app.get("/api/opportunities") |
|
async def get_opportunities(request: Request, current_user: User = Depends(get_current_user)) -> dict: |
|
records = await get_user_files(current_user.username) |
|
print("records", records) |
|
all_records = [] |
|
for record in records: |
|
all_records.extend(record.content) |
|
|
|
|
|
return {"records": all_records , "success": len(all_records) > 0} |
|
|
|
@app.websocket("/ws") |
|
async def websocket_endpoint(websocket: WebSocket) -> None: |
|
await handle_websocket(websocket) |
|
|
|
@app.post("/api/message") |
|
async def message(obj: dict, current_user: User = Depends(get_current_user)) -> JSONResponse: |
|
"""Endpoint to handle general incoming messages from the frontend.""" |
|
answer = invoke_general_model(obj["message"]) |
|
return JSONResponse(content={"message": answer.model_dump_json()}) |
|
|
|
|
|
@app.post("/api/customer_insights") |
|
async def customer_insights(obj: dict) -> JSONResponse: |
|
"""Endpoint to launch a customer insight search.""" |
|
answer = invoke_customer_search(obj["message"]) |
|
return JSONResponse(content={"AIMessage": answer.model_dump_json()}) |
|
|
|
app.mount("/assets", StaticFiles(directory=os.path.join(frontend_dir, "assets")), name="static") |
|
|
|
if __name__ == "__main__": |
|
from fastapi.testclient import TestClient |
|
|
|
client = TestClient(app) |
|
|
|
def test_message_endpoint(): |
|
|
|
response = client.post("/api/message", json={"message": "What is MEDDPICC?"}) |
|
print(response.json()) |
|
assert response.status_code == 200 |
|
assert "AIMessage" in response.json() |
|
|
|
test_message_endpoint() |