import base64
import enum
import hashlib
import hmac
import json
import logging
import os
import pickle
import re
import time
from json import JSONDecodeError

from sqlalchemy import func
from sqlalchemy.dialects.postgresql import JSONB

from configs import dify_config
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from extensions.ext_database import db
from extensions.ext_storage import storage

from .account import Account
from .model import App, Tag, TagBinding, UploadFile
from .types import StringUUID


class DatasetPermissionEnum(str, enum.Enum):
    ONLY_ME = "only_me"
    ALL_TEAM = "all_team_members"
    PARTIAL_TEAM = "partial_members"


class Dataset(db.Model):
    __tablename__ = "datasets"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="dataset_pkey"),
        db.Index("dataset_tenant_idx", "tenant_id"),
        db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"),
    )

    INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None]
    PROVIDER_LIST = ["vendor", "external", None]

    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"))
    tenant_id = db.Column(StringUUID, nullable=False)
    name = db.Column(db.String(255), nullable=False)
    description = db.Column(db.Text, nullable=True)
    provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying"))
    permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying"))
    data_source_type = db.Column(db.String(255))
    indexing_technique = db.Column(db.String(255), nullable=True)
    index_struct = db.Column(db.Text, nullable=True)
    created_by = db.Column(StringUUID, nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
    updated_by = db.Column(StringUUID, nullable=True)
    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
    embedding_model = db.Column(db.String(255), nullable=True)
    embedding_model_provider = db.Column(db.String(255), nullable=True)
    collection_binding_id = db.Column(StringUUID, nullable=True)
    retrieval_model = db.Column(JSONB, nullable=True)

    @property
    def dataset_keyword_table(self):
        dataset_keyword_table = (
            db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first()
        )
        if dataset_keyword_table:
            return dataset_keyword_table

        return None

    @property
    def index_struct_dict(self):
        return json.loads(self.index_struct) if self.index_struct else None

    @property
    def external_retrieval_model(self):
        default_retrieval_model = {
            "top_k": 2,
            "score_threshold": 0.0,
        }
        return self.retrieval_model or default_retrieval_model

    @property
    def created_by_account(self):
        return db.session.get(Account, self.created_by)

    @property
    def latest_process_rule(self):
        return (
            DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id)
            .order_by(DatasetProcessRule.created_at.desc())
            .first()
        )

    @property
    def app_count(self):
        return (
            db.session.query(func.count(AppDatasetJoin.id))
            .filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id)
            .scalar()
        )

    @property
    def document_count(self):
        return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar()

    @property
    def available_document_count(self):
        return (
            db.session.query(func.count(Document.id))
            .filter(
                Document.dataset_id == self.id,
                Document.indexing_status == "completed",
                Document.enabled == True,
                Document.archived == False,
            )
            .scalar()
        )

    @property
    def available_segment_count(self):
        return (
            db.session.query(func.count(DocumentSegment.id))
            .filter(
                DocumentSegment.dataset_id == self.id,
                DocumentSegment.status == "completed",
                DocumentSegment.enabled == True,
            )
            .scalar()
        )

    @property
    def word_count(self):
        return (
            Document.query.with_entities(func.coalesce(func.sum(Document.word_count)))
            .filter(Document.dataset_id == self.id)
            .scalar()
        )

    @property
    def doc_form(self):
        document = db.session.query(Document).filter(Document.dataset_id == self.id).first()
        if document:
            return document.doc_form
        return None

    @property
    def retrieval_model_dict(self):
        default_retrieval_model = {
            "search_method": RetrievalMethod.SEMANTIC_SEARCH.value,
            "reranking_enable": False,
            "reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""},
            "top_k": 2,
            "score_threshold_enabled": False,
        }
        return self.retrieval_model or default_retrieval_model

    @property
    def tags(self):
        tags = (
            db.session.query(Tag)
            .join(TagBinding, Tag.id == TagBinding.tag_id)
            .filter(
                TagBinding.target_id == self.id,
                TagBinding.tenant_id == self.tenant_id,
                Tag.tenant_id == self.tenant_id,
                Tag.type == "knowledge",
            )
            .all()
        )

        return tags or []

    @property
    def external_knowledge_info(self):
        if self.provider != "external":
            return None
        external_knowledge_binding = (
            db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first()
        )
        if not external_knowledge_binding:
            return None
        external_knowledge_api = (
            db.session.query(ExternalKnowledgeApis)
            .filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id)
            .first()
        )
        if not external_knowledge_api:
            return None
        return {
            "external_knowledge_id": external_knowledge_binding.external_knowledge_id,
            "external_knowledge_api_id": external_knowledge_api.id,
            "external_knowledge_api_name": external_knowledge_api.name,
            "external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""),
        }

    @staticmethod
    def gen_collection_name_by_id(dataset_id: str) -> str:
        normalized_dataset_id = dataset_id.replace("-", "_")
        return f"Vector_index_{normalized_dataset_id}_Node"


