import os import torch import torch.nn.functional as F from transformers import BertTokenizer, BertConfig, BertForSequenceClassification from fastapi import FastAPI, Depends, HTTPException, status from fastapi.security import HTTPBasic, HTTPBasicCredentials from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from typing_extensions import Annotated from Sentiment import app from .model import model, tokenizer app_security = HTTPBasic() api_key = os.getenv("API_KEY") if not api_key: raise ValueError("API_KEY is missing.") i2l = {0: 'positive', 1: 'neutral', 2: 'negative'} class AnalyzeRequest(BaseModel): text: str class AnalyzeResponse(BaseModel): text: str label: str score: float @app.get("/checkhealth", tags=["CheckHealth"]) def checkhealth(): return "Sentiment API is running." @app.post("/predict", tags=["Analyze"], summary="Analyze text from prompt", response_model=AnalyzeResponse) def predict(creds: Annotated[HTTPBasicCredentials, Depends(app_security)], data: AnalyzeRequest): if creds.password != api_key: print(creds.password, api_key) raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Incorrect Password", headers={"WWW-Authenticate": "Basic"} ) text = data.text test_sample = tokenizer.encode(text) test_sample = torch.LongTensor(test_sample).view(1, -1).to(model.device) logits = model(test_sample)[0] label_index = torch.topk(logits, k=1, dim=-1)[1].squeeze().item() label = i2l[label_index] score = f'{F.softmax(logits, dim=-1).squeeze()[label_index]:.3f}' return {"text":text, "label": label, "score": score}