Spaces:
Running
Running
File size: 4,674 Bytes
6842c08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import logging
import os
import time
import docker
import pytest
from docker import DockerClient
from pytest_docker.plugin import get_docker_ip
from fastapi.testclient import TestClient
from sqlalchemy import text, create_engine
log = logging.getLogger(__name__)
def get_fast_api_client():
from main import app
with TestClient(app) as c:
return c
class AbstractIntegrationTest:
BASE_PATH = None
def create_url(self, path="", query_params=None):
if self.BASE_PATH is None:
raise Exception("BASE_PATH is not set")
parts = self.BASE_PATH.split("/")
parts = [part.strip() for part in parts if part.strip() != ""]
path_parts = path.split("/")
path_parts = [part.strip() for part in path_parts if part.strip() != ""]
query_parts = ""
if query_params:
query_parts = "&".join(
[f"{key}={value}" for key, value in query_params.items()]
)
query_parts = f"?{query_parts}"
return "/".join(parts + path_parts) + query_parts
@classmethod
def setup_class(cls):
pass
def setup_method(self):
pass
@classmethod
def teardown_class(cls):
pass
def teardown_method(self):
pass
class AbstractPostgresTest(AbstractIntegrationTest):
DOCKER_CONTAINER_NAME = "postgres-test-container-will-get-deleted"
docker_client: DockerClient
@classmethod
def _create_db_url(cls, env_vars_postgres: dict) -> str:
host = get_docker_ip()
user = env_vars_postgres["POSTGRES_USER"]
pw = env_vars_postgres["POSTGRES_PASSWORD"]
port = 8081
db = env_vars_postgres["POSTGRES_DB"]
return f"postgresql://{user}:{pw}@{host}:{port}/{db}"
@classmethod
def setup_class(cls):
super().setup_class()
try:
env_vars_postgres = {
"POSTGRES_USER": "user",
"POSTGRES_PASSWORD": "example",
"POSTGRES_DB": "openwebui",
}
cls.docker_client = docker.from_env()
cls.docker_client.containers.run(
"postgres:16.2",
detach=True,
environment=env_vars_postgres,
name=cls.DOCKER_CONTAINER_NAME,
ports={5432: ("0.0.0.0", 8081)},
command="postgres -c log_statement=all",
)
time.sleep(0.5)
database_url = cls._create_db_url(env_vars_postgres)
os.environ["DATABASE_URL"] = database_url
retries = 10
db = None
while retries > 0:
try:
from open_webui.config import OPEN_WEBUI_DIR
db = create_engine(database_url, pool_pre_ping=True)
db = db.connect()
log.info("postgres is ready!")
break
except Exception as e:
log.warning(e)
time.sleep(3)
retries -= 1
if db:
# import must be after setting env!
cls.fast_api_client = get_fast_api_client()
db.close()
else:
raise Exception("Could not connect to Postgres")
except Exception as ex:
log.error(ex)
cls.teardown_class()
pytest.fail(f"Could not setup test environment: {ex}")
def _check_db_connection(self):
from open_webui.internal.db import Session
retries = 10
while retries > 0:
try:
Session.execute(text("SELECT 1"))
Session.commit()
break
except Exception as e:
Session.rollback()
log.warning(e)
time.sleep(3)
retries -= 1
def setup_method(self):
super().setup_method()
self._check_db_connection()
@classmethod
def teardown_class(cls) -> None:
super().teardown_class()
cls.docker_client.containers.get(cls.DOCKER_CONTAINER_NAME).remove(force=True)
def teardown_method(self):
from open_webui.internal.db import Session
# rollback everything not yet committed
Session.commit()
# truncate all tables
tables = [
"auth",
"chat",
"chatidtag",
"document",
"memory",
"model",
"prompt",
"tag",
'"user"',
]
for table in tables:
Session.execute(text(f"TRUNCATE TABLE {table}"))
Session.commit()
|