class DatasetProcessRule(db.Model):
    __tablename__ = "dataset_process_rules"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"),
        db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"),
    )

    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
    dataset_id = db.Column(StringUUID, nullable=False)
    mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying"))
    rules = db.Column(db.Text, nullable=True)
    created_by = db.Column(StringUUID, nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

    MODES = ["automatic", "custom"]
    PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"]
    AUTOMATIC_RULES = {
        "pre_processing_rules": [
            {"id": "remove_extra_spaces", "enabled": True},
            {"id": "remove_urls_emails", "enabled": False},
        ],
        "segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50},
    }

    def to_dict(self):
        return {
            "id": self.id,
            "dataset_id": self.dataset_id,
            "mode": self.mode,
            "rules": self.rules_dict,
            "created_by": self.created_by,
            "created_at": self.created_at,
        }

    @property
    def rules_dict(self):
        try:
            return json.loads(self.rules) if self.rules else None
        except JSONDecodeError:
            return None


class Document(db.Model):
    __tablename__ = "documents"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="document_pkey"),
        db.Index("document_dataset_id_idx", "dataset_id"),
        db.Index("document_is_paused_idx", "is_paused"),
        db.Index("document_tenant_idx", "tenant_id"),
    )

    # initial fields
    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
    tenant_id = db.Column(StringUUID, nullable=False)
    dataset_id = db.Column(StringUUID, nullable=False)
    position = db.Column(db.Integer, nullable=False)
    data_source_type = db.Column(db.String(255), nullable=False)
    data_source_info = db.Column(db.Text, nullable=True)
    dataset_process_rule_id = db.Column(StringUUID, nullable=True)
    batch = db.Column(db.String(255), nullable=False)
    name = db.Column(db.String(255), nullable=False)
    created_from = db.Column(db.String(255), nullable=False)
    created_by = db.Column(StringUUID, nullable=False)
    created_api_request_id = db.Column(StringUUID, nullable=True)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

    # start processing
    processing_started_at = db.Column(db.DateTime, nullable=True)

    # parsing
    file_id = db.Column(db.Text, nullable=True)
    word_count = db.Column(db.Integer, nullable=True)
    parsing_completed_at = db.Column(db.DateTime, nullable=True)

    # cleaning
    cleaning_completed_at = db.Column(db.DateTime, nullable=True)

    # split
    splitting_completed_at = db.Column(db.DateTime, nullable=True)

    # indexing
    tokens = db.Column(db.Integer, nullable=True)
    indexing_latency = db.Column(db.Float, nullable=True)
    completed_at = db.Column(db.DateTime, nullable=True)

    # pause
    is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false"))
    paused_by = db.Column(StringUUID, nullable=True)
    paused_at = db.Column(db.DateTime, nullable=True)

    # error
    error = db.Column(db.Text, nullable=True)
    stopped_at = db.Column(db.DateTime, nullable=True)

    # basic fields
    indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
    enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
    disabled_at = db.Column(db.DateTime, nullable=True)
    disabled_by = db.Column(StringUUID, nullable=True)
    archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
    archived_reason = db.Column(db.String(255), nullable=True)
    archived_by = db.Column(StringUUID, nullable=True)
    archived_at = db.Column(db.DateTime, nullable=True)
    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
    doc_type = db.Column(db.String(40), nullable=True)
    doc_metadata = db.Column(db.JSON, nullable=True)
    doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying"))
    doc_language = db.Column(db.String(255), nullable=True)

    DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"]

    @property
    def display_status(self):
        status = None
        if self.indexing_status == "waiting":
            status = "queuing"
        elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused:
            status = "paused"
        elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}:
            status = "indexing"
        elif self.indexing_status == "error":
            status = "error"
        elif self.indexing_status == "completed" and not self.archived and self.enabled:
            status = "available"
        elif self.indexing_status == "completed" and not self.archived and not self.enabled:
            status = "disabled"
        elif self.indexing_status == "completed" and self.archived:
            status = "archived"
        return status

    @property
    def data_source_info_dict(self):
        if self.data_source_info:
            try:
                data_source_info_dict = json.loads(self.data_source_info)
            except JSONDecodeError:
                data_source_info_dict = {}

            return data_source_info_dict
        return None

    @property
    def data_source_detail_dict(self):
        if self.data_source_info:
            if self.data_source_type == "upload_file":
                data_source_info_dict = json.loads(self.data_source_info)
                file_detail = (
                    db.session.query(UploadFile)
                    .filter(UploadFile.id == data_source_info_dict["upload_file_id"])
                    .one_or_none()
                )
                if file_detail:
                    return {
                        "upload_file": {
                            "id": file_detail.id,
                            "name": file_detail.name,
                            "size": file_detail.size,
                            "extension": file_detail.extension,
                            "mime_type": file_detail.mime_type,
                            "created_by": file_detail.created_by,
                            "created_at": file_detail.created_at.timestamp(),
                        }
                    }
            elif self.data_source_type in {"notion_import", "website_crawl"}:
                return json.loads(self.data_source_info)
        return {}

    @property
    def average_segment_length(self):
        if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0:
            return self.word_count // self.segment_count
        return 0

    @property
    def dataset_process_rule(self):
        if self.dataset_process_rule_id:
            return db.session.get(DatasetProcessRule, self.dataset_process_rule_id)
        return None

    @property
    def dataset(self):
        return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none()

    @property
    def segment_count(self):
        return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count()

    @property
    def hit_count(self):
        return (
            DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count)))
            .filter(DocumentSegment.document_id == self.id)
            .scalar()
        )

    def to_dict(self):
        return {
            "id": self.id,
            "tenant_id": self.tenant_id,
            "dataset_id": self.dataset_id,
            "position": self.position,
            "data_source_type": self.data_source_type,
            "data_source_info": self.data_source_info,
            "dataset_process_rule_id": self.dataset_process_rule_id,
            "batch": self.batch,
            "name": self.name,
            "created_from": self.created_from,
            "created_by": self.created_by,
            "created_api_request_id": self.created_api_request_id,
            "created_at": self.created_at,
            "processing_started_at": self.processing_started_at,
            "file_id": self.file_id,
            "word_count": self.word_count,
            "parsing_completed_at": self.parsing_completed_at,
            "cleaning_completed_at": self.cleaning_completed_at,
            "splitting_completed_at": self.splitting_completed_at,
            "tokens": self.tokens,
            "indexing_latency": self.indexing_latency,
            "completed_at": self.completed_at,
            "is_paused": self.is_paused,
            "paused_by": self.paused_by,
            "paused_at": self.paused_at,
            "error": self.error,
            "stopped_at": self.stopped_at,
            "indexing_status": self.indexing_status,
            "enabled": self.enabled,
            "disabled_at": self.disabled_at,
            "disabled_by": self.disabled_by,
            "archived": self.archived,
            "archived_reason": self.archived_reason,
            "archived_by": self.archived_by,
            "archived_at": self.archived_at,
            "updated_at": self.updated_at,
            "doc_type": self.doc_type,
            "doc_metadata": self.doc_metadata,
            "doc_form": self.doc_form,
            "doc_language": self.doc_language,
            "display_status": self.display_status,
            "data_source_info_dict": self.data_source_info_dict,
            "average_segment_length": self.average_segment_length,
            "dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None,
            "dataset": self.dataset.to_dict() if self.dataset else None,
            "segment_count": self.segment_count,
            "hit_count": self.hit_count,
        }

    @classmethod
    def from_dict(cls, data: dict):
        return cls(
            id=data.get("id"),
            tenant_id=data.get("tenant_id"),
            dataset_id=data.get("dataset_id"),
            position=data.get("position"),
            data_source_type=data.get("data_source_type"),
            data_source_info=data.get("data_source_info"),
            dataset_process_rule_id=data.get("dataset_process_rule_id"),
            batch=data.get("batch"),
            name=data.get("name"),
            created_from=data.get("created_from"),
            created_by=data.get("created_by"),
            created_api_request_id=data.get("created_api_request_id"),
            created_at=data.get("created_at"),
            processing_started_at=data.get("processing_started_at"),
            file_id=data.get("file_id"),
            word_count=data.get("word_count"),
            parsing_completed_at=data.get("parsing_completed_at"),
            cleaning_completed_at=data.get("cleaning_completed_at"),
            splitting_completed_at=data.get("splitting_completed_at"),
            tokens=data.get("tokens"),
            indexing_latency=data.get("indexing_latency"),
            completed_at=data.get("completed_at"),
            is_paused=data.get("is_paused"),
            paused_by=data.get("paused_by"),
            paused_at=data.get("paused_at"),
            error=data.get("error"),
            stopped_at=data.get("stopped_at"),
            indexing_status=data.get("indexing_status"),
            enabled=data.get("enabled"),
            disabled_at=data.get("disabled_at"),
            disabled_by=data.get("disabled_by"),
            archived=data.get("archived"),
            archived_reason=data.get("archived_reason"),
            archived_by=data.get("archived_by"),
            archived_at=data.get("archived_at"),
            updated_at=data.get("updated_at"),
            doc_type=data.get("doc_type"),
            doc_metadata=data.get("doc_metadata"),
            doc_form=data.get("doc_form"),
            doc_language=data.get("doc_language"),
        )


