Spaces:
Running
on
A10G
Running
on
A10G
import sqlite3 | |
from pathlib import Path | |
class Database: | |
def __init__(self, db_path=None): | |
if db_path is None: | |
raise ValueError("db_path must be provided") | |
self.db_path = db_path | |
self.db_file = self.db_path / "cache.db" | |
if not self.db_file.exists(): | |
print("Creating database") | |
print("DB_FILE", self.db_file) | |
db = sqlite3.connect(self.db_file) | |
with open(Path("schema.sql"), "r") as f: | |
db.executescript(f.read()) | |
db.commit() | |
db.close() | |
def get_db(self): | |
db = sqlite3.connect(self.db_file, check_same_thread=False) | |
db.row_factory = sqlite3.Row | |
return db | |
def __enter__(self): | |
self.db = self.get_db() | |
return self.db | |
def __exit__(self, exc_type, exc_value, traceback): | |
self.db.close() | |
def __call__(self): | |
return self | |
def insert(self, prompt: str, negative_prompt: str, image_path: str, seed: int): | |
with self() as db: | |
cursor = db.cursor() | |
cursor.execute( | |
"INSERT INTO cache (prompt, negative_prompt, image_path, seed) VALUES (?, ?, ?, ?)", | |
(prompt, negative_prompt, image_path, seed), | |
) | |
db.commit() | |
def check(self, prompt: str, negative_prompt: str, seed: int): | |
with self() as db: | |
cursor = db.cursor() | |
cursor.execute( | |
"SELECT image_path FROM cache WHERE prompt = ? AND negative_prompt = ? AND seed = ? ORDER BY RANDOM() LIMIT 1", | |
(prompt, negative_prompt, seed), | |
) | |
image_path = cursor.fetchone() | |
if image_path: | |
return image_path | |
return False | |