gruhit-patel's picture
Updated backend to handled text-generation parameters
7bbb8c5
raw
history blame
1.89 kB
from fastapi import FastAPI, Request, Depends, HTTPException, status
from fastapi.security import OAuth2PasswordBearer
from QuoteGenerator import QuoteGenerator
from typing import Union
from pydantic import BaseModel
import time
import os
# API to key to validate the Referer
API_KEY = os.getenv('API_KEY')
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
# Function to check of the incoming API call is from valid host or not
def api_key_auth(api_key:str = Depends(oauth2_scheme)):
if api_key != API_KEY:
raise HTTPException(
status_code = status.HTTP_401_UNAUTHORIZED,
detail="Forbidden Access"
)
class QuoteRequest(BaseModel):
tags: Union[None, str] = None
do_sample: bool = False
max_new_tokens: int = 16
num_beams: int = 1
top_k: int = 50
top_p: float = 1.0
temperature: float = 1.0
app = FastAPI()
#Middleware to note time
@app.middleware("http")
async def note_response_time(request: Request, call_next):
start_time = time.time()
response = await call_next(request)
process_time = time.time()
print(f'Time taken = {process_time-start_time:.1f}s')
return response
quote_generator = QuoteGenerator()
quote_generator.load_generator()
@app.post("/", dependencies=[Depends(api_key_auth)])
def root(request: QuoteRequest):
print("Incoming request\n", request.__dict__)
return {"quote": "<bot>:A beautiful quote generated by bot"}
@app.post("/generate_quote", dependencies=[Depends(api_key_auth)])
def generate_quote(req: QuoteRequest):
generated_quote_oup = quote_generator.generate_quote(
tags = req.tags,
max_new_tokens = req.max_new_tokens,
num_beams = req.num_beams,
temperature = req.temperature,
top_k = req.top_k,
top_p = req.top_p,
do_sample = req.do_sample
)
return {'quote': generated_quote_oup}