class DocumentSegment(db.Model):
    __tablename__ = "document_segments"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="document_segment_pkey"),
        db.Index("document_segment_dataset_id_idx", "dataset_id"),
        db.Index("document_segment_document_id_idx", "document_id"),
        db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"),
        db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"),
        db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"),
        db.Index("document_segment_tenant_idx", "tenant_id"),
    )

    # initial fields
    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
    tenant_id = db.Column(StringUUID, nullable=False)
    dataset_id = db.Column(StringUUID, nullable=False)
    document_id = db.Column(StringUUID, nullable=False)
    position = db.Column(db.Integer, nullable=False)
    content = db.Column(db.Text, nullable=False)
    answer = db.Column(db.Text, nullable=True)
    word_count = db.Column(db.Integer, nullable=False)
    tokens = db.Column(db.Integer, nullable=False)

    # indexing fields
    keywords = db.Column(db.JSON, nullable=True)
    index_node_id = db.Column(db.String(255), nullable=True)
    index_node_hash = db.Column(db.String(255), nullable=True)

    # basic fields
    hit_count = db.Column(db.Integer, nullable=False, default=0)
    enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
    disabled_at = db.Column(db.DateTime, nullable=True)
    disabled_by = db.Column(StringUUID, nullable=True)
    status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying"))
    created_by = db.Column(StringUUID, nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
    updated_by = db.Column(StringUUID, nullable=True)
    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
    indexing_at = db.Column(db.DateTime, nullable=True)
    completed_at = db.Column(db.DateTime, nullable=True)
    error = db.Column(db.Text, nullable=True)
    stopped_at = db.Column(db.DateTime, nullable=True)

    @property
    def dataset(self):
        return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first()

    @property
    def document(self):
        return db.session.query(Document).filter(Document.id == self.document_id).first()

    @property
    def previous_segment(self):
        return (
            db.session.query(DocumentSegment)
            .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1)
            .first()
        )

    @property
    def next_segment(self):
        return (
            db.session.query(DocumentSegment)
            .filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1)
            .first()
        )

    def get_sign_content(self):
        signed_urls = []
        text = self.content

        # For data before v0.10.0
        pattern = r"/files/([a-f0-9\-]+)/image-preview"
        matches = re.finditer(pattern, text)
        for match in matches:
            upload_file_id = match.group(1)
            nonce = os.urandom(16).hex()
            timestamp = str(int(time.time()))
            data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
            secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
            sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
            encoded_sign = base64.urlsafe_b64encode(sign).decode()

            params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
            signed_url = f"{match.group(0)}?{params}"
            signed_urls.append((match.start(), match.end(), signed_url))

        # For data after v0.10.0
        pattern = r"/files/([a-f0-9\-]+)/file-preview"
        matches = re.finditer(pattern, text)
        for match in matches:
            upload_file_id = match.group(1)
            nonce = os.urandom(16).hex()
            timestamp = str(int(time.time()))
            data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}"
            secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b""
            sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
            encoded_sign = base64.urlsafe_b64encode(sign).decode()

            params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
            signed_url = f"{match.group(0)}?{params}"
            signed_urls.append((match.start(), match.end(), signed_url))

        # Reconstruct the text with signed URLs
        offset = 0
        for start, end, signed_url in signed_urls:
            text = text[: start + offset] + signed_url + text[end + offset :]
            offset += len(signed_url) - (end - start)

        return text


