Spaces:
Build error
Build error
import base64 | |
import enum | |
import hashlib | |
import hmac | |
import json | |
import logging | |
import os | |
import pickle | |
import re | |
import time | |
from json import JSONDecodeError | |
from sqlalchemy import func | |
from sqlalchemy.dialects.postgresql import JSONB | |
from configs import dify_config | |
from core.rag.retrieval.retrieval_methods import RetrievalMethod | |
from extensions.ext_database import db | |
from extensions.ext_storage import storage | |
from .account import Account | |
from .model import App, Tag, TagBinding, UploadFile | |
from .types import StringUUID | |
class DatasetPermissionEnum(str, enum.Enum): | |
ONLY_ME = "only_me" | |
ALL_TEAM = "all_team_members" | |
PARTIAL_TEAM = "partial_members" | |
class Dataset(db.Model): | |
__tablename__ = "datasets" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="dataset_pkey"), | |
db.Index("dataset_tenant_idx", "tenant_id"), | |
db.Index("retrieval_model_idx", "retrieval_model", postgresql_using="gin"), | |
) | |
INDEXING_TECHNIQUE_LIST = ["high_quality", "economy", None] | |
PROVIDER_LIST = ["vendor", "external", None] | |
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()")) | |
tenant_id = db.Column(StringUUID, nullable=False) | |
name = db.Column(db.String(255), nullable=False) | |
description = db.Column(db.Text, nullable=True) | |
provider = db.Column(db.String(255), nullable=False, server_default=db.text("'vendor'::character varying")) | |
permission = db.Column(db.String(255), nullable=False, server_default=db.text("'only_me'::character varying")) | |
data_source_type = db.Column(db.String(255)) | |
indexing_technique = db.Column(db.String(255), nullable=True) | |
index_struct = db.Column(db.Text, nullable=True) | |
created_by = db.Column(StringUUID, nullable=False) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
updated_by = db.Column(StringUUID, nullable=True) | |
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
embedding_model = db.Column(db.String(255), nullable=True) | |
embedding_model_provider = db.Column(db.String(255), nullable=True) | |
collection_binding_id = db.Column(StringUUID, nullable=True) | |
retrieval_model = db.Column(JSONB, nullable=True) | |
def dataset_keyword_table(self): | |
dataset_keyword_table = ( | |
db.session.query(DatasetKeywordTable).filter(DatasetKeywordTable.dataset_id == self.id).first() | |
) | |
if dataset_keyword_table: | |
return dataset_keyword_table | |
return None | |
def index_struct_dict(self): | |
return json.loads(self.index_struct) if self.index_struct else None | |
def external_retrieval_model(self): | |
default_retrieval_model = { | |
"top_k": 2, | |
"score_threshold": 0.0, | |
} | |
return self.retrieval_model or default_retrieval_model | |
def created_by_account(self): | |
return db.session.get(Account, self.created_by) | |
def latest_process_rule(self): | |
return ( | |
DatasetProcessRule.query.filter(DatasetProcessRule.dataset_id == self.id) | |
.order_by(DatasetProcessRule.created_at.desc()) | |
.first() | |
) | |
def app_count(self): | |
return ( | |
db.session.query(func.count(AppDatasetJoin.id)) | |
.filter(AppDatasetJoin.dataset_id == self.id, App.id == AppDatasetJoin.app_id) | |
.scalar() | |
) | |
def document_count(self): | |
return db.session.query(func.count(Document.id)).filter(Document.dataset_id == self.id).scalar() | |
def available_document_count(self): | |
return ( | |
db.session.query(func.count(Document.id)) | |
.filter( | |
Document.dataset_id == self.id, | |
Document.indexing_status == "completed", | |
Document.enabled == True, | |
Document.archived == False, | |
) | |
.scalar() | |
) | |
def available_segment_count(self): | |
return ( | |
db.session.query(func.count(DocumentSegment.id)) | |
.filter( | |
DocumentSegment.dataset_id == self.id, | |
DocumentSegment.status == "completed", | |
DocumentSegment.enabled == True, | |
) | |
.scalar() | |
) | |
def word_count(self): | |
return ( | |
Document.query.with_entities(func.coalesce(func.sum(Document.word_count))) | |
.filter(Document.dataset_id == self.id) | |
.scalar() | |
) | |
def doc_form(self): | |
document = db.session.query(Document).filter(Document.dataset_id == self.id).first() | |
if document: | |
return document.doc_form | |
return None | |
def retrieval_model_dict(self): | |
default_retrieval_model = { | |
"search_method": RetrievalMethod.SEMANTIC_SEARCH.value, | |
"reranking_enable": False, | |
"reranking_model": {"reranking_provider_name": "", "reranking_model_name": ""}, | |
"top_k": 2, | |
"score_threshold_enabled": False, | |
} | |
return self.retrieval_model or default_retrieval_model | |
def tags(self): | |
tags = ( | |
db.session.query(Tag) | |
.join(TagBinding, Tag.id == TagBinding.tag_id) | |
.filter( | |
TagBinding.target_id == self.id, | |
TagBinding.tenant_id == self.tenant_id, | |
Tag.tenant_id == self.tenant_id, | |
Tag.type == "knowledge", | |
) | |
.all() | |
) | |
return tags or [] | |
def external_knowledge_info(self): | |
if self.provider != "external": | |
return None | |
external_knowledge_binding = ( | |
db.session.query(ExternalKnowledgeBindings).filter(ExternalKnowledgeBindings.dataset_id == self.id).first() | |
) | |
if not external_knowledge_binding: | |
return None | |
external_knowledge_api = ( | |
db.session.query(ExternalKnowledgeApis) | |
.filter(ExternalKnowledgeApis.id == external_knowledge_binding.external_knowledge_api_id) | |
.first() | |
) | |
if not external_knowledge_api: | |
return None | |
return { | |
"external_knowledge_id": external_knowledge_binding.external_knowledge_id, | |
"external_knowledge_api_id": external_knowledge_api.id, | |
"external_knowledge_api_name": external_knowledge_api.name, | |
"external_knowledge_api_endpoint": json.loads(external_knowledge_api.settings).get("endpoint", ""), | |
} | |
def gen_collection_name_by_id(dataset_id: str) -> str: | |
normalized_dataset_id = dataset_id.replace("-", "_") | |
return f"Vector_index_{normalized_dataset_id}_Node" | |
class DatasetProcessRule(db.Model): | |
__tablename__ = "dataset_process_rules" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="dataset_process_rule_pkey"), | |
db.Index("dataset_process_rule_dataset_id_idx", "dataset_id"), | |
) | |
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
dataset_id = db.Column(StringUUID, nullable=False) | |
mode = db.Column(db.String(255), nullable=False, server_default=db.text("'automatic'::character varying")) | |
rules = db.Column(db.Text, nullable=True) | |
created_by = db.Column(StringUUID, nullable=False) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
MODES = ["automatic", "custom"] | |
PRE_PROCESSING_RULES = ["remove_stopwords", "remove_extra_spaces", "remove_urls_emails"] | |
AUTOMATIC_RULES = { | |
"pre_processing_rules": [ | |
{"id": "remove_extra_spaces", "enabled": True}, | |
{"id": "remove_urls_emails", "enabled": False}, | |
], | |
"segmentation": {"delimiter": "\n", "max_tokens": 500, "chunk_overlap": 50}, | |
} | |
def to_dict(self): | |
return { | |
"id": self.id, | |
"dataset_id": self.dataset_id, | |
"mode": self.mode, | |
"rules": self.rules_dict, | |
"created_by": self.created_by, | |
"created_at": self.created_at, | |
} | |
def rules_dict(self): | |
try: | |
return json.loads(self.rules) if self.rules else None | |
except JSONDecodeError: | |
return None | |
class Document(db.Model): | |
__tablename__ = "documents" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="document_pkey"), | |
db.Index("document_dataset_id_idx", "dataset_id"), | |
db.Index("document_is_paused_idx", "is_paused"), | |
db.Index("document_tenant_idx", "tenant_id"), | |
) | |
# initial fields | |
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
tenant_id = db.Column(StringUUID, nullable=False) | |
dataset_id = db.Column(StringUUID, nullable=False) | |
position = db.Column(db.Integer, nullable=False) | |
data_source_type = db.Column(db.String(255), nullable=False) | |
data_source_info = db.Column(db.Text, nullable=True) | |
dataset_process_rule_id = db.Column(StringUUID, nullable=True) | |
batch = db.Column(db.String(255), nullable=False) | |
name = db.Column(db.String(255), nullable=False) | |
created_from = db.Column(db.String(255), nullable=False) | |
created_by = db.Column(StringUUID, nullable=False) | |
created_api_request_id = db.Column(StringUUID, nullable=True) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
# start processing | |
processing_started_at = db.Column(db.DateTime, nullable=True) | |
# parsing | |
file_id = db.Column(db.Text, nullable=True) | |
word_count = db.Column(db.Integer, nullable=True) | |
parsing_completed_at = db.Column(db.DateTime, nullable=True) | |
# cleaning | |
cleaning_completed_at = db.Column(db.DateTime, nullable=True) | |
# split | |
splitting_completed_at = db.Column(db.DateTime, nullable=True) | |
# indexing | |
tokens = db.Column(db.Integer, nullable=True) | |
indexing_latency = db.Column(db.Float, nullable=True) | |
completed_at = db.Column(db.DateTime, nullable=True) | |
# pause | |
is_paused = db.Column(db.Boolean, nullable=True, server_default=db.text("false")) | |
paused_by = db.Column(StringUUID, nullable=True) | |
paused_at = db.Column(db.DateTime, nullable=True) | |
# error | |
error = db.Column(db.Text, nullable=True) | |
stopped_at = db.Column(db.DateTime, nullable=True) | |
# basic fields | |
indexing_status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) | |
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |
disabled_at = db.Column(db.DateTime, nullable=True) | |
disabled_by = db.Column(StringUUID, nullable=True) | |
archived = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |
archived_reason = db.Column(db.String(255), nullable=True) | |
archived_by = db.Column(StringUUID, nullable=True) | |
archived_at = db.Column(db.DateTime, nullable=True) | |
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
doc_type = db.Column(db.String(40), nullable=True) | |
doc_metadata = db.Column(db.JSON, nullable=True) | |
doc_form = db.Column(db.String(255), nullable=False, server_default=db.text("'text_model'::character varying")) | |
doc_language = db.Column(db.String(255), nullable=True) | |
DATA_SOURCES = ["upload_file", "notion_import", "website_crawl"] | |
def display_status(self): | |
status = None | |
if self.indexing_status == "waiting": | |
status = "queuing" | |
elif self.indexing_status not in {"completed", "error", "waiting"} and self.is_paused: | |
status = "paused" | |
elif self.indexing_status in {"parsing", "cleaning", "splitting", "indexing"}: | |
status = "indexing" | |
elif self.indexing_status == "error": | |
status = "error" | |
elif self.indexing_status == "completed" and not self.archived and self.enabled: | |
status = "available" | |
elif self.indexing_status == "completed" and not self.archived and not self.enabled: | |
status = "disabled" | |
elif self.indexing_status == "completed" and self.archived: | |
status = "archived" | |
return status | |
def data_source_info_dict(self): | |
if self.data_source_info: | |
try: | |
data_source_info_dict = json.loads(self.data_source_info) | |
except JSONDecodeError: | |
data_source_info_dict = {} | |
return data_source_info_dict | |
return None | |
def data_source_detail_dict(self): | |
if self.data_source_info: | |
if self.data_source_type == "upload_file": | |
data_source_info_dict = json.loads(self.data_source_info) | |
file_detail = ( | |
db.session.query(UploadFile) | |
.filter(UploadFile.id == data_source_info_dict["upload_file_id"]) | |
.one_or_none() | |
) | |
if file_detail: | |
return { | |
"upload_file": { | |
"id": file_detail.id, | |
"name": file_detail.name, | |
"size": file_detail.size, | |
"extension": file_detail.extension, | |
"mime_type": file_detail.mime_type, | |
"created_by": file_detail.created_by, | |
"created_at": file_detail.created_at.timestamp(), | |
} | |
} | |
elif self.data_source_type in {"notion_import", "website_crawl"}: | |
return json.loads(self.data_source_info) | |
return {} | |
def average_segment_length(self): | |
if self.word_count and self.word_count != 0 and self.segment_count and self.segment_count != 0: | |
return self.word_count // self.segment_count | |
return 0 | |
def dataset_process_rule(self): | |
if self.dataset_process_rule_id: | |
return db.session.get(DatasetProcessRule, self.dataset_process_rule_id) | |
return None | |
def dataset(self): | |
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).one_or_none() | |
def segment_count(self): | |
return DocumentSegment.query.filter(DocumentSegment.document_id == self.id).count() | |
def hit_count(self): | |
return ( | |
DocumentSegment.query.with_entities(func.coalesce(func.sum(DocumentSegment.hit_count))) | |
.filter(DocumentSegment.document_id == self.id) | |
.scalar() | |
) | |
def to_dict(self): | |
return { | |
"id": self.id, | |
"tenant_id": self.tenant_id, | |
"dataset_id": self.dataset_id, | |
"position": self.position, | |
"data_source_type": self.data_source_type, | |
"data_source_info": self.data_source_info, | |
"dataset_process_rule_id": self.dataset_process_rule_id, | |
"batch": self.batch, | |
"name": self.name, | |
"created_from": self.created_from, | |
"created_by": self.created_by, | |
"created_api_request_id": self.created_api_request_id, | |
"created_at": self.created_at, | |
"processing_started_at": self.processing_started_at, | |
"file_id": self.file_id, | |
"word_count": self.word_count, | |
"parsing_completed_at": self.parsing_completed_at, | |
"cleaning_completed_at": self.cleaning_completed_at, | |
"splitting_completed_at": self.splitting_completed_at, | |
"tokens": self.tokens, | |
"indexing_latency": self.indexing_latency, | |
"completed_at": self.completed_at, | |
"is_paused": self.is_paused, | |
"paused_by": self.paused_by, | |
"paused_at": self.paused_at, | |
"error": self.error, | |
"stopped_at": self.stopped_at, | |
"indexing_status": self.indexing_status, | |
"enabled": self.enabled, | |
"disabled_at": self.disabled_at, | |
"disabled_by": self.disabled_by, | |
"archived": self.archived, | |
"archived_reason": self.archived_reason, | |
"archived_by": self.archived_by, | |
"archived_at": self.archived_at, | |
"updated_at": self.updated_at, | |
"doc_type": self.doc_type, | |
"doc_metadata": self.doc_metadata, | |
"doc_form": self.doc_form, | |
"doc_language": self.doc_language, | |
"display_status": self.display_status, | |
"data_source_info_dict": self.data_source_info_dict, | |
"average_segment_length": self.average_segment_length, | |
"dataset_process_rule": self.dataset_process_rule.to_dict() if self.dataset_process_rule else None, | |
"dataset": self.dataset.to_dict() if self.dataset else None, | |
"segment_count": self.segment_count, | |
"hit_count": self.hit_count, | |
} | |
def from_dict(cls, data: dict): | |
return cls( | |
id=data.get("id"), | |
tenant_id=data.get("tenant_id"), | |
dataset_id=data.get("dataset_id"), | |
position=data.get("position"), | |
data_source_type=data.get("data_source_type"), | |
data_source_info=data.get("data_source_info"), | |
dataset_process_rule_id=data.get("dataset_process_rule_id"), | |
batch=data.get("batch"), | |
name=data.get("name"), | |
created_from=data.get("created_from"), | |
created_by=data.get("created_by"), | |
created_api_request_id=data.get("created_api_request_id"), | |
created_at=data.get("created_at"), | |
processing_started_at=data.get("processing_started_at"), | |
file_id=data.get("file_id"), | |
word_count=data.get("word_count"), | |
parsing_completed_at=data.get("parsing_completed_at"), | |
cleaning_completed_at=data.get("cleaning_completed_at"), | |
splitting_completed_at=data.get("splitting_completed_at"), | |
tokens=data.get("tokens"), | |
indexing_latency=data.get("indexing_latency"), | |
completed_at=data.get("completed_at"), | |
is_paused=data.get("is_paused"), | |
paused_by=data.get("paused_by"), | |
paused_at=data.get("paused_at"), | |
error=data.get("error"), | |
stopped_at=data.get("stopped_at"), | |
indexing_status=data.get("indexing_status"), | |
enabled=data.get("enabled"), | |
disabled_at=data.get("disabled_at"), | |
disabled_by=data.get("disabled_by"), | |
archived=data.get("archived"), | |
archived_reason=data.get("archived_reason"), | |
archived_by=data.get("archived_by"), | |
archived_at=data.get("archived_at"), | |
updated_at=data.get("updated_at"), | |
doc_type=data.get("doc_type"), | |
doc_metadata=data.get("doc_metadata"), | |
doc_form=data.get("doc_form"), | |
doc_language=data.get("doc_language"), | |
) | |
class DocumentSegment(db.Model): | |
__tablename__ = "document_segments" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="document_segment_pkey"), | |
db.Index("document_segment_dataset_id_idx", "dataset_id"), | |
db.Index("document_segment_document_id_idx", "document_id"), | |
db.Index("document_segment_tenant_dataset_idx", "dataset_id", "tenant_id"), | |
db.Index("document_segment_tenant_document_idx", "document_id", "tenant_id"), | |
db.Index("document_segment_dataset_node_idx", "dataset_id", "index_node_id"), | |
db.Index("document_segment_tenant_idx", "tenant_id"), | |
) | |
# initial fields | |
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
tenant_id = db.Column(StringUUID, nullable=False) | |
dataset_id = db.Column(StringUUID, nullable=False) | |
document_id = db.Column(StringUUID, nullable=False) | |
position = db.Column(db.Integer, nullable=False) | |
content = db.Column(db.Text, nullable=False) | |
answer = db.Column(db.Text, nullable=True) | |
word_count = db.Column(db.Integer, nullable=False) | |
tokens = db.Column(db.Integer, nullable=False) | |
# indexing fields | |
keywords = db.Column(db.JSON, nullable=True) | |
index_node_id = db.Column(db.String(255), nullable=True) | |
index_node_hash = db.Column(db.String(255), nullable=True) | |
# basic fields | |
hit_count = db.Column(db.Integer, nullable=False, default=0) | |
enabled = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |
disabled_at = db.Column(db.DateTime, nullable=True) | |
disabled_by = db.Column(StringUUID, nullable=True) | |
status = db.Column(db.String(255), nullable=False, server_default=db.text("'waiting'::character varying")) | |
created_by = db.Column(StringUUID, nullable=False) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
updated_by = db.Column(StringUUID, nullable=True) | |
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
indexing_at = db.Column(db.DateTime, nullable=True) | |
completed_at = db.Column(db.DateTime, nullable=True) | |
error = db.Column(db.Text, nullable=True) | |
stopped_at = db.Column(db.DateTime, nullable=True) | |
def dataset(self): | |
return db.session.query(Dataset).filter(Dataset.id == self.dataset_id).first() | |
def document(self): | |
return db.session.query(Document).filter(Document.id == self.document_id).first() | |
def previous_segment(self): | |
return ( | |
db.session.query(DocumentSegment) | |
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position - 1) | |
.first() | |
) | |
def next_segment(self): | |
return ( | |
db.session.query(DocumentSegment) | |
.filter(DocumentSegment.document_id == self.document_id, DocumentSegment.position == self.position + 1) | |
.first() | |
) | |
def get_sign_content(self): | |
signed_urls = [] | |
text = self.content | |
# For data before v0.10.0 | |
pattern = r"/files/([a-f0-9\-]+)/image-preview" | |
matches = re.finditer(pattern, text) | |
for match in matches: | |
upload_file_id = match.group(1) | |
nonce = os.urandom(16).hex() | |
timestamp = str(int(time.time())) | |
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}" | |
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" | |
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |
encoded_sign = base64.urlsafe_b64encode(sign).decode() | |
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" | |
signed_url = f"{match.group(0)}?{params}" | |
signed_urls.append((match.start(), match.end(), signed_url)) | |
# For data after v0.10.0 | |
pattern = r"/files/([a-f0-9\-]+)/file-preview" | |
matches = re.finditer(pattern, text) | |
for match in matches: | |
upload_file_id = match.group(1) | |
nonce = os.urandom(16).hex() | |
timestamp = str(int(time.time())) | |
data_to_sign = f"file-preview|{upload_file_id}|{timestamp}|{nonce}" | |
secret_key = dify_config.SECRET_KEY.encode() if dify_config.SECRET_KEY else b"" | |
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest() | |
encoded_sign = base64.urlsafe_b64encode(sign).decode() | |
params = f"timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}" | |
signed_url = f"{match.group(0)}?{params}" | |
signed_urls.append((match.start(), match.end(), signed_url)) | |
# Reconstruct the text with signed URLs | |
offset = 0 | |
for start, end, signed_url in signed_urls: | |
text = text[: start + offset] + signed_url + text[end + offset :] | |
offset += len(signed_url) - (end - start) | |
return text | |
class AppDatasetJoin(db.Model): | |
__tablename__ = "app_dataset_joins" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="app_dataset_join_pkey"), | |
db.Index("app_dataset_join_app_dataset_idx", "dataset_id", "app_id"), | |
) | |
id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
app_id = db.Column(StringUUID, nullable=False) | |
dataset_id = db.Column(StringUUID, nullable=False) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |
def app(self): | |
return db.session.get(App, self.app_id) | |
class DatasetQuery(db.Model): | |
__tablename__ = "dataset_queries" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="dataset_query_pkey"), | |
db.Index("dataset_query_dataset_id_idx", "dataset_id"), | |
) | |
id = db.Column(StringUUID, primary_key=True, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
dataset_id = db.Column(StringUUID, nullable=False) | |
content = db.Column(db.Text, nullable=False) | |
source = db.Column(db.String(255), nullable=False) | |
source_app_id = db.Column(StringUUID, nullable=True) | |
created_by_role = db.Column(db.String, nullable=False) | |
created_by = db.Column(StringUUID, nullable=False) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp()) | |
class DatasetKeywordTable(db.Model): | |
__tablename__ = "dataset_keyword_tables" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="dataset_keyword_table_pkey"), | |
db.Index("dataset_keyword_table_dataset_id_idx", "dataset_id"), | |
) | |
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |
dataset_id = db.Column(StringUUID, nullable=False, unique=True) | |
keyword_table = db.Column(db.Text, nullable=False) | |
data_source_type = db.Column( | |
db.String(255), nullable=False, server_default=db.text("'database'::character varying") | |
) | |
def keyword_table_dict(self): | |
class SetDecoder(json.JSONDecoder): | |
def __init__(self, *args, **kwargs): | |
super().__init__(object_hook=self.object_hook, *args, **kwargs) | |
def object_hook(self, dct): | |
if isinstance(dct, dict): | |
for keyword, node_idxs in dct.items(): | |
if isinstance(node_idxs, list): | |
dct[keyword] = set(node_idxs) | |
return dct | |
# get dataset | |
dataset = Dataset.query.filter_by(id=self.dataset_id).first() | |
if not dataset: | |
return None | |
if self.data_source_type == "database": | |
return json.loads(self.keyword_table, cls=SetDecoder) if self.keyword_table else None | |
else: | |
file_key = "keyword_files/" + dataset.tenant_id + "/" + self.dataset_id + ".txt" | |
try: | |
keyword_table_text = storage.load_once(file_key) | |
if keyword_table_text: | |
return json.loads(keyword_table_text.decode("utf-8"), cls=SetDecoder) | |
return None | |
except Exception as e: | |
logging.exception(str(e)) | |
return None | |
class Embedding(db.Model): | |
__tablename__ = "embeddings" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="embedding_pkey"), | |
db.UniqueConstraint("model_name", "hash", "provider_name", name="embedding_hash_idx"), | |
db.Index("created_at_idx", "created_at"), | |
) | |
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |
model_name = db.Column( | |
db.String(255), nullable=False, server_default=db.text("'text-embedding-ada-002'::character varying") | |
) | |
hash = db.Column(db.String(64), nullable=False) | |
embedding = db.Column(db.LargeBinary, nullable=False) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
provider_name = db.Column(db.String(255), nullable=False, server_default=db.text("''::character varying")) | |
def set_embedding(self, embedding_data: list[float]): | |
self.embedding = pickle.dumps(embedding_data, protocol=pickle.HIGHEST_PROTOCOL) | |
def get_embedding(self) -> list[float]: | |
return pickle.loads(self.embedding) | |
class DatasetCollectionBinding(db.Model): | |
__tablename__ = "dataset_collection_bindings" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="dataset_collection_bindings_pkey"), | |
db.Index("provider_model_name_idx", "provider_name", "model_name"), | |
) | |
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |
provider_name = db.Column(db.String(40), nullable=False) | |
model_name = db.Column(db.String(255), nullable=False) | |
type = db.Column(db.String(40), server_default=db.text("'dataset'::character varying"), nullable=False) | |
collection_name = db.Column(db.String(64), nullable=False) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
class TidbAuthBinding(db.Model): | |
__tablename__ = "tidb_auth_bindings" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="tidb_auth_bindings_pkey"), | |
db.Index("tidb_auth_bindings_tenant_idx", "tenant_id"), | |
db.Index("tidb_auth_bindings_active_idx", "active"), | |
db.Index("tidb_auth_bindings_created_at_idx", "created_at"), | |
db.Index("tidb_auth_bindings_status_idx", "status"), | |
) | |
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |
tenant_id = db.Column(StringUUID, nullable=True) | |
cluster_id = db.Column(db.String(255), nullable=False) | |
cluster_name = db.Column(db.String(255), nullable=False) | |
active = db.Column(db.Boolean, nullable=False, server_default=db.text("false")) | |
status = db.Column(db.String(255), nullable=False, server_default=db.text("CREATING")) | |
account = db.Column(db.String(255), nullable=False) | |
password = db.Column(db.String(255), nullable=False) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
class Whitelist(db.Model): | |
__tablename__ = "whitelists" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="whitelists_pkey"), | |
db.Index("whitelists_tenant_idx", "tenant_id"), | |
) | |
id = db.Column(StringUUID, primary_key=True, server_default=db.text("uuid_generate_v4()")) | |
tenant_id = db.Column(StringUUID, nullable=True) | |
category = db.Column(db.String(255), nullable=False) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
class DatasetPermission(db.Model): | |
__tablename__ = "dataset_permissions" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="dataset_permission_pkey"), | |
db.Index("idx_dataset_permissions_dataset_id", "dataset_id"), | |
db.Index("idx_dataset_permissions_account_id", "account_id"), | |
db.Index("idx_dataset_permissions_tenant_id", "tenant_id"), | |
) | |
id = db.Column(StringUUID, server_default=db.text("uuid_generate_v4()"), primary_key=True) | |
dataset_id = db.Column(StringUUID, nullable=False) | |
account_id = db.Column(StringUUID, nullable=False) | |
tenant_id = db.Column(StringUUID, nullable=False) | |
has_permission = db.Column(db.Boolean, nullable=False, server_default=db.text("true")) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
class ExternalKnowledgeApis(db.Model): | |
__tablename__ = "external_knowledge_apis" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="external_knowledge_apis_pkey"), | |
db.Index("external_knowledge_apis_tenant_idx", "tenant_id"), | |
db.Index("external_knowledge_apis_name_idx", "name"), | |
) | |
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
name = db.Column(db.String(255), nullable=False) | |
description = db.Column(db.String(255), nullable=False) | |
tenant_id = db.Column(StringUUID, nullable=False) | |
settings = db.Column(db.Text, nullable=True) | |
created_by = db.Column(StringUUID, nullable=False) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
updated_by = db.Column(StringUUID, nullable=True) | |
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
def to_dict(self): | |
return { | |
"id": self.id, | |
"tenant_id": self.tenant_id, | |
"name": self.name, | |
"description": self.description, | |
"settings": self.settings_dict, | |
"dataset_bindings": self.dataset_bindings, | |
"created_by": self.created_by, | |
"created_at": self.created_at.isoformat(), | |
} | |
def settings_dict(self): | |
try: | |
return json.loads(self.settings) if self.settings else None | |
except JSONDecodeError: | |
return None | |
def dataset_bindings(self): | |
external_knowledge_bindings = ( | |
db.session.query(ExternalKnowledgeBindings) | |
.filter(ExternalKnowledgeBindings.external_knowledge_api_id == self.id) | |
.all() | |
) | |
dataset_ids = [binding.dataset_id for binding in external_knowledge_bindings] | |
datasets = db.session.query(Dataset).filter(Dataset.id.in_(dataset_ids)).all() | |
dataset_bindings = [] | |
for dataset in datasets: | |
dataset_bindings.append({"id": dataset.id, "name": dataset.name}) | |
return dataset_bindings | |
class ExternalKnowledgeBindings(db.Model): | |
__tablename__ = "external_knowledge_bindings" | |
__table_args__ = ( | |
db.PrimaryKeyConstraint("id", name="external_knowledge_bindings_pkey"), | |
db.Index("external_knowledge_bindings_tenant_idx", "tenant_id"), | |
db.Index("external_knowledge_bindings_dataset_idx", "dataset_id"), | |
db.Index("external_knowledge_bindings_external_knowledge_idx", "external_knowledge_id"), | |
db.Index("external_knowledge_bindings_external_knowledge_api_idx", "external_knowledge_api_id"), | |
) | |
id = db.Column(StringUUID, nullable=False, server_default=db.text("uuid_generate_v4()")) | |
tenant_id = db.Column(StringUUID, nullable=False) | |
external_knowledge_api_id = db.Column(StringUUID, nullable=False) | |
dataset_id = db.Column(StringUUID, nullable=False) | |
external_knowledge_id = db.Column(db.Text, nullable=False) | |
created_by = db.Column(StringUUID, nullable=False) | |
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |
updated_by = db.Column(StringUUID, nullable=True) | |
updated_at = db.Column(db.DateTime, nullable=False, server_default=db.text("CURRENT_TIMESTAMP(0)")) | |