Spaces:
Sleeping
Sleeping
thomasht86
commited on
Commit
•
6a8b62e
1
Parent(s):
8996eb9
deploy at 2024-08-24 17:41:48.872316
Browse files
main.py
CHANGED
@@ -45,6 +45,7 @@ from fasthtml.pico import Search, Grid, Fieldset, Label
|
|
45 |
from starlette.middleware import Middleware
|
46 |
from starlette.middleware.base import BaseHTTPMiddleware
|
47 |
from starlette.middleware.sessions import SessionMiddleware
|
|
|
48 |
from vespa.application import Vespa
|
49 |
import json
|
50 |
import os
|
@@ -57,6 +58,7 @@ import tempfile
|
|
57 |
from enum import Enum
|
58 |
from typing import Tuple as T
|
59 |
from urllib.parse import quote
|
|
|
60 |
|
61 |
DEV_MODE = False
|
62 |
|
@@ -164,6 +166,18 @@ class XFrameOptionsMiddleware(BaseHTTPMiddleware):
|
|
164 |
response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
|
165 |
return response
|
166 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
167 |
|
168 |
middlewares = [
|
169 |
Middleware(
|
@@ -172,6 +186,8 @@ middlewares = [
|
|
172 |
max_age=3600,
|
173 |
),
|
174 |
Middleware(XFrameOptionsMiddleware),
|
|
|
|
|
175 |
]
|
176 |
bware = Beforeware(
|
177 |
user_auth_before,
|
@@ -295,7 +311,6 @@ def get(sess):
|
|
295 |
queries = [
|
296 |
"Breast Cancer Cells Feed on Cholesterol",
|
297 |
"Treating Asthma With Plants vs. Pills",
|
298 |
-
"Alkylphenol Endocrine Disruptors",
|
299 |
"Testing Turmeric on Smokers",
|
300 |
"The Role of Pesticides in Parkinson's Disease",
|
301 |
]
|
@@ -424,9 +439,10 @@ def post(login: Login, sess):
|
|
424 |
if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
|
425 |
# Incorrect password - add error message
|
426 |
return RedirectResponse("/login?error=True", status_code=303)
|
427 |
-
|
428 |
-
|
429 |
-
|
|
|
430 |
|
431 |
|
432 |
@app.get("/logout")
|
@@ -452,9 +468,26 @@ def replace_hi_with_strong(text):
|
|
452 |
|
453 |
|
454 |
def log_query_to_db(query, ranking, sess):
|
455 |
-
|
456 |
Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
|
457 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
458 |
|
459 |
|
460 |
def parse_results(records):
|
@@ -544,12 +577,7 @@ def get_yql(ranking: RankProfile, userquery: str) -> T[str, dict]:
|
|
544 |
@app.get("/search")
|
545 |
async def search(userquery: str, ranking: str, sess):
|
546 |
print(sess)
|
547 |
-
if "queries" not in sess:
|
548 |
-
sess["queries"] = []
|
549 |
quoted = quote(userquery) + "&ranking=" + ranking
|
550 |
-
sess["queries"].append(quoted)
|
551 |
-
print(f"Searching for: {userquery}")
|
552 |
-
print(f"Ranking: {ranking}")
|
553 |
log_query_to_db(userquery, ranking, sess)
|
554 |
yql, body = get_yql(ranking, userquery)
|
555 |
async with vespa_app.asyncio() as session:
|
@@ -806,12 +834,13 @@ def get_document(docid: str, sess):
|
|
806 |
resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
|
807 |
doc = resp.json
|
808 |
# Link with Back to search results at top of page
|
|
|
809 |
return Main(
|
810 |
Div(
|
811 |
A(
|
812 |
I(cls="fa fa-arrow-left"),
|
813 |
"Back to search results",
|
814 |
-
hx_get=f"/search?userquery={
|
815 |
hx_target="#results",
|
816 |
style="margin: 10px;",
|
817 |
),
|
|
|
45 |
from starlette.middleware import Middleware
|
46 |
from starlette.middleware.base import BaseHTTPMiddleware
|
47 |
from starlette.middleware.sessions import SessionMiddleware
|
48 |
+
from starlette.middleware.cors import CORSMiddleware
|
49 |
from vespa.application import Vespa
|
50 |
import json
|
51 |
import os
|
|
|
58 |
from enum import Enum
|
59 |
from typing import Tuple as T
|
60 |
from urllib.parse import quote
|
61 |
+
import uuid
|
62 |
|
63 |
DEV_MODE = False
|
64 |
|
|
|
166 |
response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
|
167 |
return response
|
168 |
|
169 |
+
class SessionLoggingMiddleware(BaseHTTPMiddleware):
|
170 |
+
async def dispatch(self, request, call_next):
|
171 |
+
print(f"Before request: Session data: {request.session}")
|
172 |
+
response = await call_next(request)
|
173 |
+
print(f"After request: Session data: {request.session}")
|
174 |
+
return response
|
175 |
+
|
176 |
+
class DebugSessionMiddleware(SessionMiddleware):
|
177 |
+
async def __call__(self, scope, receive, send):
|
178 |
+
print(f"DebugSessionMiddleware: Before processing - Scope: {scope}")
|
179 |
+
await super().__call__(scope, receive, send)
|
180 |
+
print(f"DebugSessionMiddleware: After processing - Scope: {scope}")
|
181 |
|
182 |
middlewares = [
|
183 |
Middleware(
|
|
|
186 |
max_age=3600,
|
187 |
),
|
188 |
Middleware(XFrameOptionsMiddleware),
|
189 |
+
Middleware(SessionLoggingMiddleware),
|
190 |
+
#Middleware(DebugSessionMiddleware, secret_key=get_key(fname=sess_key_path)),
|
191 |
]
|
192 |
bware = Beforeware(
|
193 |
user_auth_before,
|
|
|
311 |
queries = [
|
312 |
"Breast Cancer Cells Feed on Cholesterol",
|
313 |
"Treating Asthma With Plants vs. Pills",
|
|
|
314 |
"Testing Turmeric on Smokers",
|
315 |
"The Role of Pesticides in Parkinson's Disease",
|
316 |
]
|
|
|
439 |
if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
|
440 |
# Incorrect password - add error message
|
441 |
return RedirectResponse("/login?error=True", status_code=303)
|
442 |
+
print(f"Session after setting auth: {sess}")
|
443 |
+
response = RedirectResponse("/admin", status_code=303)
|
444 |
+
print(f"Cookies being set: {response.headers.get('Set-Cookie')}")
|
445 |
+
return response
|
446 |
|
447 |
|
448 |
@app.get("/logout")
|
|
|
468 |
|
469 |
|
470 |
def log_query_to_db(query, ranking, sess):
|
471 |
+
queries.insert(
|
472 |
Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
|
473 |
)
|
474 |
+
if 'user_id' not in sess:
|
475 |
+
sess['user_id'] = str(uuid.uuid4())
|
476 |
+
|
477 |
+
if 'queries' not in sess:
|
478 |
+
sess['queries'] = []
|
479 |
+
|
480 |
+
query_data = {
|
481 |
+
'query': query,
|
482 |
+
'ranking': ranking,
|
483 |
+
'timestamp': int(time.time())
|
484 |
+
}
|
485 |
+
sess['queries'].append(query_data)
|
486 |
+
|
487 |
+
# Limit the number of queries stored in the session to prevent it from growing too large
|
488 |
+
sess['queries'] = sess['queries'][-100:] # Keep only the last 100 queries
|
489 |
+
|
490 |
+
return query_data
|
491 |
|
492 |
|
493 |
def parse_results(records):
|
|
|
577 |
@app.get("/search")
|
578 |
async def search(userquery: str, ranking: str, sess):
|
579 |
print(sess)
|
|
|
|
|
580 |
quoted = quote(userquery) + "&ranking=" + ranking
|
|
|
|
|
|
|
581 |
log_query_to_db(userquery, ranking, sess)
|
582 |
yql, body = get_yql(ranking, userquery)
|
583 |
async with vespa_app.asyncio() as session:
|
|
|
834 |
resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
|
835 |
doc = resp.json
|
836 |
# Link with Back to search results at top of page
|
837 |
+
last_query = sess.get('queries', [{}])[-1].get('query', '')
|
838 |
return Main(
|
839 |
Div(
|
840 |
A(
|
841 |
I(cls="fa fa-arrow-left"),
|
842 |
"Back to search results",
|
843 |
+
hx_get=f"/search?userquery={last_query}",
|
844 |
hx_target="#results",
|
845 |
style="margin: 10px;",
|
846 |
),
|