class AppDatasetJoin(db.Model):
    __tablename__ = "app_dataset_joins"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"),
        db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"),
    )

    id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
    app_id = db.Column(StringUUID, nullable=False)
    dataset_id = db.Column(StringUUID, nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())

    @property
    def app(self):
        return db.session.get(App, self.app_id)


class DatasetQuery(db.Model):
    __tablename__ = "dataset_queries"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="dataset_query_pkey"),
        db.Index("dataset_query_dataset_id_idx", "dataset_id"),
    )

    id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()"))
    dataset_id = db.Column(StringUUID, nullable=False)
    content = db.Column(db.Text, nullable=False)
    source = db.Column(db.String(255), nullable=False)
    source_app_id = db.Column(StringUUID, nullable=True)
    created_by_role = db.Column(db.String, nullable=False)
    created_by = db.Column(StringUUID, nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())


class DatasetKeywordTable(db.Model):
    __tablename__ = "dataset_keyword_tables"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"),
        db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"),
    )

    id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
    dataset_id = db.Column(StringUUID, nullable=False, unique=True)
    keyword_table = db.Column(db.Text, nullable=False)
    data_source_type = db.Column(
        db.String(255), nullable=False, server_default=db.text("'database'::character varying")
    )

    @property
    def keyword_table_dict(self):
        class SetDecoder(json.JSONDecoder):
            def __init__(self, *args, **kwargs):
                super().__init__(object_hook=self.object_hook, *args, **kwargs)

            def object_hook(self, dct):
                if isinstance(dct, dict):
                    for keyword, node_idxs in dct.items():
                        if isinstance(node_idxs, list):
                            dct[keyword] = set(node_idxs)
                return dct

        # get dataset
        dataset = Dataset.query.filter_by(id=self.dataset_id).first()
        if not dataset:
            return None
        if self.data_source_type == "database":
            return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None
        else:
            file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt"
            try:
                keyword_table_text = storage.load_once(file_key)
                if keyword_table_text:
                    return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder)
                return None
            except Exception as e:
                logging.exception(str(e))
                return None


