Spaces:
Runtime error
Runtime error
from searcher import Searcher | |
from trie import Trie | |
from helper import parse_query, make_response, download_from_bucket | |
from fastapi import FastAPI, Request, status | |
from fastapi.exceptions import RequestValidationError | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import RedirectResponse | |
from pydantic import BaseModel | |
import settings | |
import os | |
import time | |
trie = Trie() | |
searcher = Searcher(trie) | |
app = FastAPI(title="Object Search") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_headers=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
) | |
async def validation_exception_handler(request: Request, exc: RequestValidationError): | |
details = exc.errors() | |
error_details = [] | |
for error in details: | |
error_details.append({"error": f"{error['msg']} {str(error['loc'])}"}) | |
return make_response(status=200, message="Bad Request", data=error_details) | |
async def root() -> None: | |
return RedirectResponse("/docs") | |
async def perform_healthcheck() -> None: | |
return make_response(status=200, message="OK") | |
class Query(BaseModel): | |
query_text: str | |
topk: int | |
async def search(query: Query) -> None: | |
topk = query.topk | |
query = parse_query(query.query_text) | |
candidates = searcher.search(query, topk) | |
data = [candidate.serialize() for candidate in candidates] | |
return make_response(status=200, message="OK", data=data) | |
async def startup_event(): | |
if os.path.exists("cache.json"): | |
start_time = time.time() | |
trie.load_from_cache("cache.json") | |
print("Load from cache took %.2f seconds" % (time.time() - start_time)) | |
else: | |
if not os.path.exists("data"): | |
os.mkdir("data") | |
start_time = time.time() | |
download_from_bucket("data") | |
print("Download from bucket took %.2f seconds" % (time.time() - start_time)) | |
start_time = time.time() | |
trie.load_from_dir("data") | |
trie.save_to_cache("cache.json") | |
print("Load from directory took %.2f seconds" % (time.time() - start_time)) | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run("main:app", host=settings.HOST, port=settings.PORT, reload=True) | |