|
import asyncio |
|
import concurrent.futures |
|
import json |
|
import logging |
|
import os |
|
import sqlite3 |
|
from contextlib import asynccontextmanager |
|
from typing import List |
|
|
|
import numpy as np |
|
from apscheduler.schedulers.asyncio import AsyncIOScheduler |
|
from apscheduler.triggers.cron import CronTrigger |
|
from cashews import NOT_NONE, cache |
|
from fastapi import FastAPI, HTTPException, Query |
|
from huggingface_hub import login, upload_file |
|
from pandas import Timestamp |
|
from pydantic import BaseModel |
|
from starlette.responses import RedirectResponse |
|
|
|
from create_collections import collections, update_collection_for_dataset |
|
from data_loader import refresh_data |
|
|
|
login(token=os.getenv("HF_TOKEN")) |
|
|
|
UPDATE_SCHEDULE = {"hour": os.getenv("UPDATE_INTERVAL_HOURS", "*/6")} |
|
COLLECTION_UPDATE_SCHEDULE = {"hour": "0"} |
|
|
|
cache.setup("mem://?check_interval=10&size=10000") |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def get_db_connection(): |
|
conn = sqlite3.connect("datasets.db") |
|
conn.row_factory = sqlite3.Row |
|
conn.execute("PRAGMA journal_mode = WAL") |
|
conn.execute("PRAGMA synchronous = NORMAL") |
|
return conn |
|
|
|
|
|
def setup_database(): |
|
conn = get_db_connection() |
|
c = conn.cursor() |
|
c.execute( |
|
"""CREATE TABLE IF NOT EXISTS datasets |
|
(hub_id TEXT PRIMARY KEY, |
|
likes INTEGER, |
|
downloads INTEGER, |
|
tags JSON, |
|
created_at INTEGER, |
|
last_modified INTEGER, |
|
license JSON, |
|
language JSON, |
|
config_name TEXT, |
|
column_names JSON, |
|
features JSON)""" |
|
) |
|
c.execute( |
|
""" |
|
CREATE INDEX IF NOT EXISTS idx_column_names |
|
ON datasets(column_names) |
|
""" |
|
) |
|
c.execute( |
|
""" |
|
CREATE INDEX IF NOT EXISTS idx_downloads_likes |
|
ON datasets(downloads DESC, likes DESC) |
|
""" |
|
) |
|
conn.commit() |
|
c.execute("ANALYZE") |
|
conn.close() |
|
|
|
|
|
def serialize_numpy(obj): |
|
if isinstance(obj, np.ndarray): |
|
return obj.tolist() |
|
if isinstance(obj, np.integer): |
|
return int(obj) |
|
if isinstance(obj, np.floating): |
|
return float(obj) |
|
if isinstance(obj, Timestamp): |
|
return int(obj.timestamp()) |
|
logger.error(f"Object of type {type(obj)} is not JSON serializable") |
|
raise TypeError(f"Object of type {type(obj)} is not JSON serializable") |
|
|
|
|
|
def background_refresh_data(): |
|
logger.info("Starting background data refresh") |
|
try: |
|
return refresh_data() |
|
except Exception as e: |
|
logger.error(f"Error in background data refresh: {str(e)}") |
|
return None |
|
|
|
|
|
async def update_database(): |
|
logger.info("Starting scheduled data refresh") |
|
|
|
with concurrent.futures.ThreadPoolExecutor() as executor: |
|
future = executor.submit(background_refresh_data) |
|
|
|
try: |
|
datasets = await asyncio.get_event_loop().run_in_executor( |
|
None, future.result |
|
) |
|
except asyncio.CancelledError: |
|
future.cancel() |
|
logger.info("Data refresh cancelled") |
|
return |
|
|
|
if datasets is None: |
|
logger.error("Data refresh failed, skipping database update") |
|
return |
|
|
|
conn = get_db_connection() |
|
try: |
|
c = conn.cursor() |
|
c.executemany( |
|
""" |
|
INSERT OR REPLACE INTO datasets |
|
(hub_id, likes, downloads, tags, created_at, last_modified, license, language, config_name, column_names, features) |
|
VALUES (?, ?, ?, json(?), ?, ?, json(?), json(?), ?, json(?), json(?)) |
|
""", |
|
[ |
|
( |
|
data["hub_id"], |
|
data.get("likes", 0), |
|
data.get("downloads", 0), |
|
json.dumps(data.get("tags", []), default=serialize_numpy), |
|
int(data["created_at"].timestamp()) |
|
if isinstance(data["created_at"], Timestamp) |
|
else data.get("created_at", 0), |
|
int(data["last_modified"].timestamp()) |
|
if isinstance(data["last_modified"], Timestamp) |
|
else data.get("last_modified", 0), |
|
json.dumps(data.get("license", []), default=serialize_numpy), |
|
json.dumps(data.get("language", []), default=serialize_numpy), |
|
data.get("config_name", ""), |
|
json.dumps(data.get("column_names", []), default=serialize_numpy), |
|
json.dumps(data.get("features", []), default=serialize_numpy), |
|
) |
|
for data in datasets |
|
], |
|
) |
|
conn.commit() |
|
logger.info("Scheduled data refresh completed") |
|
except Exception as e: |
|
logger.error(f"Error during database update: {str(e)}") |
|
conn.rollback() |
|
finally: |
|
conn.close() |
|
|
|
try: |
|
upload_file( |
|
path_or_fileobj="datasets.db", |
|
path_in_repo="datasets.db", |
|
repo_id="librarian-bots/column-db", |
|
repo_type="dataset", |
|
) |
|
logger.info("Database file uploaded to Hugging Face Hub successfully") |
|
except Exception as e: |
|
logger.error(f"Error uploading database file to Hugging Face Hub: {str(e)}") |
|
|
|
|
|
async def update_collections(): |
|
logger.info("Starting scheduled collection update") |
|
try: |
|
for collection in collections: |
|
result = await asyncio.get_event_loop().run_in_executor( |
|
None, |
|
update_collection_for_dataset, |
|
collection["collection_name"], |
|
collection["dataset_columns"], |
|
collection["collection_description"], |
|
"librarian-bots", |
|
) |
|
logger.info(f"Updated collection: {result}") |
|
except Exception as e: |
|
logger.error(f"Error during collection update: {str(e)}") |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
setup_database() |
|
logger.info("Performing initial data refresh") |
|
await update_database() |
|
|
|
scheduler = AsyncIOScheduler() |
|
scheduler.add_job(update_database, CronTrigger(**UPDATE_SCHEDULE)) |
|
scheduler.add_job(update_collections, CronTrigger(**COLLECTION_UPDATE_SCHEDULE)) |
|
scheduler.start() |
|
|
|
await update_collections() |
|
|
|
yield |
|
|
|
scheduler.shutdown() |
|
|
|
|
|
app = FastAPI(lifespan=lifespan) |
|
|
|
|
|
@app.get("/", include_in_schema=False) |
|
def root(): |
|
return RedirectResponse(url="/docs") |
|
|
|
|
|
class SearchResponse(BaseModel): |
|
total: int |
|
page: int |
|
page_size: int |
|
results: List[dict] |
|
|
|
|
|
@cache(ttl="1h", condition=NOT_NONE) |
|
@app.get("/search", response_model=SearchResponse) |
|
async def search_datasets( |
|
columns: List[str] = Query(...), |
|
match_all: bool = Query(False), |
|
page: int = Query(1, ge=1), |
|
page_size: int = Query(10, ge=1, le=1000), |
|
): |
|
offset = (page - 1) * page_size |
|
conn = get_db_connection() |
|
c = conn.cursor() |
|
|
|
try: |
|
if match_all: |
|
query = """ |
|
SELECT *, ( |
|
SELECT COUNT(*) |
|
FROM json_each(column_names) |
|
WHERE json_each.value IN ({}) |
|
) as match_count |
|
FROM datasets |
|
WHERE match_count = ? |
|
ORDER BY downloads DESC, likes DESC |
|
LIMIT ? OFFSET ? |
|
""".format(",".join("?" * len(columns))) |
|
c.execute(query, (*columns, len(columns), page_size, offset)) |
|
else: |
|
query = """ |
|
SELECT * FROM datasets |
|
WHERE EXISTS ( |
|
SELECT 1 |
|
FROM json_each(column_names) |
|
WHERE json_each.value IN ({}) |
|
) |
|
ORDER BY downloads DESC, likes DESC |
|
LIMIT ? OFFSET ? |
|
""".format(",".join("?" * len(columns))) |
|
c.execute(query, (*columns, page_size, offset)) |
|
|
|
results = [dict(row) for row in c.fetchall()] |
|
|
|
if match_all: |
|
count_query = """ |
|
SELECT COUNT(*) as total FROM datasets |
|
WHERE ( |
|
SELECT COUNT(*) |
|
FROM json_each(column_names) |
|
WHERE json_each.value IN ({}) |
|
) = ? |
|
""".format(",".join("?" * len(columns))) |
|
c.execute(count_query, (*columns, len(columns))) |
|
else: |
|
count_query = """ |
|
SELECT COUNT(*) as total FROM datasets |
|
WHERE EXISTS ( |
|
SELECT 1 |
|
FROM json_each(column_names) |
|
WHERE json_each.value IN ({}) |
|
) |
|
""".format(",".join("?" * len(columns))) |
|
c.execute(count_query, columns) |
|
|
|
total = c.fetchone()["total"] |
|
|
|
for result in results: |
|
result["tags"] = json.loads(result["tags"]) |
|
result["license"] = json.loads(result["license"]) |
|
result["language"] = json.loads(result["language"]) |
|
result["column_names"] = json.loads(result["column_names"]) |
|
result["features"] = json.loads(result["features"]) |
|
|
|
return SearchResponse( |
|
total=total, page=page, page_size=page_size, results=results |
|
) |
|
|
|
except sqlite3.Error as e: |
|
logger.error(f"Database error: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") from e |
|
finally: |
|
conn.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|