class Embedding(db.Model):
    __tablename__ = "embeddings"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="embedding_pkey"),
        db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"),
        db.Index("created_at_idx", "created_at"),
    )

    id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
    model_name = db.Column(
        db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying")
    )
    hash = db.Column(db.String(64), nullable=False)
    embedding = db.Column(db.LargeBinary, nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
    provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying"))

    def set_embedding(self, embedding_data: list[float]):
        self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL)

    def get_embedding(self) -> list[float]:
        return pickle.loads(self.embedding)


class DatasetCollectionBinding(db.Model):
    __tablename__ = "dataset_collection_bindings"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"),
        db.Index("provider_model_name_idx", "provider_name", "model_name"),
    )

    id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
    provider_name = db.Column(db.String(40), nullable=False)
    model_name = db.Column(db.String(255), nullable=False)
    type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False)
    collection_name = db.Column(db.String(64), nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))


class TidbAuthBinding(db.Model):
    __tablename__ = "tidb_auth_bindings"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"),
        db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"),
        db.Index("tidb_auth_bindings_active_idx", "active"),
        db.Index("tidb_auth_bindings_created_at_idx", "created_at"),
        db.Index("tidb_auth_bindings_status_idx", "status"),
    )
    id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
    tenant_id = db.Column(StringUUID, nullable=True)
    cluster_id = db.Column(db.String(255), nullable=False)
    cluster_name = db.Column(db.String(255), nullable=False)
    active = db.Column(db.Boolean, nullable=False, server_default=db.text("false"))
    status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING"))
    account = db.Column(db.String(255), nullable=False)
    password = db.Column(db.String(255), nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))


