|
from fastapi import FastAPI, Request, Depends, HTTPException, status |
|
from fastapi.security import OAuth2PasswordBearer |
|
from QuoteGenerator import QuoteGenerator |
|
from typing import Union |
|
from pydantic import BaseModel |
|
import time |
|
import os |
|
|
|
|
|
API_KEY = os.getenv('API_KEY') |
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") |
|
|
|
|
|
def api_key_auth(api_key:str = Depends(oauth2_scheme)): |
|
if api_key != API_KEY: |
|
raise HTTPException( |
|
status_code = status.HTTP_401_UNAUTHORIZED, |
|
detail="Forbidden Access" |
|
) |
|
|
|
class QuoteRequest(BaseModel): |
|
tags: Union[None, str] = None |
|
do_sample: bool = False |
|
max_new_tokens: int = 16 |
|
num_beams: int = 1 |
|
top_k: int = 50 |
|
top_p: float = 1.0 |
|
temperature: float = 1.0 |
|
|
|
app = FastAPI() |
|
|
|
|
|
@app.middleware("http") |
|
async def note_response_time(request: Request, call_next): |
|
start_time = time.time() |
|
response = await call_next(request) |
|
process_time = time.time() |
|
print(f'Time taken = {process_time-start_time:.1f}s') |
|
return response |
|
|
|
quote_generator = QuoteGenerator() |
|
quote_generator.load_generator() |
|
|
|
@app.post("/", dependencies=[Depends(api_key_auth)]) |
|
def root(request: QuoteRequest): |
|
print("Incoming request\n", request.__dict__) |
|
return {"quote": "<bot>:A beautiful quote generated by bot"} |
|
|
|
@app.post("/generate_quote", dependencies=[Depends(api_key_auth)]) |
|
def generate_quote(req: QuoteRequest): |
|
generated_quote_oup = quote_generator.generate_quote( |
|
tags = req.tags, |
|
max_new_tokens = req.max_new_tokens, |
|
num_beams = req.num_beams, |
|
temperature = req.temperature, |
|
top_k = req.top_k, |
|
top_p = req.top_p, |
|
do_sample = req.do_sample |
|
) |
|
return {'quote': generated_quote_oup} |