thomasht86 commited on
Commit
6a8b62e
1 Parent(s): 8996eb9

deploy at 2024-08-24 17:41:48.872316

Browse files
Files changed (1) hide show
  1. main.py +40 -11
main.py CHANGED
@@ -45,6 +45,7 @@ from fasthtml.pico import Search, Grid, Fieldset, Label
45
  from starlette.middleware import Middleware
46
  from starlette.middleware.base import BaseHTTPMiddleware
47
  from starlette.middleware.sessions import SessionMiddleware
 
48
  from vespa.application import Vespa
49
  import json
50
  import os
@@ -57,6 +58,7 @@ import tempfile
57
  from enum import Enum
58
  from typing import Tuple as T
59
  from urllib.parse import quote
 
60
 
61
  DEV_MODE = False
62
 
@@ -164,6 +166,18 @@ class XFrameOptionsMiddleware(BaseHTTPMiddleware):
164
  response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
165
  return response
166
 
 
 
 
 
 
 
 
 
 
 
 
 
167
 
168
  middlewares = [
169
  Middleware(
@@ -172,6 +186,8 @@ middlewares = [
172
  max_age=3600,
173
  ),
174
  Middleware(XFrameOptionsMiddleware),
 
 
175
  ]
176
  bware = Beforeware(
177
  user_auth_before,
@@ -295,7 +311,6 @@ def get(sess):
295
  queries = [
296
  "Breast Cancer Cells Feed on Cholesterol",
297
  "Treating Asthma With Plants vs. Pills",
298
- "Alkylphenol Endocrine Disruptors",
299
  "Testing Turmeric on Smokers",
300
  "The Role of Pesticides in Parkinson's Disease",
301
  ]
@@ -424,9 +439,10 @@ def post(login: Login, sess):
424
  if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
425
  # Incorrect password - add error message
426
  return RedirectResponse("/login?error=True", status_code=303)
427
- sess["auth"] = True
428
- print(f"Sess after login: {sess}")
429
- return RedirectResponse("/admin", status_code=303)
 
430
 
431
 
432
  @app.get("/logout")
@@ -452,9 +468,26 @@ def replace_hi_with_strong(text):
452
 
453
 
454
  def log_query_to_db(query, ranking, sess):
455
- return queries.insert(
456
  Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
457
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
458
 
459
 
460
  def parse_results(records):
@@ -544,12 +577,7 @@ def get_yql(ranking: RankProfile, userquery: str) -> T[str, dict]:
544
  @app.get("/search")
545
  async def search(userquery: str, ranking: str, sess):
546
  print(sess)
547
- if "queries" not in sess:
548
- sess["queries"] = []
549
  quoted = quote(userquery) + "&ranking=" + ranking
550
- sess["queries"].append(quoted)
551
- print(f"Searching for: {userquery}")
552
- print(f"Ranking: {ranking}")
553
  log_query_to_db(userquery, ranking, sess)
554
  yql, body = get_yql(ranking, userquery)
555
  async with vespa_app.asyncio() as session:
@@ -806,12 +834,13 @@ def get_document(docid: str, sess):
806
  resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
807
  doc = resp.json
808
  # Link with Back to search results at top of page
 
809
  return Main(
810
  Div(
811
  A(
812
  I(cls="fa fa-arrow-left"),
813
  "Back to search results",
814
- hx_get=f"/search?userquery={sess['queries'][-1]}",
815
  hx_target="#results",
816
  style="margin: 10px;",
817
  ),
 
45
  from starlette.middleware import Middleware
46
  from starlette.middleware.base import BaseHTTPMiddleware
47
  from starlette.middleware.sessions import SessionMiddleware
48
+ from starlette.middleware.cors import CORSMiddleware
49
  from vespa.application import Vespa
50
  import json
51
  import os
 
58
  from enum import Enum
59
  from typing import Tuple as T
60
  from urllib.parse import quote
61
+ import uuid
62
 
63
  DEV_MODE = False
64
 
 
166
  response.headers["X-Frame-Options"] = "ALLOW-FROM https://huggingface.co/"
167
  return response
168
 
169
+ class SessionLoggingMiddleware(BaseHTTPMiddleware):
170
+ async def dispatch(self, request, call_next):
171
+ print(f"Before request: Session data: {request.session}")
172
+ response = await call_next(request)
173
+ print(f"After request: Session data: {request.session}")
174
+ return response
175
+
176
+ class DebugSessionMiddleware(SessionMiddleware):
177
+ async def __call__(self, scope, receive, send):
178
+ print(f"DebugSessionMiddleware: Before processing - Scope: {scope}")
179
+ await super().__call__(scope, receive, send)
180
+ print(f"DebugSessionMiddleware: After processing - Scope: {scope}")
181
 
182
  middlewares = [
183
  Middleware(
 
186
  max_age=3600,
187
  ),
188
  Middleware(XFrameOptionsMiddleware),
189
+ Middleware(SessionLoggingMiddleware),
190
+ #Middleware(DebugSessionMiddleware, secret_key=get_key(fname=sess_key_path)),
191
  ]
192
  bware = Beforeware(
193
  user_auth_before,
 
311
  queries = [
312
  "Breast Cancer Cells Feed on Cholesterol",
313
  "Treating Asthma With Plants vs. Pills",
 
314
  "Testing Turmeric on Smokers",
315
  "The Role of Pesticides in Parkinson's Disease",
316
  ]
 
439
  if not compare_digest(ADMIN_PWD.encode("utf-8"), login.pwd.encode("utf-8")):
440
  # Incorrect password - add error message
441
  return RedirectResponse("/login?error=True", status_code=303)
442
+ print(f"Session after setting auth: {sess}")
443
+ response = RedirectResponse("/admin", status_code=303)
444
+ print(f"Cookies being set: {response.headers.get('Set-Cookie')}")
445
+ return response
446
 
447
 
448
  @app.get("/logout")
 
468
 
469
 
470
  def log_query_to_db(query, ranking, sess):
471
+ queries.insert(
472
  Query(query=query, ranking=ranking, sess_id=sesskey, timestamp=int(time.time()))
473
  )
474
+ if 'user_id' not in sess:
475
+ sess['user_id'] = str(uuid.uuid4())
476
+
477
+ if 'queries' not in sess:
478
+ sess['queries'] = []
479
+
480
+ query_data = {
481
+ 'query': query,
482
+ 'ranking': ranking,
483
+ 'timestamp': int(time.time())
484
+ }
485
+ sess['queries'].append(query_data)
486
+
487
+ # Limit the number of queries stored in the session to prevent it from growing too large
488
+ sess['queries'] = sess['queries'][-100:] # Keep only the last 100 queries
489
+
490
+ return query_data
491
 
492
 
493
  def parse_results(records):
 
577
  @app.get("/search")
578
  async def search(userquery: str, ranking: str, sess):
579
  print(sess)
 
 
580
  quoted = quote(userquery) + "&ranking=" + ranking
 
 
 
581
  log_query_to_db(userquery, ranking, sess)
582
  yql, body = get_yql(ranking, userquery)
583
  async with vespa_app.asyncio() as session:
 
834
  resp = vespa_app.get_data(data_id=docid, schema="doc", namespace="tutorial")
835
  doc = resp.json
836
  # Link with Back to search results at top of page
837
+ last_query = sess.get('queries', [{}])[-1].get('query', '')
838
  return Main(
839
  Div(
840
  A(
841
  I(cls="fa fa-arrow-left"),
842
  "Back to search results",
843
+ hx_get=f"/search?userquery={last_query}",
844
  hx_target="#results",
845
  style="margin: 10px;",
846
  ),