class Whitelist(db.Model):
    __tablename__ = "whitelists"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="whitelists_pkey"),
        db.Index("whitelists_tenant_idx", "tenant_id"),
    )
    id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()"))
    tenant_id = db.Column(StringUUID, nullable=True)
    category = db.Column(db.String(255), nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))


class DatasetPermission(db.Model):
    __tablename__ = "dataset_permissions"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"),
        db.Index("idx_dataset_permissions_dataset_id", "dataset_id"),
        db.Index("idx_dataset_permissions_account_id", "account_id"),
        db.Index("idx_dataset_permissions_tenant_id", "tenant_id"),
    )

    id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True)
    dataset_id = db.Column(StringUUID, nullable=False)
    account_id = db.Column(StringUUID, nullable=False)
    tenant_id = db.Column(StringUUID, nullable=False)
    has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true"))
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))


class ExternalKnowledgeApis(db.Model):
    __tablename__ = "external_knowledge_apis"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"),
        db.Index("external_knowledge_apis_tenant_idx", "tenant_id"),
        db.Index("external_knowledge_apis_name_idx", "name"),
    )

    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
    name = db.Column(db.String(255), nullable=False)
    description = db.Column(db.String(255), nullable=False)
    tenant_id = db.Column(StringUUID, nullable=False)
    settings = db.Column(db.Text, nullable=True)
    created_by = db.Column(StringUUID, nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
    updated_by = db.Column(StringUUID, nullable=True)
    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))

    def to_dict(self):
        return {
            "id": self.id,
            "tenant_id": self.tenant_id,
            "name": self.name,
            "description": self.description,
            "settings": self.settings_dict,
            "dataset_bindings": self.dataset_bindings,
            "created_by": self.created_by,
            "created_at": self.created_at.isoformat(),
        }

    @property
    def settings_dict(self):
        try:
            return json.loads(self.settings) if self.settings else None
        except JSONDecodeError:
            return None

    @property
    def dataset_bindings(self):
        external_knowledge_bindings = (
            db.session.query(ExternalKnowledgeBindings)
            .filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id)
            .all()
        )
        dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings]
        datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all()
        dataset_bindings = []
        for dataset in datasets:
            dataset_bindings.append({"id": dataset.id, "name": dataset.name})

        return dataset_bindings


class ExternalKnowledgeBindings(db.Model):
    __tablename__ = "external_knowledge_bindings"
    __table_args__ = (
        db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"),
        db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"),
        db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"),
        db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"),
        db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"),
    )

    id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()"))
    tenant_id = db.Column(StringUUID, nullable=False)
    external_knowledge_api_id = db.Column(StringUUID, nullable=False)
    dataset_id = db.Column(StringUUID, nullable=False)
    external_knowledge_id = db.Column(db.Text, nullable=False)
    created_by = db.Column(StringUUID, nullable=False)
    created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))
    updated_by = db.Column(StringUUID, nullable=True)
    updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)"))