Spaces:
Running
on
T4
Running
on
T4
File size: 4,468 Bytes
8ce4d25 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 |
import asyncio
import json
from fasthtml.common import *
from shad4fast import *
from vespa.application import Vespa
from backend.colpali import load_model, get_result_from_query
from backend.vespa_app import get_vespa_app
from frontend.app import Home, Search, SearchResult, SearchBox
from frontend.layout import Layout
highlight_js_theme_link = Link(id="highlight-theme", rel="stylesheet", href="")
highlight_js_theme = Script(src="/static/js/highlightjs-theme.js")
highlight_js = HighlightJS(
langs=["python", "javascript", "java", "json", "xml"],
dark="github-dark",
light="github",
)
app, rt = fast_app(
htmlkw={"cls": "h-full"},
pico=False,
hdrs=(
ShadHead(tw_cdn=False, theme_handle=True),
highlight_js,
highlight_js_theme_link,
highlight_js_theme,
),
)
vespa_app: Vespa = get_vespa_app()
class ModelManager:
_instance = None
model = None
processor = None
@staticmethod
def get_instance():
if ModelManager._instance is None:
ModelManager._instance = ModelManager()
ModelManager._instance.initialize_model_and_processor()
return ModelManager._instance
def initialize_model_and_processor(self):
if self.model is None or self.processor is None: # Ensure no reinitialization
self.model, self.processor = load_model()
if self.model is None or self.processor is None:
print("Failed to initialize model or processor at startup")
else:
print("Model and processor loaded at startup")
@rt("/static/{filepath:path}")
def serve_static(filepath: str):
return FileResponse(f"./static/{filepath}")
@rt("/")
def get():
return Layout(Home())
@rt("/search")
def get(request):
# Extract the 'query' parameter from the URL using query_params
query_value = request.query_params.get("query", "").strip()
# Always render the SearchBox first
if not query_value:
# Show SearchBox and a message for missing query
return Layout(
Div(
SearchBox(query_value=query_value),
Div(
P(
"No query provided. Please enter a query.",
cls="text-center text-muted-foreground",
),
cls="p-10",
),
cls="grid",
)
)
# Show the loading message if a query is provided
return Layout(Search(request)) # Show SearchBox and Loading message initially
@rt("/fetch_results")
def get(request, query: str, nn: bool = True):
# Check if the request came from HTMX; if not, redirect to /search
if "hx-request" not in request.headers:
return RedirectResponse("/search")
# Extract the 'query' parameter from the URL
# Fetch model and processor
manager = ModelManager.get_instance()
model = manager.model
processor = manager.processor
# Fetch real search results from Vespa
result = asyncio.run(
get_result_from_query(
vespa_app,
processor=processor,
model=model,
query=query,
nn=nn,
gen_sim_map=True,
)
)
# Extract search results from the result payload
search_results = (
result["root"]["children"]
if "root" in result and "children" in result["root"]
else []
)
# Directly return the search results without the full page layout
return SearchResult(search_results)
@rt("/app")
def get():
return Layout(Div(P(f"Connected to Vespa at {vespa_app.url}"), cls="p-4"))
@rt("/run_query")
def get(query: str, nn: bool = False):
# dummy-function to avoid running the query every time
# result = get_result_dummy(query, nn)
# If we want to run real, uncomment the following lines
model, processor = get_model_and_processor()
result = asyncio.run(
get_result_from_query(
vespa_app, processor=processor, model=model, query=query, nn=nn
)
)
# model, processor = get_model_and_processor()
# result = asyncio.run(
# get_result_from_query(vespa_app, processor=processor, model=model, query=query, nn=nn)
# )
return Layout(Div(H1("Result"), Pre(Code(json.dumps(result, indent=2))), cls="p-4"))
if __name__ == "__main__":
# ModelManager.get_instance() # Initialize once at startup
serve()
|